meow commited on
Commit
710e818
1 Parent(s): cb900d8
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +193 -0
  2. README.md +3 -3
  3. confs_new/dyn_grab_arti_shadow_dm.conf +288 -0
  4. confs_new/dyn_grab_arti_shadow_dm_curriculum.conf +326 -0
  5. confs_new/dyn_grab_arti_shadow_dm_singlestage.conf +318 -0
  6. confs_new/dyn_grab_pointset_mano.conf +215 -0
  7. confs_new/dyn_grab_pointset_mano_dyn.conf +218 -0
  8. confs_new/dyn_grab_pointset_mano_dyn_optacts.conf +218 -0
  9. confs_new/dyn_grab_pointset_points_dyn.conf +257 -0
  10. confs_new/dyn_grab_pointset_points_dyn_retar.conf +274 -0
  11. confs_new/dyn_grab_pointset_points_dyn_retar_pts.conf +281 -0
  12. confs_new/dyn_grab_pointset_points_dyn_retar_pts_opts.conf +287 -0
  13. confs_new/dyn_grab_pointset_points_dyn_s1.conf +256 -0
  14. confs_new/dyn_grab_pointset_points_dyn_s2.conf +258 -0
  15. confs_new/dyn_grab_pointset_points_dyn_s3.conf +259 -0
  16. confs_new/dyn_grab_pointset_points_dyn_s4.conf +259 -0
  17. confs_new/dyn_grab_sparse_retar.conf +214 -0
  18. exp_runner_stage_1.py +0 -0
  19. models/data_utils_torch.py +1547 -0
  20. models/dataset.py +359 -0
  21. models/dataset_wtime.py +403 -0
  22. models/dyn_model_act.py +0 -0
  23. models/dyn_model_act_v2.py +0 -0
  24. models/dyn_model_act_v2_deformable.py +1582 -0
  25. models/dyn_model_utils.py +1369 -0
  26. models/embedder.py +51 -0
  27. models/fields.py +0 -0
  28. models/fields_old.py +0 -0
  29. models/renderer.py +641 -0
  30. models/renderer_def.py +725 -0
  31. models/renderer_def_multi_objs.py +1088 -0
  32. models/renderer_def_multi_objs_compositional.py +1510 -0
  33. models/renderer_def_multi_objs_rigidtrans_forward.py +1603 -0
  34. models/test.js +0 -0
  35. pre-requirements.txt +2 -0
  36. requirements.txt +27 -0
  37. scripts_demo/train_grab_pointset_points_dyn_s1.sh +28 -0
  38. scripts_new/train_grab_mano.sh +26 -0
  39. scripts_new/train_grab_mano_wreact.sh +26 -0
  40. scripts_new/train_grab_mano_wreact_optacts.sh +26 -0
  41. scripts_new/train_grab_pointset.sh +99 -0
  42. scripts_new/train_grab_pointset_points_dyn.sh +27 -0
  43. scripts_new/train_grab_pointset_points_dyn_retar.sh +36 -0
  44. scripts_new/train_grab_pointset_points_dyn_retar_pts.sh +36 -0
  45. scripts_new/train_grab_pointset_points_dyn_s1.sh +27 -0
  46. scripts_new/train_grab_pointset_points_dyn_s2.sh +27 -0
  47. scripts_new/train_grab_pointset_points_dyn_s3.sh +27 -0
  48. scripts_new/train_grab_pointset_points_dyn_s4.sh +27 -0
  49. scripts_new/train_grab_shadow_multistages.sh +102 -0
  50. scripts_new/train_grab_shadow_singlestage.sh +101 -0
.gitignore ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by .ignore support plugin (hsz.mobi)
2
+ ### Python template
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ *.npy
12
+
13
+ *.zip
14
+
15
+ *.ply
16
+
17
+
18
+ #### old files ####
19
+ exp_runner_arti_*
20
+
21
+ exp_runner.py
22
+ exp_runner_arti.py
23
+ exp_runner_dyn_model.py
24
+ exp_runner_sim.py
25
+ get-pip.py
26
+
27
+ test.py
28
+
29
+
30
+ #### old scripts and data&exp folders ###
31
+ scripts/
32
+ confs/
33
+ data/
34
+ ckpts/
35
+ exp/
36
+ uni_rep/
37
+
38
+ */*_local.sh
39
+ */*_local.conf
40
+
41
+ # Distribution / packaging
42
+ .Python
43
+ env/
44
+ build/
45
+ develop-eggs/
46
+ dist/
47
+ downloads/
48
+ eggs/
49
+ .eggs/
50
+ lib/
51
+ lib64/
52
+ parts/
53
+ sdist/
54
+ var/
55
+ *.egg-info/
56
+ .installed.cfg
57
+ *.egg
58
+
59
+ rsc/
60
+ raw_data/
61
+
62
+ # PyInstaller
63
+ # Usually these files are written by a python script from a template
64
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
65
+ *.manifest
66
+ *.spec
67
+
68
+ # Installer logs
69
+ pip-log.txt
70
+ pip-delete-this-directory.txt
71
+
72
+ # Unit test / coverage reports
73
+ htmlcov/
74
+ .tox/
75
+ .coverage
76
+ .coverage.*
77
+ .cache
78
+ nosetests.xml
79
+ coverage.xml
80
+ *,cover
81
+ .hypothesis/
82
+
83
+ # Translations
84
+ *.mo
85
+ *.pot
86
+
87
+ # Django stuff:
88
+ *.log
89
+ local_settings.py
90
+
91
+ # Flask stuff:
92
+ instance/
93
+ .webassets-cache
94
+
95
+ # Scrapy stuff:
96
+ .scrapy
97
+
98
+ # Sphinx documentation
99
+ docs/_build/
100
+
101
+ # PyBuilder
102
+ target/
103
+
104
+ # IPython Notebook
105
+ .ipynb_checkpoints
106
+
107
+ # pyenv
108
+ .python-version
109
+
110
+ # celery beat schedule file
111
+ celerybeat-schedule
112
+
113
+ # dotenv
114
+ .env
115
+
116
+ # virtualenv
117
+ venv/
118
+ ENV/
119
+
120
+ # Spyder project settings
121
+ .spyderproject
122
+
123
+ # Rope project settings
124
+ .ropeproject
125
+ ### VirtualEnv template
126
+ # Virtualenv
127
+ # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
128
+ .Python
129
+ [Bb]in
130
+ [Ii]nclude
131
+ [Ll]ib
132
+ [Ll]ib64
133
+ [Ll]ocal
134
+ # [Ss]cripts
135
+ pyvenv.cfg
136
+ .venv
137
+ pip-selfcheck.json
138
+ ### JetBrains template
139
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
140
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
141
+
142
+ # User-specific stuff:
143
+ .idea/workspace.xml
144
+ .idea/tasks.xml
145
+ .idea/dictionaries
146
+ .idea/vcs.xml
147
+ .idea/jsLibraryMappings.xml
148
+
149
+ # Sensitive or high-churn files:
150
+ .idea/dataSources.ids
151
+ .idea/dataSources.xml
152
+ .idea/dataSources.local.xml
153
+ .idea/sqlDataSources.xml
154
+ .idea/dynamic.xml
155
+ .idea/uiDesigner.xml
156
+
157
+ # Gradle:
158
+ .idea/gradle.xml
159
+ .idea/libraries
160
+
161
+ # Mongo Explorer plugin:
162
+ .idea/mongoSettings.xml
163
+
164
+ .idea/
165
+
166
+ ## File-based project format:
167
+ *.iws
168
+
169
+
170
+ ## Plugin-specific files:
171
+
172
+ # IntelliJ
173
+ /out/
174
+
175
+ # mpeltonen/sbt-idea plugin
176
+ .idea_modules/
177
+
178
+ # JIRA plugin
179
+ atlassian-ide-plugin.xml
180
+
181
+ # Crashlytics plugin (for Android Studio and IntelliJ)
182
+ com_crashlytics_export_strings.xml
183
+ crashlytics.properties
184
+ crashlytics-build.properties
185
+ fabric.properties
186
+
187
+ data
188
+ public_data
189
+ exp
190
+ tmp
191
+
192
+ .models/*.npy
193
+ .models/*.ply
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Quasi Physical Sims
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.25.0
8
  app_file: app.py
 
1
  ---
2
  title: Quasi Physical Sims
3
+ emoji: 🏃
4
+ colorFrom: gray
5
+ colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.25.0
8
  app_file: app.py
confs_new/dyn_grab_arti_shadow_dm.conf ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+ base_exp_dir = exp/CASE_NAME/wmask
4
+
5
+ tag = "train_retargeted_shadow_hand_seq_102_diffhand_model_"
6
+
7
+ recording = [
8
+ ./,
9
+ ./models
10
+ ]
11
+ }
12
+
13
+
14
+ dataset {
15
+ data_dir = public_data/CASE_NAME/
16
+ render_cameras_name = cameras_sphere.npz
17
+ object_cameras_name = cameras_sphere.npz
18
+
19
+ obj_idx = 102
20
+ }
21
+
22
+ train {
23
+ learning_rate = 5e-4
24
+ # learning_rate = 5e-6
25
+ learning_rate_actions = 5e-6
26
+ learning_rate_alpha = 0.05
27
+ end_iter = 300000
28
+
29
+ # batch_size = 128 # 64
30
+ # batch_size = 4000
31
+ # batch_size = 3096 # 64
32
+ batch_size = 1024
33
+ validate_resolution_level = 4
34
+ warm_up_end = 5000
35
+ anneal_end = 0
36
+ use_white_bkgd = False
37
+
38
+ # save_freq = 10000
39
+ save_freq = 10000
40
+ val_freq = 20 # 2500
41
+ val_mesh_freq = 20 # 5000
42
+ report_freq = 10
43
+ ### igr weight ###
44
+ igr_weight = 0.1
45
+ mask_weight = 0.1
46
+ }
47
+
48
+ model {
49
+
50
+
51
+ load_redmax_robot_actions = ""
52
+ penetrating_depth_penalty = 0.0
53
+ train_states = True
54
+
55
+ minn_dist_threshold = 0.000
56
+ obj_mass = 100.0
57
+ obj_mass = 30.0
58
+
59
+ # use_mano_hand_for_test = False
60
+ use_mano_hand_for_test = True
61
+
62
+ # train_residual_friction = False
63
+ train_residual_friction = True
64
+
65
+ # use_LBFGS = True
66
+ use_LBFGS = False
67
+
68
+
69
+ use_mano_hand_for_test = False
70
+ train_residual_friction = True
71
+
72
+ extract_delta_mesh = False
73
+ freeze_weights = True
74
+ # gt_act_xs_def = True
75
+ gt_act_xs_def = False
76
+ use_bending_network = True
77
+ use_delta_bending = True
78
+
79
+ use_passive_nets = True
80
+ use_split_network = True
81
+
82
+ n_timesteps = 60
83
+
84
+
85
+ # using_delta_glb_trans = True
86
+ using_delta_glb_trans = False
87
+ # train_multi_seqs = True
88
+
89
+
90
+ # optimize_with_intermediates = False
91
+ optimize_with_intermediates = True
92
+
93
+
94
+ loss_tangential_diff_coef = 1000
95
+ loss_tangential_diff_coef = 0
96
+
97
+
98
+
99
+ optimize_active_object = True
100
+
101
+ no_friction_constraint = False
102
+
103
+ optimize_glb_transformations = True
104
+
105
+
106
+ sim_model_path = "DiffHand/assets/hand_sphere_only_hand_testt.xml"
107
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
108
+ mano_mult_const_after_cent = 1.0
109
+ sim_num_steps = 1000000
110
+
111
+ bending_net_type = "active_force_field_v18"
112
+
113
+
114
+ ### try to train the residual friction ? ###
115
+ train_residual_friction = True
116
+ optimize_rules = True
117
+
118
+ load_optimized_init_actions = ""
119
+
120
+
121
+ use_optimizable_params = True
122
+
123
+
124
+ ### grab train seq 224 ###
125
+ penetration_determining = "sdf_of_canon"
126
+ train_with_forces_to_active = False
127
+ loss_scale_coef = 1000.0
128
+ # penetration_proj_k_to_robot_friction = 40000000.0
129
+ # penetration_proj_k_to_robot_friction = 100000000.0 # as friction coefs here #
130
+ use_same_contact_spring_k = False
131
+
132
+
133
+ # sim_model_path = "/home/xueyi/diffsim/DiffHand/assets/hand_sphere_only_hand_testt.xml"
134
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
135
+
136
+
137
+ optimize_rules = False
138
+
139
+ penetration_determining = "sdf_of_canon"
140
+
141
+ optimize_rules = False
142
+
143
+ optim_sim_model_params_from_mano = False
144
+ optimize_rules = False
145
+
146
+
147
+ penetration_proj_k_to_robot_friction = 100000000.0 # as friction coefs here # ## confs ##
148
+ penetration_proj_k_to_robot = 40000000.0 #
149
+
150
+
151
+ penetrating_depth_penalty = 1
152
+
153
+ minn_dist_threshold_robot_to_obj = 0.0
154
+
155
+
156
+ minn_dist_threshold_robot_to_obj = 0.1
157
+
158
+ optim_sim_model_params_from_mano = False
159
+ optimize_rules = False
160
+
161
+ optim_sim_model_params_from_mano = True
162
+ optimize_rules = True
163
+ minn_dist_threshold_robot_to_obj = 0.0
164
+
165
+ optim_sim_model_params_from_mano = False
166
+ optimize_rules = False
167
+ minn_dist_threshold_robot_to_obj = 0.1
168
+
169
+
170
+ ### kinematics confgs ###
171
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
172
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
173
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
174
+ # ckpt_fn = ""
175
+ # load_optimized_init_transformations = ""
176
+ optim_sim_model_params_from_mano = True
177
+ optimize_rules = True
178
+ minn_dist_threshold_robot_to_obj = 0.0
179
+
180
+ optim_sim_model_params_from_mano = False
181
+ optimize_rules = False
182
+ optimize_rules = True
183
+
184
+ ckpt_fn = "ckpts/grab/102/retargeted_shadow.pth"
185
+ load_optimized_init_transformations = "ckpts/grab/102/retargeted_shadow.pth"
186
+ optimize_rules = False
187
+
188
+ optimize_rules = True
189
+
190
+ ## opt roboto ##
191
+ opt_robo_glb_trans = True
192
+ opt_robo_glb_rot = False # opt rot # ## opt rot ##
193
+ opt_robo_states = True
194
+
195
+
196
+ use_multi_stages = False
197
+
198
+ minn_dist_threshold_robot_to_obj = 0.1
199
+
200
+ penetration_proj_k_to_robot = 40000
201
+ penetration_proj_k_to_robot_friction = 100000
202
+
203
+ drive_robot = "actions"
204
+ opt_robo_glb_trans = False
205
+ opt_robo_states = True
206
+ opt_robo_glb_rot = False
207
+
208
+ train_with_forces_to_active = True
209
+
210
+
211
+
212
+ load_redmax_robot_actions_fn = ""
213
+
214
+
215
+ optimize_rules = False
216
+
217
+
218
+
219
+ loss_scale_coef = 1.0
220
+
221
+
222
+
223
+ use_opt_rigid_translations=True
224
+
225
+ train_def = True
226
+
227
+
228
+ optimizable_rigid_translations=True
229
+
230
+ nerf {
231
+ D = 8,
232
+ d_in = 4,
233
+ d_in_view = 3,
234
+ W = 256,
235
+ multires = 10,
236
+ multires_view = 4,
237
+ output_ch = 4,
238
+ skips=[4],
239
+ use_viewdirs=True
240
+ }
241
+
242
+ sdf_network {
243
+ d_out = 257,
244
+ d_in = 3,
245
+ d_hidden = 256,
246
+ n_layers = 8,
247
+ skip_in = [4],
248
+ multires = 6,
249
+ bias = 0.5,
250
+ scale = 1.0,
251
+ geometric_init = True,
252
+ weight_norm = True,
253
+ }
254
+
255
+ variance_network {
256
+ init_val = 0.3
257
+ }
258
+
259
+ rendering_network {
260
+ d_feature = 256,
261
+ mode = idr,
262
+ d_in = 9,
263
+ d_out = 3,
264
+ d_hidden = 256,
265
+ n_layers = 4,
266
+ weight_norm = True,
267
+ multires_view = 4,
268
+ squeeze_out = True,
269
+ }
270
+
271
+ neus_renderer {
272
+ n_samples = 64,
273
+ n_importance = 64,
274
+ n_outside = 0,
275
+ up_sample_steps = 4 ,
276
+ perturb = 1.0,
277
+ }
278
+
279
+ bending_network {
280
+ multires = 6,
281
+ bending_latent_size = 32,
282
+ d_in = 3,
283
+ rigidity_hidden_dimensions = 64,
284
+ rigidity_network_depth = 5,
285
+ use_rigidity_network = False,
286
+ bending_n_timesteps = 10,
287
+ }
288
+ }
confs_new/dyn_grab_arti_shadow_dm_curriculum.conf ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+ # base_exp_dir = exp/CASE_NAME/wmask
3
+ base_exp_dir = /data2/datasets/xueyi/neus/exp/CASE_NAME/wmask
4
+
5
+ tag = "train_retargeted_shadow_hand_seq_102_diffhand_model_curriculum_"
6
+
7
+ recording = [
8
+ ./,
9
+ ./models
10
+ ]
11
+ }
12
+
13
+ dataset {
14
+ data_dir = public_data/CASE_NAME/
15
+ render_cameras_name = cameras_sphere.npz
16
+ object_cameras_name = cameras_sphere.npz
17
+
18
+ obj_idx = 102
19
+ }
20
+
21
+ train {
22
+ learning_rate = 5e-4
23
+ learning_rate_actions = 5e-6
24
+ # learning_rate = 5e-6
25
+ # learning_rate = 5e-5
26
+ learning_rate_alpha = 0.05
27
+ end_iter = 300000
28
+
29
+ # batch_size = 128 # 64
30
+ # batch_size = 4000
31
+ # batch_size = 3096 # 64
32
+ batch_size = 1024
33
+ validate_resolution_level = 4
34
+ warm_up_end = 5000
35
+ anneal_end = 0
36
+ use_white_bkgd = False
37
+
38
+ # save_freq = 10000
39
+ save_freq = 10000
40
+ val_freq = 20 # 2500
41
+ val_mesh_freq = 20 # 5000
42
+ report_freq = 10
43
+ ### igr weight ###
44
+ igr_weight = 0.1
45
+ mask_weight = 0.1
46
+ }
47
+
48
+ model {
49
+
50
+
51
+ penetration_proj_k_to_robot = 40
52
+
53
+ penetrating_depth_penalty = 1.0
54
+ penetrating_depth_penalty = 0.0
55
+ train_states = True
56
+ penetration_proj_k_to_robot = 4000000000.0
57
+
58
+
59
+ minn_dist_threshold = 0.000
60
+ # minn_dist_threshold = 0.01
61
+ obj_mass = 100.0
62
+ obj_mass = 30.0
63
+
64
+ optimize_rules = True
65
+
66
+ use_mano_hand_for_test = False
67
+ use_mano_hand_for_test = True
68
+
69
+ train_residual_friction = False
70
+ train_residual_friction = True
71
+
72
+ use_LBFGS = True
73
+ use_LBFGS = False
74
+
75
+ use_mano_hand_for_test = False
76
+ train_residual_friction = True
77
+
78
+ extract_delta_mesh = False
79
+ freeze_weights = True
80
+ # gt_act_xs_def = True
81
+ gt_act_xs_def = False
82
+ use_bending_network = True
83
+ ### for ts = 3 ###
84
+ # use_delta_bending = False
85
+ ### for ts = 3 ###
86
+ use_delta_bending = True
87
+ use_passive_nets = True
88
+ # use_passive_nets = False # sv mesh root #
89
+
90
+ use_split_network = True
91
+
92
+ penetration_determining = "plane_primitives"
93
+
94
+
95
+ n_timesteps = 3 #
96
+ # n_timesteps = 5 #
97
+ n_timesteps = 7
98
+ n_timesteps = 60
99
+
100
+
101
+
102
+
103
+ using_delta_glb_trans = True
104
+ using_delta_glb_trans = False
105
+
106
+ optimize_with_intermediates = False
107
+ optimize_with_intermediates = True
108
+
109
+
110
+ loss_tangential_diff_coef = 1000
111
+ loss_tangential_diff_coef = 0
112
+
113
+
114
+
115
+ optimize_active_object = False
116
+ optimize_active_object = True
117
+
118
+ # optimize_expanded_pts = False
119
+ # optimize_expanded_pts = True
120
+
121
+ no_friction_constraint = False
122
+
123
+ optimize_glb_transformations = True
124
+ sim_model_path = "DiffHand/assets/hand_sphere_only_hand_testt.xml"
125
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
126
+ mano_mult_const_after_cent = 1.0
127
+ sim_num_steps = 1000000
128
+
129
+ bending_net_type = "active_force_field_v18"
130
+
131
+
132
+ ### try to train the residual friction ? ###
133
+ train_residual_friction = True
134
+ optimize_rules = True
135
+ ### cube ###
136
+ load_optimized_init_actions = ""
137
+
138
+ optimize_rules = False
139
+
140
+
141
+ ## optimize rules ## penetration proj k to robot ##
142
+ optimize_rules = True
143
+ penetration_proj_k_to_robot = 4000000.0
144
+ use_optimizable_params = True
145
+
146
+ penetration_determining = "ball_primitives" # uing ball primitives
147
+ optimize_rules = True #
148
+ penetration_proj_k_to_robot = 4000000.0 #
149
+ use_optimizable_params = True
150
+ train_with_forces_to_active = False
151
+
152
+ # penetration_determining = "ball_primitives"
153
+ ### obj sdf and normals for colllision eteftion and responses ##
154
+ ### grab train seq 54; cylinder ###
155
+ penetration_determining = "sdf_of_canon"
156
+ optimize_rules = True
157
+ train_with_forces_to_active = False
158
+
159
+ ### grab train seq 1 ###
160
+ penetration_determining = "sdf_of_canon"
161
+ train_with_forces_to_active = False
162
+
163
+ ### grab train seq 224 ###
164
+ penetration_determining = "sdf_of_canon"
165
+ train_with_forces_to_active = False
166
+ loss_scale_coef = 1000.0
167
+ penetration_proj_k_to_robot_friction = 40000000.0
168
+ penetration_proj_k_to_robot_friction = 100000000.0
169
+ use_same_contact_spring_k = False
170
+ sim_model_path = "DiffHand/assets/hand_sphere_only_hand_testt.xml"
171
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
172
+
173
+
174
+ penetration_determining = "sdf_of_canon"
175
+ optimize_rules = True
176
+ # optimize_rules = True
177
+
178
+ optimize_rules = False
179
+
180
+ optimize_rules = True
181
+
182
+
183
+ optimize_rules = False
184
+
185
+ optim_sim_model_params_from_mano = True
186
+ optimize_rules = True
187
+ optim_sim_model_params_from_mano = False
188
+ optimize_rules = False
189
+
190
+ penetration_proj_k_to_robot_friction = 100000000.0
191
+ penetration_proj_k_to_robot = 40000000.0
192
+
193
+
194
+ penetrating_depth_penalty = 1
195
+
196
+ minn_dist_threshold_robot_to_obj = 0.0
197
+
198
+
199
+ minn_dist_threshold_robot_to_obj = 0.1
200
+
201
+ optim_sim_model_params_from_mano = True
202
+ optimize_rules = True
203
+ optim_sim_model_params_from_mano = False
204
+ optimize_rules = False
205
+ optim_sim_model_params_from_mano = False
206
+ optimize_rules = False
207
+
208
+ load_optimized_init_transformations = ""
209
+ optim_sim_model_params_from_mano = True
210
+ optimize_rules = True
211
+ minn_dist_threshold_robot_to_obj = 0.0
212
+
213
+
214
+ optim_sim_model_params_from_mano = False
215
+
216
+ minn_dist_threshold_robot_to_obj = 0.1
217
+
218
+
219
+ ### kinematics confgs ###
220
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
221
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
222
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
223
+ # ckpt_fn = ""
224
+ load_optimized_init_transformations = ""
225
+ optim_sim_model_params_from_mano = True
226
+ optimize_rules = True
227
+ minn_dist_threshold_robot_to_obj = 0.0
228
+
229
+ optim_sim_model_params_from_mano = False
230
+
231
+ optimize_rules = True
232
+
233
+ ckpt_fn = "ckpts/grab/102/retargeted_shadow.pth"
234
+ ckpt_fn = "/data2/datasets/xueyi/neus/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_states_optrobot__seq_102_optactswreacts_redmaxacts_rules_/checkpoints/ckpt_035459.pth"
235
+ load_optimized_init_transformations = "ckpts/grab/102/retargeted_shadow.pth"
236
+
237
+
238
+ optimize_rules = True
239
+
240
+ ## opt roboto ##
241
+ opt_robo_glb_trans = True
242
+ opt_robo_glb_rot = False # opt rot # ## opt rot ##
243
+ opt_robo_states = True
244
+
245
+
246
+ load_redmax_robot_actions_fn = "ckpts/grab/102/diffhand_act.npy"
247
+
248
+
249
+
250
+ ckpt_fn = ""
251
+
252
+ use_multi_stages = True
253
+ train_with_forces_to_active = True
254
+
255
+
256
+ # optimize_rules = False
257
+ loss_scale_coef = 1.0 ## loss scale coef ## loss scale coef ####
258
+
259
+
260
+
261
+ use_opt_rigid_translations=True
262
+
263
+ train_def = True
264
+
265
+ # optimizable_rigid_translations = False #
266
+ optimizable_rigid_translations=True
267
+
268
+ nerf {
269
+ D = 8,
270
+ d_in = 4,
271
+ d_in_view = 3,
272
+ W = 256,
273
+ multires = 10,
274
+ multires_view = 4,
275
+ output_ch = 4,
276
+ skips=[4],
277
+ use_viewdirs=True
278
+ }
279
+
280
+ sdf_network {
281
+ d_out = 257,
282
+ d_in = 3,
283
+ d_hidden = 256,
284
+ n_layers = 8,
285
+ skip_in = [4],
286
+ multires = 6,
287
+ bias = 0.5,
288
+ scale = 1.0,
289
+ geometric_init = True,
290
+ weight_norm = True,
291
+ }
292
+
293
+ variance_network {
294
+ init_val = 0.3
295
+ }
296
+
297
+ rendering_network {
298
+ d_feature = 256,
299
+ mode = idr,
300
+ d_in = 9,
301
+ d_out = 3,
302
+ d_hidden = 256,
303
+ n_layers = 4,
304
+ weight_norm = True,
305
+ multires_view = 4,
306
+ squeeze_out = True,
307
+ }
308
+
309
+ neus_renderer {
310
+ n_samples = 64,
311
+ n_importance = 64,
312
+ n_outside = 0,
313
+ up_sample_steps = 4 ,
314
+ perturb = 1.0,
315
+ }
316
+
317
+ bending_network {
318
+ multires = 6,
319
+ bending_latent_size = 32,
320
+ d_in = 3,
321
+ rigidity_hidden_dimensions = 64,
322
+ rigidity_network_depth = 5,
323
+ use_rigidity_network = False,
324
+ bending_n_timesteps = 10,
325
+ }
326
+ }
confs_new/dyn_grab_arti_shadow_dm_singlestage.conf ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+ base_exp_dir = exp/CASE_NAME/wmask
3
+
4
+ tag = "train_retargeted_shadow_hand_seq_102_diffhand_model_curriculum_"
5
+
6
+ recording = [
7
+ ./,
8
+ ./models
9
+ ]
10
+ }
11
+
12
+ dataset {
13
+ data_dir = public_data/CASE_NAME/
14
+ render_cameras_name = cameras_sphere.npz
15
+ object_cameras_name = cameras_sphere.npz
16
+
17
+ obj_idx = 102
18
+ }
19
+
20
+ train {
21
+ learning_rate = 5e-4
22
+ learning_rate_actions = 5e-6
23
+ learning_rate_alpha = 0.05
24
+ end_iter = 300000
25
+
26
+ batch_size = 1024
27
+ validate_resolution_level = 4
28
+ warm_up_end = 5000
29
+ anneal_end = 0
30
+ use_white_bkgd = False
31
+
32
+
33
+ save_freq = 10000
34
+ val_freq = 20
35
+ val_mesh_freq = 20
36
+ report_freq = 10
37
+ igr_weight = 0.1
38
+ mask_weight = 0.1
39
+ }
40
+
41
+ model {
42
+
43
+
44
+ penetration_proj_k_to_robot = 40
45
+
46
+ penetrating_depth_penalty = 1.0
47
+ penetrating_depth_penalty = 0.0
48
+ train_states = True
49
+ penetration_proj_k_to_robot = 4000000000.0
50
+
51
+
52
+ minn_dist_threshold = 0.000
53
+ # minn_dist_threshold = 0.01
54
+ obj_mass = 100.0
55
+ obj_mass = 30.0
56
+
57
+ optimize_rules = True
58
+
59
+ use_mano_hand_for_test = False
60
+ use_mano_hand_for_test = True
61
+
62
+ train_residual_friction = False
63
+ train_residual_friction = True
64
+
65
+ use_LBFGS = True
66
+ use_LBFGS = False
67
+
68
+ use_mano_hand_for_test = False
69
+ train_residual_friction = True
70
+
71
+ extract_delta_mesh = False
72
+ freeze_weights = True
73
+ # gt_act_xs_def = True
74
+ gt_act_xs_def = False
75
+ use_bending_network = True
76
+ ### for ts = 3 ###
77
+ # use_delta_bending = False
78
+ ### for ts = 3 ###
79
+ use_delta_bending = True
80
+ use_passive_nets = True
81
+ # use_passive_nets = False # sv mesh root #
82
+
83
+ use_split_network = True
84
+
85
+ penetration_determining = "plane_primitives"
86
+
87
+
88
+ n_timesteps = 3 #
89
+ # n_timesteps = 5 #
90
+ n_timesteps = 7
91
+ n_timesteps = 60
92
+
93
+
94
+
95
+
96
+ using_delta_glb_trans = True
97
+ using_delta_glb_trans = False
98
+
99
+ optimize_with_intermediates = False
100
+ optimize_with_intermediates = True
101
+
102
+
103
+ loss_tangential_diff_coef = 1000
104
+ loss_tangential_diff_coef = 0
105
+
106
+
107
+
108
+ optimize_active_object = False
109
+ optimize_active_object = True
110
+
111
+ # optimize_expanded_pts = False
112
+ # optimize_expanded_pts = True
113
+
114
+ no_friction_constraint = False
115
+
116
+ optimize_glb_transformations = True
117
+ sim_model_path = "DiffHand/assets/hand_sphere_only_hand_testt.xml"
118
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
119
+ mano_mult_const_after_cent = 1.0
120
+ sim_num_steps = 1000000
121
+
122
+ bending_net_type = "active_force_field_v18"
123
+
124
+
125
+ ### try to train the residual friction ? ###
126
+ train_residual_friction = True
127
+ optimize_rules = True
128
+ ### cube ###
129
+ load_optimized_init_actions = ""
130
+
131
+ optimize_rules = False
132
+
133
+
134
+ ## optimize rules ## penetration proj k to robot ##
135
+ optimize_rules = True
136
+ penetration_proj_k_to_robot = 4000000.0
137
+ use_optimizable_params = True
138
+
139
+ penetration_determining = "ball_primitives" # uing ball primitives
140
+ optimize_rules = True #
141
+ penetration_proj_k_to_robot = 4000000.0 #
142
+ use_optimizable_params = True
143
+ train_with_forces_to_active = False
144
+
145
+ # penetration_determining = "ball_primitives"
146
+ ### obj sdf and normals for colllision eteftion and responses ##
147
+ ### grab train seq 54; cylinder ###
148
+ penetration_determining = "sdf_of_canon"
149
+ optimize_rules = True
150
+ train_with_forces_to_active = False
151
+
152
+ ### grab train seq 1 ###
153
+ penetration_determining = "sdf_of_canon"
154
+ train_with_forces_to_active = False
155
+
156
+ ### grab train seq 224 ###
157
+ penetration_determining = "sdf_of_canon"
158
+ train_with_forces_to_active = False
159
+ loss_scale_coef = 1000.0
160
+ penetration_proj_k_to_robot_friction = 40000000.0
161
+ penetration_proj_k_to_robot_friction = 100000000.0
162
+ use_same_contact_spring_k = False
163
+ sim_model_path = "DiffHand/assets/hand_sphere_only_hand_testt.xml"
164
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
165
+
166
+
167
+ penetration_determining = "sdf_of_canon"
168
+ optimize_rules = True
169
+ # optimize_rules = True
170
+
171
+ optimize_rules = False
172
+
173
+ optimize_rules = True
174
+
175
+
176
+ optimize_rules = False
177
+
178
+ optim_sim_model_params_from_mano = True
179
+ optimize_rules = True
180
+ optim_sim_model_params_from_mano = False
181
+ optimize_rules = False
182
+
183
+ penetration_proj_k_to_robot_friction = 100000000.0
184
+ penetration_proj_k_to_robot = 40000000.0
185
+
186
+
187
+ penetrating_depth_penalty = 1
188
+
189
+ minn_dist_threshold_robot_to_obj = 0.0
190
+
191
+
192
+ minn_dist_threshold_robot_to_obj = 0.1
193
+
194
+ optim_sim_model_params_from_mano = True
195
+ optimize_rules = True
196
+ optim_sim_model_params_from_mano = False
197
+ optimize_rules = False
198
+ optim_sim_model_params_from_mano = False
199
+ optimize_rules = False
200
+
201
+ load_optimized_init_transformations = ""
202
+ optim_sim_model_params_from_mano = True
203
+ optimize_rules = True
204
+ minn_dist_threshold_robot_to_obj = 0.0
205
+
206
+
207
+ optim_sim_model_params_from_mano = False
208
+
209
+ minn_dist_threshold_robot_to_obj = 0.1
210
+
211
+
212
+ ### kinematics confgs ###
213
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
214
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
215
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
216
+ # ckpt_fn = ""
217
+ load_optimized_init_transformations = ""
218
+ optim_sim_model_params_from_mano = True
219
+ optimize_rules = True
220
+ minn_dist_threshold_robot_to_obj = 0.0
221
+
222
+ optim_sim_model_params_from_mano = False
223
+
224
+ optimize_rules = True
225
+
226
+ ckpt_fn = "ckpts/grab/102/retargeted_shadow.pth"
227
+ load_optimized_init_transformations = "ckpts/grab/102/retargeted_shadow.pth"
228
+
229
+
230
+ optimize_rules = True
231
+
232
+ ## opt roboto ##
233
+ opt_robo_glb_trans = True
234
+ opt_robo_glb_rot = False
235
+ opt_robo_states = True
236
+
237
+
238
+ load_redmax_robot_actions_fn = "ckpts/grab/102/diffhand_act.npy"
239
+
240
+
241
+
242
+ ckpt_fn = ""
243
+
244
+ use_multi_stages = False
245
+ train_with_forces_to_active = True
246
+
247
+
248
+ # optimize_rules = False
249
+ loss_scale_coef = 1.0 ## loss scale coef ## loss scale coef ####
250
+
251
+
252
+
253
+ use_opt_rigid_translations=True
254
+
255
+ train_def = True
256
+
257
+ # optimizable_rigid_translations = False #
258
+ optimizable_rigid_translations=True
259
+
260
+ nerf {
261
+ D = 8,
262
+ d_in = 4,
263
+ d_in_view = 3,
264
+ W = 256,
265
+ multires = 10,
266
+ multires_view = 4,
267
+ output_ch = 4,
268
+ skips=[4],
269
+ use_viewdirs=True
270
+ }
271
+
272
+ sdf_network {
273
+ d_out = 257,
274
+ d_in = 3,
275
+ d_hidden = 256,
276
+ n_layers = 8,
277
+ skip_in = [4],
278
+ multires = 6,
279
+ bias = 0.5,
280
+ scale = 1.0,
281
+ geometric_init = True,
282
+ weight_norm = True,
283
+ }
284
+
285
+ variance_network {
286
+ init_val = 0.3
287
+ }
288
+
289
+ rendering_network {
290
+ d_feature = 256,
291
+ mode = idr,
292
+ d_in = 9,
293
+ d_out = 3,
294
+ d_hidden = 256,
295
+ n_layers = 4,
296
+ weight_norm = True,
297
+ multires_view = 4,
298
+ squeeze_out = True,
299
+ }
300
+
301
+ neus_renderer {
302
+ n_samples = 64,
303
+ n_importance = 64,
304
+ n_outside = 0,
305
+ up_sample_steps = 4 ,
306
+ perturb = 1.0,
307
+ }
308
+
309
+ bending_network {
310
+ multires = 6,
311
+ bending_latent_size = 32,
312
+ d_in = 3,
313
+ rigidity_hidden_dimensions = 64,
314
+ rigidity_network_depth = 5,
315
+ use_rigidity_network = False,
316
+ bending_n_timesteps = 10,
317
+ }
318
+ }
confs_new/dyn_grab_pointset_mano.conf ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+
7
+ # tag = "train_retargeted_shadow_hand_seq_102_mano_sparse_retargeting_"
8
+ tag = "train_dyn_mano_acts_"
9
+
10
+ recording = [
11
+ ./,
12
+ ./models
13
+ ]
14
+ }
15
+
16
+ dataset {
17
+ data_dir = public_data/CASE_NAME/
18
+ render_cameras_name = cameras_sphere.npz
19
+ object_cameras_name = cameras_sphere.npz
20
+ obj_idx = 102
21
+ }
22
+
23
+ train {
24
+ learning_rate = 5e-4
25
+ learning_rate_alpha = 0.05
26
+ end_iter = 300000
27
+
28
+ batch_size = 1024
29
+ validate_resolution_level = 4
30
+ warm_up_end = 5000
31
+ anneal_end = 0
32
+ use_white_bkgd = False
33
+
34
+ # save_freq = 10000
35
+ save_freq = 10000
36
+ val_freq = 20
37
+ val_mesh_freq = 20
38
+ report_freq = 10
39
+ igr_weight = 0.1
40
+ mask_weight = 0.1
41
+ }
42
+
43
+ model {
44
+
45
+ optimize_dyn_actions = True
46
+
47
+
48
+ optimize_robot = True
49
+
50
+ use_penalty_based_friction = True
51
+
52
+ use_split_params = False
53
+
54
+ use_sqr_spring_stiffness = True
55
+
56
+ use_pre_proj_frictions = True
57
+
58
+
59
+
60
+ use_sqrt_dist = True
61
+ contact_maintaining_dist_thres = 0.2
62
+
63
+ robot_actions_diff_coef = 0.001
64
+
65
+
66
+ use_sdf_as_contact_dist = True
67
+
68
+
69
+ #
70
+ use_contact_dist_as_sdf = False
71
+
72
+ use_glb_proj_delta = True
73
+
74
+
75
+
76
+ # penetration_proj_k_to_robot = 30
77
+ penetrating_depth_penalty = 1.0
78
+ train_states = True
79
+
80
+
81
+
82
+ minn_dist_threshold = 0.000
83
+ obj_mass = 30.0
84
+
85
+
86
+ use_LBFGS = True
87
+ use_LBFGS = False
88
+
89
+ use_mano_hand_for_test = False # use the dynamic mano model here #
90
+
91
+ extract_delta_mesh = False
92
+ freeze_weights = True
93
+ gt_act_xs_def = False
94
+ use_bending_network = True
95
+ ### for ts = 3 ###
96
+ # use_delta_bending = False
97
+ ### for ts = 3 ###
98
+
99
+
100
+
101
+
102
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
103
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
104
+
105
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
106
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
107
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
108
+
109
+ bending_net_type = "active_force_field_v18"
110
+ sim_num_steps = 1000000
111
+ n_timesteps = 60
112
+ optim_sim_model_params_from_mano = False
113
+ penetration_determining = "sdf_of_canon"
114
+ train_with_forces_to_active = False
115
+ loss_scale_coef = 1000.0
116
+ use_same_contact_spring_k = False
117
+ use_optimizable_params = True #
118
+ train_residual_friction = True
119
+ mano_mult_const_after_cent = 1.0
120
+ optimize_glb_transformations = True
121
+ no_friction_constraint = False
122
+ optimize_active_object = True
123
+ loss_tangential_diff_coef = 0
124
+ optimize_with_intermediates = True
125
+ using_delta_glb_trans = False
126
+ train_multi_seqs = False
127
+ use_split_network = True
128
+ use_delta_bending = True
129
+
130
+
131
+
132
+
133
+ ##### contact spring model settings ####
134
+ minn_dist_threshold_robot_to_obj = 0.1
135
+ penetration_proj_k_to_robot_friction = 10000000.0
136
+ penetration_proj_k_to_robot = 4000000.0
137
+ ##### contact spring model settings ####
138
+
139
+
140
+ ###### ######
141
+ # drive_pointset = "states"
142
+ fix_obj = True # to track the hand only
143
+ optimize_rules = False
144
+ train_pointset_acts_via_deltas = False
145
+ load_optimized_init_actions = ""
146
+ load_optimized_init_transformations = ""
147
+ ckpt_fn = ""
148
+ retar_only_glb = True
149
+ # use_multi_stages = True
150
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
151
+
152
+ use_opt_rigid_translations=True
153
+
154
+ train_def = True
155
+ optimizable_rigid_translations=True
156
+
157
+ nerf {
158
+ D = 8,
159
+ d_in = 4,
160
+ d_in_view = 3,
161
+ W = 256,
162
+ multires = 10,
163
+ multires_view = 4,
164
+ output_ch = 4,
165
+ skips=[4],
166
+ use_viewdirs=True
167
+ }
168
+
169
+ sdf_network {
170
+ d_out = 257,
171
+ d_in = 3,
172
+ d_hidden = 256,
173
+ n_layers = 8,
174
+ skip_in = [4],
175
+ multires = 6,
176
+ bias = 0.5,
177
+ scale = 1.0,
178
+ geometric_init = True,
179
+ weight_norm = True,
180
+ }
181
+
182
+ variance_network {
183
+ init_val = 0.3
184
+ }
185
+
186
+ rendering_network {
187
+ d_feature = 256,
188
+ mode = idr,
189
+ d_in = 9,
190
+ d_out = 3,
191
+ d_hidden = 256,
192
+ n_layers = 4,
193
+ weight_norm = True,
194
+ multires_view = 4,
195
+ squeeze_out = True,
196
+ }
197
+
198
+ neus_renderer {
199
+ n_samples = 64,
200
+ n_importance = 64,
201
+ n_outside = 0,
202
+ up_sample_steps = 4 ,
203
+ perturb = 1.0,
204
+ }
205
+
206
+ bending_network {
207
+ multires = 6,
208
+ bending_latent_size = 32,
209
+ d_in = 3,
210
+ rigidity_hidden_dimensions = 64,
211
+ rigidity_network_depth = 5,
212
+ use_rigidity_network = False,
213
+ bending_n_timesteps = 10,
214
+ }
215
+ }
confs_new/dyn_grab_pointset_mano_dyn.conf ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+
7
+ # tag = "train_retargeted_shadow_hand_seq_102_mano_sparse_retargeting_"
8
+ # tag = "train_dyn_mano_acts_"
9
+ tag = "train_dyn_mano_acts_wreact_optps_"
10
+
11
+ recording = [
12
+ ./,
13
+ ./models
14
+ ]
15
+ }
16
+
17
+ dataset {
18
+ data_dir = public_data/CASE_NAME/
19
+ render_cameras_name = cameras_sphere.npz
20
+ object_cameras_name = cameras_sphere.npz
21
+ obj_idx = 102
22
+ }
23
+
24
+ train {
25
+ learning_rate = 5e-4
26
+ learning_rate_alpha = 0.05
27
+ end_iter = 300000
28
+
29
+ batch_size = 1024
30
+ validate_resolution_level = 4
31
+ warm_up_end = 5000
32
+ anneal_end = 0
33
+ use_white_bkgd = False
34
+
35
+ # save_freq = 10000
36
+ save_freq = 10000
37
+ val_freq = 20
38
+ val_mesh_freq = 20
39
+ report_freq = 10
40
+ igr_weight = 0.1
41
+ mask_weight = 0.1
42
+ }
43
+
44
+ model {
45
+
46
+ optimize_dyn_actions = True
47
+
48
+
49
+ optimize_robot = True
50
+
51
+ use_penalty_based_friction = True
52
+
53
+ use_split_params = False
54
+
55
+ use_sqr_spring_stiffness = True
56
+
57
+ use_pre_proj_frictions = True
58
+
59
+
60
+
61
+ use_sqrt_dist = True
62
+ contact_maintaining_dist_thres = 0.2
63
+
64
+ robot_actions_diff_coef = 0.001
65
+
66
+
67
+ use_sdf_as_contact_dist = True
68
+
69
+
70
+ #
71
+ use_contact_dist_as_sdf = False
72
+
73
+ use_glb_proj_delta = True
74
+
75
+
76
+
77
+ # penetration_proj_k_to_robot = 30
78
+ penetrating_depth_penalty = 1.0
79
+ train_states = True
80
+
81
+
82
+
83
+ minn_dist_threshold = 0.000
84
+ obj_mass = 30.0
85
+
86
+
87
+ use_LBFGS = True
88
+ use_LBFGS = False
89
+
90
+ use_mano_hand_for_test = False # use the dynamic mano model here #
91
+
92
+ extract_delta_mesh = False
93
+ freeze_weights = True
94
+ gt_act_xs_def = False
95
+ use_bending_network = True
96
+ ### for ts = 3 ###
97
+ # use_delta_bending = False
98
+ ### for ts = 3 ###
99
+
100
+
101
+
102
+
103
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
104
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
105
+
106
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
107
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
108
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
109
+
110
+ bending_net_type = "active_force_field_v18"
111
+ sim_num_steps = 1000000
112
+ n_timesteps = 60
113
+ optim_sim_model_params_from_mano = False
114
+ penetration_determining = "sdf_of_canon"
115
+ train_with_forces_to_active = False
116
+ loss_scale_coef = 1000.0
117
+ use_same_contact_spring_k = False
118
+ use_optimizable_params = True #
119
+ train_residual_friction = True
120
+ mano_mult_const_after_cent = 1.0
121
+ optimize_glb_transformations = True
122
+ no_friction_constraint = False
123
+ optimize_active_object = True
124
+ loss_tangential_diff_coef = 0
125
+ optimize_with_intermediates = True
126
+ using_delta_glb_trans = False
127
+ train_multi_seqs = False
128
+ use_split_network = True
129
+ use_delta_bending = True
130
+
131
+
132
+
133
+
134
+ ##### contact spring model settings ####
135
+ minn_dist_threshold_robot_to_obj = 0.1
136
+ penetration_proj_k_to_robot_friction = 10000000.0
137
+ penetration_proj_k_to_robot = 4000000.0
138
+ ##### contact spring model settings ####
139
+
140
+
141
+ ###### Stage 1: optimize for the parametes ######
142
+ # drive_pointset = "states"
143
+ fix_obj = False
144
+ optimize_rules = True
145
+ train_pointset_acts_via_deltas = False
146
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
147
+ load_optimized_init_transformations = ""
148
+ ckpt_fn = "ckpts/grab/102/dyn_mano_arti.pth"
149
+ # retar_only_glb = True
150
+ # use_multi_stages = True
151
+ ###### Stage 1: optimize for the parametes ######
152
+
153
+ use_opt_rigid_translations=True
154
+
155
+ train_def = True
156
+ optimizable_rigid_translations=True
157
+
158
+ nerf {
159
+ D = 8,
160
+ d_in = 4,
161
+ d_in_view = 3,
162
+ W = 256,
163
+ multires = 10,
164
+ multires_view = 4,
165
+ output_ch = 4,
166
+ skips=[4],
167
+ use_viewdirs=True
168
+ }
169
+
170
+ sdf_network {
171
+ d_out = 257,
172
+ d_in = 3,
173
+ d_hidden = 256,
174
+ n_layers = 8,
175
+ skip_in = [4],
176
+ multires = 6,
177
+ bias = 0.5,
178
+ scale = 1.0,
179
+ geometric_init = True,
180
+ weight_norm = True,
181
+ }
182
+
183
+ variance_network {
184
+ init_val = 0.3
185
+ }
186
+
187
+ rendering_network {
188
+ d_feature = 256,
189
+ mode = idr,
190
+ d_in = 9,
191
+ d_out = 3,
192
+ d_hidden = 256,
193
+ n_layers = 4,
194
+ weight_norm = True,
195
+ multires_view = 4,
196
+ squeeze_out = True,
197
+ }
198
+
199
+ neus_renderer {
200
+ n_samples = 64,
201
+ n_importance = 64,
202
+ n_outside = 0,
203
+ up_sample_steps = 4 ,
204
+ perturb = 1.0,
205
+ }
206
+
207
+ bending_network {
208
+ multires = 6,
209
+ bending_latent_size = 32,
210
+ d_in = 3,
211
+ rigidity_hidden_dimensions = 64,
212
+ rigidity_network_depth = 5,
213
+ use_rigidity_network = False,
214
+ bending_n_timesteps = 10,
215
+ }
216
+ }
217
+
218
+
confs_new/dyn_grab_pointset_mano_dyn_optacts.conf ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+
7
+ # tag = "train_retargeted_shadow_hand_seq_102_mano_sparse_retargeting_"
8
+ # tag = "train_dyn_mano_acts_"
9
+ tag = "train_dyn_mano_acts_wreact_optps_optacts_"
10
+
11
+ recording = [
12
+ ./,
13
+ ./models
14
+ ]
15
+ }
16
+
17
+ dataset {
18
+ data_dir = public_data/CASE_NAME/
19
+ render_cameras_name = cameras_sphere.npz
20
+ object_cameras_name = cameras_sphere.npz
21
+ obj_idx = 102
22
+ }
23
+
24
+ train {
25
+ learning_rate = 5e-4
26
+ learning_rate_alpha = 0.05
27
+ end_iter = 300000
28
+
29
+ batch_size = 1024
30
+ validate_resolution_level = 4
31
+ warm_up_end = 5000
32
+ anneal_end = 0
33
+ use_white_bkgd = False
34
+
35
+ # save_freq = 10000
36
+ save_freq = 10000
37
+ val_freq = 20
38
+ val_mesh_freq = 20
39
+ report_freq = 10
40
+ igr_weight = 0.1
41
+ mask_weight = 0.1
42
+ }
43
+
44
+ model {
45
+
46
+ optimize_dyn_actions = True
47
+
48
+
49
+ optimize_robot = True
50
+
51
+ use_penalty_based_friction = True
52
+
53
+ use_split_params = False
54
+
55
+ use_sqr_spring_stiffness = True
56
+
57
+ use_pre_proj_frictions = True
58
+
59
+
60
+
61
+ use_sqrt_dist = True
62
+ contact_maintaining_dist_thres = 0.2
63
+
64
+ robot_actions_diff_coef = 0.001
65
+
66
+
67
+ use_sdf_as_contact_dist = True
68
+
69
+
70
+ #
71
+ use_contact_dist_as_sdf = False
72
+
73
+ use_glb_proj_delta = True
74
+
75
+
76
+
77
+ # penetration_proj_k_to_robot = 30
78
+ penetrating_depth_penalty = 1.0
79
+ train_states = True
80
+
81
+
82
+
83
+ minn_dist_threshold = 0.000
84
+ obj_mass = 30.0
85
+
86
+
87
+ use_LBFGS = True
88
+ use_LBFGS = False
89
+
90
+ use_mano_hand_for_test = False # use the dynamic mano model here #
91
+
92
+ extract_delta_mesh = False
93
+ freeze_weights = True
94
+ gt_act_xs_def = False
95
+ use_bending_network = True
96
+ ### for ts = 3 ###
97
+ # use_delta_bending = False
98
+ ### for ts = 3 ###
99
+
100
+
101
+
102
+
103
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
104
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
105
+
106
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
107
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
108
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
109
+
110
+ bending_net_type = "active_force_field_v18"
111
+ sim_num_steps = 1000000
112
+ n_timesteps = 60
113
+ optim_sim_model_params_from_mano = False
114
+ penetration_determining = "sdf_of_canon"
115
+ train_with_forces_to_active = False
116
+ loss_scale_coef = 1000.0
117
+ use_same_contact_spring_k = False
118
+ use_optimizable_params = True #
119
+ train_residual_friction = True
120
+ mano_mult_const_after_cent = 1.0
121
+ optimize_glb_transformations = True
122
+ no_friction_constraint = False
123
+ optimize_active_object = True
124
+ loss_tangential_diff_coef = 0
125
+ optimize_with_intermediates = True
126
+ using_delta_glb_trans = False
127
+ train_multi_seqs = False
128
+ use_split_network = True
129
+ use_delta_bending = True
130
+
131
+
132
+
133
+
134
+ ##### contact spring model settings ####
135
+ minn_dist_threshold_robot_to_obj = 0.1
136
+ penetration_proj_k_to_robot_friction = 10000000.0
137
+ penetration_proj_k_to_robot = 4000000.0
138
+ ##### contact spring model settings ####
139
+
140
+
141
+ ###### Stage 1: optimize for the parametes ######
142
+ # drive_pointset = "states"
143
+ fix_obj = False
144
+ optimize_rules = False
145
+ train_pointset_acts_via_deltas = False
146
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
147
+ load_optimized_init_transformations = ""
148
+ ckpt_fn = "ckpts/grab/102/dyn_mano_arti.pth"
149
+ # retar_only_glb = True
150
+ # use_multi_stages = True
151
+ ###### Stage 1: optimize for the parametes ######
152
+
153
+ use_opt_rigid_translations=True
154
+
155
+ train_def = True
156
+ optimizable_rigid_translations=True
157
+
158
+ nerf {
159
+ D = 8,
160
+ d_in = 4,
161
+ d_in_view = 3,
162
+ W = 256,
163
+ multires = 10,
164
+ multires_view = 4,
165
+ output_ch = 4,
166
+ skips=[4],
167
+ use_viewdirs=True
168
+ }
169
+
170
+ sdf_network {
171
+ d_out = 257,
172
+ d_in = 3,
173
+ d_hidden = 256,
174
+ n_layers = 8,
175
+ skip_in = [4],
176
+ multires = 6,
177
+ bias = 0.5,
178
+ scale = 1.0,
179
+ geometric_init = True,
180
+ weight_norm = True,
181
+ }
182
+
183
+ variance_network {
184
+ init_val = 0.3
185
+ }
186
+
187
+ rendering_network {
188
+ d_feature = 256,
189
+ mode = idr,
190
+ d_in = 9,
191
+ d_out = 3,
192
+ d_hidden = 256,
193
+ n_layers = 4,
194
+ weight_norm = True,
195
+ multires_view = 4,
196
+ squeeze_out = True,
197
+ }
198
+
199
+ neus_renderer {
200
+ n_samples = 64,
201
+ n_importance = 64,
202
+ n_outside = 0,
203
+ up_sample_steps = 4 ,
204
+ perturb = 1.0,
205
+ }
206
+
207
+ bending_network {
208
+ multires = 6,
209
+ bending_latent_size = 32,
210
+ d_in = 3,
211
+ rigidity_hidden_dimensions = 64,
212
+ rigidity_network_depth = 5,
213
+ use_rigidity_network = False,
214
+ bending_n_timesteps = 10,
215
+ }
216
+ }
217
+
218
+
confs_new/dyn_grab_pointset_points_dyn.conf ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_"
7
+
8
+ recording = [
9
+ ./,
10
+ ./models
11
+ ]
12
+ }
13
+
14
+ dataset {
15
+ data_dir = public_data/CASE_NAME/
16
+ render_cameras_name = cameras_sphere.npz
17
+ object_cameras_name = cameras_sphere.npz
18
+ obj_idx = 102
19
+ }
20
+
21
+ train {
22
+ learning_rate = 5e-4
23
+ learning_rate_alpha = 0.05
24
+ end_iter = 300000
25
+
26
+ batch_size = 1024 # 64
27
+ validate_resolution_level = 4
28
+ warm_up_end = 5000
29
+ anneal_end = 0
30
+ use_white_bkgd = False
31
+
32
+ # save_freq = 10000
33
+ save_freq = 10000
34
+ val_freq = 20 # 2500
35
+ val_mesh_freq = 20 # 5000
36
+ report_freq = 10
37
+ ### igr weight ###
38
+ igr_weight = 0.1
39
+ mask_weight = 0.1
40
+ }
41
+
42
+ model {
43
+
44
+ optimize_dyn_actions = True
45
+
46
+
47
+ optimize_robot = True
48
+
49
+ use_penalty_based_friction = True
50
+
51
+ use_split_params = False
52
+
53
+ use_sqr_spring_stiffness = True
54
+
55
+ use_pre_proj_frictions = True
56
+
57
+
58
+
59
+ use_sqrt_dist = True
60
+ contact_maintaining_dist_thres = 0.2
61
+
62
+ robot_actions_diff_coef = 0.001
63
+
64
+
65
+ use_sdf_as_contact_dist = True
66
+
67
+
68
+ #
69
+ use_contact_dist_as_sdf = False
70
+
71
+ use_glb_proj_delta = True
72
+
73
+
74
+
75
+ # penetration_proj_k_to_robot = 30
76
+ penetrating_depth_penalty = 1.0
77
+ train_states = True
78
+
79
+
80
+
81
+ minn_dist_threshold = 0.000
82
+ obj_mass = 30.0
83
+
84
+
85
+ use_LBFGS = True
86
+ use_LBFGS = False
87
+
88
+ use_mano_hand_for_test = False # use the dynamic mano model here #
89
+
90
+ extract_delta_mesh = False
91
+ freeze_weights = True
92
+ gt_act_xs_def = False
93
+ use_bending_network = True
94
+ ### for ts = 3 ###
95
+ # use_delta_bending = False
96
+ ### for ts = 3 ###
97
+
98
+
99
+
100
+
101
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
102
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
103
+
104
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
105
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
106
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
107
+
108
+ bending_net_type = "active_force_field_v18"
109
+ sim_num_steps = 1000000
110
+ n_timesteps = 60
111
+ optim_sim_model_params_from_mano = False
112
+ penetration_determining = "sdf_of_canon"
113
+ train_with_forces_to_active = False
114
+ loss_scale_coef = 1000.0
115
+ use_same_contact_spring_k = False
116
+ use_optimizable_params = True #
117
+ train_residual_friction = True
118
+ mano_mult_const_after_cent = 1.0
119
+ optimize_glb_transformations = True
120
+ no_friction_constraint = False
121
+ optimize_active_object = True
122
+ loss_tangential_diff_coef = 0
123
+ optimize_with_intermediates = True
124
+ using_delta_glb_trans = False
125
+ train_multi_seqs = False
126
+ use_split_network = True
127
+ use_delta_bending = True
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+ ###### threshold, ks settings 1, optimize acts ######
138
+ # drive_pointset = "actions"
139
+ # fix_obj = True
140
+ # optimize_rules = False
141
+ # train_pointset_acts_via_deltas = True
142
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
143
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
144
+ ###### threshold, ks settings 1, optimize acts ######
145
+
146
+
147
+ ##### contact spring model settings ####
148
+ minn_dist_threshold_robot_to_obj = 0.1
149
+ penetration_proj_k_to_robot_friction = 10000000.0
150
+ penetration_proj_k_to_robot = 4000000.0
151
+ ##### contact spring model settings ####
152
+
153
+
154
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
155
+ drive_pointset = "states"
156
+ fix_obj = True
157
+ optimize_rules = False
158
+ train_pointset_acts_via_deltas = False
159
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
160
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
161
+
162
+
163
+ ###### Stage 2: threshold, ks settings 1, optimize acts ######
164
+ drive_pointset = "actions"
165
+ fix_obj = True
166
+ optimize_rules = False
167
+ train_pointset_acts_via_deltas = True
168
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
169
+ ###### Stage 2: threshold, ks settings 1, optimize acts ######
170
+
171
+
172
+ ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
173
+ drive_pointset = "actions"
174
+ fix_obj = False
175
+ optimize_rules = True
176
+ train_pointset_acts_via_deltas = True
177
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
178
+ ##### model parameters optimized from the MANO hand trajectory #####
179
+ ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
180
+ ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
181
+
182
+
183
+ ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
184
+ drive_pointset = "actions"
185
+ fix_obj = False
186
+ optimize_rules = False
187
+ train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
188
+ ##### model parameters optimized from the MANO hand expanded set trajectory #####
189
+ ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
190
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
191
+ ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
192
+
193
+
194
+ use_opt_rigid_translations=True
195
+
196
+ train_def = True
197
+ optimizable_rigid_translations=True
198
+
199
+ nerf {
200
+ D = 8,
201
+ d_in = 4,
202
+ d_in_view = 3,
203
+ W = 256,
204
+ multires = 10,
205
+ multires_view = 4,
206
+ output_ch = 4,
207
+ skips=[4],
208
+ use_viewdirs=True
209
+ }
210
+
211
+ sdf_network {
212
+ d_out = 257,
213
+ d_in = 3,
214
+ d_hidden = 256,
215
+ n_layers = 8,
216
+ skip_in = [4],
217
+ multires = 6,
218
+ bias = 0.5,
219
+ scale = 1.0,
220
+ geometric_init = True,
221
+ weight_norm = True,
222
+ }
223
+
224
+ variance_network {
225
+ init_val = 0.3
226
+ }
227
+
228
+ rendering_network {
229
+ d_feature = 256,
230
+ mode = idr,
231
+ d_in = 9,
232
+ d_out = 3,
233
+ d_hidden = 256,
234
+ n_layers = 4,
235
+ weight_norm = True,
236
+ multires_view = 4,
237
+ squeeze_out = True,
238
+ }
239
+
240
+ neus_renderer {
241
+ n_samples = 64,
242
+ n_importance = 64,
243
+ n_outside = 0,
244
+ up_sample_steps = 4 ,
245
+ perturb = 1.0,
246
+ }
247
+
248
+ bending_network {
249
+ multires = 6,
250
+ bending_latent_size = 32,
251
+ d_in = 3,
252
+ rigidity_hidden_dimensions = 64,
253
+ rigidity_network_depth = 5,
254
+ use_rigidity_network = False,
255
+ bending_n_timesteps = 10,
256
+ }
257
+ }
confs_new/dyn_grab_pointset_points_dyn_retar.conf ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_"
7
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_retar_to_shadow_"
8
+
9
+ recording = [
10
+ ./,
11
+ ./models
12
+ ]
13
+ }
14
+
15
+ dataset {
16
+ data_dir = public_data/CASE_NAME/
17
+ render_cameras_name = cameras_sphere.npz
18
+ object_cameras_name = cameras_sphere.npz
19
+ obj_idx = 102
20
+ }
21
+
22
+ train {
23
+ learning_rate = 5e-4
24
+ learning_rate_alpha = 0.05
25
+ end_iter = 300000
26
+
27
+ batch_size = 1024 # 64
28
+ validate_resolution_level = 4
29
+ warm_up_end = 5000
30
+ anneal_end = 0
31
+ use_white_bkgd = False
32
+
33
+ # save_freq = 10000
34
+ save_freq = 10000
35
+ val_freq = 20 # 2500
36
+ val_mesh_freq = 20 # 5000
37
+ report_freq = 10
38
+ ### igr weight ###
39
+ igr_weight = 0.1
40
+ mask_weight = 0.1
41
+ }
42
+
43
+ model {
44
+
45
+ optimize_dyn_actions = True
46
+
47
+
48
+ optimize_robot = True
49
+
50
+ use_penalty_based_friction = True
51
+
52
+ use_split_params = False
53
+
54
+ use_sqr_spring_stiffness = True
55
+
56
+ use_pre_proj_frictions = True
57
+
58
+
59
+
60
+ use_sqrt_dist = True
61
+ contact_maintaining_dist_thres = 0.2
62
+
63
+ robot_actions_diff_coef = 0.001
64
+
65
+
66
+ use_sdf_as_contact_dist = True
67
+
68
+
69
+ #
70
+ use_contact_dist_as_sdf = False
71
+
72
+ use_glb_proj_delta = True
73
+
74
+
75
+
76
+ # penetration_proj_k_to_robot = 30
77
+ penetrating_depth_penalty = 1.0
78
+ train_states = True
79
+
80
+
81
+
82
+ minn_dist_threshold = 0.000
83
+ obj_mass = 30.0
84
+
85
+
86
+ use_LBFGS = True
87
+ use_LBFGS = False
88
+
89
+ use_mano_hand_for_test = False # use the dynamic mano model here #
90
+
91
+ extract_delta_mesh = False
92
+ freeze_weights = True
93
+ gt_act_xs_def = False
94
+ use_bending_network = True
95
+ ### for ts = 3 ###
96
+ # use_delta_bending = False
97
+ ### for ts = 3 ###
98
+
99
+
100
+
101
+
102
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
103
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
104
+
105
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
106
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
107
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
108
+
109
+ bending_net_type = "active_force_field_v18"
110
+ sim_num_steps = 1000000
111
+ n_timesteps = 60
112
+ optim_sim_model_params_from_mano = False
113
+ penetration_determining = "sdf_of_canon"
114
+ train_with_forces_to_active = False
115
+ loss_scale_coef = 1000.0
116
+ use_same_contact_spring_k = False
117
+ use_optimizable_params = True #
118
+ train_residual_friction = True
119
+ mano_mult_const_after_cent = 1.0
120
+ optimize_glb_transformations = True
121
+ no_friction_constraint = False
122
+ optimize_active_object = True
123
+ loss_tangential_diff_coef = 0
124
+ optimize_with_intermediates = True
125
+ using_delta_glb_trans = False
126
+ train_multi_seqs = False
127
+ use_split_network = True
128
+ use_delta_bending = True
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+ ###### threshold, ks settings 1, optimize acts ######
139
+ # drive_pointset = "actions"
140
+ # fix_obj = True
141
+ # optimize_rules = False
142
+ # train_pointset_acts_via_deltas = True
143
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
144
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
145
+ ###### threshold, ks settings 1, optimize acts ######
146
+
147
+
148
+ ##### contact spring model settings ####
149
+ minn_dist_threshold_robot_to_obj = 0.1
150
+ penetration_proj_k_to_robot_friction = 10000000.0
151
+ penetration_proj_k_to_robot = 4000000.0
152
+ ##### contact spring model settings ####
153
+
154
+
155
+ # ###### Stage 1: threshold, ks settings 1, optimize offsets ######
156
+ # drive_pointset = "states"
157
+ # fix_obj = True
158
+ # optimize_rules = False
159
+ # train_pointset_acts_via_deltas = False
160
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
161
+ # ###### Stage 1: threshold, ks settings 1, optimize offsets ######
162
+
163
+
164
+ # ###### Stage 2: threshold, ks settings 1, optimize acts ######
165
+ # drive_pointset = "actions"
166
+ # fix_obj = True
167
+ # optimize_rules = False
168
+ # train_pointset_acts_via_deltas = True
169
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
170
+ # ###### Stage 2: threshold, ks settings 1, optimize acts ######
171
+
172
+
173
+ # ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
174
+ # drive_pointset = "actions"
175
+ # fix_obj = False
176
+ # optimize_rules = True
177
+ # train_pointset_acts_via_deltas = True
178
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
179
+ # ##### model parameters optimized from the MANO hand trajectory #####
180
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
181
+ # ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
182
+
183
+
184
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
185
+ # drive_pointset = "actions"
186
+ # fix_obj = False
187
+ # optimize_rules = False
188
+ # train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
189
+ # ##### model parameters optimized from the MANO hand expanded set trajectory #####
190
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
191
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
192
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
193
+
194
+
195
+
196
+ ###### Retargeting Stage 1 ######
197
+ drive_pointset = "actions"
198
+ fix_obj = False
199
+ optimize_rules = False
200
+ train_pointset_acts_via_deltas = True
201
+ ##### model parameters optimized from the MANO hand expanded set trajectory #####
202
+ ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
203
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
204
+ load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti.pth"
205
+ finger_cd_loss = 1.0
206
+ optimize_pointset_motion_only = True
207
+ ###### Retargeting Stage 1 ######
208
+
209
+
210
+
211
+ use_opt_rigid_translations=True
212
+
213
+ train_def = True
214
+ optimizable_rigid_translations=True
215
+
216
+ nerf {
217
+ D = 8,
218
+ d_in = 4,
219
+ d_in_view = 3,
220
+ W = 256,
221
+ multires = 10,
222
+ multires_view = 4,
223
+ output_ch = 4,
224
+ skips=[4],
225
+ use_viewdirs=True
226
+ }
227
+
228
+ sdf_network {
229
+ d_out = 257,
230
+ d_in = 3,
231
+ d_hidden = 256,
232
+ n_layers = 8,
233
+ skip_in = [4],
234
+ multires = 6,
235
+ bias = 0.5,
236
+ scale = 1.0,
237
+ geometric_init = True,
238
+ weight_norm = True,
239
+ }
240
+
241
+ variance_network {
242
+ init_val = 0.3
243
+ }
244
+
245
+ rendering_network {
246
+ d_feature = 256,
247
+ mode = idr,
248
+ d_in = 9,
249
+ d_out = 3,
250
+ d_hidden = 256,
251
+ n_layers = 4,
252
+ weight_norm = True,
253
+ multires_view = 4,
254
+ squeeze_out = True,
255
+ }
256
+
257
+ neus_renderer {
258
+ n_samples = 64,
259
+ n_importance = 64,
260
+ n_outside = 0,
261
+ up_sample_steps = 4 ,
262
+ perturb = 1.0,
263
+ }
264
+
265
+ bending_network {
266
+ multires = 6,
267
+ bending_latent_size = 32,
268
+ d_in = 3,
269
+ rigidity_hidden_dimensions = 64,
270
+ rigidity_network_depth = 5,
271
+ use_rigidity_network = False,
272
+ bending_n_timesteps = 10,
273
+ }
274
+ }
confs_new/dyn_grab_pointset_points_dyn_retar_pts.conf ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_"
7
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_retar_to_shadow_pointset_"
8
+
9
+ recording = [
10
+ ./,
11
+ ./models
12
+ ]
13
+ }
14
+
15
+ dataset {
16
+ data_dir = public_data/CASE_NAME/
17
+ render_cameras_name = cameras_sphere.npz
18
+ object_cameras_name = cameras_sphere.npz
19
+ obj_idx = 102
20
+ }
21
+
22
+ train {
23
+ learning_rate = 5e-4
24
+ learning_rate_alpha = 0.05
25
+ end_iter = 300000
26
+
27
+ batch_size = 1024 # 64
28
+ validate_resolution_level = 4
29
+ warm_up_end = 5000
30
+ anneal_end = 0
31
+ use_white_bkgd = False
32
+
33
+ # save_freq = 10000
34
+ save_freq = 10000
35
+ val_freq = 20 # 2500
36
+ val_mesh_freq = 20 # 5000
37
+ report_freq = 10
38
+ ### igr weight ###
39
+ igr_weight = 0.1
40
+ mask_weight = 0.1
41
+ }
42
+
43
+ model {
44
+
45
+ optimize_dyn_actions = True
46
+
47
+
48
+ optimize_robot = True
49
+
50
+ use_penalty_based_friction = True
51
+
52
+ use_split_params = False
53
+
54
+ use_sqr_spring_stiffness = True
55
+
56
+ use_pre_proj_frictions = True
57
+
58
+
59
+
60
+ use_sqrt_dist = True
61
+ contact_maintaining_dist_thres = 0.2
62
+
63
+ robot_actions_diff_coef = 0.001
64
+
65
+
66
+ use_sdf_as_contact_dist = True
67
+
68
+
69
+ #
70
+ use_contact_dist_as_sdf = False
71
+
72
+ use_glb_proj_delta = True
73
+
74
+
75
+
76
+ # penetration_proj_k_to_robot = 30
77
+ penetrating_depth_penalty = 1.0
78
+ train_states = True
79
+
80
+
81
+
82
+ minn_dist_threshold = 0.000
83
+ obj_mass = 30.0
84
+
85
+
86
+ use_LBFGS = True
87
+ use_LBFGS = False
88
+
89
+ use_mano_hand_for_test = False # use the dynamic mano model here #
90
+
91
+ extract_delta_mesh = False
92
+ freeze_weights = True
93
+ gt_act_xs_def = False
94
+ use_bending_network = True
95
+ ### for ts = 3 ###
96
+ # use_delta_bending = False
97
+ ### for ts = 3 ###
98
+
99
+
100
+
101
+
102
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
103
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
104
+
105
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
106
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
107
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
108
+
109
+ bending_net_type = "active_force_field_v18"
110
+ sim_num_steps = 1000000
111
+ n_timesteps = 60
112
+ optim_sim_model_params_from_mano = False
113
+ penetration_determining = "sdf_of_canon"
114
+ train_with_forces_to_active = False
115
+ loss_scale_coef = 1000.0
116
+ use_same_contact_spring_k = False
117
+ use_optimizable_params = True #
118
+ train_residual_friction = True
119
+ mano_mult_const_after_cent = 1.0
120
+ optimize_glb_transformations = True
121
+ no_friction_constraint = False
122
+ optimize_active_object = True
123
+ loss_tangential_diff_coef = 0
124
+ optimize_with_intermediates = True
125
+ using_delta_glb_trans = False
126
+ train_multi_seqs = False
127
+ use_split_network = True
128
+ use_delta_bending = True
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+ ###### threshold, ks settings 1, optimize acts ######
139
+ # drive_pointset = "actions"
140
+ # fix_obj = True
141
+ # optimize_rules = False
142
+ # train_pointset_acts_via_deltas = True
143
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
144
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
145
+ ###### threshold, ks settings 1, optimize acts ######
146
+
147
+
148
+ ##### contact spring model settings ####
149
+ minn_dist_threshold_robot_to_obj = 0.1
150
+ penetration_proj_k_to_robot_friction = 10000000.0
151
+ penetration_proj_k_to_robot = 4000000.0
152
+ ##### contact spring model settings ####
153
+
154
+
155
+ # ###### Stage 1: threshold, ks settings 1, optimize offsets ######
156
+ # drive_pointset = "states"
157
+ # fix_obj = True
158
+ # optimize_rules = False
159
+ # train_pointset_acts_via_deltas = False
160
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
161
+ # ###### Stage 1: threshold, ks settings 1, optimize offsets ######
162
+
163
+
164
+ # ###### Stage 2: threshold, ks settings 1, optimize acts ######
165
+ # drive_pointset = "actions"
166
+ # fix_obj = True
167
+ # optimize_rules = False
168
+ # train_pointset_acts_via_deltas = True
169
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
170
+ # ###### Stage 2: threshold, ks settings 1, optimize acts ######
171
+
172
+
173
+ # ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
174
+ # drive_pointset = "actions"
175
+ # fix_obj = False
176
+ # optimize_rules = True
177
+ # train_pointset_acts_via_deltas = True
178
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
179
+ # ##### model parameters optimized from the MANO hand trajectory #####
180
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
181
+ # ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
182
+
183
+
184
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
185
+ # drive_pointset = "actions"
186
+ # fix_obj = False
187
+ # optimize_rules = False
188
+ # train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
189
+ # ##### model parameters optimized from the MANO hand expanded set trajectory #####
190
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
191
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
192
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
193
+
194
+
195
+
196
+ ###### Retargeting Stage 1 ######
197
+ drive_pointset = "actions"
198
+ fix_obj = False
199
+ optimize_rules = False
200
+ train_pointset_acts_via_deltas = True
201
+ ##### model parameters optimized from the MANO hand expanded set trajectory #####
202
+ ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
203
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
204
+ load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti_retar.pth"
205
+ finger_cd_loss = 1.0
206
+ optimize_pointset_motion_only = True
207
+ ###### Retargeting Stage 1 ######
208
+
209
+
210
+
211
+ ###### Retargeting Stage 2 ######
212
+ load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti_retar_optimized_arti.pth"
213
+ ###### Retargeting Stage 2 ######
214
+
215
+
216
+
217
+
218
+ use_opt_rigid_translations=True
219
+
220
+ train_def = True
221
+ optimizable_rigid_translations=True
222
+
223
+ nerf {
224
+ D = 8,
225
+ d_in = 4,
226
+ d_in_view = 3,
227
+ W = 256,
228
+ multires = 10,
229
+ multires_view = 4,
230
+ output_ch = 4,
231
+ skips=[4],
232
+ use_viewdirs=True
233
+ }
234
+
235
+ sdf_network {
236
+ d_out = 257,
237
+ d_in = 3,
238
+ d_hidden = 256,
239
+ n_layers = 8,
240
+ skip_in = [4],
241
+ multires = 6,
242
+ bias = 0.5,
243
+ scale = 1.0,
244
+ geometric_init = True,
245
+ weight_norm = True,
246
+ }
247
+
248
+ variance_network {
249
+ init_val = 0.3
250
+ }
251
+
252
+ rendering_network {
253
+ d_feature = 256,
254
+ mode = idr,
255
+ d_in = 9,
256
+ d_out = 3,
257
+ d_hidden = 256,
258
+ n_layers = 4,
259
+ weight_norm = True,
260
+ multires_view = 4,
261
+ squeeze_out = True,
262
+ }
263
+
264
+ neus_renderer {
265
+ n_samples = 64,
266
+ n_importance = 64,
267
+ n_outside = 0,
268
+ up_sample_steps = 4 ,
269
+ perturb = 1.0,
270
+ }
271
+
272
+ bending_network {
273
+ multires = 6,
274
+ bending_latent_size = 32,
275
+ d_in = 3,
276
+ rigidity_hidden_dimensions = 64,
277
+ rigidity_network_depth = 5,
278
+ use_rigidity_network = False,
279
+ bending_n_timesteps = 10,
280
+ }
281
+ }
confs_new/dyn_grab_pointset_points_dyn_retar_pts_opts.conf ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_"
7
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_retar_to_shadow_pointset_"
8
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_retar_to_shadow_pointset_optrules_"
9
+
10
+ recording = [
11
+ ./,
12
+ ./models
13
+ ]
14
+ }
15
+
16
+ dataset {
17
+ data_dir = public_data/CASE_NAME/
18
+ render_cameras_name = cameras_sphere.npz
19
+ object_cameras_name = cameras_sphere.npz
20
+ obj_idx = 102
21
+ }
22
+
23
+ train {
24
+ learning_rate = 5e-4
25
+ learning_rate_alpha = 0.05
26
+ end_iter = 300000
27
+
28
+ batch_size = 1024 # 64
29
+ validate_resolution_level = 4
30
+ warm_up_end = 5000
31
+ anneal_end = 0
32
+ use_white_bkgd = False
33
+
34
+ save_freq = 10000
35
+ val_freq = 20
36
+ val_mesh_freq = 20
37
+ report_freq = 10
38
+ igr_weight = 0.1
39
+ mask_weight = 0.1
40
+ }
41
+
42
+ model {
43
+
44
+ optimize_dyn_actions = True
45
+
46
+
47
+ optimize_robot = True
48
+
49
+ use_penalty_based_friction = True
50
+
51
+ use_split_params = False
52
+
53
+ use_sqr_spring_stiffness = True
54
+
55
+ use_pre_proj_frictions = True
56
+
57
+
58
+
59
+ use_sqrt_dist = True
60
+ contact_maintaining_dist_thres = 0.2
61
+
62
+ robot_actions_diff_coef = 0.001
63
+
64
+
65
+ use_sdf_as_contact_dist = True
66
+
67
+
68
+ #
69
+ use_contact_dist_as_sdf = False
70
+
71
+ use_glb_proj_delta = True
72
+
73
+
74
+
75
+ # penetration_proj_k_to_robot = 30
76
+ penetrating_depth_penalty = 1.0
77
+ train_states = True
78
+
79
+
80
+
81
+ minn_dist_threshold = 0.000
82
+ obj_mass = 30.0
83
+
84
+
85
+ use_LBFGS = True
86
+ use_LBFGS = False
87
+
88
+ use_mano_hand_for_test = False # use the dynamic mano model here #
89
+
90
+ extract_delta_mesh = False
91
+ freeze_weights = True
92
+ gt_act_xs_def = False
93
+ use_bending_network = True
94
+ ### for ts = 3 ###
95
+ # use_delta_bending = False
96
+ ### for ts = 3 ###
97
+
98
+
99
+
100
+
101
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
102
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
103
+
104
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
105
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
106
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
107
+
108
+ bending_net_type = "active_force_field_v18"
109
+ sim_num_steps = 1000000
110
+ n_timesteps = 60
111
+ optim_sim_model_params_from_mano = False
112
+ penetration_determining = "sdf_of_canon"
113
+ train_with_forces_to_active = False
114
+ loss_scale_coef = 1000.0
115
+ use_same_contact_spring_k = False
116
+ use_optimizable_params = True #
117
+ train_residual_friction = True
118
+ mano_mult_const_after_cent = 1.0
119
+ optimize_glb_transformations = True
120
+ no_friction_constraint = False
121
+ optimize_active_object = True
122
+ loss_tangential_diff_coef = 0
123
+ optimize_with_intermediates = True
124
+ using_delta_glb_trans = False
125
+ train_multi_seqs = False
126
+ use_split_network = True
127
+ use_delta_bending = True
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+ ###### threshold, ks settings 1, optimize acts ######
138
+ # drive_pointset = "actions"
139
+ # fix_obj = True
140
+ # optimize_rules = False
141
+ # train_pointset_acts_via_deltas = True
142
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
143
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
144
+ ###### threshold, ks settings 1, optimize acts ######
145
+
146
+
147
+ ##### contact spring model settings ####
148
+ minn_dist_threshold_robot_to_obj = 0.1
149
+ penetration_proj_k_to_robot_friction = 10000000.0
150
+ penetration_proj_k_to_robot = 4000000.0
151
+ ##### contact spring model settings ####
152
+
153
+
154
+ # ###### Stage 1: threshold, ks settings 1, optimize offsets ######
155
+ # drive_pointset = "states"
156
+ # fix_obj = True
157
+ # optimize_rules = False
158
+ # train_pointset_acts_via_deltas = False
159
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
160
+ # ###### Stage 1: threshold, ks settings 1, optimize offsets ######
161
+
162
+
163
+ # ###### Stage 2: threshold, ks settings 1, optimize acts ######
164
+ # drive_pointset = "actions"
165
+ # fix_obj = True
166
+ # optimize_rules = False
167
+ # train_pointset_acts_via_deltas = True
168
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
169
+ # ###### Stage 2: threshold, ks settings 1, optimize acts ######
170
+
171
+
172
+ # ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
173
+ # drive_pointset = "actions"
174
+ # fix_obj = False
175
+ # optimize_rules = True
176
+ # train_pointset_acts_via_deltas = True
177
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
178
+ # ##### model parameters optimized from the MANO hand trajectory #####
179
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
180
+ # ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
181
+
182
+
183
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
184
+ # drive_pointset = "actions"
185
+ # fix_obj = False
186
+ # optimize_rules = False
187
+ # train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
188
+ # ##### model parameters optimized from the MANO hand expanded set trajectory #####
189
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
190
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
191
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
192
+
193
+
194
+
195
+ ###### Retargeting Stage 1 ######
196
+ drive_pointset = "actions"
197
+ fix_obj = False
198
+ optimize_rules = False
199
+ train_pointset_acts_via_deltas = True
200
+ ##### model parameters optimized from the MANO hand expanded set trajectory #####
201
+ ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
202
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
203
+ load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti_retar.pth"
204
+ finger_cd_loss = 1.0
205
+ optimize_pointset_motion_only = True
206
+ ###### Retargeting Stage 1 ######
207
+
208
+
209
+ ###### Retargeting Stage 2 ######
210
+ load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti_retar_optimized_arti.pth"
211
+ ###### Retargeting Stage 2 ######
212
+
213
+
214
+ ###### Retargeting Stage 3 ######
215
+ load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti_retar_optimized_arti_optimized_pts.pth"
216
+ optimize_anchored_pts = False
217
+ optimize_rules = True
218
+ optimize_pointset_motion_only = False
219
+ ###### Retargeting Stage 3 ######
220
+
221
+
222
+
223
+
224
+ use_opt_rigid_translations=True
225
+
226
+ train_def = True
227
+ optimizable_rigid_translations=True
228
+
229
+ nerf {
230
+ D = 8,
231
+ d_in = 4,
232
+ d_in_view = 3,
233
+ W = 256,
234
+ multires = 10,
235
+ multires_view = 4,
236
+ output_ch = 4,
237
+ skips=[4],
238
+ use_viewdirs=True
239
+ }
240
+
241
+ sdf_network {
242
+ d_out = 257,
243
+ d_in = 3,
244
+ d_hidden = 256,
245
+ n_layers = 8,
246
+ skip_in = [4],
247
+ multires = 6,
248
+ bias = 0.5,
249
+ scale = 1.0,
250
+ geometric_init = True,
251
+ weight_norm = True,
252
+ }
253
+
254
+ variance_network {
255
+ init_val = 0.3
256
+ }
257
+
258
+ rendering_network {
259
+ d_feature = 256,
260
+ mode = idr,
261
+ d_in = 9,
262
+ d_out = 3,
263
+ d_hidden = 256,
264
+ n_layers = 4,
265
+ weight_norm = True,
266
+ multires_view = 4,
267
+ squeeze_out = True,
268
+ }
269
+
270
+ neus_renderer {
271
+ n_samples = 64,
272
+ n_importance = 64,
273
+ n_outside = 0,
274
+ up_sample_steps = 4 ,
275
+ perturb = 1.0,
276
+ }
277
+
278
+ bending_network {
279
+ multires = 6,
280
+ bending_latent_size = 32,
281
+ d_in = 3,
282
+ rigidity_hidden_dimensions = 64,
283
+ rigidity_network_depth = 5,
284
+ use_rigidity_network = False,
285
+ bending_n_timesteps = 10,
286
+ }
287
+ }
confs_new/dyn_grab_pointset_points_dyn_s1.conf ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_"
7
+
8
+ recording = [
9
+ ./,
10
+ ./models
11
+ ]
12
+ }
13
+
14
+ dataset {
15
+ data_dir = public_data/CASE_NAME/
16
+ render_cameras_name = cameras_sphere.npz
17
+ object_cameras_name = cameras_sphere.npz
18
+ obj_idx = 102
19
+ }
20
+
21
+ train {
22
+ learning_rate = 5e-4
23
+ learning_rate_alpha = 0.05
24
+ end_iter = 300000
25
+
26
+ batch_size = 1024
27
+ validate_resolution_level = 4
28
+ warm_up_end = 5000
29
+ anneal_end = 0
30
+ use_white_bkgd = False
31
+
32
+ # save_freq = 10000
33
+ save_freq = 10000
34
+ val_freq = 20
35
+ val_mesh_freq = 20
36
+ report_freq = 10
37
+ igr_weight = 0.1
38
+ mask_weight = 0.1
39
+ }
40
+
41
+ model {
42
+
43
+ optimize_dyn_actions = True
44
+
45
+
46
+ optimize_robot = True
47
+
48
+ use_penalty_based_friction = True
49
+
50
+ use_split_params = False
51
+
52
+ use_sqr_spring_stiffness = True
53
+
54
+ use_pre_proj_frictions = True
55
+
56
+
57
+
58
+ use_sqrt_dist = True
59
+ contact_maintaining_dist_thres = 0.2
60
+
61
+ robot_actions_diff_coef = 0.001
62
+
63
+
64
+ use_sdf_as_contact_dist = True
65
+
66
+
67
+ #
68
+ use_contact_dist_as_sdf = False
69
+
70
+ use_glb_proj_delta = True
71
+
72
+
73
+
74
+ # penetration_proj_k_to_robot = 30
75
+ penetrating_depth_penalty = 1.0
76
+ train_states = True
77
+
78
+
79
+
80
+ minn_dist_threshold = 0.000
81
+ obj_mass = 30.0
82
+
83
+
84
+ use_LBFGS = True
85
+ use_LBFGS = False
86
+
87
+ use_mano_hand_for_test = False # use the dynamic mano model here #
88
+
89
+ extract_delta_mesh = False
90
+ freeze_weights = True
91
+ gt_act_xs_def = False
92
+ use_bending_network = True
93
+ ### for ts = 3 ###
94
+ # use_delta_bending = False
95
+ ### for ts = 3 ###
96
+
97
+
98
+
99
+
100
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
101
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
102
+
103
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
104
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
105
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
106
+
107
+ bending_net_type = "active_force_field_v18"
108
+ sim_num_steps = 1000000
109
+ n_timesteps = 60
110
+ optim_sim_model_params_from_mano = False
111
+ penetration_determining = "sdf_of_canon"
112
+ train_with_forces_to_active = False
113
+ loss_scale_coef = 1000.0
114
+ use_same_contact_spring_k = False
115
+ use_optimizable_params = True #
116
+ train_residual_friction = True
117
+ mano_mult_const_after_cent = 1.0
118
+ optimize_glb_transformations = True
119
+ no_friction_constraint = False
120
+ optimize_active_object = True
121
+ loss_tangential_diff_coef = 0
122
+ optimize_with_intermediates = True
123
+ using_delta_glb_trans = False
124
+ train_multi_seqs = False
125
+ use_split_network = True
126
+ use_delta_bending = True
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+ ###### threshold, ks settings 1, optimize acts ######
137
+ # drive_pointset = "actions"
138
+ # fix_obj = True
139
+ # optimize_rules = False
140
+ # train_pointset_acts_via_deltas = True
141
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
142
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
143
+ ###### threshold, ks settings 1, optimize acts ######
144
+
145
+
146
+ ##### contact spring model settings ####
147
+ minn_dist_threshold_robot_to_obj = 0.1
148
+ penetration_proj_k_to_robot_friction = 10000000.0
149
+ penetration_proj_k_to_robot = 4000000.0
150
+ ##### contact spring model settings ####
151
+
152
+
153
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
154
+ drive_pointset = "states"
155
+ fix_obj = True
156
+ optimize_rules = False
157
+ train_pointset_acts_via_deltas = False
158
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
159
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
160
+
161
+
162
+ # ###### Stage 2: threshold, ks settings 1, optimize acts ######
163
+ # drive_pointset = "actions"
164
+ # fix_obj = True
165
+ # optimize_rules = False
166
+ # train_pointset_acts_via_deltas = True
167
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
168
+ # ###### Stage 2: threshold, ks settings 1, optimize acts ######
169
+
170
+
171
+ # ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
172
+ # drive_pointset = "actions"
173
+ # fix_obj = False
174
+ # optimize_rules = True
175
+ # train_pointset_acts_via_deltas = True
176
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
177
+ # ##### model parameters optimized from the MANO hand trajectory #####
178
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
179
+ # ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
180
+
181
+
182
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
183
+ # drive_pointset = "actions"
184
+ # fix_obj = False
185
+ # optimize_rules = False
186
+ # train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
187
+ # ##### model parameters optimized from the MANO hand expanded set trajectory #####
188
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
189
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
190
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
191
+
192
+
193
+ use_opt_rigid_translations=True
194
+
195
+ train_def = True
196
+ optimizable_rigid_translations=True
197
+
198
+ nerf {
199
+ D = 8,
200
+ d_in = 4,
201
+ d_in_view = 3,
202
+ W = 256,
203
+ multires = 10,
204
+ multires_view = 4,
205
+ output_ch = 4,
206
+ skips=[4],
207
+ use_viewdirs=True
208
+ }
209
+
210
+ sdf_network {
211
+ d_out = 257,
212
+ d_in = 3,
213
+ d_hidden = 256,
214
+ n_layers = 8,
215
+ skip_in = [4],
216
+ multires = 6,
217
+ bias = 0.5,
218
+ scale = 1.0,
219
+ geometric_init = True,
220
+ weight_norm = True,
221
+ }
222
+
223
+ variance_network {
224
+ init_val = 0.3
225
+ }
226
+
227
+ rendering_network {
228
+ d_feature = 256,
229
+ mode = idr,
230
+ d_in = 9,
231
+ d_out = 3,
232
+ d_hidden = 256,
233
+ n_layers = 4,
234
+ weight_norm = True,
235
+ multires_view = 4,
236
+ squeeze_out = True,
237
+ }
238
+
239
+ neus_renderer {
240
+ n_samples = 64,
241
+ n_importance = 64,
242
+ n_outside = 0,
243
+ up_sample_steps = 4 ,
244
+ perturb = 1.0,
245
+ }
246
+
247
+ bending_network {
248
+ multires = 6,
249
+ bending_latent_size = 32,
250
+ d_in = 3,
251
+ rigidity_hidden_dimensions = 64,
252
+ rigidity_network_depth = 5,
253
+ use_rigidity_network = False,
254
+ bending_n_timesteps = 10,
255
+ }
256
+ }
confs_new/dyn_grab_pointset_points_dyn_s2.conf ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_"
7
+
8
+ recording = [
9
+ ./,
10
+ ./models
11
+ ]
12
+ }
13
+
14
+ dataset {
15
+ data_dir = public_data/CASE_NAME/
16
+ render_cameras_name = cameras_sphere.npz
17
+ object_cameras_name = cameras_sphere.npz
18
+ obj_idx = 102
19
+ }
20
+
21
+ train {
22
+ learning_rate = 5e-4
23
+ learning_rate_alpha = 0.05
24
+ end_iter = 300000
25
+
26
+ batch_size = 1024 # 64
27
+ validate_resolution_level = 4
28
+ warm_up_end = 5000
29
+ anneal_end = 0
30
+ use_white_bkgd = False
31
+
32
+ # save_freq = 10000
33
+ save_freq = 10000
34
+ val_freq = 20 # 2500
35
+ val_mesh_freq = 20 # 5000
36
+ report_freq = 10
37
+ ### igr weight ###
38
+ igr_weight = 0.1
39
+ mask_weight = 0.1
40
+ }
41
+
42
+ model {
43
+
44
+ optimize_dyn_actions = True
45
+
46
+
47
+ optimize_robot = True
48
+
49
+ use_penalty_based_friction = True
50
+
51
+ use_split_params = False
52
+
53
+ use_sqr_spring_stiffness = True
54
+
55
+ use_pre_proj_frictions = True
56
+
57
+
58
+
59
+ use_sqrt_dist = True
60
+ contact_maintaining_dist_thres = 0.2
61
+
62
+ robot_actions_diff_coef = 0.001
63
+
64
+
65
+ use_sdf_as_contact_dist = True
66
+
67
+
68
+ #
69
+ use_contact_dist_as_sdf = False
70
+
71
+ use_glb_proj_delta = True
72
+
73
+
74
+
75
+ # penetration_proj_k_to_robot = 30
76
+ penetrating_depth_penalty = 1.0
77
+ train_states = True
78
+
79
+
80
+
81
+ minn_dist_threshold = 0.000
82
+ obj_mass = 30.0
83
+
84
+
85
+ use_LBFGS = True
86
+ use_LBFGS = False
87
+
88
+ use_mano_hand_for_test = False # use the dynamic mano model here #
89
+
90
+ extract_delta_mesh = False
91
+ freeze_weights = True
92
+ gt_act_xs_def = False
93
+ use_bending_network = True
94
+ ### for ts = 3 ###
95
+ # use_delta_bending = False
96
+ ### for ts = 3 ###
97
+
98
+
99
+
100
+
101
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
102
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
103
+
104
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
105
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
106
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
107
+
108
+ bending_net_type = "active_force_field_v18"
109
+ sim_num_steps = 1000000
110
+ n_timesteps = 60
111
+ optim_sim_model_params_from_mano = False
112
+ penetration_determining = "sdf_of_canon"
113
+ train_with_forces_to_active = False
114
+ loss_scale_coef = 1000.0
115
+ use_same_contact_spring_k = False
116
+ use_optimizable_params = True #
117
+ train_residual_friction = True
118
+ mano_mult_const_after_cent = 1.0
119
+ optimize_glb_transformations = True
120
+ no_friction_constraint = False
121
+ optimize_active_object = True
122
+ loss_tangential_diff_coef = 0
123
+ optimize_with_intermediates = True
124
+ using_delta_glb_trans = False
125
+ train_multi_seqs = False
126
+ use_split_network = True
127
+ use_delta_bending = True
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+ ###### threshold, ks settings 1, optimize acts ######
138
+ # drive_pointset = "actions"
139
+ # fix_obj = True
140
+ # optimize_rules = False
141
+ # train_pointset_acts_via_deltas = True
142
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
143
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
144
+ ###### threshold, ks settings 1, optimize acts ######
145
+
146
+
147
+ ##### contact spring model settings ####
148
+ minn_dist_threshold_robot_to_obj = 0.1
149
+ penetration_proj_k_to_robot_friction = 10000000.0
150
+ penetration_proj_k_to_robot = 4000000.0
151
+ ##### contact spring model settings ####
152
+
153
+
154
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
155
+ drive_pointset = "states"
156
+ fix_obj = True
157
+ optimize_rules = False
158
+ train_pointset_acts_via_deltas = False
159
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
160
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
161
+
162
+
163
+ ###### Stage 2: threshold, ks settings 1, optimize acts ######
164
+ drive_pointset = "actions"
165
+ fix_obj = True
166
+ optimize_rules = False
167
+ train_pointset_acts_via_deltas = True
168
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt" ## pre-optimized ckpts
169
+ load_optimized_init_actions = "exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_/checkpoints/ckpt_004000.pth"
170
+ ###### Stage 2: threshold, ks settings 1, optimize acts ######
171
+
172
+
173
+ # ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
174
+ # drive_pointset = "actions"
175
+ # fix_obj = False
176
+ # optimize_rules = True
177
+ # train_pointset_acts_via_deltas = True
178
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
179
+ # ##### model parameters optimized from the MANO hand trajectory #####
180
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
181
+ # ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
182
+
183
+
184
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
185
+ # drive_pointset = "actions"
186
+ # fix_obj = False
187
+ # optimize_rules = False
188
+ # train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
189
+ # ##### model parameters optimized from the MANO hand expanded set trajectory #####
190
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
191
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
192
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
193
+
194
+
195
+ use_opt_rigid_translations=True
196
+
197
+ train_def = True
198
+ optimizable_rigid_translations=True
199
+
200
+ nerf {
201
+ D = 8,
202
+ d_in = 4,
203
+ d_in_view = 3,
204
+ W = 256,
205
+ multires = 10,
206
+ multires_view = 4,
207
+ output_ch = 4,
208
+ skips=[4],
209
+ use_viewdirs=True
210
+ }
211
+
212
+ sdf_network {
213
+ d_out = 257,
214
+ d_in = 3,
215
+ d_hidden = 256,
216
+ n_layers = 8,
217
+ skip_in = [4],
218
+ multires = 6,
219
+ bias = 0.5,
220
+ scale = 1.0,
221
+ geometric_init = True,
222
+ weight_norm = True,
223
+ }
224
+
225
+ variance_network {
226
+ init_val = 0.3
227
+ }
228
+
229
+ rendering_network {
230
+ d_feature = 256,
231
+ mode = idr,
232
+ d_in = 9,
233
+ d_out = 3,
234
+ d_hidden = 256,
235
+ n_layers = 4,
236
+ weight_norm = True,
237
+ multires_view = 4,
238
+ squeeze_out = True,
239
+ }
240
+
241
+ neus_renderer {
242
+ n_samples = 64,
243
+ n_importance = 64,
244
+ n_outside = 0,
245
+ up_sample_steps = 4 ,
246
+ perturb = 1.0,
247
+ }
248
+
249
+ bending_network {
250
+ multires = 6,
251
+ bending_latent_size = 32,
252
+ d_in = 3,
253
+ rigidity_hidden_dimensions = 64,
254
+ rigidity_network_depth = 5,
255
+ use_rigidity_network = False,
256
+ bending_n_timesteps = 10,
257
+ }
258
+ }
confs_new/dyn_grab_pointset_points_dyn_s3.conf ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_optsysps_"
7
+
8
+ recording = [
9
+ ./,
10
+ ./models
11
+ ]
12
+ }
13
+
14
+ dataset {
15
+ data_dir = public_data/CASE_NAME/
16
+ render_cameras_name = cameras_sphere.npz
17
+ object_cameras_name = cameras_sphere.npz
18
+ obj_idx = 102
19
+ }
20
+
21
+ train {
22
+ learning_rate = 5e-4
23
+ learning_rate_alpha = 0.05
24
+ end_iter = 300000
25
+
26
+ batch_size = 1024 # 64
27
+ validate_resolution_level = 4
28
+ warm_up_end = 5000
29
+ anneal_end = 0
30
+ use_white_bkgd = False
31
+
32
+ # save_freq = 10000
33
+ save_freq = 10000
34
+ val_freq = 20 # 2500
35
+ val_mesh_freq = 20 # 5000
36
+ report_freq = 10
37
+ ### igr weight ###
38
+ igr_weight = 0.1
39
+ mask_weight = 0.1
40
+ }
41
+
42
+ model {
43
+
44
+ optimize_dyn_actions = True
45
+
46
+
47
+ optimize_robot = True
48
+
49
+ use_penalty_based_friction = True
50
+
51
+ use_split_params = False
52
+
53
+ use_sqr_spring_stiffness = True
54
+
55
+ use_pre_proj_frictions = True
56
+
57
+
58
+
59
+ use_sqrt_dist = True
60
+ contact_maintaining_dist_thres = 0.2
61
+
62
+ robot_actions_diff_coef = 0.001
63
+
64
+
65
+ use_sdf_as_contact_dist = True
66
+
67
+
68
+ #
69
+ use_contact_dist_as_sdf = False
70
+
71
+ use_glb_proj_delta = True
72
+
73
+
74
+
75
+ # penetration_proj_k_to_robot = 30
76
+ penetrating_depth_penalty = 1.0
77
+ train_states = True
78
+
79
+
80
+
81
+ minn_dist_threshold = 0.000
82
+ obj_mass = 30.0
83
+
84
+
85
+ use_LBFGS = True
86
+ use_LBFGS = False
87
+
88
+ use_mano_hand_for_test = False # use the dynamic mano model here #
89
+
90
+ extract_delta_mesh = False
91
+ freeze_weights = True
92
+ gt_act_xs_def = False
93
+ use_bending_network = True
94
+ ### for ts = 3 ###
95
+ # use_delta_bending = False
96
+ ### for ts = 3 ###
97
+
98
+
99
+
100
+
101
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
102
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
103
+
104
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
105
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
106
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
107
+
108
+ bending_net_type = "active_force_field_v18"
109
+ sim_num_steps = 1000000
110
+ n_timesteps = 60
111
+ optim_sim_model_params_from_mano = False
112
+ penetration_determining = "sdf_of_canon"
113
+ train_with_forces_to_active = False
114
+ loss_scale_coef = 1000.0
115
+ use_same_contact_spring_k = False
116
+ use_optimizable_params = True #
117
+ train_residual_friction = True
118
+ mano_mult_const_after_cent = 1.0
119
+ optimize_glb_transformations = True
120
+ no_friction_constraint = False
121
+ optimize_active_object = True
122
+ loss_tangential_diff_coef = 0
123
+ optimize_with_intermediates = True
124
+ using_delta_glb_trans = False
125
+ train_multi_seqs = False
126
+ use_split_network = True
127
+ use_delta_bending = True
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+ ###### threshold, ks settings 1, optimize acts ######
138
+ # drive_pointset = "actions"
139
+ # fix_obj = True
140
+ # optimize_rules = False
141
+ # train_pointset_acts_via_deltas = True
142
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
143
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
144
+ ###### threshold, ks settings 1, optimize acts ######
145
+
146
+
147
+ ##### contact spring model settings ####
148
+ minn_dist_threshold_robot_to_obj = 0.1
149
+ penetration_proj_k_to_robot_friction = 10000000.0
150
+ penetration_proj_k_to_robot = 4000000.0
151
+ ##### contact spring model settings ####
152
+
153
+
154
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
155
+ drive_pointset = "states"
156
+ fix_obj = True
157
+ optimize_rules = False
158
+ train_pointset_acts_via_deltas = False
159
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
160
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
161
+
162
+
163
+ ###### Stage 2: threshold, ks settings 1, optimize acts ######
164
+ drive_pointset = "actions"
165
+ fix_obj = True
166
+ optimize_rules = False
167
+ train_pointset_acts_via_deltas = True
168
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
169
+ ###### Stage 2: threshold, ks settings 1, optimize acts ######
170
+
171
+
172
+ ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
173
+ drive_pointset = "actions"
174
+ fix_obj = False
175
+ optimize_rules = True
176
+ train_pointset_acts_via_deltas = True
177
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
178
+ load_optimized_init_actions = "exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_/checkpoints/ckpt_004000.pth"
179
+ ##### model parameters optimized from the MANO hand trajectory #####
180
+ ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
181
+ ckpt_fn = "exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_/checkpoints/ckpt_004000.pth"
182
+ ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
183
+
184
+
185
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
186
+ # drive_pointset = "actions"
187
+ # fix_obj = False
188
+ # optimize_rules = False
189
+ # train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
190
+ # ##### model parameters optimized from the MANO hand expanded set trajectory #####
191
+ # ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
192
+ # load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
193
+ # ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
194
+
195
+
196
+ use_opt_rigid_translations=True
197
+
198
+ train_def = True
199
+ optimizable_rigid_translations=True
200
+
201
+ nerf {
202
+ D = 8,
203
+ d_in = 4,
204
+ d_in_view = 3,
205
+ W = 256,
206
+ multires = 10,
207
+ multires_view = 4,
208
+ output_ch = 4,
209
+ skips=[4],
210
+ use_viewdirs=True
211
+ }
212
+
213
+ sdf_network {
214
+ d_out = 257,
215
+ d_in = 3,
216
+ d_hidden = 256,
217
+ n_layers = 8,
218
+ skip_in = [4],
219
+ multires = 6,
220
+ bias = 0.5,
221
+ scale = 1.0,
222
+ geometric_init = True,
223
+ weight_norm = True,
224
+ }
225
+
226
+ variance_network {
227
+ init_val = 0.3
228
+ }
229
+
230
+ rendering_network {
231
+ d_feature = 256,
232
+ mode = idr,
233
+ d_in = 9,
234
+ d_out = 3,
235
+ d_hidden = 256,
236
+ n_layers = 4,
237
+ weight_norm = True,
238
+ multires_view = 4,
239
+ squeeze_out = True,
240
+ }
241
+
242
+ neus_renderer {
243
+ n_samples = 64,
244
+ n_importance = 64,
245
+ n_outside = 0,
246
+ up_sample_steps = 4 ,
247
+ perturb = 1.0,
248
+ }
249
+
250
+ bending_network {
251
+ multires = 6,
252
+ bending_latent_size = 32,
253
+ d_in = 3,
254
+ rigidity_hidden_dimensions = 64,
255
+ rigidity_network_depth = 5,
256
+ use_rigidity_network = False,
257
+ bending_n_timesteps = 10,
258
+ }
259
+ }
confs_new/dyn_grab_pointset_points_dyn_s4.conf ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+ tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_optsysps_optacts_"
7
+
8
+ recording = [
9
+ ./,
10
+ ./models
11
+ ]
12
+ }
13
+
14
+ dataset {
15
+ data_dir = public_data/CASE_NAME/
16
+ render_cameras_name = cameras_sphere.npz
17
+ object_cameras_name = cameras_sphere.npz
18
+ obj_idx = 102
19
+ }
20
+
21
+ train {
22
+ learning_rate = 5e-4
23
+ learning_rate_alpha = 0.05
24
+ end_iter = 300000
25
+
26
+ batch_size = 1024 # 64
27
+ validate_resolution_level = 4
28
+ warm_up_end = 5000
29
+ anneal_end = 0
30
+ use_white_bkgd = False
31
+
32
+ # save_freq = 10000
33
+ save_freq = 10000
34
+ val_freq = 20 # 2500
35
+ val_mesh_freq = 20 # 5000
36
+ report_freq = 10
37
+ ### igr weight ###
38
+ igr_weight = 0.1
39
+ mask_weight = 0.1
40
+ }
41
+
42
+ model {
43
+
44
+ optimize_dyn_actions = True
45
+
46
+
47
+ optimize_robot = True
48
+
49
+ use_penalty_based_friction = True
50
+
51
+ use_split_params = False
52
+
53
+ use_sqr_spring_stiffness = True
54
+
55
+ use_pre_proj_frictions = True
56
+
57
+
58
+
59
+ use_sqrt_dist = True
60
+ contact_maintaining_dist_thres = 0.2
61
+
62
+ robot_actions_diff_coef = 0.001
63
+
64
+
65
+ use_sdf_as_contact_dist = True
66
+
67
+
68
+ #
69
+ use_contact_dist_as_sdf = False
70
+
71
+ use_glb_proj_delta = True
72
+
73
+
74
+
75
+ # penetration_proj_k_to_robot = 30
76
+ penetrating_depth_penalty = 1.0
77
+ train_states = True
78
+
79
+
80
+
81
+ minn_dist_threshold = 0.000
82
+ obj_mass = 30.0
83
+
84
+
85
+ use_LBFGS = True
86
+ use_LBFGS = False
87
+
88
+ use_mano_hand_for_test = False # use the dynamic mano model here #
89
+
90
+ extract_delta_mesh = False
91
+ freeze_weights = True
92
+ gt_act_xs_def = False
93
+ use_bending_network = True
94
+ ### for ts = 3 ###
95
+ # use_delta_bending = False
96
+ ### for ts = 3 ###
97
+
98
+
99
+
100
+
101
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
102
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
103
+
104
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
105
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
106
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
107
+
108
+ bending_net_type = "active_force_field_v18"
109
+ sim_num_steps = 1000000
110
+ n_timesteps = 60
111
+ optim_sim_model_params_from_mano = False
112
+ penetration_determining = "sdf_of_canon"
113
+ train_with_forces_to_active = False
114
+ loss_scale_coef = 1000.0
115
+ use_same_contact_spring_k = False
116
+ use_optimizable_params = True #
117
+ train_residual_friction = True
118
+ mano_mult_const_after_cent = 1.0
119
+ optimize_glb_transformations = True
120
+ no_friction_constraint = False
121
+ optimize_active_object = True
122
+ loss_tangential_diff_coef = 0
123
+ optimize_with_intermediates = True
124
+ using_delta_glb_trans = False
125
+ train_multi_seqs = False
126
+ use_split_network = True
127
+ use_delta_bending = True
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+ ###### threshold, ks settings 1, optimize acts ######
138
+ # drive_pointset = "actions"
139
+ # fix_obj = True
140
+ # optimize_rules = False
141
+ # train_pointset_acts_via_deltas = True
142
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
143
+ # load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
144
+ ###### threshold, ks settings 1, optimize acts ######
145
+
146
+
147
+ ##### contact spring model settings ####
148
+ minn_dist_threshold_robot_to_obj = 0.1
149
+ penetration_proj_k_to_robot_friction = 10000000.0
150
+ penetration_proj_k_to_robot = 4000000.0
151
+ ##### contact spring model settings ####
152
+
153
+
154
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
155
+ drive_pointset = "states"
156
+ fix_obj = True
157
+ optimize_rules = False
158
+ train_pointset_acts_via_deltas = False
159
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
160
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
161
+
162
+
163
+ ###### Stage 2: threshold, ks settings 1, optimize acts ######
164
+ drive_pointset = "actions"
165
+ fix_obj = True
166
+ optimize_rules = False
167
+ train_pointset_acts_via_deltas = True
168
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
169
+ ###### Stage 2: threshold, ks settings 1, optimize acts ######
170
+
171
+
172
+ ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
173
+ drive_pointset = "actions"
174
+ fix_obj = False
175
+ optimize_rules = True
176
+ train_pointset_acts_via_deltas = True
177
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
178
+ ##### model parameters optimized from the MANO hand trajectory #####
179
+ ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
180
+ ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
181
+
182
+
183
+ ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
184
+ drive_pointset = "actions"
185
+ fix_obj = False
186
+ optimize_rules = False
187
+ train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
188
+ ##### model parameters optimized from the MANO hand expanded set trajectory #####
189
+ ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
190
+ load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
191
+ ckpt_fn = "exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_optsysps_/checkpoints/ckpt_044000.pth"
192
+ load_optimized_init_actions = "exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_optsysps_/checkpoints/ckpt_044000.pth"
193
+ ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
194
+
195
+
196
+ use_opt_rigid_translations=True
197
+
198
+ train_def = True
199
+ optimizable_rigid_translations=True
200
+
201
+ nerf {
202
+ D = 8,
203
+ d_in = 4,
204
+ d_in_view = 3,
205
+ W = 256,
206
+ multires = 10,
207
+ multires_view = 4,
208
+ output_ch = 4,
209
+ skips=[4],
210
+ use_viewdirs=True
211
+ }
212
+
213
+ sdf_network {
214
+ d_out = 257,
215
+ d_in = 3,
216
+ d_hidden = 256,
217
+ n_layers = 8,
218
+ skip_in = [4],
219
+ multires = 6,
220
+ bias = 0.5,
221
+ scale = 1.0,
222
+ geometric_init = True,
223
+ weight_norm = True,
224
+ }
225
+
226
+ variance_network {
227
+ init_val = 0.3
228
+ }
229
+
230
+ rendering_network {
231
+ d_feature = 256,
232
+ mode = idr,
233
+ d_in = 9,
234
+ d_out = 3,
235
+ d_hidden = 256,
236
+ n_layers = 4,
237
+ weight_norm = True,
238
+ multires_view = 4,
239
+ squeeze_out = True,
240
+ }
241
+
242
+ neus_renderer {
243
+ n_samples = 64,
244
+ n_importance = 64,
245
+ n_outside = 0,
246
+ up_sample_steps = 4 ,
247
+ perturb = 1.0,
248
+ }
249
+
250
+ bending_network {
251
+ multires = 6,
252
+ bending_latent_size = 32,
253
+ d_in = 3,
254
+ rigidity_hidden_dimensions = 64,
255
+ rigidity_network_depth = 5,
256
+ use_rigidity_network = False,
257
+ bending_n_timesteps = 10,
258
+ }
259
+ }
confs_new/dyn_grab_sparse_retar.conf ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ general {
2
+
3
+
4
+ base_exp_dir = exp/CASE_NAME/wmask
5
+
6
+
7
+ tag = "train_retargeted_shadow_hand_seq_102_mano_sparse_retargeting_"
8
+
9
+ recording = [
10
+ ./,
11
+ ./models
12
+ ]
13
+ }
14
+
15
+ dataset {
16
+ data_dir = public_data/CASE_NAME/
17
+ render_cameras_name = cameras_sphere.npz
18
+ object_cameras_name = cameras_sphere.npz
19
+ obj_idx = 102
20
+ }
21
+
22
+ train {
23
+ learning_rate = 5e-4
24
+ learning_rate_alpha = 0.05
25
+ end_iter = 300000
26
+
27
+ batch_size = 1024
28
+ validate_resolution_level = 4
29
+ warm_up_end = 5000
30
+ anneal_end = 0
31
+ use_white_bkgd = False
32
+
33
+ # save_freq = 10000
34
+ save_freq = 10000
35
+ val_freq = 20
36
+ val_mesh_freq = 20
37
+ report_freq = 10
38
+ igr_weight = 0.1
39
+ mask_weight = 0.1
40
+ }
41
+
42
+ model {
43
+
44
+ optimize_dyn_actions = True
45
+
46
+
47
+ optimize_robot = True
48
+
49
+ use_penalty_based_friction = True
50
+
51
+ use_split_params = False
52
+
53
+ use_sqr_spring_stiffness = True
54
+
55
+ use_pre_proj_frictions = True
56
+
57
+
58
+
59
+ use_sqrt_dist = True
60
+ contact_maintaining_dist_thres = 0.2
61
+
62
+ robot_actions_diff_coef = 0.001
63
+
64
+
65
+ use_sdf_as_contact_dist = True
66
+
67
+
68
+ #
69
+ use_contact_dist_as_sdf = False
70
+
71
+ use_glb_proj_delta = True
72
+
73
+
74
+
75
+ # penetration_proj_k_to_robot = 30
76
+ penetrating_depth_penalty = 1.0
77
+ train_states = True
78
+
79
+
80
+
81
+ minn_dist_threshold = 0.000
82
+ obj_mass = 30.0
83
+
84
+
85
+ use_LBFGS = True
86
+ use_LBFGS = False
87
+
88
+ use_mano_hand_for_test = False # use the dynamic mano model here #
89
+
90
+ extract_delta_mesh = False
91
+ freeze_weights = True
92
+ gt_act_xs_def = False
93
+ use_bending_network = True
94
+ ### for ts = 3 ###
95
+ # use_delta_bending = False
96
+ ### for ts = 3 ###
97
+
98
+
99
+
100
+
101
+ sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
102
+ mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
103
+
104
+ obj_sdf_fn = "data/grab/102/102_obj.npy"
105
+ kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
106
+ scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
107
+
108
+ bending_net_type = "active_force_field_v18"
109
+ sim_num_steps = 1000000
110
+ n_timesteps = 60
111
+ optim_sim_model_params_from_mano = False
112
+ penetration_determining = "sdf_of_canon"
113
+ train_with_forces_to_active = False
114
+ loss_scale_coef = 1000.0
115
+ use_same_contact_spring_k = False
116
+ use_optimizable_params = True #
117
+ train_residual_friction = True
118
+ mano_mult_const_after_cent = 1.0
119
+ optimize_glb_transformations = True
120
+ no_friction_constraint = False
121
+ optimize_active_object = True
122
+ loss_tangential_diff_coef = 0
123
+ optimize_with_intermediates = True
124
+ using_delta_glb_trans = False
125
+ train_multi_seqs = False
126
+ use_split_network = True
127
+ use_delta_bending = True
128
+
129
+
130
+
131
+
132
+ ##### contact spring model settings ####
133
+ minn_dist_threshold_robot_to_obj = 0.1
134
+ penetration_proj_k_to_robot_friction = 10000000.0
135
+ penetration_proj_k_to_robot = 4000000.0
136
+ ##### contact spring model settings ####
137
+
138
+
139
+ ###### ######
140
+ drive_pointset = "states"
141
+ fix_obj = True
142
+ optimize_rules = False
143
+ train_pointset_acts_via_deltas = False
144
+ load_optimized_init_actions = ""
145
+ load_optimized_init_transformations = ""
146
+ ckpt_fn = ""
147
+ retar_only_glb = True
148
+ # use_multi_stages = True
149
+ ###### Stage 1: threshold, ks settings 1, optimize offsets ######
150
+
151
+ use_opt_rigid_translations=True
152
+
153
+ train_def = True
154
+ optimizable_rigid_translations=True
155
+
156
+ nerf {
157
+ D = 8,
158
+ d_in = 4,
159
+ d_in_view = 3,
160
+ W = 256,
161
+ multires = 10,
162
+ multires_view = 4,
163
+ output_ch = 4,
164
+ skips=[4],
165
+ use_viewdirs=True
166
+ }
167
+
168
+ sdf_network {
169
+ d_out = 257,
170
+ d_in = 3,
171
+ d_hidden = 256,
172
+ n_layers = 8,
173
+ skip_in = [4],
174
+ multires = 6,
175
+ bias = 0.5,
176
+ scale = 1.0,
177
+ geometric_init = True,
178
+ weight_norm = True,
179
+ }
180
+
181
+ variance_network {
182
+ init_val = 0.3
183
+ }
184
+
185
+ rendering_network {
186
+ d_feature = 256,
187
+ mode = idr,
188
+ d_in = 9,
189
+ d_out = 3,
190
+ d_hidden = 256,
191
+ n_layers = 4,
192
+ weight_norm = True,
193
+ multires_view = 4,
194
+ squeeze_out = True,
195
+ }
196
+
197
+ neus_renderer {
198
+ n_samples = 64,
199
+ n_importance = 64,
200
+ n_outside = 0,
201
+ up_sample_steps = 4 ,
202
+ perturb = 1.0,
203
+ }
204
+
205
+ bending_network {
206
+ multires = 6,
207
+ bending_latent_size = 32,
208
+ d_in = 3,
209
+ rigidity_hidden_dimensions = 64,
210
+ rigidity_network_depth = 5,
211
+ use_rigidity_network = False,
212
+ bending_n_timesteps = 10,
213
+ }
214
+ }
exp_runner_stage_1.py ADDED
The diff for this file is too large to render. See raw diff
 
models/data_utils_torch.py ADDED
@@ -0,0 +1,1547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mesh data utilities."""
2
+ from re import I
3
+ # import matplotlib.pyplot as plt
4
+ # from mpl_toolkits import mplot3d # pylint: disable=unused-import
5
+ # from mpl_toolkits.mplot3d.art3d import Poly3DCollection
6
+ import networkx as nx
7
+ import numpy as np
8
+ import six
9
+ import os
10
+ import math
11
+ import torch
12
+ import models.fields as fields
13
+ # from options.options import opt
14
+ # from polygen_torch.
15
+ # from utils.constants import MASK_GRID_VALIE
16
+
17
+ try:
18
+ from torch_cluster import fps
19
+ except:
20
+ pass
21
+
22
+ MAX_RANGE = 0.1
23
+ MIN_RANGE = -0.1
24
+
25
+ # import open3d as o3d
26
+
27
+ def calculate_correspondences(last_mesh, bending_network, tot_timesteps, delta_bending):
28
+ # iterate all the timesteps and get the bended
29
+ timestep_to_pts = {
30
+ tot_timesteps - 1: last_mesh.detach().cpu().numpy()
31
+ }
32
+ if delta_bending:
33
+
34
+ for i_ts in range(tot_timesteps - 1, -1, -1):
35
+ if isinstance(bending_network, list):
36
+ tot_offsets = []
37
+ for i_bending, cur_bending_net in enumerate(bending_network):
38
+ cur_def_pts = cur_bending_net(last_mesh, i_ts)
39
+ tot_offsets.append(cur_def_pts - last_mesh)
40
+ tot_offsets = torch.stack(tot_offsets, dim=0)
41
+ tot_offsets = torch.sum(tot_offsets, dim=0)
42
+ last_mesh = last_mesh + tot_offsets
43
+ else:
44
+ last_mesh = bending_network(last_mesh, i_ts)
45
+ timestep_to_pts[i_ts - 1] = last_mesh.detach().cpu().numpy()
46
+ elif isinstance(bending_network, fields.BendingNetworkRigidTrans):
47
+ for i_ts in range(tot_timesteps - 1, -1, -1):
48
+ last_mesh = bending_network.forward_delta(last_mesh, i_ts)
49
+ timestep_to_pts[i_ts - 1] = last_mesh.detach().cpu().numpy()
50
+ else:
51
+ raise ValueError(f"the function is designed for delta bending")
52
+ return timestep_to_pts
53
+ # pass
54
+
55
+ def joint_infos_to_numpy(joint_infos):
56
+ joint_infos_np = []
57
+ for part_joint_info in joint_infos:
58
+ for zz in ["dir", "center"]:
59
+ # if isinstance(part_joint_info["axis"][zz], np.array):
60
+ part_joint_info["axis"][zz] = part_joint_info["axis"][zz].detach().numpy()
61
+ joint_infos_np.append(part_joint_info)
62
+ return joint_infos_np
63
+
64
+
65
+ def normalie_pc_bbox_batched(pc, rt_stats=False):
66
+ pc_min = torch.min(pc, dim=1, keepdim=True)[0]
67
+ pc_max = torch.max(pc, dim=1, keepdim=True)[0]
68
+ pc_center = 0.5 * (pc_min + pc_max)
69
+
70
+ pc = pc - pc_center
71
+ extents = pc_max - pc_min
72
+ scale = torch.sqrt(torch.sum(extents ** 2, dim=-1, keepdim=True))
73
+
74
+ pc = pc / torch.clamp(scale, min=1e-6)
75
+ if rt_stats:
76
+ return pc, pc_center, scale
77
+ else:
78
+ return pc
79
+
80
+
81
+
82
+ def scale_vertices_to_target_scale(vertices, target_scale):
83
+ # vertices: bsz x N x 3;
84
+ # target_scale: bsz x 1
85
+ # normalized_vertices = normalize_vertices_scale_torch(vertices)
86
+ normalized_vertices = normalie_pc_bbox_batched(vertices)
87
+ normalized_vertices = normalized_vertices * target_scale.unsqueeze(1)
88
+ return normalized_vertices
89
+
90
+ def compute_normals_o3d(verts, faces): #### assume no batching... ####
91
+ mesh = o3d.geometry.TriangleMesh()
92
+ # o3d_mesh_b.vertices = verts_b
93
+ # o3d_mesh_b.triangles = np.array(faces_b, dtype=np.long)
94
+ mesh.vertices = verts.detach().cpu().numpy()
95
+ mesh.triangles = faces.detach().cpu().numpy().astype(np.long)
96
+ verts_normals = mesh.compute_vertex_normals(normalized=True)
97
+ verts_normals = torch.from_numpy(verts_normals, dtype=torch.float32).cuda()
98
+ return verts_normals
99
+
100
+
101
+ def get_vals_via_nearest_neighbours(pts_src, pts_tar, val_tar):
102
+ ### n_src x 3 ---> n_src x n_tar
103
+ dist_src_tar = torch.sum((pts_src.unsqueeze(-2) - pts_tar.unsqueeze(0)) ** 2, dim=-1)
104
+ minn_dists_src_tar, minn_dists_tar_idxes = torch.min(dist_src_tar, dim=-1) ### n_src
105
+ selected_src_val = batched_index_select(values=val_tar, indices=minn_dists_tar_idxes, dim=0) ### n_src x dim
106
+ return selected_src_val
107
+
108
+
109
+
110
+ ## sample conected componetn start from selected_verts
111
+ def sample_bfs_component(selected_vert, faces, max_num_grids):
112
+ vert_idx_to_adj_verts = {}
113
+ for i_f, cur_f in enumerate(faces):
114
+ # for i0, v0 in enumerate(cur_f):
115
+ for i0 in range(len(cur_f)):
116
+ v0 = cur_f[i0] - 1
117
+ i1 = (i0 + 1) % len(cur_f)
118
+ v1 = cur_f[i1] - 1
119
+ if v0 not in vert_idx_to_adj_verts:
120
+ vert_idx_to_adj_verts[v0] = {v1: 1}
121
+ else:
122
+ vert_idx_to_adj_verts[v0][v1] = 1
123
+ if v1 not in vert_idx_to_adj_verts:
124
+ vert_idx_to_adj_verts[v1] = {v0: 1}
125
+ else:
126
+ vert_idx_to_adj_verts[v1][v0] = 1
127
+ vert_idx_to_visited = {} # whether visisted here #
128
+ vis_que = [selected_vert]
129
+ vert_idx_to_visited[selected_vert] = 1
130
+ visited = [selected_vert]
131
+ while len(vis_que) > 0 and len(visited) < max_num_grids:
132
+ cur_frnt_vert = vis_que[0]
133
+ vis_que.pop(0)
134
+ if cur_frnt_vert in vert_idx_to_adj_verts:
135
+ cur_frnt_vert_adjs = vert_idx_to_adj_verts[cur_frnt_vert]
136
+ for adj_vert in cur_frnt_vert_adjs:
137
+ if adj_vert not in vert_idx_to_visited:
138
+ vert_idx_to_visited[adj_vert] = 1
139
+ vis_que.append(adj_vert)
140
+ visited.append(adj_vert)
141
+ if len(visited) >= max_num_grids:
142
+ visited = visited[: max_num_grids - 1]
143
+ return visited
144
+
145
+ def select_faces_via_verts(selected_verts, faces):
146
+ if not isinstance(selected_verts, list):
147
+ selected_verts = selected_verts.tolist()
148
+ # selected_verts_dict = {ii: 1 for ii in selected_verts}
149
+ old_idx_to_new_idx = {v + 1: ii + 1 for ii, v in enumerate(selected_verts)} ####### v + 1: ii + 1 --> for the selected_verts
150
+ new_faces = []
151
+ for i_f, cur_f in enumerate(faces):
152
+ cur_new_f = []
153
+ valid = True
154
+ for cur_v in cur_f:
155
+ if cur_v in old_idx_to_new_idx:
156
+ cur_new_f.append(old_idx_to_new_idx[cur_v])
157
+ else:
158
+ valid = False
159
+ break
160
+ if valid:
161
+ new_faces.append(cur_new_f)
162
+ return new_faces
163
+
164
+
165
+ def convert_grid_content_to_grid_pts(content_value, grid_size):
166
+ flat_grid = torch.zeros([grid_size ** 3], dtype=torch.long)
167
+ cur_idx = flat_grid.size(0) - 1
168
+ while content_value > 0:
169
+ flat_grid[cur_idx] = content_value % grid_size
170
+ content_value = content_value // grid_size
171
+ cur_idx -= 1
172
+ grid_pts = flat_grid.contiguous().view(grid_size, grid_size, grid_size).contiguous()
173
+ return grid_pts
174
+
175
+ # 0.2
176
+ def warp_coord(sampled_gradients, val, reflect=False): # val from [0.0, 1.0] # from the 0.0
177
+ # assume single value as inputs
178
+ grad_values = sampled_gradients.tolist()
179
+ # mid_val
180
+ mid_val = grad_values[0] * 0.2 + grad_values[1] * 0.2 + grad_values[2] * 0.1
181
+ if reflect:
182
+ grad_values[3] = grad_values[1]
183
+ grad_values[4] = grad_values[0]
184
+
185
+ # if not reflect:
186
+ accum_val = 0.0
187
+ for i_val in range(len(grad_values)):
188
+ if val > 0.2 * (i_val + 1) and i_val < 4: # if i_val == 4, directly use the reamining length * corresponding gradient value
189
+ accum_val += grad_values[i_val] * 0.2
190
+ else:
191
+ accum_val += grad_values[i_val] * (val - 0.2 * i_val)
192
+ break
193
+ return accum_val # modified value
194
+
195
+ def random_shift(vertices, shift_factor=0.25):
196
+ """Apply random shift to vertices."""
197
+ # max_shift_pos = tf.cast(255 - tf.reduce_max(vertices, axis=0), tf.float32)
198
+
199
+ # max_shift_pos = tf.maximum(max_shift_pos, 1e-9)
200
+
201
+ # max_shift_neg = tf.cast(tf.reduce_min(vertices, axis=0), tf.float32)
202
+ # max_shift_neg = tf.maximum(max_shift_neg, 1e-9)
203
+
204
+ # shift = tfd.TruncatedNormal(
205
+ # tf.zeros([1, 3]), shift_factor*255, -max_shift_neg,
206
+ # max_shift_pos).sample()
207
+ # shift = tf.cast(shift, tf.int32)
208
+ # vertices += shift
209
+
210
+ minn_tensor = torch.tensor([1e-9], dtype=torch.float32)
211
+
212
+ max_shift_pos = (255 - torch.max(vertices, dim=0)[0]).float()
213
+ max_shift_pos = torch.maximum(max_shift_pos, minn_tensor)
214
+ max_shift_neg = (torch.min(vertices, dim=0)[0]).float()
215
+ max_shift_neg = torch.maximum(max_shift_neg, minn_tensor)
216
+
217
+ shift = torch.zeros((1, 3), dtype=torch.float32)
218
+ # torch.nn.init.trunc_normal_(shift, 0., shift_factor * 255., -max_shift_neg, max_shift_pos)
219
+ for i_s in range(shift.size(-1)):
220
+ cur_axis_max_shift_neg = max_shift_neg[i_s].item()
221
+ cur_axis_max_shift_pos = max_shift_pos[i_s].item()
222
+ cur_axis_shift = torch.zeros((1,), dtype=torch.float32)
223
+
224
+ torch.nn.init.trunc_normal_(cur_axis_shift, 0., shift_factor * 255., -cur_axis_max_shift_neg, cur_axis_max_shift_pos)
225
+ shift[:, i_s] = cur_axis_shift.item()
226
+
227
+ shift = shift.long()
228
+ vertices += shift
229
+
230
+ return vertices
231
+
232
+ def safe_transpose(tsr, dima, dimb):
233
+ tsr = tsr.contiguous().transpose(dima, dimb).contiguous()
234
+ return tsr
235
+
236
+ def batched_index_select(values, indices, dim = 1):
237
+ value_dims = values.shape[(dim + 1):]
238
+ values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
239
+ indices = indices[(..., *((None,) * len(value_dims)))]
240
+ indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
241
+ value_expand_len = len(indices_shape) - (dim + 1)
242
+ values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
243
+
244
+ value_expand_shape = [-1] * len(values.shape)
245
+ expand_slice = slice(dim, (dim + value_expand_len))
246
+ value_expand_shape[expand_slice] = indices.shape[expand_slice]
247
+ values = values.expand(*value_expand_shape)
248
+
249
+ dim += value_expand_len
250
+ return values.gather(dim, indices)
251
+
252
+ def read_obj_file_ours(obj_fn, sub_one=False):
253
+ vertices = []
254
+ faces = []
255
+ with open(obj_fn, "r") as rf:
256
+ for line in rf:
257
+ items = line.strip().split(" ")
258
+ if items[0] == 'v':
259
+ cur_verts = items[1:]
260
+ cur_verts = [float(vv) for vv in cur_verts]
261
+ vertices.append(cur_verts)
262
+ elif items[0] == 'f':
263
+ cur_faces = items[1:] # faces
264
+ cur_face_idxes = []
265
+ for cur_f in cur_faces:
266
+ try:
267
+ cur_f_idx = int(cur_f.split("/")[0])
268
+ except:
269
+ cur_f_idx = int(cur_f.split("//")[0])
270
+ cur_face_idxes.append(cur_f_idx if not sub_one else cur_f_idx - 1)
271
+ faces.append(cur_face_idxes)
272
+ rf.close()
273
+ vertices = np.array(vertices, dtype=np.float)
274
+ return vertices, faces
275
+
276
+ def read_obj_file(obj_file):
277
+ """Read vertices and faces from already opened file."""
278
+ vertex_list = []
279
+ flat_vertices_list = []
280
+ flat_vertices_indices = {}
281
+ flat_triangles = []
282
+
283
+ for line in obj_file:
284
+ tokens = line.split()
285
+ if not tokens:
286
+ continue
287
+ line_type = tokens[0]
288
+ # We skip lines not starting with v or f.
289
+ if line_type == 'v': #
290
+ vertex_list.append([float(x) for x in tokens[1:]])
291
+ elif line_type == 'f':
292
+ triangle = []
293
+ for i in range(len(tokens) - 1):
294
+ vertex_name = tokens[i + 1]
295
+ if vertex_name in flat_vertices_indices: # triangles
296
+ triangle.append(flat_vertices_indices[vertex_name])
297
+ continue
298
+ flat_vertex = []
299
+ for index in six.ensure_str(vertex_name).split('/'):
300
+ if not index:
301
+ continue
302
+ # obj triangle indices are 1 indexed, so subtract 1 here.
303
+ flat_vertex += vertex_list[int(index) - 1]
304
+ flat_vertex_index = len(flat_vertices_list)
305
+ flat_vertices_list.append(flat_vertex)
306
+ # flat_vertex_index
307
+ flat_vertices_indices[vertex_name] = flat_vertex_index
308
+ triangle.append(flat_vertex_index)
309
+ flat_triangles.append(triangle)
310
+
311
+ return np.array(flat_vertices_list, dtype=np.float32), flat_triangles
312
+
313
+
314
+ def batched_index_select(values, indices, dim = 1):
315
+ value_dims = values.shape[(dim + 1):]
316
+ values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
317
+ indices = indices[(..., *((None,) * len(value_dims)))]
318
+ indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
319
+ value_expand_len = len(indices_shape) - (dim + 1)
320
+ values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
321
+
322
+ value_expand_shape = [-1] * len(values.shape)
323
+ expand_slice = slice(dim, (dim + value_expand_len))
324
+ value_expand_shape[expand_slice] = indices.shape[expand_slice]
325
+ values = values.expand(*value_expand_shape)
326
+
327
+ dim += value_expand_len
328
+ return values.gather(dim, indices)
329
+
330
+
331
+ def safe_transpose(x, dim1, dim2):
332
+ x = x.contiguous().transpose(dim1, dim2).contiguous()
333
+ return x
334
+
335
+ def merge_meshes(vertices_list, faces_list):
336
+ tot_vertices = []
337
+ tot_faces = []
338
+ nn_verts = 0
339
+ for cur_vertices, cur_faces in zip(vertices_list, faces_list):
340
+ tot_vertices.append(cur_vertices)
341
+ new_cur_faces = []
342
+ for cur_face_idx in cur_faces:
343
+ new_cur_face_idx = [vert_idx + nn_verts for vert_idx in cur_face_idx]
344
+ new_cur_faces.append(new_cur_face_idx)
345
+ nn_verts += cur_vertices.shape[0]
346
+ tot_faces += new_cur_faces # get total-faces
347
+ tot_vertices = np.concatenate(tot_vertices, axis=0)
348
+ return tot_vertices, tot_faces
349
+
350
+
351
+ def read_obj(obj_path):
352
+ """Open .obj file from the path provided and read vertices and faces."""
353
+
354
+ with open(obj_path) as obj_file:
355
+ return read_obj_file_ours(obj_path, sub_one=True)
356
+ # return read_obj_file(obj_file)
357
+
358
+
359
+
360
+
361
+ def center_vertices(vertices):
362
+ """Translate the vertices so that bounding box is centered at zero."""
363
+ vert_min = vertices.min(axis=0)
364
+ vert_max = vertices.max(axis=0)
365
+ vert_center = 0.5 * (vert_min + vert_max)
366
+ return vertices - vert_center
367
+
368
+
369
+ def normalize_vertices_scale(vertices):
370
+ """Scale the vertices so that the long diagonal of the bounding box is one."""
371
+ vert_min = vertices.min(axis=0)
372
+ vert_max = vertices.max(axis=0)
373
+ extents = vert_max - vert_min
374
+ scale = np.sqrt(np.sum(extents**2)) # normalize the diagonal line to 1.
375
+ return vertices / scale
376
+
377
+
378
+ def get_vertices_center(vertices):
379
+ vert_min = vertices.min(axis=0)
380
+ vert_max = vertices.max(axis=0)
381
+ vert_center = 0.5 * (vert_min + vert_max)
382
+ return vert_center
383
+
384
+ def get_batched_vertices_center(vertices):
385
+ vert_min = vertices.min(axis=1)
386
+ vert_max = vertices.max(axis=1)
387
+ vert_center = 0.5 * (vert_min + vert_max)
388
+ return vert_center
389
+
390
+ def get_vertices_scale(vertices):
391
+ vert_min = vertices.min(axis=0)
392
+ vert_max = vertices.max(axis=0)
393
+ extents = vert_max - vert_min
394
+ scale = np.sqrt(np.sum(extents**2))
395
+ return scale
396
+
397
+ def quantize_verts(verts, n_bits=8, min_range=None, max_range=None):
398
+ """Convert vertices in [-1., 1.] to discrete values in [0, n_bits**2 - 1]."""
399
+ min_range = -0.5 if min_range is None else min_range
400
+ max_range = 0.5 if max_range is None else max_range
401
+ range_quantize = 2**n_bits - 1
402
+ verts_quantize = (verts - min_range) * range_quantize / (
403
+ max_range - min_range)
404
+ return verts_quantize.astype('int32')
405
+
406
+ def quantize_verts_torch(verts, n_bits=8, min_range=None, max_range=None):
407
+ min_range = -0.5 if min_range is None else min_range
408
+ max_range = 0.5 if max_range is None else max_range
409
+ range_quantize = 2**n_bits - 1
410
+ verts_quantize = (verts - min_range) * range_quantize / (
411
+ max_range - min_range)
412
+ return verts_quantize.long()
413
+
414
+ def dequantize_verts(verts, n_bits=8, add_noise=False, min_range=None, max_range=None):
415
+ """Convert quantized vertices to floats."""
416
+ min_range = -0.5 if min_range is None else min_range
417
+ max_range = 0.5 if max_range is None else max_range
418
+ range_quantize = 2**n_bits - 1
419
+ verts = verts.astype('float32')
420
+ verts = verts * (max_range - min_range) / range_quantize + min_range
421
+ if add_noise:
422
+ verts += np.random.uniform(size=verts.shape) * (1 / range_quantize)
423
+ return verts
424
+
425
+ def dequantize_verts_torch(verts, n_bits=8, add_noise=False, min_range=None, max_range=None):
426
+ min_range = -0.5 if min_range is None else min_range
427
+ max_range = 0.5 if max_range is None else max_range
428
+ range_quantize = 2**n_bits - 1
429
+ verts = verts.float()
430
+ verts = verts * (max_range - min_range) / range_quantize + min_range
431
+ # if add_noise:
432
+ # verts += np.random.uniform(size=verts.shape) * (1 / range_quantize)
433
+ return verts
434
+
435
+
436
+ ### dump vertices and faces to the obj file
437
+ def write_obj(vertices, faces, file_path, transpose=True, scale=1.):
438
+ """Write vertices and faces to obj."""
439
+ if transpose:
440
+ vertices = vertices[:, [1, 2, 0]]
441
+ vertices *= scale
442
+ if faces is not None:
443
+ if min(min(faces)) == 0:
444
+ f_add = 1
445
+ else:
446
+ f_add = 0
447
+ else:
448
+ faces = []
449
+ with open(file_path, 'w') as f:
450
+ for v in vertices:
451
+ f.write('v {} {} {}\n'.format(v[0], v[1], v[2]))
452
+ for face in faces:
453
+ line = 'f'
454
+ for i in face:
455
+ line += ' {}'.format(i + f_add)
456
+ line += '\n'
457
+ f.write(line)
458
+
459
+
460
+ def face_to_cycles(face):
461
+ """Find cycles in face."""
462
+ g = nx.Graph()
463
+ for v in range(len(face) - 1):
464
+ g.add_edge(face[v], face[v + 1])
465
+ g.add_edge(face[-1], face[0])
466
+ return list(nx.cycle_basis(g))
467
+
468
+
469
+ def flatten_faces(faces):
470
+ """Converts from list of faces to flat face array with stopping indices."""
471
+ if not faces:
472
+ return np.array([0])
473
+ else:
474
+ l = [f + [-1] for f in faces[:-1]]
475
+ l += [faces[-1] + [-2]]
476
+ return np.array([item for sublist in l for item in sublist]) + 2 # pylint: disable=g-complex-comprehension
477
+
478
+
479
+ def unflatten_faces(flat_faces):
480
+ """Converts from flat face sequence to a list of separate faces."""
481
+ def group(seq):
482
+ g = []
483
+ for el in seq:
484
+ if el == 0 or el == -1:
485
+ yield g
486
+ g = []
487
+ else:
488
+ g.append(el - 1)
489
+ yield g
490
+ outputs = list(group(flat_faces - 1))[:-1]
491
+ # Remove empty faces
492
+ return [o for o in outputs if len(o) > 2]
493
+
494
+
495
+
496
+ def quantize_process_mesh(vertices, faces, tris=None, quantization_bits=8, remove_du=True):
497
+ """Quantize vertices, remove resulting duplicates and reindex faces."""
498
+ vertices = quantize_verts(vertices, quantization_bits)
499
+ vertices, inv = np.unique(vertices, axis=0, return_inverse=True) # return inverse and unique the vertices
500
+
501
+ #
502
+ if opt.dataset.sort_dist:
503
+ if opt.model.debug:
504
+ print("sorting via dist...")
505
+ vertices_max = np.max(vertices, axis=0)
506
+ vertices_min = np.min(vertices, axis=0)
507
+ dist_vertices = np.minimum(np.abs(vertices - np.array([[vertices_min[0], vertices_min[1], 0]])), np.abs(vertices - np.array([[vertices_max[0], vertices_max[1], 0]])))
508
+ dist_vertices = np.concatenate([dist_vertices[:, 0:1] + dist_vertices[:, 1:2], dist_vertices[:, 2:]], axis=-1)
509
+ sort_inds = np.lexsort(dist_vertices.T)
510
+ else:
511
+ # Sort vertices by z then y then x.
512
+ sort_inds = np.lexsort(vertices.T) # sorted indices...
513
+ vertices = vertices[sort_inds]
514
+
515
+ # Re-index faces and tris to re-ordered vertices.
516
+ faces = [np.argsort(sort_inds)[inv[f]] for f in faces]
517
+ if tris is not None:
518
+ tris = np.array([np.argsort(sort_inds)[inv[t]] for t in tris])
519
+
520
+ # Merging duplicate vertices and re-indexing the faces causes some faces to
521
+ # contain loops (e.g [2, 3, 5, 2, 4]). Split these faces into distinct
522
+ # sub-faces.
523
+ sub_faces = []
524
+ for f in faces:
525
+ cliques = face_to_cycles(f)
526
+ for c in cliques:
527
+ c_length = len(c)
528
+ # Only append faces with more than two verts.
529
+ if c_length > 2:
530
+ d = np.argmin(c)
531
+ # Cyclically permute faces just that first index is the smallest.
532
+ sub_faces.append([c[(d + i) % c_length] for i in range(c_length)])
533
+ faces = sub_faces
534
+ if tris is not None:
535
+ tris = np.array([v for v in tris if len(set(v)) == len(v)])
536
+
537
+ # Sort faces by lowest vertex indices. If two faces have the same lowest
538
+ # index then sort by next lowest and so on.
539
+ faces.sort(key=lambda f: tuple(sorted(f)))
540
+ if tris is not None:
541
+ tris = tris.tolist()
542
+ tris.sort(key=lambda f: tuple(sorted(f)))
543
+ tris = np.array(tris)
544
+
545
+ # After removing degenerate faces some vertices are now unreferenced. # Vertices
546
+ # Remove these. # Vertices
547
+ num_verts = vertices.shape[0]
548
+ # print(f"remove_du: {remove_du}")
549
+ if remove_du: ##### num_verts
550
+ print("Removing du..")
551
+ try:
552
+ vert_connected = np.equal(
553
+ np.arange(num_verts)[:, None], np.hstack(faces)[None]).any(axis=-1)
554
+ vertices = vertices[vert_connected]
555
+
556
+
557
+ # Re-index faces and tris to re-ordered vertices.
558
+ vert_indices = (
559
+ np.arange(num_verts) - np.cumsum(1 - vert_connected.astype('int')))
560
+ faces = [vert_indices[f].tolist() for f in faces]
561
+ except:
562
+ pass
563
+ if tris is not None:
564
+ tris = np.array([vert_indices[t].tolist() for t in tris])
565
+
566
+ return vertices, faces, tris
567
+
568
+
569
+ def process_mesh(vertices, faces, quantization_bits=8, recenter_mesh=True, remove_du=True):
570
+ """Process mesh vertices and faces."""
571
+
572
+
573
+
574
+ # Transpose so that z-axis is vertical.
575
+ vertices = vertices[:, [2, 0, 1]]
576
+
577
+ if recenter_mesh:
578
+ # Translate the vertices so that bounding box is centered at zero.
579
+ vertices = center_vertices(vertices)
580
+
581
+ # Scale the vertices so that the long diagonal of the bounding box is equal
582
+ # to one.
583
+ vertices = normalize_vertices_scale(vertices)
584
+
585
+ # Quantize and sort vertices, remove resulting duplicates, sort and reindex
586
+ # faces.
587
+ vertices, faces, _ = quantize_process_mesh(
588
+ vertices, faces, quantization_bits=quantization_bits, remove_du=remove_du) ##### quantize_process_mesh
589
+
590
+ # unflatten_faces = np.array(faces, dtype=np.long) ### start from zero
591
+
592
+ # Flatten faces and add 'new face' = 1 and 'stop' = 0 tokens.
593
+ faces = flatten_faces(faces)
594
+
595
+ # Discard degenerate meshes without faces.
596
+ return {
597
+ 'vertices': vertices,
598
+ 'faces': faces,
599
+ }
600
+
601
+
602
+ def process_mesh_ours(vertices, faces, quantization_bits=8, recenter_mesh=True, remove_du=True):
603
+ """Process mesh vertices and faces."""
604
+ # Transpose so that z-axis is vertical.
605
+ vertices = vertices[:, [2, 0, 1]]
606
+
607
+ if recenter_mesh:
608
+ # Translate the vertices so that bounding box is centered at zero.
609
+ vertices = center_vertices(vertices)
610
+
611
+ # Scale the vertices so that the long diagonal of the bounding box is equal
612
+ # to one.
613
+ vertices = normalize_vertices_scale(vertices)
614
+
615
+ # Quantize and sort vertices, remove resulting duplicates, sort and reindex
616
+ # faces.
617
+ quant_vertices, faces, _ = quantize_process_mesh(
618
+ vertices, faces, quantization_bits=quantization_bits, remove_du=remove_du) ##### quantize_process_mesh
619
+ vertices = dequantize_verts(quant_vertices) #### dequantize vertices ####
620
+ ### vertices: nn_verts x 3
621
+ # try:
622
+ # # print("faces", faces)
623
+ # unflatten_faces = np.array(faces, dtype=np.long)
624
+ # except:
625
+ # print("faces", faces)
626
+ # raise ValueError("Something bad happened when processing meshes...")
627
+
628
+ # Flatten faces and add 'new face' = 1 and 'stop' = 0 tokens.
629
+
630
+ faces = flatten_faces(faces)
631
+
632
+ # Discard degenerate meshes without faces.
633
+ return {
634
+ 'vertices': quant_vertices,
635
+ 'faces': faces,
636
+ # 'unflatten_faces': unflatten_faces,
637
+ 'dequant_vertices': vertices,
638
+ 'class_label': 0
639
+ }
640
+
641
+ def read_mesh_from_obj_file(fn, quantization_bits=8, recenter_mesh=True, remove_du=True):
642
+ vertices, faces = read_obj(fn)
643
+ # print(vertices.shape)
644
+ mesh_dict = process_mesh_ours(vertices, faces, quantization_bits=quantization_bits, recenter_mesh=recenter_mesh, remove_du=remove_du)
645
+ return mesh_dict
646
+
647
+ def process_mesh_list(vertices, faces, quantization_bits=8, recenter_mesh=True):
648
+ """Process mesh vertices and faces."""
649
+
650
+ vertices = [cur_vert[:, [2, 0, 1]] for cur_vert in vertices]
651
+
652
+ tot_vertices = np.concatenate(vertices, axis=0) # center and scale of those vertices
653
+ vert_center = get_vertices_center(tot_vertices)
654
+ vert_scale = get_vertices_scale(tot_vertices)
655
+
656
+ processed_vertices, processed_faces = [], []
657
+
658
+ for cur_verts, cur_faces in zip(vertices, faces):
659
+ # print(f"current vertices: {cur_verts.shape}, faces: {len(cur_faces)}")
660
+ normalized_verts = (cur_verts - vert_center) / vert_scale
661
+ cur_processed_verts, cur_processed_faces, _ = quantize_process_mesh(
662
+ normalized_verts, cur_faces, quantization_bits=quantization_bits
663
+ )
664
+ processed_vertices.append(cur_processed_verts)
665
+ processed_faces.append(cur_processed_faces)
666
+ vertices, faces = merge_meshes(processed_vertices, processed_faces)
667
+ faces = flatten_faces(faces=faces)
668
+
669
+
670
+ # Discard degenerate meshes without faces.
671
+ return {
672
+ 'vertices': vertices,
673
+
674
+ 'faces': faces,
675
+
676
+ }
677
+
678
+
679
+ def plot_sampled_meshes(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
680
+
681
+ if not os.path.exists(sv_mesh_folder):
682
+ os.mkdir(sv_mesh_folder)
683
+
684
+ part_vertex_samples = [v_sample['left'], v_sample['rgt']]
685
+ part_face_samples = [f_sample['left'], f_sample['rgt']]
686
+
687
+
688
+ tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
689
+ tot_n_part = 2
690
+
691
+ if predict_joint:
692
+ pred_dir = v_sample['joint_dir']
693
+ pred_pvp = v_sample['joint_pvp']
694
+ print("pred_dir", pred_dir.shape, pred_dir)
695
+ print("pred_pvp", pred_pvp.shape, pred_pvp)
696
+ else:
697
+ pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
698
+
699
+
700
+
701
+
702
+ tot_mesh_list = []
703
+ for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
704
+ mesh_list = []
705
+ for i_n in range(tot_n_samples):
706
+ mesh_list.append(
707
+ {
708
+ 'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
709
+ 'faces': unflatten_faces(
710
+ cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
711
+ }
712
+ )
713
+ tot_mesh_list.append(mesh_list)
714
+ # and write this obj file?
715
+ # write_obj(vertices, faces, file_path, transpose=True, scale=1.):
716
+ # write mesh objs
717
+ for i_n in range(tot_n_samples):
718
+ cur_mesh = mesh_list[i_n]
719
+ cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
720
+ cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
721
+ if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
722
+ write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
723
+
724
+
725
+
726
+ ###### plot mesh (predicted) ######
727
+ tot_samples_mesh_dict = []
728
+ for i_s in range(tot_n_samples):
729
+ cur_s_tot_vertices = []
730
+ cur_s_tot_faces = []
731
+ cur_s_n_vertices = 0
732
+
733
+ for i_p in range(tot_n_part):
734
+ cur_s_cur_part_mesh_dict = tot_mesh_list[i_p][i_s]
735
+ cur_s_cur_part_vertices, cur_s_cur_part_faces = cur_s_cur_part_mesh_dict['vertices'], \
736
+ cur_s_cur_part_mesh_dict['faces']
737
+ cur_s_cur_part_new_faces = []
738
+ for cur_s_cur_part_cur_face in cur_s_cur_part_faces:
739
+ cur_s_cur_part_cur_new_face = [fid + cur_s_n_vertices for fid in cur_s_cur_part_cur_face]
740
+ cur_s_cur_part_new_faces.append(cur_s_cur_part_cur_new_face)
741
+ cur_s_n_vertices += cur_s_cur_part_vertices.shape[0]
742
+ cur_s_tot_vertices.append(cur_s_cur_part_vertices)
743
+ cur_s_tot_faces += cur_s_cur_part_new_faces
744
+
745
+ cur_s_tot_vertices = np.concatenate(cur_s_tot_vertices, axis=0)
746
+ cur_s_mesh_dict = {
747
+ 'vertices': cur_s_tot_vertices, 'faces': cur_s_tot_faces
748
+ }
749
+ tot_samples_mesh_dict.append(cur_s_mesh_dict)
750
+
751
+ for i_s in range(tot_n_samples):
752
+ cur_mesh = tot_samples_mesh_dict[i_s]
753
+ cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
754
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_s}.obj")
755
+ if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
756
+ write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
757
+ ###### plot mesh (predicted) ######
758
+
759
+
760
+ ###### plot mesh (translated) ######
761
+ tot_samples_mesh_dict = []
762
+ for i_s in range(tot_n_samples):
763
+ cur_s_tot_vertices = []
764
+ cur_s_tot_faces = []
765
+ cur_s_n_vertices = 0
766
+ cur_s_pred_pvp = pred_pvp[i_s]
767
+
768
+ for i_p in range(tot_n_part):
769
+ cur_s_cur_part_mesh_dict = tot_mesh_list[i_p][i_s]
770
+ cur_s_cur_part_vertices, cur_s_cur_part_faces = cur_s_cur_part_mesh_dict['vertices'], \
771
+ cur_s_cur_part_mesh_dict['faces']
772
+ cur_s_cur_part_new_faces = []
773
+ for cur_s_cur_part_cur_face in cur_s_cur_part_faces:
774
+ cur_s_cur_part_cur_new_face = [fid + cur_s_n_vertices for fid in cur_s_cur_part_cur_face]
775
+ cur_s_cur_part_new_faces.append(cur_s_cur_part_cur_new_face)
776
+ cur_s_n_vertices += cur_s_cur_part_vertices.shape[0]
777
+
778
+ if i_p == 1:
779
+ # min_rngs = cur_s_cur_part_vertices.min(1)
780
+ # max_rngs = cur_s_cur_part_vertices.max(1)
781
+ min_rngs = cur_s_cur_part_vertices.min(0)
782
+ max_rngs = cur_s_cur_part_vertices.max(0)
783
+ # shifted; cur_s_pred_pvp
784
+ # shifted = np.array([0., cur_s_pred_pvp[1] - max_rngs[1], cur_s_pred_pvp[2] - min_rngs[2]], dtype=np.float)
785
+ # shifted = np.reshape(shifted, [1, 3]) #
786
+ cur_s_pred_pvp = np.array([0., max_rngs[1], min_rngs[2]], dtype=np.float32)
787
+ pvp_sample_pred_err = np.sum((cur_s_pred_pvp - pred_pvp[i_s]) ** 2)
788
+ # print prediction err, pred pvp and real pvp
789
+ # print("cur_s, sample_pred_pvp_err:", pvp_sample_pred_err.item(), ";real val:", cur_s_pred_pvp, "; pred_val:", pred_pvp[i_s])
790
+ pred_pvp[i_s] = cur_s_pred_pvp
791
+ shifted = np.zeros((1, 3), dtype=np.float32)
792
+ cur_s_cur_part_vertices = cur_s_cur_part_vertices + shifted # shift vertices... # min_rngs
793
+ # shifted
794
+ cur_s_tot_vertices.append(cur_s_cur_part_vertices)
795
+ cur_s_tot_faces += cur_s_cur_part_new_faces
796
+
797
+ cur_s_tot_vertices = np.concatenate(cur_s_tot_vertices, axis=0)
798
+ cur_s_mesh_dict = {
799
+ 'vertices': cur_s_tot_vertices, 'faces': cur_s_tot_faces
800
+ }
801
+ tot_samples_mesh_dict.append(cur_s_mesh_dict)
802
+
803
+ for i_s in range(tot_n_samples):
804
+ cur_mesh = tot_samples_mesh_dict[i_s]
805
+ cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
806
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_s}_shifted.obj")
807
+ if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
808
+ write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
809
+ ###### plot mesh (translated) ######
810
+
811
+
812
+
813
+ ###### plot mesh (rotated) ######
814
+ if predict_joint:
815
+ from revolute_transform import revoluteTransform
816
+ tot_samples_mesh_dict = []
817
+ for i_s in range(tot_n_samples):
818
+ cur_s_tot_vertices = []
819
+ cur_s_tot_faces = []
820
+ cur_s_n_vertices = 0
821
+
822
+ # cur_s_pred_dir = pred_dir[i_s]
823
+ cur_s_pred_pvp = pred_pvp[i_s]
824
+ print("current pred dir:", cur_s_pred_dir, "; current pred pvp:", cur_s_pred_pvp)
825
+ cur_s_pred_dir = np.array([1.0, 0.0, 0.0], dtype=np.float)
826
+ # cur_s_pred_pvp = cur_s_pred_pvp[[1, 2, 0]]
827
+
828
+ for i_p in range(tot_n_part):
829
+ cur_s_cur_part_mesh_dict = tot_mesh_list[i_p][i_s]
830
+ cur_s_cur_part_vertices, cur_s_cur_part_faces = cur_s_cur_part_mesh_dict['vertices'], \
831
+ cur_s_cur_part_mesh_dict['faces']
832
+
833
+ if i_p == 1:
834
+ cur_s_cur_part_vertices, _ = revoluteTransform(cur_s_cur_part_vertices, cur_s_pred_pvp, cur_s_pred_dir, 0.5 * np.pi) # reverse revolute vertices of the upper piece
835
+ cur_s_cur_part_vertices = cur_s_cur_part_vertices[:, :3] #
836
+ cur_s_cur_part_new_faces = []
837
+ for cur_s_cur_part_cur_face in cur_s_cur_part_faces:
838
+ cur_s_cur_part_cur_new_face = [fid + cur_s_n_vertices for fid in cur_s_cur_part_cur_face]
839
+ cur_s_cur_part_new_faces.append(cur_s_cur_part_cur_new_face)
840
+ cur_s_n_vertices += cur_s_cur_part_vertices.shape[0]
841
+ cur_s_tot_vertices.append(cur_s_cur_part_vertices)
842
+ # print(f"i_s: {i_s}, i_p: {i_p}, n_vertices: {cur_s_cur_part_vertices.shape[0]}")
843
+ cur_s_tot_faces += cur_s_cur_part_new_faces
844
+
845
+ cur_s_tot_vertices = np.concatenate(cur_s_tot_vertices, axis=0)
846
+ # print(f"i_s: {i_s}, n_cur_s_tot_vertices: {cur_s_tot_vertices.shape[0]}")
847
+ cur_s_mesh_dict = {
848
+ 'vertices': cur_s_tot_vertices, 'faces': cur_s_tot_faces
849
+ }
850
+ tot_samples_mesh_dict.append(cur_s_mesh_dict)
851
+ # plot_meshes(tot_samples_mesh_dict, ax_lims=0.5, mesh_sv_fn=f"./figs/training_step_{n}_part_{tot_n_part}_rot.png") # plot the mesh;
852
+ for i_s in range(tot_n_samples):
853
+ cur_mesh = tot_samples_mesh_dict[i_s]
854
+ cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
855
+ # rotated mesh fn
856
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_s}_rot.obj")
857
+ # write object in the file...
858
+ if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
859
+ write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
860
+
861
+
862
+
863
+
864
+ def sample_pts_from_mesh(vertices, faces, npoints=512, minus_one=True):
865
+ return vertices
866
+ sampled_pcts = []
867
+ pts_to_seg_idx = []
868
+ # triangles and pts
869
+ minus_val = 0 if not minus_one else 1
870
+ for i in range(len(faces)): #
871
+ cur_face = faces[i]
872
+ n_tris = len(cur_face) - 2
873
+ v_as, v_bs, v_cs = [cur_face[0] for _ in range(n_tris)], cur_face[1: len(cur_face) - 1], cur_face[2: len(cur_face)]
874
+ for v_a, v_b, v_c in zip(v_as, v_bs, v_cs):
875
+
876
+ v_a, v_b, v_c = vertices[v_a - minus_val], vertices[v_b - minus_val], vertices[v_c - minus_val]
877
+ ab, ac = v_b - v_a, v_c - v_a
878
+ cos_ab_ac = (np.sum(ab * ac) / np.clip(np.sqrt(np.sum(ab ** 2)) * np.sqrt(np.sum(ac ** 2)), a_min=1e-9,
879
+ a_max=9999999.)).item()
880
+ sin_ab_ac = math.sqrt(min(max(0., 1. - cos_ab_ac ** 2), 1.))
881
+ cur_area = 0.5 * sin_ab_ac * np.sqrt(np.sum(ab ** 2)).item() * np.sqrt(np.sum(ac ** 2)).item()
882
+
883
+ cur_sampled_pts = int(cur_area * npoints)
884
+ cur_sampled_pts = 1 if cur_sampled_pts == 0 else cur_sampled_pts
885
+ # if cur_sampled_pts == 0:
886
+
887
+ tmp_x, tmp_y = np.random.uniform(0, 1., (cur_sampled_pts,)).tolist(), np.random.uniform(0., 1., (
888
+ cur_sampled_pts,)).tolist()
889
+
890
+ for xx, yy in zip(tmp_x, tmp_y):
891
+ sqrt_xx, sqrt_yy = math.sqrt(xx), math.sqrt(yy)
892
+ aa = 1. - sqrt_xx
893
+ bb = sqrt_xx * (1. - yy)
894
+ cc = yy * sqrt_xx
895
+ cur_pos = v_a * aa + v_b * bb + v_c * cc
896
+ sampled_pcts.append(cur_pos)
897
+ # pts_to_seg_idx.append(cur_tri_seg)
898
+
899
+ # seg_idx_to_sampled_pts = {}
900
+ sampled_pcts = np.array(sampled_pcts, dtype=np.float)
901
+
902
+ return sampled_pcts
903
+
904
+
905
+ def fps_fr_numpy(np_pts, n_sampling=4096):
906
+ pts = torch.from_numpy(np_pts).float().cuda()
907
+ pts_fps_idx = farthest_point_sampling(pts.unsqueeze(0), n_sampling=n_sampling) # farthes points sampling ##
908
+ pts = pts[pts_fps_idx].cpu()
909
+ return pts
910
+
911
+
912
+ def farthest_point_sampling(pos: torch.FloatTensor, n_sampling: int):
913
+ bz, N = pos.size(0), pos.size(1)
914
+ feat_dim = pos.size(-1)
915
+ device = pos.device
916
+ sampling_ratio = float(n_sampling / N)
917
+ pos_float = pos.float()
918
+
919
+ batch = torch.arange(bz, dtype=torch.long).view(bz, 1).to(device)
920
+ mult_one = torch.ones((N,), dtype=torch.long).view(1, N).to(device)
921
+
922
+ batch = batch * mult_one
923
+ batch = batch.view(-1)
924
+ pos_float = pos_float.contiguous().view(-1, feat_dim).contiguous() # (bz x N, 3)
925
+ # sampling_ratio = torch.tensor([sampling_ratio for _ in range(bz)], dtype=torch.float).to(device)
926
+ # batch = torch.zeros((N, ), dtype=torch.long, device=device)
927
+ sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False)
928
+ # shape of sampled_idx?
929
+ return sampled_idx
930
+
931
+
932
+ def plot_sampled_meshes_single_part(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
933
+
934
+ if not os.path.exists(sv_mesh_folder):
935
+ os.mkdir(sv_mesh_folder)
936
+
937
+ part_vertex_samples = [v_sample]
938
+ part_face_samples = [f_sample]
939
+
940
+
941
+ tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
942
+ tot_n_part = 2
943
+
944
+ # not predict joints here
945
+ # if predict_joint:
946
+ # pred_dir = v_sample['joint_dir']
947
+ # pred_pvp = v_sample['joint_pvp']
948
+ # print("pred_dir", pred_dir.shape, pred_dir)
949
+ # print("pred_pvp", pred_pvp.shape, pred_pvp)
950
+ # else:
951
+ # pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
952
+
953
+
954
+ tot_mesh_list = []
955
+ for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
956
+ mesh_list = []
957
+ for i_n in range(tot_n_samples):
958
+ mesh_list.append(
959
+ {
960
+ 'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
961
+ 'faces': unflatten_faces(
962
+ cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
963
+ }
964
+ )
965
+ tot_mesh_list.append(mesh_list)
966
+
967
+ for i_n in range(tot_n_samples):
968
+ cur_mesh = mesh_list[i_n]
969
+ cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
970
+ # cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
971
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
972
+ print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
973
+ if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
974
+ write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
975
+
976
+
977
+ def plot_sampled_meshes(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
978
+
979
+ if not os.path.exists(sv_mesh_folder):
980
+ os.mkdir(sv_mesh_folder)
981
+
982
+ part_vertex_samples = [v_sample]
983
+ part_face_samples = [f_sample]
984
+
985
+
986
+ tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
987
+ # tot_n_part = 2
988
+
989
+ # not predict joints here
990
+ # if predict_joint:
991
+ # pred_dir = v_sample['joint_dir']
992
+ # pred_pvp = v_sample['joint_pvp']
993
+ # print("pred_dir", pred_dir.shape, pred_dir)
994
+ # print("pred_pvp", pred_pvp.shape, pred_pvp)
995
+ # else:
996
+ # pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
997
+
998
+
999
+ tot_mesh_list = []
1000
+ for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
1001
+ mesh_list = []
1002
+ for i_n in range(tot_n_samples):
1003
+ mesh_list.append(
1004
+ {
1005
+ 'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
1006
+ 'faces': unflatten_faces(
1007
+ cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
1008
+ }
1009
+ )
1010
+ tot_mesh_list.append(mesh_list)
1011
+
1012
+ for i_n in range(tot_n_samples):
1013
+ cur_mesh = mesh_list[i_n]
1014
+ cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
1015
+ # cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
1016
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
1017
+ print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
1018
+ if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
1019
+ write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
1020
+
1021
+
1022
+ def filter_masked_vertices(vertices, mask_indicator):
1023
+ # vertices: n_verts x 3
1024
+ mask_indicator = np.reshape(mask_indicator, (vertices.shape[0], 3))
1025
+ tot_vertices = []
1026
+ for i_v in range(vertices.shape[0]):
1027
+ cur_vert = vertices[i_v]
1028
+ cur_vert_indicator = mask_indicator[i_v][0].item()
1029
+ if cur_vert_indicator < 0.5:
1030
+ tot_vertices.append(cur_vert)
1031
+ tot_vertices = np.array(tot_vertices)
1032
+ return tot_vertices
1033
+
1034
+
1035
+ def plot_sampled_ar_subd_meshes(v_sample, f_sample, sv_mesh_folder, cur_step=0, ):
1036
+ if not os.path.exists(sv_mesh_folder): ### vertices_mask
1037
+ os.mkdir(sv_mesh_folder)
1038
+ ### v_sample: bsz x nn_verts x 3
1039
+ vertices_mask = v_sample['vertices_mask']
1040
+ vertices = v_sample['vertices']
1041
+ faces = f_sample['faces']
1042
+ num_face_indices = f_sample['num_face_indices'] #### num_faces_indices
1043
+ bsz = vertices.shape[0]
1044
+
1045
+ for i_bsz in range(bsz):
1046
+ cur_vertices = vertices[i_bsz]
1047
+ cur_vertices_mask = vertices_mask[i_bsz]
1048
+ cur_faces = faces[i_bsz]
1049
+ cur_num_face_indices = num_face_indices[i_bsz]
1050
+ cur_nn_verts = cur_vertices_mask.sum(-1).item()
1051
+ cur_nn_verts = int(cur_nn_verts)
1052
+ cur_vertices = cur_vertices[:cur_nn_verts]
1053
+ cur_faces = unflatten_faces(
1054
+ cur_faces[:int(cur_num_face_indices)])
1055
+
1056
+ cur_num_faces = len(cur_faces)
1057
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_inst_{i_bsz}.obj")
1058
+ # cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}_context.obj")
1059
+ print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_nn_verts}, nn_faces: {cur_num_faces}")
1060
+ if cur_nn_verts > 0 and cur_num_faces > 0:
1061
+ write_obj(cur_vertices, cur_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
1062
+
1063
+
1064
+
1065
+ def plot_sampled_meshes_single_part_for_pretraining(v_sample, f_sample, context, sv_mesh_folder, cur_step=0, predict_joint=True, with_context=True):
1066
+
1067
+ if not os.path.exists(sv_mesh_folder):
1068
+ os.mkdir(sv_mesh_folder)
1069
+
1070
+ part_vertex_samples = [v_sample]
1071
+ part_face_samples = [f_sample]
1072
+
1073
+ context_vertices = [context['vertices']]
1074
+ context_faces = [context['faces']]
1075
+ context_vertices_mask = [context['vertices_mask']]
1076
+ context_faces_mask = [context['faces_mask']]
1077
+
1078
+
1079
+ tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
1080
+ tot_n_part = 2
1081
+
1082
+ # not predict joints here
1083
+ # if predict_joint:
1084
+ # pred_dir = v_sample['joint_dir']
1085
+ # pred_pvp = v_sample['joint_pvp']
1086
+ # print("pred_dir", pred_dir.shape, pred_dir)
1087
+ # print("pred_pvp", pred_pvp.shape, pred_pvp)
1088
+ # else:
1089
+ # pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
1090
+
1091
+ #
1092
+
1093
+
1094
+ tot_mesh_list = []
1095
+ for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
1096
+ mesh_list = []
1097
+ context_mesh_list = []
1098
+ for i_n in range(tot_n_samples):
1099
+ mesh_list.append(
1100
+ {
1101
+ 'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
1102
+ 'faces': unflatten_faces(
1103
+ cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
1104
+ }
1105
+ )
1106
+
1107
+ cur_context_vertices = context_vertices[i_p][i_n]
1108
+ cur_context_faces = context_faces[i_p][i_n]
1109
+ cur_context_vertices_mask = context_vertices_mask[i_p][i_n]
1110
+ cur_context_faces_mask = context_faces_mask[i_p][i_n]
1111
+ cur_nn_vertices = np.sum(cur_context_vertices_mask).item()
1112
+ cur_nn_faces = np.sum(cur_context_faces_mask).item()
1113
+ cur_nn_vertices, cur_nn_faces = int(cur_nn_vertices), int(cur_nn_faces)
1114
+ cur_context_vertices = cur_context_vertices[:cur_nn_vertices]
1115
+ if 'mask_vertices_flat_indicator' in context:
1116
+ cur_context_vertices_mask_indicator = context['mask_vertices_flat_indicator'][i_n]
1117
+ cur_context_vertices_mask_indicator = cur_context_vertices_mask_indicator[:cur_nn_vertices * 3]
1118
+ cur_context_vertices = filter_masked_vertices(cur_context_vertices, cur_context_vertices_mask_indicator)
1119
+ cur_context_faces = cur_context_faces[:cur_nn_faces] # context faces...
1120
+ context_mesh_dict = {
1121
+ 'vertices': dequantize_verts(cur_context_vertices, n_bits=8), 'faces': unflatten_faces(cur_context_faces)
1122
+ }
1123
+ context_mesh_list.append(context_mesh_dict)
1124
+
1125
+ tot_mesh_list.append(mesh_list)
1126
+
1127
+ # if with_context:
1128
+ for i_n in range(tot_n_samples):
1129
+ cur_mesh = mesh_list[i_n]
1130
+ cur_context_mesh = context_mesh_list[i_n]
1131
+ cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
1132
+ cur_context_vertices, cur_context_faces = cur_context_mesh['vertices'], cur_context_mesh['faces']
1133
+ # cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
1134
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
1135
+ cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}_context.obj")
1136
+ print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
1137
+ if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
1138
+ write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
1139
+ if cur_context_vertices.shape[0] > 0 and len(cur_context_faces) > 0:
1140
+ write_obj(cur_context_vertices, cur_context_faces, cur_context_mesh_sv_fn, transpose=True, scale=1.)
1141
+
1142
+
1143
+ def plot_grids_for_pretraining_ml(v_sample, sv_mesh_folder="", cur_step=0, context=None):
1144
+
1145
+ if not os.path.exists(sv_mesh_folder):
1146
+ os.mkdir(sv_mesh_folder)
1147
+
1148
+ mesh_list = []
1149
+ context_mesh_list = []
1150
+ tot_n_samples = v_sample['vertices'].shape[0]
1151
+
1152
+ for i_n in range(tot_n_samples):
1153
+ mesh_list.append(
1154
+ {
1155
+ 'vertices': v_sample['vertices'][i_n][:v_sample['num_vertices'][i_n]],
1156
+ 'faces': []
1157
+ }
1158
+ )
1159
+
1160
+ cur_context_vertices = context['vertices'][i_n]
1161
+ cur_context_vertices_mask = context['vertices_mask'][i_n]
1162
+ cur_nn_vertices = np.sum(cur_context_vertices_mask).item()
1163
+ cur_nn_vertices = int(cur_nn_vertices)
1164
+ cur_context_vertices = cur_context_vertices[:cur_nn_vertices]
1165
+ if 'mask_vertices_flat_indicator' in context:
1166
+ cur_context_vertices_mask_indicator = context['mask_vertices_flat_indicator'][i_n]
1167
+ cur_context_vertices_mask_indicator = cur_context_vertices_mask_indicator[:cur_nn_vertices * 3]
1168
+ cur_context_vertices = filter_masked_vertices(cur_context_vertices, cur_context_vertices_mask_indicator)
1169
+ context_mesh_dict = {
1170
+ 'vertices': dequantize_verts(cur_context_vertices, n_bits=8), 'faces': []
1171
+ }
1172
+ context_mesh_list.append(context_mesh_dict)
1173
+
1174
+ # tot_mesh_list.append(mesh_list)
1175
+
1176
+ # if with_context:
1177
+ for i_n in range(tot_n_samples):
1178
+ cur_mesh = mesh_list[i_n]
1179
+ cur_context_mesh = context_mesh_list[i_n]
1180
+ cur_mesh_vertices = cur_mesh['vertices']
1181
+ cur_context_vertices = cur_context_mesh['vertices']
1182
+ # cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
1183
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_n}.obj")
1184
+ cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_n}_context.obj")
1185
+ # print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
1186
+ print(f"saving the sample to {cur_mesh_sv_fn}, context sample to {cur_context_mesh_sv_fn}")
1187
+ if cur_mesh_vertices.shape[0] > 0:
1188
+ write_obj(cur_mesh_vertices, None, cur_mesh_sv_fn, transpose=True, scale=1.)
1189
+ if cur_context_vertices.shape[0] > 0:
1190
+ write_obj(cur_context_vertices, None, cur_context_mesh_sv_fn, transpose=True, scale=1.)
1191
+
1192
+
1193
+ def get_grid_content_from_grids(grid_xyzs, grid_values, grid_size=2):
1194
+ cur_bsz_grid_xyzs = grid_xyzs # grid_length x 3 # grids pts for a sinlge batch
1195
+ cur_bsz_grid_values = grid_values # grid_length x gs x gs x gs
1196
+ pts = []
1197
+ for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
1198
+ cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
1199
+ if cur_grid_xyz[0] == -1 or cur_grid_xyz[1] == -1 or cur_grid_xyz[2] == -1:
1200
+ break
1201
+ if len(cur_bsz_grid_values.shape) > 1:
1202
+ cur_grid_values = cur_bsz_grid_values[i_grid]
1203
+ else:
1204
+ cur_grid_content = cur_bsz_grid_values[i_grid].item()
1205
+ if cur_grid_content >= MASK_GRID_VALIE:
1206
+ continue
1207
+ inde = 2
1208
+ cur_grid_values = []
1209
+ for i_s in range(grid_size ** 3):
1210
+ cur_mod_value = cur_grid_content % inde
1211
+ cur_grid_content = cur_grid_content // inde
1212
+ cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
1213
+ cur_grid_values = np.array(cur_grid_values, dtype=np.long)
1214
+ cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
1215
+ # if words
1216
+ # flip words
1217
+ for i_x in range(cur_grid_values.shape[0]):
1218
+ for i_y in range(cur_grid_values.shape[1]):
1219
+ for i_z in range(cur_grid_values.shape[2]):
1220
+ cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
1221
+ if cur_grid_xyz_value > 0.5:
1222
+ cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
1223
+ pts.append([cur_x, cur_y, cur_z])
1224
+ return pts
1225
+
1226
+ def plot_grids_for_pretraining(v_sample, sv_mesh_folder="", cur_step=0, context=None, sv_mesh_fn=None):
1227
+
1228
+ ##### plot grids
1229
+ if not os.path.exists(sv_mesh_folder):
1230
+ os.mkdir(sv_mesh_folder)
1231
+
1232
+ # part_vertex_samples = [v_sample] # vertex samples
1233
+ # part_face_samples = [f_sample] # face samples
1234
+
1235
+ grid_xyzs = v_sample['grid_xyzs']
1236
+ grid_values = v_sample['grid_values']
1237
+
1238
+ bsz = grid_xyzs.shape[0]
1239
+ grid_size = opt.vertex_model.grid_size
1240
+
1241
+
1242
+ for i_bsz in range(bsz):
1243
+ cur_bsz_grid_xyzs = grid_xyzs[i_bsz] # grid_length x 3
1244
+ cur_bsz_grid_values = grid_values[i_bsz] # grid_length x gs x gs x gs
1245
+ pts = []
1246
+ for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
1247
+ cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
1248
+ if cur_grid_xyz[0] == -1 or cur_grid_xyz[1] == -1 or cur_grid_xyz[2] == -1:
1249
+ break
1250
+ if len(cur_bsz_grid_values.shape) > 1:
1251
+ cur_grid_values = cur_bsz_grid_values[i_grid]
1252
+ else:
1253
+ cur_grid_content = cur_bsz_grid_values[i_grid].item()
1254
+ if cur_grid_content >= MASK_GRID_VALIE:
1255
+ continue
1256
+ inde = 2
1257
+ cur_grid_values = []
1258
+ for i_s in range(grid_size ** 3):
1259
+ cur_mod_value = cur_grid_content % inde
1260
+ cur_grid_content = cur_grid_content // inde
1261
+ cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
1262
+ cur_grid_values = np.array(cur_grid_values, dtype=np.long)
1263
+ cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
1264
+ # if
1265
+ for i_x in range(cur_grid_values.shape[0]):
1266
+ for i_y in range(cur_grid_values.shape[1]):
1267
+ for i_z in range(cur_grid_values.shape[2]):
1268
+ cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
1269
+ if cur_grid_xyz_value > 0.5:
1270
+ cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
1271
+ pts.append([cur_x, cur_y, cur_z])
1272
+
1273
+
1274
+ if len(pts) == 0:
1275
+ print("zzz, len(pts) == 0")
1276
+ continue
1277
+ pts = np.array(pts, dtype=np.float32)
1278
+ # pts = center_vertices(pts)
1279
+ # pts = normalize_vertices_scale(pts)
1280
+ pts = pts[:, [2, 1, 0]]
1281
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}.obj" if sv_mesh_fn is None else sv_mesh_fn)
1282
+
1283
+ print(f"Saving obj to {cur_mesh_sv_fn}")
1284
+ write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
1285
+
1286
+
1287
+ def plot_grids_for_pretraining_obj_corpus(v_sample, sv_mesh_folder="", cur_step=0, context=None, sv_mesh_fn=None):
1288
+
1289
+ ##### plot grids
1290
+ if not os.path.exists(sv_mesh_folder):
1291
+ os.mkdir(sv_mesh_folder)
1292
+
1293
+ # part_vertex_samples = [v_sample] # vertex samples
1294
+ # part_face_samples = [f_sample] # face samples
1295
+
1296
+ grid_xyzs = v_sample['grid_xyzs']
1297
+ grid_values = v_sample['grid_values']
1298
+
1299
+ bsz = grid_xyzs.shape[0]
1300
+ grid_size = opt.vertex_model.grid_size
1301
+
1302
+
1303
+ for i_bsz in range(bsz):
1304
+ cur_bsz_grid_xyzs = grid_xyzs[i_bsz] # grid_length x 3
1305
+ cur_bsz_grid_values = grid_values[i_bsz] # grid_length x gs x gs x gs
1306
+ part_pts = []
1307
+ pts = []
1308
+ for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
1309
+ cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
1310
+ ##### grid_xyz; grid_
1311
+ if cur_grid_xyz[0] == -1 and cur_grid_xyz[1] == -1 and cur_grid_xyz[2] == -1:
1312
+ part_pts.append(pts)
1313
+ pts = []
1314
+ continue
1315
+ ##### cur_grid_xyz... #####
1316
+ elif not (cur_grid_xyz[0] >= 0 and cur_grid_xyz[1] >= 0 and cur_grid_xyz[2] >= 0):
1317
+ continue
1318
+ if len(cur_bsz_grid_values.shape) > 1:
1319
+ cur_grid_values = cur_bsz_grid_values[i_grid]
1320
+ else:
1321
+ cur_grid_content = cur_bsz_grid_values[i_grid].item()
1322
+ if cur_grid_content >= MASK_GRID_VALIE: # mask grid value
1323
+ continue
1324
+ inde = 2
1325
+ cur_grid_values = []
1326
+ for i_s in range(grid_size ** 3):
1327
+ cur_mod_value = cur_grid_content % inde
1328
+ cur_grid_content = cur_grid_content // inde
1329
+ cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
1330
+ cur_grid_values = np.array(cur_grid_values, dtype=np.long)
1331
+ cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
1332
+ # if
1333
+ for i_x in range(cur_grid_values.shape[0]):
1334
+ for i_y in range(cur_grid_values.shape[1]):
1335
+ for i_z in range(cur_grid_values.shape[2]):
1336
+ cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
1337
+ ##### gird-xyz-values #####
1338
+ if cur_grid_xyz_value > 0.5: # cur_grid_xyz_value
1339
+ cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
1340
+ pts.append([cur_x, cur_y, cur_z])
1341
+
1342
+ if len(pts) > 0:
1343
+ part_pts.append(pts)
1344
+ pts = []
1345
+ tot_nn_pts = sum([len(aa) for aa in part_pts])
1346
+ if tot_nn_pts == 0:
1347
+ print("zzz, tot_nn_pts == 0")
1348
+ continue
1349
+
1350
+ for i_p, pts in enumerate(part_pts):
1351
+ if len(pts) == 0:
1352
+ continue
1353
+ pts = np.array(pts, dtype=np.float32)
1354
+ pts = center_vertices(pts)
1355
+ # pts = normalize_vertices_scale(pts)
1356
+ pts = pts[:, [2, 1, 0]]
1357
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}_ip_{i_p}.obj" if sv_mesh_fn is None else sv_mesh_fn)
1358
+
1359
+ print(f"Saving {i_p}-th part obj to {cur_mesh_sv_fn}")
1360
+ write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
1361
+
1362
+
1363
+
1364
+ def plot_grids_for_pretraining_obj_part(v_sample, sv_mesh_folder="", cur_step=0, context=None, sv_mesh_fn=None):
1365
+
1366
+ ##### plot grids
1367
+ if not os.path.exists(sv_mesh_folder):
1368
+ os.mkdir(sv_mesh_folder)
1369
+
1370
+ # part_vertex_samples = [v_sample] # vertex samples
1371
+ # part_face_samples = [f_sample] # face samples
1372
+
1373
+ grid_xyzs = v_sample['grid_xyzs']
1374
+ grid_values = v_sample['grid_values']
1375
+
1376
+ bsz = grid_xyzs.shape[0]
1377
+ grid_size = opt.vertex_model.grid_size
1378
+
1379
+
1380
+ for i_bsz in range(bsz):
1381
+ cur_bsz_grid_xyzs = grid_xyzs[i_bsz] # grid_length x 3
1382
+ cur_bsz_grid_values = grid_values[i_bsz] # grid_length x gs x gs x gs
1383
+ part_pts = []
1384
+ pts = []
1385
+ for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
1386
+ cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
1387
+ ##### grid_xyz; grid_
1388
+ if cur_grid_xyz[0] == -1 and cur_grid_xyz[1] == -1 and cur_grid_xyz[2] == -1:
1389
+ part_pts.append(pts)
1390
+ pts = []
1391
+ break
1392
+ elif cur_grid_xyz[0] == -1 and cur_grid_xyz[1] == -1 and cur_grid_xyz[2] == 0:
1393
+ part_pts.append(pts)
1394
+ pts = []
1395
+ continue
1396
+ ##### cur_grid_xyz... #####
1397
+ elif not (cur_grid_xyz[0] >= 0 and cur_grid_xyz[1] >= 0 and cur_grid_xyz[2] >= 0):
1398
+ continue
1399
+ if len(cur_bsz_grid_values.shape) > 1:
1400
+ cur_grid_values = cur_bsz_grid_values[i_grid]
1401
+ else:
1402
+ cur_grid_content = cur_bsz_grid_values[i_grid].item()
1403
+ if cur_grid_content >= MASK_GRID_VALIE: # invalid jor dummy content value s
1404
+ continue
1405
+ inde = 2
1406
+ cur_grid_values = []
1407
+ for i_s in range(grid_size ** 3):
1408
+ cur_mod_value = cur_grid_content % inde
1409
+ cur_grid_content = cur_grid_content // inde
1410
+ cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
1411
+ cur_grid_values = np.array(cur_grid_values, dtype=np.long)
1412
+ cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
1413
+ # if
1414
+ for i_x in range(cur_grid_values.shape[0]):
1415
+ for i_y in range(cur_grid_values.shape[1]):
1416
+ for i_z in range(cur_grid_values.shape[2]):
1417
+ cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
1418
+ ##### gird-xyz-values #####
1419
+ if cur_grid_xyz_value > 0.5: # cur_grid_xyz_value
1420
+ cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
1421
+ pts.append([cur_x, cur_y, cur_z])
1422
+
1423
+ if len(pts) > 0:
1424
+ part_pts.append(pts)
1425
+ pts = []
1426
+ tot_nn_pts = sum([len(aa) for aa in part_pts])
1427
+ if tot_nn_pts == 0:
1428
+ print("zzz, tot_nn_pts == 0")
1429
+ continue
1430
+
1431
+ for i_p, pts in enumerate(part_pts):
1432
+ if len(pts) == 0:
1433
+ continue
1434
+ pts = np.array(pts, dtype=np.float32)
1435
+ pts = center_vertices(pts)
1436
+ # pts = normalize_vertices_scale(pts)
1437
+ pts = pts[:, [2, 1, 0]]
1438
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}_ip_{i_p}.obj" if sv_mesh_fn is None else sv_mesh_fn)
1439
+
1440
+ print(f"Saving {i_p}-th part obj to {cur_mesh_sv_fn}")
1441
+ write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
1442
+
1443
+
1444
+ def plot_grids_for_pretraining_ml(v_sample, sv_mesh_folder="", cur_step=0, context=None):
1445
+
1446
+ if not os.path.exists(sv_mesh_folder):
1447
+ os.mkdir(sv_mesh_folder)
1448
+
1449
+ # part_vertex_samples = [v_sample] # vertex samples
1450
+ # part_face_samples = [f_sample] # face samples
1451
+
1452
+ grid_xyzs = v_sample['grid_xyzs']
1453
+ grid_values = v_sample['grid_values']
1454
+
1455
+ context_grid_xyzs = context['grid_xyzs'] - 1
1456
+ # context_grid_values = context['grid_content']
1457
+ context_grid_values = context['mask_grid_content']
1458
+
1459
+ bsz = grid_xyzs.shape[0]
1460
+ grid_size = opt.vertex_model.grid_size
1461
+
1462
+
1463
+ for i_bsz in range(bsz):
1464
+ cur_bsz_grid_pts = get_grid_content_from_grids(grid_xyzs[i_bsz], grid_values[i_bsz], grid_size=grid_size)
1465
+ cur_context_grid_pts = get_grid_content_from_grids(context_grid_xyzs[i_bsz], context_grid_values[i_bsz], grid_size=grid_size)
1466
+
1467
+ if len(cur_bsz_grid_pts) > 0 and len(cur_context_grid_pts) > 0:
1468
+ cur_bsz_grid_pts = np.array(cur_bsz_grid_pts, dtype=np.float32)
1469
+ cur_bsz_grid_pts = center_vertices(cur_bsz_grid_pts)
1470
+ cur_bsz_grid_pts = normalize_vertices_scale(cur_bsz_grid_pts)
1471
+ cur_bsz_grid_pts = cur_bsz_grid_pts[:, [2, 1, 0]]
1472
+ #### plot current mesh / sampled points ####
1473
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}.obj")
1474
+ print(f"Saving predicted grid points to {cur_mesh_sv_fn}")
1475
+ write_obj(cur_bsz_grid_pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
1476
+
1477
+ cur_context_grid_pts = np.array(cur_context_grid_pts, dtype=np.float32)
1478
+ cur_context_grid_pts = center_vertices(cur_context_grid_pts)
1479
+ cur_context_grid_pts = normalize_vertices_scale(cur_context_grid_pts)
1480
+ cur_context_grid_pts = cur_context_grid_pts[:, [2, 1, 0]]
1481
+ #### plot current mesh / sampled points ####
1482
+ cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}_context.obj")
1483
+ print(f"Saving context grid points to {cur_context_mesh_sv_fn}")
1484
+ write_obj(cur_context_grid_pts, None, cur_context_mesh_sv_fn, transpose=True, scale=1.)
1485
+
1486
+ # print(f"Saving obj to {cur_mesh_sv_fn}")
1487
+ # write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
1488
+
1489
+ # if len(cur_bsz_grid_pts) == 0:
1490
+ # print("zzz, len(pts) == 0")
1491
+ # continue
1492
+ # pts = np.array(pts, dtype=np.float32)
1493
+ # pts = center_vertices(pts)
1494
+ # pts = normalize_vertices_scale(pts)
1495
+ # pts = pts[:, [2, 1, 0]]
1496
+ # cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}.obj")
1497
+
1498
+ # print(f"Saving obj to {cur_mesh_sv_fn}")
1499
+ # write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
1500
+
1501
+
1502
+
1503
+ def plot_sampled_meshes_single_part_for_sampling(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
1504
+
1505
+ if not os.path.exists(sv_mesh_folder):
1506
+ os.mkdir(sv_mesh_folder)
1507
+
1508
+ part_vertex_samples = [v_sample]
1509
+ part_face_samples = [f_sample]
1510
+
1511
+
1512
+ tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
1513
+ tot_n_part = 2
1514
+
1515
+ # not predict joints here
1516
+ # if predict_joint:
1517
+ # pred_dir = v_sample['joint_dir']
1518
+ # pred_pvp = v_sample['joint_pvp']
1519
+ # print("pred_dir", pred_dir.shape, pred_dir)
1520
+ # print("pred_pvp", pred_pvp.shape, pred_pvp)
1521
+ # else:
1522
+ # pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
1523
+
1524
+
1525
+ tot_mesh_list = []
1526
+ for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
1527
+ mesh_list = []
1528
+ for i_n in range(tot_n_samples):
1529
+ mesh_list.append(
1530
+ {
1531
+ 'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
1532
+ 'faces': unflatten_faces(
1533
+ cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
1534
+ }
1535
+ )
1536
+ tot_mesh_list.append(mesh_list)
1537
+
1538
+ for i_n in range(tot_n_samples):
1539
+ cur_mesh = mesh_list[i_n]
1540
+ cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
1541
+ # cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
1542
+ cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
1543
+ print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
1544
+ if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
1545
+ write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
1546
+
1547
+
models/dataset.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import cv2 as cv
4
+ import numpy as np
5
+ import os
6
+ from glob import glob
7
+ from icecream import ic
8
+ from scipy.spatial.transform import Rotation as Rot
9
+ from scipy.spatial.transform import Slerp
10
+
11
+
12
+ # This function is borrowed from IDR: https://github.com/lioryariv/idr
13
+ def load_K_Rt_from_P(filename, P=None):
14
+ if P is None:
15
+ lines = open(filename).read().splitlines()
16
+ if len(lines) == 4:
17
+ lines = lines[1:]
18
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
19
+ P = np.asarray(lines).astype(np.float32).squeeze()
20
+
21
+ out = cv.decomposeProjectionMatrix(P)
22
+ K = out[0]
23
+ R = out[1]
24
+ t = out[2]
25
+
26
+ K = K / K[2, 2]
27
+ intrinsics = np.eye(4)
28
+ intrinsics[:3, :3] = K
29
+
30
+ pose = np.eye(4, dtype=np.float32)
31
+ pose[:3, :3] = R.transpose()
32
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
33
+
34
+ return intrinsics, pose
35
+
36
+ def filter_iamges_via_pixel_values(data_dir):
37
+ images_lis = sorted(glob(os.path.join(data_dir, 'image/*.png'))) ## images lis ##
38
+ n_images = len(images_lis)
39
+ images_np = np.stack([cv.imread(im_name) for im_name in images_lis]) / 255.0
40
+ print(f"images_np: {images_np.shape}")
41
+ # nn_frames x res x res x 3 #
42
+ images_np = 1. - images_np
43
+ has_density_values = (np.sum(images_np, axis=-1) > 0.7).astype(np.float32)
44
+ has_density_values = np.sum(np.sum(has_density_values, axis=-1), axis=-1)
45
+ tot_res_nns = float(images_np.shape[1] * images_np.shape[2])
46
+ has_density_ratio = has_density_values / tot_res_nns ### has density ratio and ratio #
47
+ print(f"has_density_values: {has_density_values.shape}")
48
+ paried_has_density_ratio_list = [(i_fr, has_density_ratio[i_fr].item()) for i_fr in range(has_density_ratio.shape[0])]
49
+ paried_has_density_ratio_list = sorted(paried_has_density_ratio_list, key=lambda ii: ii[1], reverse=True)
50
+ mid_rnk_value = len(paried_has_density_ratio_list) // 4
51
+ print(f"mid value of the density ratio")
52
+ print(paried_has_density_ratio_list[mid_rnk_value])
53
+ iamge_idx = paried_has_density_ratio_list[mid_rnk_value][0]
54
+ print(f"iamge idx: {images_lis[iamge_idx]}")
55
+ print(paried_has_density_ratio_list[:mid_rnk_value])
56
+ tot_selected_img_idx_list = [ii[0] for ii in paried_has_density_ratio_list[:mid_rnk_value]]
57
+ tot_selected_img_idx_list =sorted(tot_selected_img_idx_list)
58
+ print(len(tot_selected_img_idx_list))
59
+ # print(tot_selected_img_idx_list[54])
60
+ print(tot_selected_img_idx_list)
61
+
62
+
63
+
64
+ class Dataset:
65
+ def __init__(self, conf):
66
+ super(Dataset, self).__init__()
67
+ print('Load data: Begin')
68
+ self.device = torch.device('cuda')
69
+ self.conf = conf
70
+
71
+ self.selected_img_idxes_list = [0, 1, 5, 6, 7, 8, 9, 13, 14, 15, 35, 36, 42, 43, 44, 48, 49, 50, 51, 55, 56, 57, 61, 62, 63, 69, 84, 90, 91, 92, 96, 97]
72
+ # self.selected_img_idxes_list = [0, 1, 5, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 28, 29, 35, 36, 37, 40, 41, 70, 71, 79, 82, 83, 84, 85, 92, 93, 96, 97, 98, 99, 105, 106, 107, 110, 111, 112, 113, 118, 119, 120, 121, 124, 125, 133, 134, 135, 139, 174, 175, 176, 177, 180, 188, 189, 190, 191, 194, 195]
73
+
74
+ self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 36, 40, 41, 70, 71, 78, 82, 83, 84, 85, 90, 91, 92, 93, 96, 97]
75
+
76
+ self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 36, 40, 41, 70, 71, 78, 82, 83, 84, 85, 90, 91, 92, 93, 96, 97, 98, 99, 104, 105, 106, 107, 110, 111, 112, 113, 118, 119, 120, 121, 124, 125, 134, 135, 139, 174, 175, 176, 177, 180, 181, 182, 183, 188, 189, 190, 191, 194, 195]
77
+
78
+ self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 20, 21, 22, 23, 26, 27, 70, 78, 83, 84, 85, 91, 92, 93, 96, 97, 98, 99, 105, 106, 107, 110, 111, 112, 113, 119, 120, 121, 124, 125, 175, 176, 181, 182, 188, 189, 190, 191, 194, 195]
79
+ # or the timestep to the dataset instance ## # selected img idxes list #
80
+ self.selected_img_idxes = np.array(self.selected_img_idxes_list).astype(np.int32)
81
+
82
+
83
+
84
+
85
+
86
+ self.data_dir = conf.get_string('data_dir')
87
+ self.render_cameras_name = conf.get_string('render_cameras_name')
88
+ self.object_cameras_name = conf.get_string('object_cameras_name')
89
+
90
+ ## camera outside sphere ##
91
+ self.camera_outside_sphere = conf.get_bool('camera_outside_sphere', default=True)
92
+ self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1)
93
+
94
+ camera_dict = np.load(os.path.join(self.data_dir, self.render_cameras_name))
95
+ # camera_dict = np.load("/home/xueyi/diffsim/NeuS/public_data/dtu_scan24/cameras_sphere.npz")
96
+ self.camera_dict = camera_dict # rendr camera dict #
97
+ # render camera dict # # number of pixels in the views -> very thin geometry is not useful
98
+ self.images_lis = sorted(glob(os.path.join(self.data_dir, 'image/*.png')))
99
+
100
+ # iamges_lis # and the images_lis and the images_lis #
101
+ # self.images_lis = self.images_lis[:1] # totoal views and poses of the camera; # and select cameras for rendering #
102
+
103
+ self.n_images = len(self.images_lis)
104
+ self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 256.0
105
+
106
+
107
+ self.selected_img_idxes_list = list(range(self.images_np.shape[0]))
108
+ self.selected_img_idxes = np.array(self.selected_img_idxes_list).astype(np.int32)
109
+
110
+ self.images_np = self.images_np[self.selected_img_idxes] ## get selected iamges_np #
111
+
112
+ ### if we deal with the backgound carefully ### ### get
113
+ self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 255.0
114
+ self.images_np = self.images_np[self.selected_img_idxes]
115
+ self.images_np = 1. - self.images_np ###
116
+
117
+
118
+ self.masks_lis = sorted(glob(os.path.join(self.data_dir, 'mask/*.png')))
119
+
120
+ # self.masks_lis = self.masks_lis[:1]
121
+
122
+ try:
123
+ self.masks_np = np.stack([cv.imread(im_name) for im_name in self.masks_lis]) / 256.0
124
+ self.masks_np = self.masks_np[self.selected_img_idxes]
125
+ except:
126
+ self.masks_np = self.images_np.copy()
127
+
128
+
129
+
130
+
131
+
132
+
133
+ # world_mat is a projection matrix from world to image
134
+ self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
135
+
136
+ self.scale_mats_np = []
137
+
138
+ # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.
139
+ self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
140
+
141
+ self.intrinsics_all = []
142
+ self.pose_all = []
143
+
144
+ # for idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, self.world_mats_np)):
145
+ for idx in self.selected_img_idxes_list:
146
+ scale_mat = self.scale_mats_np[idx]
147
+ world_mat = self.world_mats_np[idx]
148
+
149
+ if "hand" in self.data_dir:
150
+ intrinsics = np.eye(4)
151
+ fov = 512. / 2. # * 2
152
+ res = 512.
153
+ intrinsics[:3, :3] = np.array([
154
+ [fov, 0, 0.5* res], # res #
155
+ [0, fov, 0.5* res], # res #
156
+ [0, 0, 1]
157
+ ], dtype=np.float32)
158
+ pose = camera_dict['camera_mat_%d' % idx].astype(np.float32)
159
+ else:
160
+ P = world_mat @ scale_mat
161
+ P = P[:3, :4]
162
+ intrinsics, pose = load_K_Rt_from_P(None, P)
163
+
164
+ self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
165
+ self.pose_all.append(torch.from_numpy(pose).float())
166
+
167
+ ### images, masks,
168
+ self.images = torch.from_numpy(self.images_np.astype(np.float32)).cpu() # [n_images, H, W, 3] #
169
+ self.masks = torch.from_numpy(self.masks_np.astype(np.float32)).cpu() # [n_images, H, W, 3] #
170
+ self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4] # optimize sdf field # rigid model hand
171
+ self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4]
172
+ self.focal = self.intrinsics_all[0][0, 0]
173
+ self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4]
174
+ self.H, self.W = self.images.shape[1], self.images.shape[2]
175
+ self.image_pixels = self.H * self.W
176
+
177
+ object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0])
178
+ object_bbox_max = np.array([ 1.01, 1.01, 1.01, 1.0])
179
+ # Object scale mat: region of interest to **extract mesh**
180
+ object_scale_mat = np.load(os.path.join(self.data_dir, self.object_cameras_name))['scale_mat_0']
181
+ object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
182
+ object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
183
+ self.object_bbox_min = object_bbox_min[:3, 0]
184
+ self.object_bbox_max = object_bbox_max[:3, 0]
185
+
186
+ self.n_images = self.images.size(0)
187
+
188
+ print('Load data: End')
189
+
190
+ def get_rays(H, W, K, c2w, inverse_y, flip_x, flip_y, mode='center'):
191
+ i, j = torch.meshgrid( # meshgrid #
192
+ torch.linspace(0, W-1, W, device=c2w.device),
193
+ torch.linspace(0, H-1, H, device=c2w.device))
194
+ i = i.t().float()
195
+ j = j.t().float()
196
+ if mode == 'lefttop':
197
+ pass
198
+ elif mode == 'center':
199
+ i, j = i+0.5, j+0.5
200
+ elif mode == 'random':
201
+ i = i+torch.rand_like(i)
202
+ j = j+torch.rand_like(j)
203
+ else:
204
+ raise NotImplementedError
205
+
206
+ if flip_x:
207
+ i = i.flip((1,))
208
+ if flip_y:
209
+ j = j.flip((0,))
210
+ if inverse_y:
211
+ dirs = torch.stack([(i-K[0][2])/K[0][0], (j-K[1][2])/K[1][1], torch.ones_like(i)], -1)
212
+ else:
213
+ dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
214
+ # Rotate ray directions from camera frame to the world frame
215
+ rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
216
+ # Translate camera frame's origin to the world frame. It is the origin of all rays.
217
+ rays_o = c2w[:3,3].expand(rays_d.shape)
218
+ return rays_o, rays_d
219
+
220
+ def gen_rays_at(self, img_idx, resolution_level=1):
221
+ """
222
+ Generate rays at world space from one camera.
223
+ """
224
+ l = resolution_level
225
+ tx = torch.linspace(0, self.W - 1, self.W // l)
226
+ ty = torch.linspace(0, self.H - 1, self.H // l)
227
+ pixels_x, pixels_y = torch.meshgrid(tx, ty)
228
+
229
+ ##### previous method #####
230
+ # p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
231
+ # # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1) # W, H, 3
232
+ # p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
233
+ # rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
234
+ # rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
235
+ # rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3
236
+ ##### previous method #####
237
+
238
+ fov = 512.; res = 512.
239
+ K = np.array([
240
+ [fov, 0, 0.5* res],
241
+ [0, fov, 0.5* res],
242
+ [0, 0, 1]
243
+ ], dtype=np.float32)
244
+ K = torch.from_numpy(K).float().cuda()
245
+
246
+
247
+ # ### `center` mode ### #
248
+ c2w = self.pose_all[img_idx]
249
+ pixels_x, pixels_y = pixels_x+0.5, pixels_y+0.5
250
+
251
+ dirs = torch.stack([(pixels_x-K[0][2])/K[0][0], -(pixels_y-K[1][2])/K[1][1], -torch.ones_like(pixels_x)], -1)
252
+ rays_v = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
253
+ rays_o = c2w[:3,3].expand(rays_v.shape)
254
+ # dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
255
+
256
+ # p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
257
+ # # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1) # W, H, 3
258
+ # p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
259
+ # rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
260
+ # rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
261
+ # rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3
262
+ return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
263
+
264
+ def gen_random_rays_at(self, img_idx, batch_size):
265
+ """
266
+ Generate random rays at world space from one camera.
267
+ """
268
+ pixels_x = torch.randint(low=0, high=self.W, size=[batch_size])
269
+ pixels_y = torch.randint(low=0, high=self.H, size=[batch_size])
270
+ color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3
271
+
272
+ mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3
273
+
274
+
275
+ ##### previous method #####
276
+ # p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
277
+ # # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
278
+ # p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3
279
+ # rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3
280
+ # rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]).squeeze() # batch_size, 3
281
+ # rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_v.shape) # batch_size, 3
282
+ ##### previous method #####
283
+
284
+ fov = 512.; res = 512.
285
+ K = np.array([
286
+ [fov, 0, 0.5* res],
287
+ [0, fov, 0.5* res],
288
+ [0, 0, 1]
289
+ ], dtype=np.float32)
290
+ K = torch.from_numpy(K).float().cuda()
291
+
292
+
293
+ # ### `center` mode ### #
294
+ c2w = self.pose_all[img_idx]
295
+ pixels_x, pixels_y = pixels_x+0.5, pixels_y+0.5
296
+
297
+ dirs = torch.stack([(pixels_x-K[0][2])/K[0][0], -(pixels_y-K[1][2])/K[1][1], -torch.ones_like(pixels_x)], -1)
298
+ rays_v = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
299
+ rays_o = c2w[:3,3].expand(rays_v.shape)
300
+
301
+
302
+ return torch.cat([rays_o.cpu(), rays_v.cpu(), color, mask[:, :1]], dim=-1).cuda() # batch_size, 10
303
+
304
+ def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
305
+ """
306
+ Interpolate pose between two cameras.
307
+ """
308
+ l = resolution_level
309
+ tx = torch.linspace(0, self.W - 1, self.W // l)
310
+ ty = torch.linspace(0, self.H - 1, self.H // l)
311
+ pixels_x, pixels_y = torch.meshgrid(tx, ty)
312
+ p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
313
+ p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
314
+ rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
315
+ trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
316
+ pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
317
+ pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
318
+ pose_0 = np.linalg.inv(pose_0)
319
+ pose_1 = np.linalg.inv(pose_1)
320
+ rot_0 = pose_0[:3, :3]
321
+ rot_1 = pose_1[:3, :3]
322
+ rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
323
+ key_times = [0, 1]
324
+ slerp = Slerp(key_times, rots)
325
+ rot = slerp(ratio)
326
+ pose = np.diag([1.0, 1.0, 1.0, 1.0])
327
+ pose = pose.astype(np.float32)
328
+ pose[:3, :3] = rot.as_matrix()
329
+ pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
330
+ pose = np.linalg.inv(pose)
331
+ rot = torch.from_numpy(pose[:3, :3]).cuda()
332
+ trans = torch.from_numpy(pose[:3, 3]).cuda()
333
+ rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
334
+ rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
335
+ return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
336
+
337
+ def near_far_from_sphere(self, rays_o, rays_d):
338
+ a = torch.sum(rays_d**2, dim=-1, keepdim=True)
339
+ b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
340
+ mid = 0.5 * (-b) / a
341
+ near = mid - 1.0
342
+ far = mid + 1.0
343
+ return near, far
344
+
345
+ ## iamge_at ##
346
+ def image_at(self, idx, resolution_level):
347
+ if self.selected_img_idxes_list is not None:
348
+ img = cv.imread(self.images_lis[self.selected_img_idxes_list[idx]])
349
+ else:
350
+ img = cv.imread(self.images_lis[idx])
351
+ return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255)
352
+
353
+
354
+ if __name__=='__main__':
355
+ data_dir = "/data/datasets/genn/diffsim/diffredmax/save_res/goal_optimize_model_hand_sphere_test_obj_type_active_nfr_10_view_divide_0.5_n_views_7_three_planes_False_recon_dvgo_new_Nposes_7_routine_2"
356
+ data_dir = "/data/datasets/genn/diffsim/neus/public_data/hand_test"
357
+ data_dir = "/data2/datasets/diffsim/neus/public_data/hand_test_routine_2"
358
+ data_dir = "/data2/datasets/diffsim/neus/public_data/hand_test_routine_2_light_color"
359
+ filter_iamges_via_pixel_values(data_dir=data_dir)
models/dataset_wtime.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import cv2 as cv
4
+ import numpy as np
5
+ import os
6
+ from glob import glob
7
+ from icecream import ic
8
+ from scipy.spatial.transform import Rotation as Rot
9
+ from scipy.spatial.transform import Slerp
10
+
11
+
12
+ # This function is borrowed from IDR: https://github.com/lioryariv/idr
13
+ def load_K_Rt_from_P(filename, P=None):
14
+ if P is None:
15
+ lines = open(filename).read().splitlines()
16
+ if len(lines) == 4:
17
+ lines = lines[1:]
18
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
19
+ P = np.asarray(lines).astype(np.float32).squeeze()
20
+
21
+ out = cv.decomposeProjectionMatrix(P)
22
+ K = out[0]
23
+ R = out[1]
24
+ t = out[2]
25
+
26
+ K = K / K[2, 2]
27
+ intrinsics = np.eye(4)
28
+ intrinsics[:3, :3] = K
29
+
30
+ pose = np.eye(4, dtype=np.float32)
31
+ pose[:3, :3] = R.transpose()
32
+ pose[:3, 3] = (t[:3] / t[3])[:, 0]
33
+
34
+ return intrinsics, pose
35
+
36
+ def filter_iamges_via_pixel_values(data_dir):
37
+ images_lis = sorted(glob(os.path.join(data_dir, 'image/*.png'))) ## images lis ##
38
+ n_images = len(images_lis)
39
+ images_np = np.stack([cv.imread(im_name) for im_name in images_lis]) / 255.0
40
+ print(f"images_np: {images_np.shape}")
41
+ # nn_frames x res x res x 3 #
42
+ images_np = 1. - images_np
43
+ has_density_values = (np.sum(images_np, axis=-1) > 0.7).astype(np.float32)
44
+ has_density_values = np.sum(np.sum(has_density_values, axis=-1), axis=-1)
45
+ tot_res_nns = float(images_np.shape[1] * images_np.shape[2])
46
+ has_density_ratio = has_density_values / tot_res_nns ### has density ratio and ratio #
47
+ print(f"has_density_values: {has_density_values.shape}")
48
+ paried_has_density_ratio_list = [(i_fr, has_density_ratio[i_fr].item()) for i_fr in range(has_density_ratio.shape[0])]
49
+ paried_has_density_ratio_list = sorted(paried_has_density_ratio_list, key=lambda ii: ii[1], reverse=True)
50
+ mid_rnk_value = len(paried_has_density_ratio_list) // 4
51
+ print(f"mid value of the density ratio")
52
+ print(paried_has_density_ratio_list[mid_rnk_value])
53
+ iamge_idx = paried_has_density_ratio_list[mid_rnk_value][0]
54
+ print(f"iamge idx: {images_lis[iamge_idx]}")
55
+ print(paried_has_density_ratio_list[:mid_rnk_value])
56
+ tot_selected_img_idx_list = [ii[0] for ii in paried_has_density_ratio_list[:mid_rnk_value]]
57
+ tot_selected_img_idx_list =sorted(tot_selected_img_idx_list)
58
+ print(len(tot_selected_img_idx_list))
59
+ # print(tot_selected_img_idx_list[54])
60
+ print(tot_selected_img_idx_list)
61
+
62
+
63
+
64
+ class Dataset:
65
+ def __init__(self, conf, time_idx, mode='train'):
66
+ super(Dataset, self).__init__()
67
+ print('Load data: Begin')
68
+ self.device = torch.device('cuda')
69
+ self.conf = conf
70
+
71
+ self.selected_img_idxes_list = [0, 1, 5, 6, 7, 8, 9, 13, 14, 15, 35, 36, 42, 43, 44, 48, 49, 50, 51, 55, 56, 57, 61, 62, 63, 69, 84, 90, 91, 92, 96, 97]
72
+ # self.selected_img_idxes_list = [0, 1, 5, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 28, 29, 35, 36, 37, 40, 41, 70, 71, 79, 82, 83, 84, 85, 92, 93, 96, 97, 98, 99, 105, 106, 107, 110, 111, 112, 113, 118, 119, 120, 121, 124, 125, 133, 134, 135, 139, 174, 175, 176, 177, 180, 188, 189, 190, 191, 194, 195]
73
+
74
+ self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 36, 40, 41, 70, 71, 78, 82, 83, 84, 85, 90, 91, 92, 93, 96, 97]
75
+
76
+ self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 36, 40, 41, 70, 71, 78, 82, 83, 84, 85, 90, 91, 92, 93, 96, 97, 98, 99, 104, 105, 106, 107, 110, 111, 112, 113, 118, 119, 120, 121, 124, 125, 134, 135, 139, 174, 175, 176, 177, 180, 181, 182, 183, 188, 189, 190, 191, 194, 195]
77
+ # selected img idxes list #
78
+ self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 20, 21, 22, 23, 26, 27, 70, 78, 83, 84, 85, 91, 92, 93, 96, 97, 98, 99, 105, 106, 107, 110, 111, 112, 113, 119, 120, 121, 124, 125, 175, 176, 181, 182, 188, 189, 190, 191, 194, 195]
79
+ # or the timestep to the dataset instance ## # selected img idxes list #
80
+ self.selected_img_idxes = np.array(self.selected_img_idxes_list).astype(np.int32)
81
+
82
+
83
+
84
+
85
+
86
+ self.data_dir = conf.get_string('data_dir')
87
+
88
+ self.data_dir = os.path.join(self.data_dir, f"{time_idx}") # the time_idx #
89
+
90
+ self.render_cameras_name = conf.get_string('render_cameras_name')
91
+ self.object_cameras_name = conf.get_string('object_cameras_name')
92
+
93
+ ## camera outside sphere ##
94
+ self.camera_outside_sphere = conf.get_bool('camera_outside_sphere', default=True)
95
+ self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1)
96
+
97
+ camera_dict = np.load(os.path.join(self.data_dir, self.render_cameras_name))
98
+ # camera_dict = np.load("/home/xueyi/diffsim/NeuS/public_data/dtu_scan24/cameras_sphere.npz")
99
+ self.camera_dict = camera_dict # rendr camera dict #
100
+ # render camera dict # # number of pixels in the views -> very thin geometry is not useful
101
+ self.images_lis = sorted(glob(os.path.join(self.data_dir, 'image/*.png')))
102
+
103
+ # iamges_lis # and the images_lis and the images_lis #
104
+ # self.images_lis = self.images_lis[:1] # totoal views and poses of the camera; # and select cameras for rendering #
105
+
106
+ self.n_images = len(self.images_lis)
107
+
108
+ if mode == 'train_from_model_rules':
109
+ self.images_np = cv.imread(self.images_lis[0]) / 256.0
110
+ print(self.images_np.shape)
111
+ self.images_np = np.reshape(self.images_np, (1, self.images_np.shape[0], self.images_np.shape[1], self.images_np.shape[2]))
112
+ self.images_np = [self.images_np for _ in range(len(self.images_lis))]
113
+ self.images_np = np.concatenate(self.images_np, axis=0)
114
+ else:
115
+ presaved_imags_npy_fn = os.path.join(self.data_dir, "processed_images.npy")
116
+ if not os.path.exists(presaved_imags_npy_fn):
117
+ self.images_np = []
118
+ for i_im_idx, im_name in enumerate(self.images_lis):
119
+ print(f"loading {i_im_idx} / {len(self.images_lis)}")
120
+ cur_im = cv.imread(im_name) # for im_name in self.images_lis
121
+ self.images_np.append(cur_im)
122
+ self.images_np = np.stack(self.images_np) / 256.0
123
+ np.save(presaved_imags_npy_fn, self.images_np)
124
+ else:
125
+ print(f"Loading from {presaved_imags_npy_fn}")
126
+ self.images_np = np.load(presaved_imags_npy_fn, allow_pickle=True)
127
+
128
+ # self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 256.0
129
+
130
+
131
+ # self.selected_img_idxes_list = list(range(self.images_np.shape[0]))
132
+ # self.selected_img_idxes = np.array(self.selected_img_idxes_list).astype(np.int32)
133
+
134
+ # get
135
+ self.images_np = self.images_np[self.selected_img_idxes] ## get selected iamges_np #
136
+
137
+ ### if we deal with the backgound carefully ### ### get
138
+ # self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 255.0
139
+ # self.images_np = self.images_np[self.selected_img_idxes]
140
+ self.images_np = 1. - self.images_np ###
141
+
142
+
143
+
144
+ self.masks_lis = sorted(glob(os.path.join(self.data_dir, 'mask/*.png')))
145
+
146
+ if mode == 'train_from_model_rules':
147
+ self.masks_np = cv.imread(self.masks_lis[0]) / 256.0
148
+ print("masks shape:", self.masks_np.shape)
149
+ self.masks_np = np.reshape(self.masks_np, (1, self.masks_np.shape[0], self.masks_np.shape[1], self.masks_np.shape[2])) # .repeat(len(self.masks_lis), 1, 1)
150
+ self.masks_np = [self.masks_np for _ in range(len(self.masks_lis))]
151
+ self.masks_np = np.concatenate(self.masks_np, axis=0)
152
+ else:
153
+ presaved_masks_npy_fn = os.path.join(self.data_dir, "processed_masks.npy")
154
+ # self.masks_lis = self.masks_lis[:1]
155
+
156
+ if not os.path.exists(presaved_masks_npy_fn):
157
+ try:
158
+ self.masks_np = np.stack([cv.imread(im_name) for im_name in self.masks_lis]) / 256.0
159
+ self.masks_np = self.masks_np[self.selected_img_idxes]
160
+ except:
161
+ self.masks_np = self.images_np.copy()
162
+ np.save(presaved_masks_npy_fn, self.masks_np)
163
+ else:
164
+ print(f"Loading from {presaved_masks_npy_fn}")
165
+ self.masks_np = np.load(presaved_masks_npy_fn, allow_pickle=True)
166
+
167
+
168
+
169
+
170
+
171
+
172
+ # world_mat is a projection matrix from world to image
173
+ self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
174
+
175
+ self.scale_mats_np = []
176
+
177
+ # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.
178
+ self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
179
+
180
+ self.intrinsics_all = []
181
+ self.pose_all = []
182
+
183
+ # for idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, self.world_mats_np)):
184
+ for idx in self.selected_img_idxes_list:
185
+ scale_mat = self.scale_mats_np[idx]
186
+ world_mat = self.world_mats_np[idx]
187
+
188
+ if "hand" in self.data_dir:
189
+ intrinsics = np.eye(4)
190
+ fov = 512. / 2. # * 2
191
+ res = 512.
192
+ intrinsics[:3, :3] = np.array([
193
+ [fov, 0, 0.5* res], # res #
194
+ [0, fov, 0.5* res], # res #
195
+ [0, 0, 1]
196
+ ], dtype=np.float32)
197
+ pose = camera_dict['camera_mat_%d' % idx].astype(np.float32)
198
+ else:
199
+ P = world_mat @ scale_mat
200
+ P = P[:3, :4]
201
+ intrinsics, pose = load_K_Rt_from_P(None, P)
202
+
203
+ self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
204
+ self.pose_all.append(torch.from_numpy(pose).float())
205
+
206
+ ### images, masks,
207
+ self.images = torch.from_numpy(self.images_np.astype(np.float32)).cpu() # [n_images, H, W, 3] #
208
+ self.masks = torch.from_numpy(self.masks_np.astype(np.float32)).cpu() # [n_images, H, W, 3] #
209
+ self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4] # optimize sdf field # rigid model hand
210
+ self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4]
211
+ self.focal = self.intrinsics_all[0][0, 0]
212
+ self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4]
213
+ self.H, self.W = self.images.shape[1], self.images.shape[2]
214
+ self.image_pixels = self.H * self.W
215
+
216
+ object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0])
217
+ object_bbox_max = np.array([ 1.01, 1.01, 1.01, 1.0])
218
+ # Object scale mat: region of interest to **extract mesh**
219
+ object_scale_mat = np.load(os.path.join(self.data_dir, self.object_cameras_name))['scale_mat_0']
220
+ object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
221
+ object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
222
+ self.object_bbox_min = object_bbox_min[:3, 0]
223
+ self.object_bbox_max = object_bbox_max[:3, 0]
224
+
225
+ self.n_images = self.images.size(0)
226
+
227
+ print('Load data: End')
228
+
229
+ def get_rays(H, W, K, c2w, inverse_y, flip_x, flip_y, mode='center'):
230
+ i, j = torch.meshgrid( # meshgrid #
231
+ torch.linspace(0, W-1, W, device=c2w.device),
232
+ torch.linspace(0, H-1, H, device=c2w.device))
233
+ i = i.t().float()
234
+ j = j.t().float()
235
+ if mode == 'lefttop':
236
+ pass
237
+ elif mode == 'center':
238
+ i, j = i+0.5, j+0.5
239
+ elif mode == 'random':
240
+ i = i+torch.rand_like(i)
241
+ j = j+torch.rand_like(j)
242
+ else:
243
+ raise NotImplementedError
244
+
245
+ if flip_x:
246
+ i = i.flip((1,))
247
+ if flip_y:
248
+ j = j.flip((0,))
249
+ if inverse_y:
250
+ dirs = torch.stack([(i-K[0][2])/K[0][0], (j-K[1][2])/K[1][1], torch.ones_like(i)], -1)
251
+ else:
252
+ dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
253
+ # Rotate ray directions from camera frame to the world frame
254
+ rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
255
+ # Translate camera frame's origin to the world frame. It is the origin of all rays.
256
+ rays_o = c2w[:3,3].expand(rays_d.shape)
257
+ return rays_o, rays_d
258
+
259
+ def gen_rays_at(self, img_idx, resolution_level=1):
260
+ """
261
+ Generate rays at world space from one camera.
262
+ """
263
+ l = resolution_level
264
+ tx = torch.linspace(0, self.W - 1, self.W // l)
265
+ ty = torch.linspace(0, self.H - 1, self.H // l)
266
+ pixels_x, pixels_y = torch.meshgrid(tx, ty)
267
+
268
+ ##### previous method #####
269
+ # p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
270
+ # # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1) # W, H, 3
271
+ # p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
272
+ # rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
273
+ # rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
274
+ # rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3
275
+ ##### previous method #####
276
+
277
+ fov = 512.; res = 512.
278
+ K = np.array([
279
+ [fov, 0, 0.5* res],
280
+ [0, fov, 0.5* res],
281
+ [0, 0, 1]
282
+ ], dtype=np.float32)
283
+ K = torch.from_numpy(K).float().cuda()
284
+
285
+
286
+ # ### `center` mode ### #
287
+ c2w = self.pose_all[img_idx]
288
+ pixels_x, pixels_y = pixels_x+0.5, pixels_y+0.5
289
+
290
+ dirs = torch.stack([(pixels_x-K[0][2])/K[0][0], -(pixels_y-K[1][2])/K[1][1], -torch.ones_like(pixels_x)], -1)
291
+ rays_v = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
292
+ rays_o = c2w[:3,3].expand(rays_v.shape)
293
+ # dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
294
+
295
+ # p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
296
+ # # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1) # W, H, 3
297
+ # p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
298
+ # rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
299
+ # rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
300
+ # rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3
301
+ return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
302
+
303
+ def gen_random_rays_at(self, img_idx, batch_size):
304
+ """
305
+ Generate random rays at world space from one camera.
306
+ """
307
+ img_idx = img_idx.cpu()
308
+ pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]).cpu()
309
+ pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]).cpu()
310
+
311
+ # print(self.images.device, img_idx.device, pixels_y.device)
312
+ color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3
313
+
314
+ mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3
315
+
316
+
317
+ ##### previous method #####
318
+ # p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
319
+ # # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
320
+ # p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3
321
+ # rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3
322
+ # rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]).squeeze() # batch_size, 3
323
+ # rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_v.shape) # batch_size, 3
324
+ ##### previous method #####
325
+
326
+ fov = 512.; res = 512.
327
+ K = np.array([
328
+ [fov, 0, 0.5* res],
329
+ [0, fov, 0.5* res],
330
+ [0, 0, 1]
331
+ ], dtype=np.float32)
332
+ K = torch.from_numpy(K).float().cuda()
333
+
334
+
335
+ # ### `center` mode ### #
336
+ c2w = self.pose_all[img_idx]
337
+
338
+ pixels_x = pixels_x.cuda()
339
+ pixels_y = pixels_y.cuda()
340
+ pixels_x, pixels_y = pixels_x+0.5, pixels_y+0.5
341
+
342
+ dirs = torch.stack([(pixels_x-K[0][2])/K[0][0], -(pixels_y-K[1][2])/K[1][1], -torch.ones_like(pixels_x)], -1)
343
+ rays_v = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
344
+ rays_o = c2w[:3,3].expand(rays_v.shape)
345
+
346
+
347
+ return torch.cat([rays_o.cpu(), rays_v.cpu(), color, mask[:, :1]], dim=-1).cuda() # batch_size, 10
348
+
349
+ def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
350
+ """
351
+ Interpolate pose between two cameras.
352
+ """
353
+ l = resolution_level
354
+ tx = torch.linspace(0, self.W - 1, self.W // l)
355
+ ty = torch.linspace(0, self.H - 1, self.H // l)
356
+ pixels_x, pixels_y = torch.meshgrid(tx, ty)
357
+ p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
358
+ p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
359
+ rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
360
+ trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
361
+ pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
362
+ pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
363
+ pose_0 = np.linalg.inv(pose_0)
364
+ pose_1 = np.linalg.inv(pose_1)
365
+ rot_0 = pose_0[:3, :3]
366
+ rot_1 = pose_1[:3, :3]
367
+ rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
368
+ key_times = [0, 1]
369
+ slerp = Slerp(key_times, rots)
370
+ rot = slerp(ratio)
371
+ pose = np.diag([1.0, 1.0, 1.0, 1.0])
372
+ pose = pose.astype(np.float32)
373
+ pose[:3, :3] = rot.as_matrix()
374
+ pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
375
+ pose = np.linalg.inv(pose)
376
+ rot = torch.from_numpy(pose[:3, :3]).cuda()
377
+ trans = torch.from_numpy(pose[:3, 3]).cuda()
378
+ rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
379
+ rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
380
+ return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
381
+
382
+ def near_far_from_sphere(self, rays_o, rays_d):
383
+ a = torch.sum(rays_d**2, dim=-1, keepdim=True)
384
+ b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
385
+ mid = 0.5 * (-b) / a
386
+ near = mid - 1.0
387
+ far = mid + 1.0
388
+ return near, far
389
+
390
+ def image_at(self, idx, resolution_level):
391
+ if self.selected_img_idxes_list is not None:
392
+ img = cv.imread(self.images_lis[self.selected_img_idxes_list[idx]])
393
+ else:
394
+ img = cv.imread(self.images_lis[idx])
395
+ return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255)
396
+
397
+
398
+ if __name__=='__main__':
399
+ data_dir = "/data/datasets/genn/diffsim/diffredmax/save_res/goal_optimize_model_hand_sphere_test_obj_type_active_nfr_10_view_divide_0.5_n_views_7_three_planes_False_recon_dvgo_new_Nposes_7_routine_2"
400
+ data_dir = "/data/datasets/genn/diffsim/neus/public_data/hand_test"
401
+ data_dir = "/data2/datasets/diffsim/neus/public_data/hand_test_routine_2"
402
+ data_dir = "/data2/datasets/diffsim/neus/public_data/hand_test_routine_2_light_color"
403
+ filter_iamges_via_pixel_values(data_dir=data_dir)
models/dyn_model_act.py ADDED
The diff for this file is too large to render. See raw diff
 
models/dyn_model_act_v2.py ADDED
The diff for this file is too large to render. See raw diff
 
models/dyn_model_act_v2_deformable.py ADDED
@@ -0,0 +1,1582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ # import torch
4
+ # from ..utils import Timer
5
+ import numpy as np
6
+ # import torch.nn.functional as F
7
+ import os
8
+
9
+ import argparse
10
+
11
+ from xml.etree.ElementTree import ElementTree
12
+
13
+ import trimesh
14
+ import torch
15
+ import torch.nn as nn
16
+ # import List
17
+ # class link; joint; body
18
+ ###
19
+
20
+ from scipy.spatial.transform import Rotation as R
21
+ from torch.distributions.uniform import Uniform
22
+
23
+ # deformable articulated objects with the articulated models #
24
+
25
+ DAMPING = 1.0
26
+ DAMPING = 0.3
27
+
28
+ def plane_rotation_matrix_from_angle_xz(angle):
29
+ sin_ = torch.sin(angle)
30
+ cos_ = torch.cos(angle)
31
+ zero_padding = torch.zeros_like(cos_)
32
+ one_padding = torch.ones_like(cos_)
33
+ col_a = torch.stack(
34
+ [cos_, zero_padding, sin_], dim=0
35
+ )
36
+ col_b = torch.stack(
37
+ [zero_padding, one_padding, zero_padding], dim=0
38
+ )
39
+ col_c = torch.stack(
40
+ [-1. * sin_, zero_padding, cos_], dim=0
41
+ )
42
+ rot_mtx = torch.stack(
43
+ [col_a, col_b, col_c], dim=-1
44
+ )
45
+ return rot_mtx
46
+
47
+ def plane_rotation_matrix_from_angle(angle):
48
+ ## angle of
49
+ sin_ = torch.sin(angle)
50
+ cos_ = torch.cos(angle)
51
+ col_a = torch.stack(
52
+ [cos_, sin_], dim=0 ### col of the rotation matrix
53
+ )
54
+ col_b = torch.stack(
55
+ [-1. * sin_, cos_], dim=0 ## cols of the rotation matrix
56
+ )
57
+ rot_mtx = torch.stack(
58
+ [col_a, col_b], dim=-1 ### rotation matrix
59
+ )
60
+ return rot_mtx
61
+
62
+ def rotation_matrix_from_axis_angle(axis, angle): # rotation_matrix_from_axis_angle ->
63
+ # sin_ = np.sin(angle) # ti.math.sin(angle)
64
+ # cos_ = np.cos(angle) # ti.math.cos(angle)
65
+ sin_ = torch.sin(angle) # ti.math.sin(angle)
66
+ cos_ = torch.cos(angle) # ti.math.cos(angle)
67
+ u_x, u_y, u_z = axis[0], axis[1], axis[2]
68
+ u_xx = u_x * u_x
69
+ u_yy = u_y * u_y
70
+ u_zz = u_z * u_z
71
+ u_xy = u_x * u_y
72
+ u_xz = u_x * u_z
73
+ u_yz = u_y * u_z
74
+
75
+ row_a = torch.stack(
76
+ [cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dim=0
77
+ )
78
+ # print(f"row_a: {row_a.size()}")
79
+ row_b = torch.stack(
80
+ [u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dim=0
81
+ )
82
+ # print(f"row_b: {row_b.size()}")
83
+ row_c = torch.stack(
84
+ [u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dim=0
85
+ )
86
+ # print(f"row_c: {row_c.size()}")
87
+
88
+ ### rot_mtx for the rot_mtx ###
89
+ rot_mtx = torch.stack(
90
+ [row_a, row_b, row_c], dim=-1 ### rot_matrix of he matrix ##
91
+ )
92
+
93
+ return rot_mtx
94
+
95
+
96
+ def update_quaternion(delta_angle, prev_quat):
97
+ s1 = 0
98
+ s2 = prev_quat[0]
99
+ v2 = prev_quat[1:]
100
+ v1 = delta_angle / 2
101
+ new_v = s1 * v2 + s2 * v1 + torch.cross(v1, v2)
102
+ new_s = s1 * s2 - torch.sum(v1 * v2)
103
+ new_quat = torch.cat([new_s.unsqueeze(0), new_v], dim=0)
104
+ return new_quat
105
+
106
+
107
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
108
+ """
109
+ Convert rotations given as quaternions to rotation matrices.
110
+
111
+ Args:
112
+ quaternions: quaternions with real part first,
113
+ as tensor of shape (..., 4).
114
+
115
+ Returns:
116
+ Rotation matrices as tensor of shape (..., 3, 3).
117
+ """
118
+ r, i, j, k = torch.unbind(quaternions, -1) # -1 for the quaternion matrix #
119
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
120
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
121
+
122
+ o = torch.stack(
123
+ (
124
+ 1 - two_s * (j * j + k * k),
125
+ two_s * (i * j - k * r),
126
+ two_s * (i * k + j * r),
127
+ two_s * (i * j + k * r),
128
+ 1 - two_s * (i * i + k * k),
129
+ two_s * (j * k - i * r),
130
+ two_s * (i * k - j * r),
131
+ two_s * (j * k + i * r),
132
+ 1 - two_s * (i * i + j * j),
133
+ ),
134
+ -1,
135
+ )
136
+
137
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
138
+
139
+
140
+
141
+
142
+ class Inertial:
143
+ def __init__(self, origin_rpy, origin_xyz, mass, inertia) -> None:
144
+ self.origin_rpy = origin_rpy
145
+ self.origin_xyz = origin_xyz
146
+ self.mass = mass
147
+ self.inertia = inertia
148
+ if torch.sum(self.inertia).item() < 1e-4:
149
+ self.inertia = self.inertia + torch.eye(3, dtype=torch.float32).cuda()
150
+ pass
151
+
152
+ class Visual:
153
+ def __init__(self, visual_xyz, visual_rpy, geometry_mesh_fn, geometry_mesh_scale) -> None:
154
+ # self.visual_origin = visual_origin
155
+ self.visual_xyz = visual_xyz
156
+ self.visual_rpy = visual_rpy
157
+ self.mesh_nm = geometry_mesh_fn.split("/")[-1].split(".")[0]
158
+ mesh_root = "/home/xueyi/diffsim/NeuS/rsc/mano" ## mano models of the mesh root ##
159
+ if not os.path.exists(mesh_root):
160
+ mesh_root = "/data/xueyi/diffsim/NeuS/rsc/mano"
161
+ self.mesh_root = mesh_root
162
+ self.geometry_mesh_fn = os.path.join(mesh_root, geometry_mesh_fn)
163
+ self.geometry_mesh_scale = geometry_mesh_scale
164
+ # tranformed by xyz #
165
+ self.vertices, self.faces = self.load_geoemtry_mesh()
166
+ self.cur_expanded_visual_pts = None
167
+ pass
168
+
169
+ def load_geoemtry_mesh(self, ):
170
+ # mesh_root =
171
+ mesh = trimesh.load_mesh(self.geometry_mesh_fn)
172
+ vertices = mesh.vertices
173
+ faces = mesh.faces
174
+
175
+ vertices = torch.from_numpy(vertices).float().cuda()
176
+ faces =torch.from_numpy(faces).long().cuda()
177
+
178
+ vertices = vertices * self.geometry_mesh_scale.unsqueeze(0) + self.visual_xyz.unsqueeze(0)
179
+
180
+ return vertices, faces
181
+
182
+ # init_visual_meshes = get_init_visual_meshes(self, parent_rot, parent_trans, init_visual_meshes)
183
+ def get_init_visual_meshes(self, parent_rot, parent_trans, init_visual_meshes):
184
+ # cur_vertices = torch.matmul(parent_rot, self.vertices.transpose(1, 0)).contiguous().transpose(1, 0).contiguous() + parent_trans.unsqueeze(0)
185
+ cur_vertices = self.vertices
186
+ # print(f"adding mesh loaded from {self.geometry_mesh_fn}")
187
+ init_visual_meshes['vertices'].append(cur_vertices) # cur vertices # trans #
188
+ init_visual_meshes['faces'].append(self.faces)
189
+ return init_visual_meshes
190
+
191
+ def expand_visual_pts(self, ):
192
+ expand_factor = 0.2
193
+ nn_expand_pts = 20
194
+
195
+ expand_factor = 0.4
196
+ nn_expand_pts = 40 ### number of the expanded points ### ## points ##
197
+ expand_save_fn = f"{self.mesh_nm}_expanded_pts_factor_{expand_factor}_nnexp_{nn_expand_pts}.npy"
198
+ expand_save_fn = os.path.join(self.mesh_root, expand_save_fn)
199
+
200
+ if not os.path.exists(expand_save_fn):
201
+ cur_expanded_visual_pts = []
202
+ if self.cur_expanded_visual_pts is None:
203
+ cur_src_pts = self.vertices
204
+ else:
205
+ cur_src_pts = self.cur_expanded_visual_pts
206
+ maxx_verts, _ = torch.max(cur_src_pts, dim=0)
207
+ minn_verts, _ = torch.min(cur_src_pts, dim=0)
208
+ extent_verts = maxx_verts - minn_verts ## (3,)-dim vecotr
209
+ norm_extent_verts = torch.norm(extent_verts, dim=-1).item() ## (1,)-dim vector
210
+ expand_r = norm_extent_verts * expand_factor
211
+ # nn_expand_pts = 5 # expand the vertices to 5 times of the original vertices
212
+ for i_pts in range(self.vertices.size(0)):
213
+ cur_pts = cur_src_pts[i_pts]
214
+ # sample from the circile with cur_pts as thejcenter and the radius as expand_r
215
+ # (-r, r) # sample the offset vector in the size of (nn_expand_pts, 3)
216
+ offset_dist = Uniform(-1. * expand_r, expand_r)
217
+ offset_vec = offset_dist.sample((nn_expand_pts, 3)).cuda()
218
+ cur_expanded_pts = cur_pts + offset_vec
219
+ cur_expanded_visual_pts.append(cur_expanded_pts)
220
+ cur_expanded_visual_pts = torch.cat(cur_expanded_visual_pts, dim=0)
221
+ np.save(expand_save_fn, cur_expanded_visual_pts.detach().cpu().numpy())
222
+ else:
223
+ print(f"Loading visual pts from {expand_save_fn}") # load from the fn #
224
+ cur_expanded_visual_pts = np.load(expand_save_fn, allow_pickle=True)
225
+ cur_expanded_visual_pts = torch.from_numpy(cur_expanded_visual_pts).float().cuda()
226
+ self.cur_expanded_visual_pts = cur_expanded_visual_pts # expanded visual pts #
227
+ return self.cur_expanded_visual_pts
228
+ # cur_pts #
229
+ # use r as the search direction # # expande save fn #
230
+ def get_transformed_visual_pts(self, visual_pts_list):
231
+ visual_pts_list.append(self.cur_expanded_visual_pts) #
232
+ return visual_pts_list
233
+
234
+
235
+
236
+ ## link urdf ## expand the visual pts to form the expanded visual grids pts #
237
+ # use get_name_to_visual_pts_faces to get the transformed visual pts and faces #
238
+ ## Link_urdf ##
239
+ class Link_urdf: # get_transformed_visual_pts #
240
+ def __init__(self, name, inertial: Inertial, visual: Visual=None) -> None:
241
+
242
+ self.name = name
243
+ self.inertial = inertial
244
+ self.visual = visual # vsiual meshes #
245
+
246
+ # self.joint = joint
247
+ # self.body = body
248
+ # self.children = children
249
+ # self.name = name
250
+
251
+ self.link_idx = ...
252
+
253
+ # self.args = args
254
+
255
+ self.joint = None # joint name to struct
256
+ # self.join
257
+ self.children = ...
258
+ self.children = {} # joint name to child sruct
259
+
260
+ ### dyn_model_act ###
261
+ # parent_rot_mtx, parent_trans_vec #
262
+ # parent_rot_mtx, parent_trans_vec # # link urdf #
263
+ # self.parent_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(), requires_grad=True)
264
+ # self.parent_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(), requires_grad=True)
265
+ # self.curr_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(), requires_grad=True)
266
+ # self.curr_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(), requires_grad=True)
267
+ # #
268
+ # self.tot_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(), requires_grad=True)
269
+ # self.tot_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(), requires_grad=True)
270
+
271
+ # expand visual pts #
272
+ def expand_visual_pts(self, expanded_visual_pts, link_name_to_visited, link_name_to_link_struct):
273
+ link_name_to_visited[self.name] = 1
274
+ if self.visual is not None:
275
+ cur_expanded_visual_pts = self.visual.expand_visual_pts()
276
+ expanded_visual_pts.append(cur_expanded_visual_pts)
277
+
278
+ for cur_link in self.children:
279
+ cur_link_struct = link_name_to_link_struct[self.children[cur_link]]
280
+ cur_link_name = cur_link_struct.name
281
+ if cur_link_name in link_name_to_visited:
282
+ continue
283
+ expanded_visual_pts = cur_link_struct.expand_visual_pts(expanded_visual_pts, link_name_to_visited, link_name_to_link_struct)
284
+ return expanded_visual_pts
285
+
286
+ def get_transformed_visual_pts(self, visual_pts_list, link_name_to_visited, link_name_to_link_struct):
287
+ link_name_to_visited[self.name] = 1
288
+
289
+ if self.joint is not None:
290
+ for cur_joint_name in self.joint:
291
+ cur_joint = self.joint[cur_joint_name]
292
+ cur_child_name = self.children[cur_joint_name]
293
+ cur_child = link_name_to_link_struct[cur_child_name] # parent and the child_visual, cur_child.visual #
294
+ # parent #
295
+ # print(f"joint: {cur_joint.name}, child: {cur_child_name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
296
+ # print(f"joint: {cur_joint.name}, child: {cur_child_name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
297
+ # joint_origin_xyz = cur_joint.origin_xyz ## transformed_visual_pts ####
298
+ if cur_child_name in link_name_to_visited:
299
+ continue
300
+ # cur_child_visual_pts = {'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []}
301
+ cur_child_visual_pts_list = []
302
+ cur_child_visual_pts_list = cur_child.get_transformed_visual_pts(cur_child_visual_pts_list, link_name_to_visited, link_name_to_link_struct)
303
+
304
+ if len(cur_child_visual_pts_list) > 0:
305
+ cur_child_visual_pts = torch.cat(cur_child_visual_pts_list, dim=0)
306
+ # cur_child_verts, cur_child_faces = cur_child_visual_pts['vertices'], cur_child_visual_pts['faces']
307
+ # cur_child_link_idxes = cur_child_visual_pts['link_idxes']
308
+ # cur_transformed_joint_pos = cur_child_visual_pts['transformed_joint_pos']
309
+ # joint_link_idxes = cur_child_visual_pts['joint_link_idxes']
310
+ # if len(cur_child_verts) > 0:
311
+ # cur_child_verts, cur_child_faces = merge_meshes(cur_child_verts, cur_child_faces)
312
+ cur_child_visual_pts = cur_child_visual_pts + cur_joint.origin_xyz.unsqueeze(0)
313
+ cur_joint_rot, cur_joint_trans = cur_joint.compute_transformation_from_current_state()
314
+ cur_child_visual_pts = torch.matmul(cur_joint_rot, cur_child_visual_pts.transpose(1, 0).contiguous()).transpose(1, 0).contiguous() + cur_joint_trans.unsqueeze(0)
315
+
316
+ # if len(cur_transformed_joint_pos) > 0:
317
+ # cur_transformed_joint_pos = torch.cat(cur_transformed_joint_pos, dim=0)
318
+ # cur_transformed_joint_pos = cur_transformed_joint_pos + cur_joint.origin_xyz.unsqueeze(0)
319
+ # cur_transformed_joint_pos = torch.matmul(cur_joint_rot, cur_transformed_joint_pos.transpose(1, 0).contiguous()).transpose(1, 0).contiguous() + cur_joint_trans.unsqueeze(0)
320
+ # cur_joint_pos = cur_joint_trans.unsqueeze(0).clone()
321
+ # cur_transformed_joint_pos = torch.cat(
322
+ # [cur_transformed_joint_pos, cur_joint_pos], dim=0 ##### joint poses #####
323
+ # )
324
+ # else:
325
+ # cur_transformed_joint_pos = cur_joint_trans.unsqueeze(0).clone()
326
+
327
+ # if len(joint_link_idxes) > 0:
328
+ # joint_link_idxes = torch.cat(joint_link_idxes, dim=-1) ### joint_link idxes ###
329
+ # cur_joint_idx = cur_child.link_idx
330
+ # joint_link_idxes = torch.cat(
331
+ # [joint_link_idxes, torch.tensor([cur_joint_idx], dtype=torch.long).cuda()], dim=-1
332
+ # )
333
+ # else:
334
+ # joint_link_idxes = torch.tensor([cur_child.link_idx], dtype=torch.long).cuda().view(1,)
335
+
336
+
337
+ visual_pts_list.append(cur_child_visual_pts)
338
+ # cur_child_verts = cur_child_verts + # transformed joint pos #
339
+ # cur_child_link_idxes = torch.cat(cur_child_link_idxes, dim=-1)
340
+ # # joint_link_idxes = torch.cat(joint_link_idxes, dim=-1)
341
+ # init_visual_meshes['vertices'].append(cur_child_verts)
342
+ # init_visual_meshes['faces'].append(cur_child_faces)
343
+ # init_visual_meshes['link_idxes'].append(cur_child_link_idxes)
344
+ # init_visual_meshes['transformed_joint_pos'].append(cur_transformed_joint_pos)
345
+ # init_visual_meshes['joint_link_idxes'].append(joint_link_idxes)
346
+
347
+
348
+ if self.visual is not None:
349
+ # get_transformed_visual_pts #
350
+ visual_pts_list = self.visual.get_transformed_visual_pts(visual_pts_list)
351
+
352
+ # for cur_link in self.children:
353
+ # cur_link_name = cur_link.name
354
+ # if cur_link_name in link_name_to_visited: # link name to visited #
355
+ # continue
356
+ # visual_pts_list = cur_link.get_transformed_visual_pts(visual_pts_list, link_name_to_visited, link_name_to_link_struct)
357
+ return visual_pts_list
358
+
359
+ # use both the articulated motion and the frre form
360
+ def set_initial_state(self, states, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct):
361
+
362
+ link_name_to_visited[self.name] = 1
363
+
364
+ if self.joint is not None:
365
+ for cur_joint_name in self.joint:
366
+ cur_joint = self.joint[cur_joint_name]
367
+ cur_joint_name = cur_joint.name
368
+ cur_child = self.children[cur_joint_name]
369
+ cur_child_struct = link_name_to_link_struct[cur_child]
370
+ cur_child_name = cur_child_struct.name
371
+
372
+ if cur_child_name in link_name_to_visited:
373
+ continue
374
+ if cur_joint.type in ['revolute']:
375
+ cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name] # action joint name to joint idx #
376
+ # cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name] #
377
+ # cur_joint = self.joint[cur_joint_name]
378
+ cur_state = states[cur_joint_idx] ### joint state ###
379
+ cur_joint.set_initial_state(cur_state)
380
+ cur_child_struct.set_initial_state(states, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct)
381
+
382
+
383
+
384
+ def get_init_visual_meshes(self, parent_rot, parent_trans, init_visual_meshes, link_name_to_link_struct, link_name_to_visited):
385
+ link_name_to_visited[self.name] = 1
386
+
387
+ # 'transformed_joint_pos': [], 'link_idxes': []
388
+ if self.joint is not None:
389
+ # for i_ch, (cur_joint, cur_child) in enumerate(zip(self.joint, self.children)):
390
+ # print(f"joint: {cur_joint.name}, child: {cur_child.name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
391
+ # joint_origin_xyz = cur_joint.origin_xyz
392
+ # init_visual_meshes = cur_child.get_init_visual_meshes(parent_rot, parent_trans + joint_origin_xyz, init_visual_meshes)
393
+ # print(f"name: {self.name}, keys: {self.joint.keys()}")
394
+ for cur_joint_name in self.joint: #
395
+ cur_joint = self.joint[cur_joint_name]
396
+ cur_child_name = self.children[cur_joint_name]
397
+ cur_child = link_name_to_link_struct[cur_child_name]
398
+ # print(f"joint: {cur_joint.name}, child: {cur_child_name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
399
+ # print(f"joint: {cur_joint.name}, child: {cur_child_name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
400
+ joint_origin_xyz = cur_joint.origin_xyz
401
+ if cur_child_name in link_name_to_visited:
402
+ continue
403
+ cur_child_visual_pts = {'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []}
404
+ cur_child_visual_pts = cur_child.get_init_visual_meshes(parent_rot, parent_trans + joint_origin_xyz, cur_child_visual_pts, link_name_to_link_struct, link_name_to_visited)
405
+ cur_child_verts, cur_child_faces = cur_child_visual_pts['vertices'], cur_child_visual_pts['faces']
406
+ cur_child_link_idxes = cur_child_visual_pts['link_idxes']
407
+ cur_transformed_joint_pos = cur_child_visual_pts['transformed_joint_pos']
408
+ joint_link_idxes = cur_child_visual_pts['joint_link_idxes']
409
+ if len(cur_child_verts) > 0:
410
+ cur_child_verts, cur_child_faces = merge_meshes(cur_child_verts, cur_child_faces)
411
+ cur_child_verts = cur_child_verts + cur_joint.origin_xyz.unsqueeze(0)
412
+ cur_joint_rot, cur_joint_trans = cur_joint.compute_transformation_from_current_state()
413
+ cur_child_verts = torch.matmul(cur_joint_rot, cur_child_verts.transpose(1, 0).contiguous()).transpose(1, 0).contiguous() + cur_joint_trans.unsqueeze(0)
414
+
415
+ if len(cur_transformed_joint_pos) > 0:
416
+ cur_transformed_joint_pos = torch.cat(cur_transformed_joint_pos, dim=0)
417
+ cur_transformed_joint_pos = cur_transformed_joint_pos + cur_joint.origin_xyz.unsqueeze(0)
418
+ cur_transformed_joint_pos = torch.matmul(cur_joint_rot, cur_transformed_joint_pos.transpose(1, 0).contiguous()).transpose(1, 0).contiguous() + cur_joint_trans.unsqueeze(0)
419
+ cur_joint_pos = cur_joint_trans.unsqueeze(0).clone()
420
+ cur_transformed_joint_pos = torch.cat(
421
+ [cur_transformed_joint_pos, cur_joint_pos], dim=0 ##### joint poses #####
422
+ )
423
+ else:
424
+ cur_transformed_joint_pos = cur_joint_trans.unsqueeze(0).clone()
425
+
426
+ if len(joint_link_idxes) > 0:
427
+ joint_link_idxes = torch.cat(joint_link_idxes, dim=-1) ### joint_link idxes ###
428
+ cur_joint_idx = cur_child.link_idx
429
+ joint_link_idxes = torch.cat(
430
+ [joint_link_idxes, torch.tensor([cur_joint_idx], dtype=torch.long).cuda()], dim=-1
431
+ )
432
+ else:
433
+ joint_link_idxes = torch.tensor([cur_child.link_idx], dtype=torch.long).cuda().view(1,)
434
+
435
+
436
+
437
+ # cur_child_verts = cur_child_verts + # transformed joint pos #
438
+ cur_child_link_idxes = torch.cat(cur_child_link_idxes, dim=-1)
439
+ # joint_link_idxes = torch.cat(joint_link_idxes, dim=-1)
440
+ init_visual_meshes['vertices'].append(cur_child_verts)
441
+ init_visual_meshes['faces'].append(cur_child_faces)
442
+ init_visual_meshes['link_idxes'].append(cur_child_link_idxes)
443
+ init_visual_meshes['transformed_joint_pos'].append(cur_transformed_joint_pos)
444
+ init_visual_meshes['joint_link_idxes'].append(joint_link_idxes)
445
+
446
+
447
+ # joint_origin_xyz = self.joint.origin_xyz
448
+ else:
449
+ joint_origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32).cuda()
450
+ # self.parent_rot_mtx = parent_rot
451
+ # self.parent_trans_vec = parent_trans + joint_origin_xyz
452
+
453
+
454
+ if self.visual is not None:
455
+ init_visual_meshes = self.visual.get_init_visual_meshes(parent_rot, parent_trans, init_visual_meshes)
456
+ cur_visual_mesh_pts_nn = self.visual.vertices.size(0)
457
+ cur_link_idxes = torch.zeros((cur_visual_mesh_pts_nn, ), dtype=torch.long).cuda()+ self.link_idx
458
+ init_visual_meshes['link_idxes'].append(cur_link_idxes)
459
+
460
+ # for cur_link in self.children: #
461
+ # init_visual_meshes = cur_link.get_init_visual_meshes(self.parent_rot_mtx, self.parent_trans_vec, init_visual_meshes)
462
+ return init_visual_meshes ## init visual meshes ##
463
+
464
+ # calculate inerti
465
+ def calculate_inertia(self, link_name_to_visited, link_name_to_link_struct):
466
+ link_name_to_visited[self.name] = 1
467
+ self.cur_inertia = torch.zeros((3, 3), dtype=torch.float32).cuda()
468
+
469
+ if self.joint is not None:
470
+ for joint_nm in self.joint:
471
+ cur_joint = self.joint[joint_nm]
472
+ cur_child = self.children[joint_nm]
473
+ cur_child_struct = link_name_to_link_struct[cur_child]
474
+ cur_child_name = cur_child_struct.name
475
+ if cur_child_name in link_name_to_visited:
476
+ continue
477
+ joint_rot, joint_trans = cur_joint.compute_transformation_from_current_state(n_grad=True)
478
+ # cur_parent_rot = torch.matmul(parent_rot, joint_rot) #
479
+ # cur_parent_trans = torch.matmul(parent_rot, joint_trans.unsqueeze(-1)).squeeze(-1) + parent_trans #
480
+ child_inertia = cur_child_struct.calculate_inertia(link_name_to_visited, link_name_to_link_struct)
481
+ child_inertia = torch.matmul(
482
+ joint_rot.detach(), torch.matmul(child_inertia, joint_rot.detach().transpose(1, 0).contiguous())
483
+ ).detach()
484
+ self.cur_inertia += child_inertia
485
+ # if self.visual is not None:
486
+ # self.cur_inertia += self.visual.inertia
487
+ self.cur_inertia += self.inertial.inertia.detach()
488
+ return self.cur_inertia
489
+
490
+
491
+ def set_delta_state_and_update(self, states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, link_name_to_link_struct):
492
+
493
+ link_name_to_visited[self.name] = 1
494
+
495
+ if self.joint is not None:
496
+ for cur_joint_name in self.joint:
497
+
498
+ cur_joint = self.joint[cur_joint_name] # joint model
499
+
500
+ cur_child = self.children[cur_joint_name] # child model #
501
+
502
+ cur_child_struct = link_name_to_link_struct[cur_child]
503
+
504
+ cur_child_name = cur_child_struct.name
505
+
506
+ if cur_child_name in link_name_to_visited:
507
+ continue
508
+
509
+ ## cur child inertia ##
510
+ # cur_child_inertia = cur_child_struct.cur_inertia
511
+
512
+
513
+ if cur_joint.type in ['revolute']:
514
+ cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name]
515
+ cur_state = states[cur_joint_idx]
516
+ ### get the child struct ###
517
+ # set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
518
+ # set actions and update states #
519
+ cur_joint.set_delta_state_and_update(cur_state, cur_timestep)
520
+
521
+ cur_child_struct.set_delta_state_and_update(states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, link_name_to_link_struct)
522
+
523
+
524
+
525
+ # the joint #
526
+ # set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
527
+ def set_actions_and_update_states(self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct):
528
+
529
+ link_name_to_visited[self.name] = 1
530
+
531
+ # the current joint of the
532
+ if self.joint is not None:
533
+ for cur_joint_name in self.joint:
534
+
535
+ cur_joint = self.joint[cur_joint_name] # joint model
536
+
537
+ cur_child = self.children[cur_joint_name] # child model #
538
+
539
+ cur_child_struct = link_name_to_link_struct[cur_child]
540
+
541
+ cur_child_name = cur_child_struct.name
542
+
543
+ if cur_child_name in link_name_to_visited:
544
+ continue
545
+
546
+ cur_child_inertia = cur_child_struct.cur_inertia
547
+
548
+
549
+ if cur_joint.type in ['revolute']:
550
+ cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name]
551
+ cur_action = actions[cur_joint_idx]
552
+ ### get the child struct ###
553
+ # set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
554
+ # set actions and update states #
555
+ cur_joint.set_actions_and_update_states(cur_action, cur_timestep, time_cons, cur_child_inertia.detach())
556
+
557
+ cur_child_struct.set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct)
558
+
559
+
560
+ def set_init_states_target_value(self, init_states):
561
+ if self.joint.type == 'revolute':
562
+ self.joint_angle = init_states[self.joint.joint_idx]
563
+ joint_axis = self.joint.axis
564
+ self.rot_vec = self.joint_angle * joint_axis
565
+ self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32).cuda()
566
+ self.joint.state = self.joint.state + update_quaternion(self.rot_vec, self.joint.state)
567
+ self.joint.timestep_to_states[0] = self.joint.state.detach()
568
+ self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
569
+ for cur_link in self.children:
570
+ cur_link.set_init_states_target_value(init_states)
571
+
572
+ # should forward for one single step -> use the action #
573
+ def set_init_states(self, ):
574
+ self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32).cuda()
575
+ self.joint.timestep_to_states[0] = self.joint.state.detach()
576
+ self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
577
+ for cur_link in self.children:
578
+ cur_link.set_init_states()
579
+
580
+
581
+ def get_visual_pts(self, visual_pts_list):
582
+ visual_pts_list = self.body.get_visual_pts(visual_pts_list)
583
+ for cur_link in self.children:
584
+ visual_pts_list = cur_link.get_visual_pts(visual_pts_list)
585
+ visual_pts_list = torch.cat(visual_pts_list, dim=0)
586
+ return visual_pts_list
587
+
588
+ def get_visual_faces_list(self, visual_faces_list):
589
+ visual_faces_list = self.body.get_visual_faces_list(visual_faces_list)
590
+ for cur_link in self.children:
591
+ visual_faces_list = cur_link.get_visual_faces_list(visual_faces_list)
592
+ return visual_faces_list
593
+ # pass
594
+
595
+
596
+ def set_state(self, name_to_state):
597
+ self.joint.set_state(name_to_state=name_to_state)
598
+ for child_link in self.children:
599
+ child_link.set_state(name_to_state)
600
+
601
+ def set_state_via_vec(self, state_vec):
602
+ self.joint.set_state_via_vec(state_vec)
603
+ for child_link in self.children:
604
+ child_link.set_state_via_vec(state_vec)
605
+
606
+
607
+
608
+
609
+ class Joint_Limit:
610
+ def __init__(self, effort, lower, upper, velocity) -> None:
611
+ self.effort = effort
612
+ self.lower = lower
613
+ self.velocity = velocity
614
+ self.upper = upper
615
+ pass
616
+
617
+ # Joint_urdf(name, joint_type, parent_link, child_link, origin_xyz, axis_xyz, limit: Joint_Limit)
618
+ class Joint_urdf: #
619
+
620
+ def __init__(self, name, joint_type, parent_link, child_link, origin_xyz, axis_xyz, limit: Joint_Limit) -> None:
621
+ self.name = name
622
+ self.type = joint_type
623
+ self.parent_link = parent_link
624
+ self.child_link = child_link
625
+ self.origin_xyz = origin_xyz
626
+ self.axis_xyz = axis_xyz
627
+ self.limit = limit
628
+
629
+ # joint angle; joint state #
630
+ self.timestep_to_vels = {}
631
+ self.timestep_to_states = {}
632
+
633
+ self.init_pos = self.origin_xyz.clone()
634
+
635
+ #### only for the current state #### # joint urdf #
636
+ self.state = nn.Parameter(
637
+ torch.tensor([1., 0., 0., 0.], dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True
638
+ )
639
+ self.action = nn.Parameter(
640
+ torch.zeros((1,), dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True
641
+ )
642
+ # self.rot_mtx = np.eye(3, dtypes=np.float32)
643
+ # self.trans_vec = np.zeros((3,), dtype=np.float32) ## rot m
644
+ self.rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True)
645
+ self.trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True)
646
+
647
+ def set_initial_state(self, state):
648
+ # joint angle as the state value #
649
+ self.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
650
+ delta_rot_vec = self.axis_xyz * state
651
+ # self.timestep_to_states[0] = state.detach()
652
+ cur_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
653
+ init_state = cur_state + update_quaternion(delta_rot_vec, cur_state)
654
+ self.timestep_to_states[0] = init_state.detach()
655
+ self.state = init_state
656
+
657
+ def set_delta_state_and_update(self, state, cur_timestep):
658
+ self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).cuda().detach()
659
+ delta_rot_vec = self.axis_xyz * state
660
+ if cur_timestep == 0:
661
+ prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
662
+ else:
663
+ prev_state = self.timestep_to_states[cur_timestep - 1].detach()
664
+ cur_state = prev_state + update_quaternion(delta_rot_vec, prev_state)
665
+ self.timestep_to_states[cur_timestep] = cur_state.detach()
666
+ self.state = cur_state
667
+
668
+
669
+
670
+ def compute_transformation_from_current_state(self, n_grad=False):
671
+ # together with the parent rot mtx and the parent trans vec #
672
+ # cur_joint_state = self.state
673
+ if self.type == "revolute":
674
+ # rot_mtx = rotation_matrix_from_axis_angle(self.axis, cur_joint_state)
675
+ # trans_vec = self.pos - np.matmul(rot_mtx, self.pos.reshape(3, 1)).reshape(3)
676
+ if n_grad:
677
+ rot_mtx = quaternion_to_matrix(self.state.detach())
678
+ else:
679
+ rot_mtx = quaternion_to_matrix(self.state)
680
+ # trans_vec = self.pos - torch.matmul(rot_mtx, self.pos.view(3, 1)).view(3).contiguous()
681
+ trans_vec = self.origin_xyz - torch.matmul(rot_mtx, self.origin_xyz.view(3, 1)).view(3).contiguous()
682
+ self.rot_mtx = rot_mtx
683
+ self.trans_vec = trans_vec
684
+ elif self.type == "fixed":
685
+ rot_mtx = torch.eye(3, dtype=torch.float32).cuda()
686
+ trans_vec = torch.zeros((3,), dtype=torch.float32).cuda()
687
+ # trans_vec = self.origin_xyz
688
+ self.rot_mtx = rot_mtx
689
+ self.trans_vec = trans_vec #
690
+ else:
691
+ pass
692
+ return self.rot_mtx, self.trans_vec
693
+
694
+
695
+ # set actions # set actions and udpate states #
696
+ def set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
697
+
698
+ # timestep_to_vels, timestep_to_states, state #
699
+ if self.type in ['revolute']:
700
+
701
+ self.action = action
702
+ #
703
+ # visual_pts and visual_pts_mass #
704
+ # cur_joint_pos = self.joint.pos #
705
+ # TODO: check whether the following is correct #
706
+ torque = self.action * self.axis_xyz
707
+
708
+ # # Compute inertia matrix #
709
+ # inertial = torch.zeros((3, 3), dtype=torch.float32).cuda()
710
+ # for i_pts in range(self.visual_pts.size(0)):
711
+ # cur_pts = self.visual_pts[i_pts]
712
+ # cur_pts_mass = self.visual_pts_mass[i_pts]
713
+ # cur_r = cur_pts - cur_joint_pos # r_i
714
+ # # cur_vert = init_passive_mesh[i_v]
715
+ # # cur_r = cur_vert - init_passive_mesh_center
716
+ # dot_r_r = torch.sum(cur_r * cur_r)
717
+ # cur_eye_mtx = torch.eye(3, dtype=torch.float32).cuda()
718
+ # r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
719
+ # inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
720
+ # m = torch.sum(self.visual_pts_mass)
721
+ # # Use torque to update angular velocity -> state #
722
+ # inertia_inv = torch.linalg.inv(inertial)
723
+
724
+ # axis-angle of
725
+ # inertia_inv = self.cur_inertia_inv
726
+ # print(f"updating actions and states for the joint {self.name} with type {self.type}")
727
+ inertia_inv = torch.linalg.inv(cur_inertia).detach()
728
+
729
+ delta_omega = torch.matmul(inertia_inv, torque.unsqueeze(-1)).squeeze(-1)
730
+
731
+ # delta_omega = torque / 400 #
732
+
733
+ # timestep_to_vels, timestep_to_states, state #
734
+
735
+ # TODO: dt should be an optimizable constant? should it be the same value as that optimized for the passive object? #
736
+ delta_angular_vel = delta_omega * time_cons # * self.args.dt
737
+ delta_angular_vel = delta_angular_vel.squeeze(0)
738
+ if cur_timestep > 0: ## cur_timestep - 1 ##
739
+ prev_angular_vel = self.timestep_to_vels[cur_timestep - 1].detach()
740
+ cur_angular_vel = prev_angular_vel + delta_angular_vel * DAMPING
741
+ else:
742
+ cur_angular_vel = delta_angular_vel
743
+
744
+ self.timestep_to_vels[cur_timestep] = cur_angular_vel.detach()
745
+
746
+ cur_delta_quat = cur_angular_vel * time_cons # * self.args.dt
747
+ cur_delta_quat = cur_delta_quat.squeeze(0)
748
+ cur_state = self.timestep_to_states[cur_timestep].detach() # quaternion #
749
+ # print(f"cur_delta_quat: {cur_delta_quat.size()}, cur_state: {cur_state.size()}")
750
+ nex_state = cur_state + update_quaternion(cur_delta_quat, cur_state)
751
+ self.timestep_to_states[cur_timestep + 1] = nex_state.detach()
752
+ self.state = nex_state # set the joint state #
753
+
754
+
755
+ # get_transformed_visual_pts() --- transformed_visual_pts ##
756
+ # use the transformed visual # the articulated motion field #
757
+ # then we should add the free motion field here # # add the free motion field # # hwo to use that? #
758
+ # another rules for optimizing articulation motion field #
759
+ # -> the articulated model predicted transformations #
760
+ # -> the free motion field -> the motion field predicted by the network for each timestep -> an implicit motion field #
761
+
762
+ class Robot_urdf:
763
+ def __init__(self, links, link_name_to_link_idxes, link_name_to_link_struct, joint_name_to_joint_idx, actions_joint_name_to_joint_idx) -> None:
764
+ self.links = links
765
+ self.link_name_to_link_idxes = link_name_to_link_idxes
766
+ self.link_name_to_link_struct = link_name_to_link_struct
767
+ # joint_name_to_joint_idx, actions_joint_name_to_joint_idx
768
+ self.joint_name_to_joint_idx = joint_name_to_joint_idx
769
+ self.actions_joint_name_to_joint_idx = actions_joint_name_to_joint_idx
770
+
771
+ #
772
+ # particles
773
+ # sample particles
774
+ # how to sample particles?
775
+ # how to expand the particles? # -> you can use weights in the model dict #
776
+ # from grids and jample from grids #
777
+ # link idx to the
778
+ # robot #
779
+ # init vertices, init faces #
780
+ # expande the aprticles #
781
+ # expanede particles #
782
+ # use particles to conduct the simulation #
783
+
784
+
785
+
786
+ self.init_vertices, self.init_faces = self.get_init_visual_pts()
787
+
788
+
789
+ init_visual_pts_sv_fn = "robot_expanded_visual_pts.npy"
790
+ np.save(init_visual_pts_sv_fn, self.init_vertices.detach().cpu().numpy())
791
+
792
+ joint_name_to_joint_idx_sv_fn = "mano_joint_name_to_joint_idx.npy"
793
+ np.save(joint_name_to_joint_idx_sv_fn, self.joint_name_to_joint_idx)
794
+
795
+ actions_joint_name_to_joint_idx_sv_fn = "mano_actions_joint_name_to_joint_idx.npy"
796
+ np.save(actions_joint_name_to_joint_idx_sv_fn, self.actions_joint_name_to_joint_idx)
797
+
798
+ tot_joints = len(self.joint_name_to_joint_idx)
799
+ tot_actions_joints = len(self.actions_joint_name_to_joint_idx)
800
+
801
+ print(f"tot_joints: {tot_joints}, tot_actions_joints: {tot_actions_joints}")
802
+
803
+ pass
804
+
805
+
806
+
807
+
808
+ def expand_visual_pts(self, ):
809
+ link_name_to_visited = {}
810
+ # transform the visual pts #
811
+ # action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
812
+
813
+ palm_idx = self.link_name_to_link_idxes["palm"]
814
+ palm_link = self.links[palm_idx]
815
+ expanded_visual_pts = []
816
+ # expanded the visual pts # # transformed viusal pts # or the translations of the visual pts #
817
+ expanded_visual_pts = palm_link.expand_visual_pts(expanded_visual_pts, link_name_to_visited, self.link_name_to_link_struct)
818
+ expanded_visual_pts = torch.cat(expanded_visual_pts, dim=0)
819
+ # pass
820
+ return expanded_visual_pts
821
+
822
+
823
+ # get_transformed_visual_pts() # get_transformed_visual_pts of the visual pts ### get_transformed_visual_pts ## get_transformed_visual_pts ### #
824
+ def get_transformed_visual_pts(self, ):
825
+ init_visual_pts = []
826
+ link_name_to_visited = {}
827
+
828
+ palm_idx = self.link_name_to_link_idxes["palm"]
829
+ palm_link = self.links[palm_idx]
830
+
831
+ ### init_visual_pts # from the pal mink to get the total transformed visual pts ##
832
+ init_visual_pts = palm_link.get_transformed_visual_pts(init_visual_pts, link_name_to_visited, self.link_name_to_link_struct)
833
+
834
+ init_visual_pts = torch.cat(init_visual_pts, dim=0) ## get the inita visual pts from the palm link ###
835
+ return init_visual_pts
836
+
837
+
838
+ ### samping issue? --- TODO` `
839
+ def get_init_visual_pts(self, ):
840
+ init_visual_meshes = {
841
+ 'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []
842
+ }
843
+ init_parent_rot = torch.eye(3, dtype=torch.float32).cuda()
844
+ init_parent_trans = torch.zeros((3,), dtype=torch.float32).cuda()
845
+
846
+ palm_idx = self.link_name_to_link_idxes["palm"]
847
+ palm_link = self.links[palm_idx]
848
+
849
+ link_name_to_visited = {}
850
+
851
+ init_visual_meshes = palm_link.get_init_visual_meshes(init_parent_rot, init_parent_trans, init_visual_meshes, self.link_name_to_link_struct, link_name_to_visited)
852
+
853
+ self.link_idxes = torch.cat(init_visual_meshes['link_idxes'], dim=-1)
854
+ self.transformed_joint_pos = torch.cat(init_visual_meshes['transformed_joint_pos'], dim=0)
855
+ self.joint_link_idxes = torch.cat(init_visual_meshes['joint_link_idxes'], dim=-1) ###
856
+
857
+
858
+
859
+ # for cur_link in self.links:
860
+ # init_visual_meshes = cur_link.get_init_visual_meshes(init_parent_rot, init_parent_trans, init_visual_meshes, self.link_name_to_link_struct, link_name_to_visited)
861
+
862
+ init_vertices, init_faces = merge_meshes(init_visual_meshes['vertices'], init_visual_meshes['faces'])
863
+ return init_vertices, init_faces
864
+
865
+
866
+ def set_delta_state_and_update(self, states, cur_timestep):
867
+ link_name_to_visited = {}
868
+
869
+ action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
870
+
871
+ palm_idx = self.link_name_to_link_idxes["palm"]
872
+ palm_link = self.links[palm_idx]
873
+
874
+ link_name_to_visited = {}
875
+
876
+ palm_link.set_delta_state_and_update(states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, self.link_name_to_link_struct)
877
+
878
+ # cur_joint.set_actions_and_update_states(cur_action, cur_timestep, time_cons, cur_child_inertia)
879
+ def set_actions_and_update_states(self, actions, cur_timestep, time_cons,):
880
+ # actions
881
+ # self.actions_joint_name_to_joint_idx as the action joint name to joint idx
882
+ link_name_to_visited = {}
883
+
884
+ action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
885
+
886
+ palm_idx = self.link_name_to_link_idxes["palm"]
887
+ palm_link = self.links[palm_idx]
888
+
889
+ link_name_to_visited = {}
890
+
891
+ palm_link.set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
892
+
893
+ # for cur_joint in
894
+
895
+ # for cur_link in self.links:
896
+ # if cur_link.joint is not None:
897
+ # for cur_joint_nm in cur_link.joint:
898
+ # if cur_link.joint[cur_joint_nm].type in ['revolute']:
899
+ # cur_link_joint_name = cur_link.joint[cur_joint_nm].name
900
+ # cur_link_joint_idx = self.actions_joint_name_to_joint_idx[cur_link_joint_name]
901
+
902
+
903
+ # for cur_link in self.links:
904
+ # cur_link.set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
905
+
906
+ ### TODO: add the contact torque when calculating the nextstep states ###
907
+ ### TODO: not an accurate implementation since differen joints should be jconsidered for one single link ###
908
+ ### TODO: the articulated force modle is not so easy as this one .... ###
909
+ def set_contact_forces(self, hard_selected_forces, hard_selected_manipulating_points, hard_selected_sampled_input_pts_idxes):
910
+ # transformed_joint_pos, joint_link_idxes, link_idxes #
911
+ selected_pts_link_idxes = self.link_idxes[hard_selected_sampled_input_pts_idxes]
912
+ # use the selected link idxes #
913
+ # selected pts idxes #
914
+
915
+ # self.joint_link_idxes, transformed_joint_pos #
916
+ self.link_idx_to_transformed_joint_pos = {}
917
+ for i_link in range(self.transformed_joint_pos.size(0)):
918
+ cur_link_idx = self.link_idxes[i_link].item()
919
+ cur_link_pos = self.transformed_joint_pos[i_link]
920
+ # if cur_link_idx not in self.link_idx_to_transformed_joint_pos:
921
+ self.link_idx_to_transformed_joint_pos[cur_link_idx] = cur_link_pos
922
+ # self.link_idx_to_transformed_joint_pos[cur_link_idx].append(cur_link_pos)
923
+
924
+ # from the
925
+ self.link_idx_to_contact_forces = {}
926
+ for i_c_pts in range(hard_selected_forces.size(0)):
927
+ cur_contact_force = hard_selected_forces[i_c_pts] ##
928
+ cur_link_idx = selected_pts_link_idxes[i_c_pts].item()
929
+ cur_link_pos = self.link_idx_to_transformed_joint_pos[cur_link_idx]
930
+ cur_link_action_pos = hard_selected_manipulating_points[i_c_pts]
931
+ # (action_pos - link_pos) x (-contact_force) #
932
+ cur_contact_torque = torch.cross(
933
+ cur_link_action_pos - cur_link_pos, -cur_contact_force
934
+ )
935
+ if cur_link_idx not in self.link_idx_to_contact_forces:
936
+ self.link_idx_to_contact_forces[cur_link_idx] = [cur_contact_torque]
937
+ else:
938
+ self.link_idx_to_contact_forces[cur_link_idx].append(cur_contact_torque)
939
+ for link_idx in self.link_idx_to_contact_forces:
940
+ self.link_idx_to_contact_forces[link_idx] = torch.stack(self.link_idx_to_contact_forces[link_idx], dim=0)
941
+ self.link_idx_to_contact_forces[link_idx] = torch.sum(self.link_idx_to_contact_forces[link_idx] , dim=0)
942
+ for link_idx, link_struct in enumerate(self.links):
943
+ if link_idx in self.link_idx_to_contact_forces:
944
+ cur_link_contact_force = self.link_idx_to_contact_forces[link_idx]
945
+ link_struct.contact_torque = cur_link_contact_force
946
+ else:
947
+ link_struct.contact_torque = None
948
+
949
+
950
+ # def se ### from the optimizable initial states ###
951
+ def set_initial_state(self, states):
952
+ action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
953
+ link_name_to_visited = {}
954
+
955
+ palm_idx = self.link_name_to_link_idxes["palm"]
956
+ palm_link = self.links[palm_idx]
957
+
958
+ link_name_to_visited = {}
959
+
960
+ palm_link.set_initial_state(states, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
961
+
962
+ # for cur_link in self.links:
963
+ # cur_link.set_initial_state(states, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
964
+
965
+ ### after each timestep -> re-calculate the inertial matrix using the current simulated states and the set the new actiosn and forward the simulation #
966
+ def calculate_inertia(self):
967
+ link_name_to_visited = {}
968
+
969
+ palm_idx = self.link_name_to_link_idxes["palm"]
970
+ palm_link = self.links[palm_idx]
971
+
972
+ link_name_to_visited = {}
973
+
974
+ palm_link.calculate_inertia(link_name_to_visited, self.link_name_to_link_struct)
975
+
976
+ # for cur_link in self.links:
977
+ # cur_link.calculate_inertia(link_name_to_visited, self.link_name_to_link_struct)
978
+
979
+
980
+
981
+
982
+
983
+ def parse_nparray_from_string(strr, args=None):
984
+ vals = strr.split(" ")
985
+ vals = [float(val) for val in vals]
986
+ vals = np.array(vals, dtype=np.float32)
987
+ vals = torch.from_numpy(vals).float()
988
+ ## vals ##
989
+ vals = nn.Parameter(vals.cuda(), requires_grad=True)
990
+
991
+ return vals
992
+
993
+
994
+ ### parse link data ###
995
+ def parse_link_data(link, args):
996
+
997
+ link_name = link.attrib["name"]
998
+ # print(f"parsing link: {link_name}") ## joints body meshes #
999
+
1000
+ joint = link.find("./joint")
1001
+
1002
+ joint_name = joint.attrib["name"]
1003
+ joint_type = joint.attrib["type"]
1004
+ if joint_type in ["revolute"]: ## a general xml parser here?
1005
+ axis = joint.attrib["axis"]
1006
+ axis = parse_nparray_from_string(axis, args=args)
1007
+ else:
1008
+ axis = None
1009
+ pos = joint.attrib["pos"] #
1010
+ pos = parse_nparray_from_string(pos, args=args)
1011
+ quat = joint.attrib["quat"]
1012
+ quat = parse_nparray_from_string(quat, args=args)
1013
+
1014
+ try:
1015
+ frame = joint.attrib["frame"]
1016
+ except:
1017
+ frame = "WORLD"
1018
+
1019
+ if joint_type not in ["fixed"]:
1020
+ damping = joint.attrib["damping"]
1021
+ damping = float(damping)
1022
+ else:
1023
+ damping = 0.0
1024
+
1025
+ cur_joint = Joint(joint_name, joint_type, axis, pos, quat, frame, damping, args=args)
1026
+
1027
+ body = link.find("./body")
1028
+ body_name = body.attrib["name"]
1029
+ body_type = body.attrib["type"]
1030
+ if body_type == "mesh":
1031
+ filename = body.attrib["filename"]
1032
+ else:
1033
+ filename = ""
1034
+
1035
+ if body_type == "sphere":
1036
+ radius = body.attrib["radius"]
1037
+ radius = float(radius)
1038
+ else:
1039
+ radius = 0.
1040
+
1041
+ pos = body.attrib["pos"]
1042
+ pos = parse_nparray_from_string(pos, args=args)
1043
+ quat = body.attrib["quat"]
1044
+ quat = joint.attrib["quat"]
1045
+ try:
1046
+ transform_type = body.attrib["transform_type"]
1047
+ except:
1048
+ transform_type = "OBJ_TO_WORLD"
1049
+ density = body.attrib["density"]
1050
+ density = float(density)
1051
+ mu = body.attrib["mu"]
1052
+ mu = float(mu)
1053
+ try: ## rgba ##
1054
+ rgba = body.attrib["rgba"]
1055
+ rgba = parse_nparray_from_string(rgba, args=args)
1056
+ except:
1057
+ rgba = np.zeros((4,), dtype=np.float32)
1058
+
1059
+ cur_body = Body(body_name, body_type, filename, pos, quat, transform_type, density, mu, rgba, radius, args=args)
1060
+
1061
+ children_link = []
1062
+ links = link.findall("./link")
1063
+ for child_link in links: #
1064
+ cur_child_link = parse_link_data(child_link, args=args)
1065
+ children_link.append(cur_child_link)
1066
+
1067
+ link_name = link.attrib["name"]
1068
+ link_obj = Link(link_name, joint=cur_joint, body=cur_body, children=children_link, args=args)
1069
+ return link_obj
1070
+
1071
+
1072
+ ### parse link data ###
1073
+ def parse_link_data_urdf(link):
1074
+
1075
+ link_name = link.attrib["name"]
1076
+ # print(f"parsing link: {link_name}") ## joints body meshes #
1077
+
1078
+ inertial = link.find("./inertial")
1079
+
1080
+ origin = inertial.find("./origin")
1081
+ inertial_pos = origin.attrib["xyz"]
1082
+ inertial_pos = parse_nparray_from_string(inertial_pos)
1083
+
1084
+ inertial_rpy = origin.attrib["rpy"]
1085
+ inertial_rpy = parse_nparray_from_string(inertial_rpy)
1086
+
1087
+ inertial_mass = inertial.find("./mass")
1088
+ inertial_mass = inertial_mass.attrib["value"]
1089
+
1090
+ inertial_inertia = inertial.find("./inertia")
1091
+ inertial_ixx = inertial_inertia.attrib["ixx"]
1092
+ inertial_ixx = float(inertial_ixx)
1093
+ inertial_ixy = inertial_inertia.attrib["ixy"]
1094
+ inertial_ixy = float(inertial_ixy)
1095
+ inertial_ixz = inertial_inertia.attrib["ixz"]
1096
+ inertial_ixz = float(inertial_ixz)
1097
+ inertial_iyy = inertial_inertia.attrib["iyy"]
1098
+ inertial_iyy = float(inertial_iyy)
1099
+ inertial_iyz = inertial_inertia.attrib["iyz"]
1100
+ inertial_iyz = float(inertial_iyz)
1101
+ inertial_izz = inertial_inertia.attrib["izz"]
1102
+ inertial_izz = float(inertial_izz)
1103
+
1104
+ inertial_inertia_mtx = torch.zeros((3, 3), dtype=torch.float32).cuda()
1105
+ inertial_inertia_mtx[0, 0] = inertial_ixx
1106
+ inertial_inertia_mtx[0, 1] = inertial_ixy
1107
+ inertial_inertia_mtx[0, 2] = inertial_ixz
1108
+ inertial_inertia_mtx[1, 0] = inertial_ixy
1109
+ inertial_inertia_mtx[1, 1] = inertial_iyy
1110
+ inertial_inertia_mtx[1, 2] = inertial_iyz
1111
+ inertial_inertia_mtx[2, 0] = inertial_ixz
1112
+ inertial_inertia_mtx[2, 1] = inertial_iyz
1113
+ inertial_inertia_mtx[2, 2] = inertial_izz
1114
+
1115
+ # [xx, xy, xz] #
1116
+ # [0, yy, yz] #
1117
+ # [0, 0, zz] #
1118
+
1119
+ # a strange inertia value ... #
1120
+ # TODO: how to compute the inertia matrix? #
1121
+
1122
+ visual = link.find("./visual")
1123
+
1124
+ if visual is not None:
1125
+ origin = visual.find("./origin")
1126
+ visual_pos = origin.attrib["xyz"]
1127
+ visual_pos = parse_nparray_from_string(visual_pos)
1128
+ visual_rpy = origin.attrib["rpy"]
1129
+ visual_rpy = parse_nparray_from_string(visual_rpy)
1130
+ geometry = visual.find("./geometry")
1131
+ geometry_mesh = geometry.find("./mesh")
1132
+ mesh_fn = geometry_mesh.attrib["filename"]
1133
+ mesh_scale = geometry_mesh.attrib["scale"]
1134
+
1135
+ mesh_scale = parse_nparray_from_string(mesh_scale)
1136
+ mesh_fn = str(mesh_fn)
1137
+
1138
+
1139
+ link_struct = Link_urdf(name=link_name, inertial=Inertial(origin_rpy=inertial_rpy, origin_xyz=inertial_pos, mass=inertial_mass, inertia=inertial_inertia_mtx), visual=Visual(visual_rpy=visual_rpy, visual_xyz=visual_pos, geometry_mesh_fn=mesh_fn, geometry_mesh_scale=mesh_scale) if visual is not None else None)
1140
+
1141
+ return link_struct
1142
+
1143
+ def parse_joint_data_urdf(joint):
1144
+ joint_name = joint.attrib["name"]
1145
+ joint_type = joint.attrib["type"]
1146
+
1147
+ parent = joint.find("./parent")
1148
+ child = joint.find("./child")
1149
+ parent_name = parent.attrib["link"]
1150
+ child_name = child.attrib["link"]
1151
+
1152
+ joint_origin = joint.find("./origin")
1153
+ # if joint_origin.
1154
+ try:
1155
+ origin_xyz = joint_origin.attrib["xyz"]
1156
+ origin_xyz = parse_nparray_from_string(origin_xyz)
1157
+ except:
1158
+ origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32).cuda()
1159
+
1160
+ joint_axis = joint.find("./axis")
1161
+ if joint_axis is not None:
1162
+ joint_axis = joint_axis.attrib["xyz"]
1163
+ joint_axis = parse_nparray_from_string(joint_axis)
1164
+ else:
1165
+ joint_axis = torch.tensor([1, 0., 0.], dtype=torch.float32).cuda()
1166
+
1167
+ joint_limit = joint.find("./limit")
1168
+ if joint_limit is not None:
1169
+ joint_lower = joint_limit.attrib["lower"]
1170
+ joint_lower = float(joint_lower)
1171
+ joint_upper = joint_limit.attrib["upper"]
1172
+ joint_upper = float(joint_upper)
1173
+ joint_effort = joint_limit.attrib["effort"]
1174
+ joint_effort = float(joint_effort)
1175
+ joint_velocity = joint_limit.attrib["velocity"]
1176
+ joint_velocity = float(joint_velocity)
1177
+ else:
1178
+ joint_lower = -0.5000
1179
+ joint_upper = 1.57
1180
+ joint_effort = 1000
1181
+ joint_velocity = 0.5
1182
+
1183
+ # cosntruct the joint data #
1184
+ joint_limit = Joint_Limit(effort=joint_effort, lower=joint_lower, upper=joint_upper, velocity=joint_velocity)
1185
+ cur_joint_struct = Joint_urdf(joint_name, joint_type, parent_name, child_name, origin_xyz, joint_axis, joint_limit)
1186
+ return cur_joint_struct
1187
+
1188
+
1189
+
1190
+ def parse_data_from_urdf(xml_fn):
1191
+
1192
+ tree = ElementTree()
1193
+ tree.parse(xml_fn)
1194
+ print(f"{xml_fn}")
1195
+ ### get total robots ###
1196
+ # robots = tree.findall("link")
1197
+ cur_robot = tree
1198
+ # i_robot = 0 #
1199
+ # tot_robots = [] #
1200
+ # for cur_robot in robots: #
1201
+ # print(f"Getting robot: {i_robot}") #
1202
+ # i_robot += 1 #
1203
+ # print(f"len(robots): {len(robots)}") #
1204
+ # cur_robot = robots[0] #
1205
+ cur_links = cur_robot.findall("./link")
1206
+ # i_link = 0
1207
+ link_name_to_link_idxes = {}
1208
+ cur_robot_links = []
1209
+ link_name_to_link_struct = {}
1210
+ for i_link_idx, cur_link in enumerate(cur_links):
1211
+ cur_link_struct = parse_link_data_urdf(cur_link)
1212
+ print(f"Adding link {cur_link_struct.name}")
1213
+ cur_link_struct.link_idx = i_link_idx
1214
+ cur_robot_links.append(cur_link_struct)
1215
+
1216
+ link_name_to_link_idxes[cur_link_struct.name] = i_link_idx
1217
+ link_name_to_link_struct[cur_link_struct.name] = cur_link_struct
1218
+ # for cur_link in cur_links:
1219
+ # cur_robot_links.append(parse_link_data_urdf(cur_link, args=args))
1220
+
1221
+ print(f"link_name_to_link_struct: {len(link_name_to_link_struct)}, ")
1222
+
1223
+ tot_robot_joints = []
1224
+
1225
+ joint_name_to_joint_idx = {}
1226
+
1227
+ actions_joint_name_to_joint_idx = {}
1228
+
1229
+ cur_joints = cur_robot.findall("./joint")
1230
+ for i_joint, cur_joint in enumerate(cur_joints):
1231
+ cur_joint_struct = parse_joint_data_urdf(cur_joint)
1232
+ cur_joint_parent_link = cur_joint_struct.parent_link
1233
+ cur_joint_child_link = cur_joint_struct.child_link
1234
+
1235
+ cur_joint_idx = len(tot_robot_joints)
1236
+ cur_joint_name = cur_joint_struct.name
1237
+
1238
+ joint_name_to_joint_idx[cur_joint_name] = cur_joint_idx
1239
+
1240
+ cur_joint_type = cur_joint_struct.type
1241
+ if cur_joint_type in ['revolute']:
1242
+ actions_joint_name_to_joint_idx[cur_joint_name] = cur_joint_idx
1243
+
1244
+
1245
+ #### add the current joint to tot joints ###
1246
+ tot_robot_joints.append(cur_joint_struct)
1247
+
1248
+ parent_link_idx = link_name_to_link_idxes[cur_joint_parent_link]
1249
+ cur_parent_link_struct = cur_robot_links[parent_link_idx]
1250
+
1251
+
1252
+ child_link_idx = link_name_to_link_idxes[cur_joint_child_link]
1253
+ cur_child_link_struct = cur_robot_links[child_link_idx]
1254
+ # parent link struct #
1255
+ if link_name_to_link_struct[cur_joint_parent_link].joint is not None:
1256
+ link_name_to_link_struct[cur_joint_parent_link].joint[cur_joint_struct.name] = cur_joint_struct
1257
+ link_name_to_link_struct[cur_joint_parent_link].children[cur_joint_struct.name] = cur_child_link_struct.name
1258
+ # cur_child_link_struct
1259
+ # cur_parent_link_struct.joint.append(cur_joint_struct)
1260
+ # cur_parent_link_struct.children.append(cur_child_link_struct)
1261
+ else:
1262
+ link_name_to_link_struct[cur_joint_parent_link].joint = {
1263
+ cur_joint_struct.name: cur_joint_struct
1264
+ }
1265
+ link_name_to_link_struct[cur_joint_parent_link].children = {
1266
+ cur_joint_struct.name: cur_child_link_struct.name
1267
+ # cur_child_link_struct
1268
+ }
1269
+ # cur_parent_link_struct.joint = [cur_joint_struct]
1270
+ # cur_parent_link_struct.children.append(cur_child_link_struct)
1271
+ # pass
1272
+
1273
+
1274
+ cur_robot_obj = Robot_urdf(cur_robot_links, link_name_to_link_idxes, link_name_to_link_struct, joint_name_to_joint_idx, actions_joint_name_to_joint_idx)
1275
+ # tot_robots.append(cur_robot_obj)
1276
+
1277
+ # for the joint robots #
1278
+ # for every joint
1279
+ # tot_actuators = []
1280
+ # actuators = tree.findall("./actuator/motor")
1281
+ # joint_nm_to_joint_idx = {}
1282
+ # i_act = 0
1283
+ # for cur_act in actuators:
1284
+ # cur_act_joint_nm = cur_act.attrib["joint"]
1285
+ # joint_nm_to_joint_idx[cur_act_joint_nm] = i_act
1286
+ # i_act += 1 ### add the act ###
1287
+
1288
+ # tot_robots[0].set_joint_idx(joint_nm_to_joint_idx) ### set joint idx here ### # tot robots #
1289
+ # tot_robots[0].get_nn_pts()
1290
+ # tot_robots[1].get_nn_pts()
1291
+
1292
+ return cur_robot_obj
1293
+
1294
+
1295
+ def get_name_to_state_from_str(states_str):
1296
+ tot_states = states_str.split(" ")
1297
+ tot_states = [float(cur_state) for cur_state in tot_states]
1298
+ joint_name_to_state = {}
1299
+ for i in range(len(tot_states)):
1300
+ cur_joint_name = f"joint{i + 1}"
1301
+ cur_joint_state = tot_states[i]
1302
+ joint_name_to_state[cur_joint_name] = cur_joint_state
1303
+ return joint_name_to_state
1304
+
1305
+
1306
+ def merge_meshes(verts_list, faces_list):
1307
+ nn_verts = 0
1308
+ tot_verts_list = []
1309
+ tot_faces_list = []
1310
+ for i_vv, cur_verts in enumerate(verts_list):
1311
+ cur_verts_nn = cur_verts.size(0)
1312
+ tot_verts_list.append(cur_verts)
1313
+ tot_faces_list.append(faces_list[i_vv] + nn_verts)
1314
+ nn_verts = nn_verts + cur_verts_nn
1315
+ tot_verts_list = torch.cat(tot_verts_list, dim=0)
1316
+ tot_faces_list = torch.cat(tot_faces_list, dim=0)
1317
+ return tot_verts_list, tot_faces_list
1318
+
1319
+
1320
+
1321
+ class RobotAgent: # robot and the robot #
1322
+ def __init__(self, xml_fn) -> None:
1323
+ self.xml_fn = xml_fn
1324
+ # self.args = args
1325
+
1326
+ ##
1327
+ active_robot = parse_data_from_urdf(xml_fn)
1328
+
1329
+ self.time_constant = nn.Embedding(
1330
+ num_embeddings=3, embedding_dim=1
1331
+ ).cuda()
1332
+ torch.nn.init.ones_(self.time_constant.weight) #
1333
+ self.time_constant.weight.data = self.time_constant.weight.data * 0.2 ### time_constant data #
1334
+
1335
+ self.optimizable_actions = nn.Embedding(
1336
+ num_embeddings=100, embedding_dim=60,
1337
+ ).cuda()
1338
+ torch.nn.init.zeros_(self.optimizable_actions.weight) #
1339
+
1340
+ self.learning_rate = 5e-4
1341
+
1342
+ self.active_robot = active_robot
1343
+
1344
+ # # get init states # #
1345
+ self.set_init_states()
1346
+ init_visual_pts = self.get_init_state_visual_pts()
1347
+ self.init_visual_pts = init_visual_pts
1348
+
1349
+
1350
+ def set_init_states_target_value(self, init_states):
1351
+ # glb_rot = torch.eye(n=3, dtype=torch.float32).cuda()
1352
+ # glb_trans = torch.zeros((3,), dtype=torch.float32).cuda() ### glb_trans #### and the rot 3##
1353
+
1354
+ # tot_init_states = {}
1355
+ # tot_init_states['glb_rot'] = glb_rot;
1356
+ # tot_init_states['glb_trans'] = glb_trans;
1357
+ # tot_init_states['links_init_states'] = init_states
1358
+ # self.active_robot.set_init_states_target_value(tot_init_states)
1359
+ # init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
1360
+ self.active_robot.set_initial_state(init_states)
1361
+
1362
+ def set_init_states(self):
1363
+ # glb_rot = torch.eye(n=3, dtype=torch.float32).cuda()
1364
+ # glb_trans = torch.zeros((3,), dtype=torch.float32).cuda() ### glb_trans #### and the rot 3##
1365
+
1366
+ # ### random rotation ###
1367
+ # # glb_rot_np = R.random().as_matrix()
1368
+ # # glb_rot = torch.from_numpy(glb_rot_np).float().cuda()
1369
+ # ### random rotation ###
1370
+
1371
+ # # glb_rot, glb_trans #
1372
+ # init_states = {} # init states #
1373
+ # init_states['glb_rot'] = glb_rot; #
1374
+ # init_states['glb_trans'] = glb_trans;
1375
+ # self.active_robot.set_init_states(init_states)
1376
+
1377
+ init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
1378
+ self.active_robot.set_initial_state(init_joint_states)
1379
+
1380
+ def get_init_state_visual_pts(self, ):
1381
+ # visual_pts_list = [] # compute the transformation via current state #
1382
+ # visual_pts_list, visual_pts_mass_list = self.active_robot.compute_transformation_via_current_state( visual_pts_list)
1383
+ cur_verts, cur_faces = self.active_robot.get_init_visual_pts()
1384
+ self.faces = cur_faces
1385
+ # init_visual_pts = visual_pts_list
1386
+ return cur_verts
1387
+
1388
+ def set_actions_and_update_states(self, actions, cur_timestep):
1389
+ #
1390
+ time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long).cuda()) ### time constant of the system ##
1391
+ self.active_robot.set_actions_and_update_states(actions, cur_timestep, time_cons) ###
1392
+ pass
1393
+
1394
+ def forward_stepping_test(self, ):
1395
+ # delta_glb_rot; delta_glb_trans #
1396
+ timestep_to_visual_pts = {}
1397
+ for i_step in range(50):
1398
+ actions = {}
1399
+ actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32).cuda()
1400
+ actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32).cuda()
1401
+ actions_link_actions = torch.ones((22, ), dtype=torch.float32).cuda()
1402
+ # actions_link_actions = actions_link_actions * 0.2
1403
+ actions_link_actions = actions_link_actions * -1. #
1404
+ actions['link_actions'] = actions_link_actions
1405
+ self.set_actions_and_update_states(actions=actions, cur_timestep=i_step)
1406
+
1407
+ cur_visual_pts = robot_agent.get_init_state_visual_pts()
1408
+ cur_visual_pts = cur_visual_pts.detach().cpu().numpy()
1409
+ timestep_to_visual_pts[i_step + 1] = cur_visual_pts
1410
+ return timestep_to_visual_pts
1411
+
1412
+ def initialize_optimization(self, reference_pts_dict):
1413
+ self.n_timesteps = 50
1414
+ # self.n_timesteps = 19 # first 19-timesteps optimization #
1415
+ self.nn_tot_optimization_iters = 100
1416
+ # self.nn_tot_optimization_iters = 57
1417
+ # TODO: load reference points #
1418
+ self.ts_to_reference_pts = np.load(reference_pts_dict, allow_pickle=True).item() ####
1419
+ self.ts_to_reference_pts = {
1420
+ ts // 2 + 1: torch.from_numpy(self.ts_to_reference_pts[ts]).float().cuda() for ts in self.ts_to_reference_pts
1421
+ }
1422
+
1423
+
1424
+ def forward_stepping_optimization(self, ):
1425
+ nn_tot_optimization_iters = self.nn_tot_optimization_iters
1426
+ params_to_train = []
1427
+ params_to_train += list(self.optimizable_actions.parameters())
1428
+ self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate)
1429
+
1430
+ for i_iter in range(nn_tot_optimization_iters):
1431
+
1432
+ tot_losses = []
1433
+ ts_to_robot_points = {}
1434
+ for cur_ts in range(self.n_timesteps):
1435
+ # print(f"iter: {i_iter}, cur_ts: {cur_ts}")
1436
+ # actions = {}
1437
+ # actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32).cuda()
1438
+ # actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32).cuda()
1439
+ actions_link_actions = self.optimizable_actions(torch.zeros((1,), dtype=torch.long).cuda() + cur_ts).squeeze(0)
1440
+ # actions_link_actions = actions_link_actions * 0.2
1441
+ # actions_link_actions = actions_link_actions * -1. #
1442
+ # actions['link_actions'] = actions_link_actions
1443
+ # self.set_actions_and_update_states(actions=actions, cur_timestep=cur_ts) # update the interaction #
1444
+
1445
+ with torch.no_grad():
1446
+ self.active_robot.calculate_inertia()
1447
+
1448
+ self.active_robot.set_actions_and_update_states(actions_link_actions, cur_ts, 0.2)
1449
+
1450
+ cur_visual_pts, cur_faces = self.active_robot.get_init_visual_pts()
1451
+ ts_to_robot_points[cur_ts + 1] = cur_visual_pts.clone()
1452
+
1453
+ cur_reference_pts = self.ts_to_reference_pts[cur_ts + 1]
1454
+ diff = torch.sum((cur_visual_pts - cur_reference_pts) ** 2, dim=-1)
1455
+ diff = diff.mean()
1456
+
1457
+ # diff.
1458
+ self.optimizer.zero_grad()
1459
+ diff.backward(retain_graph=True)
1460
+ # diff.backward(retain_graph=False)
1461
+ self.optimizer.step()
1462
+
1463
+ tot_losses.append(diff.item())
1464
+
1465
+
1466
+ loss = sum(tot_losses) / float(len(tot_losses))
1467
+ print(f"Iter: {i_iter}, average loss: {loss}")
1468
+ # print(f"Iter: {i_iter}, average loss: {loss.item()}, start optimizing")
1469
+ # self.optimizer.zero_grad()
1470
+ # loss.backward()
1471
+ # self.optimizer.step()
1472
+
1473
+ self.ts_to_robot_points = {
1474
+ ts: ts_to_robot_points[ts].detach().cpu().numpy() for ts in ts_to_robot_points
1475
+ }
1476
+ self.ts_to_ref_points = {
1477
+ ts: self.ts_to_reference_pts[ts].detach().cpu().numpy() for ts in ts_to_robot_points
1478
+ }
1479
+ return self.ts_to_robot_points, self.ts_to_ref_points
1480
+
1481
+
1482
+
1483
+
1484
+ def rotation_matrix_from_axis_angle(axis, angle): # rotation_matrix_from_axis_angle ->
1485
+ # sin_ = np.sin(angle) # ti.math.sin(angle)
1486
+ # cos_ = np.cos(angle) # ti.math.cos(angle)
1487
+ sin_ = torch.sin(angle) # ti.math.sin(angle)
1488
+ cos_ = torch.cos(angle) # ti.math.cos(angle)
1489
+ u_x, u_y, u_z = axis[0], axis[1], axis[2]
1490
+ u_xx = u_x * u_x
1491
+ u_yy = u_y * u_y
1492
+ u_zz = u_z * u_z
1493
+ u_xy = u_x * u_y
1494
+ u_xz = u_x * u_z
1495
+ u_yz = u_y * u_z ##
1496
+
1497
+
1498
+ row_a = torch.stack(
1499
+ [cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dim=0
1500
+ )
1501
+ # print(f"row_a: {row_a.size()}")
1502
+ row_b = torch.stack(
1503
+ [u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dim=0
1504
+ )
1505
+ # print(f"row_b: {row_b.size()}")
1506
+ row_c = torch.stack(
1507
+ [u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dim=0
1508
+ )
1509
+ # print(f"row_c: {row_c.size()}")
1510
+
1511
+ ### rot_mtx for the rot_mtx ###
1512
+ rot_mtx = torch.stack(
1513
+ [row_a, row_b, row_c], dim=-1 ### rot_matrix of he matrix ##
1514
+ )
1515
+
1516
+ return rot_mtx
1517
+
1518
+
1519
+
1520
+ #### Big TODO: the external contact forces from the manipulated object to the robot ####
1521
+ if __name__=='__main__':
1522
+
1523
+ urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/mano/mano_mean_nocoll_simplified.urdf"
1524
+ robot_agent = RobotAgent(urdf_fn)
1525
+
1526
+ ref_dict_npy = "reference_verts.npy"
1527
+ robot_agent.initialize_optimization(ref_dict_npy)
1528
+ ts_to_robot_points, ts_to_ref_points = robot_agent.forward_stepping_optimization()
1529
+ np.save(f"ts_to_robot_points.npy", ts_to_robot_points)
1530
+ np.save(f"ts_to_ref_points.npy", ts_to_ref_points)
1531
+ exit(0)
1532
+
1533
+ urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/mano/mano_mean_nocoll_simplified.urdf"
1534
+ cur_robot = parse_data_from_urdf(urdf_fn)
1535
+ # self.init_vertices, self.init_faces
1536
+ init_vertices, init_faces = cur_robot.init_vertices, cur_robot.init_faces
1537
+
1538
+
1539
+
1540
+ init_vertices = init_vertices.detach().cpu().numpy()
1541
+ init_faces = init_faces.detach().cpu().numpy()
1542
+
1543
+
1544
+ ## initial states ehre ##3
1545
+ # mesh_obj = trimesh.Trimesh(vertices=init_vertices, faces=init_faces)
1546
+ # mesh_obj.export(f"hand_urdf.ply")
1547
+
1548
+ ##### Test the set initial state function #####
1549
+ init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
1550
+ cur_robot.set_initial_state(init_joint_states)
1551
+ ##### Test the set initial state function #####
1552
+
1553
+
1554
+
1555
+
1556
+ cur_zeros_actions = torch.zeros((60, ), dtype=torch.float32).cuda()
1557
+ cur_ones_actions = torch.ones((60, ), dtype=torch.float32).cuda() # * 100
1558
+
1559
+ ts_to_mesh_verts = {}
1560
+ for i_ts in range(50):
1561
+ cur_robot.calculate_inertia()
1562
+
1563
+ cur_robot.set_actions_and_update_states(cur_ones_actions, i_ts, 0.2) ###
1564
+
1565
+
1566
+ cur_verts, cur_faces = cur_robot.get_init_visual_pts()
1567
+ cur_mesh = trimesh.Trimesh(vertices=cur_verts.detach().cpu().numpy(), faces=cur_faces.detach().cpu().numpy())
1568
+
1569
+ ts_to_mesh_verts[i_ts + i_ts] = cur_verts.detach().cpu().numpy()
1570
+ # cur_mesh.export(f"stated_mano_mesh.ply")
1571
+ # cur_mesh.export(f"zero_actioned_mano_mesh.ply")
1572
+ cur_mesh.export(f"ones_actioned_mano_mesh_ts_{i_ts}.ply")
1573
+
1574
+ np.save(f"reference_verts.npy", ts_to_mesh_verts)
1575
+
1576
+ exit(0)
1577
+
1578
+ xml_fn = "/home/xueyi/diffsim/DiffHand/assets/hand_sphere.xml"
1579
+ robot_agent = RobotAgent(xml_fn=xml_fn, args=None)
1580
+ init_visual_pts = robot_agent.init_visual_pts.detach().cpu().numpy()
1581
+ exit(0)
1582
+
models/dyn_model_utils.py ADDED
@@ -0,0 +1,1369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ # import torch
4
+ # from ..utils import Timer
5
+ import numpy as np
6
+ # import torch.nn.functional as F
7
+ import os
8
+
9
+ import argparse
10
+
11
+ from xml.etree.ElementTree import ElementTree
12
+
13
+ import trimesh
14
+ import torch
15
+ import torch.nn as nn
16
+ # import List
17
+ # class link; joint; body
18
+ ###
19
+
20
+ ## calculate transformation to the frame ##
21
+
22
+ ## the joint idx ##
23
+ ##### th_cuda_idx #####
24
+
25
+ # name to the main axis? #
26
+
27
+ # def get_body_name_to_main_axis()
28
+ # another one is just getting the joint offset positions?
29
+ # and after the revolute transformation # # all revolute joint points ####
30
+ def get_body_name_to_main_axis():
31
+ # negative y; positive x #
32
+ body_name_to_main_axis = {
33
+ "body2": -2, "body6": 1, "body10": 1, "body14": 1, "body17": 1
34
+ }
35
+ return body_name_to_main_axis ## get the body name to main axis ##
36
+
37
+ ## insert one
38
+ def plane_rotation_matrix_from_angle_xz(angle):
39
+ ## angle of
40
+ sin_ = torch.sin(angle)
41
+ cos_ = torch.cos(angle)
42
+ zero_padding = torch.zeros_like(cos_)
43
+ one_padding = torch.ones_like(cos_)
44
+ col_a = torch.stack(
45
+ [cos_, zero_padding, sin_], dim=0
46
+ )
47
+ col_b = torch.stack(
48
+ [zero_padding, one_padding, zero_padding], dim=0
49
+ )
50
+ col_c = torch.stack(
51
+ [-1. * sin_, zero_padding, cos_], dim=0
52
+ )
53
+ rot_mtx = torch.stack(
54
+ [col_a, col_b, col_c], dim=-1
55
+ )
56
+ # col_a = torch.stack(
57
+ # [cos_, sin_], dim=0 ### col of the rotation matrix
58
+ # )
59
+ # col_b = torch.stack(
60
+ # [-1. * sin_, cos_], dim=0 ## cols of the rotation matrix
61
+ # )
62
+ # rot_mtx = torch.stack(
63
+ # [col_a, col_b], dim=-1 ### rotation matrix
64
+ # )
65
+ return rot_mtx
66
+
67
+ def plane_rotation_matrix_from_angle(angle):
68
+ ## angle of
69
+ sin_ = torch.sin(angle)
70
+ cos_ = torch.cos(angle)
71
+ col_a = torch.stack(
72
+ [cos_, sin_], dim=0 ### col of the rotation matrix
73
+ )
74
+ col_b = torch.stack(
75
+ [-1. * sin_, cos_], dim=0 ## cols of the rotation matrix
76
+ )
77
+ rot_mtx = torch.stack(
78
+ [col_a, col_b], dim=-1 ### rotation matrix
79
+ )
80
+ return rot_mtx
81
+
82
+ def rotation_matrix_from_axis_angle(axis, angle): # rotation_matrix_from_axis_angle ->
83
+ # sin_ = np.sin(angle) # ti.math.sin(angle)
84
+ # cos_ = np.cos(angle) # ti.math.cos(angle)
85
+ sin_ = torch.sin(angle) # ti.math.sin(angle)
86
+ cos_ = torch.cos(angle) # ti.math.cos(angle)
87
+ u_x, u_y, u_z = axis[0], axis[1], axis[2]
88
+ u_xx = u_x * u_x
89
+ u_yy = u_y * u_y
90
+ u_zz = u_z * u_z
91
+ u_xy = u_x * u_y
92
+ u_xz = u_x * u_z
93
+ u_yz = u_y * u_z ##
94
+ # rot_mtx = np.stack(
95
+ # [np.array([cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dtype=np.float32),
96
+ # np.array([u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dtype=np.float32),
97
+ # np.array([u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dtype=np.float32)
98
+ # ], axis=-1 ### np stack
99
+ # ) ## a single
100
+
101
+ # rot_mtx = torch.stack(
102
+ # [
103
+ # torch.tensor([cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dtype=torch.float32),
104
+ # torch.tensor([u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dtype=torch.float32),
105
+ # torch.tensor([u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dtype=torch.float32)
106
+ # ], dim=-1 ## stack those torch tensors ##
107
+ # )
108
+
109
+ row_a = torch.stack(
110
+ [cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dim=0
111
+ )
112
+ # print(f"row_a: {row_a.size()}")
113
+ row_b = torch.stack(
114
+ [u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dim=0
115
+ )
116
+ # print(f"row_b: {row_b.size()}")
117
+ row_c = torch.stack(
118
+ [u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dim=0
119
+ )
120
+ # print(f"row_c: {row_c.size()}")
121
+
122
+ ### rot_mtx for the rot_mtx ###
123
+ rot_mtx = torch.stack(
124
+ [row_a, row_b, row_c], dim=-1 ### rot_matrix of he matrix ##
125
+ )
126
+
127
+ # rot_mtx = torch.stack(
128
+ # [
129
+
130
+ # torch.tensor([cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dtype=torch.float32),
131
+ # torch.tensor([u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dtype=torch.float32),
132
+ # torch.tensor([u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dtype=torch.float32)
133
+ # ], dim=-1 ## stack those torch tensors ##
134
+ # )
135
+
136
+ # rot_mtx_numpy = rot_mtx.to_numpy()
137
+ # rot_mtx_at_rot_mtx = rot_mtx @ rot_mtx.transpose()
138
+ # print(rot_mtx_at_rot_mtx)
139
+ return rot_mtx
140
+
141
+ ## joint name = "joint3" ##
142
+ # <joint name="joint3" type="revolute" axis="0.000000 0.000000 -1.000000" pos="4.689700 -4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e7"/>
143
+ class Joint:
144
+ def __init__(self, name, joint_type, axis, pos, quat, frame, damping, args) -> None:
145
+ self.name = name
146
+ self.type = joint_type
147
+ self.axis = axis
148
+ self.pos = pos
149
+ self.quat = quat
150
+ self.frame = frame
151
+ self.damping = damping
152
+
153
+ self.args = args
154
+
155
+ #### TODO: the dimension of the state vector ? ####
156
+ # self.state = 0. ## parameter
157
+ self.state = nn.Parameter(
158
+ torch.zeros((1,), dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
159
+ )
160
+ # self.rot_mtx = np.eye(3, dtypes=np.float32)
161
+ # self.trans_vec = np.zeros((3,), dtype=np.float32) ## rot m
162
+ self.rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True)
163
+ self.trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True)
164
+ # self.rot_mtx = np.eye(3, dtype=np.float32)
165
+ # self.trans_vec = np.zeros((3,), dtype=np.float32)
166
+
167
+ self.axis_rot_mtx = torch.tensor(
168
+ [
169
+ [1, 0, 0], [0, -1, 0], [0, 0, -1]
170
+ ], dtype=torch.float32
171
+ ).cuda(self.args.th_cuda_idx)
172
+
173
+ self.joint_idx = -1
174
+
175
+ self.transformed_joint_pts = self.pos.clone()
176
+
177
+ def print_grads(self, ):
178
+ print(f"rot_mtx: {self.rot_mtx.grad}")
179
+ print(f"trans_vec: {self.trans_vec.grad}")
180
+
181
+ def clear_grads(self,):
182
+ if self.rot_mtx.grad is not None:
183
+ self.rot_mtx.grad.data = self.rot_mtx.grad.data * 0.
184
+ if self.trans_vec.grad is not None:
185
+ self.trans_vec.grad.data = self.trans_vec.grad.data * 0.
186
+
187
+ def compute_transformation(self,):
188
+ # use the state to transform them # # transform # ## transform the state ##
189
+ # use the state to transform them # # transform them for the state #
190
+ if self.type == "revolute":
191
+ # print(f"computing transformation matrices with axis: {self.axis}, state: {self.state}")
192
+ # rotation matrix from the axis angle #
193
+ rot_mtx = rotation_matrix_from_axis_angle(self.axis, self.state)
194
+ # rot_mtx(p - p_v) + p_v -> rot_mtx p - rot_mtx p_v + p_v
195
+ # trans_vec = self.pos - np.matmul(rot_mtx, self.pos.reshape(3, 1)).reshape(3)
196
+ # self.rot_mtx = np.copy(rot_mtx)
197
+ # self.trans_vec = np.copy(trans_vec)
198
+ trans_vec = self.pos - torch.matmul(rot_mtx, self.pos.view(3, 1)).view(3).contiguous()
199
+ self.rot_mtx = rot_mtx
200
+ self.trans_vec = trans_vec
201
+ else:
202
+ ### TODO: implement transformations for joints in other types ###
203
+ pass
204
+
205
+ def set_state(self, name_to_state):
206
+ if self.name in name_to_state:
207
+ # self.state = name_to_state["name"]
208
+ self.state = name_to_state[self.name] ##
209
+
210
+ def set_state_via_vec(self, state_vec): ### transform points via the state vectors here ###
211
+ if self.joint_idx >= 0:
212
+ self.state = state_vec[self.joint_idx] ## give the parameter to the parameters ##
213
+
214
+ def set_joint_idx(self, joint_name_to_idx):
215
+ if self.name in joint_name_to_idx:
216
+ self.joint_idx = joint_name_to_idx[self.name]
217
+
218
+
219
+ def set_args(self, args):
220
+ self.args = args
221
+
222
+
223
+ def compute_transformation_via_state_vals(self, state_vals):
224
+ if self.joint_idx >= 0:
225
+ cur_joint_state = state_vals[self.joint_idx]
226
+ else:
227
+ cur_joint_state = self.state
228
+ # use the state to transform them # # transform # ## transform the state ##
229
+ # use the state to transform them # # transform them for the state #
230
+ if self.type == "revolute":
231
+ # print(f"computing transformation matrices with axis: {self.axis}, state: {self.state}")
232
+ # rotation matrix from the axis angle #
233
+ rot_mtx = rotation_matrix_from_axis_angle(self.axis, cur_joint_state)
234
+ # rot_mtx(p - p_v) + p_v -> rot_mtx p - rot_mtx p_v + p_v
235
+ # trans_vec = self.pos - np.matmul(rot_mtx, self.pos.reshape(3, 1)).reshape(3)
236
+ # self.rot_mtx = np.copy(rot_mtx)
237
+ # self.trans_vec = np.copy(trans_vec)
238
+ trans_vec = self.pos - torch.matmul(rot_mtx, self.pos.view(3, 1)).view(3).contiguous()
239
+ self.rot_mtx = rot_mtx
240
+ self.trans_vec = trans_vec
241
+ elif self.type == "free2d":
242
+ cur_joint_state = state_vals # still only for the current scene #
243
+ # cur_joint_state
244
+ cur_joint_rot_val = state_vals[2]
245
+ ### rot_mtx ### ### rot_mtx ###
246
+ rot_mtx = plane_rotation_matrix_from_angle_xz(cur_joint_rot_val)
247
+ # rot_mtx = plane_rotation_matrix_from_angle(cur_joint_rot_val) ### 2 x 2 rot matrix #
248
+ # cur joint rot val #
249
+ # rot mtx of the rotation
250
+ # xy_val =
251
+ # axis_rot_mtx
252
+ # R_axis^T ( R R_axis (p) + trans (with the y-axis padded) )
253
+ cur_trans_vec = torch.stack(
254
+ [state_vals[0], torch.zeros_like(state_vals[0]), state_vals[1]], dim=0
255
+ )
256
+ # cur_trans_vec #
257
+ rot_mtx = torch.matmul(self.axis_rot_mtx.transpose(1, 0), torch.matmul(rot_mtx, self.axis_rot_mtx))
258
+ trans_vec = torch.matmul(self.axis_rot_mtx.transpose(1, 0), cur_trans_vec.unsqueeze(-1).contiguous()).squeeze(-1).contiguous() + self.pos
259
+
260
+ self.rot_mtx = rot_mtx
261
+ self.trans_vec = trans_vec ## rot_mtx and trans_vec #
262
+ else:
263
+ ### TODO: implement transformations for joints in other types ###
264
+ pass
265
+ return self.rot_mtx, self.trans_vec
266
+
267
+ def transform_joints_via_parent_rot_trans_infos(self, parent_rot_mtx, parent_trans_vec):
268
+ #
269
+ # if self.type == "revolute" or self.type == "free2d":
270
+ transformed_joint_pts = torch.matmul(parent_rot_mtx, self.pos.view(3 ,1).contiguous()).view(3).contiguous() + parent_trans_vec
271
+ # else:
272
+ self.transformed_joint_pts = transformed_joint_pts ### get self transformed joint pts here ###
273
+ return transformed_joint_pts
274
+ # if self.joint_idx >= 0:
275
+ # cur_joint_state = state_vals[self.joint_idx]
276
+ # else:
277
+ # cur_joint_state = self.state # state #
278
+ # # use the state to transform them # # transform ### transform the state ##
279
+ # # use the state to transform them # # transform them for the state # transform for the state #
280
+ # if self.type == "revolute":
281
+ # # print(f"computing transformation matrices with axis: {self.axis}, state: {self.state}")
282
+ # # rotation matrix from the axis angle #
283
+ # rot_mtx = rotation_matrix_from_axis_angle(self.axis, cur_joint_state)
284
+ # # rot_mtx(p - p_v) + p_v -> rot_mtx p - rot_mtx p_v + p_v
285
+ # # trans_vec = self.pos - np.matmul(rot_mtx, self.pos.reshape(3, 1)).reshape(3)
286
+ # # self.rot_mtx = np.copy(rot_mtx)
287
+ # # self.trans_vec = np.copy(trans_vec)
288
+ # trans_vec = self.pos - torch.matmul(rot_mtx, self.pos.view(3, 1)).view(3).contiguous()
289
+ # self.rot_mtx = rot_mtx
290
+ # self.trans_vec = trans_vec
291
+ # elif self.type == "free2d":
292
+ # cur_joint_state = state_vals # still only for the current scene #
293
+ # # cur_joint_state
294
+ # cur_joint_rot_val = state_vals[2]
295
+ # ### rot_mtx ### ### rot_mtx ###
296
+ # rot_mtx = plane_rotation_matrix_from_angle_xz(cur_joint_rot_val)
297
+ # # rot_mtx = plane_rotation_matrix_from_angle(cur_joint_rot_val) ### 2 x 2 rot matrix #
298
+ # # cur joint rot val #
299
+ # # rot mtx of the rotation
300
+ # # xy_val =
301
+ # # axis_rot_mtx
302
+ # # R_axis^T ( R R_axis (p) + trans (with the y-axis padded) )
303
+ # cur_trans_vec = torch.stack(
304
+ # [state_vals[0], torch.zeros_like(state_vals[0]), state_vals[1]], dim=0
305
+ # )
306
+ # # cur_trans_vec #
307
+ # rot_mtx = torch.matmul(self.axis_rot_mtx.transpose(1, 0), torch.matmul(rot_mtx, self.axis_rot_mtx))
308
+ # trans_vec = torch.matmul(self.axis_rot_mtx.transpose(1, 0), cur_trans_vec.unsqueeze(-1).contiguous()).squeeze(-1).contiguous() + self.pos
309
+
310
+ # self.rot_mtx = rot_mtx
311
+ # self.trans_vec = trans_vec ## rot_mtx and trans_vec #
312
+ # else:
313
+ # ### TODO: implement transformations for joints in other types ###
314
+ # pass
315
+ # return self.rot_mtx, self.trans_vec
316
+
317
+ ## fixed joint -> translation and rotation ##
318
+ ## revolute joint -> can be actuated ##
319
+ ## set states and compute the transfromations in a top-to-down manner ##
320
+
321
+ ## trnasform the robot -> a list of qs ##
322
+ ## a list of qs ##
323
+ ## transform from the root of the robot; pass qs from the root to the leaf node ##
324
+ ## visual meshes or visual meshes from the basic description of robots ##
325
+ ## visual meshes; or visual points ##
326
+ ## visual meshes -> transform them into the visual density values here ##
327
+ ## visual meshes -> transform them into the ## into the visual counterparts ##
328
+ ## ## visual meshes -> ## ## ##
329
+ # <body name="body0" type="mesh" filename="hand/body0.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.700000 0.700000 0.700000 1"/>
330
+ class Body:
331
+ def __init__(self, name, body_type, filename, pos, quat, transform_type, density, mu, rgba, radius, args) -> None:
332
+ self.name = name
333
+ self.body_type = body_type
334
+ ### for mesh object ###
335
+ self.filename = filename
336
+ self.args = args
337
+
338
+ self.pos = pos
339
+ self.quat = quat
340
+ self.transform_type = transform_type
341
+ self.density = density
342
+ self.mu = mu
343
+ self.rgba = rgba
344
+
345
+ ### for sphere object ###
346
+ self.radius = radius
347
+ ## or vertices here #
348
+ ## pass them to the child and treat them as the parent transformation ##
349
+
350
+ self.visual_pts_ref = None
351
+ self.visual_faces_ref = None
352
+
353
+ self.visual_pts = None ## visual pts and
354
+
355
+ self.body_name_to_main_axis = get_body_name_to_main_axis() ### get the body name to main axis here #
356
+
357
+ self.get_visual_counterparts()
358
+
359
+
360
+ def update_radius(self,):
361
+ self.radius.data = self.radius.data - self.radius.grad.data
362
+
363
+ self.radius.grad.data = self.radius.grad.data * 0.
364
+
365
+
366
+ def update_xml_file(self,):
367
+ xml_content_with_flexible_radius = f"""<redmax model="hand">
368
+ <!-- 1) change the damping value here? -->
369
+ <!-- 2) change the center of mass -->
370
+ <option integrator="BDF2" timestep="0.01" gravity="0. 0. -0.000098"/>
371
+ <ground pos="0 0 -10" normal="0 0 1"/>
372
+ <default>
373
+ <general_primitive_contact kn="1e6" kt="1e3" mu="0.8" damping="3e1" />
374
+ </default>
375
+
376
+ <robot>
377
+ <link name="link0">
378
+ <joint name="joint0" type="fixed" pos="0 0 0" quat="1 0 0 0" frame="WORLD"/>
379
+ <body name="body0" type="mesh" filename="hand/body0.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.700000 0.700000 0.700000 1"/>
380
+ <link name="link1">
381
+ <joint name="joint1" type="revolute" axis="0.000000 0.000000 -1.000000" pos="-3.300000 -5.689700 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
382
+ <body name="body1" type="mesh" filename="hand/body1.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.600000 0.600000 0.600000 1"/>
383
+ <link name="link2">
384
+ <joint name="joint2" type="revolute" axis="1.000000 0.000000 0.000000" pos="-3.300000 -7.680000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
385
+ <body name="body2" type="mesh" filename="hand/body2.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.500000 0.500000 0.500000 1"/>
386
+ </link>
387
+ </link>
388
+ <link name="link3">
389
+ <!-- revolute joint -->
390
+ <joint name="joint3" type="revolute" axis="0.000000 0.000000 -1.000000" pos="4.689700 -4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
391
+ <body name="body3" type="mesh" filename="hand/body3.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.600000 0.600000 0.600000 1"/>
392
+ <link name="link4">
393
+ <joint name="joint4" type="revolute" axis="0.000000 1.000000 0.000000" pos="6.680000 -4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
394
+ <body name="body4" type="mesh" filename="hand/body4.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.500000 0.500000 0.500000 1"/>
395
+ <link name="link5">
396
+ <joint name="joint5" type="revolute" axis="0.000000 1.000000 0.000000" pos="11.080000 -4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
397
+ <body name="body5" type="mesh" filename="hand/body5.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.400000 0.400000 0.400000 1"/>
398
+ <link name="link6">
399
+ <joint name="joint6" type="revolute" axis="0.000000 1.000000 0.000000" pos="15.480000 -4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
400
+ <body name="body6" type="mesh" filename="hand/body6.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.300000 0.300000 0.300000 1"/>
401
+ </link>
402
+ </link>
403
+ </link>
404
+ </link>
405
+ <link name="link7">
406
+ <joint name="joint7" type="revolute" axis="0.000000 0.000000 -1.000000" pos="4.689700 -1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
407
+ <body name="body7" type="mesh" filename="hand/body7.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.600000 0.600000 0.600000 1"/>
408
+ <link name="link8">
409
+ <joint name="joint8" type="revolute" axis="0.000000 1.000000 0.000000" pos="6.680000 -1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
410
+ <body name="body8" type="mesh" filename="hand/body8.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.500000 0.500000 0.500000 1"/>
411
+ <link name="link9">
412
+ <joint name="joint9" type="revolute" axis="0.000000 1.000000 0.000000" pos="11.080000 -1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
413
+ <body name="body9" type="mesh" filename="hand/body9.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.400000 0.400000 0.400000 1"/>
414
+ <link name="link10">
415
+ <joint name="joint10" type="revolute" axis="0.000000 1.000000 0.000000" pos="15.480000 -1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
416
+ <body name="body10" type="mesh" filename="hand/body10.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.300000 0.300000 0.300000 1"/>
417
+ </link>
418
+ </link>
419
+ </link>
420
+ </link>
421
+ <link name="link11">
422
+ <joint name="joint11" type="revolute" axis="0.000000 0.000000 -1.000000" pos="4.689700 1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
423
+ <body name="body11" type="mesh" filename="hand/body11.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.600000 0.600000 0.600000 1"/>
424
+ <link name="link12">
425
+ <joint name="joint12" type="revolute" axis="0.000000 1.000000 0.000000" pos="6.680000 1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
426
+ <body name="body12" type="mesh" filename="hand/body12.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.500000 0.500000 0.500000 1"/>
427
+ <link name="link13">
428
+ <joint name="joint13" type="revolute" axis="0.000000 1.000000 0.000000" pos="11.080000 1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
429
+ <body name="body13" type="mesh" filename="hand/body13.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.400000 0.400000 0.400000 1"/>
430
+ <link name="link14">
431
+ <joint name="joint14" type="revolute" axis="0.000000 1.000000 0.000000" pos="15.480000 1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
432
+ <body name="body14" type="mesh" filename="hand/body14.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.300000 0.300000 0.300000 1"/>
433
+ </link>
434
+ </link>
435
+ </link>
436
+ </link>
437
+ <link name="link15">
438
+ <joint name="joint15" type="revolute" axis="0.000000 0.000000 -1.000000" pos="4.689700 4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
439
+ <body name="body15" type="mesh" filename="hand/body15.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.600000 0.600000 0.600000 1"/>
440
+ <link name="link16">
441
+ <joint name="joint16" type="revolute" axis="0.000000 1.000000 0.000000" pos="6.680000 4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
442
+ <body name="body16" type="mesh" filename="hand/body16.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.500000 0.500000 0.500000 1"/>
443
+ <link name="link17">
444
+ <joint name="joint17" type="revolute" axis="0.000000 1.000000 0.000000" pos="11.080000 4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
445
+ <body name="body17" type="mesh" filename="hand/body17.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.400000 0.400000 0.400000 1"/>
446
+ </link>
447
+ </link>
448
+ </link>
449
+ </link>
450
+ </robot>
451
+
452
+ <robot>
453
+ <link name="sphere">
454
+ <joint name="sphere" type="free2d" pos = "10. 0.0 3.5" quat="1 -1 0 0" format="LOCAL" damping="0"/>
455
+ <body name="sphere" type="sphere" radius="{self.radius[0].detach().cpu().item()}" pos="0 0 0" quat="1 0 0 0" density="0.5" mu="0" texture="resources/textures/sphere.jpg"/>
456
+ </link>
457
+ </robot>
458
+
459
+ <contact>
460
+ <ground_contact body="sphere" kn="1e6" kt="1e3" mu="0.8" damping="3e1"/>
461
+ <general_primitive_contact general_body="body0" primitive_body="sphere"/>
462
+ <general_primitive_contact general_body="body1" primitive_body="sphere"/>
463
+ <general_primitive_contact general_body="body2" primitive_body="sphere"/>
464
+ <general_primitive_contact general_body="body3" primitive_body="sphere"/>
465
+ <general_primitive_contact general_body="body4" primitive_body="sphere"/>
466
+ <general_primitive_contact general_body="body5" primitive_body="sphere"/>
467
+ <general_primitive_contact general_body="body6" primitive_body="sphere"/>
468
+ <general_primitive_contact general_body="body7" primitive_body="sphere"/>
469
+ <general_primitive_contact general_body="body8" primitive_body="sphere"/>
470
+ <general_primitive_contact general_body="body9" primitive_body="sphere"/>
471
+ <general_primitive_contact general_body="body10" primitive_body="sphere"/>
472
+ <general_primitive_contact general_body="body11" primitive_body="sphere"/>
473
+ <general_primitive_contact general_body="body12" primitive_body="sphere"/>
474
+ <general_primitive_contact general_body="body13" primitive_body="sphere"/>
475
+ <general_primitive_contact general_body="body14" primitive_body="sphere"/>
476
+ <general_primitive_contact general_body="body15" primitive_body="sphere"/>
477
+ <general_primitive_contact general_body="body16" primitive_body="sphere"/>
478
+ <general_primitive_contact general_body="body17" primitive_body="sphere"/>
479
+ </contact>
480
+
481
+ <actuator>
482
+ <motor joint="joint1" ctrl="force" ctrl_range="-3e5 3e5"/>
483
+ <motor joint="joint2" ctrl="force" ctrl_range="-3e5 3e5"/>
484
+ <motor joint="joint3" ctrl="force" ctrl_range="-3e5 3e5"/>
485
+ <motor joint="joint4" ctrl="force" ctrl_range="-3e5 3e5"/>
486
+ <motor joint="joint5" ctrl="force" ctrl_range="-3e5 3e5"/>
487
+ <motor joint="joint6" ctrl="force" ctrl_range="-3e5 3e5"/>
488
+ <motor joint="joint7" ctrl="force" ctrl_range="-3e5 3e5"/>
489
+ <motor joint="joint8" ctrl="force" ctrl_range="-3e5 3e5"/>
490
+ <motor joint="joint9" ctrl="force" ctrl_range="-3e5 3e5"/>
491
+ <motor joint="joint10" ctrl="force" ctrl_range="-3e5 3e5"/>
492
+ <motor joint="joint11" ctrl="force" ctrl_range="-3e5 3e5"/>
493
+ <motor joint="joint12" ctrl="force" ctrl_range="-3e5 3e5"/>
494
+ <motor joint="joint13" ctrl="force" ctrl_range="-3e5 3e5"/>
495
+ <motor joint="joint14" ctrl="force" ctrl_range="-3e5 3e5"/>
496
+ <motor joint="joint15" ctrl="force" ctrl_range="-3e5 3e5"/>
497
+ <motor joint="joint16" ctrl="force" ctrl_range="-3e5 3e5"/>
498
+ <motor joint="joint17" ctrl="force" ctrl_range="-3e5 3e5"/>
499
+ </actuator>
500
+ </redmax>
501
+ """
502
+ xml_loading_fn = "/home/xueyi/diffsim/DiffHand/assets/hand_sphere_free_sphere_geo_test.xml"
503
+ with open(xml_loading_fn, "w") as wf:
504
+ wf.write(xml_content_with_flexible_radius)
505
+ wf.close()
506
+
507
+ ### get visual pts colorrs ### ###
508
+ def get_visual_pts_colors(self, ):
509
+ tot_visual_pts_nn = self.visual_pts_ref.size(0)
510
+ # self.pts_rgba = [torch.from_numpy(self.rgba).float().cuda(self.args.th_cuda_idx) for _ in range(tot_visual_pts_nn)] # total visual pts nn
511
+ self.pts_rgba = [torch.tensor(self.rgba.data).cuda(self.args.th_cuda_idx) for _ in range(tot_visual_pts_nn)] # total visual pts nn skeletong
512
+ self.pts_rgba = torch.stack(self.pts_rgba, dim=0) #
513
+ return self.pts_rgba
514
+
515
+ def get_visual_counterparts(self,):
516
+ ### TODO: implement this for visual counterparts ### mid line regression and name to body mapping relations --- for each body, how to calculate the midline and other properties?
517
+ ######## get body type ########## get visual midline of the input mesh and the mesh vertices? ######## # skeleton of the hand -> 21 points ? retarget from this hand to the mano hand and use the mano hand priors?
518
+ if self.body_type == "sphere":
519
+ filename = "/home/xueyi/diffsim/DiffHand/examples/save_res/hand_sphere_demo/meshes/18.obj"
520
+ if not os.path.exists(filename):
521
+ filename = "/data/xueyi/diffsim/DiffHand/assets/18.obj"
522
+ else:
523
+ filename = self.filename
524
+ rt_asset_path = "/home/xueyi/diffsim/DiffHand/assets" ### assets folder ###
525
+ if not os.path.exists(rt_asset_path):
526
+ rt_asset_path = "/data/xueyi/diffsim/DiffHand/assets"
527
+ filename = os.path.join(rt_asset_path, filename)
528
+ body_mesh = trimesh.load(filename, process=False)
529
+ # verts = np.array(body_mesh.vertices)
530
+ # faces = np.array(body_mesh.faces, dtype=np.long)
531
+
532
+ # self.visual_pts_ref = np.copy(verts) ## verts ##
533
+ # self.visual_faces_ref = np.copy(faces) ## faces
534
+ # self.visual_pts_ref #
535
+
536
+ #### body_mesh.vertices ####
537
+ # verts = torch.tensor(body_mesh.vertices, dtype=torch.float32).cuda(self.args.th_cuda_idx)
538
+ # faces = torch.tensor(body_mesh.faces, dtype=torch.long).cuda(self.args.th_cuda_idx)
539
+ #### body_mesh.vertices ####
540
+
541
+ # self.pos = nn.Parameter(
542
+ # torch.tensor([0., 0., 0.], dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
543
+ # )
544
+
545
+ self.pos = nn.Parameter(
546
+ torch.tensor(self.pos.detach().cpu().tolist(), dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
547
+ )
548
+
549
+ ### Step 1 ### -> set the pos to the correct initial pose ###
550
+
551
+ self.radius = nn.Parameter(
552
+ torch.tensor([self.args.initial_radius], dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
553
+ )
554
+ ### visual pts ref ### ## body_mesh.vertices -> #
555
+ self.visual_pts_ref = torch.tensor(body_mesh.vertices, dtype=torch.float32).cuda(self.args.th_cuda_idx)
556
+
557
+ # if self.name == "sphere":
558
+ # self.visual_pts_ref = self.visual_pts_ref / 2. # the initial radius
559
+ # self.visual_pts_ref = self.visual_pts_ref * self.radius ## multiple the initla radius #
560
+
561
+ # self.visual_pts_ref = nn.Parameter(
562
+ # torch.tensor(body_mesh.vertices, dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
563
+ # )
564
+ # self.visual_faces_ref = nn.Parameter(
565
+ # torch.tensor(body_mesh.faces, dtype=torch.long, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
566
+ # )
567
+ self.visual_faces_ref = torch.tensor(body_mesh.faces, dtype=torch.long).cuda(self.args.th_cuda_idx)
568
+
569
+ # body_name_to_main_axis
570
+ # body_name_to_main_axis for the body_name_to_main_axis #
571
+ # visual_faces_ref #
572
+ # visual_pts_ref #
573
+
574
+ minn_pts, _ = torch.min(self.visual_pts_ref, dim=0) ### get the visual pts minn ###
575
+ maxx_pts, _ = torch.max(self.visual_pts_ref, dim=0) ### visual pts maxx ###
576
+ mean_pts = torch.mean(self.visual_pts_ref, dim=0) ### mean_pts of the mean_pts ###
577
+
578
+ if self.name in self.body_name_to_main_axis:
579
+ cur_main_axis = self.body_name_to_main_axis[self.name] ## get the body name ##
580
+
581
+ if cur_main_axis == -2:
582
+ main_axis_pts = minn_pts[1] # the main axis pts
583
+ full_main_axis_pts = torch.tensor([mean_pts[0], main_axis_pts, mean_pts[2]], dtype=torch.float32).cuda(self.args.th_cuda_idx)
584
+ elif cur_main_axis == 1:
585
+ main_axis_pts = maxx_pts[0] # the maxx axis pts
586
+ full_main_axis_pts = torch.tensor([main_axis_pts, mean_pts[1], mean_pts[2]], dtype=torch.float32).cuda(self.args.th_cuda_idx)
587
+ self.full_main_axis_pts_ref = full_main_axis_pts
588
+ else:
589
+ self.full_main_axis_pts_ref = mean_pts.clone() ### get the mean pts ###
590
+ # mean_pts
591
+ # main_axis_pts =
592
+
593
+
594
+ # self.visual_pts_ref = verts #
595
+ # self.visual_faces_ref = faces #
596
+ # get visual points colors # the color should be an optimizable property # # # or init visual point colors here ## # or init visual point colors here #
597
+ # simulatoable assets ## for the
598
+
599
+ def transform_visual_pts_ref(self,):
600
+ if self.name == "sphere":
601
+ visual_pts_ref = self.visual_pts_ref / 2. #
602
+ visual_pts_ref = visual_pts_ref * self.radius
603
+ else:
604
+ visual_pts_ref = self.visual_pts_ref
605
+ return visual_pts_ref
606
+
607
+ def transform_visual_pts(self, rot_mtx, trans_vec):
608
+ visual_pts_ref = self.transform_visual_pts_ref()
609
+ # rot_mtx: 3 x 3 numpy array
610
+ # trans_vec: 3 numpy array
611
+ # print(f"transforming body with rot_mtx: {rot_mtx} and trans_vec: {trans_vec}")
612
+ # self.visual_pts = np.matmul(rot_mtx, self.visual_pts_ref.T).T + trans_vec.reshape(1, 3) # reshape #
613
+ # print(f"rot_mtx: {rot_mtx}, trans_vec: {trans_vec}")
614
+ self.visual_pts = torch.matmul(rot_mtx, visual_pts_ref.transpose(1, 0)).transpose(1, 0) + trans_vec.unsqueeze(0)
615
+
616
+ # full_main_axis_pts ->
617
+ self.full_main_axis_pts = torch.matmul(rot_mtx, self.full_main_axis_pts_ref.unsqueeze(-1)).contiguous().squeeze(-1) + trans_vec
618
+ self.full_main_axis_pts = self.full_main_axis_pts.unsqueeze(0)
619
+
620
+ return self.visual_pts
621
+
622
+ def get_tot_transformed_joints(self, transformed_joints):
623
+ if self.name in self.body_name_to_main_axis:
624
+ transformed_joints.append(self.full_main_axis_pts)
625
+ return transformed_joints
626
+
627
+ def get_nn_pts(self,):
628
+ self.nn_pts = self.visual_pts_ref.size(0)
629
+ return self.nn_pts
630
+
631
+ def set_args(self, args):
632
+ self.args = args
633
+
634
+ def clear_grad(self, ):
635
+ if self.pos.grad is not None:
636
+ self.pos.grad.data = self.pos.grad.data * 0.
637
+ if self.radius.grad is not None:
638
+ self.radius.grad.data = self.radius.grad.data * 0.
639
+
640
+
641
+ # get the visual counterparts of the boyd mesh or elements #
642
+
643
+ # xyz attribute ## ## xyz attribute #
644
+
645
+ # use get_name_to_visual_pts
646
+ # use get_name_to_visual_pts_faces to get the transformed visual pts and faces #
647
+ class Link:
648
+ def __init__(self, name, joint: Joint, body: Body, children, args) -> None:
649
+
650
+ self.joint = joint
651
+ self.body = body
652
+ self.children = children
653
+ self.name = name
654
+
655
+ self.args = args
656
+
657
+ # joint # parent_rot_mtx, parent_trans_vec
658
+ self.parent_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True)
659
+ self.parent_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True)
660
+ self.curr_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True)
661
+ self.curr_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True)
662
+ #
663
+ self.tot_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True)
664
+ self.tot_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True) ## torch zeros #
665
+
666
+ def print_grads(self, ): ### print grads here ###
667
+ print(f"parent_rot_mtx: {self.parent_rot_mtx.grad}")
668
+ print(f"parent_trans_vec: {self.parent_trans_vec.grad}")
669
+ print(f"curr_rot_mtx: {self.curr_rot_mtx.grad}")
670
+ print(f"curr_trans_vec: {self.curr_trans_vec.grad}")
671
+ print(f"tot_rot_mtx: {self.tot_rot_mtx.grad}")
672
+ print(f"tot_trans_vec: {self.tot_trans_vec.grad}")
673
+ print(f"Joint")
674
+ self.joint.print_grads()
675
+ for cur_link in self.children:
676
+ cur_link.print_grads()
677
+
678
+
679
+ def set_state(self, name_to_state):
680
+ self.joint.set_state(name_to_state=name_to_state)
681
+ for child_link in self.children:
682
+ child_link.set_state(name_to_state)
683
+
684
+
685
+ def set_state_via_vec(self, state_vec): #
686
+ self.joint.set_state_via_vec(state_vec)
687
+ for child_link in self.children:
688
+ child_link.set_state_via_vec(state_vec)
689
+ # if self.joint_idx >= 0:
690
+ # self.state = state_vec[self.joint_idx]
691
+
692
+ ##
693
+ def get_tot_transformed_joints(self, transformed_joints):
694
+ cur_joint_transformed_pts = self.joint.transformed_joint_pts.unsqueeze(0) ### 3 pts
695
+ transformed_joints.append(cur_joint_transformed_pts)
696
+ transformed_joints = self.body.get_tot_transformed_joints(transformed_joints)
697
+ # if self.joint.name
698
+ for cur_link in self.children:
699
+ transformed_joints = cur_link.get_tot_transformed_joints(transformed_joints)
700
+ return transformed_joints
701
+
702
+ def compute_transformation_via_state_vecs(self, state_vals, parent_rot_mtx, parent_trans_vec, visual_pts_list):
703
+ # state vecs and rot mtx # state vecs #####
704
+ joint_rot_mtx, joint_trans_vec = self.joint.compute_transformation_via_state_vals(state_vals=state_vals)
705
+
706
+ self.curr_rot_mtx = joint_rot_mtx
707
+ self.curr_trans_vec = joint_trans_vec
708
+
709
+ self.joint.transform_joints_via_parent_rot_trans_infos(parent_rot_mtx=parent_rot_mtx, parent_trans_vec=parent_trans_vec) ## get rot and trans mtx and vecs ###
710
+
711
+ tot_parent_rot_mtx = torch.matmul(parent_rot_mtx, joint_rot_mtx)
712
+ tot_parent_trans_vec = torch.matmul(parent_rot_mtx, joint_trans_vec.unsqueeze(-1)).view(3) + parent_trans_vec
713
+
714
+ self.tot_rot_mtx = tot_parent_rot_mtx
715
+ self.tot_trans_vec = tot_parent_trans_vec
716
+
717
+ # self.tot_rot_mtx = np.copy(tot_parent_rot_mtx)
718
+ # self.tot_trans_vec = np.copy(tot_parent_trans_vec)
719
+
720
+ ### visual_pts_list for recording visual pts ###
721
+
722
+ cur_body_visual_pts = self.body.transform_visual_pts(rot_mtx=self.tot_rot_mtx, trans_vec=self.tot_trans_vec)
723
+ visual_pts_list.append(cur_body_visual_pts)
724
+
725
+ for cur_link in self.children:
726
+ # cur_link.parent_rot_mtx = np.copy(tot_parent_rot_mtx) ### set children parent rot mtx and the trans vec
727
+ # cur_link.parent_trans_vec = np.copy(tot_parent_trans_vec) ##
728
+ cur_link.parent_rot_mtx = tot_parent_rot_mtx ### set children parent rot mtx and the trans vec #
729
+ cur_link.parent_trans_vec = tot_parent_trans_vec ##
730
+ # cur_link.compute_transformation() ## compute self's transformations
731
+ cur_link.compute_transformation_via_state_vecs(state_vals, tot_parent_rot_mtx, tot_parent_trans_vec, visual_pts_list)
732
+
733
+ def get_visual_pts_rgba_values(self, pts_rgba_vals_list):
734
+
735
+ cur_body_visual_rgba_vals = self.body.get_visual_pts_colors()
736
+ pts_rgba_vals_list.append(cur_body_visual_rgba_vals)
737
+
738
+ for cur_link in self.children:
739
+ cur_link.get_visual_pts_rgba_values(pts_rgba_vals_list)
740
+
741
+
742
+
743
+ def compute_transformation(self,):
744
+ self.joint.compute_transformation()
745
+ # self.curr_rot_mtx = np.copy(self.joint.rot_mtx)
746
+ # self.curr_trans_vec = np.copy(self.joint.trans_vec)
747
+
748
+ self.curr_rot_mtx = self.joint.rot_mtx
749
+ self.curr_trans_vec = self.joint.trans_vec
750
+ # rot_p (rot_c p + trans_c) + trans_p #
751
+ # rot_p rot_c p + rot_p trans_c + trans_p #
752
+ #### matmul ####
753
+ # tot_parent_rot_mtx = np.matmul(self.parent_rot_mtx, self.curr_rot_mtx)
754
+ # tot_parent_trans_vec = np.matmul(self.parent_rot_mtx, self.curr_trans_vec.reshape(3, 1)).reshape(3) + self.parent_trans_vec
755
+
756
+ tot_parent_rot_mtx = torch.matmul(self.parent_rot_mtx, self.curr_rot_mtx)
757
+ tot_parent_trans_vec = torch.matmul(self.parent_rot_mtx, self.curr_trans_vec.unsqueeze(-1)).view(3) + self.parent_trans_vec
758
+
759
+ self.tot_rot_mtx = tot_parent_rot_mtx
760
+ self.tot_trans_vec = tot_parent_trans_vec
761
+
762
+ # self.tot_rot_mtx = np.copy(tot_parent_rot_mtx)
763
+ # self.tot_trans_vec = np.copy(tot_parent_trans_vec)
764
+
765
+ for cur_link in self.children:
766
+ # cur_link.parent_rot_mtx = np.copy(tot_parent_rot_mtx) ### set children parent rot mtx and the trans vec
767
+ # cur_link.parent_trans_vec = np.copy(tot_parent_trans_vec) ##
768
+ cur_link.parent_rot_mtx = tot_parent_rot_mtx ### set children parent rot mtx and the trans vec #
769
+ cur_link.parent_trans_vec = tot_parent_trans_vec ##
770
+ cur_link.compute_transformation() ## compute self's transformations
771
+
772
+ def get_name_to_visual_pts_faces(self, name_to_visual_pts_faces):
773
+ # transform_visual_pts # ## rot_mt
774
+ self.body.transform_visual_pts(rot_mtx=self.tot_rot_mtx, trans_vec=self.tot_trans_vec)
775
+ name_to_visual_pts_faces[self.body.name] = {"pts": self.body.visual_pts, "faces": self.body.visual_faces_ref}
776
+ for cur_link in self.children:
777
+ cur_link.get_name_to_visual_pts_faces(name_to_visual_pts_faces) ## transform the pts faces
778
+
779
+ def get_visual_pts_list(self, visual_pts_list):
780
+ # transform_visual_pts # ## rot_mt
781
+ self.body.transform_visual_pts(rot_mtx=self.tot_rot_mtx, trans_vec=self.tot_trans_vec)
782
+ visual_pts_list.append(self.body.visual_pts) # body template #
783
+ # name_to_visual_pts_faces[self.body.name] = {"pts": self.body.visual_pts, "faces": self.body.visual_faces_ref}
784
+ for cur_link in self.children:
785
+ # cur_link.get_name_to_visual_pts_faces(name_to_visual_pts_faces) ## transform the pts faces
786
+ cur_link.get_visual_pts_list(visual_pts_list)
787
+
788
+
789
+
790
+ def set_joint_idx(self, joint_name_to_idx):
791
+ self.joint.set_joint_idx(joint_name_to_idx)
792
+ for cur_link in self.children:
793
+ cur_link.set_joint_idx(joint_name_to_idx)
794
+ # if self.name in joint_name_to_idx:
795
+ # self.joint_idx = joint_name_to_idx[self.name]
796
+
797
+ def get_nn_pts(self,):
798
+ nn_pts = 0
799
+ nn_pts += self.body.get_nn_pts()
800
+ for cur_link in self.children:
801
+ nn_pts += cur_link.get_nn_pts()
802
+ self.nn_pts = nn_pts
803
+ return self.nn_pts
804
+
805
+ def clear_grads(self,):
806
+
807
+ if self.parent_rot_mtx.grad is not None:
808
+ self.parent_rot_mtx.grad.data = self.parent_rot_mtx.grad.data * 0.
809
+ if self.parent_trans_vec.grad is not None:
810
+ self.parent_trans_vec.grad.data = self.parent_trans_vec.grad.data * 0.
811
+ if self.curr_rot_mtx.grad is not None:
812
+ self.curr_rot_mtx.grad.data = self.curr_rot_mtx.grad.data * 0.
813
+ if self.curr_trans_vec.grad is not None:
814
+ self.curr_trans_vec.grad.data = self.curr_trans_vec.grad.data * 0.
815
+ if self.tot_rot_mtx.grad is not None:
816
+ self.tot_rot_mtx.grad.data = self.tot_rot_mtx.grad.data * 0.
817
+ if self.tot_trans_vec.grad is not None:
818
+ self.tot_trans_vec.grad.data = self.tot_trans_vec.grad.data * 0.
819
+ # print(f"parent_rot_mtx: {self.parent_rot_mtx.grad}")
820
+ # print(f"parent_trans_vec: {self.parent_trans_vec.grad}")
821
+ # print(f"curr_rot_mtx: {self.curr_rot_mtx.grad}")
822
+ # print(f"curr_trans_vec: {self.curr_trans_vec.grad}")
823
+ # print(f"tot_rot_mtx: {self.tot_rot_mtx.grad}")
824
+ # print(f"tot_trans_vec: {self.tot_trans_vec.grad}")
825
+ # print(f"Joint")
826
+ self.joint.clear_grads()
827
+ self.body.clear_grad()
828
+ for cur_link in self.children:
829
+ cur_link.clear_grads()
830
+
831
+ def set_args(self, args):
832
+ self.args = args
833
+ for cur_link in self.children:
834
+ cur_link.set_args(args)
835
+
836
+
837
+
838
+
839
+ class Robot: # robot and the robot #
840
+ def __init__(self, children_links, args) -> None:
841
+ self.children = children_links
842
+ self.args = args
843
+
844
+ def set_state(self, name_to_state):
845
+ for cur_link in self.children:
846
+ cur_link.set_state(name_to_state)
847
+
848
+ def compute_transformation(self,):
849
+ for cur_link in self.children:
850
+ cur_link.compute_transformation()
851
+
852
+ def get_name_to_visual_pts_faces(self, name_to_visual_pts_faces):
853
+ for cur_link in self.children:
854
+ cur_link.get_name_to_visual_pts_faces(name_to_visual_pts_faces)
855
+
856
+ def get_visual_pts_list(self, visual_pts_list):
857
+ for cur_link in self.children:
858
+ cur_link.get_visual_pts_list(visual_pts_list)
859
+
860
+ def set_joint_idx(self, joint_name_to_idx):
861
+ for cur_link in self.children:
862
+ cur_link.set_joint_idx(joint_name_to_idx) ### set joint idx ###
863
+
864
+ def set_state_via_vec(self, state_vec): ### set the state vec for the state vec ###
865
+ for cur_link in self.children: ### set the state vec for the state vec ###
866
+ cur_link.set_state_via_vec(state_vec)
867
+ # self.joint.set_state_via_vec(state_vec)
868
+ # for child_link in self.children:
869
+ # child_link.set_state_via_vec(state_vec)
870
+
871
+ # get_tot_transformed_joints
872
+ def get_tot_transformed_joints(self, transformed_joints):
873
+ for cur_link in self.children: #
874
+ transformed_joints = cur_link.get_tot_transformed_joints(transformed_joints)
875
+ return transformed_joints
876
+
877
+ def get_nn_pts(self):
878
+ nn_pts = 0
879
+ for cur_link in self.children:
880
+ nn_pts += cur_link.get_nn_pts()
881
+ self.nn_pts = nn_pts
882
+ return self.nn_pts
883
+
884
+ def set_args(self, args):
885
+ self.args = args
886
+ for cur_link in self.children: ## args ##
887
+ cur_link.set_args(args)
888
+
889
+ def print_grads(self):
890
+ for cur_link in self.children:
891
+ cur_link.print_grads()
892
+
893
+ def clear_grads(self,): ## clear grads ##
894
+ for cur_link in self.children:
895
+ cur_link.clear_grads()
896
+
897
+ def compute_transformation_via_state_vecs(self, state_vals, visual_pts_list):
898
+ # parent_rot_mtx, parent_trans_vec
899
+ for cur_link in self.children:
900
+ cur_link.compute_transformation_via_state_vecs(state_vals, cur_link.parent_rot_mtx, cur_link.parent_trans_vec, visual_pts_list)
901
+ return visual_pts_list
902
+
903
+ # get_visual_pts_rgba_values(self, pts_rgba_vals_list):
904
+ def get_visual_pts_rgba_values(self, pts_rgba_vals_list):
905
+ for cur_link in self.children:
906
+ cur_link.get_visual_pts_rgba_values(pts_rgba_vals_list)
907
+ return pts_rgba_vals_list ## compute pts rgba vals list ##
908
+
909
+ def parse_nparray_from_string(strr, args):
910
+ vals = strr.split(" ")
911
+ vals = [float(val) for val in vals]
912
+ vals = np.array(vals, dtype=np.float32)
913
+ vals = torch.from_numpy(vals).float()
914
+ ## vals ##
915
+ vals = nn.Parameter(vals.cuda(args.th_cuda_idx), requires_grad=True)
916
+
917
+ return vals
918
+
919
+
920
+ ### parse link data ###
921
+ def parse_link_data(link, args):
922
+
923
+ link_name = link.attrib["name"]
924
+ # print(f"parsing link: {link_name}") ## joints body meshes #
925
+
926
+ joint = link.find("./joint")
927
+
928
+ joint_name = joint.attrib["name"]
929
+ joint_type = joint.attrib["type"]
930
+ if joint_type in ["revolute"]: ## a general xml parser here?
931
+ axis = joint.attrib["axis"]
932
+ axis = parse_nparray_from_string(axis, args=args)
933
+ else:
934
+ axis = None
935
+ pos = joint.attrib["pos"] #
936
+ pos = parse_nparray_from_string(pos, args=args)
937
+ quat = joint.attrib["quat"]
938
+ quat = parse_nparray_from_string(quat, args=args)
939
+
940
+ try:
941
+ frame = joint.attrib["frame"]
942
+ except:
943
+ frame = "WORLD"
944
+
945
+ if joint_type not in ["fixed"]:
946
+ damping = joint.attrib["damping"]
947
+ damping = float(damping)
948
+ else:
949
+ damping = 0.0
950
+
951
+ cur_joint = Joint(joint_name, joint_type, axis, pos, quat, frame, damping, args=args)
952
+
953
+ body = link.find("./body")
954
+ body_name = body.attrib["name"]
955
+ body_type = body.attrib["type"]
956
+ if body_type == "mesh":
957
+ filename = body.attrib["filename"]
958
+ else:
959
+ filename = ""
960
+
961
+ if body_type == "sphere":
962
+ radius = body.attrib["radius"]
963
+ radius = float(radius)
964
+ else:
965
+ radius = 0.
966
+
967
+ pos = body.attrib["pos"]
968
+ pos = parse_nparray_from_string(pos, args=args)
969
+ quat = body.attrib["quat"]
970
+ quat = joint.attrib["quat"]
971
+ try:
972
+ transform_type = body.attrib["transform_type"]
973
+ except:
974
+ transform_type = "OBJ_TO_WORLD"
975
+ density = body.attrib["density"]
976
+ density = float(density)
977
+ mu = body.attrib["mu"]
978
+ mu = float(mu)
979
+ try: ## rgba ##
980
+ rgba = body.attrib["rgba"]
981
+ rgba = parse_nparray_from_string(rgba, args=args)
982
+ except:
983
+ rgba = np.zeros((4,), dtype=np.float32)
984
+
985
+ cur_body = Body(body_name, body_type, filename, pos, quat, transform_type, density, mu, rgba, radius, args=args)
986
+
987
+ children_link = []
988
+ links = link.findall("./link")
989
+ for child_link in links: #
990
+ cur_child_link = parse_link_data(child_link, args=args)
991
+ children_link.append(cur_child_link)
992
+
993
+ link_name = link.attrib["name"]
994
+ link_obj = Link(link_name, joint=cur_joint, body=cur_body, children=children_link, args=args)
995
+ return link_obj
996
+
997
+
998
+
999
+
1000
+ def parse_data_from_xml(xml_fn, args):
1001
+
1002
+ tree = ElementTree()
1003
+ tree.parse(xml_fn)
1004
+
1005
+ ### get total robots ###
1006
+ robots = tree.findall("./robot")
1007
+ i_robot = 0
1008
+ tot_robots = []
1009
+ for cur_robot in robots:
1010
+ print(f"Getting robot: {i_robot}")
1011
+ i_robot += 1
1012
+ cur_links = cur_robot.findall("./link")
1013
+ # i_link = 0
1014
+ cur_robot_links = []
1015
+ for cur_link in cur_links: ## child of the link ##
1016
+ ### a parse link util -> the child of the link is composed of (the joint; body; and children links (with children or with no child here))
1017
+ # cur_link_name = cur_link.attrib["name"]
1018
+ # print(f"Getting link: {i_link} with name: {cur_link_name}")
1019
+ # i_link += 1 ##
1020
+ cur_robot_links.append(parse_link_data(cur_link, args=args))
1021
+ cur_robot_obj = Robot(cur_robot_links, args=args)
1022
+ tot_robots.append(cur_robot_obj)
1023
+
1024
+
1025
+ tot_actuators = []
1026
+ actuators = tree.findall("./actuator/motor")
1027
+ joint_nm_to_joint_idx = {}
1028
+ i_act = 0
1029
+ for cur_act in actuators:
1030
+ cur_act_joint_nm = cur_act.attrib["joint"]
1031
+ joint_nm_to_joint_idx[cur_act_joint_nm] = i_act
1032
+ i_act += 1 ### add the act ###
1033
+
1034
+ tot_robots[0].set_joint_idx(joint_nm_to_joint_idx) ### set joint idx here ### # tot robots #
1035
+ tot_robots[0].get_nn_pts()
1036
+ tot_robots[1].get_nn_pts()
1037
+
1038
+ return tot_robots
1039
+
1040
+ def get_name_to_state_from_str(states_str):
1041
+ tot_states = states_str.split(" ")
1042
+ tot_states = [float(cur_state) for cur_state in tot_states]
1043
+ joint_name_to_state = {}
1044
+ for i in range(len(tot_states)):
1045
+ cur_joint_name = f"joint{i + 1}"
1046
+ cur_joint_state = tot_states[i]
1047
+ joint_name_to_state[cur_joint_name] = cur_joint_state
1048
+ return joint_name_to_state
1049
+
1050
+ def create_zero_states():
1051
+ nn_joints = 17
1052
+ joint_name_to_state = {}
1053
+ for i_j in range(nn_joints):
1054
+ cur_joint_name = f"joint{i_j + 1}"
1055
+ joint_name_to_state[cur_joint_name] = 0.
1056
+ return joint_name_to_state
1057
+
1058
+ # [6.96331033e-17 3.54807679e-06 1.74046190e-15 2.66367417e-05
1059
+ # 1.22444894e-05 3.38976792e-06 1.46917635e-15 2.66367383e-05
1060
+ # 1.22444882e-05 3.38976786e-06 1.97778813e-15 2.66367383e-05
1061
+ # 1.22444882e-05 3.38976786e-06 4.76033293e-16 1.26279884e-05
1062
+ # 3.51189993e-06 0.00000000e+00 4.89999978e-03 0.00000000e+00]
1063
+
1064
+
1065
+ def rotation_matrix_from_axis_angle_np(axis, angle): # rotation_matrix_from_axis_angle ->
1066
+ sin_ = np.sin(angle) # ti.math.sin(angle)
1067
+ cos_ = np.cos(angle) # ti.math.cos(angle)
1068
+ # sin_ = torch.sin(angle) # ti.math.sin(angle)
1069
+ # cos_ = torch.cos(angle) # ti.math.cos(angle)
1070
+ u_x, u_y, u_z = axis[0], axis[1], axis[2]
1071
+ u_xx = u_x * u_x
1072
+ u_yy = u_y * u_y
1073
+ u_zz = u_z * u_z
1074
+ u_xy = u_x * u_y
1075
+ u_xz = u_x * u_z
1076
+ u_yz = u_y * u_z ##
1077
+
1078
+
1079
+ row_a = np.stack(
1080
+ [cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], axis=0
1081
+ )
1082
+ # print(f"row_a: {row_a.size()}")
1083
+ row_b = np.stack(
1084
+ [u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], axis=0
1085
+ )
1086
+ # print(f"row_b: {row_b.size()}")
1087
+ row_c = np.stack(
1088
+ [u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], axis=0
1089
+ )
1090
+ # print(f"row_c: {row_c.size()}")
1091
+
1092
+ ### rot_mtx for the rot_mtx ###
1093
+ rot_mtx = np.stack(
1094
+ [row_a, row_b, row_c], axis=-1 ### rot_matrix of he matrix ##
1095
+ )
1096
+
1097
+ return rot_mtx
1098
+
1099
+
1100
+
1101
+
1102
+ def rotation_matrix_from_axis_angle(axis, angle): # rotation_matrix_from_axis_angle ->
1103
+ # sin_ = np.sin(angle) # ti.math.sin(angle)
1104
+ # cos_ = np.cos(angle) # ti.math.cos(angle)
1105
+ sin_ = torch.sin(angle) # ti.math.sin(angle)
1106
+ cos_ = torch.cos(angle) # ti.math.cos(angle)
1107
+ u_x, u_y, u_z = axis[0], axis[1], axis[2]
1108
+ u_xx = u_x * u_x
1109
+ u_yy = u_y * u_y
1110
+ u_zz = u_z * u_z
1111
+ u_xy = u_x * u_y
1112
+ u_xz = u_x * u_z
1113
+ u_yz = u_y * u_z ##
1114
+
1115
+
1116
+ row_a = torch.stack(
1117
+ [cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dim=0
1118
+ )
1119
+ # print(f"row_a: {row_a.size()}")
1120
+ row_b = torch.stack(
1121
+ [u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dim=0
1122
+ )
1123
+ # print(f"row_b: {row_b.size()}")
1124
+ row_c = torch.stack(
1125
+ [u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dim=0
1126
+ )
1127
+ # print(f"row_c: {row_c.size()}")
1128
+
1129
+ ### rot_mtx for the rot_mtx ###
1130
+ rot_mtx = torch.stack(
1131
+ [row_a, row_b, row_c], dim=-1 ### rot_matrix of he matrix ##
1132
+ )
1133
+
1134
+ return rot_mtx
1135
+
1136
+
1137
+ def get_camera_to_world_poses(n=10, ):
1138
+ ## sample from the upper half sphere ##
1139
+ # theta and phi for the
1140
+ theta = np.random.uniform(low=0.0, high=1.0, size=(n,)) * np.pi * 2. # xz palne #
1141
+ phi = np.random.uniform(low=-1.0, high=0.0, size=(n,)) * np.pi ## [-0.5 \pi, 0.5 \pi] ## negative pi to the original pi
1142
+ # theta = torch.from_numpy(theta).float().cuda()
1143
+ tot_c2w_matrix = []
1144
+ for i_n in range(n):
1145
+ # y_rot_vec = torch.tensor([0., 1., 0.]).float().cuda(th_cuda_idx)
1146
+ # y_rot_mtx = load_utils.rotation_matrix_from_axis_angle(rot_vec, rot_angle)
1147
+
1148
+
1149
+ z_axis_rot_axis = np.array([0, 0, 1.], dtype=np.float32)
1150
+ z_axis_rot_angle = np.pi - theta[i_n]
1151
+ z_axis_rot_matrix = rotation_matrix_from_axis_angle_np(z_axis_rot_axis, z_axis_rot_angle)
1152
+ rotated_plane_rot_axis_ori = np.array([1, -1, 0], dtype=np.float32)
1153
+ rotated_plane_rot_axis_ori = rotated_plane_rot_axis_ori / np.sqrt(np.sum(rotated_plane_rot_axis_ori ** 2))
1154
+ rotated_plane_rot_axis = np.matmul(z_axis_rot_matrix, rotated_plane_rot_axis_ori)
1155
+
1156
+ plane_rot_angle = phi[i_n]
1157
+ plane_rot_matrix = rotation_matrix_from_axis_angle_np(rotated_plane_rot_axis, plane_rot_angle)
1158
+
1159
+ c2w_matrix = np.matmul(plane_rot_matrix, z_axis_rot_matrix)
1160
+ c2w_trans_matrix = np.array(
1161
+ [np.cos(theta[i_n]) * np.sin(phi[i_n]), np.sin(theta[i_n]) * np.sin(phi[i_n]), np.cos(phi[i_n])], dtype=np.float32
1162
+ )
1163
+ c2w_matrix = np.concatenate(
1164
+ [c2w_matrix, c2w_trans_matrix.reshape(3, 1)], axis=-1
1165
+ ) ##c2w matrix
1166
+ tot_c2w_matrix.append(c2w_matrix)
1167
+ tot_c2w_matrix = np.stack(tot_c2w_matrix, axis=0)
1168
+ return tot_c2w_matrix
1169
+
1170
+
1171
+ def get_camera_to_world_poses_th(n=10, th_cuda_idx=0):
1172
+ ## sample from the upper half sphere ##
1173
+ # theta and phi for the
1174
+ theta = np.random.uniform(low=0.0, high=1.0, size=(n,)) * np.pi * 2. # xz palne #
1175
+ phi = np.random.uniform(low=-1.0, high=0.0, size=(n,)) * np.pi ## [-0.5 \pi, 0.5 \pi] ## negative pi to the original pi
1176
+
1177
+ # n_total = 14
1178
+ # n_xz = 14
1179
+ # n_y = 7
1180
+ # theta = [i_xz * 1.0 / float(n_xz) * np.pi * 2. for i_xz in range(n_xz)]
1181
+ # phi = [i_y * (-1.0) / float(n_y) * np.pi for i_y in range(n_y)]
1182
+
1183
+
1184
+ theta = torch.from_numpy(theta).float().cuda(th_cuda_idx)
1185
+ phi = torch.from_numpy(phi).float().cuda(th_cuda_idx)
1186
+
1187
+ tot_c2w_matrix = []
1188
+ for i_n in range(n): # if use veyr dense views like those
1189
+ y_rot_angle = theta[i_n]
1190
+ y_rot_vec = torch.tensor([0., 1., 0.]).float().cuda(th_cuda_idx)
1191
+ y_rot_mtx = rotation_matrix_from_axis_angle(y_rot_vec, y_rot_angle)
1192
+
1193
+ x_axis = torch.tensor([1., 0., 0.]).float().cuda(th_cuda_idx)
1194
+ y_rot_x_axis = torch.matmul(y_rot_mtx, x_axis.unsqueeze(-1)).squeeze(-1) ### y_rot_x_axis #
1195
+
1196
+ x_rot_angle = phi[i_n]
1197
+ x_rot_mtx = rotation_matrix_from_axis_angle(y_rot_x_axis, x_rot_angle)
1198
+
1199
+ rot_mtx = torch.matmul(x_rot_mtx, y_rot_mtx)
1200
+ xyz_offset = torch.tensor([0., 0., 1.5]).float().cuda(th_cuda_idx)
1201
+ rot_xyz_offset = torch.matmul(rot_mtx, xyz_offset.unsqueeze(-1)).squeeze(-1).contiguous() + 0.5 ### 3 for the xyz offset
1202
+
1203
+ c2w_matrix = torch.cat(
1204
+ [rot_mtx, rot_xyz_offset.unsqueeze(-1)], dim=-1
1205
+ )
1206
+ tot_c2w_matrix.append(c2w_matrix)
1207
+
1208
+
1209
+ # z_axis_rot_axis = np.array([0, 0, 1.], dtype=np.float32)
1210
+ # z_axis_rot_angle = np.pi - theta[i_n]
1211
+ # z_axis_rot_matrix = rotation_matrix_from_axis_angle_np(z_axis_rot_axis, z_axis_rot_angle)
1212
+ # rotated_plane_rot_axis_ori = np.array([1, -1, 0], dtype=np.float32)
1213
+ # rotated_plane_rot_axis_ori = rotated_plane_rot_axis_ori / np.sqrt(np.sum(rotated_plane_rot_axis_ori ** 2))
1214
+ # rotated_plane_rot_axis = np.matmul(z_axis_rot_matrix, rotated_plane_rot_axis_ori)
1215
+
1216
+ # plane_rot_angle = phi[i_n]
1217
+ # plane_rot_matrix = rotation_matrix_from_axis_angle_np(rotated_plane_rot_axis, plane_rot_angle)
1218
+
1219
+ # c2w_matrix = np.matmul(plane_rot_matrix, z_axis_rot_matrix)
1220
+ # c2w_trans_matrix = np.array(
1221
+ # [np.cos(theta[i_n]) * np.sin(phi[i_n]), np.sin(theta[i_n]) * np.sin(phi[i_n]), np.cos(phi[i_n])], dtype=np.float32
1222
+ # )
1223
+ # c2w_matrix = np.concatenate(
1224
+ # [c2w_matrix, c2w_trans_matrix.reshape(3, 1)], axis=-1
1225
+ # ) ##c2w matrix
1226
+ # tot_c2w_matrix.append(c2w_matrix)
1227
+ # tot_c2w_matrix = np.stack(tot_c2w_matrix, axis=0)
1228
+ tot_c2w_matrix = torch.stack(tot_c2w_matrix, dim=0)
1229
+ return tot_c2w_matrix
1230
+
1231
+
1232
+ def get_camera_to_world_poses_th_routine_1(n=7, th_cuda_idx=0):
1233
+ ## sample from the upper half sphere ##
1234
+ # theta and phi for the
1235
+
1236
+ # theta = np.random.uniform(low=0.0, high=1.0, size=(n,)) * np.pi * 2. # xz palne #
1237
+ # phi = np.random.uniform(low=-1.0, high=0.0, size=(n,)) * np.pi ## [-0.5 \pi, 0.5 \pi] ## negative pi to the original pi
1238
+
1239
+ # n_total = 14
1240
+ n_xz = 2 * n # 14
1241
+ n_y = n # 7
1242
+ theta = [i_xz * 1.0 / float(n_xz) * np.pi * 2. for i_xz in range(n_xz)]
1243
+ phi = [i_y * (-1.0) / float(n_y) * np.pi for i_y in range(n_y)]
1244
+
1245
+ theta = torch.tensor(theta).float().cuda(th_cuda_idx)
1246
+ phi = torch.tensor(phi).float().cuda(th_cuda_idx)
1247
+ # theta = torch.from_numpy(theta).float().cuda(th_cuda_idx)
1248
+ # phi = torch.from_numpy(phi).float().cuda(th_cuda_idx)
1249
+
1250
+ tot_c2w_matrix = []
1251
+
1252
+ for i_theta in range(theta.size(0)):
1253
+ for i_phi in range(phi.size(0)):
1254
+ y_rot_angle = theta[i_theta]
1255
+ y_rot_vec = torch.tensor([0., 1., 0.]).float().cuda(th_cuda_idx)
1256
+ y_rot_mtx = rotation_matrix_from_axis_angle(y_rot_vec, y_rot_angle)
1257
+
1258
+ x_axis = torch.tensor([1., 0., 0.]).float().cuda(th_cuda_idx)
1259
+ y_rot_x_axis = torch.matmul(y_rot_mtx, x_axis.unsqueeze(-1)).squeeze(-1) ### y_rot_x_axis #
1260
+
1261
+ x_rot_angle = phi[i_phi]
1262
+ x_rot_mtx = rotation_matrix_from_axis_angle(y_rot_x_axis, x_rot_angle)
1263
+
1264
+ rot_mtx = torch.matmul(x_rot_mtx, y_rot_mtx)
1265
+ xyz_offset = torch.tensor([0., 0., 1.5]).float().cuda(th_cuda_idx)
1266
+ rot_xyz_offset = torch.matmul(rot_mtx, xyz_offset.unsqueeze(-1)).squeeze(-1).contiguous() + 0.5 ### 3 for the xyz offset
1267
+
1268
+ c2w_matrix = torch.cat(
1269
+ [rot_mtx, rot_xyz_offset.unsqueeze(-1)], dim=-1
1270
+ )
1271
+ tot_c2w_matrix.append(c2w_matrix)
1272
+
1273
+ tot_c2w_matrix = torch.stack(tot_c2w_matrix, dim=0)
1274
+ return tot_c2w_matrix
1275
+
1276
+
1277
+ def get_camera_to_world_poses_th_routine_2(n=7, th_cuda_idx=0):
1278
+ ## sample from the upper half sphere ##
1279
+ # theta and phi for the
1280
+
1281
+ # theta = np.random.uniform(low=0.0, high=1.0, size=(n,)) * np.pi * 2. # xz palne #
1282
+ # phi = np.random.uniform(low=-1.0, high=0.0, size=(n,)) * np.pi ## [-0.5 \pi, 0.5 \pi] ## negative pi to the original pi
1283
+
1284
+ # n_total = 14
1285
+ n_xz = 2 * n # 14
1286
+ n_y = 2 * n # 7
1287
+ theta = [i_xz * 1.0 / float(n_xz) * np.pi * 2. for i_xz in range(n_xz)]
1288
+ # phi = [i_y * (-1.0) / float(n_y) * np.pi for i_y in range(n_y)]
1289
+ phi = [i_y * (-1.0) / float(n_y) * np.pi * 2. for i_y in range(n_y)]
1290
+
1291
+ theta = torch.tensor(theta).float().cuda(th_cuda_idx)
1292
+ phi = torch.tensor(phi).float().cuda(th_cuda_idx)
1293
+ # theta = torch.from_numpy(theta).float().cuda(th_cuda_idx)
1294
+ # phi = torch.from_numpy(phi).float().cuda(th_cuda_idx)
1295
+
1296
+ tot_c2w_matrix = []
1297
+
1298
+ for i_theta in range(theta.size(0)):
1299
+ for i_phi in range(phi.size(0)):
1300
+ y_rot_angle = theta[i_theta]
1301
+ y_rot_vec = torch.tensor([0., 1., 0.]).float().cuda(th_cuda_idx)
1302
+ y_rot_mtx = rotation_matrix_from_axis_angle(y_rot_vec, y_rot_angle)
1303
+
1304
+ x_axis = torch.tensor([1., 0., 0.]).float().cuda(th_cuda_idx)
1305
+ y_rot_x_axis = torch.matmul(y_rot_mtx, x_axis.unsqueeze(-1)).squeeze(-1) ### y_rot_x_axis #
1306
+
1307
+ x_rot_angle = phi[i_phi]
1308
+ x_rot_mtx = rotation_matrix_from_axis_angle(y_rot_x_axis, x_rot_angle)
1309
+
1310
+ rot_mtx = torch.matmul(x_rot_mtx, y_rot_mtx)
1311
+ xyz_offset = torch.tensor([0., 0., 1.5]).float().cuda(th_cuda_idx)
1312
+ rot_xyz_offset = torch.matmul(rot_mtx, xyz_offset.unsqueeze(-1)).squeeze(-1).contiguous() + 0.5 ### 3 for the xyz offset
1313
+
1314
+ c2w_matrix = torch.cat(
1315
+ [rot_mtx, rot_xyz_offset.unsqueeze(-1)], dim=-1
1316
+ )
1317
+ tot_c2w_matrix.append(c2w_matrix)
1318
+
1319
+ tot_c2w_matrix = torch.stack(tot_c2w_matrix, dim=0)
1320
+ return tot_c2w_matrix
1321
+
1322
+
1323
+
1324
+
1325
+
1326
+
1327
+ if __name__=='__main__':
1328
+ xml_fn = "/home/xueyi/diffsim/DiffHand/assets/hand_sphere.xml"
1329
+ tot_robots = parse_data_from_xml(xml_fn=xml_fn)
1330
+ # tot_robots =
1331
+
1332
+ active_optimized_states = """-0.00025872 -0.00025599 -0.00025296 -0.00022881 -0.00024449 -0.0002549 -0.00025296 -0.00022881 -0.00024449 -0.0002549 -0.00025296 -0.00022881 -0.00024449 -0.0002549 -0.00025694 -0.00024656 -0.00025556 0. 0.0049 0."""
1333
+ active_optimized_states = """-1.10617972 -1.10742263 -1.06198363 -1.03212746 -1.05429142 -1.08617289 -1.05868192 -1.01624365 -1.04478191 -1.08260959 -1.06719107 -1.04082455 -1.05995886 -1.08674006 -1.09396691 -1.08965532 -1.10036577 -10.7117466 -3.62511998 1.49450353"""
1334
+ # active_goal_optimized_states = """-1.10617972 -1.10742263 -1.0614858 -1.03189609 -1.05404354 -1.08610468 -1.05863293 -1.0174248 -1.04576456 -1.08297396 -1.06719107 -1.04082455 -1.05995886 -1.08674006 -1.09396691 -1.08965532 -1.10036577 -10.73396897 -3.68095432 1.50679285"""
1335
+ active_optimized_states = """-0.42455298 -0.42570447 -0.40567708 -0.39798589 -0.40953955 -0.42025055 -0.37910662 -0.496165 -0.37664644 -0.41942727 -0.40596508 -0.3982109 -0.40959847 -0.42024905 -0.41835001 -0.41929961 -0.42365131 -1.18756073 -2.90337822 0.4224685"""
1336
+ active_optimized_states = """-0.42442816 -0.42557961 -0.40366201 -0.3977891 -0.40947627 -0.4201424 -0.3799285 -0.3808375 -0.37953552 -0.42039598 -0.4058405 -0.39808804 -0.40947487 -0.42012458 -0.41822534 -0.41917521 -0.4235266 -0.87189658 -1.42093761 0.21977979"""
1337
+
1338
+ active_robot = tot_robots[0]
1339
+ zero_states = create_zero_states()
1340
+ active_robot.set_state(zero_states)
1341
+ active_robot.compute_transformation()
1342
+ name_to_visual_pts_surfaces = {}
1343
+ active_robot.get_name_to_visual_pts_faces(name_to_visual_pts_surfaces)
1344
+ print(len(name_to_visual_pts_surfaces))
1345
+
1346
+ sv_res_rt = "/home/xueyi/diffsim/DiffHand/examples/save_res"
1347
+ sv_res_rt = os.path.join(sv_res_rt, "load_utils_test")
1348
+ os.makedirs(sv_res_rt, exist_ok=True)
1349
+
1350
+ tmp_visual_res_sv_fn = os.path.join(sv_res_rt, f"res_with_zero_states.npy")
1351
+ np.save(tmp_visual_res_sv_fn, name_to_visual_pts_surfaces)
1352
+ print(f"tmp visual res saved to {tmp_visual_res_sv_fn}")
1353
+
1354
+
1355
+ optimized_states = get_name_to_state_from_str(active_optimized_states)
1356
+ active_robot.set_state(optimized_states)
1357
+ active_robot.compute_transformation()
1358
+ name_to_visual_pts_surfaces = {}
1359
+ active_robot.get_name_to_visual_pts_faces(name_to_visual_pts_surfaces)
1360
+ print(len(name_to_visual_pts_surfaces))
1361
+ # sv_res_rt = "/home/xueyi/diffsim/DiffHand/examples/save_res"
1362
+ # sv_res_rt = os.path.join(sv_res_rt, "load_utils_test")
1363
+ # os.makedirs(sv_res_rt, exist_ok=True)
1364
+
1365
+ # tmp_visual_res_sv_fn = os.path.join(sv_res_rt, f"res_with_optimized_states.npy")
1366
+ tmp_visual_res_sv_fn = os.path.join(sv_res_rt, f"active_ngoal_res_with_optimized_states_goal_n3.npy")
1367
+ np.save(tmp_visual_res_sv_fn, name_to_visual_pts_surfaces)
1368
+ print(f"tmp visual res with optimized states saved to {tmp_visual_res_sv_fn}")
1369
+
models/embedder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
6
+ class Embedder:
7
+ def __init__(self, **kwargs):
8
+ self.kwargs = kwargs
9
+ self.create_embedding_fn()
10
+
11
+ def create_embedding_fn(self):
12
+ embed_fns = []
13
+ d = self.kwargs['input_dims']
14
+ out_dim = 0
15
+ if self.kwargs['include_input']:
16
+ embed_fns.append(lambda x: x)
17
+ out_dim += d
18
+
19
+ max_freq = self.kwargs['max_freq_log2']
20
+ N_freqs = self.kwargs['num_freqs']
21
+
22
+ if self.kwargs['log_sampling']:
23
+ freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
24
+ else:
25
+ freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
26
+
27
+ for freq in freq_bands:
28
+ for p_fn in self.kwargs['periodic_fns']:
29
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
30
+ out_dim += d
31
+
32
+ self.embed_fns = embed_fns
33
+ self.out_dim = out_dim
34
+
35
+ def embed(self, inputs):
36
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
37
+
38
+
39
+ def get_embedder(multires, input_dims=3):
40
+ embed_kwargs = {
41
+ 'include_input': True,
42
+ 'input_dims': input_dims,
43
+ 'max_freq_log2': multires-1,
44
+ 'num_freqs': multires,
45
+ 'log_sampling': True,
46
+ 'periodic_fns': [torch.sin, torch.cos],
47
+ }
48
+
49
+ embedder_obj = Embedder(**embed_kwargs)
50
+ def embed(x, eo=embedder_obj): return eo.embed(x)
51
+ return embed, embedder_obj.out_dim
models/fields.py ADDED
The diff for this file is too large to render. See raw diff
 
models/fields_old.py ADDED
The diff for this file is too large to render. See raw diff
 
models/renderer.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import logging
6
+ import mcubes
7
+ from icecream import ic
8
+ import os
9
+
10
+ import trimesh
11
+ from pysdf import SDF
12
+
13
+ from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
14
+
15
+ def create_mt_variable(device):
16
+ triangle_table = torch.tensor(
17
+ [
18
+ [-1, -1, -1, -1, -1, -1],
19
+ [1, 0, 2, -1, -1, -1],
20
+ [4, 0, 3, -1, -1, -1],
21
+ [1, 4, 2, 1, 3, 4],
22
+ [3, 1, 5, -1, -1, -1],
23
+ [2, 3, 0, 2, 5, 3],
24
+ [1, 4, 0, 1, 5, 4],
25
+ [4, 2, 5, -1, -1, -1],
26
+ [4, 5, 2, -1, -1, -1],
27
+ [4, 1, 0, 4, 5, 1],
28
+ [3, 2, 0, 3, 5, 2],
29
+ [1, 3, 5, -1, -1, -1],
30
+ [4, 1, 2, 4, 3, 1],
31
+ [3, 0, 4, -1, -1, -1],
32
+ [2, 0, 1, -1, -1, -1],
33
+ [-1, -1, -1, -1, -1, -1]
34
+ ], dtype=torch.long, device=device)
35
+
36
+ num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
37
+ base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
38
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
39
+ return triangle_table, num_triangles_table, base_tet_edges, v_id
40
+
41
+
42
+
43
+ def extract_fields_from_tets(bound_min, bound_max, resolution, query_func):
44
+ # load tet via resolution #
45
+ # scale them via bounds #
46
+ # extract the geometry #
47
+ # /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
48
+ device = bound_min.device
49
+ # if resolution in [64, 70, 80, 90, 100]:
50
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
51
+ # else:
52
+ tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
53
+ tets = np.load(tet_fn)
54
+ verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
55
+ indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
56
+ # split #
57
+ # verts; verts; #
58
+ minn_verts, _ = torch.min(verts, dim=0)
59
+ maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
60
+ # scale_verts = maxx_verts - minn_verts
61
+ scale_bounds = bound_max - bound_min # scale bounds #
62
+
63
+ ### scale the vertices ###
64
+ scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
65
+
66
+ # scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
67
+
68
+ scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
69
+ # scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
70
+
71
+ # scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
72
+ # scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
73
+
74
+ sdf_values = []
75
+ N = 64
76
+ query_bundles = N ** 3 ### N^3
77
+ query_NNs = scaled_verts.size(0) // query_bundles
78
+ if query_NNs * query_bundles < scaled_verts.size(0):
79
+ query_NNs += 1
80
+ for i_query in range(query_NNs):
81
+ cur_bundle_st = i_query * query_bundles
82
+ cur_bundle_ed = (i_query + 1) * query_bundles
83
+ cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
84
+ cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
85
+ cur_query_vals = query_func(cur_query_pts)
86
+ sdf_values.append(cur_query_vals)
87
+ sdf_values = torch.cat(sdf_values, dim=0)
88
+ # print(f"queryed sdf values: {sdf_values.size()}") #
89
+
90
+ GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
91
+ GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
92
+
93
+ # intrinsic, tet values, pts values, sdf network #
94
+ triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
95
+ tet_table, num_tets_table = create_tetmesh_variables(device)
96
+
97
+ sdf_values = sdf_values.squeeze(-1) # how the rendering #
98
+
99
+ # print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
100
+ # print(f"scaled_verts: {scaled_verts.size()}, ")
101
+ # pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
102
+ # return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
103
+ # marching_tets_tetmesh ##
104
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
105
+ ### use the GT sdf values for the marching tets ###
106
+ GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
107
+
108
+ # print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
109
+ return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
110
+
111
+ def extract_fields(bound_min, bound_max, resolution, query_func):
112
+ N = 64
113
+ X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
114
+ Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
115
+ Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
116
+
117
+ u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
118
+ with torch.no_grad():
119
+ for xi, xs in enumerate(X):
120
+ for yi, ys in enumerate(Y):
121
+ for zi, zs in enumerate(Z):
122
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
123
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
124
+ val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
125
+ u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
126
+ # should save u here #
127
+ # save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
128
+ # np.save(save_u_path, u)
129
+ # print(f"u saved to {save_u_path}")
130
+ return u
131
+
132
+
133
+ def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
134
+ print('threshold: {}'.format(threshold))
135
+
136
+ ## using maching cubes ###
137
+ u = extract_fields(bound_min, bound_max, resolution, query_func)
138
+ vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
139
+ b_max_np = bound_max.detach().cpu().numpy()
140
+ b_min_np = bound_min.detach().cpu().numpy()
141
+
142
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
143
+ ### using maching cubes ###
144
+
145
+ ### using marching tets ###
146
+ # vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
147
+ # vertices = vertices.detach().cpu().numpy()
148
+ # triangles = triangles.detach().cpu().numpy()
149
+ ### using marching tets ###
150
+
151
+ # b_max_np = bound_max.detach().cpu().numpy()
152
+ # b_min_np = bound_min.detach().cpu().numpy()
153
+
154
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
155
+ return vertices, triangles
156
+
157
+ def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func):
158
+ # print('threshold: {}'.format(threshold))
159
+
160
+ ### using maching cubes ###
161
+ # u = extract_fields(bound_min, bound_max, resolution, query_func)
162
+ # vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
163
+ # b_max_np = bound_max.detach().cpu().numpy()
164
+ # b_min_np = bound_min.detach().cpu().numpy()
165
+
166
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
167
+ ### using maching cubes ###
168
+
169
+ ##
170
+ ### using marching tets ### fiels from tets ##
171
+ vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
172
+ # vertices = vertices.detach().cpu().numpy()
173
+ # triangles = triangles.detach().cpu().numpy()
174
+ ### using marching tets ###
175
+
176
+ # b_max_np = bound_max.detach().cpu().numpy()
177
+ # b_min_np = bound_min.detach().cpu().numpy()
178
+ #
179
+
180
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
181
+ return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
182
+
183
+
184
+ def sample_pdf(bins, weights, n_samples, det=False):
185
+ # This implementation is from NeRF
186
+ # Get pdf
187
+ weights = weights + 1e-5 # prevent nans
188
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
189
+ cdf = torch.cumsum(pdf, -1)
190
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
191
+ # Take uniform samples
192
+ if det:
193
+ u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
194
+ u = u.expand(list(cdf.shape[:-1]) + [n_samples])
195
+ else:
196
+ u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
197
+
198
+ # Invert CDF # invert cdf #
199
+ u = u.contiguous()
200
+ inds = torch.searchsorted(cdf, u, right=True)
201
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
202
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
203
+ inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
204
+
205
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
206
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
207
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
208
+
209
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
210
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
211
+ t = (u - cdf_g[..., 0]) / denom
212
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
213
+
214
+ return samples
215
+
216
+
217
+ def load_GT_vertices(GT_meshes_folder):
218
+ tot_meshes_fns = os.listdir(GT_meshes_folder)
219
+ tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
220
+ tot_mesh_verts = []
221
+ tot_mesh_faces = []
222
+ n_tot_verts = 0
223
+ for fn in tot_meshes_fns:
224
+ cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
225
+ obj_mesh = trimesh.load(cur_mesh_fn, process=False)
226
+ # obj_mesh.remove_degenerate_faces(height=1e-06)
227
+
228
+ verts_obj = np.array(obj_mesh.vertices)
229
+ faces_obj = np.array(obj_mesh.faces)
230
+
231
+ tot_mesh_verts.append(verts_obj)
232
+ tot_mesh_faces.append(faces_obj + n_tot_verts)
233
+ n_tot_verts += verts_obj.shape[0]
234
+
235
+ # tot_mesh_faces.append(faces_obj)
236
+ tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
237
+ tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
238
+ return tot_mesh_verts, tot_mesh_faces
239
+
240
+
241
+ class NeuSRenderer:
242
+ def __init__(self,
243
+ nerf,
244
+ sdf_network,
245
+ deviation_network,
246
+ color_network,
247
+ n_samples,
248
+ n_importance,
249
+ n_outside,
250
+ up_sample_steps,
251
+ perturb):
252
+ self.nerf = nerf
253
+ self.sdf_network = sdf_network
254
+ self.deviation_network = deviation_network
255
+ self.color_network = color_network
256
+ self.n_samples = n_samples
257
+ self.n_importance = n_importance
258
+ self.n_outside = n_outside
259
+ self.up_sample_steps = up_sample_steps
260
+ self.perturb = perturb
261
+
262
+ GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
263
+ self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
264
+ maxx_pts = 25.
265
+ minn_pts = -15.
266
+ self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
267
+ f = SDF(self.mesh_vertices, self.mesh_faces)
268
+ self.gt_sdf = f ## a unite sphere or box
269
+
270
+ self.minn_pts = 0
271
+ self.maxx_pts = 1.
272
+
273
+ # self.minn_pts = -1.5
274
+ # self.maxx_pts = 1.5
275
+ self.bkg_pts = ... # TODO: the bkg pts
276
+ self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
277
+ self.dist_interp_thres = ... # TODO
278
+
279
+ # get the pts and render the pts #
280
+ # pts and the rendering pts #
281
+ def deform_pts(self, pts):
282
+ # pts: nn_batch x nn_samples x 3
283
+ nnb, nns = pts.size(0), pts.size(1)
284
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
285
+ dist_pts_to_bkg_pts = torch.sum(
286
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
287
+ )
288
+
289
+
290
+
291
+
292
+ def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None):
293
+ """
294
+ Render background
295
+ """
296
+ batch_size, n_samples = z_vals.shape
297
+
298
+ # Section length
299
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
300
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
301
+ mid_z_vals = z_vals + dists * 0.5
302
+
303
+ # Section midpoints #
304
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
305
+
306
+ # pts = pts.flip((-1,)) * 2 - 1
307
+ pts = pts * 2 - 1
308
+
309
+ dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
310
+ pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
311
+
312
+ dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
313
+
314
+ pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
315
+ dirs = dirs.reshape(-1, 3)
316
+
317
+ density, sampled_color = nerf(pts, dirs)
318
+ sampled_color = torch.sigmoid(sampled_color)
319
+ alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
320
+ alpha = alpha.reshape(batch_size, n_samples)
321
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
322
+ sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
323
+ color = (weights[:, :, None] * sampled_color).sum(dim=1)
324
+ if background_rgb is not None:
325
+ color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
326
+
327
+ return {
328
+ 'color': color,
329
+ 'sampled_color': sampled_color,
330
+ 'alpha': alpha,
331
+ 'weights': weights,
332
+ }
333
+
334
+ def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
335
+ """
336
+ Up sampling give a fixed inv_s
337
+ """
338
+ batch_size, n_samples = z_vals.shape
339
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
340
+
341
+ # pts = pts.flip((-1,)) * 2 - 1
342
+ pts = pts * 2 - 1
343
+
344
+ radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
345
+ inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
346
+ sdf = sdf.reshape(batch_size, n_samples)
347
+ prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
348
+ prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
349
+ mid_sdf = (prev_sdf + next_sdf) * 0.5
350
+ cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
351
+
352
+ # ----------------------------------------------------------------------------------------------------------
353
+ # Use min value of [ cos, prev_cos ]
354
+ # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
355
+ # robust when meeting situations like below:
356
+ #
357
+ # SDF
358
+ # ^
359
+ # |\ -----x----...
360
+ # | \ /
361
+ # | x x
362
+ # |---\----/-------------> 0 level
363
+ # | \ /
364
+ # | \/
365
+ # |
366
+ # ----------------------------------------------------------------------------------------------------------
367
+ prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
368
+ cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
369
+ cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
370
+ cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
371
+
372
+ dist = (next_z_vals - prev_z_vals)
373
+ prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
374
+ next_esti_sdf = mid_sdf + cos_val * dist * 0.5
375
+ prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
376
+ next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
377
+ alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
378
+ weights = alpha * torch.cumprod(
379
+ torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
380
+
381
+ z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
382
+ return z_samples
383
+
384
+ def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
385
+ batch_size, n_samples = z_vals.shape
386
+ _, n_importance = new_z_vals.shape
387
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
388
+
389
+ # pts = pts.flip((-1,)) * 2 - 1
390
+ pts = pts * 2 - 1
391
+
392
+ z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
393
+ z_vals, index = torch.sort(z_vals, dim=-1)
394
+
395
+ if not last:
396
+ new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
397
+ sdf = torch.cat([sdf, new_sdf], dim=-1)
398
+ xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
399
+ index = index.reshape(-1)
400
+ sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
401
+
402
+ return z_vals, sdf
403
+
404
+ def render_core(self,
405
+ rays_o,
406
+ rays_d,
407
+ z_vals,
408
+ sample_dist,
409
+ sdf_network,
410
+ deviation_network,
411
+ color_network,
412
+ background_alpha=None,
413
+ background_sampled_color=None,
414
+ background_rgb=None,
415
+ cos_anneal_ratio=0.0):
416
+ batch_size, n_samples = z_vals.shape
417
+
418
+ # Section length
419
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
420
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
421
+ mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
422
+
423
+ # Section midpoints
424
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
425
+ dirs = rays_d[:, None, :].expand(pts.shape)
426
+
427
+ pts = pts.reshape(-1, 3) # pts, nn_ou
428
+ dirs = dirs.reshape(-1, 3)
429
+
430
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
431
+
432
+ # pts = pts.flip((-1,)) * 2 - 1 #
433
+ pts = pts * 2 - 1
434
+
435
+ sdf_nn_output = sdf_network(pts)
436
+ sdf = sdf_nn_output[:, :1]
437
+ feature_vector = sdf_nn_output[:, 1:]
438
+
439
+ gradients = sdf_network.gradient(pts).squeeze()
440
+ sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
441
+
442
+ # deviation network #
443
+ inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
444
+ inv_s = inv_s.expand(batch_size * n_samples, 1)
445
+
446
+ true_cos = (dirs * gradients).sum(-1, keepdim=True)
447
+
448
+ # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
449
+ # the cos value "not dead" at the beginning training iterations, for better convergence.
450
+ iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
451
+ F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
452
+
453
+ # Estimate signed distances at section points
454
+ estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
455
+ estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
456
+
457
+ prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
458
+ next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
459
+
460
+ p = prev_cdf - next_cdf
461
+ c = prev_cdf
462
+
463
+ alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
464
+
465
+ pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
466
+ inside_sphere = (pts_norm < 1.0).float().detach()
467
+ relax_inside_sphere = (pts_norm < 1.2).float().detach()
468
+
469
+ # Render with background
470
+ if background_alpha is not None:
471
+ alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
472
+ alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
473
+ sampled_color = sampled_color * inside_sphere[:, :, None] +\
474
+ background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
475
+ sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
476
+
477
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
478
+ weights_sum = weights.sum(dim=-1, keepdim=True)
479
+
480
+ color = (sampled_color * weights[:, :, None]).sum(dim=1)
481
+ if background_rgb is not None: # Fixed background, usually black
482
+ color = color + background_rgb * (1.0 - weights_sum)
483
+
484
+ # Eikonal loss
485
+ gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
486
+ dim=-1) - 1.0) ** 2
487
+ gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
488
+
489
+ return {
490
+ 'color': color,
491
+ 'sdf': sdf,
492
+ 'dists': dists,
493
+ 'gradients': gradients.reshape(batch_size, n_samples, 3),
494
+ 's_val': 1.0 / inv_s,
495
+ 'mid_z_vals': mid_z_vals,
496
+ 'weights': weights,
497
+ 'cdf': c.reshape(batch_size, n_samples),
498
+ 'gradient_error': gradient_error,
499
+ 'inside_sphere': inside_sphere
500
+ }
501
+
502
+ def render(self, rays_o, rays_d, near, far, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
503
+ batch_size = len(rays_o)
504
+ sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere
505
+ z_vals = torch.linspace(0.0, 1.0, self.n_samples)
506
+ z_vals = near + (far - near) * z_vals[None, :]
507
+
508
+ z_vals_outside = None
509
+ if self.n_outside > 0:
510
+ z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
511
+
512
+ n_samples = self.n_samples
513
+ perturb = self.perturb
514
+
515
+ if perturb_overwrite >= 0:
516
+ perturb = perturb_overwrite
517
+ if perturb > 0:
518
+ t_rand = (torch.rand([batch_size, 1]) - 0.5)
519
+ z_vals = z_vals + t_rand * 2.0 / self.n_samples
520
+
521
+ if self.n_outside > 0:
522
+ mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
523
+ upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
524
+ lower = torch.cat([z_vals_outside[..., :1], mids], -1)
525
+ t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
526
+ z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
527
+
528
+ if self.n_outside > 0:
529
+ z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
530
+
531
+ background_alpha = None
532
+ background_sampled_color = None
533
+
534
+ # Up sample
535
+ if self.n_importance > 0:
536
+ with torch.no_grad():
537
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
538
+
539
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
540
+ # sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
541
+ # gt_sdf #
542
+
543
+ #
544
+ # pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
545
+
546
+ # pts = pts.flip((-1,)) * 2 - 1
547
+ pts = pts * 2 - 1
548
+
549
+ pts_exp = pts.reshape(-1, 3)
550
+ # minn_pts, _ = torch.min(pts_exp, dim=0)
551
+ # maxx_pts, _ = torch.max(pts_exp, dim=0)
552
+ # print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
553
+
554
+ # pts_to_near = pts - near.unsqueeze(1)
555
+ # maxx_pts = 1.5; minn_pts = -1.5
556
+ # # maxx_pts = 3; minn_pts = -3
557
+ # # maxx_pts = 1; minn_pts = -1
558
+ # pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
559
+
560
+ ## render and iamges ####
561
+ if use_gt_sdf:
562
+ ### use the GT sdf field ####
563
+ # print(f"Using gt sdf :")
564
+ sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
565
+ sdf = torch.from_numpy(sdf).float().cuda()
566
+ sdf = sdf.reshape(batch_size, self.n_samples)
567
+ ### use the GT sdf field ####
568
+ else:
569
+ #### use the optimized sdf field ####
570
+ sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
571
+ #### use the optimized sdf field ####
572
+
573
+ for i in range(self.up_sample_steps):
574
+ new_z_vals = self.up_sample(rays_o,
575
+ rays_d,
576
+ z_vals,
577
+ sdf,
578
+ self.n_importance // self.up_sample_steps,
579
+ 64 * 2**i)
580
+ z_vals, sdf = self.cat_z_vals(rays_o,
581
+ rays_d,
582
+ z_vals,
583
+ new_z_vals,
584
+ sdf,
585
+ last=(i + 1 == self.up_sample_steps))
586
+
587
+ n_samples = self.n_samples + self.n_importance
588
+
589
+ # Background model
590
+ if self.n_outside > 0:
591
+ z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
592
+ z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
593
+ ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf)
594
+
595
+ background_sampled_color = ret_outside['sampled_color']
596
+ background_alpha = ret_outside['alpha']
597
+
598
+ # Render core
599
+ ret_fine = self.render_core(rays_o, #
600
+ rays_d,
601
+ z_vals,
602
+ sample_dist,
603
+ self.sdf_network,
604
+ self.deviation_network,
605
+ self.color_network,
606
+ background_rgb=background_rgb,
607
+ background_alpha=background_alpha,
608
+ background_sampled_color=background_sampled_color,
609
+ cos_anneal_ratio=cos_anneal_ratio)
610
+
611
+ color_fine = ret_fine['color']
612
+ weights = ret_fine['weights']
613
+ weights_sum = weights.sum(dim=-1, keepdim=True)
614
+ gradients = ret_fine['gradients']
615
+ s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
616
+
617
+ return {
618
+ 'color_fine': color_fine,
619
+ 's_val': s_val,
620
+ 'cdf_fine': ret_fine['cdf'],
621
+ 'weight_sum': weights_sum,
622
+ 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
623
+ 'gradients': gradients,
624
+ 'weights': weights,
625
+ 'gradient_error': ret_fine['gradient_error'],
626
+ 'inside_sphere': ret_fine['inside_sphere']
627
+ }
628
+
629
+ def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
630
+ return extract_geometry(bound_min, # extract geometry #
631
+ bound_max,
632
+ resolution=resolution,
633
+ threshold=threshold,
634
+ query_func=lambda pts: -self.sdf_network.sdf(pts))
635
+
636
+ def extract_geometry_tets(self, bound_min, bound_max, resolution, threshold=0.0):
637
+ return extract_geometry_tets(bound_min, # extract geometry #
638
+ bound_max,
639
+ resolution=resolution,
640
+ threshold=threshold,
641
+ query_func=lambda pts: -self.sdf_network.sdf(pts))
models/renderer_def.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import logging
6
+ import mcubes
7
+ from icecream import ic
8
+ import os
9
+
10
+ import trimesh
11
+ from pysdf import SDF
12
+
13
+ from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
14
+
15
+ def create_mt_variable(device):
16
+ triangle_table = torch.tensor(
17
+ [
18
+ [-1, -1, -1, -1, -1, -1],
19
+ [1, 0, 2, -1, -1, -1],
20
+ [4, 0, 3, -1, -1, -1],
21
+ [1, 4, 2, 1, 3, 4],
22
+ [3, 1, 5, -1, -1, -1],
23
+ [2, 3, 0, 2, 5, 3],
24
+ [1, 4, 0, 1, 5, 4],
25
+ [4, 2, 5, -1, -1, -1],
26
+ [4, 5, 2, -1, -1, -1],
27
+ [4, 1, 0, 4, 5, 1],
28
+ [3, 2, 0, 3, 5, 2],
29
+ [1, 3, 5, -1, -1, -1],
30
+ [4, 1, 2, 4, 3, 1],
31
+ [3, 0, 4, -1, -1, -1],
32
+ [2, 0, 1, -1, -1, -1],
33
+ [-1, -1, -1, -1, -1, -1]
34
+ ], dtype=torch.long, device=device)
35
+
36
+ num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
37
+ base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
38
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
39
+ return triangle_table, num_triangles_table, base_tet_edges, v_id
40
+
41
+
42
+
43
+ def extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=None):
44
+ # load tet via resolution #
45
+ # scale them via bounds #
46
+ # extract the geometry #
47
+ # /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
48
+ device = bound_min.device
49
+ # if resolution in [64, 70, 80, 90, 100]:
50
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
51
+ # else:
52
+ tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
53
+ tets = np.load(tet_fn)
54
+ verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
55
+ indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
56
+ # split #
57
+ # verts; verts; #
58
+ minn_verts, _ = torch.min(verts, dim=0)
59
+ maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
60
+ # scale_verts = maxx_verts - minn_verts
61
+ scale_bounds = bound_max - bound_min # scale bounds #
62
+
63
+ ### scale the vertices ###
64
+ scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
65
+
66
+ # scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
67
+
68
+ scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
69
+ # scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
70
+
71
+ # scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
72
+ # scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
73
+
74
+ sdf_values = []
75
+ N = 64
76
+ query_bundles = N ** 3 ### N^3
77
+ query_NNs = scaled_verts.size(0) // query_bundles
78
+ if query_NNs * query_bundles < scaled_verts.size(0):
79
+ query_NNs += 1
80
+ for i_query in range(query_NNs):
81
+ cur_bundle_st = i_query * query_bundles
82
+ cur_bundle_ed = (i_query + 1) * query_bundles
83
+ cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
84
+ cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
85
+ if def_func is not None:
86
+ cur_query_pts = def_func(cur_query_pts)
87
+ cur_query_vals = query_func(cur_query_pts)
88
+ sdf_values.append(cur_query_vals)
89
+ sdf_values = torch.cat(sdf_values, dim=0)
90
+ # print(f"queryed sdf values: {sdf_values.size()}") #
91
+
92
+ GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
93
+ GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
94
+
95
+ # intrinsic, tet values, pts values, sdf network #
96
+ triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
97
+ tet_table, num_tets_table = create_tetmesh_variables(device)
98
+
99
+ sdf_values = sdf_values.squeeze(-1) # how the rendering #
100
+
101
+ # print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
102
+ # print(f"scaled_verts: {scaled_verts.size()}, ")
103
+ # pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
104
+ # return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
105
+ # marching_tets_tetmesh ##
106
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
107
+ ### use the GT sdf values for the marching tets ###
108
+ GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
109
+
110
+ # print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
111
+ return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
112
+
113
+ def extract_fields(bound_min, bound_max, resolution, query_func):
114
+ N = 64
115
+ X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
116
+ Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
117
+ Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
118
+
119
+ u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
120
+ with torch.no_grad():
121
+ for xi, xs in enumerate(X):
122
+ for yi, ys in enumerate(Y):
123
+ for zi, zs in enumerate(Z):
124
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
125
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
126
+ val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
127
+ u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
128
+ # should save u here #
129
+ # save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
130
+ # np.save(save_u_path, u)
131
+ # print(f"u saved to {save_u_path}")
132
+ return u
133
+
134
+
135
+ def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
136
+ print('threshold: {}'.format(threshold))
137
+
138
+ ## using maching cubes ###
139
+ u = extract_fields(bound_min, bound_max, resolution, query_func)
140
+ vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
141
+ b_max_np = bound_max.detach().cpu().numpy()
142
+ b_min_np = bound_min.detach().cpu().numpy()
143
+
144
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
145
+ ### using maching cubes ###
146
+
147
+ ### using marching tets ###
148
+ # vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
149
+ # vertices = vertices.detach().cpu().numpy()
150
+ # triangles = triangles.detach().cpu().numpy()
151
+ ### using marching tets ###
152
+
153
+ # b_max_np = bound_max.detach().cpu().numpy()
154
+ # b_min_np = bound_min.detach().cpu().numpy()
155
+
156
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
157
+ return vertices, triangles
158
+
159
+ def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func, def_func=None):
160
+ # print('threshold: {}'.format(threshold))
161
+
162
+ ### using maching cubes ###
163
+ # u = extract_fields(bound_min, bound_max, resolution, query_func)
164
+ # vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
165
+ # b_max_np = bound_max.detach().cpu().numpy()
166
+ # b_min_np = bound_min.detach().cpu().numpy()
167
+
168
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
169
+ ### using maching cubes ###
170
+
171
+ ##
172
+ ### using marching tets ### fiels from tets ##
173
+ vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=def_func)
174
+ # vertices = vertices.detach().cpu().numpy()
175
+ # triangles = triangles.detach().cpu().numpy()
176
+ ### using marching tets ###
177
+
178
+ # b_max_np = bound_max.detach().cpu().numpy()
179
+ # b_min_np = bound_min.detach().cpu().numpy()
180
+ #
181
+
182
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
183
+ return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
184
+
185
+
186
+ def sample_pdf(bins, weights, n_samples, det=False):
187
+ # This implementation is from NeRF
188
+ # Get pdf
189
+ weights = weights + 1e-5 # prevent nans
190
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
191
+ cdf = torch.cumsum(pdf, -1)
192
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
193
+ # Take uniform samples
194
+ if det:
195
+ u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
196
+ u = u.expand(list(cdf.shape[:-1]) + [n_samples])
197
+ else:
198
+ u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
199
+
200
+ # Invert CDF # invert cdf #
201
+ u = u.contiguous()
202
+ inds = torch.searchsorted(cdf, u, right=True)
203
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
204
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
205
+ inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
206
+
207
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
208
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
209
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
210
+
211
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
212
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
213
+ t = (u - cdf_g[..., 0]) / denom
214
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
215
+
216
+ return samples
217
+
218
+
219
+ def load_GT_vertices(GT_meshes_folder):
220
+ tot_meshes_fns = os.listdir(GT_meshes_folder)
221
+ tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
222
+ tot_mesh_verts = []
223
+ tot_mesh_faces = []
224
+ n_tot_verts = 0
225
+ for fn in tot_meshes_fns:
226
+ cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
227
+ obj_mesh = trimesh.load(cur_mesh_fn, process=False)
228
+ # obj_mesh.remove_degenerate_faces(height=1e-06)
229
+
230
+ verts_obj = np.array(obj_mesh.vertices)
231
+ faces_obj = np.array(obj_mesh.faces)
232
+
233
+ tot_mesh_verts.append(verts_obj)
234
+ tot_mesh_faces.append(faces_obj + n_tot_verts)
235
+ n_tot_verts += verts_obj.shape[0]
236
+
237
+ # tot_mesh_faces.append(faces_obj)
238
+ tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
239
+ tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
240
+ return tot_mesh_verts, tot_mesh_faces
241
+
242
+
243
+ class NeuSRenderer:
244
+ def __init__(self,
245
+ nerf,
246
+ sdf_network,
247
+ deviation_network,
248
+ color_network,
249
+ n_samples,
250
+ n_importance,
251
+ n_outside,
252
+ up_sample_steps,
253
+ perturb):
254
+ self.nerf = nerf
255
+ self.sdf_network = sdf_network
256
+ self.deviation_network = deviation_network
257
+ self.color_network = color_network
258
+ self.n_samples = n_samples
259
+ self.n_importance = n_importance
260
+ self.n_outside = n_outside
261
+ self.up_sample_steps = up_sample_steps
262
+ self.perturb = perturb
263
+
264
+ GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
265
+ self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
266
+ maxx_pts = 25.
267
+ minn_pts = -15.
268
+ self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
269
+ f = SDF(self.mesh_vertices, self.mesh_faces)
270
+ self.gt_sdf = f ## a unite sphere or box
271
+
272
+ self.minn_pts = 0
273
+ self.maxx_pts = 1.
274
+
275
+ # self.minn_pts = -1.5 # gorudn-truth states with the deformation -> update the sdf value fiedl
276
+ # self.maxx_pts = 1.5
277
+ self.bkg_pts = ... # TODO: the bkg pts # bkg_pts; # bkg_pts_defs #
278
+ self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
279
+ self.dist_interp_thres = ... # TODO: set the cur_bkg_pts_defs #
280
+
281
+ self.bending_network = ... # TODO: add the bending network #
282
+ self.use_bending_network = ... # TODO: set the property #
283
+ self.use_delta_bending = ... # TODO
284
+ # use bending network #
285
+
286
+
287
+ # get the pts and render the pts #
288
+ # pts and the rendering pts #
289
+ def deform_pts(self, pts, pts_ts=0):
290
+
291
+ if self.use_bending_network:
292
+ if len(pts.size()) == 3:
293
+ nnb, nns = pts.size(0), pts.size(1)
294
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
295
+ else:
296
+ pts_exp = pts
297
+ # pts_ts #
298
+ if self.use_delta_bending:
299
+ # if pts_ts >= 5:
300
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
301
+ # for cur_pts_ts in range(4, -1, -1):
302
+ # # print(f"using delta bending with pts_ts: {cur_pts_ts}")
303
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
304
+ # else:
305
+ # for cur_pts_ts in range(pts_ts, -1, -1):
306
+ # # print(f"using delta bending with pts_ts: {cur_pts_ts}")
307
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
308
+ for cur_pts_ts in range(pts_ts, -1, -1):
309
+ # print(f"using delta bending with pts_ts: {cur_pts_ts}")
310
+ pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
311
+ else:
312
+ pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
313
+ if len(pts.size()) == 3:
314
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
315
+ else:
316
+ pts = pts_exp
317
+ return pts
318
+
319
+ # pts: nn_batch x nn_samples x 3
320
+ if len(pts.size()) == 3:
321
+ nnb, nns = pts.size(0), pts.size(1)
322
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
323
+ else:
324
+ pts_exp = pts
325
+ # print(f"prior to deforming: {pts.size()}")
326
+
327
+ dist_pts_to_bkg_pts = torch.sum(
328
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
329
+ )
330
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
331
+ dist_mask_float = dist_mask.float()
332
+
333
+ # dist_mask_float #
334
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
335
+ cur_fr_pts_def = torch.sum(
336
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
337
+ )
338
+ dist_mask_float_summ = torch.sum(
339
+ dist_mask_float, dim=1
340
+ )
341
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
342
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
343
+ pts_exp = pts_exp - cur_fr_pts_def
344
+ if len(pts.size()) == 3:
345
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
346
+ else:
347
+ pts = pts_exp
348
+ return pts #
349
+
350
+
351
+
352
+
353
+ def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None, pts_ts=0):
354
+ """
355
+ Render background
356
+ """
357
+ batch_size, n_samples = z_vals.shape
358
+
359
+ # Section length
360
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
361
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
362
+ mid_z_vals = z_vals + dists * 0.5
363
+
364
+ # Section midpoints #
365
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
366
+
367
+ # pts = pts.flip((-1,)) * 2 - 1
368
+ pts = pts * 2 - 1
369
+
370
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
371
+
372
+ dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
373
+ pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
374
+
375
+ dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
376
+
377
+ pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
378
+ dirs = dirs.reshape(-1, 3)
379
+
380
+ density, sampled_color = nerf(pts, dirs)
381
+ sampled_color = torch.sigmoid(sampled_color)
382
+ alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
383
+ alpha = alpha.reshape(batch_size, n_samples)
384
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
385
+ sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
386
+ color = (weights[:, :, None] * sampled_color).sum(dim=1)
387
+ if background_rgb is not None:
388
+ color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
389
+
390
+ return {
391
+ 'color': color,
392
+ 'sampled_color': sampled_color,
393
+ 'alpha': alpha,
394
+ 'weights': weights,
395
+ }
396
+
397
+ def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s, pts_ts=0):
398
+ """
399
+ Up sampling give a fixed inv_s
400
+ """
401
+ batch_size, n_samples = z_vals.shape
402
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
403
+
404
+ # pts = pts.flip((-1,)) * 2 - 1
405
+ pts = pts * 2 - 1
406
+
407
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
408
+
409
+ radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
410
+ inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
411
+ sdf = sdf.reshape(batch_size, n_samples)
412
+ prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
413
+ prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
414
+ mid_sdf = (prev_sdf + next_sdf) * 0.5
415
+ cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
416
+
417
+ # ----------------------------------------------------------------------------------------------------------
418
+ # Use min value of [ cos, prev_cos ]
419
+ # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
420
+ # robust when meeting situations like below:
421
+ #
422
+ # SDF
423
+ # ^
424
+ # |\ -----x----...
425
+ # | \ /
426
+ # | x x
427
+ # |---\----/-------------> 0 level
428
+ # | \ /
429
+ # | \/
430
+ # |
431
+ # ----------------------------------------------------------------------------------------------------------
432
+ prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
433
+ cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
434
+ cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
435
+ cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
436
+
437
+ dist = (next_z_vals - prev_z_vals)
438
+ prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
439
+ next_esti_sdf = mid_sdf + cos_val * dist * 0.5
440
+ prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
441
+ next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
442
+ alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
443
+ weights = alpha * torch.cumprod(
444
+ torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
445
+
446
+ z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
447
+ return z_samples
448
+
449
+ def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False, pts_ts=0):
450
+ batch_size, n_samples = z_vals.shape
451
+ _, n_importance = new_z_vals.shape
452
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
453
+
454
+ # pts = pts.flip((-1,)) * 2 - 1
455
+ pts = pts * 2 - 1
456
+
457
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
458
+
459
+ z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
460
+ z_vals, index = torch.sort(z_vals, dim=-1)
461
+
462
+ if not last:
463
+ new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
464
+ sdf = torch.cat([sdf, new_sdf], dim=-1)
465
+ xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
466
+ index = index.reshape(-1)
467
+ sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
468
+
469
+ return z_vals, sdf
470
+
471
+ def render_core(self,
472
+ rays_o,
473
+ rays_d,
474
+ z_vals,
475
+ sample_dist,
476
+ sdf_network,
477
+ deviation_network,
478
+ color_network,
479
+ background_alpha=None,
480
+ background_sampled_color=None,
481
+ background_rgb=None,
482
+ cos_anneal_ratio=0.0,
483
+ pts_ts=0):
484
+ batch_size, n_samples = z_vals.shape
485
+
486
+ # Section length
487
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
488
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
489
+ mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
490
+
491
+ # Section midpoints
492
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
493
+ dirs = rays_d[:, None, :].expand(pts.shape)
494
+
495
+ pts = pts.reshape(-1, 3) # pts, nn_ou
496
+ dirs = dirs.reshape(-1, 3)
497
+
498
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
499
+
500
+ # pts = pts.flip((-1,)) * 2 - 1
501
+ pts = pts * 2 - 1
502
+
503
+
504
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
505
+
506
+ sdf_nn_output = sdf_network(pts)
507
+ sdf = sdf_nn_output[:, :1]
508
+ feature_vector = sdf_nn_output[:, 1:]
509
+
510
+ gradients = sdf_network.gradient(pts).squeeze()
511
+ sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
512
+
513
+ # deviation network #
514
+ inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
515
+ inv_s = inv_s.expand(batch_size * n_samples, 1)
516
+
517
+ true_cos = (dirs * gradients).sum(-1, keepdim=True)
518
+
519
+ # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
520
+ # the cos value "not dead" at the beginning training iterations, for better convergence.
521
+ iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
522
+ F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
523
+
524
+ # Estimate signed distances at section points
525
+ estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
526
+ estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
527
+
528
+ prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
529
+ next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
530
+
531
+ p = prev_cdf - next_cdf
532
+ c = prev_cdf
533
+
534
+ alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
535
+
536
+ pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
537
+ inside_sphere = (pts_norm < 1.0).float().detach()
538
+ relax_inside_sphere = (pts_norm < 1.2).float().detach()
539
+
540
+ # Render with background
541
+ if background_alpha is not None:
542
+ alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
543
+ alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
544
+ sampled_color = sampled_color * inside_sphere[:, :, None] +\
545
+ background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
546
+ sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
547
+
548
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
549
+ weights_sum = weights.sum(dim=-1, keepdim=True)
550
+
551
+ color = (sampled_color * weights[:, :, None]).sum(dim=1)
552
+ if background_rgb is not None: # Fixed background, usually black
553
+ color = color + background_rgb * (1.0 - weights_sum)
554
+
555
+ # Eikonal loss
556
+ gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
557
+ dim=-1) - 1.0) ** 2
558
+ gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
559
+
560
+ return {
561
+ 'color': color,
562
+ 'sdf': sdf,
563
+ 'dists': dists,
564
+ 'gradients': gradients.reshape(batch_size, n_samples, 3),
565
+ 's_val': 1.0 / inv_s,
566
+ 'mid_z_vals': mid_z_vals,
567
+ 'weights': weights,
568
+ 'cdf': c.reshape(batch_size, n_samples),
569
+ 'gradient_error': gradient_error,
570
+ 'inside_sphere': inside_sphere
571
+ }
572
+
573
+ def render(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
574
+ batch_size = len(rays_o)
575
+ sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere
576
+ z_vals = torch.linspace(0.0, 1.0, self.n_samples)
577
+ z_vals = near + (far - near) * z_vals[None, :]
578
+
579
+ z_vals_outside = None
580
+ if self.n_outside > 0:
581
+ z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
582
+
583
+ n_samples = self.n_samples
584
+ perturb = self.perturb
585
+
586
+ if perturb_overwrite >= 0:
587
+ perturb = perturb_overwrite
588
+ if perturb > 0:
589
+ t_rand = (torch.rand([batch_size, 1]) - 0.5)
590
+ z_vals = z_vals + t_rand * 2.0 / self.n_samples
591
+
592
+ if self.n_outside > 0:
593
+ mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
594
+ upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
595
+ lower = torch.cat([z_vals_outside[..., :1], mids], -1)
596
+ t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
597
+ z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
598
+
599
+ if self.n_outside > 0:
600
+ z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
601
+
602
+ background_alpha = None
603
+ background_sampled_color = None
604
+
605
+ # Up sample
606
+ if self.n_importance > 0:
607
+ with torch.no_grad():
608
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
609
+
610
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
611
+ # sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
612
+ # gt_sdf #
613
+
614
+ #
615
+ # pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
616
+
617
+ # pts = pts.flip((-1,)) * 2 - 1
618
+ pts = pts * 2 - 1
619
+
620
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
621
+
622
+ pts_exp = pts.reshape(-1, 3)
623
+ # minn_pts, _ = torch.min(pts_exp, dim=0)
624
+ # maxx_pts, _ = torch.max(pts_exp, dim=0) # deformation field (not a rigid one) -> the meshes #
625
+ # print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
626
+
627
+ # pts_to_near = pts - near.unsqueeze(1)
628
+ # maxx_pts = 1.5; minn_pts = -1.5
629
+ # # maxx_pts = 3; minn_pts = -3
630
+ # # maxx_pts = 1; minn_pts = -1
631
+ # pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
632
+
633
+ ## render and iamges ####
634
+ if use_gt_sdf:
635
+ ### use the GT sdf field ####
636
+ # print(f"Using gt sdf :")
637
+ sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
638
+ sdf = torch.from_numpy(sdf).float().cuda()
639
+ sdf = sdf.reshape(batch_size, self.n_samples)
640
+ ### use the GT sdf field ####
641
+ else:
642
+ #### use the optimized sdf field ####
643
+ sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
644
+ #### use the optimized sdf field ####
645
+
646
+ for i in range(self.up_sample_steps):
647
+ new_z_vals = self.up_sample(rays_o,
648
+ rays_d,
649
+ z_vals,
650
+ sdf,
651
+ self.n_importance // self.up_sample_steps,
652
+ 64 * 2**i,
653
+ pts_ts=pts_ts)
654
+ z_vals, sdf = self.cat_z_vals(rays_o,
655
+ rays_d,
656
+ z_vals,
657
+ new_z_vals,
658
+ sdf,
659
+ last=(i + 1 == self.up_sample_steps),
660
+ pts_ts=pts_ts)
661
+
662
+ n_samples = self.n_samples + self.n_importance
663
+
664
+ # Background model
665
+ if self.n_outside > 0:
666
+ z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
667
+ z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
668
+ ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf, pts_ts=pts_ts)
669
+
670
+ background_sampled_color = ret_outside['sampled_color']
671
+ background_alpha = ret_outside['alpha']
672
+
673
+ # Render core
674
+ ret_fine = self.render_core(rays_o, #
675
+ rays_d,
676
+ z_vals,
677
+ sample_dist,
678
+ self.sdf_network,
679
+ self.deviation_network,
680
+ self.color_network,
681
+ background_rgb=background_rgb,
682
+ background_alpha=background_alpha,
683
+ background_sampled_color=background_sampled_color,
684
+ cos_anneal_ratio=cos_anneal_ratio,
685
+ pts_ts=pts_ts)
686
+
687
+ color_fine = ret_fine['color']
688
+ weights = ret_fine['weights']
689
+ weights_sum = weights.sum(dim=-1, keepdim=True)
690
+ gradients = ret_fine['gradients']
691
+ s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
692
+
693
+ return {
694
+ 'color_fine': color_fine,
695
+ 's_val': s_val,
696
+ 'cdf_fine': ret_fine['cdf'],
697
+ 'weight_sum': weights_sum,
698
+ 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
699
+ 'gradients': gradients,
700
+ 'weights': weights,
701
+ 'gradient_error': ret_fine['gradient_error'],
702
+ 'inside_sphere': ret_fine['inside_sphere']
703
+ }
704
+
705
+ def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
706
+ return extract_geometry(bound_min, # extract geometry #
707
+ bound_max,
708
+ resolution=resolution,
709
+ threshold=threshold,
710
+ query_func=lambda pts: -self.sdf_network.sdf(pts))
711
+
712
+ def extract_geometry_tets(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
713
+ if wdef:
714
+ return extract_geometry_tets(bound_min, # extract geometry #
715
+ bound_max,
716
+ resolution=resolution,
717
+ threshold=threshold,
718
+ query_func=lambda pts: -self.sdf_network.sdf(pts),
719
+ def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts))
720
+ else:
721
+ return extract_geometry_tets(bound_min, # extract geometry #
722
+ bound_max,
723
+ resolution=resolution,
724
+ threshold=threshold,
725
+ query_func=lambda pts: -self.sdf_network.sdf(pts))
models/renderer_def_multi_objs.py ADDED
@@ -0,0 +1,1088 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import logging
6
+ import mcubes
7
+ from icecream import ic
8
+ import os
9
+
10
+ import trimesh
11
+ from pysdf import SDF
12
+
13
+ import models.fields as fields
14
+
15
+ from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
16
+
17
+ def batched_index_select(values, indices, dim = 1):
18
+ value_dims = values.shape[(dim + 1):]
19
+ values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
20
+ indices = indices[(..., *((None,) * len(value_dims)))]
21
+ indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
22
+ value_expand_len = len(indices_shape) - (dim + 1)
23
+ values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
24
+
25
+ value_expand_shape = [-1] * len(values.shape)
26
+ expand_slice = slice(dim, (dim + value_expand_len))
27
+ value_expand_shape[expand_slice] = indices.shape[expand_slice]
28
+ values = values.expand(*value_expand_shape)
29
+
30
+ dim += value_expand_len
31
+ return values.gather(dim, indices)
32
+
33
+
34
+ def create_mt_variable(device):
35
+ triangle_table = torch.tensor(
36
+ [
37
+ [-1, -1, -1, -1, -1, -1],
38
+ [1, 0, 2, -1, -1, -1],
39
+ [4, 0, 3, -1, -1, -1],
40
+ [1, 4, 2, 1, 3, 4],
41
+ [3, 1, 5, -1, -1, -1],
42
+ [2, 3, 0, 2, 5, 3],
43
+ [1, 4, 0, 1, 5, 4],
44
+ [4, 2, 5, -1, -1, -1],
45
+ [4, 5, 2, -1, -1, -1],
46
+ [4, 1, 0, 4, 5, 1],
47
+ [3, 2, 0, 3, 5, 2],
48
+ [1, 3, 5, -1, -1, -1],
49
+ [4, 1, 2, 4, 3, 1],
50
+ [3, 0, 4, -1, -1, -1],
51
+ [2, 0, 1, -1, -1, -1],
52
+ [-1, -1, -1, -1, -1, -1]
53
+ ], dtype=torch.long, device=device)
54
+
55
+ num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
56
+ base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
57
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
58
+ return triangle_table, num_triangles_table, base_tet_edges, v_id
59
+
60
+
61
+
62
+ def extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=None):
63
+ # load tet via resolution #
64
+ # scale them via bounds #
65
+ # extract the geometry #
66
+ # /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
67
+ device = bound_min.device
68
+ # if resolution in [64, 70, 80, 90, 100]:
69
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
70
+ # else:
71
+ tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
72
+ tets = np.load(tet_fn)
73
+ verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
74
+ indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
75
+ # split #
76
+ # verts; verts; #
77
+ minn_verts, _ = torch.min(verts, dim=0)
78
+ maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
79
+ # scale_verts = maxx_verts - minn_verts
80
+ scale_bounds = bound_max - bound_min # scale bounds #
81
+
82
+ ### scale the vertices ###
83
+ scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
84
+
85
+ # scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
86
+
87
+ scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
88
+ # scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
89
+
90
+ # scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
91
+ # scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
92
+
93
+ sdf_values = []
94
+ N = 64
95
+ query_bundles = N ** 3 ### N^3
96
+ query_NNs = scaled_verts.size(0) // query_bundles
97
+ if query_NNs * query_bundles < scaled_verts.size(0):
98
+ query_NNs += 1
99
+ for i_query in range(query_NNs):
100
+ cur_bundle_st = i_query * query_bundles
101
+ cur_bundle_ed = (i_query + 1) * query_bundles
102
+ cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
103
+ cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
104
+ if def_func is not None:
105
+ cur_query_pts = def_func(cur_query_pts)
106
+ cur_query_vals = query_func(cur_query_pts)
107
+ sdf_values.append(cur_query_vals)
108
+ sdf_values = torch.cat(sdf_values, dim=0)
109
+ # print(f"queryed sdf values: {sdf_values.size()}") #
110
+
111
+ GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
112
+ GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
113
+
114
+ # intrinsic, tet values, pts values, sdf network #
115
+ triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
116
+ tet_table, num_tets_table = create_tetmesh_variables(device)
117
+
118
+ sdf_values = sdf_values.squeeze(-1) # how the rendering #
119
+
120
+ # print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
121
+ # print(f"scaled_verts: {scaled_verts.size()}, ")
122
+ # pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
123
+ # return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
124
+ # marching_tets_tetmesh ##
125
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
126
+ ### use the GT sdf values for the marching tets ###
127
+ GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
128
+
129
+ # print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
130
+ return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
131
+
132
+ def extract_fields(bound_min, bound_max, resolution, query_func):
133
+ N = 64
134
+ X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
135
+ Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
136
+ Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
137
+
138
+ u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
139
+ with torch.no_grad():
140
+ for xi, xs in enumerate(X):
141
+ for yi, ys in enumerate(Y):
142
+ for zi, zs in enumerate(Z):
143
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
144
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
145
+ val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
146
+ u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
147
+ # should save u here #
148
+ # save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
149
+ # np.save(save_u_path, u)
150
+ # print(f"u saved to {save_u_path}")
151
+ return u
152
+
153
+
154
+ def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
155
+ print('threshold: {}'.format(threshold))
156
+
157
+ ## using maching cubes ###
158
+ u = extract_fields(bound_min, bound_max, resolution, query_func)
159
+ vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
160
+ b_max_np = bound_max.detach().cpu().numpy()
161
+ b_min_np = bound_min.detach().cpu().numpy()
162
+
163
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
164
+ ### using maching cubes ###
165
+
166
+ ### using marching tets ###
167
+ # vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
168
+ # vertices = vertices.detach().cpu().numpy()
169
+ # triangles = triangles.detach().cpu().numpy()
170
+ ### using marching tets ###
171
+
172
+ # b_max_np = bound_max.detach().cpu().numpy()
173
+ # b_min_np = bound_min.detach().cpu().numpy()
174
+
175
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
176
+ return vertices, triangles
177
+
178
+ def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func, def_func=None):
179
+ # print('threshold: {}'.format(threshold))
180
+
181
+ ### using maching cubes ###
182
+ # u = extract_fields(bound_min, bound_max, resolution, query_func)
183
+ # vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
184
+ # b_max_np = bound_max.detach().cpu().numpy()
185
+ # b_min_np = bound_min.detach().cpu().numpy()
186
+
187
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
188
+ ### using maching cubes ###
189
+
190
+ ##
191
+ ### using marching tets ### fiels from tets ##
192
+ vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=def_func)
193
+ # vertices = vertices.detach().cpu().numpy()
194
+ # triangles = triangles.detach().cpu().numpy()
195
+ ### using marching tets ###
196
+
197
+ # b_max_np = bound_max.detach().cpu().numpy()
198
+ # b_min_np = bound_min.detach().cpu().numpy()
199
+ #
200
+
201
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
202
+ return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
203
+
204
+
205
+ def sample_pdf(bins, weights, n_samples, det=False):
206
+ # This implementation is from NeRF
207
+ # Get pdf
208
+ weights = weights + 1e-5 # prevent nans
209
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
210
+ cdf = torch.cumsum(pdf, -1)
211
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
212
+ # Take uniform samples
213
+ if det:
214
+ u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
215
+ u = u.expand(list(cdf.shape[:-1]) + [n_samples])
216
+ else:
217
+ u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
218
+
219
+ # Invert CDF # invert cdf #
220
+ u = u.contiguous()
221
+ inds = torch.searchsorted(cdf, u, right=True)
222
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
223
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
224
+ inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
225
+
226
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
227
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
228
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
229
+
230
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
231
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
232
+ t = (u - cdf_g[..., 0]) / denom
233
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
234
+
235
+ return samples
236
+
237
+
238
+ def load_GT_vertices(GT_meshes_folder):
239
+ tot_meshes_fns = os.listdir(GT_meshes_folder)
240
+ tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
241
+ tot_mesh_verts = []
242
+ tot_mesh_faces = []
243
+ n_tot_verts = 0
244
+ for fn in tot_meshes_fns:
245
+ cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
246
+ obj_mesh = trimesh.load(cur_mesh_fn, process=False)
247
+ # obj_mesh.remove_degenerate_faces(height=1e-06)
248
+
249
+ verts_obj = np.array(obj_mesh.vertices)
250
+ faces_obj = np.array(obj_mesh.faces)
251
+
252
+ tot_mesh_verts.append(verts_obj)
253
+ tot_mesh_faces.append(faces_obj + n_tot_verts)
254
+ n_tot_verts += verts_obj.shape[0]
255
+
256
+ # tot_mesh_faces.append(faces_obj)
257
+ tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
258
+ tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
259
+ return tot_mesh_verts, tot_mesh_faces
260
+
261
+
262
+ class NeuSRenderer:
263
+ def __init__(self,
264
+ nerf,
265
+ sdf_network,
266
+ deviation_network,
267
+ color_network,
268
+ n_samples,
269
+ n_importance,
270
+ n_outside,
271
+ up_sample_steps,
272
+ perturb):
273
+ self.nerf = nerf # multiple sdf networks and deviation networks and xxx #
274
+ self.sdf_network = sdf_network
275
+ self.deviation_network = deviation_network
276
+ self.color_network = color_network
277
+ self.n_samples = n_samples
278
+ self.n_importance = n_importance
279
+ self.n_outside = n_outside
280
+ self.up_sample_steps = up_sample_steps
281
+ self.perturb = perturb
282
+
283
+ GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
284
+ self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
285
+ maxx_pts = 25.
286
+ minn_pts = -15.
287
+ self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
288
+ f = SDF(self.mesh_vertices, self.mesh_faces)
289
+ self.gt_sdf = f ## a unite sphere or box
290
+
291
+ self.minn_pts = 0
292
+ self.maxx_pts = 1.
293
+
294
+ # self.minn_pts = -1.5 # gorudn-truth states with the deformation -> update the sdf value fiedl
295
+ # self.maxx_pts = 1.5 #
296
+ self.bkg_pts = ... # TODO: the bkg pts # bkg_pts; # bkg_pts_defs #
297
+ self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
298
+ self.dist_interp_thres = ... # TODO: set the cur_bkg_pts_defs #
299
+
300
+ self.bending_network = ... # TODO: add the bending network #
301
+ self.use_bending_network = ... # TODO: set the property #
302
+ self.use_delta_bending = ... # TODO
303
+ self.prev_sdf_network = ... # TODO
304
+ self.use_selector = False
305
+ # use bending network #
306
+ # two bending netwrok
307
+ # two sdf networks
308
+
309
+
310
+ # get the pts and render the pts #
311
+ # pts and the rendering pts #
312
+ def deform_pts(self, pts, pts_ts=0): # deform pts #
313
+
314
+ if self.use_bending_network:
315
+ if len(pts.size()) == 3:
316
+ nnb, nns = pts.size(0), pts.size(1)
317
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
318
+ else:
319
+ pts_exp = pts
320
+ # pts_ts #
321
+ if self.use_delta_bending:
322
+
323
+ if isinstance(self.bending_network, list):
324
+ pts_offsets = []
325
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
326
+ if isinstance(cur_bending_network, fields.BendingNetwork):
327
+ for cur_pts_ts in range(pts_ts, -1, -1):
328
+ cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
329
+ elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
330
+ cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
331
+ else:
332
+ raise ValueError('Encountered with unexpected bending network class...')
333
+ pts_offsets.append(cur_pts_exp - pts_exp)
334
+ pts_offsets = torch.stack(pts_offsets, dim=0)
335
+ pts_offsets = torch.sum(pts_offsets, dim=0)
336
+ pts_exp = pts_exp + pts_offsets
337
+ # for cur_pts_ts in range(pts_ts, -1, -1):
338
+ # if isinstance(self.bending_network, list): # pts ts #
339
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
340
+ # pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
341
+ # else:
342
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
343
+ else:
344
+ if isinstance(self.bending_network, list): # prev sdf network #
345
+ pts_offsets = []
346
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
347
+ bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
348
+ pts_offsets.append(bended_pts_exp - pts_exp)
349
+ pts_offsets = torch.stack(pts_offsets, dim=0)
350
+ pts_offsets = torch.sum(pts_offsets, dim=0)
351
+ pts_exp = pts_exp + pts_offsets
352
+ else:
353
+ pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
354
+ if len(pts.size()) == 3:
355
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
356
+ else:
357
+ pts = pts_exp
358
+ return pts
359
+
360
+ # pts: nn_batch x nn_samples x 3
361
+ if len(pts.size()) == 3:
362
+ nnb, nns = pts.size(0), pts.size(1)
363
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
364
+ else:
365
+ pts_exp = pts
366
+ # print(f"prior to deforming: {pts.size()}")
367
+
368
+ dist_pts_to_bkg_pts = torch.sum(
369
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
370
+ )
371
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
372
+ dist_mask_float = dist_mask.float()
373
+
374
+ # dist_mask_float #
375
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
376
+ cur_fr_pts_def = torch.sum(
377
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
378
+ )
379
+ dist_mask_float_summ = torch.sum(
380
+ dist_mask_float, dim=1
381
+ )
382
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
383
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
384
+ pts_exp = pts_exp - cur_fr_pts_def
385
+ if len(pts.size()) == 3:
386
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
387
+ else:
388
+ pts = pts_exp
389
+ return pts #
390
+
391
+
392
+ def deform_pts_with_selector(self, pts, pts_ts=0): # deform pts #
393
+
394
+ if self.use_bending_network:
395
+ if len(pts.size()) == 3:
396
+ nnb, nns = pts.size(0), pts.size(1)
397
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
398
+ else:
399
+ pts_exp = pts
400
+ # pts_ts #
401
+ if self.use_delta_bending:
402
+ if isinstance(self.bending_network, list):
403
+ bended_pts = []
404
+ queries_sdfs_selector = []
405
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
406
+ if cur_bending_network.use_opt_rigid_translations:
407
+ bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
408
+ else:
409
+ # bended_pts_exp = pts_exp.clone()
410
+ for cur_pts_ts in range(pts_ts, -1, -1):
411
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
412
+ _, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
413
+ bended_pts.append(bended_pts_exp)
414
+ queries_sdfs_selector.append(cur_bended_pts_selecotr)
415
+ bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
416
+ queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
417
+ # queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
418
+ sdf_selector = queries_sdfs_selector[:, -1]
419
+ # sdf_selector = queries_sdfs_selector
420
+ # delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
421
+ bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
422
+ # print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
423
+ pts_exp = bended_pts.squeeze(1)
424
+
425
+
426
+ # for cur_pts_ts in range(pts_ts, -1, -1):
427
+ # if isinstance(self.bending_network, list):
428
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
429
+ # pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
430
+ # else:
431
+
432
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
433
+ else:
434
+ if isinstance(self.bending_network, list): # prev sdf network #
435
+ # pts_offsets = []
436
+ bended_pts = []
437
+ queries_sdfs_selector = []
438
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
439
+ bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
440
+ # pts_offsets.append(bended_pts_exp - pts_exp)
441
+ _, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
442
+ bended_pts.append(bended_pts_exp)
443
+ queries_sdfs_selector.append(cur_bended_pts_selecotr)
444
+ bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
445
+ queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
446
+ # queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
447
+ sdf_selector = queries_sdfs_selector[:, -1]
448
+ # sdf_selector = queries_sdfs_selector
449
+
450
+
451
+ # delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
452
+ bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
453
+ # print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
454
+ pts_exp = bended_pts.squeeze(1)
455
+
456
+ # pts_offsets = torch.stack(pts_offsets, dim=0)
457
+ # pts_offsets = torch.sum(pts_offsets, dim=0)
458
+ # pts_exp = pts_exp + pts_offsets
459
+ else:
460
+ pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
461
+ if len(pts.size()) == 3:
462
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
463
+ else:
464
+ pts = pts_exp
465
+ return pts
466
+
467
+ # pts: nn_batch x nn_samples x 3
468
+ if len(pts.size()) == 3:
469
+ nnb, nns = pts.size(0), pts.size(1)
470
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
471
+ else:
472
+ pts_exp = pts
473
+ # print(f"prior to deforming: {pts.size()}")
474
+
475
+ dist_pts_to_bkg_pts = torch.sum(
476
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
477
+ )
478
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
479
+ dist_mask_float = dist_mask.float()
480
+
481
+ # dist_mask_float #
482
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
483
+ cur_fr_pts_def = torch.sum(
484
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
485
+ )
486
+ dist_mask_float_summ = torch.sum(
487
+ dist_mask_float, dim=1
488
+ )
489
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
490
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
491
+ pts_exp = pts_exp - cur_fr_pts_def
492
+ if len(pts.size()) == 3:
493
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
494
+ else:
495
+ pts = pts_exp
496
+ return pts #
497
+
498
+
499
+ def deform_pts_passive(self, pts, pts_ts=0):
500
+
501
+ if self.use_bending_network:
502
+ if len(pts.size()) == 3:
503
+ nnb, nns = pts.size(0), pts.size(1)
504
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
505
+ else:
506
+ pts_exp = pts
507
+ # pts_ts #
508
+ if self.use_delta_bending:
509
+ for cur_pts_ts in range(pts_ts, -1, -1):
510
+ if isinstance(self.bending_network, list):
511
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
512
+ pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
513
+ else:
514
+ pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
515
+ else:
516
+ # if isinstance(self.bending_network, list):
517
+ # pts_offsets = []
518
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
519
+ # bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
520
+ # pts_offsets.append(bended_pts_exp - pts_exp)
521
+ # pts_offsets = torch.stack(pts_offsets, dim=0)
522
+ # pts_offsets = torch.sum(pts_offsets, dim=0)
523
+ # pts_exp = pts_exp + pts_offsets
524
+ # else:
525
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=pts_ts)
526
+ if len(pts.size()) == 3:
527
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
528
+ else:
529
+ pts = pts_exp
530
+ return pts
531
+
532
+ # pts: nn_batch x nn_samples x 3
533
+ if len(pts.size()) == 3:
534
+ nnb, nns = pts.size(0), pts.size(1)
535
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
536
+ else:
537
+ pts_exp = pts
538
+ # print(f"prior to deforming: {pts.size()}")
539
+
540
+ dist_pts_to_bkg_pts = torch.sum(
541
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
542
+ )
543
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
544
+ dist_mask_float = dist_mask.float()
545
+
546
+ # dist_mask_float #
547
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
548
+ cur_fr_pts_def = torch.sum(
549
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
550
+ )
551
+ dist_mask_float_summ = torch.sum(
552
+ dist_mask_float, dim=1
553
+ )
554
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
555
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
556
+ pts_exp = pts_exp - cur_fr_pts_def
557
+ if len(pts.size()) == 3:
558
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
559
+ else:
560
+ pts = pts_exp
561
+ return pts #
562
+
563
+
564
+ def query_pts_sdf_fn_for_selector(self, pts):
565
+ # for negative
566
+ # 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field
567
+ # 2) negative in thie field and also negative in the previous field --->
568
+ # 2) for positive values of this current field --->
569
+ cur_sdf = self.sdf_network.sdf(pts)
570
+ prev_sdf = self.prev_sdf_network.sdf(pts)
571
+ neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
572
+ neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
573
+
574
+ neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
575
+
576
+ pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
577
+ pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
578
+ res_sdf = torch.zeros_like(cur_sdf)
579
+ res_sdf[neg_neg] = 1. #
580
+ res_sdf[neg_pos] = cur_sdf[neg_pos]
581
+ res_sdf[pos_neg] = cur_sdf[pos_neg]
582
+
583
+ # inside the residual mesh -> must be neg and pos
584
+ res_sdf_selector = torch.zeros_like(cur_sdf).long() #
585
+ # res_sdf_selector[neg_pos] = 1 # is the residual mesh
586
+ res_sdf_selector[neg_weq_pos] = 1
587
+ # res_sdf_selector[]
588
+
589
+ cat_cur_prev_sdf = torch.stack(
590
+ [cur_sdf, prev_sdf], dim=-1
591
+ )
592
+ minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
593
+ res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
594
+
595
+ return res_sdf, res_sdf_selector
596
+
597
+ def query_func_sdf(self, pts):
598
+ if isinstance(self.sdf_network, list):
599
+ tot_sdf_values = []
600
+ for i_obj, cur_sdf_network in enumerate(self.sdf_network):
601
+ cur_sdf_values = cur_sdf_network.sdf(pts)
602
+ tot_sdf_values.append(cur_sdf_values)
603
+ tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
604
+ tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
605
+ sdf = tot_sdf_values
606
+ else:
607
+ sdf = self.sdf_network.sdf(pts)
608
+ return sdf
609
+
610
+ def query_func_sdf_passive(self, pts):
611
+ # if isinstance(self.sdf_network, list):
612
+ # tot_sdf_values = []
613
+ # for i_obj, cur_sdf_network in enumerate(self.sdf_network):
614
+ # cur_sdf_values = cur_sdf_network.sdf(pts)
615
+ # tot_sdf_values.append(cur_sdf_values)
616
+ # tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
617
+ # tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
618
+ # sdf = tot_sdf_values
619
+ # else:
620
+ sdf = self.sdf_network[-1].sdf(pts)
621
+
622
+ return sdf
623
+
624
+
625
+ def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None, pts_ts=0):
626
+ """
627
+ Render background
628
+ """
629
+ batch_size, n_samples = z_vals.shape
630
+
631
+ # Section length
632
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
633
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
634
+ mid_z_vals = z_vals + dists * 0.5
635
+
636
+ # Section midpoints #
637
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
638
+
639
+ # pts = pts.flip((-1,)) * 2 - 1
640
+ pts = pts * 2 - 1
641
+
642
+ if self.use_selector:
643
+ pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
644
+ else:
645
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
646
+
647
+ dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
648
+ pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
649
+
650
+ dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
651
+
652
+ pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
653
+ dirs = dirs.reshape(-1, 3)
654
+
655
+ density, sampled_color = nerf(pts, dirs)
656
+ sampled_color = torch.sigmoid(sampled_color)
657
+ alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
658
+ alpha = alpha.reshape(batch_size, n_samples)
659
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
660
+ sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
661
+ color = (weights[:, :, None] * sampled_color).sum(dim=1)
662
+ if background_rgb is not None:
663
+ color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
664
+
665
+ return {
666
+ 'color': color,
667
+ 'sampled_color': sampled_color,
668
+ 'alpha': alpha,
669
+ 'weights': weights,
670
+ }
671
+
672
+ def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s, pts_ts=0):
673
+ """
674
+ Up sampling give a fixed inv_s
675
+ """
676
+ batch_size, n_samples = z_vals.shape
677
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
678
+
679
+ # pts = pts.flip((-1,)) * 2 - 1
680
+ pts = pts * 2 - 1
681
+
682
+ if self.use_selector:
683
+ pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
684
+ else:
685
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
686
+
687
+ radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
688
+ inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
689
+ sdf = sdf.reshape(batch_size, n_samples)
690
+ prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
691
+ prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
692
+ mid_sdf = (prev_sdf + next_sdf) * 0.5
693
+ cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
694
+
695
+ # ----------------------------------------------------------------------------------------------------------
696
+ # Use min value of [ cos, prev_cos ]
697
+ # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
698
+ # robust when meeting situations like below:
699
+ #
700
+ # SDF
701
+ # ^
702
+ # |\ -----x----...
703
+ # | \ /
704
+ # | x x
705
+ # |---\----/-------------> 0 level
706
+ # | \ /
707
+ # | \/
708
+ # |
709
+ # ----------------------------------------------------------------------------------------------------------
710
+ prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
711
+ cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
712
+ cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
713
+ cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
714
+
715
+ dist = (next_z_vals - prev_z_vals)
716
+ prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
717
+ next_esti_sdf = mid_sdf + cos_val * dist * 0.5
718
+ prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
719
+ next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
720
+ alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
721
+ weights = alpha * torch.cumprod(
722
+ torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
723
+
724
+ z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
725
+ return z_samples
726
+
727
+ def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False, pts_ts=0):
728
+ batch_size, n_samples = z_vals.shape
729
+ _, n_importance = new_z_vals.shape
730
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
731
+
732
+ # pts = pts.flip((-1,)) * 2 - 1
733
+ pts = pts * 2 - 1
734
+
735
+ if self.use_selector:
736
+ pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
737
+ else:
738
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
739
+
740
+ z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
741
+ z_vals, index = torch.sort(z_vals, dim=-1)
742
+
743
+ if not last:
744
+ if isinstance(self.sdf_network, list):
745
+ tot_new_sdf = []
746
+ for i_obj, cur_sdf_network in enumerate(self.sdf_network):
747
+ cur_new_sdf = cur_sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
748
+ tot_new_sdf.append(cur_new_sdf)
749
+ tot_new_sdf = torch.stack(tot_new_sdf, dim=-1)
750
+ new_sdf, _ = torch.min(tot_new_sdf, dim=-1) #
751
+ else:
752
+ new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
753
+ sdf = torch.cat([sdf, new_sdf], dim=-1)
754
+ xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
755
+ index = index.reshape(-1)
756
+ sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
757
+
758
+ return z_vals, sdf
759
+
760
+
761
+
762
+ def render_core(self,
763
+ rays_o,
764
+ rays_d,
765
+ z_vals,
766
+ sample_dist,
767
+ sdf_network,
768
+ deviation_network,
769
+ color_network,
770
+ background_alpha=None,
771
+ background_sampled_color=None,
772
+ background_rgb=None,
773
+ cos_anneal_ratio=0.0,
774
+ pts_ts=0):
775
+ batch_size, n_samples = z_vals.shape
776
+
777
+ # Section length
778
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
779
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
780
+ mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
781
+
782
+ # Section midpoints
783
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
784
+ dirs = rays_d[:, None, :].expand(pts.shape)
785
+
786
+ pts = pts.reshape(-1, 3) # pts, nn_ou
787
+ dirs = dirs.reshape(-1, 3)
788
+
789
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
790
+
791
+ # pts = pts.flip((-1,)) * 2 - 1
792
+ pts = pts * 2 - 1
793
+
794
+ if self.use_selector:
795
+ pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
796
+ else:
797
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
798
+
799
+ if isinstance(sdf_network, list):
800
+ tot_sdf = []
801
+ tot_feature_vector = []
802
+ tot_obj_sel = []
803
+ tot_gradients = []
804
+ for i_obj, cur_sdf_network in enumerate(sdf_network):
805
+ cur_sdf_nn_output = cur_sdf_network(pts)
806
+ cur_sdf, cur_feature_vector = cur_sdf_nn_output[:, :1], cur_sdf_nn_output[:, 1:]
807
+ tot_sdf.append(cur_sdf)
808
+ tot_feature_vector.append(cur_feature_vector)
809
+
810
+ gradients = cur_sdf_network.gradient(pts).squeeze()
811
+ tot_gradients.append(gradients)
812
+ tot_sdf = torch.stack(tot_sdf, dim=-1)
813
+ sdf, obj_sel = torch.min(tot_sdf, dim=-1)
814
+ feature_vector = torch.stack(tot_feature_vector, dim=1)
815
+
816
+ # batched_index_select
817
+ # print(f"before sel: {feature_vector.size()}, obj_sel: {obj_sel.size()}")
818
+ feature_vector = batched_index_select(values=feature_vector, indices=obj_sel, dim=1).squeeze(1)
819
+
820
+
821
+ # feature_vector = feature_vector[obj_sel.unsqueeze(-1), :].squeeze(1)
822
+ # print(f"after sel: {feature_vector.size()}")
823
+ tot_gradients = torch.stack(tot_gradients, dim=1)
824
+ # gradients = tot_gradients[obj_sel.unsqueeze(-1)].squeeze(1)
825
+ gradients = batched_index_select(values=tot_gradients, indices=obj_sel, dim=1).squeeze(1)
826
+ # print(f"gradients: {gradients.size()}, tot_gradients: {tot_gradients.size()}")
827
+
828
+ else:
829
+ sdf_nn_output = sdf_network(pts)
830
+ sdf = sdf_nn_output[:, :1]
831
+ feature_vector = sdf_nn_output[:, 1:]
832
+ gradients = sdf_network.gradient(pts).squeeze()
833
+
834
+ sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
835
+
836
+ # deviation network #
837
+ inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
838
+ inv_s = inv_s.expand(batch_size * n_samples, 1)
839
+
840
+ true_cos = (dirs * gradients).sum(-1, keepdim=True)
841
+
842
+ # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
843
+ # the cos value "not dead" at the beginning training iterations, for better convergence.
844
+ iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
845
+ F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
846
+
847
+ # Estimate signed distances at section points
848
+ estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
849
+ estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
850
+
851
+ prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
852
+ next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
853
+
854
+ p = prev_cdf - next_cdf
855
+ c = prev_cdf
856
+
857
+ alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
858
+
859
+ pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
860
+ inside_sphere = (pts_norm < 1.0).float().detach()
861
+ relax_inside_sphere = (pts_norm < 1.2).float().detach()
862
+
863
+ # Render with background
864
+ if background_alpha is not None:
865
+ alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
866
+ alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
867
+ sampled_color = sampled_color * inside_sphere[:, :, None] +\
868
+ background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
869
+ sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
870
+
871
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
872
+ weights_sum = weights.sum(dim=-1, keepdim=True)
873
+
874
+ color = (sampled_color * weights[:, :, None]).sum(dim=1)
875
+ if background_rgb is not None: # Fixed background, usually black
876
+ color = color + background_rgb * (1.0 - weights_sum)
877
+
878
+ # Eikonal loss
879
+ gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
880
+ dim=-1) - 1.0) ** 2
881
+ gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
882
+
883
+ return {
884
+ 'color': color,
885
+ 'sdf': sdf,
886
+ 'dists': dists,
887
+ 'gradients': gradients.reshape(batch_size, n_samples, 3),
888
+ 's_val': 1.0 / inv_s,
889
+ 'mid_z_vals': mid_z_vals,
890
+ 'weights': weights,
891
+ 'cdf': c.reshape(batch_size, n_samples),
892
+ 'gradient_error': gradient_error,
893
+ 'inside_sphere': inside_sphere
894
+ }
895
+
896
+ def render(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
897
+ batch_size = len(rays_o)
898
+ sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere
899
+ z_vals = torch.linspace(0.0, 1.0, self.n_samples) # linspace #
900
+ z_vals = near + (far - near) * z_vals[None, :]
901
+
902
+ z_vals_outside = None
903
+ if self.n_outside > 0:
904
+ z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
905
+
906
+ n_samples = self.n_samples
907
+ perturb = self.perturb
908
+
909
+ if perturb_overwrite >= 0:
910
+ perturb = perturb_overwrite
911
+ if perturb > 0:
912
+ t_rand = (torch.rand([batch_size, 1]) - 0.5)
913
+ z_vals = z_vals + t_rand * 2.0 / self.n_samples
914
+
915
+ if self.n_outside > 0: # z values output # n_outside #
916
+ mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
917
+ upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
918
+ lower = torch.cat([z_vals_outside[..., :1], mids], -1)
919
+ t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
920
+ z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
921
+
922
+ if self.n_outside > 0:
923
+ z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
924
+
925
+ background_alpha = None
926
+ background_sampled_color = None
927
+
928
+ # Up sample
929
+ if self.n_importance > 0:
930
+ with torch.no_grad():
931
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
932
+
933
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
934
+ # sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
935
+ # gt_sdf #
936
+
937
+ #
938
+ # pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
939
+
940
+ # pts = pts.flip((-1,)) * 2 - 1
941
+ pts = pts * 2 - 1
942
+
943
+ if self.use_selector:
944
+ pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
945
+ else:
946
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts) # give nthe pts
947
+
948
+ pts_exp = pts.reshape(-1, 3)
949
+ # minn_pts, _ = torch.min(pts_exp, dim=0)
950
+ # maxx_pts, _ = torch.max(pts_exp, dim=0) # deformation field (not a rigid one) -> the meshes #
951
+ # print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
952
+
953
+ # pts_to_near = pts - near.unsqueeze(1)
954
+ # maxx_pts = 1.5; minn_pts = -1.5
955
+ # # maxx_pts = 3; minn_pts = -3
956
+ # # maxx_pts = 1; minn_pts = -1
957
+ # pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
958
+
959
+ ## render and iamges ####
960
+ if use_gt_sdf:
961
+ ### use the GT sdf field ####
962
+ # print(f"Using gt sdf :")
963
+ sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
964
+ sdf = torch.from_numpy(sdf).float().cuda()
965
+ sdf = sdf.reshape(batch_size, self.n_samples)
966
+ ### use the GT sdf field ####
967
+ else:
968
+ # pts_exp: (bsz x nn_s) x 3 -> (sdf_network) -> (bsz x nn_s)
969
+ #### use the optimized sdf field ####
970
+
971
+ # sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
972
+
973
+ if isinstance(self.sdf_network, list):
974
+ tot_sdf_values = []
975
+ for i_obj, cur_sdf_network in enumerate(self.sdf_network):
976
+ cur_sdf_values = cur_sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
977
+ tot_sdf_values.append(cur_sdf_values)
978
+ tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
979
+ tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
980
+ sdf = tot_sdf_values
981
+ else:
982
+ sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
983
+
984
+ #### use the optimized sdf field ####
985
+
986
+ for i in range(self.up_sample_steps):
987
+ new_z_vals = self.up_sample(rays_o,
988
+ rays_d,
989
+ z_vals,
990
+ sdf,
991
+ self.n_importance // self.up_sample_steps,
992
+ 64 * 2**i,
993
+ pts_ts=pts_ts)
994
+ z_vals, sdf = self.cat_z_vals(rays_o,
995
+ rays_d,
996
+ z_vals,
997
+ new_z_vals,
998
+ sdf,
999
+ last=(i + 1 == self.up_sample_steps),
1000
+ pts_ts=pts_ts)
1001
+
1002
+ n_samples = self.n_samples + self.n_importance
1003
+
1004
+ # Background model
1005
+ if self.n_outside > 0:
1006
+ z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
1007
+ z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
1008
+ ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf, pts_ts=pts_ts)
1009
+
1010
+ background_sampled_color = ret_outside['sampled_color']
1011
+ background_alpha = ret_outside['alpha']
1012
+
1013
+ # Render core
1014
+ ret_fine = self.render_core(rays_o, #
1015
+ rays_d,
1016
+ z_vals,
1017
+ sample_dist,
1018
+ self.sdf_network,
1019
+ self.deviation_network,
1020
+ self.color_network,
1021
+ background_rgb=background_rgb,
1022
+ background_alpha=background_alpha,
1023
+ background_sampled_color=background_sampled_color,
1024
+ cos_anneal_ratio=cos_anneal_ratio,
1025
+ pts_ts=pts_ts)
1026
+
1027
+ color_fine = ret_fine['color']
1028
+ weights = ret_fine['weights']
1029
+ weights_sum = weights.sum(dim=-1, keepdim=True)
1030
+ gradients = ret_fine['gradients']
1031
+ s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
1032
+
1033
+ return {
1034
+ 'color_fine': color_fine,
1035
+ 's_val': s_val,
1036
+ 'cdf_fine': ret_fine['cdf'],
1037
+ 'weight_sum': weights_sum,
1038
+ 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
1039
+ 'gradients': gradients,
1040
+ 'weights': weights,
1041
+ 'gradient_error': ret_fine['gradient_error'],
1042
+ 'inside_sphere': ret_fine['inside_sphere']
1043
+ }
1044
+
1045
+ def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
1046
+ return extract_geometry(bound_min, # extract geometry #
1047
+ bound_max,
1048
+ resolution=resolution,
1049
+ threshold=threshold,
1050
+ # query_func=lambda pts: -self.sdf_network.sdf(pts),
1051
+ query_func=lambda pts: -self.query_func_sdf(pts)
1052
+ )
1053
+
1054
+ # if self.deform_pts_with_selector:
1055
+ # pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
1056
+ def extract_geometry_tets(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
1057
+ if wdef:
1058
+ return extract_geometry_tets(bound_min, # extract geometry #
1059
+ bound_max,
1060
+ resolution=resolution,
1061
+ threshold=threshold,
1062
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1063
+ def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts))
1064
+ else:
1065
+ return extract_geometry_tets(bound_min, # extract geometry #
1066
+ bound_max,
1067
+ resolution=resolution,
1068
+ threshold=threshold,
1069
+ # query_func=lambda pts: -self.sdf_network.sdf(pts)
1070
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1071
+ )
1072
+
1073
+ def extract_geometry_tets_passive(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
1074
+ if wdef:
1075
+ return extract_geometry_tets(bound_min, # extract geometry #
1076
+ bound_max,
1077
+ resolution=resolution,
1078
+ threshold=threshold,
1079
+ query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts),
1080
+ def_func=lambda pts: self.deform_pts_passive(pts, pts_ts=pts_ts))
1081
+ else:
1082
+ return extract_geometry_tets(bound_min, # extract geometry #
1083
+ bound_max,
1084
+ resolution=resolution,
1085
+ threshold=threshold,
1086
+ # query_func=lambda pts: -self.sdf_network.sdf(pts)
1087
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1088
+ )
models/renderer_def_multi_objs_compositional.py ADDED
@@ -0,0 +1,1510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import logging
6
+ import mcubes
7
+ from icecream import ic
8
+ import os
9
+
10
+ import trimesh
11
+ from pysdf import SDF
12
+
13
+ import models.fields as fields
14
+
15
+ from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
16
+
17
+ def batched_index_select(values, indices, dim = 1):
18
+ value_dims = values.shape[(dim + 1):]
19
+ values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
20
+ indices = indices[(..., *((None,) * len(value_dims)))]
21
+ indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
22
+ value_expand_len = len(indices_shape) - (dim + 1)
23
+ values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
24
+
25
+ value_expand_shape = [-1] * len(values.shape)
26
+ expand_slice = slice(dim, (dim + value_expand_len))
27
+ value_expand_shape[expand_slice] = indices.shape[expand_slice]
28
+ values = values.expand(*value_expand_shape)
29
+
30
+ dim += value_expand_len
31
+ return values.gather(dim, indices)
32
+
33
+
34
+ def create_mt_variable(device):
35
+ triangle_table = torch.tensor(
36
+ [
37
+ [-1, -1, -1, -1, -1, -1],
38
+ [1, 0, 2, -1, -1, -1],
39
+ [4, 0, 3, -1, -1, -1],
40
+ [1, 4, 2, 1, 3, 4],
41
+ [3, 1, 5, -1, -1, -1],
42
+ [2, 3, 0, 2, 5, 3],
43
+ [1, 4, 0, 1, 5, 4],
44
+ [4, 2, 5, -1, -1, -1],
45
+ [4, 5, 2, -1, -1, -1],
46
+ [4, 1, 0, 4, 5, 1],
47
+ [3, 2, 0, 3, 5, 2],
48
+ [1, 3, 5, -1, -1, -1],
49
+ [4, 1, 2, 4, 3, 1],
50
+ [3, 0, 4, -1, -1, -1],
51
+ [2, 0, 1, -1, -1, -1],
52
+ [-1, -1, -1, -1, -1, -1]
53
+ ], dtype=torch.long, device=device)
54
+
55
+ num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
56
+ base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
57
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
58
+ return triangle_table, num_triangles_table, base_tet_edges, v_id
59
+
60
+
61
+
62
+ def extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=None):
63
+ # load tet via resolution #
64
+ # scale them via bounds #
65
+ # extract the geometry #
66
+ # /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
67
+ device = bound_min.device
68
+ # if resolution in [64, 70, 80, 90, 100]:
69
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
70
+ # else:
71
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
72
+ tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
73
+ if not os.path.exists(tet_fn):
74
+ tet_fn = f"/data/xueyi/NeuS/data/tets/{100}_compress.npz"
75
+ tets = np.load(tet_fn)
76
+ verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
77
+ indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
78
+ # split #
79
+ # verts; verts; #
80
+ minn_verts, _ = torch.min(verts, dim=0)
81
+ maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
82
+ # scale_verts = maxx_verts - minn_verts
83
+ scale_bounds = bound_max - bound_min # scale bounds #
84
+
85
+ ### scale the vertices ###
86
+ scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
87
+
88
+ # scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
89
+
90
+ scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
91
+ # scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
92
+
93
+ # scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
94
+ # scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
95
+
96
+ sdf_values = []
97
+ N = 64
98
+ query_bundles = N ** 3 ### N^3
99
+ query_NNs = scaled_verts.size(0) // query_bundles
100
+ if query_NNs * query_bundles < scaled_verts.size(0):
101
+ query_NNs += 1
102
+ for i_query in range(query_NNs):
103
+ cur_bundle_st = i_query * query_bundles
104
+ cur_bundle_ed = (i_query + 1) * query_bundles
105
+ cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
106
+ cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
107
+ if def_func is not None:
108
+ cur_query_pts = def_func(cur_query_pts)
109
+ cur_query_vals = query_func(cur_query_pts)
110
+ sdf_values.append(cur_query_vals)
111
+ sdf_values = torch.cat(sdf_values, dim=0)
112
+ # print(f"queryed sdf values: {sdf_values.size()}") #
113
+
114
+ # GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
115
+ gt_sdf_fn = "/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy"
116
+ if not os.path.exists(gt_sdf_fn):
117
+ gt_sdf_fn = "/data/xueyi/NeuS/data/100_sdf_values.npy"
118
+ GT_sdf_values = np.load(gt_sdf_fn, allow_pickle=True)
119
+ GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
120
+
121
+ # intrinsic, tet values, pts values, sdf network #
122
+ triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
123
+ tet_table, num_tets_table = create_tetmesh_variables(device)
124
+
125
+ sdf_values = sdf_values.squeeze(-1) # how the rendering #
126
+
127
+ # print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
128
+ # print(f"scaled_verts: {scaled_verts.size()}, ")
129
+ # pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
130
+ # return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
131
+ # marching_tets_tetmesh ##
132
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
133
+ ### use the GT sdf values for the marching tets ###
134
+ GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
135
+
136
+ # print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
137
+ return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
138
+
139
+
140
+ def extract_fields_from_tets_selector(bound_min, bound_max, resolution, query_func, def_func=None):
141
+ # load tet via resolution #
142
+ # scale them via bounds #
143
+ # extract the geometry #
144
+ # /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
145
+ device = bound_min.device
146
+ # if resolution in [64, 70, 80, 90, 100]:
147
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
148
+ # else:
149
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
150
+ tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
151
+ if not os.path.exists(tet_fn):
152
+ tet_fn = f"/data/xueyi/NeuS/data/tets/{100}_compress.npz"
153
+ tets = np.load(tet_fn)
154
+ verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
155
+ indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
156
+ # split #
157
+ # verts; verts; #
158
+ minn_verts, _ = torch.min(verts, dim=0)
159
+ maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
160
+ # scale_verts = maxx_verts - minn_verts
161
+ scale_bounds = bound_max - bound_min # scale bounds #
162
+
163
+ ### scale the vertices ###
164
+ scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
165
+
166
+ # scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
167
+
168
+ scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
169
+ # scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
170
+
171
+ # scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
172
+ # scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
173
+
174
+ sdf_values = []
175
+ N = 64
176
+ query_bundles = N ** 3 ### N^3
177
+ query_NNs = scaled_verts.size(0) // query_bundles
178
+ if query_NNs * query_bundles < scaled_verts.size(0):
179
+ query_NNs += 1
180
+ for i_query in range(query_NNs):
181
+ cur_bundle_st = i_query * query_bundles
182
+ cur_bundle_ed = (i_query + 1) * query_bundles
183
+ cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
184
+ cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
185
+ if def_func is not None:
186
+ cur_query_pts, _ = def_func(cur_query_pts)
187
+ cur_query_vals = query_func(cur_query_pts)
188
+ sdf_values.append(cur_query_vals)
189
+ sdf_values = torch.cat(sdf_values, dim=0)
190
+ # print(f"queryed sdf values: {sdf_values.size()}") #
191
+
192
+ # GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
193
+ gt_sdf_fn = "/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy"
194
+ if not os.path.exists(gt_sdf_fn):
195
+ gt_sdf_fn = "/data/xueyi/NeuS/data/100_sdf_values.npy"
196
+ GT_sdf_values = np.load(gt_sdf_fn, allow_pickle=True)
197
+ GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
198
+
199
+ # intrinsic, tet values, pts values, sdf network #
200
+ triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
201
+ tet_table, num_tets_table = create_tetmesh_variables(device)
202
+
203
+ sdf_values = sdf_values.squeeze(-1) # how the rendering #
204
+
205
+ # print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
206
+ # print(f"scaled_verts: {scaled_verts.size()}, ")
207
+ # pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
208
+ # return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
209
+ # marching_tets_tetmesh ##
210
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
211
+ ### use the GT sdf values for the marching tets ###
212
+ GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
213
+
214
+ # print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
215
+ return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
216
+
217
+
218
+ def extract_fields(bound_min, bound_max, resolution, query_func):
219
+ N = 64
220
+ X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
221
+ Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
222
+ Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
223
+
224
+ u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
225
+ with torch.no_grad():
226
+ for xi, xs in enumerate(X):
227
+ for yi, ys in enumerate(Y):
228
+ for zi, zs in enumerate(Z):
229
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
230
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
231
+ val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
232
+ u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
233
+ # should save u here #
234
+ # save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
235
+ # np.save(save_u_path, u) #
236
+ # print(f"u saved to {save_u_path}")
237
+ return u
238
+
239
+
240
+ def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
241
+ print('threshold: {}'.format(threshold))
242
+
243
+ ## using maching cubes ###
244
+ u = extract_fields(bound_min, bound_max, resolution, query_func)
245
+ vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
246
+ b_max_np = bound_max.detach().cpu().numpy()
247
+ b_min_np = bound_min.detach().cpu().numpy()
248
+
249
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
250
+ ### using maching cubes ###
251
+
252
+ ### using marching tets ###
253
+ # vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
254
+ # vertices = vertices.detach().cpu().numpy()
255
+ # triangles = triangles.detach().cpu().numpy()
256
+ ### using marching tets ###
257
+
258
+ # b_max_np = bound_max.detach().cpu().numpy()
259
+ # b_min_np = bound_min.detach().cpu().numpy()
260
+
261
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
262
+ return vertices, triangles
263
+
264
+ def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func, def_func=None, selector=False):
265
+ # print('threshold: {}'.format(threshold))
266
+
267
+ ### using maching cubes ###
268
+ # u = extract_fields(bound_min, bound_max, resolution, query_func)
269
+ # vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
270
+ # b_max_np = bound_max.detach().cpu().numpy()
271
+ # b_min_np = bound_min.detach().cpu().numpy()
272
+
273
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
274
+ ### using maching cubes ###
275
+
276
+ ##
277
+ ### using marching tets ### fiels from tets ##
278
+ if selector:
279
+ vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets_selector(bound_min, bound_max, resolution, query_func, def_func=def_func)
280
+ else:
281
+ vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=def_func)
282
+ # vertices = vertices.detach().cpu().numpy()
283
+ # triangles = triangles.detach().cpu().numpy()
284
+ ### using marching tets ###
285
+
286
+ # b_max_np = bound_max.detach().cpu().numpy()
287
+ # b_min_np = bound_min.detach().cpu().numpy()
288
+ #
289
+
290
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
291
+ return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
292
+
293
+
294
+ ### sample pdfs ###
295
+ def sample_pdf(bins, weights, n_samples, det=False):
296
+ # This implementation is from NeRF
297
+ # Get pdf
298
+ weights = weights + 1e-5 # prevent nans
299
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
300
+ cdf = torch.cumsum(pdf, -1)
301
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
302
+ # Take uniform samples
303
+ if det:
304
+ u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
305
+ u = u.expand(list(cdf.shape[:-1]) + [n_samples])
306
+ else:
307
+ u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
308
+
309
+ # Invert CDF # invert cdf #
310
+ u = u.contiguous()
311
+ inds = torch.searchsorted(cdf, u, right=True)
312
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
313
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
314
+ inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
315
+
316
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
317
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
318
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
319
+
320
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
321
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
322
+ t = (u - cdf_g[..., 0]) / denom
323
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
324
+
325
+ return samples
326
+
327
+
328
+ def load_GT_vertices(GT_meshes_folder):
329
+ tot_meshes_fns = os.listdir(GT_meshes_folder)
330
+ tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
331
+ tot_mesh_verts = []
332
+ tot_mesh_faces = []
333
+ n_tot_verts = 0
334
+ for fn in tot_meshes_fns:
335
+ cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
336
+ obj_mesh = trimesh.load(cur_mesh_fn, process=False)
337
+ # obj_mesh.remove_degenerate_faces(height=1e-06)
338
+
339
+ verts_obj = np.array(obj_mesh.vertices)
340
+ faces_obj = np.array(obj_mesh.faces)
341
+
342
+ tot_mesh_verts.append(verts_obj)
343
+ tot_mesh_faces.append(faces_obj + n_tot_verts)
344
+ n_tot_verts += verts_obj.shape[0]
345
+
346
+ # tot_mesh_faces.append(faces_obj)
347
+ tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
348
+ tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
349
+ return tot_mesh_verts, tot_mesh_faces
350
+
351
+
352
+ class NeuSRenderer:
353
+ def __init__(self,
354
+ nerf,
355
+ sdf_network,
356
+ deviation_network,
357
+ color_network,
358
+ n_samples,
359
+ n_importance,
360
+ n_outside,
361
+ up_sample_steps,
362
+ perturb):
363
+ self.nerf = nerf #
364
+ self.sdf_network = sdf_network
365
+ self.deviation_network = deviation_network
366
+ self.color_network = color_network
367
+ self.n_samples = n_samples
368
+ self.n_importance = n_importance
369
+ self.n_outside = n_outside
370
+ self.up_sample_steps = up_sample_steps
371
+ self.perturb = perturb
372
+
373
+ GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
374
+ if not os.path.exists(GT_meshes_folder):
375
+ GT_meshes_folder = "/data/xueyi/diffsim/DiffHand/assets/hand"
376
+ self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
377
+ maxx_pts = 25.
378
+ minn_pts = -15.
379
+ self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
380
+ f = SDF(self.mesh_vertices, self.mesh_faces)
381
+ self.gt_sdf = f ## a unite sphere or box
382
+
383
+ self.minn_pts = 0
384
+ self.maxx_pts = 1.
385
+
386
+ # self.minn_pts = -1.5 #
387
+ # self.maxx_pts = 1.5 #
388
+ self.bkg_pts = ... # TODO
389
+ self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
390
+ self.dist_interp_thres = ... # TODO: set the cur_bkg_pts_defs #
391
+
392
+ self.bending_network = ... # TODO: add the bending network #
393
+ self.use_bending_network = ... # TODO: set the property #
394
+ self.use_delta_bending = ... # TODO
395
+ self.prev_sdf_network = ... # TODO
396
+ self.use_selector = False
397
+ self.timestep_to_passive_mesh = ... # TODO
398
+ self.timestep_to_active_mesh = ... # TODO
399
+
400
+
401
+
402
+ def deform_pts(self, pts, pts_ts=0, update_tot_def=True): # deform pts #
403
+
404
+ if self.use_bending_network:
405
+ if len(pts.size()) == 3:
406
+ nnb, nns = pts.size(0), pts.size(1)
407
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
408
+ else:
409
+ pts_exp = pts
410
+ # pts_ts #
411
+ if self.use_delta_bending:
412
+
413
+ if isinstance(self.bending_network, list):
414
+ pts_offsets = []
415
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
416
+ if isinstance(cur_bending_network, fields.BendingNetwork):
417
+ for cur_pts_ts in range(pts_ts, -1, -1):
418
+ cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
419
+ elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
420
+ cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
421
+ else:
422
+ raise ValueError('Encountered with unexpected bending network class...')
423
+ pts_offsets.append(cur_pts_exp - pts_exp)
424
+ pts_offsets = torch.stack(pts_offsets, dim=0)
425
+ pts_offsets = torch.sum(pts_offsets, dim=0)
426
+ pts_exp = pts_exp + pts_offsets
427
+ # for cur_pts_ts in range(pts_ts, -1, -1):
428
+ # if isinstance(self.bending_network, list): # pts ts #
429
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
430
+ # pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
431
+ # else:
432
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
433
+ else:
434
+ if isinstance(self.bending_network, list): # prev sdf network #
435
+ pts_offsets = []
436
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
437
+ bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
438
+ pts_offsets.append(bended_pts_exp - pts_exp)
439
+ pts_offsets = torch.stack(pts_offsets, dim=0)
440
+ pts_offsets = torch.sum(pts_offsets, dim=0)
441
+ pts_exp = pts_exp + pts_offsets
442
+ else:
443
+ pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
444
+ if len(pts.size()) == 3:
445
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
446
+ else:
447
+ pts = pts_exp
448
+ return pts
449
+
450
+ # pts: nn_batch x nn_samples x 3
451
+ if len(pts.size()) == 3:
452
+ nnb, nns = pts.size(0), pts.size(1)
453
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
454
+ else:
455
+ pts_exp = pts
456
+ # print(f"prior to deforming: {pts.size()}")
457
+
458
+ dist_pts_to_bkg_pts = torch.sum(
459
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
460
+ )
461
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
462
+ dist_mask_float = dist_mask.float()
463
+
464
+ # dist_mask_float #
465
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
466
+ cur_fr_pts_def = torch.sum(
467
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
468
+ )
469
+ dist_mask_float_summ = torch.sum(
470
+ dist_mask_float, dim=1
471
+ )
472
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
473
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
474
+ pts_exp = pts_exp - cur_fr_pts_def
475
+ if len(pts.size()) == 3:
476
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
477
+ else:
478
+ pts = pts_exp
479
+ return pts #
480
+
481
+
482
+ def deform_pts_with_selector(self, pts, pts_ts=0, update_tot_def=True): # deform pts #
483
+
484
+ if self.use_bending_network:
485
+ if len(pts.size()) == 3:
486
+ nnb, nns = pts.size(0), pts.size(1)
487
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
488
+ else:
489
+ pts_exp = pts
490
+ # pts_ts #
491
+ if self.use_delta_bending:
492
+ if isinstance(self.bending_network, list):
493
+ bended_pts = []
494
+ queries_sdfs_selector = []
495
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
496
+ if cur_bending_network.use_opt_rigid_translations:
497
+ bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
498
+ else:
499
+ # bended_pts_exp = pts_exp.clone()
500
+ if i_obj == 1 and pts_ts == 0:
501
+ bended_pts_exp = pts_exp
502
+ elif i_obj == 1:
503
+ for cur_pts_ts in range(pts_ts, 0, -1): ### before 0 ###
504
+ if isinstance(cur_bending_network, fields.BendingNetwork):
505
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
506
+ elif isinstance(cur_bending_network, fields.BendingNetworkForceForward) or isinstance(cur_bending_network, fields.BendingNetworkRigidTransForward):
507
+ # input_pts, input_pts_ts, timestep_to_passive_mesh, act_sdf_net=None, details=None, special_loss_return=False
508
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh)
509
+ elif isinstance(cur_bending_network, fields.BendingNetworkForceFieldForward):
510
+ # input_pts, input_pts_ts, timestep_to_passive_mesh, passive_sdf_net, details=None, special_loss_return=False
511
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1])
512
+ elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForward):
513
+ # active_bending_net, active_sdf_net,
514
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
515
+ elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV2):
516
+ # active_bending_net, active_sdf_net,
517
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
518
+ elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV3):
519
+ # active_bending_net, active_sdf_net,
520
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
521
+ elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV4):
522
+ # active_bending_net, active_sdf_net,
523
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
524
+ elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV5):
525
+ # active_bending_net, active_sdf_net,
526
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
527
+ elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV6):
528
+ # active_bending_net, active_sdf_net,
529
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
530
+ elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV7):
531
+ # active_bending_net, active_sdf_net,
532
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
533
+ elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV8):
534
+ # active_bending_net, active_sdf_net,
535
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0], update_tot_def=update_tot_def)
536
+ elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV9):
537
+ # active_bending_net, active_sdf_net,
538
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0], update_tot_def=update_tot_def)
539
+ else:
540
+ raise ValueError(f"Unrecognized bending network type: {type(cur_bending_network)}")
541
+ else:
542
+ for cur_pts_ts in range(pts_ts, -1, -1):
543
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
544
+ # _, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
545
+ _, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector_ndelta(bended_pts_exp, i_net=i_obj)
546
+ bended_pts.append(bended_pts_exp)
547
+ queries_sdfs_selector.append(cur_bended_pts_selecotr)
548
+ bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts # # bended_pts #
549
+ queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
550
+ # queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
551
+ #### get the final sdf_selector from queries_sdfs_selector ####
552
+ # sdf_selector = queries_sdfs_selector[:, -1]
553
+ # neg_neg = ((queries_sdfs_selector[:, 0] == 0).float() + (queries_sdfs_selector[:, -1] == 1).float()) > 1.5 #### both inside of the object
554
+ sdf_selector = 1 - queries_sdfs_selector[:, 0]
555
+ # neg_neg
556
+ # sdf_selector = queries_sdfs_selector
557
+ # delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
558
+ bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
559
+ # print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
560
+ # pts_exp = bended_pts.squeeze(1)
561
+ pts_exp = bended_pts
562
+ # for cur_pts_ts in range(pts_ts, -1, -1):
563
+ # if isinstance(self.bending_network, list):
564
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
565
+ # pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
566
+ # else:
567
+
568
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
569
+ else:
570
+ if isinstance(self.bending_network, list): # prev sdf network #
571
+ # pts_offsets = []
572
+ bended_pts = []
573
+ queries_sdfs_selector = []
574
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
575
+ bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
576
+ # pts_offsets.append(bended_pts_exp - pts_exp)
577
+ _, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
578
+ bended_pts.append(bended_pts_exp)
579
+ queries_sdfs_selector.append(cur_bended_pts_selecotr)
580
+ bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
581
+ queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
582
+ # queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
583
+ sdf_selector = queries_sdfs_selector[:, -1]
584
+ # sdf_selector = queries_sdfs_selector
585
+
586
+
587
+ # delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
588
+ bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
589
+ # print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
590
+ pts_exp = bended_pts.squeeze(1)
591
+
592
+ # pts_offsets = torch.stack(pts_offsets, dim=0)
593
+ # pts_offsets = torch.sum(pts_offsets, dim=0)
594
+ # pts_exp = pts_exp + pts_offsets
595
+ else:
596
+ pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
597
+ if len(pts.size()) == 3:
598
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
599
+ else:
600
+ pts = pts_exp
601
+ return pts, sdf_selector
602
+ else:
603
+ # pts: nn_batch x nn_samples x 3
604
+ if len(pts.size()) == 3:
605
+ nnb, nns = pts.size(0), pts.size(1)
606
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
607
+ else:
608
+ pts_exp = pts
609
+ # print(f"prior to deforming: {pts.size()}")
610
+
611
+ dist_pts_to_bkg_pts = torch.sum(
612
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
613
+ )
614
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
615
+ dist_mask_float = dist_mask.float()
616
+
617
+ # dist_mask_float #
618
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
619
+ cur_fr_pts_def = torch.sum(
620
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
621
+ )
622
+ dist_mask_float_summ = torch.sum(
623
+ dist_mask_float, dim=1
624
+ )
625
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
626
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
627
+ pts_exp = pts_exp - cur_fr_pts_def
628
+ if len(pts.size()) == 3:
629
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
630
+ else:
631
+ pts = pts_exp
632
+ return pts #
633
+
634
+
635
+ def deform_pts_passive(self, pts, pts_ts=0):
636
+
637
+ if self.use_bending_network:
638
+ if len(pts.size()) == 3:
639
+ nnb, nns = pts.size(0), pts.size(1)
640
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
641
+ else:
642
+ pts_exp = pts
643
+ # pts_ts #
644
+ if self.use_delta_bending:
645
+ if pts_ts > 0:
646
+ for cur_pts_ts in range(pts_ts, 0, -1):
647
+ # if isinstance(self.bending_network, list):
648
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
649
+ # pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
650
+ # else:
651
+ if isinstance(self.bending_network[-1], fields.BendingNetwork):
652
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts)
653
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkForceForward) or isinstance(self.bending_network[-1], fields.BendingNetworkRigidTransForward):
654
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh)
655
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkForceFieldForward):
656
+ # input_pts, input_pts_ts, timestep_to_passive_mesh, passive_sdf_net, details=None, special_loss_return=False
657
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1])
658
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForward):
659
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
660
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV2):
661
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
662
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV3):
663
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
664
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV4):
665
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
666
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV5):
667
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
668
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV6):
669
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
670
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV7):
671
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
672
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV8):
673
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
674
+ elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV9):
675
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
676
+ else:
677
+ raise ValueError(f"Unrecognized bending network type: {type(self.bending_network[-1])}")
678
+ # pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts)
679
+ else:
680
+ # if isinstance(self.bending_network, list):
681
+ # pts_offsets = []
682
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
683
+ # bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
684
+ # pts_offsets.append(bended_pts_exp - pts_exp)
685
+ # pts_offsets = torch.stack(pts_offsets, dim=0)
686
+ # pts_offsets = torch.sum(pts_offsets, dim=0)
687
+ # pts_exp = pts_exp + pts_offsets
688
+ # else:
689
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=pts_ts)
690
+ if len(pts.size()) == 3:
691
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
692
+ else:
693
+ pts = pts_exp
694
+ return pts
695
+
696
+ # pts: nn_batch x nn_samples x 3
697
+ if len(pts.size()) == 3:
698
+ nnb, nns = pts.size(0), pts.size(1)
699
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
700
+ else:
701
+ pts_exp = pts
702
+ # print(f"prior to deforming: {pts.size()}")
703
+
704
+ dist_pts_to_bkg_pts = torch.sum(
705
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
706
+ )
707
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
708
+ dist_mask_float = dist_mask.float()
709
+
710
+ # dist_mask_float #
711
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
712
+ cur_fr_pts_def = torch.sum(
713
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
714
+ )
715
+ dist_mask_float_summ = torch.sum(
716
+ dist_mask_float, dim=1
717
+ )
718
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
719
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
720
+ pts_exp = pts_exp - cur_fr_pts_def
721
+ if len(pts.size()) == 3:
722
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
723
+ else:
724
+ pts = pts_exp
725
+ return pts #
726
+
727
+ # delta mesh as passive mesh #
728
+ def query_pts_sdf_fn_for_selector(self, pts):
729
+ # for negative
730
+ # 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field
731
+ # 2) negative in thie field and also negative in the previous field --->
732
+ # 2) for positive values of this current field --->
733
+ cur_sdf = self.sdf_network.sdf(pts)
734
+ prev_sdf = self.prev_sdf_network.sdf(pts)
735
+ neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
736
+ neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
737
+
738
+ neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
739
+
740
+ pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
741
+ pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
742
+ res_sdf = torch.zeros_like(cur_sdf)
743
+ res_sdf[neg_neg] = 1. #
744
+ res_sdf[neg_pos] = cur_sdf[neg_pos]
745
+ res_sdf[pos_neg] = cur_sdf[pos_neg]
746
+
747
+ # inside the residual mesh -> must be neg and pos
748
+ res_sdf_selector = torch.zeros_like(cur_sdf).long() #
749
+ # res_sdf_selector[neg_pos] = 1 # is the residual mesh
750
+ res_sdf_selector[neg_weq_pos] = 1
751
+ # res_sdf_selector[]
752
+
753
+ cat_cur_prev_sdf = torch.stack(
754
+ [cur_sdf, prev_sdf], dim=-1
755
+ )
756
+ minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
757
+ res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
758
+
759
+ return res_sdf, res_sdf_selector
760
+
761
+ def query_pts_sdf_fn_for_selector_ndelta(self, pts, i_net):
762
+ # for negative
763
+ # 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field
764
+ # 2) negative in thie field and also negative in the previous field --->
765
+ # 2) for positive values of this current field --->
766
+ passive_sdf = self.sdf_network[i_net].sdf(pts).squeeze(-1)
767
+ passive_sdf_selector = torch.zeros_like(passive_sdf).long()
768
+ passive_sdf_selector[passive_sdf <= 0.] = 1.
769
+ return passive_sdf, passive_sdf_selector
770
+
771
+ cur_sdf = self.sdf_network.sdf(pts)
772
+ prev_sdf = self.prev_sdf_network.sdf(pts)
773
+ neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
774
+ neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
775
+
776
+ neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
777
+
778
+ pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
779
+ pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
780
+ res_sdf = torch.zeros_like(cur_sdf)
781
+ res_sdf[neg_neg] = 1. #
782
+ res_sdf[neg_pos] = cur_sdf[neg_pos]
783
+ res_sdf[pos_neg] = cur_sdf[pos_neg]
784
+
785
+ # inside the residual mesh -> must be neg and pos
786
+ res_sdf_selector = torch.zeros_like(cur_sdf).long() #
787
+ # res_sdf_selector[neg_pos] = 1 # is the residual mesh
788
+ res_sdf_selector[neg_weq_pos] = 1
789
+ # res_sdf_selector[]
790
+
791
+ cat_cur_prev_sdf = torch.stack(
792
+ [cur_sdf, prev_sdf], dim=-1
793
+ )
794
+ minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
795
+ res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
796
+
797
+ return res_sdf, res_sdf_selector
798
+
799
+
800
+ def query_func_sdf(self, pts):
801
+ if isinstance(self.sdf_network, list):
802
+ tot_sdf_values = []
803
+ for i_obj, cur_sdf_network in enumerate(self.sdf_network):
804
+ cur_sdf_values = cur_sdf_network.sdf(pts)
805
+ tot_sdf_values.append(cur_sdf_values)
806
+ tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
807
+ tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
808
+ sdf = tot_sdf_values
809
+ else:
810
+ sdf = self.sdf_network.sdf(pts)
811
+ return sdf
812
+
813
+ def query_func_sdf_passive(self, pts):
814
+ # if isinstance(self.sdf_network, list):
815
+ # tot_sdf_values = []
816
+ # for i_obj, cur_sdf_network in enumerate(self.sdf_network):
817
+ # cur_sdf_values = cur_sdf_network.sdf(pts)
818
+ # tot_sdf_values.append(cur_sdf_values)
819
+ # tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
820
+ # tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
821
+ # sdf = tot_sdf_values
822
+ # else:
823
+ sdf = self.sdf_network[-1].sdf(pts)
824
+
825
+ return sdf
826
+
827
+
828
+ def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None, pts_ts=0):
829
+ """
830
+ Render background
831
+ """
832
+ batch_size, n_samples = z_vals.shape
833
+
834
+ # Section length
835
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
836
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
837
+ mid_z_vals = z_vals + dists * 0.5
838
+
839
+ # Section midpoints #
840
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
841
+
842
+ # pts = pts.flip((-1,)) * 2 - 1
843
+ pts = pts * 2 - 1
844
+
845
+ if self.use_selector:
846
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
847
+ else:
848
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
849
+
850
+ dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
851
+ pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
852
+
853
+ dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
854
+
855
+ pts = pts.reshape(-1, 3 + int(self.n_outside > 0)) ### deformed_pts ###
856
+ dirs = dirs.reshape(-1, 3)
857
+
858
+ if self.use_selector:
859
+ tot_density, tot_sampled_color = [], []
860
+ for i_nerf, cur_nerf in enumerate(nerf):
861
+ cur_density, cur_sampled_color = cur_nerf(pts, dirs)
862
+ tot_density.append(cur_density)
863
+ tot_sampled_color.append(cur_sampled_color)
864
+ tot_density = torch.stack(tot_density, dim=1)
865
+ tot_sampled_color = torch.stack(tot_sampled_color, dim=1) ### sampled colors
866
+ # print(f"tot_density: {tot_density.size()}, tot_sampled_color: {tot_sampled_color.size()}, sdf_selector: {sdf_selector.size()}")
867
+ density = batched_index_select(values=tot_density, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
868
+ sampled_color = batched_index_select(values=tot_sampled_color, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
869
+ else:
870
+ density, sampled_color = nerf(pts, dirs)
871
+ sampled_color = torch.sigmoid(sampled_color)
872
+ alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
873
+ alpha = alpha.reshape(batch_size, n_samples)
874
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
875
+ sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
876
+ color = (weights[:, :, None] * sampled_color).sum(dim=1)
877
+ if background_rgb is not None:
878
+ color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
879
+
880
+ return {
881
+ 'color': color,
882
+ 'sampled_color': sampled_color,
883
+ 'alpha': alpha,
884
+ 'weights': weights,
885
+ }
886
+
887
+ def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s, pts_ts=0):
888
+ """
889
+ Up sampling give a fixed inv_s
890
+ """
891
+ batch_size, n_samples = z_vals.shape
892
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
893
+
894
+ # pts = pts.flip((-1,)) * 2 - 1
895
+ pts = pts * 2 - 1
896
+
897
+ if self.use_selector:
898
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
899
+ else:
900
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
901
+
902
+ radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
903
+ inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
904
+ sdf = sdf.reshape(batch_size, n_samples)
905
+ prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
906
+ prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
907
+ mid_sdf = (prev_sdf + next_sdf) * 0.5
908
+ cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
909
+
910
+ # ----------------------------------------------------------------------------------------------------------
911
+ # Use min value of [ cos, prev_cos ]
912
+ # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
913
+ # robust when meeting situations like below:
914
+ #
915
+ # SDF
916
+ # ^
917
+ # |\ -----x----...
918
+ # | \ /
919
+ # | x x
920
+ # |---\----/-------------> 0 level
921
+ # | \ /
922
+ # | \/
923
+ # |
924
+ # ----------------------------------------------------------------------------------------------------------
925
+ prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
926
+ cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
927
+ cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
928
+ cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
929
+
930
+ dist = (next_z_vals - prev_z_vals)
931
+ prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
932
+ next_esti_sdf = mid_sdf + cos_val * dist * 0.5
933
+ prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
934
+ next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
935
+ alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
936
+ weights = alpha * torch.cumprod(
937
+ torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
938
+
939
+ z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
940
+ return z_samples
941
+
942
+ def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False, pts_ts=0):
943
+ batch_size, n_samples = z_vals.shape
944
+ _, n_importance = new_z_vals.shape
945
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
946
+
947
+ # pts = pts.flip((-1,)) * 2 - 1
948
+ pts = pts * 2 - 1
949
+
950
+ if self.use_selector:
951
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
952
+ else:
953
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
954
+
955
+ z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
956
+ z_vals, index = torch.sort(z_vals, dim=-1)
957
+
958
+ if not last:
959
+ if isinstance(self.sdf_network, list):
960
+ tot_new_sdf = []
961
+ for i_obj, cur_sdf_network in enumerate(self.sdf_network):
962
+ cur_new_sdf = cur_sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
963
+ tot_new_sdf.append(cur_new_sdf)
964
+ tot_new_sdf = torch.stack(tot_new_sdf, dim=-1)
965
+ new_sdf, _ = torch.min(tot_new_sdf, dim=-1) #
966
+ else:
967
+ new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
968
+ sdf = torch.cat([sdf, new_sdf], dim=-1)
969
+ xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
970
+ index = index.reshape(-1)
971
+ sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
972
+
973
+ return z_vals, sdf
974
+
975
+
976
+
977
+ def render_core(self,
978
+ rays_o,
979
+ rays_d,
980
+ z_vals,
981
+ sample_dist,
982
+ sdf_network,
983
+ deviation_network,
984
+ color_network,
985
+ background_alpha=None,
986
+ background_sampled_color=None,
987
+ background_rgb=None,
988
+ cos_anneal_ratio=0.0,
989
+ pts_ts=0):
990
+ batch_size, n_samples = z_vals.shape
991
+
992
+ # Section length
993
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
994
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
995
+ mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
996
+
997
+ # Section midpoints
998
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
999
+ dirs = rays_d[:, None, :].expand(pts.shape)
1000
+
1001
+ pts = pts.reshape(-1, 3) # pts, nn_ou
1002
+ dirs = dirs.reshape(-1, 3)
1003
+
1004
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
1005
+
1006
+ # pts = pts.flip((-1,)) * 2 - 1
1007
+ pts = pts * 2 - 1
1008
+
1009
+ if self.use_selector:
1010
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
1011
+ else:
1012
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
1013
+
1014
+ if isinstance(sdf_network, list):
1015
+ tot_sdf = []
1016
+ tot_feature_vector = []
1017
+ tot_obj_sel = []
1018
+ tot_gradients = []
1019
+ for i_obj, cur_sdf_network in enumerate(sdf_network):
1020
+ cur_sdf_nn_output = cur_sdf_network(pts)
1021
+ cur_sdf, cur_feature_vector = cur_sdf_nn_output[:, :1], cur_sdf_nn_output[:, 1:]
1022
+ tot_sdf.append(cur_sdf)
1023
+ tot_feature_vector.append(cur_feature_vector)
1024
+
1025
+ gradients = cur_sdf_network.gradient(pts).squeeze()
1026
+ tot_gradients.append(gradients)
1027
+ tot_sdf = torch.stack(tot_sdf, dim=-1)
1028
+
1029
+ #
1030
+ if self.use_selector:
1031
+ sdf = batched_index_select(tot_sdf, sdf_selector.unsqueeze(1).unsqueeze(1), dim=2).squeeze(-1)
1032
+ obj_sel = sdf_selector.unsqueeze(1)
1033
+ else:
1034
+ sdf, obj_sel = torch.min(tot_sdf, dim=-1)
1035
+ feature_vector = torch.stack(tot_feature_vector, dim=1)
1036
+
1037
+ # batched_index_select
1038
+ # print(f"before sel: {feature_vector.size()}, obj_sel: {obj_sel.size()}")
1039
+ feature_vector = batched_index_select(values=feature_vector, indices=obj_sel, dim=1).squeeze(1)
1040
+
1041
+ # feature_vector = feature_vector[obj_sel.unsqueeze(-1), :].squeeze(1)
1042
+ # print(f"after sel: {feature_vector.size()}")
1043
+ tot_gradients = torch.stack(tot_gradients, dim=1)
1044
+ # gradients = tot_gradients[obj_sel.unsqueeze(-1)].squeeze(1)
1045
+ gradients = batched_index_select(values=tot_gradients, indices=obj_sel, dim=1).squeeze(1)
1046
+ # print(f"gradients: {gradients.size()}, tot_gradients: {tot_gradients.size()}")
1047
+
1048
+ else:
1049
+ sdf_nn_output = sdf_network(pts)
1050
+ sdf = sdf_nn_output[:, :1]
1051
+ feature_vector = sdf_nn_output[:, 1:]
1052
+ gradients = sdf_network.gradient(pts).squeeze()
1053
+
1054
+ if self.use_selector:
1055
+ tot_sampled_color = []
1056
+ for i_color_net, cur_color_network in enumerate(color_network):
1057
+ cur_sampled_color = cur_color_network(pts, gradients, dirs, feature_vector) # .reshape(batch_size, n_samples, 3)
1058
+ tot_sampled_color.append(cur_sampled_color)
1059
+ # print(f"tot_density: {tot_density.size()}, tot_sampled_color: {tot_sampled_color.size()}, sdf_selector: {sdf_selector.size()}")
1060
+ tot_sampled_color = torch.stack(tot_sampled_color, dim=1)
1061
+ sampled_color = batched_index_select(values=tot_sampled_color, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1).reshape(batch_size, n_samples, 3)
1062
+ else:
1063
+ sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
1064
+
1065
+ if self.use_selector and isinstance(deviation_network, list):
1066
+ tot_inv_s = []
1067
+ for i_dev_net, cur_deviation_network in enumerate(deviation_network):
1068
+ cur_inv_s = cur_deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6)
1069
+ tot_inv_s.append(cur_inv_s)
1070
+ tot_inv_s = torch.stack(tot_inv_s, dim=1)
1071
+ inv_s = batched_index_select(values=tot_inv_s, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
1072
+ # inv_s =
1073
+ else:
1074
+ # deviation network #
1075
+ inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
1076
+ inv_s = inv_s.expand(batch_size * n_samples, 1)
1077
+
1078
+ true_cos = (dirs * gradients).sum(-1, keepdim=True)
1079
+
1080
+ # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
1081
+ # the cos value "not dead" at the beginning training iterations, for better convergence.
1082
+ iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
1083
+ F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
1084
+
1085
+ # Estimate signed distances at section points
1086
+ estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
1087
+ estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
1088
+
1089
+ prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
1090
+ next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
1091
+
1092
+ p = prev_cdf - next_cdf
1093
+ c = prev_cdf
1094
+
1095
+ alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
1096
+
1097
+ pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
1098
+ inside_sphere = (pts_norm < 1.0).float().detach()
1099
+ relax_inside_sphere = (pts_norm < 1.2).float().detach()
1100
+
1101
+ # Render with background
1102
+ if background_alpha is not None:
1103
+ alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
1104
+ alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
1105
+ sampled_color = sampled_color * inside_sphere[:, :, None] +\
1106
+ background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
1107
+ sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
1108
+
1109
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
1110
+ weights_sum = weights.sum(dim=-1, keepdim=True)
1111
+
1112
+ color = (sampled_color * weights[:, :, None]).sum(dim=1)
1113
+ if background_rgb is not None: # Fixed background, usually black
1114
+ color = color + background_rgb * (1.0 - weights_sum)
1115
+
1116
+ # Eikonal loss
1117
+ gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
1118
+ dim=-1) - 1.0) ** 2
1119
+ gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
1120
+
1121
+ return {
1122
+ 'color': color,
1123
+ 'sdf': sdf,
1124
+ 'dists': dists,
1125
+ 'gradients': gradients.reshape(batch_size, n_samples, 3),
1126
+ 's_val': 1.0 / inv_s,
1127
+ 'mid_z_vals': mid_z_vals,
1128
+ 'weights': weights,
1129
+ 'cdf': c.reshape(batch_size, n_samples),
1130
+ 'gradient_error': gradient_error,
1131
+ 'inside_sphere': inside_sphere
1132
+ }
1133
+
1134
+ def per_sdf_query(self, pts):
1135
+ tot_sdfs = []
1136
+ for i_sdf_net, cur_sdf_network in enumerate(self.sdf_network):
1137
+ cur_sdf_value = cur_sdf_network.sdf(pts).squeeze(-1)
1138
+ tot_sdfs.append(cur_sdf_value)
1139
+ tot_sdfs = torch.stack(tot_sdfs, dim=1)
1140
+ return tot_sdfs
1141
+
1142
+
1143
+ def render(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
1144
+ batch_size = len(rays_o)
1145
+ sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere
1146
+ z_vals = torch.linspace(0.0, 1.0, self.n_samples) # linspace #
1147
+ z_vals = near + (far - near) * z_vals[None, :]
1148
+
1149
+ z_vals_outside = None
1150
+ if self.n_outside > 0:
1151
+ z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
1152
+
1153
+ n_samples = self.n_samples
1154
+ perturb = self.perturb
1155
+
1156
+ if perturb_overwrite >= 0:
1157
+ perturb = perturb_overwrite
1158
+ if perturb > 0:
1159
+ t_rand = (torch.rand([batch_size, 1]) - 0.5)
1160
+ z_vals = z_vals + t_rand * 2.0 / self.n_samples
1161
+
1162
+ if self.n_outside > 0: # z values output # n_outside #
1163
+ mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
1164
+ upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
1165
+ lower = torch.cat([z_vals_outside[..., :1], mids], -1)
1166
+ t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
1167
+ z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
1168
+
1169
+ if self.n_outside > 0:
1170
+ z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
1171
+
1172
+ background_alpha = None
1173
+ background_sampled_color = None
1174
+
1175
+ # Up sample
1176
+ if self.n_importance > 0:
1177
+ with torch.no_grad():
1178
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
1179
+
1180
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
1181
+ # sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
1182
+ # gt_sdf #
1183
+
1184
+ #
1185
+ # pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
1186
+
1187
+ # pts = pts.flip((-1,)) * 2 - 1
1188
+ pts = pts * 2 - 1
1189
+
1190
+ if self.use_selector:
1191
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
1192
+ else:
1193
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts) # give nthe pts
1194
+
1195
+ pts_exp = pts.reshape(-1, 3)
1196
+ # minn_pts, _ = torch.min(pts_exp, dim=0)
1197
+ # maxx_pts, _ = torch.max(pts_exp, dim=0) # deformation field (not a rigid one) -> the meshes #
1198
+ # print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
1199
+
1200
+ # pts_to_near = pts - near.unsqueeze(1)
1201
+ # maxx_pts = 1.5; minn_pts = -1.5
1202
+ # # maxx_pts = 3; minn_pts = -3
1203
+ # # maxx_pts = 1; minn_pts = -1
1204
+ # pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
1205
+
1206
+ ## render and iamges ####
1207
+ # if use_gt_sdf:
1208
+ # ### use the GT sdf field ####
1209
+ # # print(f"Using gt sdf :")
1210
+ # sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
1211
+ # sdf = torch.from_numpy(sdf).float().cuda()
1212
+ # sdf = sdf.reshape(batch_size, self.n_samples)
1213
+ # ### use the GT sdf field ####
1214
+ # else:
1215
+ # # pts_exp: (bsz x nn_s) x 3 -> (sdf_network) -> (bsz x nn_s)
1216
+ # #### use the optimized sdf field ####
1217
+
1218
+ # # sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
1219
+
1220
+ if isinstance(self.sdf_network, list):
1221
+ if self.use_selector:
1222
+ tot_sdf_values = []
1223
+ for i_obj, cur_sdf_network in enumerate(self.sdf_network):
1224
+ cur_sdf_values = cur_sdf_network.sdf(pts_exp).squeeze(-1)
1225
+ tot_sdf_values.append(cur_sdf_values)
1226
+ tot_sdf_values = torch.stack(tot_sdf_values, dim=1)
1227
+ tot_sdf_values = batched_index_select(tot_sdf_values, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1)
1228
+ sdf = tot_sdf_values.reshape(batch_size, self.n_samples)
1229
+ else:
1230
+ # tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
1231
+ tot_sdf_values = []
1232
+ for i_obj, cur_sdf_network in enumerate(self.sdf_network):
1233
+ cur_sdf_values = cur_sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
1234
+ tot_sdf_values.append(cur_sdf_values)
1235
+ tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
1236
+ tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
1237
+ sdf = tot_sdf_values
1238
+ else:
1239
+ sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
1240
+
1241
+ #### use the optimized sdf field ####
1242
+
1243
+ for i in range(self.up_sample_steps):
1244
+ new_z_vals = self.up_sample(rays_o,
1245
+ rays_d,
1246
+ z_vals,
1247
+ sdf,
1248
+ self.n_importance // self.up_sample_steps,
1249
+ 64 * 2**i,
1250
+ pts_ts=pts_ts)
1251
+ z_vals, sdf = self.cat_z_vals(rays_o,
1252
+ rays_d,
1253
+ z_vals,
1254
+ new_z_vals,
1255
+ sdf,
1256
+ last=(i + 1 == self.up_sample_steps),
1257
+ pts_ts=pts_ts)
1258
+
1259
+ n_samples = self.n_samples + self.n_importance
1260
+
1261
+ # Background model
1262
+ if self.n_outside > 0:
1263
+ z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
1264
+ z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
1265
+ ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf, pts_ts=pts_ts)
1266
+
1267
+ background_sampled_color = ret_outside['sampled_color']
1268
+ background_alpha = ret_outside['alpha']
1269
+
1270
+ tot_sdfs = self.per_sdf_query(pts_exp)
1271
+
1272
+ # Render core
1273
+ ret_fine = self.render_core(rays_o, #
1274
+ rays_d,
1275
+ z_vals,
1276
+ sample_dist,
1277
+ self.sdf_network,
1278
+ self.deviation_network,
1279
+ self.color_network,
1280
+ background_rgb=background_rgb,
1281
+ background_alpha=background_alpha,
1282
+ background_sampled_color=background_sampled_color,
1283
+ cos_anneal_ratio=cos_anneal_ratio,
1284
+ pts_ts=pts_ts)
1285
+
1286
+ color_fine = ret_fine['color']
1287
+ weights = ret_fine['weights']
1288
+ weights_sum = weights.sum(dim=-1, keepdim=True)
1289
+ gradients = ret_fine['gradients']
1290
+ s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
1291
+
1292
+ return {
1293
+ 'color_fine': color_fine,
1294
+ 's_val': s_val,
1295
+ 'cdf_fine': ret_fine['cdf'],
1296
+ 'weight_sum': weights_sum,
1297
+ 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
1298
+ 'gradients': gradients,
1299
+ 'weights': weights,
1300
+ 'gradient_error': ret_fine['gradient_error'],
1301
+ 'inside_sphere': ret_fine['inside_sphere'],
1302
+ 'tot_sdfs': tot_sdfs,
1303
+ }
1304
+
1305
+
1306
+
1307
+ def render_def(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False, update_tot_def=True):
1308
+ batch_size = len(rays_o)
1309
+ # sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere
1310
+ z_vals = torch.linspace(0.0, 1.0, self.n_samples)
1311
+ z_vals = near + (far - near) * z_vals[None, :]
1312
+
1313
+ z_vals_outside = None
1314
+ if self.n_outside > 0:
1315
+ z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
1316
+
1317
+ n_samples = self.n_samples
1318
+ perturb = self.perturb
1319
+
1320
+ if perturb_overwrite >= 0:
1321
+ perturb = perturb_overwrite
1322
+ if perturb > 0:
1323
+ t_rand = (torch.rand([batch_size, 1]) - 0.5)
1324
+ z_vals = z_vals + t_rand * 2.0 / self.n_samples
1325
+
1326
+ if self.n_outside > 0: # z values output # n_outside #
1327
+ mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
1328
+ upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
1329
+ lower = torch.cat([z_vals_outside[..., :1], mids], -1)
1330
+ t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
1331
+ z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
1332
+
1333
+ if self.n_outside > 0:
1334
+ z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
1335
+
1336
+ background_alpha = None
1337
+ background_sampled_color = None
1338
+
1339
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
1340
+
1341
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
1342
+ # sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
1343
+ # gt_sdf #
1344
+
1345
+ #
1346
+ # pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
1347
+
1348
+ # pts = pts.flip((-1,)) * 2 - 1
1349
+ pts = pts * 2 - 1
1350
+
1351
+ if self.use_selector:
1352
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts, update_tot_def=update_tot_def)
1353
+ else:
1354
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts, update_tot_def=update_tot_def) # give nthe pts
1355
+
1356
+ return {
1357
+ 'defed_pts': pts
1358
+ }
1359
+ #
1360
+
1361
+
1362
+ def extract_fields_from_tets_selector_self(self, bound_min, bound_max, resolution, i_ts, passive=False):
1363
+ # load tet via resolution #
1364
+ # scale them via bounds #
1365
+ # extract the geometry #
1366
+ # /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
1367
+ device = bound_min.device
1368
+ # if resolution in [64, 70, 80, 90, 100]:
1369
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
1370
+ # else:
1371
+ tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
1372
+ if not os.path.exists(tet_fn):
1373
+ tet_fn = f"/data/xueyi/NeuS/data/tets/{100}_compress.npz"
1374
+ tets = np.load(tet_fn)
1375
+ verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
1376
+ indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
1377
+ # split #
1378
+ # verts; verts; #
1379
+ minn_verts, _ = torch.min(verts, dim=0)
1380
+ maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
1381
+ # scale_verts = maxx_verts - minn_verts
1382
+ scale_bounds = bound_max - bound_min # scale bounds #
1383
+
1384
+ ### scale the vertices ###
1385
+ scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
1386
+
1387
+ # scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
1388
+
1389
+ scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
1390
+ # scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
1391
+
1392
+ # scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
1393
+ # scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
1394
+
1395
+ sdf_values = []
1396
+ N = 64
1397
+ query_bundles = N ** 3 ### N^3
1398
+ query_NNs = scaled_verts.size(0) // query_bundles
1399
+ if query_NNs * query_bundles < scaled_verts.size(0):
1400
+ query_NNs += 1
1401
+ for i_query in range(query_NNs):
1402
+ cur_bundle_st = i_query * query_bundles
1403
+ cur_bundle_ed = (i_query + 1) * query_bundles
1404
+ cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
1405
+ cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
1406
+ # if def_func is not None:
1407
+ cur_query_pts, sdf_selector = self.deform_pts_with_selector(cur_query_pts, pts_ts=i_ts)
1408
+ # cur_query_pts, _
1409
+ # cur_query_vals = query_func(cur_query_pts)
1410
+
1411
+
1412
+
1413
+ if passive:
1414
+ cur_query_vals = self.sdf_network[1].sdf(cur_query_pts) # .squeeze(-1)
1415
+ else:
1416
+ tot_sdf_values = []
1417
+ for i_obj, cur_sdf_network in enumerate(self.sdf_network):
1418
+ cur_sdf_values = cur_sdf_network.sdf(cur_query_pts).squeeze(-1)
1419
+ tot_sdf_values.append(cur_sdf_values)
1420
+ tot_sdf_values = torch.stack(tot_sdf_values, dim=1)
1421
+ tot_sdf_values = batched_index_select(tot_sdf_values, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1)
1422
+ cur_query_vals = tot_sdf_values.unsqueeze(1)
1423
+ # sdf = tot_sdf_values.reshape(batch_size, self.n_samples)
1424
+ # for i_obj,
1425
+ sdf_values.append(cur_query_vals)
1426
+ sdf_values = torch.cat(sdf_values, dim=0)
1427
+ # print(f"queryed sdf values: {sdf_values.size()}") #
1428
+
1429
+ gt_sdf_fn = "/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy"
1430
+ if not os.path.exists(gt_sdf_fn):
1431
+ gt_sdf_fn = "/data/xueyi/NeuS/data/100_sdf_values.npy"
1432
+ GT_sdf_values = np.load(gt_sdf_fn, allow_pickle=True)
1433
+
1434
+ GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
1435
+
1436
+ # intrinsic, tet values, pts values, sdf network #
1437
+ triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
1438
+ tet_table, num_tets_table = create_tetmesh_variables(device)
1439
+
1440
+ sdf_values = sdf_values.squeeze(-1) # how the rendering #
1441
+
1442
+ # print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
1443
+ # print(f"scaled_verts: {scaled_verts.size()}, ")
1444
+ # pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
1445
+ # return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
1446
+ # marching_tets_tetmesh ##
1447
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
1448
+ ### use the GT sdf values for the marching tets ###
1449
+ GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
1450
+
1451
+ # print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
1452
+ return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
1453
+
1454
+
1455
+ def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
1456
+ return extract_geometry(bound_min, # extract geometry #
1457
+ bound_max,
1458
+ resolution=resolution,
1459
+ threshold=threshold,
1460
+ # query_func=lambda pts: -self.sdf_network.sdf(pts),
1461
+ query_func=lambda pts: -self.query_func_sdf(pts)
1462
+ )
1463
+
1464
+ # if self.deform_pts_with_selector:
1465
+ # pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
1466
+ def extract_geometry_tets(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
1467
+ if wdef:
1468
+ return extract_geometry_tets(bound_min, # extract geometry #
1469
+ bound_max,
1470
+ resolution=resolution,
1471
+ threshold=threshold,
1472
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1473
+ def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts),
1474
+ selector=True)
1475
+ else:
1476
+ return extract_geometry_tets(bound_min, # extract geometry #
1477
+ bound_max,
1478
+ resolution=resolution,
1479
+ threshold=threshold,
1480
+ # query_func=lambda pts: -self.sdf_network.sdf(pts)
1481
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1482
+ selector=True
1483
+ )
1484
+
1485
+ def extract_geometry_tets_passive(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
1486
+ if wdef:
1487
+ return extract_geometry_tets(bound_min, # extract geometry #
1488
+ bound_max,
1489
+ resolution=resolution,
1490
+ threshold=threshold,
1491
+ query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts),
1492
+ def_func=lambda pts: self.deform_pts_passive(pts, pts_ts=pts_ts),
1493
+ selector=False
1494
+ )
1495
+ # return extract_geometry_tets(bound_min, # extract geometry #
1496
+ # bound_max,
1497
+ # resolution=resolution,
1498
+ # threshold=threshold,
1499
+ # query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts),
1500
+ # def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts),
1501
+ # selector=True)
1502
+ else:
1503
+ return extract_geometry_tets(bound_min, # extract geometry #
1504
+ bound_max,
1505
+ resolution=resolution,
1506
+ threshold=threshold,
1507
+ # query_func=lambda pts: -self.sdf_network.sdf(pts)
1508
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1509
+ selector=False
1510
+ )
models/renderer_def_multi_objs_rigidtrans_forward.py ADDED
@@ -0,0 +1,1603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import logging
6
+ import mcubes
7
+ from icecream import ic
8
+ import os
9
+
10
+ import trimesh
11
+ from pysdf import SDF
12
+
13
+ import models.fields as fields
14
+
15
+ from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
16
+
17
+ def batched_index_select(values, indices, dim = 1):
18
+ value_dims = values.shape[(dim + 1):]
19
+ values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
20
+ indices = indices[(..., *((None,) * len(value_dims)))]
21
+ indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
22
+ value_expand_len = len(indices_shape) - (dim + 1)
23
+ values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
24
+
25
+ value_expand_shape = [-1] * len(values.shape)
26
+ expand_slice = slice(dim, (dim + value_expand_len))
27
+ value_expand_shape[expand_slice] = indices.shape[expand_slice]
28
+ values = values.expand(*value_expand_shape)
29
+
30
+ dim += value_expand_len
31
+ return values.gather(dim, indices)
32
+
33
+
34
+ def create_mt_variable(device):
35
+ triangle_table = torch.tensor(
36
+ [
37
+ [-1, -1, -1, -1, -1, -1],
38
+ [1, 0, 2, -1, -1, -1],
39
+ [4, 0, 3, -1, -1, -1],
40
+ [1, 4, 2, 1, 3, 4],
41
+ [3, 1, 5, -1, -1, -1],
42
+ [2, 3, 0, 2, 5, 3],
43
+ [1, 4, 0, 1, 5, 4],
44
+ [4, 2, 5, -1, -1, -1],
45
+ [4, 5, 2, -1, -1, -1],
46
+ [4, 1, 0, 4, 5, 1],
47
+ [3, 2, 0, 3, 5, 2],
48
+ [1, 3, 5, -1, -1, -1],
49
+ [4, 1, 2, 4, 3, 1],
50
+ [3, 0, 4, -1, -1, -1],
51
+ [2, 0, 1, -1, -1, -1],
52
+ [-1, -1, -1, -1, -1, -1]
53
+ ], dtype=torch.long, device=device)
54
+
55
+ num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
56
+ base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
57
+ v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
58
+ return triangle_table, num_triangles_table, base_tet_edges, v_id
59
+
60
+
61
+
62
+ def extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=None):
63
+ # load tet via resolution #
64
+ # scale them via bounds #
65
+ # extract the geometry #
66
+ # /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
67
+ device = bound_min.device
68
+ # if resolution in [64, 70, 80, 90, 100]:
69
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
70
+ # else:
71
+ tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
72
+ tets = np.load(tet_fn)
73
+ verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
74
+ indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
75
+ # split #
76
+ # verts; verts; #
77
+ minn_verts, _ = torch.min(verts, dim=0)
78
+ maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
79
+ # scale_verts = maxx_verts - minn_verts
80
+ scale_bounds = bound_max - bound_min # scale bounds #
81
+
82
+ ### scale the vertices ###
83
+ scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
84
+
85
+ # scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
86
+
87
+ scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
88
+ # scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
89
+
90
+ # scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
91
+ # scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
92
+
93
+ sdf_values = []
94
+ N = 64
95
+ query_bundles = N ** 3 ### N^3
96
+ query_NNs = scaled_verts.size(0) // query_bundles
97
+ if query_NNs * query_bundles < scaled_verts.size(0):
98
+ query_NNs += 1
99
+ for i_query in range(query_NNs):
100
+ cur_bundle_st = i_query * query_bundles
101
+ cur_bundle_ed = (i_query + 1) * query_bundles
102
+ cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
103
+ cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
104
+ if def_func is not None:
105
+ # cur_query_pts = def_func(cur_query_pts)
106
+ cur_query_pts, _ = def_func(cur_query_pts)
107
+ cur_query_vals = query_func(cur_query_pts)
108
+ sdf_values.append(cur_query_vals)
109
+ sdf_values = torch.cat(sdf_values, dim=0)
110
+ # print(f"queryed sdf values: {sdf_values.size()}") #
111
+
112
+ GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
113
+ GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
114
+
115
+ # intrinsic, tet values, pts values, sdf network #
116
+ triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
117
+ tet_table, num_tets_table = create_tetmesh_variables(device)
118
+
119
+ sdf_values = sdf_values.squeeze(-1) # how the rendering #
120
+
121
+ # print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
122
+ # print(f"scaled_verts: {scaled_verts.size()}, ")
123
+ # pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
124
+ # return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
125
+ # marching_tets_tetmesh ##
126
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
127
+ ### use the GT sdf values for the marching tets ###
128
+ GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
129
+
130
+ # print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
131
+ return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
132
+
133
+
134
+ def extract_fields_from_tets_selector(bound_min, bound_max, resolution, query_func, def_func=None):
135
+ # load tet via resolution #
136
+ # scale them via bounds #
137
+ # extract the geometry #
138
+ # /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
139
+ device = bound_min.device
140
+ # if resolution in [64, 70, 80, 90, 100]:
141
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
142
+ # else:
143
+ tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
144
+ tets = np.load(tet_fn)
145
+ verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
146
+ indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
147
+ # split #
148
+ # verts; verts; #
149
+ minn_verts, _ = torch.min(verts, dim=0)
150
+ maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
151
+ # scale_verts = maxx_verts - minn_verts
152
+ scale_bounds = bound_max - bound_min # scale bounds #
153
+
154
+ ### scale the vertices ###
155
+ scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
156
+
157
+ # scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
158
+
159
+ scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
160
+ # scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
161
+
162
+ # scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
163
+ # scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
164
+
165
+ sdf_values = []
166
+ N = 64
167
+ query_bundles = N ** 3 ### N^3
168
+ query_NNs = scaled_verts.size(0) // query_bundles
169
+ if query_NNs * query_bundles < scaled_verts.size(0):
170
+ query_NNs += 1
171
+ for i_query in range(query_NNs):
172
+ cur_bundle_st = i_query * query_bundles
173
+ cur_bundle_ed = (i_query + 1) * query_bundles
174
+ cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
175
+ cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
176
+ if def_func is not None:
177
+ # cur_query_pts = def_func(cur_query_pts)
178
+ cur_query_pts, _ = def_func(cur_query_pts)
179
+ cur_query_vals = query_func(cur_query_pts)
180
+ sdf_values.append(cur_query_vals)
181
+ sdf_values = torch.cat(sdf_values, dim=0)
182
+ # print(f"queryed sdf values: {sdf_values.size()}") #
183
+
184
+ GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
185
+ GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
186
+
187
+ # intrinsic, tet values, pts values, sdf network #
188
+ triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
189
+ tet_table, num_tets_table = create_tetmesh_variables(device)
190
+
191
+ sdf_values = sdf_values.squeeze(-1) # how the rendering #
192
+
193
+ # print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
194
+ # print(f"scaled_verts: {scaled_verts.size()}, ")
195
+ # pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
196
+ # return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
197
+ # marching_tets_tetmesh ##
198
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
199
+ ### use the GT sdf values for the marching tets ###
200
+ GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
201
+
202
+ # print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
203
+ return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
204
+
205
+
206
+ def extract_fields(bound_min, bound_max, resolution, query_func):
207
+ N = 64
208
+ X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
209
+ Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
210
+ Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
211
+
212
+ u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
213
+ with torch.no_grad():
214
+ for xi, xs in enumerate(X):
215
+ for yi, ys in enumerate(Y):
216
+ for zi, zs in enumerate(Z):
217
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
218
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
219
+ val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
220
+ u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
221
+ # should save u here #
222
+ # save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
223
+ # np.save(save_u_path, u)
224
+ # print(f"u saved to {save_u_path}")
225
+ return u
226
+
227
+
228
+ def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
229
+ print('threshold: {}'.format(threshold))
230
+
231
+ ## using maching cubes ###
232
+ u = extract_fields(bound_min, bound_max, resolution, query_func)
233
+ vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
234
+ b_max_np = bound_max.detach().cpu().numpy()
235
+ b_min_np = bound_min.detach().cpu().numpy()
236
+
237
+ vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
238
+ ### using maching cubes ###
239
+
240
+ ### using marching tets ###
241
+ # vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
242
+ # vertices = vertices.detach().cpu().numpy()
243
+ # triangles = triangles.detach().cpu().numpy()
244
+ ### using marching tets ###
245
+
246
+ # b_max_np = bound_max.detach().cpu().numpy()
247
+ # b_min_np = bound_min.detach().cpu().numpy()
248
+
249
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
250
+ return vertices, triangles
251
+
252
+ def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func, def_func=None):
253
+ # print('threshold: {}'.format(threshold))
254
+
255
+ ### using maching cubes ###
256
+ # u = extract_fields(bound_min, bound_max, resolution, query_func)
257
+ # vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
258
+ # b_max_np = bound_max.detach().cpu().numpy()
259
+ # b_min_np = bound_min.detach().cpu().numpy()
260
+
261
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
262
+ ### using maching cubes ###
263
+
264
+ ##
265
+ ### using marching tets ### fiels from tets ##
266
+ vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=def_func)
267
+ # vertices = vertices.detach().cpu().numpy()
268
+ # triangles = triangles.detach().cpu().numpy()
269
+ ### using marching tets ###
270
+
271
+ # b_max_np = bound_max.detach().cpu().numpy()
272
+ # b_min_np = bound_min.detach().cpu().numpy()
273
+ #
274
+
275
+ # vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
276
+ return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
277
+
278
+
279
+ def sample_pdf(bins, weights, n_samples, det=False):
280
+ # This implementation is from NeRF
281
+ # Get pdf
282
+ weights = weights + 1e-5 # prevent nans
283
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
284
+ cdf = torch.cumsum(pdf, -1)
285
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
286
+ # Take uniform samples
287
+ if det:
288
+ u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
289
+ u = u.expand(list(cdf.shape[:-1]) + [n_samples])
290
+ else:
291
+ u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
292
+
293
+ # Invert CDF # invert cdf #
294
+ u = u.contiguous()
295
+ inds = torch.searchsorted(cdf, u, right=True)
296
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
297
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
298
+ inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
299
+
300
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
301
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
302
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
303
+
304
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
305
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
306
+ t = (u - cdf_g[..., 0]) / denom
307
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
308
+
309
+ return samples
310
+
311
+
312
+ def load_GT_vertices(GT_meshes_folder):
313
+ tot_meshes_fns = os.listdir(GT_meshes_folder)
314
+ tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
315
+ tot_mesh_verts = []
316
+ tot_mesh_faces = []
317
+ n_tot_verts = 0
318
+ for fn in tot_meshes_fns:
319
+ cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
320
+ obj_mesh = trimesh.load(cur_mesh_fn, process=False)
321
+ # obj_mesh.remove_degenerate_faces(height=1e-06)
322
+
323
+ verts_obj = np.array(obj_mesh.vertices)
324
+ faces_obj = np.array(obj_mesh.faces)
325
+
326
+ tot_mesh_verts.append(verts_obj)
327
+ tot_mesh_faces.append(faces_obj + n_tot_verts)
328
+ n_tot_verts += verts_obj.shape[0]
329
+
330
+ # tot_mesh_faces.append(faces_obj)
331
+ tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
332
+ tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
333
+ return tot_mesh_verts, tot_mesh_faces
334
+
335
+
336
+ class NeuSRenderer:
337
+ def __init__(self,
338
+ nerf,
339
+ sdf_network,
340
+ deviation_network,
341
+ color_network,
342
+ n_samples,
343
+ n_importance,
344
+ n_outside,
345
+ up_sample_steps,
346
+ perturb):
347
+ self.nerf = nerf # multiple sdf networks and deviation networks and xxx #
348
+ self.sdf_network = sdf_network
349
+ self.deviation_network = deviation_network
350
+ self.color_network = color_network
351
+ self.n_samples = n_samples
352
+ self.n_importance = n_importance
353
+ self.n_outside = n_outside
354
+ self.up_sample_steps = up_sample_steps
355
+ self.perturb = perturb
356
+
357
+ GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
358
+ self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
359
+ maxx_pts = 25.
360
+ minn_pts = -15.
361
+ self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
362
+ f = SDF(self.mesh_vertices, self.mesh_faces)
363
+ self.gt_sdf = f ## a unite sphere or box
364
+
365
+ self.minn_pts = 0
366
+ self.maxx_pts = 1.
367
+
368
+ # self.minn_pts = -1.5 # gorudn-truth states with the deformation -> update the sdf value fiedl
369
+ # self.maxx_pts = 1.5 #
370
+ self.bkg_pts = ... # TODO: the bkg pts # bkg_pts; # bkg_pts_defs #
371
+ self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
372
+ self.dist_interp_thres = ... # TODO: set the cur_bkg_pts_defs #
373
+
374
+ self.bending_network = ... # TODO: add the bending network #
375
+ self.use_bending_network = ... # TODO: set the property #
376
+ self.use_delta_bending = ... # TODO
377
+ self.prev_sdf_network = ... # TODO
378
+ self.use_selector = False
379
+ # self.bending_network_rigidtrans_forward = ... ## TODO: set the rigidjtrans forward ###
380
+ # timestep_to_mesh, timestep_to_passive_mesh, bending_net, bending_net_passive, act_sdf_net, details=None, special_loss_return=False
381
+ self.timestep_to_mesh = ... ## TODO
382
+ self.timestep_to_passive_mesh = ... ### TODO
383
+ self.bending_net = ... ## TODO
384
+ self.bending_net_passive = ... ### TODO
385
+ # self.act_sdf_net = ... ### TODO
386
+ self.bending_net_kinematic = ... ### TODO
387
+ self.time_to_act_joints = ... # ## TODO
388
+ # use bending network #
389
+ # two bending netwrok
390
+ # two sdf networks # deform pts kinematic #
391
+
392
+ def deform_pts_kinematic(self, pts, pts_ts=0): # deform pts #
393
+
394
+ if self.use_bending_network:
395
+ if len(pts.size()) == 3:
396
+ nnb, nns = pts.size(0), pts.size(1)
397
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
398
+ else:
399
+ pts_exp = pts
400
+ # pts_ts #
401
+ if self.use_delta_bending:
402
+
403
+ if isinstance(self.bending_net_kinematic, list):
404
+ pts_offsets = [] #
405
+ for i_obj, cur_bending_network in enumerate(self.bending_net_kinematic):
406
+ if isinstance(cur_bending_network, fields.BendingNetwork):
407
+ for cur_pts_ts in range(pts_ts, -1, -1):
408
+ cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
409
+ # elif isinstance(cur_bending_network, fields.BendingNetworkForward):
410
+ # cur_pts_exp = cur_bending_network(input_pts=pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
411
+ elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
412
+ cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
413
+ else:
414
+ raise ValueError('Encountered with unexpected bending network class...')
415
+ pts_offsets.append(cur_pts_exp - pts_exp)
416
+ pts_offsets = torch.stack(pts_offsets, dim=0)
417
+ pts_offsets = torch.sum(pts_offsets, dim=0)
418
+ pts_exp = pts_exp + pts_offsets
419
+ # for cur_pts_ts in range(pts_ts, -1, -1):
420
+ # if isinstance(self.bending_network, list): # pts ts #
421
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
422
+ # pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
423
+ # else:
424
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
425
+ else:
426
+ if isinstance(self.bending_net_kinematic, list): # prev sdf network #
427
+ pts_offsets = []
428
+ for i_obj, cur_bending_network in enumerate(self.bending_net_kinematic):
429
+ bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
430
+ pts_offsets.append(bended_pts_exp - pts_exp)
431
+ pts_offsets = torch.stack(pts_offsets, dim=0)
432
+ pts_offsets = torch.sum(pts_offsets, dim=0)
433
+ pts_exp = pts_exp + pts_offsets
434
+ else:
435
+ pts_exp = self.bending_net_kinematic(pts_exp, input_pts_ts=pts_ts)
436
+ if len(pts.size()) == 3:
437
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
438
+ else:
439
+ pts = pts_exp
440
+ return pts
441
+
442
+ # pts: nn_batch x nn_samples x 3
443
+ if len(pts.size()) == 3:
444
+ nnb, nns = pts.size(0), pts.size(1)
445
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
446
+ else:
447
+ pts_exp = pts
448
+ # print(f"prior to deforming: {pts.size()}")
449
+
450
+ dist_pts_to_bkg_pts = torch.sum(
451
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
452
+ )
453
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
454
+ dist_mask_float = dist_mask.float()
455
+
456
+ # dist_mask_float #
457
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
458
+ cur_fr_pts_def = torch.sum(
459
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
460
+ )
461
+ dist_mask_float_summ = torch.sum(
462
+ dist_mask_float, dim=1
463
+ )
464
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
465
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
466
+ pts_exp = pts_exp - cur_fr_pts_def
467
+ if len(pts.size()) == 3:
468
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
469
+ else:
470
+ pts = pts_exp
471
+ return pts #
472
+
473
+
474
+ def deform_pts_kinematic_active(self, pts, pts_ts=0): # deform pts #
475
+
476
+ if self.use_bending_network:
477
+ if len(pts.size()) == 3:
478
+ nnb, nns = pts.size(0), pts.size(1)
479
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
480
+ else:
481
+ pts_exp = pts
482
+ # pts_ts #
483
+ if self.use_delta_bending:
484
+ for cur_pts_ts in range(pts_ts, -1, -1):
485
+ pts_exp = self.bending_net(pts_exp, input_pts_ts=cur_pts_ts)
486
+ # if isinstance(self.bending_net_kinematic, list):
487
+ # pts_offsets = [] #
488
+ # for i_obj, cur_bending_network in enumerate(self.bending_net_kinematic):
489
+ # if isinstance(cur_bending_network, fields.BendingNetwork):
490
+ # for cur_pts_ts in range(pts_ts, -1, -1):
491
+ # cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
492
+ # # elif isinstance(cur_bending_network, fields.BendingNetworkForward):
493
+ # # cur_pts_exp = cur_bending_network(input_pts=pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
494
+ # elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
495
+ # cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
496
+ # else:
497
+ # raise ValueError('Encountered with unexpected bending network class...')
498
+ # pts_offsets.append(cur_pts_exp - pts_exp)
499
+ # pts_offsets = torch.stack(pts_offsets, dim=0)
500
+ # pts_offsets = torch.sum(pts_offsets, dim=0)
501
+ # pts_exp = pts_exp + pts_offsets
502
+ # for cur_pts_ts in range(pts_ts, -1, -1):
503
+ # if isinstance(self.bending_network, list): # pts ts #
504
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
505
+ # pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
506
+ # else:
507
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
508
+ else:
509
+ pts_exp = self.bending_net(pts_exp, input_pts_ts=pts_ts)
510
+ # if isinstance(self.bending_net_kinematic, list): # prev sdf network #
511
+ # pts_offsets = []
512
+ # for i_obj, cur_bending_network in enumerate(self.bending_net_kinematic):
513
+ # bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
514
+ # pts_offsets.append(bended_pts_exp - pts_exp)
515
+ # pts_offsets = torch.stack(pts_offsets, dim=0)
516
+ # pts_offsets = torch.sum(pts_offsets, dim=0)
517
+ # pts_exp = pts_exp + pts_offsets
518
+ # else:
519
+ # pts_exp = self.bending_net_kinematic(pts_exp, input_pts_ts=pts_ts)
520
+ if len(pts.size()) == 3:
521
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
522
+ else:
523
+ pts = pts_exp
524
+ return pts
525
+
526
+ # pts: nn_batch x nn_samples x 3
527
+ if len(pts.size()) == 3:
528
+ nnb, nns = pts.size(0), pts.size(1)
529
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
530
+ else:
531
+ pts_exp = pts
532
+ # print(f"prior to deforming: {pts.size()}")
533
+
534
+ dist_pts_to_bkg_pts = torch.sum(
535
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
536
+ )
537
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
538
+ dist_mask_float = dist_mask.float()
539
+
540
+ # dist_mask_float #
541
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
542
+ cur_fr_pts_def = torch.sum(
543
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
544
+ )
545
+ dist_mask_float_summ = torch.sum(
546
+ dist_mask_float, dim=1
547
+ )
548
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
549
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
550
+ pts_exp = pts_exp - cur_fr_pts_def
551
+ if len(pts.size()) == 3:
552
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
553
+ else:
554
+ pts = pts_exp
555
+ return pts #
556
+
557
+ # get the pts and render the pts #
558
+ # pts and the rendering pts #
559
+ def deform_pts(self, pts, pts_ts=0): # deform pts #
560
+
561
+ if self.use_bending_network:
562
+ if len(pts.size()) == 3:
563
+ nnb, nns = pts.size(0), pts.size(1)
564
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
565
+ else:
566
+ pts_exp = pts
567
+ # pts_ts #
568
+ if self.use_delta_bending:
569
+
570
+ if isinstance(self.bending_network, list):
571
+ pts_offsets = []
572
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
573
+ if isinstance(cur_bending_network, fields.BendingNetwork):
574
+ for cur_pts_ts in range(pts_ts, -1, -1):
575
+ cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
576
+ elif isinstance(cur_bending_network, fields.BendingNetworkForward):
577
+ for cur_pts_ts in range(pts_ts-1, -1, -1):
578
+ cur_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
579
+ elif isinstance(cur_bending_network, fields.BendingNetworkForwardJointDyn):
580
+ # for cur_pts_ts in range(pts_ts-1, -1, -1):
581
+ for cur_pts_ts in range(pts_ts, 0, -1):
582
+ cur_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, timestep_to_joints_pts=self.time_to_act_joints, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
583
+ # elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
584
+ # cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
585
+ else:
586
+ raise ValueError('Encountered with unexpected bending network class...')
587
+ pts_offsets.append(cur_pts_exp - pts_exp)
588
+ pts_offsets = torch.stack(pts_offsets, dim=0)
589
+ pts_offsets = torch.sum(pts_offsets, dim=0)
590
+ pts_exp = pts_exp + pts_offsets
591
+ # for cur_pts_ts in range(pts_ts, -1, -1):
592
+ # if isinstance(self.bending_network, list): # pts ts #
593
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
594
+ # pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
595
+ # else:
596
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
597
+ else:
598
+ if isinstance(self.bending_network, list): # prev sdf network #
599
+ pts_offsets = []
600
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
601
+ bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
602
+ pts_offsets.append(bended_pts_exp - pts_exp)
603
+ pts_offsets = torch.stack(pts_offsets, dim=0)
604
+ pts_offsets = torch.sum(pts_offsets, dim=0)
605
+ pts_exp = pts_exp + pts_offsets
606
+ else:
607
+ pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
608
+ if len(pts.size()) == 3:
609
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
610
+ else:
611
+ pts = pts_exp
612
+ return pts
613
+
614
+ # pts: nn_batch x nn_samples x 3
615
+ if len(pts.size()) == 3:
616
+ nnb, nns = pts.size(0), pts.size(1)
617
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
618
+ else:
619
+ pts_exp = pts
620
+ # print(f"prior to deforming: {pts.size()}")
621
+
622
+ dist_pts_to_bkg_pts = torch.sum(
623
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
624
+ )
625
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
626
+ dist_mask_float = dist_mask.float()
627
+
628
+ # dist_mask_float #
629
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
630
+ cur_fr_pts_def = torch.sum(
631
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
632
+ )
633
+ dist_mask_float_summ = torch.sum(
634
+ dist_mask_float, dim=1
635
+ )
636
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
637
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
638
+ pts_exp = pts_exp - cur_fr_pts_def
639
+ if len(pts.size()) == 3:
640
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
641
+ else:
642
+ pts = pts_exp
643
+ return pts #
644
+
645
+
646
+ def deform_pts_with_selector(self, pts, pts_ts=0): # deform pts #
647
+
648
+ if self.use_bending_network:
649
+ if len(pts.size()) == 3:
650
+ nnb, nns = pts.size(0), pts.size(1)
651
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
652
+ else:
653
+ pts_exp = pts
654
+ # pts_ts #
655
+ if self.use_delta_bending:
656
+ if isinstance(self.bending_network, list):
657
+ bended_pts = []
658
+ queries_sdfs_selector = []
659
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
660
+ # if cur_bending_network.use_opt_rigid_translations:
661
+ # bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
662
+ # else:
663
+ # # bended_pts_exp = pts_exp.clone()
664
+ # for cur_pts_ts in range(pts_ts, -1, -1):
665
+ # bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
666
+
667
+ if isinstance(cur_bending_network, fields.BendingNetwork):
668
+ for cur_pts_ts in range(pts_ts, -1, -1):
669
+ bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
670
+ elif isinstance(cur_bending_network, fields.BendingNetworkForward):
671
+ for cur_pts_ts in range(pts_ts-1, -1, -1):
672
+ bended_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
673
+ elif isinstance(cur_bending_network, fields.BendingNetworkForwardJointDyn):
674
+ # for cur_pts_ts in range(pts_ts-1, -1, -1):
675
+ for cur_pts_ts in range(pts_ts, 0, -1):
676
+ bended_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, timestep_to_joints_pts=self.time_to_act_joints, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
677
+
678
+ if pts_ts == 0:
679
+ bended_pts_exp = pts_exp.clone()
680
+ _, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
681
+ bended_pts.append(bended_pts_exp)
682
+ queries_sdfs_selector.append(cur_bended_pts_selecotr)
683
+ bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
684
+ queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
685
+ # queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
686
+ sdf_selector = queries_sdfs_selector[:, -1]
687
+ # sdf_selector = queries_sdfs_selector
688
+ # delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
689
+ # print(f"bended_pts: {bended_pts.size()}, sdf_selector: {sdf_selector.size()}, maxx_sdf_selector: {torch.max(sdf_selector)}, minn_sdf_selector: {torch.min(sdf_selector)}")
690
+ bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
691
+ # print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
692
+ # pts_exp = bended_pts.squeeze(1)
693
+ pts_exp = bended_pts
694
+
695
+
696
+ # for cur_pts_ts in range(pts_ts, -1, -1):
697
+ # if isinstance(self.bending_network, list):
698
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
699
+ # pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
700
+ # else:
701
+
702
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
703
+ else:
704
+ if isinstance(self.bending_network, list): # prev sdf network #
705
+ # pts_offsets = []
706
+ bended_pts = []
707
+ queries_sdfs_selector = []
708
+ for i_obj, cur_bending_network in enumerate(self.bending_network):
709
+ bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
710
+ # pts_offsets.append(bended_pts_exp - pts_exp)
711
+ _, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
712
+ bended_pts.append(bended_pts_exp)
713
+ queries_sdfs_selector.append(cur_bended_pts_selecotr)
714
+ bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
715
+ queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
716
+ # queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
717
+ sdf_selector = queries_sdfs_selector[:, -1]
718
+ # sdf_selector = queries_sdfs_selector
719
+
720
+
721
+ # delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
722
+ bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
723
+ # print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
724
+ pts_exp = bended_pts.squeeze(1)
725
+
726
+ # pts_offsets = torch.stack(pts_offsets, dim=0)
727
+ # pts_offsets = torch.sum(pts_offsets, dim=0)
728
+ # pts_exp = pts_exp + pts_offsets
729
+ else:
730
+ pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
731
+ if len(pts.size()) == 3:
732
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
733
+ else:
734
+ pts = pts_exp
735
+ return pts, sdf_selector
736
+
737
+ # pts: nn_batch x nn_samples x 3
738
+ if len(pts.size()) == 3:
739
+ nnb, nns = pts.size(0), pts.size(1)
740
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
741
+ else:
742
+ pts_exp = pts
743
+ # print(f"prior to deforming: {pts.size()}")
744
+
745
+ dist_pts_to_bkg_pts = torch.sum(
746
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
747
+ )
748
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
749
+ dist_mask_float = dist_mask.float()
750
+
751
+ # dist_mask_float #
752
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
753
+ cur_fr_pts_def = torch.sum(
754
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
755
+ )
756
+ dist_mask_float_summ = torch.sum(
757
+ dist_mask_float, dim=1
758
+ )
759
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
760
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
761
+ pts_exp = pts_exp - cur_fr_pts_def
762
+ if len(pts.size()) == 3:
763
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
764
+ else:
765
+ pts = pts_exp
766
+ return pts #
767
+
768
+
769
+ def deform_pts_passive(self, pts, pts_ts=0):
770
+
771
+ if self.use_bending_network:
772
+ if len(pts.size()) == 3:
773
+ nnb, nns = pts.size(0), pts.size(1)
774
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
775
+ else:
776
+ pts_exp = pts
777
+ # pts_ts #
778
+
779
+
780
+
781
+ if self.use_delta_bending:
782
+ cur_bending_network = self.bending_network[-1]
783
+ if isinstance(cur_bending_network, fields.BendingNetwork):
784
+ for cur_pts_ts in range(pts_ts, -1, -1):
785
+ cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
786
+ elif isinstance(cur_bending_network, fields.BendingNetworkForward):
787
+ for cur_pts_ts in range(pts_ts-1, -1, -1):
788
+ cur_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
789
+ elif isinstance(cur_bending_network, fields.BendingNetworkForwardJointDyn):
790
+ # for cur_pts_ts in range(pts_ts-1, -1, -1):
791
+ for cur_pts_ts in range(pts_ts, 0, -1):
792
+ cur_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, timestep_to_joints_pts=self.time_to_act_joints, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
793
+ # elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
794
+ # cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
795
+ else:
796
+ raise ValueError('Encountered with unexpected bending network class...')
797
+ # for cur_pts_ts in range(pts_ts, -1, -1):
798
+ # if isinstance(self.bending_network, list):
799
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
800
+ # pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
801
+ # else:
802
+ # pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
803
+ # time_to_offset = {
804
+ # 0:[-0.03338804, 0.07566567, 0.0958022 ],
805
+ # 1:[-0.05909395, 0.05454276, 0.09974975],
806
+ # 2: [-0.07214502, -0.00118192, 0.09003166],
807
+ # 3: [-0.10040219, -0.01334709, 0.08493543],
808
+ # 4: [-0.10047092, -0.01264334, 0.05320398],
809
+ # 5: [-0.09152254, 0.00722668, 0.0101514 ],
810
+ # }
811
+
812
+ if pts_ts > 0:
813
+ pts_exp = cur_pts_exp
814
+ else:
815
+ pts_exp = pts_exp
816
+ else:
817
+ # if isinstance(self.bending_network, list):
818
+ # pts_offsets = []
819
+ # for i_obj, cur_bending_network in enumerate(self.bending_network):
820
+ # bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
821
+ # pts_offsets.append(bended_pts_exp - pts_exp)
822
+ # pts_offsets = torch.stack(pts_offsets, dim=0)
823
+ # pts_offsets = torch.sum(pts_offsets, dim=0)
824
+ # pts_exp = pts_exp + pts_offsets
825
+ # else:
826
+ pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=pts_ts)
827
+ if len(pts.size()) == 3:
828
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
829
+ else:
830
+ pts = pts_exp
831
+ return pts
832
+
833
+ # pts: nn_batch x nn_samples x 3
834
+ if len(pts.size()) == 3:
835
+ nnb, nns = pts.size(0), pts.size(1)
836
+ pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
837
+ else:
838
+ pts_exp = pts
839
+ # print(f"prior to deforming: {pts.size()}")
840
+
841
+ dist_pts_to_bkg_pts = torch.sum(
842
+ (pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
843
+ )
844
+ dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
845
+ dist_mask_float = dist_mask.float()
846
+
847
+ # dist_mask_float #
848
+ cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
849
+ cur_fr_pts_def = torch.sum(
850
+ cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
851
+ )
852
+ dist_mask_float_summ = torch.sum(
853
+ dist_mask_float, dim=1
854
+ )
855
+ dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
856
+ cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
857
+ pts_exp = pts_exp - cur_fr_pts_def
858
+ if len(pts.size()) == 3:
859
+ pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
860
+ else:
861
+ pts = pts_exp
862
+ return pts #
863
+
864
+
865
+ def query_pts_sdf_fn_for_selector(self, pts):
866
+ # for negative
867
+ # 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field
868
+ # 2) negative in thie field and also negative in the previous field --->
869
+ # 2) for positive values of this current field --->
870
+ # maxx_pts, _ = torch.max(pts, dim=0)
871
+ # minn_pts, _ = torch.min(pts, dim=0)
872
+
873
+ cur_sdf = self.sdf_network.sdf(pts).squeeze(-1)
874
+ prev_sdf = self.prev_sdf_network.sdf(pts).squeeze(-1)
875
+ neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
876
+ neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
877
+
878
+ neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
879
+
880
+ pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
881
+ pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
882
+ res_sdf = torch.zeros_like(cur_sdf)
883
+
884
+ # print(f"res_sdf: {res_sdf.size()}, neg_neg: {neg_neg.size()}, pts: {pts.size()}, maxx_pts: {maxx_pts}, minn_pts: {minn_pts}")
885
+ # if torch.sum(neg_neg.float()).item() > 0.:
886
+ res_sdf[neg_neg] = 1. #
887
+ # if torch.sum(neg_pos.float()).item() > 0.:
888
+ res_sdf[neg_pos] = cur_sdf[neg_pos]
889
+ # if torch.sum(pos_neg.float()).item() > 0.:
890
+ res_sdf[pos_neg] = cur_sdf[pos_neg]
891
+
892
+ # inside the residual mesh -> must be neg and pos
893
+ res_sdf_selector = torch.zeros_like(cur_sdf).long() #
894
+ # res_sdf_selector[neg_pos] = 1 # is the residual mesh
895
+ # if torch.sum(neg_weq_pos.float()).item() > 0.:
896
+ res_sdf_selector[neg_weq_pos] = 1
897
+ # res_sdf_selector[]
898
+
899
+ cat_cur_prev_sdf = torch.stack(
900
+ [cur_sdf, prev_sdf], dim=-1
901
+ )
902
+ minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
903
+
904
+ if torch.sum(pos_pos.float()).item() > 0.:
905
+ res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
906
+
907
+ return res_sdf, res_sdf_selector
908
+
909
+ def query_func_sdf(self, pts):
910
+ # if isinstance(self.sdf_network, list):
911
+ # tot_sdf_values = []
912
+ # for i_obj, cur_sdf_network in enumerate(self.sdf_network):
913
+ # cur_sdf_values = cur_sdf_network.sdf(pts)
914
+ # tot_sdf_values.append(cur_sdf_values)
915
+ # tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
916
+ # tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
917
+ # sdf = tot_sdf_values
918
+ # else:
919
+ # sdf = self.sdf_network.sdf(pts)
920
+
921
+
922
+ cur_sdf = self.sdf_network.sdf(pts)
923
+ prev_sdf = self.prev_sdf_network.sdf(pts)
924
+ neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
925
+ neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
926
+
927
+ neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
928
+
929
+ pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
930
+ pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
931
+ res_sdf = torch.zeros_like(cur_sdf)
932
+
933
+ # print(f"res_sdf: {res_sdf.size()}, neg_neg: {neg_neg.size()}, pts: {pts.size()}, maxx_pts: {maxx_pts}, minn_pts: {minn_pts}")
934
+ # if torch.sum(neg_neg.float()).item() > 0.:
935
+ res_sdf[neg_neg] = prev_sdf[neg_neg]
936
+ # if torch.sum(neg_pos.float()).item() > 0.:
937
+ res_sdf[neg_pos] = cur_sdf[neg_pos]
938
+ # if torch.sum(pos_neg.float()).item() > 0.:
939
+ res_sdf[pos_neg] = cur_sdf[pos_neg]
940
+
941
+ # inside the residual mesh -> must be neg and pos
942
+ # res_sdf_selector = torch.zeros_like(cur_sdf).long() #
943
+ # res_sdf_selector[neg_pos] = 1 # is the residual mesh
944
+ # if torch.sum(neg_weq_pos.float()).item() > 0.:
945
+ # res_sdf_selector[neg_weq_pos] = 1
946
+ # res_sdf_selector[]
947
+
948
+ cat_cur_prev_sdf = torch.stack(
949
+ [cur_sdf, prev_sdf], dim=-1
950
+ )
951
+ minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
952
+
953
+ # if torch.sum(pos_pos.float()).item() > 0.:
954
+ res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
955
+
956
+ return res_sdf
957
+
958
+ def query_func_active(self, pts):
959
+ # if isinstance(self.sdf_network, list):
960
+ # tot_sdf_values = []
961
+ # for i_obj, cur_sdf_network in enumerate(self.sdf_network):
962
+ # cur_sdf_values = cur_sdf_network.sdf(pts)
963
+ # tot_sdf_values.append(cur_sdf_values)
964
+ # tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
965
+ # tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
966
+ # sdf = tot_sdf_values
967
+ # else:
968
+ sdf = self.prev_sdf_network.sdf(pts)
969
+ return sdf
970
+
971
+ def query_func_sdf_passive(self, pts):
972
+ # if isinstance(self.sdf_network, list):
973
+ # tot_sdf_values = []
974
+ # for i_obj, cur_sdf_network in enumerate(self.sdf_network):
975
+ # cur_sdf_values = cur_sdf_network.sdf(pts)
976
+ # tot_sdf_values.append(cur_sdf_values)
977
+ # tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
978
+ # tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
979
+ # sdf = tot_sdf_values
980
+ # else:
981
+ sdf = self.sdf_network[-1].sdf(pts)
982
+
983
+ return sdf
984
+
985
+
986
+ def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None, pts_ts=0):
987
+ """
988
+ Render background
989
+ """
990
+ batch_size, n_samples = z_vals.shape
991
+
992
+ # Section length
993
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
994
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
995
+ mid_z_vals = z_vals + dists * 0.5
996
+
997
+ # Section midpoints #
998
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
999
+
1000
+ # pts = pts.flip((-1,)) * 2 - 1
1001
+ pts = pts * 2 - 1
1002
+
1003
+ if self.use_selector:
1004
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
1005
+ else:
1006
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
1007
+
1008
+ dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
1009
+ pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
1010
+
1011
+ dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
1012
+
1013
+ pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
1014
+ dirs = dirs.reshape(-1, 3)
1015
+
1016
+ density, sampled_color = nerf(pts, dirs)
1017
+ sampled_color = torch.sigmoid(sampled_color)
1018
+ alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
1019
+ alpha = alpha.reshape(batch_size, n_samples)
1020
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
1021
+ sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
1022
+ color = (weights[:, :, None] * sampled_color).sum(dim=1)
1023
+ if background_rgb is not None:
1024
+ color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
1025
+
1026
+ return {
1027
+ 'color': color,
1028
+ 'sampled_color': sampled_color,
1029
+ 'alpha': alpha,
1030
+ 'weights': weights,
1031
+ }
1032
+
1033
+ def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s, pts_ts=0):
1034
+ """
1035
+ Up sampling give a fixed inv_s
1036
+ """
1037
+ batch_size, n_samples = z_vals.shape
1038
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
1039
+
1040
+ # pts = pts.flip((-1,)) * 2 - 1
1041
+ pts = pts * 2 - 1
1042
+
1043
+ if self.use_selector:
1044
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
1045
+ else:
1046
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
1047
+
1048
+ radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
1049
+ inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
1050
+ sdf = sdf.reshape(batch_size, n_samples)
1051
+ prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
1052
+ prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
1053
+ mid_sdf = (prev_sdf + next_sdf) * 0.5
1054
+ cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
1055
+
1056
+ # ----------------------------------------------------------------------------------------------------------
1057
+ # Use min value of [ cos, prev_cos ]
1058
+ # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
1059
+ # robust when meeting situations like below:
1060
+ #
1061
+ # SDF
1062
+ # ^
1063
+ # |\ -----x----...
1064
+ # | \ /
1065
+ # | x x
1066
+ # |---\----/-------------> 0 level
1067
+ # | \ /
1068
+ # | \/
1069
+ # |
1070
+ # ----------------------------------------------------------------------------------------------------------
1071
+ prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
1072
+ cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
1073
+ cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
1074
+ cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
1075
+
1076
+ dist = (next_z_vals - prev_z_vals)
1077
+ prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
1078
+ next_esti_sdf = mid_sdf + cos_val * dist * 0.5
1079
+ prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
1080
+ next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
1081
+ alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
1082
+ weights = alpha * torch.cumprod(
1083
+ torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
1084
+
1085
+ z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
1086
+ return z_samples
1087
+
1088
+ def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False, pts_ts=0):
1089
+ batch_size, n_samples = z_vals.shape
1090
+ _, n_importance = new_z_vals.shape
1091
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
1092
+
1093
+ # pts = pts.flip((-1,)) * 2 - 1
1094
+ pts = pts * 2 - 1
1095
+
1096
+ if self.use_selector:
1097
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
1098
+ else:
1099
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
1100
+
1101
+ z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
1102
+ z_vals, index = torch.sort(z_vals, dim=-1)
1103
+
1104
+ if not last:
1105
+ if isinstance(self.sdf_network, list):
1106
+ tot_new_sdf = []
1107
+ for i_obj, cur_sdf_network in enumerate(self.sdf_network):
1108
+ cur_new_sdf = cur_sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
1109
+ tot_new_sdf.append(cur_new_sdf)
1110
+ tot_new_sdf = torch.stack(tot_new_sdf, dim=-1)
1111
+ new_sdf, _ = torch.min(tot_new_sdf, dim=-1) #
1112
+ else:
1113
+ if self.use_selector:
1114
+ new_sdf_cur = self.sdf_network.sdf(pts.reshape(-1, 3)) # .reshape(batch_size, n_importance)
1115
+ new_sdf_prev = self.prev_sdf_network.sdf(pts.reshape(-1, 3)) # .reshape(batch_size, n_importance)
1116
+ new_sdf = torch.stack([new_sdf_prev, new_sdf_cur], dim=1)
1117
+ new_sdf = batched_index_select(new_sdf, sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
1118
+ new_sdf = new_sdf.reshape(batch_size, n_importance)
1119
+ else:
1120
+ new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
1121
+ sdf = torch.cat([sdf, new_sdf], dim=-1)
1122
+ xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
1123
+ index = index.reshape(-1)
1124
+ sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
1125
+
1126
+ return z_vals, sdf
1127
+
1128
+
1129
+
1130
+ def render_core(self,
1131
+ rays_o,
1132
+ rays_d,
1133
+ z_vals,
1134
+ sample_dist,
1135
+ sdf_network,
1136
+ deviation_network,
1137
+ color_network,
1138
+ background_alpha=None,
1139
+ background_sampled_color=None,
1140
+ background_rgb=None,
1141
+ cos_anneal_ratio=0.0,
1142
+ pts_ts=0):
1143
+ batch_size, n_samples = z_vals.shape
1144
+
1145
+ # Section length
1146
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
1147
+ dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
1148
+ mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
1149
+
1150
+ # Section midpoints
1151
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
1152
+ dirs = rays_d[:, None, :].expand(pts.shape)
1153
+
1154
+ pts = pts.reshape(-1, 3) # pts, nn_ou
1155
+ dirs = dirs.reshape(-1, 3)
1156
+
1157
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
1158
+
1159
+ # pts = pts.flip((-1,)) * 2 - 1
1160
+ pts = pts * 2 - 1
1161
+
1162
+ if self.use_selector:
1163
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
1164
+ else:
1165
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
1166
+
1167
+ if isinstance(sdf_network, list):
1168
+ tot_sdf = []
1169
+ tot_feature_vector = []
1170
+ tot_obj_sel = []
1171
+ tot_gradients = []
1172
+ for i_obj, cur_sdf_network in enumerate(sdf_network):
1173
+ cur_sdf_nn_output = cur_sdf_network(pts)
1174
+ cur_sdf, cur_feature_vector = cur_sdf_nn_output[:, :1], cur_sdf_nn_output[:, 1:]
1175
+ tot_sdf.append(cur_sdf)
1176
+ tot_feature_vector.append(cur_feature_vector)
1177
+
1178
+ gradients = cur_sdf_network.gradient(pts).squeeze()
1179
+ tot_gradients.append(gradients)
1180
+ tot_sdf = torch.stack(tot_sdf, dim=-1)
1181
+ sdf, obj_sel = torch.min(tot_sdf, dim=-1)
1182
+ feature_vector = torch.stack(tot_feature_vector, dim=1)
1183
+
1184
+ # batched_index_select
1185
+ # print(f"before sel: {feature_vector.size()}, obj_sel: {obj_sel.size()}")
1186
+ feature_vector = batched_index_select(values=feature_vector, indices=obj_sel, dim=1).squeeze(1)
1187
+
1188
+
1189
+ # feature_vector = feature_vector[obj_sel.unsqueeze(-1), :].squeeze(1)
1190
+ # print(f"after sel: {feature_vector.size()}")
1191
+ tot_gradients = torch.stack(tot_gradients, dim=1)
1192
+ # gradients = tot_gradients[obj_sel.unsqueeze(-1)].squeeze(1)
1193
+ gradients = batched_index_select(values=tot_gradients, indices=obj_sel, dim=1).squeeze(1)
1194
+ # print(f"gradients: {gradients.size()}, tot_gradients: {tot_gradients.size()}")
1195
+
1196
+ else:
1197
+ # sdf_nn_output = sdf_network(pts)
1198
+ # sdf = sdf_nn_output[:, :1]
1199
+ # feature_vector = sdf_nn_output[:, 1:]
1200
+ # gradients = sdf_network.gradient(pts).squeeze()
1201
+
1202
+ if self.use_selector:
1203
+ prev_sdf_nn_output = self.prev_sdf_network(pts)
1204
+ prev_gradients = self.prev_sdf_network.gradient(pts).squeeze()
1205
+ cur_sdf_nn_output = self.sdf_network(pts)
1206
+ cur_gradients = self.sdf_network.gradient(pts).squeeze()
1207
+
1208
+ sdf_nn_output = torch.stack([prev_sdf_nn_output, cur_sdf_nn_output], dim=1)
1209
+ sdf_nn_output = batched_index_select(sdf_nn_output, sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
1210
+
1211
+ sdf = sdf_nn_output[:, :1]
1212
+ feature_vector = sdf_nn_output[:, 1:]
1213
+
1214
+ gradients = torch.stack([prev_gradients, cur_gradients], dim=1)
1215
+ gradients = batched_index_select(gradients, sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
1216
+ else:
1217
+ sdf_nn_output = sdf_network(pts)
1218
+ sdf = sdf_nn_output[:, :1]
1219
+ feature_vector = sdf_nn_output[:, 1:]
1220
+ gradients = sdf_network.gradient(pts).squeeze()
1221
+ # new_sdf_cur = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
1222
+ # new_sdf_prev = self.prev_sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
1223
+ # new_sdf = torch.stack([new_sdf_prev, new_sdf_cur], dim=1)
1224
+ # new_sdf = batched_index_select(new_sdf, sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
1225
+
1226
+ sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
1227
+
1228
+ # deviation network #
1229
+ inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
1230
+ inv_s = inv_s.expand(batch_size * n_samples, 1)
1231
+
1232
+ true_cos = (dirs * gradients).sum(-1, keepdim=True)
1233
+
1234
+ # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
1235
+ # the cos value "not dead" at the beginning training iterations, for better convergence.
1236
+ iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
1237
+ F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
1238
+
1239
+ # Estimate signed distances at section points
1240
+ estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
1241
+ estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
1242
+
1243
+ prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
1244
+ next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
1245
+
1246
+ p = prev_cdf - next_cdf
1247
+ c = prev_cdf
1248
+
1249
+ alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
1250
+
1251
+ pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
1252
+ inside_sphere = (pts_norm < 1.0).float().detach()
1253
+ relax_inside_sphere = (pts_norm < 1.2).float().detach()
1254
+
1255
+ # Render with background
1256
+ if background_alpha is not None:
1257
+ alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
1258
+ alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
1259
+ sampled_color = sampled_color * inside_sphere[:, :, None] +\
1260
+ background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
1261
+ sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
1262
+
1263
+ weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
1264
+ weights_sum = weights.sum(dim=-1, keepdim=True)
1265
+
1266
+ color = (sampled_color * weights[:, :, None]).sum(dim=1)
1267
+ if background_rgb is not None: # Fixed background, usually black
1268
+ color = color + background_rgb * (1.0 - weights_sum)
1269
+
1270
+ # Eikonal loss
1271
+ gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
1272
+ dim=-1) - 1.0) ** 2
1273
+ gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
1274
+
1275
+ return {
1276
+ 'color': color,
1277
+ 'sdf': sdf,
1278
+ 'dists': dists,
1279
+ 'gradients': gradients.reshape(batch_size, n_samples, 3),
1280
+ 's_val': 1.0 / inv_s,
1281
+ 'mid_z_vals': mid_z_vals,
1282
+ 'weights': weights,
1283
+ 'cdf': c.reshape(batch_size, n_samples),
1284
+ 'gradient_error': gradient_error,
1285
+ 'inside_sphere': inside_sphere
1286
+ }
1287
+
1288
+ def render(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
1289
+ batch_size = len(rays_o)
1290
+ sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere
1291
+ z_vals = torch.linspace(0.0, 1.0, self.n_samples) # linspace #
1292
+ z_vals = near + (far - near) * z_vals[None, :]
1293
+
1294
+ z_vals_outside = None
1295
+ if self.n_outside > 0:
1296
+ z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
1297
+
1298
+ n_samples = self.n_samples
1299
+ perturb = self.perturb
1300
+
1301
+ if perturb_overwrite >= 0:
1302
+ perturb = perturb_overwrite
1303
+ if perturb > 0:
1304
+ t_rand = (torch.rand([batch_size, 1]) - 0.5)
1305
+ z_vals = z_vals + t_rand * 2.0 / self.n_samples
1306
+
1307
+ if self.n_outside > 0: # z values output # n_outside #
1308
+ mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
1309
+ upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
1310
+ lower = torch.cat([z_vals_outside[..., :1], mids], -1)
1311
+ t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
1312
+ z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
1313
+
1314
+ if self.n_outside > 0:
1315
+ z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
1316
+
1317
+ background_alpha = None
1318
+ background_sampled_color = None
1319
+
1320
+ # Up sample
1321
+ if self.n_importance > 0:
1322
+ with torch.no_grad():
1323
+ pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
1324
+
1325
+ pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
1326
+ # sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
1327
+ # gt_sdf #
1328
+
1329
+ #
1330
+ # pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
1331
+
1332
+ # pts = pts.flip((-1,)) * 2 - 1
1333
+ pts = pts * 2 - 1
1334
+
1335
+ if self.use_selector:
1336
+ pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
1337
+ else:
1338
+ pts = self.deform_pts(pts=pts, pts_ts=pts_ts) # give nthe pts
1339
+
1340
+ pts_exp = pts.reshape(-1, 3)
1341
+ # minn_pts, _ = torch.min(pts_exp, dim=0)
1342
+ # maxx_pts, _ = torch.max(pts_exp, dim=0) # deformation field (not a rigid one) -> the meshes #
1343
+ # print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
1344
+
1345
+ # pts_to_near = pts - near.unsqueeze(1)
1346
+ # maxx_pts = 1.5; minn_pts = -1.5
1347
+ # # maxx_pts = 3; minn_pts = -3
1348
+ # # maxx_pts = 1; minn_pts = -1
1349
+ # pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
1350
+
1351
+ ## render and iamges ####
1352
+ if use_gt_sdf:
1353
+ ### use the GT sdf field ####
1354
+ # print(f"Using gt sdf :")
1355
+ sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
1356
+ sdf = torch.from_numpy(sdf).float().cuda()
1357
+ sdf = sdf.reshape(batch_size, self.n_samples)
1358
+ ### use the GT sdf field ####
1359
+ else:
1360
+ # pts_exp: (bsz x nn_s) x 3 -> (sdf_network) -> (bsz x nn_s)
1361
+ #### use the optimized sdf field ####
1362
+
1363
+ # sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
1364
+
1365
+ if isinstance(self.sdf_network, list):
1366
+ tot_sdf_values = []
1367
+ for i_obj, cur_sdf_network in enumerate(self.sdf_network):
1368
+ cur_sdf_values = cur_sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
1369
+ tot_sdf_values.append(cur_sdf_values)
1370
+ tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
1371
+ tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
1372
+ sdf = tot_sdf_values
1373
+ else:
1374
+ # sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
1375
+
1376
+ if self.use_selector:
1377
+ prev_sdf = self.prev_sdf_network.sdf(pts_exp) # .reshape(batch_size, self.n_samples)
1378
+ cur_sdf = self.sdf_network.sdf(pts_exp) # .reshape(batch_size, self.n_samples)
1379
+ sdf = torch.stack([prev_sdf, cur_sdf], dim=1)
1380
+ sdf = batched_index_select(sdf, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
1381
+ sdf = sdf.reshape(batch_size, self.n_samples)
1382
+ else:
1383
+ sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
1384
+
1385
+ #### use the optimized sdf field ####
1386
+
1387
+ for i in range(self.up_sample_steps):
1388
+ new_z_vals = self.up_sample(rays_o,
1389
+ rays_d,
1390
+ z_vals,
1391
+ sdf,
1392
+ self.n_importance // self.up_sample_steps,
1393
+ 64 * 2**i,
1394
+ pts_ts=pts_ts)
1395
+ z_vals, sdf = self.cat_z_vals(rays_o,
1396
+ rays_d,
1397
+ z_vals,
1398
+ new_z_vals,
1399
+ sdf,
1400
+ last=(i + 1 == self.up_sample_steps),
1401
+ pts_ts=pts_ts)
1402
+
1403
+ n_samples = self.n_samples + self.n_importance
1404
+
1405
+ # Background model
1406
+ if self.n_outside > 0:
1407
+ z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
1408
+ z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
1409
+ ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf, pts_ts=pts_ts)
1410
+
1411
+ background_sampled_color = ret_outside['sampled_color']
1412
+ background_alpha = ret_outside['alpha']
1413
+
1414
+ # Render core
1415
+ ret_fine = self.render_core(rays_o, #
1416
+ rays_d,
1417
+ z_vals,
1418
+ sample_dist,
1419
+ self.sdf_network,
1420
+ self.deviation_network,
1421
+ self.color_network,
1422
+ background_rgb=background_rgb,
1423
+ background_alpha=background_alpha,
1424
+ background_sampled_color=background_sampled_color,
1425
+ cos_anneal_ratio=cos_anneal_ratio,
1426
+ pts_ts=pts_ts)
1427
+
1428
+ color_fine = ret_fine['color']
1429
+ weights = ret_fine['weights']
1430
+ weights_sum = weights.sum(dim=-1, keepdim=True)
1431
+ gradients = ret_fine['gradients']
1432
+ s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
1433
+
1434
+ return {
1435
+ 'color_fine': color_fine,
1436
+ 's_val': s_val,
1437
+ 'cdf_fine': ret_fine['cdf'],
1438
+ 'weight_sum': weights_sum,
1439
+ 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
1440
+ 'gradients': gradients,
1441
+ 'weights': weights,
1442
+ 'gradient_error': ret_fine['gradient_error'],
1443
+ 'inside_sphere': ret_fine['inside_sphere']
1444
+ }
1445
+
1446
+ # def
1447
+ def extract_fields_from_tets_with_selector(self, bound_min, bound_max, resolution, pts_ts ):
1448
+ # load tet via resolution #
1449
+ # scale them via bounds #
1450
+ # extract the geometry #
1451
+ # /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
1452
+ device = bound_min.device
1453
+ # if resolution in [64, 70, 80, 90, 100]:
1454
+ # tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
1455
+ # else:
1456
+ tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
1457
+ tets = np.load(tet_fn)
1458
+ verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
1459
+ indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
1460
+ # split #
1461
+ # verts; verts; #
1462
+ minn_verts, _ = torch.min(verts, dim=0)
1463
+ maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
1464
+ # scale_verts = maxx_verts - minn_verts
1465
+ scale_bounds = bound_max - bound_min # scale bounds #
1466
+
1467
+ ### scale the vertices ###
1468
+ scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
1469
+
1470
+ # scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
1471
+
1472
+ scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
1473
+ # scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
1474
+
1475
+ # scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
1476
+ # scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
1477
+
1478
+ sdf_values = []
1479
+ N = 64
1480
+ query_bundles = N ** 3 ### N^3
1481
+ query_NNs = scaled_verts.size(0) // query_bundles
1482
+ if query_NNs * query_bundles < scaled_verts.size(0):
1483
+ query_NNs += 1
1484
+ for i_query in range(query_NNs):
1485
+ cur_bundle_st = i_query * query_bundles
1486
+ cur_bundle_ed = (i_query + 1) * query_bundles
1487
+ cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
1488
+ cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
1489
+ # if def_func is not None:
1490
+ cur_query_pts, sdf_selector = self.deform_pts_with_selector(pts=cur_query_pts, pts_ts=pts_ts)
1491
+
1492
+ prev_query_vals = -self.prev_sdf_network.sdf(cur_query_pts)
1493
+ cur_query_vals = -self.sdf_network.sdf(cur_query_pts)
1494
+ cur_query_vals = torch.stack([prev_query_vals, cur_query_vals], dim=1)
1495
+ cur_query_vals = batched_index_select(cur_query_vals, sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
1496
+
1497
+ # cur_query_vals = query_func(cur_query_pts)
1498
+ sdf_values.append(cur_query_vals)
1499
+ sdf_values = torch.cat(sdf_values, dim=0)
1500
+ # print(f"queryed sdf values: {sdf_values.size()}") #
1501
+
1502
+ GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
1503
+ GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
1504
+
1505
+ # intrinsic, tet values, pts values, sdf network #
1506
+ triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
1507
+ tet_table, num_tets_table = create_tetmesh_variables(device)
1508
+
1509
+ sdf_values = sdf_values.squeeze(-1) # how the rendering #
1510
+
1511
+ # print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
1512
+ # print(f"scaled_verts: {scaled_verts.size()}, ")
1513
+ # pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
1514
+ # return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
1515
+ # marching_tets_tetmesh ##
1516
+ verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
1517
+ ### use the GT sdf values for the marching tets ###
1518
+ GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
1519
+
1520
+ # print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
1521
+ return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
1522
+
1523
+
1524
+ def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
1525
+ return extract_geometry(bound_min, # extract geometry #
1526
+ bound_max,
1527
+ resolution=resolution,
1528
+ threshold=threshold,
1529
+ # query_func=lambda pts: -self.sdf_network.sdf(pts),
1530
+ query_func=lambda pts: -self.query_func_sdf(pts)
1531
+ )
1532
+
1533
+ # if self.deform_pts_with_selector:
1534
+ # pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
1535
+ def extract_geometry_tets(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
1536
+ if wdef:
1537
+ return extract_geometry_tets(bound_min, # extract geometry #
1538
+ bound_max,
1539
+ resolution=resolution,
1540
+ threshold=threshold,
1541
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1542
+ def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts))
1543
+ else:
1544
+ return extract_geometry_tets(bound_min, # extract geometry #
1545
+ bound_max,
1546
+ resolution=resolution,
1547
+ threshold=threshold,
1548
+ # query_func=lambda pts: -self.sdf_network.sdf(pts)
1549
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1550
+ )
1551
+
1552
+
1553
+ def extract_geometry_tets_kinematic(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
1554
+ if wdef:
1555
+ return extract_geometry_tets(bound_min, # extract geometry #
1556
+ bound_max,
1557
+ resolution=resolution,
1558
+ threshold=threshold,
1559
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1560
+ def_func=lambda pts: self.deform_pts_kinematic(pts, pts_ts=pts_ts))
1561
+ else:
1562
+ return extract_geometry_tets(bound_min, # extract geometry #
1563
+ bound_max,
1564
+ resolution=resolution,
1565
+ threshold=threshold,
1566
+ # query_func=lambda pts: -self.sdf_network.sdf(pts)
1567
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1568
+ )
1569
+
1570
+
1571
+ def extract_geometry_tets_active(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
1572
+ if wdef:
1573
+ return extract_geometry_tets(bound_min, # extract geometry #
1574
+ bound_max,
1575
+ resolution=resolution,
1576
+ threshold=threshold,
1577
+ query_func=lambda pts: -self.query_func_active(pts), # lambda pts: -self.sdf_network.sdf(pts),
1578
+ def_func=lambda pts: self.deform_pts_kinematic_active(pts, pts_ts=pts_ts))
1579
+ else:
1580
+ return extract_geometry_tets(bound_min, # extract geometry #
1581
+ bound_max,
1582
+ resolution=resolution,
1583
+ threshold=threshold,
1584
+ # query_func=lambda pts: -self.sdf_network.sdf(pts)
1585
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1586
+ )
1587
+
1588
+ def extract_geometry_tets_passive(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
1589
+ if wdef:
1590
+ return extract_geometry_tets(bound_min, # extract geometry #
1591
+ bound_max,
1592
+ resolution=resolution,
1593
+ threshold=threshold,
1594
+ query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts),
1595
+ def_func=lambda pts: self.deform_pts_passive(pts, pts_ts=pts_ts))
1596
+ else:
1597
+ return extract_geometry_tets(bound_min, # extract geometry #
1598
+ bound_max,
1599
+ resolution=resolution,
1600
+ threshold=threshold,
1601
+ # query_func=lambda pts: -self.sdf_network.sdf(pts)
1602
+ query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
1603
+ )
models/test.js ADDED
File without changes
pre-requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pip==23.3.2
2
+ torch==2.2.0
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/cpu/torch_stable.html
2
+ -f https://data.pyg.org/whl/torch-2.2.0%2Bcpu.html
3
+ # pip==20.2.4
4
+ torch==2.2.0
5
+ # torchvision==0.13.1
6
+ # torchaudio==0.12.1
7
+ scipy
8
+ trimesh
9
+ icecream
10
+ tqdm
11
+ pyhocon
12
+ open3d
13
+ tensorboard
14
+
15
+ # blobfile==2.0.1
16
+ # manopth @ git+https://github.com/hassony2/manopth.git
17
+ # numpy==1.23.1
18
+ # psutil==5.9.2
19
+ # scikit-learn
20
+ # scipy==1.9.3
21
+ # tensorboard
22
+ # tensorboardx
23
+ # tqdm
24
+ # trimesh
25
+ # clip
26
+ # chumpy
27
+ # opencv-python
scripts_demo/train_grab_pointset_points_dyn_s1.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+
5
+
6
+
7
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
8
+
9
+
10
+
11
+ export trainer=exp_runner_stage_1.py
12
+
13
+
14
+ export mode="train_point_set"
15
+
16
+ export conf=dyn_grab_pointset_points_dyn_s1.conf
17
+
18
+ export conf_root="./confs_new"
19
+
20
+
21
+ export data_path="./data/102_grab_all_data.npy"
22
+ # bash scripts_new/train_grab_pointset_points_dyn_s1.sh
23
+
24
+ export cuda_ids="0"
25
+
26
+
27
+ CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case} --data_path=${data_path}
28
+
scripts_new/train_grab_mano.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+ export conf=wmask_refine_passive_rigidtrans_forward.conf
5
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
6
+
7
+
8
+
9
+ export trainer=exp_runner_stage_1.py
10
+
11
+ export mode="train_dyn_mano_model"
12
+
13
+
14
+ export conf=dyn_grab_pointset_mano.conf
15
+
16
+ export conf_root="./confs_new"
17
+
18
+
19
+
20
+ # bash scripts_new/train_grab_mano.sh
21
+
22
+ export cuda_ids="0"
23
+ #
24
+
25
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
26
+
scripts_new/train_grab_mano_wreact.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+
5
+
6
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
7
+
8
+
9
+
10
+ export trainer=exp_runner_stage_1.py
11
+
12
+ export mode="train_dyn_mano_model_wreact" ## wreact ## wreact ##
13
+
14
+ export conf=dyn_grab_pointset_mano_dyn.conf
15
+
16
+
17
+ export conf_root="./confs_new"
18
+
19
+
20
+
21
+ export cuda_ids="0"
22
+
23
+
24
+
25
+ CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
26
+
scripts_new/train_grab_mano_wreact_optacts.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+
5
+
6
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
7
+
8
+
9
+
10
+ export trainer=exp_runner_stage_1.py
11
+
12
+ export mode="train_dyn_mano_model_wreact" ## wreact ## wreact ##
13
+
14
+ export conf=dyn_grab_pointset_mano_dyn.conf
15
+
16
+
17
+ export conf_root="./confs_new"
18
+
19
+
20
+
21
+ export cuda_ids="0"
22
+
23
+
24
+
25
+ CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
26
+
scripts_new/train_grab_pointset.sh ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+ export cuda_ids="4"
5
+
6
+ # export cuda_ids="5"
7
+
8
+
9
+ export trainer=exp_runner_arti_forward.py
10
+ export conf=wmask_refine_passive_rigidtrans_forward.conf
11
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
12
+
13
+
14
+ # export trainer=exp_runner_arti_multi_objs_compositional.py
15
+ export trainer=exp_runner_arti_multi_objs_compositional_ks.py
16
+ # export trainer=exp_runner_arti_multi_objs_dyn.py
17
+ export trainer=exp_runner_arti_multi_objs_pointset.py
18
+ # /home/xueyi/diffsim/NeuS/confs/wmask_refine_passive_compositional.conf
19
+ export conf=wmask_refine_passive_compositional.conf
20
+ export conf=dyn_arctic_ks.conf
21
+
22
+ export conf=dyn_arctic_ks_robohand.conf
23
+ # /data/xueyi/diffsim/NeuS/confs/dyn_arctic_ks_robohand_from_mano_model_rules.conf
24
+ export conf=dyn_arctic_ks_robohand_from_mano_model_rules.conf
25
+ export conf=dyn_arctic_robohand_from_mano_model_rules.conf
26
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions.conf
27
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2.conf
28
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand.conf
29
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v2.conf
30
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v4.conf
31
+
32
+ # /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
33
+ export conf=dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
34
+ # /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
35
+ # export conf=dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
36
+
37
+ export conf=dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
38
+ # export conf=dyn_arctic_ks_robohand_from_mano_model_rules_arti.conf
39
+ # a very stiff system for #
40
+ export conf=dyn_grab_pointset_mano.conf
41
+
42
+ export mode="train_from_model_rules"
43
+ export mode="train_from_model_rules"
44
+
45
+ export mode="train_sdf_from_model_rules"
46
+ export mode="train_actions_from_model_rules"
47
+ # export mode="train_actions_from_sim_rules"
48
+
49
+
50
+
51
+ export mode="train_def"
52
+ # export mode="train_actions_from_model_rules"
53
+ export mode="train_mano_actions_from_model_rules"
54
+
55
+ # virtual force # # ## virtual forces ##
56
+ # /data/xueyi/diffsim/NeuS/confs/dyn_arctic_ks_robohand_from_mano_model_rules.conf ##
57
+ export mode="train_actions_from_model_rules"
58
+
59
+ export mode="train_actions_from_mano_model_rules"
60
+
61
+ export mode="train_real_robot_actions_from_mano_model_rules"
62
+ export mode="train_real_robot_actions_from_mano_model_rules_diffhand"
63
+
64
+ export mode="train_real_robot_actions_from_mano_model_rules_diffhand_fortest"
65
+ export mode="train_real_robot_actions_from_mano_model_rules_manohand_fortest"
66
+
67
+
68
+ export mode="train_real_robot_actions_from_mano_model_rules_diffhand_fortest"
69
+
70
+
71
+
72
+ export mode="train_real_robot_actions_from_mano_model_rules_manohand_fortest_states"
73
+
74
+ ###### diffsim ######
75
+ # /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v3.conf
76
+
77
+ export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_world"
78
+ export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_rl"
79
+
80
+ export mode="train_dyn_mano_model_states"
81
+
82
+
83
+ ## train dyn mano model states wreact ##
84
+ export mode="train_dyn_mano_model_states_wreact"
85
+ export conf=dyn_grab_pointset_mano_dyn.conf
86
+ ## train dyn mano model states wreact ##
87
+
88
+
89
+ export conf_root="./confs_new"
90
+
91
+
92
+ #
93
+ # bash scripts_new/train_grab_pointset_dyn.sh
94
+
95
+ export cuda_ids="2"
96
+ #
97
+
98
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
99
+
scripts_new/train_grab_pointset_points_dyn.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+
5
+
6
+
7
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
8
+
9
+
10
+
11
+ export trainer=exp_runner_stage_1.py
12
+
13
+
14
+ export mode="train_point_set"
15
+
16
+ export conf=dyn_grab_pointset_points_dyn.conf
17
+
18
+ export conf_root="./confs_new"
19
+
20
+
21
+ # bash scripts_new/train_grab_pointset_points_dyn.sh
22
+
23
+ export cuda_ids="0"
24
+
25
+
26
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
27
+
scripts_new/train_grab_pointset_points_dyn_retar.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+
5
+
6
+
7
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
8
+
9
+
10
+
11
+ export trainer=exp_runner_stage_1.py
12
+
13
+
14
+
15
+ ### stage 1 -> tracking MANO expanded set using Shadow's object mesh points ###
16
+ export mode="train_point_set_retar"
17
+ export conf=dyn_grab_pointset_points_dyn_retar.conf
18
+
19
+ # ### stage 2 -> tracking MANO expanded set using Shadow's expanded points ###
20
+ # export mode="train_expanded_set_motions_retar_pts"
21
+ # export conf=dyn_grab_pointset_points_dyn_retar_pts.conf
22
+
23
+
24
+
25
+
26
+ export conf_root="./confs_new"
27
+
28
+
29
+ # bash scripts_new/train_grab_pointset_points_dyn_retar.sh
30
+
31
+ export cuda_ids="0"
32
+
33
+
34
+
35
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
36
+
scripts_new/train_grab_pointset_points_dyn_retar_pts.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+
5
+
6
+
7
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
8
+
9
+
10
+
11
+ export trainer=exp_runner_stage_1.py
12
+
13
+
14
+
15
+ # ### stage 1 -> tracking MANO expanded set using Shadow's object mesh points ###
16
+ # export mode="train_expanded_set_motions_retar"
17
+ # export conf=dyn_grab_pointset_points_dyn_retar.conf
18
+
19
+ ### stage 2 -> tracking MANO expanded set using Shadow's expanded points ###
20
+ export mode="train_point_set_retar_pts"
21
+ export conf=dyn_grab_pointset_points_dyn_retar_pts.conf
22
+
23
+
24
+
25
+
26
+ export conf_root="./confs_new"
27
+
28
+
29
+ # bash scripts_new/train_grab_pointset_points_dyn_retar.sh
30
+
31
+ export cuda_ids="0"
32
+
33
+
34
+
35
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
36
+
scripts_new/train_grab_pointset_points_dyn_s1.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+
5
+
6
+
7
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
8
+
9
+
10
+
11
+ export trainer=exp_runner_stage_1.py
12
+
13
+
14
+ export mode="train_point_set"
15
+
16
+ export conf=dyn_grab_pointset_points_dyn_s1.conf
17
+
18
+ export conf_root="./confs_new"
19
+
20
+
21
+ # bash scripts_new/train_grab_pointset_points_dyn_s1.sh
22
+
23
+ export cuda_ids="0"
24
+
25
+
26
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
27
+
scripts_new/train_grab_pointset_points_dyn_s2.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+
5
+
6
+
7
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
8
+
9
+
10
+
11
+ export trainer=exp_runner_stage_1.py
12
+
13
+
14
+ export mode="train_point_set"
15
+
16
+ export conf=dyn_grab_pointset_points_dyn_s2.conf
17
+
18
+ export conf_root="./confs_new"
19
+
20
+
21
+ # bash scripts_new/train_grab_pointset_points_dyn_s2.sh
22
+
23
+ export cuda_ids="0"
24
+
25
+
26
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
27
+
scripts_new/train_grab_pointset_points_dyn_s3.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+
5
+
6
+
7
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
8
+
9
+
10
+
11
+ export trainer=exp_runner_stage_1.py
12
+
13
+
14
+ export mode="train_point_set"
15
+
16
+ export conf=dyn_grab_pointset_points_dyn_s3.conf
17
+
18
+ export conf_root="./confs_new"
19
+
20
+
21
+ # bash scripts_new/train_grab_pointset_points_dyn_s3.sh
22
+
23
+ export cuda_ids="0"
24
+
25
+
26
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
27
+
scripts_new/train_grab_pointset_points_dyn_s4.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+
5
+
6
+
7
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
8
+
9
+
10
+
11
+ export trainer=exp_runner_stage_1.py
12
+
13
+
14
+ export mode="train_point_set"
15
+
16
+ export conf=dyn_grab_pointset_points_dyn_s4.conf
17
+
18
+ export conf_root="./confs_new"
19
+
20
+
21
+ # bash scripts_new/train_grab_pointset_points_dyn_s4.sh
22
+
23
+ export cuda_ids="0"
24
+
25
+
26
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
27
+
scripts_new/train_grab_shadow_multistages.sh ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+ export cuda_ids="3"
5
+
6
+
7
+
8
+ export trainer=exp_runner_arti_forward.py
9
+ export conf=wmask_refine_passive_rigidtrans_forward.conf
10
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
11
+
12
+
13
+ # export trainer=exp_runner_arti_multi_objs_compositional.py
14
+ export trainer=exp_runner_arti_multi_objs_compositional_ks.py
15
+ export trainer=exp_runner_arti_multi_objs_arti_dyn.py
16
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand.conf
17
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v2.conf
18
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v4.conf
19
+
20
+ # /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
21
+ export conf=dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
22
+
23
+ # /home/xueyi/diffsim/NeuS/confs/dyn_grab_mano_model_states.conf
24
+ export conf=dyn_grab_mano_model_states.conf
25
+
26
+ export conf=dyn_grab_shadow_model_states.conf
27
+
28
+
29
+
30
+
31
+ export mode="train_def"
32
+ # export mode="train_actions_from_model_rules"
33
+ export mode="train_mano_actions_from_model_rules"
34
+
35
+ # virtual force #
36
+ # /data/xueyi/diffsim/NeuS/confs/dyn_arctic_ks_robohand_from_mano_model_rules.conf ##
37
+
38
+ export mode="train_real_robot_actions_from_mano_model_rules_diffhand_fortest"
39
+
40
+
41
+
42
+ export mode="train_real_robot_actions_from_mano_model_rules_manohand_fortest_states"
43
+
44
+
45
+ ### using the diffsim ###
46
+
47
+ ###### diffsim ######
48
+ # /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v3.conf
49
+
50
+
51
+ export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_world"
52
+ export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_rl"
53
+
54
+ export mode="train_dyn_mano_model_states"
55
+
56
+ #
57
+ export mode="train_real_robot_actions_from_mano_model_rules_v5_shadowhand_fortest_states_grab"
58
+ ## dyn grab shadow
59
+
60
+
61
+ ### optimize for the manipulatable hand actions ###
62
+ export mode="train_real_robot_actions_from_mano_model_rules_v5_shadowhand_fortest_states_grab_redmax_acts"
63
+
64
+
65
+
66
+ # ## acts ##
67
+ # export conf=dyn_grab_shadow_model_states_224.conf
68
+ # export conf=dyn_grab_shadow_model_states_89.conf
69
+ # export conf=dyn_grab_shadow_model_states_102.conf
70
+ # export conf=dyn_grab_shadow_model_states_7.conf
71
+ # export conf=dyn_grab_shadow_model_states_47.conf
72
+ # export conf=dyn_grab_shadow_model_states_67.conf
73
+ # export conf=dyn_grab_shadow_model_states_76.conf
74
+ # export conf=dyn_grab_shadow_model_states_85.conf
75
+ # # export conf=dyn_grab_shadow_model_states_91.conf
76
+ # # export conf=dyn_grab_shadow_model_states_167.conf
77
+ # export conf=dyn_grab_shadow_model_states_107.conf
78
+ # export conf=dyn_grab_shadow_model_states_306.conf
79
+ # export conf=dyn_grab_shadow_model_states_313.conf
80
+ # export conf=dyn_grab_shadow_model_states_322.conf
81
+
82
+
83
+ # /home/xueyi/diffsim/NeuS/confs_new/dyn_grab_arti_shadow_multi_stages.conf
84
+ export conf=dyn_grab_arti_shadow_multi_stages.conf
85
+ # export conf=dyn_grab_shadow_model_states_398.conf
86
+ # export conf=dyn_grab_shadow_model_states_363.conf
87
+ # export conf=dyn_grab_shadow_model_states_358.conf
88
+
89
+ # bash scripts_new/train_grab_shadow_multistages.sh
90
+
91
+
92
+
93
+
94
+
95
+ export conf_root="./confs_new"
96
+
97
+
98
+ export cuda_ids="4"
99
+
100
+
101
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
102
+
scripts_new/train_grab_shadow_singlestage.sh ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ export PYTHONPATH=.
3
+
4
+ export cuda_ids="3"
5
+
6
+
7
+
8
+ export trainer=exp_runner_arti_forward.py
9
+ export conf=wmask_refine_passive_rigidtrans_forward.conf
10
+ export data_case=hand_test_routine_2_light_color_wtime_active_passive
11
+
12
+
13
+ # export trainer=exp_runner_arti_multi_objs_compositional.py
14
+ export trainer=exp_runner_arti_multi_objs_compositional_ks.py
15
+ export trainer=exp_runner_arti_multi_objs_arti_dyn.py
16
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand.conf
17
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v2.conf
18
+ export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v4.conf
19
+
20
+ # /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
21
+ export conf=dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
22
+
23
+ # /home/xueyi/diffsim/NeuS/confs/dyn_grab_mano_model_states.conf
24
+ export conf=dyn_grab_mano_model_states.conf
25
+
26
+ export conf=dyn_grab_shadow_model_states.conf
27
+
28
+
29
+
30
+
31
+ export mode="train_def"
32
+ # export mode="train_actions_from_model_rules"
33
+ export mode="train_mano_actions_from_model_rules"
34
+
35
+ # virtual force #
36
+ # /data/xueyi/diffsim/NeuS/confs/dyn_arctic_ks_robohand_from_mano_model_rules.conf ##
37
+
38
+ export mode="train_real_robot_actions_from_mano_model_rules_diffhand_fortest"
39
+
40
+
41
+
42
+ export mode="train_real_robot_actions_from_mano_model_rules_manohand_fortest_states"
43
+
44
+
45
+ ###### diffsim ######
46
+ # /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v3.conf
47
+
48
+
49
+ export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_world"
50
+ export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_rl"
51
+
52
+ export mode="train_dyn_mano_model_states"
53
+
54
+ #
55
+ export mode="train_real_robot_actions_from_mano_model_rules_v5_shadowhand_fortest_states_grab"
56
+ ## dyn grab shadow
57
+
58
+ ### optimze for the redmax hand actions from joint states ###
59
+ export mode="train_real_robot_actions_from_mano_model_rules_shadowhand"
60
+
61
+ ### optimize for the manipulatable hand actions ###
62
+ export mode="train_real_robot_actions_from_mano_model_rules_v5_shadowhand_fortest_states_grab_redmax_acts"
63
+
64
+
65
+
66
+ # ## acts ##
67
+ # export conf=dyn_grab_shadow_model_states_224.conf
68
+ # export conf=dyn_grab_shadow_model_states_89.conf
69
+ # export conf=dyn_grab_shadow_model_states_102.conf
70
+ # export conf=dyn_grab_shadow_model_states_7.conf
71
+ # export conf=dyn_grab_shadow_model_states_47.conf
72
+ # export conf=dyn_grab_shadow_model_states_67.conf
73
+ # export conf=dyn_grab_shadow_model_states_76.conf
74
+ # export conf=dyn_grab_shadow_model_states_85.conf
75
+ # # export conf=dyn_grab_shadow_model_states_91.conf
76
+ # # export conf=dyn_grab_shadow_model_states_167.conf
77
+ # export conf=dyn_grab_shadow_model_states_107.conf
78
+ # export conf=dyn_grab_shadow_model_states_306.conf
79
+ # export conf=dyn_grab_shadow_model_states_313.conf
80
+ # export conf=dyn_grab_shadow_model_states_322.conf
81
+
82
+ # /home/xueyi/diffsim/NeuS/confs_new/dyn_grab_arti_shadow_multi_stages.conf
83
+ # /home/xueyi/diffsim/NeuS/confs_new/dyn_grab_arti_shadow_single_stage.conf
84
+ export conf=dyn_grab_arti_shadow_single_stage.conf
85
+ # export conf=dyn_grab_shadow_model_states_398.conf
86
+ # export conf=dyn_grab_shadow_model_states_363.conf
87
+ # export conf=dyn_grab_shadow_model_states_358.conf
88
+
89
+ # bash scripts_new/train_grab_shadow_singlestage.sh
90
+
91
+
92
+
93
+
94
+ export conf_root="./confs_new"
95
+
96
+
97
+ export cuda_ids="3"
98
+
99
+
100
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
101
+