Spaces:
Running
Running
meow
commited on
Commit
•
710e818
1
Parent(s):
cb900d8
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +193 -0
- README.md +3 -3
- confs_new/dyn_grab_arti_shadow_dm.conf +288 -0
- confs_new/dyn_grab_arti_shadow_dm_curriculum.conf +326 -0
- confs_new/dyn_grab_arti_shadow_dm_singlestage.conf +318 -0
- confs_new/dyn_grab_pointset_mano.conf +215 -0
- confs_new/dyn_grab_pointset_mano_dyn.conf +218 -0
- confs_new/dyn_grab_pointset_mano_dyn_optacts.conf +218 -0
- confs_new/dyn_grab_pointset_points_dyn.conf +257 -0
- confs_new/dyn_grab_pointset_points_dyn_retar.conf +274 -0
- confs_new/dyn_grab_pointset_points_dyn_retar_pts.conf +281 -0
- confs_new/dyn_grab_pointset_points_dyn_retar_pts_opts.conf +287 -0
- confs_new/dyn_grab_pointset_points_dyn_s1.conf +256 -0
- confs_new/dyn_grab_pointset_points_dyn_s2.conf +258 -0
- confs_new/dyn_grab_pointset_points_dyn_s3.conf +259 -0
- confs_new/dyn_grab_pointset_points_dyn_s4.conf +259 -0
- confs_new/dyn_grab_sparse_retar.conf +214 -0
- exp_runner_stage_1.py +0 -0
- models/data_utils_torch.py +1547 -0
- models/dataset.py +359 -0
- models/dataset_wtime.py +403 -0
- models/dyn_model_act.py +0 -0
- models/dyn_model_act_v2.py +0 -0
- models/dyn_model_act_v2_deformable.py +1582 -0
- models/dyn_model_utils.py +1369 -0
- models/embedder.py +51 -0
- models/fields.py +0 -0
- models/fields_old.py +0 -0
- models/renderer.py +641 -0
- models/renderer_def.py +725 -0
- models/renderer_def_multi_objs.py +1088 -0
- models/renderer_def_multi_objs_compositional.py +1510 -0
- models/renderer_def_multi_objs_rigidtrans_forward.py +1603 -0
- models/test.js +0 -0
- pre-requirements.txt +2 -0
- requirements.txt +27 -0
- scripts_demo/train_grab_pointset_points_dyn_s1.sh +28 -0
- scripts_new/train_grab_mano.sh +26 -0
- scripts_new/train_grab_mano_wreact.sh +26 -0
- scripts_new/train_grab_mano_wreact_optacts.sh +26 -0
- scripts_new/train_grab_pointset.sh +99 -0
- scripts_new/train_grab_pointset_points_dyn.sh +27 -0
- scripts_new/train_grab_pointset_points_dyn_retar.sh +36 -0
- scripts_new/train_grab_pointset_points_dyn_retar_pts.sh +36 -0
- scripts_new/train_grab_pointset_points_dyn_s1.sh +27 -0
- scripts_new/train_grab_pointset_points_dyn_s2.sh +27 -0
- scripts_new/train_grab_pointset_points_dyn_s3.sh +27 -0
- scripts_new/train_grab_pointset_points_dyn_s4.sh +27 -0
- scripts_new/train_grab_shadow_multistages.sh +102 -0
- scripts_new/train_grab_shadow_singlestage.sh +101 -0
.gitignore
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by .ignore support plugin (hsz.mobi)
|
2 |
+
### Python template
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
*.npy
|
12 |
+
|
13 |
+
*.zip
|
14 |
+
|
15 |
+
*.ply
|
16 |
+
|
17 |
+
|
18 |
+
#### old files ####
|
19 |
+
exp_runner_arti_*
|
20 |
+
|
21 |
+
exp_runner.py
|
22 |
+
exp_runner_arti.py
|
23 |
+
exp_runner_dyn_model.py
|
24 |
+
exp_runner_sim.py
|
25 |
+
get-pip.py
|
26 |
+
|
27 |
+
test.py
|
28 |
+
|
29 |
+
|
30 |
+
#### old scripts and data&exp folders ###
|
31 |
+
scripts/
|
32 |
+
confs/
|
33 |
+
data/
|
34 |
+
ckpts/
|
35 |
+
exp/
|
36 |
+
uni_rep/
|
37 |
+
|
38 |
+
*/*_local.sh
|
39 |
+
*/*_local.conf
|
40 |
+
|
41 |
+
# Distribution / packaging
|
42 |
+
.Python
|
43 |
+
env/
|
44 |
+
build/
|
45 |
+
develop-eggs/
|
46 |
+
dist/
|
47 |
+
downloads/
|
48 |
+
eggs/
|
49 |
+
.eggs/
|
50 |
+
lib/
|
51 |
+
lib64/
|
52 |
+
parts/
|
53 |
+
sdist/
|
54 |
+
var/
|
55 |
+
*.egg-info/
|
56 |
+
.installed.cfg
|
57 |
+
*.egg
|
58 |
+
|
59 |
+
rsc/
|
60 |
+
raw_data/
|
61 |
+
|
62 |
+
# PyInstaller
|
63 |
+
# Usually these files are written by a python script from a template
|
64 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
65 |
+
*.manifest
|
66 |
+
*.spec
|
67 |
+
|
68 |
+
# Installer logs
|
69 |
+
pip-log.txt
|
70 |
+
pip-delete-this-directory.txt
|
71 |
+
|
72 |
+
# Unit test / coverage reports
|
73 |
+
htmlcov/
|
74 |
+
.tox/
|
75 |
+
.coverage
|
76 |
+
.coverage.*
|
77 |
+
.cache
|
78 |
+
nosetests.xml
|
79 |
+
coverage.xml
|
80 |
+
*,cover
|
81 |
+
.hypothesis/
|
82 |
+
|
83 |
+
# Translations
|
84 |
+
*.mo
|
85 |
+
*.pot
|
86 |
+
|
87 |
+
# Django stuff:
|
88 |
+
*.log
|
89 |
+
local_settings.py
|
90 |
+
|
91 |
+
# Flask stuff:
|
92 |
+
instance/
|
93 |
+
.webassets-cache
|
94 |
+
|
95 |
+
# Scrapy stuff:
|
96 |
+
.scrapy
|
97 |
+
|
98 |
+
# Sphinx documentation
|
99 |
+
docs/_build/
|
100 |
+
|
101 |
+
# PyBuilder
|
102 |
+
target/
|
103 |
+
|
104 |
+
# IPython Notebook
|
105 |
+
.ipynb_checkpoints
|
106 |
+
|
107 |
+
# pyenv
|
108 |
+
.python-version
|
109 |
+
|
110 |
+
# celery beat schedule file
|
111 |
+
celerybeat-schedule
|
112 |
+
|
113 |
+
# dotenv
|
114 |
+
.env
|
115 |
+
|
116 |
+
# virtualenv
|
117 |
+
venv/
|
118 |
+
ENV/
|
119 |
+
|
120 |
+
# Spyder project settings
|
121 |
+
.spyderproject
|
122 |
+
|
123 |
+
# Rope project settings
|
124 |
+
.ropeproject
|
125 |
+
### VirtualEnv template
|
126 |
+
# Virtualenv
|
127 |
+
# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
|
128 |
+
.Python
|
129 |
+
[Bb]in
|
130 |
+
[Ii]nclude
|
131 |
+
[Ll]ib
|
132 |
+
[Ll]ib64
|
133 |
+
[Ll]ocal
|
134 |
+
# [Ss]cripts
|
135 |
+
pyvenv.cfg
|
136 |
+
.venv
|
137 |
+
pip-selfcheck.json
|
138 |
+
### JetBrains template
|
139 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
|
140 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
141 |
+
|
142 |
+
# User-specific stuff:
|
143 |
+
.idea/workspace.xml
|
144 |
+
.idea/tasks.xml
|
145 |
+
.idea/dictionaries
|
146 |
+
.idea/vcs.xml
|
147 |
+
.idea/jsLibraryMappings.xml
|
148 |
+
|
149 |
+
# Sensitive or high-churn files:
|
150 |
+
.idea/dataSources.ids
|
151 |
+
.idea/dataSources.xml
|
152 |
+
.idea/dataSources.local.xml
|
153 |
+
.idea/sqlDataSources.xml
|
154 |
+
.idea/dynamic.xml
|
155 |
+
.idea/uiDesigner.xml
|
156 |
+
|
157 |
+
# Gradle:
|
158 |
+
.idea/gradle.xml
|
159 |
+
.idea/libraries
|
160 |
+
|
161 |
+
# Mongo Explorer plugin:
|
162 |
+
.idea/mongoSettings.xml
|
163 |
+
|
164 |
+
.idea/
|
165 |
+
|
166 |
+
## File-based project format:
|
167 |
+
*.iws
|
168 |
+
|
169 |
+
|
170 |
+
## Plugin-specific files:
|
171 |
+
|
172 |
+
# IntelliJ
|
173 |
+
/out/
|
174 |
+
|
175 |
+
# mpeltonen/sbt-idea plugin
|
176 |
+
.idea_modules/
|
177 |
+
|
178 |
+
# JIRA plugin
|
179 |
+
atlassian-ide-plugin.xml
|
180 |
+
|
181 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
182 |
+
com_crashlytics_export_strings.xml
|
183 |
+
crashlytics.properties
|
184 |
+
crashlytics-build.properties
|
185 |
+
fabric.properties
|
186 |
+
|
187 |
+
data
|
188 |
+
public_data
|
189 |
+
exp
|
190 |
+
tmp
|
191 |
+
|
192 |
+
.models/*.npy
|
193 |
+
.models/*.ply
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: Quasi Physical Sims
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.25.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
title: Quasi Physical Sims
|
3 |
+
emoji: 🏃
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.25.0
|
8 |
app_file: app.py
|
confs_new/dyn_grab_arti_shadow_dm.conf
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
4 |
+
|
5 |
+
tag = "train_retargeted_shadow_hand_seq_102_diffhand_model_"
|
6 |
+
|
7 |
+
recording = [
|
8 |
+
./,
|
9 |
+
./models
|
10 |
+
]
|
11 |
+
}
|
12 |
+
|
13 |
+
|
14 |
+
dataset {
|
15 |
+
data_dir = public_data/CASE_NAME/
|
16 |
+
render_cameras_name = cameras_sphere.npz
|
17 |
+
object_cameras_name = cameras_sphere.npz
|
18 |
+
|
19 |
+
obj_idx = 102
|
20 |
+
}
|
21 |
+
|
22 |
+
train {
|
23 |
+
learning_rate = 5e-4
|
24 |
+
# learning_rate = 5e-6
|
25 |
+
learning_rate_actions = 5e-6
|
26 |
+
learning_rate_alpha = 0.05
|
27 |
+
end_iter = 300000
|
28 |
+
|
29 |
+
# batch_size = 128 # 64
|
30 |
+
# batch_size = 4000
|
31 |
+
# batch_size = 3096 # 64
|
32 |
+
batch_size = 1024
|
33 |
+
validate_resolution_level = 4
|
34 |
+
warm_up_end = 5000
|
35 |
+
anneal_end = 0
|
36 |
+
use_white_bkgd = False
|
37 |
+
|
38 |
+
# save_freq = 10000
|
39 |
+
save_freq = 10000
|
40 |
+
val_freq = 20 # 2500
|
41 |
+
val_mesh_freq = 20 # 5000
|
42 |
+
report_freq = 10
|
43 |
+
### igr weight ###
|
44 |
+
igr_weight = 0.1
|
45 |
+
mask_weight = 0.1
|
46 |
+
}
|
47 |
+
|
48 |
+
model {
|
49 |
+
|
50 |
+
|
51 |
+
load_redmax_robot_actions = ""
|
52 |
+
penetrating_depth_penalty = 0.0
|
53 |
+
train_states = True
|
54 |
+
|
55 |
+
minn_dist_threshold = 0.000
|
56 |
+
obj_mass = 100.0
|
57 |
+
obj_mass = 30.0
|
58 |
+
|
59 |
+
# use_mano_hand_for_test = False
|
60 |
+
use_mano_hand_for_test = True
|
61 |
+
|
62 |
+
# train_residual_friction = False
|
63 |
+
train_residual_friction = True
|
64 |
+
|
65 |
+
# use_LBFGS = True
|
66 |
+
use_LBFGS = False
|
67 |
+
|
68 |
+
|
69 |
+
use_mano_hand_for_test = False
|
70 |
+
train_residual_friction = True
|
71 |
+
|
72 |
+
extract_delta_mesh = False
|
73 |
+
freeze_weights = True
|
74 |
+
# gt_act_xs_def = True
|
75 |
+
gt_act_xs_def = False
|
76 |
+
use_bending_network = True
|
77 |
+
use_delta_bending = True
|
78 |
+
|
79 |
+
use_passive_nets = True
|
80 |
+
use_split_network = True
|
81 |
+
|
82 |
+
n_timesteps = 60
|
83 |
+
|
84 |
+
|
85 |
+
# using_delta_glb_trans = True
|
86 |
+
using_delta_glb_trans = False
|
87 |
+
# train_multi_seqs = True
|
88 |
+
|
89 |
+
|
90 |
+
# optimize_with_intermediates = False
|
91 |
+
optimize_with_intermediates = True
|
92 |
+
|
93 |
+
|
94 |
+
loss_tangential_diff_coef = 1000
|
95 |
+
loss_tangential_diff_coef = 0
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
optimize_active_object = True
|
100 |
+
|
101 |
+
no_friction_constraint = False
|
102 |
+
|
103 |
+
optimize_glb_transformations = True
|
104 |
+
|
105 |
+
|
106 |
+
sim_model_path = "DiffHand/assets/hand_sphere_only_hand_testt.xml"
|
107 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
108 |
+
mano_mult_const_after_cent = 1.0
|
109 |
+
sim_num_steps = 1000000
|
110 |
+
|
111 |
+
bending_net_type = "active_force_field_v18"
|
112 |
+
|
113 |
+
|
114 |
+
### try to train the residual friction ? ###
|
115 |
+
train_residual_friction = True
|
116 |
+
optimize_rules = True
|
117 |
+
|
118 |
+
load_optimized_init_actions = ""
|
119 |
+
|
120 |
+
|
121 |
+
use_optimizable_params = True
|
122 |
+
|
123 |
+
|
124 |
+
### grab train seq 224 ###
|
125 |
+
penetration_determining = "sdf_of_canon"
|
126 |
+
train_with_forces_to_active = False
|
127 |
+
loss_scale_coef = 1000.0
|
128 |
+
# penetration_proj_k_to_robot_friction = 40000000.0
|
129 |
+
# penetration_proj_k_to_robot_friction = 100000000.0 # as friction coefs here #
|
130 |
+
use_same_contact_spring_k = False
|
131 |
+
|
132 |
+
|
133 |
+
# sim_model_path = "/home/xueyi/diffsim/DiffHand/assets/hand_sphere_only_hand_testt.xml"
|
134 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
135 |
+
|
136 |
+
|
137 |
+
optimize_rules = False
|
138 |
+
|
139 |
+
penetration_determining = "sdf_of_canon"
|
140 |
+
|
141 |
+
optimize_rules = False
|
142 |
+
|
143 |
+
optim_sim_model_params_from_mano = False
|
144 |
+
optimize_rules = False
|
145 |
+
|
146 |
+
|
147 |
+
penetration_proj_k_to_robot_friction = 100000000.0 # as friction coefs here # ## confs ##
|
148 |
+
penetration_proj_k_to_robot = 40000000.0 #
|
149 |
+
|
150 |
+
|
151 |
+
penetrating_depth_penalty = 1
|
152 |
+
|
153 |
+
minn_dist_threshold_robot_to_obj = 0.0
|
154 |
+
|
155 |
+
|
156 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
157 |
+
|
158 |
+
optim_sim_model_params_from_mano = False
|
159 |
+
optimize_rules = False
|
160 |
+
|
161 |
+
optim_sim_model_params_from_mano = True
|
162 |
+
optimize_rules = True
|
163 |
+
minn_dist_threshold_robot_to_obj = 0.0
|
164 |
+
|
165 |
+
optim_sim_model_params_from_mano = False
|
166 |
+
optimize_rules = False
|
167 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
168 |
+
|
169 |
+
|
170 |
+
### kinematics confgs ###
|
171 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
172 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
173 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
174 |
+
# ckpt_fn = ""
|
175 |
+
# load_optimized_init_transformations = ""
|
176 |
+
optim_sim_model_params_from_mano = True
|
177 |
+
optimize_rules = True
|
178 |
+
minn_dist_threshold_robot_to_obj = 0.0
|
179 |
+
|
180 |
+
optim_sim_model_params_from_mano = False
|
181 |
+
optimize_rules = False
|
182 |
+
optimize_rules = True
|
183 |
+
|
184 |
+
ckpt_fn = "ckpts/grab/102/retargeted_shadow.pth"
|
185 |
+
load_optimized_init_transformations = "ckpts/grab/102/retargeted_shadow.pth"
|
186 |
+
optimize_rules = False
|
187 |
+
|
188 |
+
optimize_rules = True
|
189 |
+
|
190 |
+
## opt roboto ##
|
191 |
+
opt_robo_glb_trans = True
|
192 |
+
opt_robo_glb_rot = False # opt rot # ## opt rot ##
|
193 |
+
opt_robo_states = True
|
194 |
+
|
195 |
+
|
196 |
+
use_multi_stages = False
|
197 |
+
|
198 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
199 |
+
|
200 |
+
penetration_proj_k_to_robot = 40000
|
201 |
+
penetration_proj_k_to_robot_friction = 100000
|
202 |
+
|
203 |
+
drive_robot = "actions"
|
204 |
+
opt_robo_glb_trans = False
|
205 |
+
opt_robo_states = True
|
206 |
+
opt_robo_glb_rot = False
|
207 |
+
|
208 |
+
train_with_forces_to_active = True
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
load_redmax_robot_actions_fn = ""
|
213 |
+
|
214 |
+
|
215 |
+
optimize_rules = False
|
216 |
+
|
217 |
+
|
218 |
+
|
219 |
+
loss_scale_coef = 1.0
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
use_opt_rigid_translations=True
|
224 |
+
|
225 |
+
train_def = True
|
226 |
+
|
227 |
+
|
228 |
+
optimizable_rigid_translations=True
|
229 |
+
|
230 |
+
nerf {
|
231 |
+
D = 8,
|
232 |
+
d_in = 4,
|
233 |
+
d_in_view = 3,
|
234 |
+
W = 256,
|
235 |
+
multires = 10,
|
236 |
+
multires_view = 4,
|
237 |
+
output_ch = 4,
|
238 |
+
skips=[4],
|
239 |
+
use_viewdirs=True
|
240 |
+
}
|
241 |
+
|
242 |
+
sdf_network {
|
243 |
+
d_out = 257,
|
244 |
+
d_in = 3,
|
245 |
+
d_hidden = 256,
|
246 |
+
n_layers = 8,
|
247 |
+
skip_in = [4],
|
248 |
+
multires = 6,
|
249 |
+
bias = 0.5,
|
250 |
+
scale = 1.0,
|
251 |
+
geometric_init = True,
|
252 |
+
weight_norm = True,
|
253 |
+
}
|
254 |
+
|
255 |
+
variance_network {
|
256 |
+
init_val = 0.3
|
257 |
+
}
|
258 |
+
|
259 |
+
rendering_network {
|
260 |
+
d_feature = 256,
|
261 |
+
mode = idr,
|
262 |
+
d_in = 9,
|
263 |
+
d_out = 3,
|
264 |
+
d_hidden = 256,
|
265 |
+
n_layers = 4,
|
266 |
+
weight_norm = True,
|
267 |
+
multires_view = 4,
|
268 |
+
squeeze_out = True,
|
269 |
+
}
|
270 |
+
|
271 |
+
neus_renderer {
|
272 |
+
n_samples = 64,
|
273 |
+
n_importance = 64,
|
274 |
+
n_outside = 0,
|
275 |
+
up_sample_steps = 4 ,
|
276 |
+
perturb = 1.0,
|
277 |
+
}
|
278 |
+
|
279 |
+
bending_network {
|
280 |
+
multires = 6,
|
281 |
+
bending_latent_size = 32,
|
282 |
+
d_in = 3,
|
283 |
+
rigidity_hidden_dimensions = 64,
|
284 |
+
rigidity_network_depth = 5,
|
285 |
+
use_rigidity_network = False,
|
286 |
+
bending_n_timesteps = 10,
|
287 |
+
}
|
288 |
+
}
|
confs_new/dyn_grab_arti_shadow_dm_curriculum.conf
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
# base_exp_dir = exp/CASE_NAME/wmask
|
3 |
+
base_exp_dir = /data2/datasets/xueyi/neus/exp/CASE_NAME/wmask
|
4 |
+
|
5 |
+
tag = "train_retargeted_shadow_hand_seq_102_diffhand_model_curriculum_"
|
6 |
+
|
7 |
+
recording = [
|
8 |
+
./,
|
9 |
+
./models
|
10 |
+
]
|
11 |
+
}
|
12 |
+
|
13 |
+
dataset {
|
14 |
+
data_dir = public_data/CASE_NAME/
|
15 |
+
render_cameras_name = cameras_sphere.npz
|
16 |
+
object_cameras_name = cameras_sphere.npz
|
17 |
+
|
18 |
+
obj_idx = 102
|
19 |
+
}
|
20 |
+
|
21 |
+
train {
|
22 |
+
learning_rate = 5e-4
|
23 |
+
learning_rate_actions = 5e-6
|
24 |
+
# learning_rate = 5e-6
|
25 |
+
# learning_rate = 5e-5
|
26 |
+
learning_rate_alpha = 0.05
|
27 |
+
end_iter = 300000
|
28 |
+
|
29 |
+
# batch_size = 128 # 64
|
30 |
+
# batch_size = 4000
|
31 |
+
# batch_size = 3096 # 64
|
32 |
+
batch_size = 1024
|
33 |
+
validate_resolution_level = 4
|
34 |
+
warm_up_end = 5000
|
35 |
+
anneal_end = 0
|
36 |
+
use_white_bkgd = False
|
37 |
+
|
38 |
+
# save_freq = 10000
|
39 |
+
save_freq = 10000
|
40 |
+
val_freq = 20 # 2500
|
41 |
+
val_mesh_freq = 20 # 5000
|
42 |
+
report_freq = 10
|
43 |
+
### igr weight ###
|
44 |
+
igr_weight = 0.1
|
45 |
+
mask_weight = 0.1
|
46 |
+
}
|
47 |
+
|
48 |
+
model {
|
49 |
+
|
50 |
+
|
51 |
+
penetration_proj_k_to_robot = 40
|
52 |
+
|
53 |
+
penetrating_depth_penalty = 1.0
|
54 |
+
penetrating_depth_penalty = 0.0
|
55 |
+
train_states = True
|
56 |
+
penetration_proj_k_to_robot = 4000000000.0
|
57 |
+
|
58 |
+
|
59 |
+
minn_dist_threshold = 0.000
|
60 |
+
# minn_dist_threshold = 0.01
|
61 |
+
obj_mass = 100.0
|
62 |
+
obj_mass = 30.0
|
63 |
+
|
64 |
+
optimize_rules = True
|
65 |
+
|
66 |
+
use_mano_hand_for_test = False
|
67 |
+
use_mano_hand_for_test = True
|
68 |
+
|
69 |
+
train_residual_friction = False
|
70 |
+
train_residual_friction = True
|
71 |
+
|
72 |
+
use_LBFGS = True
|
73 |
+
use_LBFGS = False
|
74 |
+
|
75 |
+
use_mano_hand_for_test = False
|
76 |
+
train_residual_friction = True
|
77 |
+
|
78 |
+
extract_delta_mesh = False
|
79 |
+
freeze_weights = True
|
80 |
+
# gt_act_xs_def = True
|
81 |
+
gt_act_xs_def = False
|
82 |
+
use_bending_network = True
|
83 |
+
### for ts = 3 ###
|
84 |
+
# use_delta_bending = False
|
85 |
+
### for ts = 3 ###
|
86 |
+
use_delta_bending = True
|
87 |
+
use_passive_nets = True
|
88 |
+
# use_passive_nets = False # sv mesh root #
|
89 |
+
|
90 |
+
use_split_network = True
|
91 |
+
|
92 |
+
penetration_determining = "plane_primitives"
|
93 |
+
|
94 |
+
|
95 |
+
n_timesteps = 3 #
|
96 |
+
# n_timesteps = 5 #
|
97 |
+
n_timesteps = 7
|
98 |
+
n_timesteps = 60
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
using_delta_glb_trans = True
|
104 |
+
using_delta_glb_trans = False
|
105 |
+
|
106 |
+
optimize_with_intermediates = False
|
107 |
+
optimize_with_intermediates = True
|
108 |
+
|
109 |
+
|
110 |
+
loss_tangential_diff_coef = 1000
|
111 |
+
loss_tangential_diff_coef = 0
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
optimize_active_object = False
|
116 |
+
optimize_active_object = True
|
117 |
+
|
118 |
+
# optimize_expanded_pts = False
|
119 |
+
# optimize_expanded_pts = True
|
120 |
+
|
121 |
+
no_friction_constraint = False
|
122 |
+
|
123 |
+
optimize_glb_transformations = True
|
124 |
+
sim_model_path = "DiffHand/assets/hand_sphere_only_hand_testt.xml"
|
125 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
126 |
+
mano_mult_const_after_cent = 1.0
|
127 |
+
sim_num_steps = 1000000
|
128 |
+
|
129 |
+
bending_net_type = "active_force_field_v18"
|
130 |
+
|
131 |
+
|
132 |
+
### try to train the residual friction ? ###
|
133 |
+
train_residual_friction = True
|
134 |
+
optimize_rules = True
|
135 |
+
### cube ###
|
136 |
+
load_optimized_init_actions = ""
|
137 |
+
|
138 |
+
optimize_rules = False
|
139 |
+
|
140 |
+
|
141 |
+
## optimize rules ## penetration proj k to robot ##
|
142 |
+
optimize_rules = True
|
143 |
+
penetration_proj_k_to_robot = 4000000.0
|
144 |
+
use_optimizable_params = True
|
145 |
+
|
146 |
+
penetration_determining = "ball_primitives" # uing ball primitives
|
147 |
+
optimize_rules = True #
|
148 |
+
penetration_proj_k_to_robot = 4000000.0 #
|
149 |
+
use_optimizable_params = True
|
150 |
+
train_with_forces_to_active = False
|
151 |
+
|
152 |
+
# penetration_determining = "ball_primitives"
|
153 |
+
### obj sdf and normals for colllision eteftion and responses ##
|
154 |
+
### grab train seq 54; cylinder ###
|
155 |
+
penetration_determining = "sdf_of_canon"
|
156 |
+
optimize_rules = True
|
157 |
+
train_with_forces_to_active = False
|
158 |
+
|
159 |
+
### grab train seq 1 ###
|
160 |
+
penetration_determining = "sdf_of_canon"
|
161 |
+
train_with_forces_to_active = False
|
162 |
+
|
163 |
+
### grab train seq 224 ###
|
164 |
+
penetration_determining = "sdf_of_canon"
|
165 |
+
train_with_forces_to_active = False
|
166 |
+
loss_scale_coef = 1000.0
|
167 |
+
penetration_proj_k_to_robot_friction = 40000000.0
|
168 |
+
penetration_proj_k_to_robot_friction = 100000000.0
|
169 |
+
use_same_contact_spring_k = False
|
170 |
+
sim_model_path = "DiffHand/assets/hand_sphere_only_hand_testt.xml"
|
171 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
172 |
+
|
173 |
+
|
174 |
+
penetration_determining = "sdf_of_canon"
|
175 |
+
optimize_rules = True
|
176 |
+
# optimize_rules = True
|
177 |
+
|
178 |
+
optimize_rules = False
|
179 |
+
|
180 |
+
optimize_rules = True
|
181 |
+
|
182 |
+
|
183 |
+
optimize_rules = False
|
184 |
+
|
185 |
+
optim_sim_model_params_from_mano = True
|
186 |
+
optimize_rules = True
|
187 |
+
optim_sim_model_params_from_mano = False
|
188 |
+
optimize_rules = False
|
189 |
+
|
190 |
+
penetration_proj_k_to_robot_friction = 100000000.0
|
191 |
+
penetration_proj_k_to_robot = 40000000.0
|
192 |
+
|
193 |
+
|
194 |
+
penetrating_depth_penalty = 1
|
195 |
+
|
196 |
+
minn_dist_threshold_robot_to_obj = 0.0
|
197 |
+
|
198 |
+
|
199 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
200 |
+
|
201 |
+
optim_sim_model_params_from_mano = True
|
202 |
+
optimize_rules = True
|
203 |
+
optim_sim_model_params_from_mano = False
|
204 |
+
optimize_rules = False
|
205 |
+
optim_sim_model_params_from_mano = False
|
206 |
+
optimize_rules = False
|
207 |
+
|
208 |
+
load_optimized_init_transformations = ""
|
209 |
+
optim_sim_model_params_from_mano = True
|
210 |
+
optimize_rules = True
|
211 |
+
minn_dist_threshold_robot_to_obj = 0.0
|
212 |
+
|
213 |
+
|
214 |
+
optim_sim_model_params_from_mano = False
|
215 |
+
|
216 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
217 |
+
|
218 |
+
|
219 |
+
### kinematics confgs ###
|
220 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
221 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
222 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
223 |
+
# ckpt_fn = ""
|
224 |
+
load_optimized_init_transformations = ""
|
225 |
+
optim_sim_model_params_from_mano = True
|
226 |
+
optimize_rules = True
|
227 |
+
minn_dist_threshold_robot_to_obj = 0.0
|
228 |
+
|
229 |
+
optim_sim_model_params_from_mano = False
|
230 |
+
|
231 |
+
optimize_rules = True
|
232 |
+
|
233 |
+
ckpt_fn = "ckpts/grab/102/retargeted_shadow.pth"
|
234 |
+
ckpt_fn = "/data2/datasets/xueyi/neus/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_states_optrobot__seq_102_optactswreacts_redmaxacts_rules_/checkpoints/ckpt_035459.pth"
|
235 |
+
load_optimized_init_transformations = "ckpts/grab/102/retargeted_shadow.pth"
|
236 |
+
|
237 |
+
|
238 |
+
optimize_rules = True
|
239 |
+
|
240 |
+
## opt roboto ##
|
241 |
+
opt_robo_glb_trans = True
|
242 |
+
opt_robo_glb_rot = False # opt rot # ## opt rot ##
|
243 |
+
opt_robo_states = True
|
244 |
+
|
245 |
+
|
246 |
+
load_redmax_robot_actions_fn = "ckpts/grab/102/diffhand_act.npy"
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
ckpt_fn = ""
|
251 |
+
|
252 |
+
use_multi_stages = True
|
253 |
+
train_with_forces_to_active = True
|
254 |
+
|
255 |
+
|
256 |
+
# optimize_rules = False
|
257 |
+
loss_scale_coef = 1.0 ## loss scale coef ## loss scale coef ####
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
use_opt_rigid_translations=True
|
262 |
+
|
263 |
+
train_def = True
|
264 |
+
|
265 |
+
# optimizable_rigid_translations = False #
|
266 |
+
optimizable_rigid_translations=True
|
267 |
+
|
268 |
+
nerf {
|
269 |
+
D = 8,
|
270 |
+
d_in = 4,
|
271 |
+
d_in_view = 3,
|
272 |
+
W = 256,
|
273 |
+
multires = 10,
|
274 |
+
multires_view = 4,
|
275 |
+
output_ch = 4,
|
276 |
+
skips=[4],
|
277 |
+
use_viewdirs=True
|
278 |
+
}
|
279 |
+
|
280 |
+
sdf_network {
|
281 |
+
d_out = 257,
|
282 |
+
d_in = 3,
|
283 |
+
d_hidden = 256,
|
284 |
+
n_layers = 8,
|
285 |
+
skip_in = [4],
|
286 |
+
multires = 6,
|
287 |
+
bias = 0.5,
|
288 |
+
scale = 1.0,
|
289 |
+
geometric_init = True,
|
290 |
+
weight_norm = True,
|
291 |
+
}
|
292 |
+
|
293 |
+
variance_network {
|
294 |
+
init_val = 0.3
|
295 |
+
}
|
296 |
+
|
297 |
+
rendering_network {
|
298 |
+
d_feature = 256,
|
299 |
+
mode = idr,
|
300 |
+
d_in = 9,
|
301 |
+
d_out = 3,
|
302 |
+
d_hidden = 256,
|
303 |
+
n_layers = 4,
|
304 |
+
weight_norm = True,
|
305 |
+
multires_view = 4,
|
306 |
+
squeeze_out = True,
|
307 |
+
}
|
308 |
+
|
309 |
+
neus_renderer {
|
310 |
+
n_samples = 64,
|
311 |
+
n_importance = 64,
|
312 |
+
n_outside = 0,
|
313 |
+
up_sample_steps = 4 ,
|
314 |
+
perturb = 1.0,
|
315 |
+
}
|
316 |
+
|
317 |
+
bending_network {
|
318 |
+
multires = 6,
|
319 |
+
bending_latent_size = 32,
|
320 |
+
d_in = 3,
|
321 |
+
rigidity_hidden_dimensions = 64,
|
322 |
+
rigidity_network_depth = 5,
|
323 |
+
use_rigidity_network = False,
|
324 |
+
bending_n_timesteps = 10,
|
325 |
+
}
|
326 |
+
}
|
confs_new/dyn_grab_arti_shadow_dm_singlestage.conf
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
3 |
+
|
4 |
+
tag = "train_retargeted_shadow_hand_seq_102_diffhand_model_curriculum_"
|
5 |
+
|
6 |
+
recording = [
|
7 |
+
./,
|
8 |
+
./models
|
9 |
+
]
|
10 |
+
}
|
11 |
+
|
12 |
+
dataset {
|
13 |
+
data_dir = public_data/CASE_NAME/
|
14 |
+
render_cameras_name = cameras_sphere.npz
|
15 |
+
object_cameras_name = cameras_sphere.npz
|
16 |
+
|
17 |
+
obj_idx = 102
|
18 |
+
}
|
19 |
+
|
20 |
+
train {
|
21 |
+
learning_rate = 5e-4
|
22 |
+
learning_rate_actions = 5e-6
|
23 |
+
learning_rate_alpha = 0.05
|
24 |
+
end_iter = 300000
|
25 |
+
|
26 |
+
batch_size = 1024
|
27 |
+
validate_resolution_level = 4
|
28 |
+
warm_up_end = 5000
|
29 |
+
anneal_end = 0
|
30 |
+
use_white_bkgd = False
|
31 |
+
|
32 |
+
|
33 |
+
save_freq = 10000
|
34 |
+
val_freq = 20
|
35 |
+
val_mesh_freq = 20
|
36 |
+
report_freq = 10
|
37 |
+
igr_weight = 0.1
|
38 |
+
mask_weight = 0.1
|
39 |
+
}
|
40 |
+
|
41 |
+
model {
|
42 |
+
|
43 |
+
|
44 |
+
penetration_proj_k_to_robot = 40
|
45 |
+
|
46 |
+
penetrating_depth_penalty = 1.0
|
47 |
+
penetrating_depth_penalty = 0.0
|
48 |
+
train_states = True
|
49 |
+
penetration_proj_k_to_robot = 4000000000.0
|
50 |
+
|
51 |
+
|
52 |
+
minn_dist_threshold = 0.000
|
53 |
+
# minn_dist_threshold = 0.01
|
54 |
+
obj_mass = 100.0
|
55 |
+
obj_mass = 30.0
|
56 |
+
|
57 |
+
optimize_rules = True
|
58 |
+
|
59 |
+
use_mano_hand_for_test = False
|
60 |
+
use_mano_hand_for_test = True
|
61 |
+
|
62 |
+
train_residual_friction = False
|
63 |
+
train_residual_friction = True
|
64 |
+
|
65 |
+
use_LBFGS = True
|
66 |
+
use_LBFGS = False
|
67 |
+
|
68 |
+
use_mano_hand_for_test = False
|
69 |
+
train_residual_friction = True
|
70 |
+
|
71 |
+
extract_delta_mesh = False
|
72 |
+
freeze_weights = True
|
73 |
+
# gt_act_xs_def = True
|
74 |
+
gt_act_xs_def = False
|
75 |
+
use_bending_network = True
|
76 |
+
### for ts = 3 ###
|
77 |
+
# use_delta_bending = False
|
78 |
+
### for ts = 3 ###
|
79 |
+
use_delta_bending = True
|
80 |
+
use_passive_nets = True
|
81 |
+
# use_passive_nets = False # sv mesh root #
|
82 |
+
|
83 |
+
use_split_network = True
|
84 |
+
|
85 |
+
penetration_determining = "plane_primitives"
|
86 |
+
|
87 |
+
|
88 |
+
n_timesteps = 3 #
|
89 |
+
# n_timesteps = 5 #
|
90 |
+
n_timesteps = 7
|
91 |
+
n_timesteps = 60
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
using_delta_glb_trans = True
|
97 |
+
using_delta_glb_trans = False
|
98 |
+
|
99 |
+
optimize_with_intermediates = False
|
100 |
+
optimize_with_intermediates = True
|
101 |
+
|
102 |
+
|
103 |
+
loss_tangential_diff_coef = 1000
|
104 |
+
loss_tangential_diff_coef = 0
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
optimize_active_object = False
|
109 |
+
optimize_active_object = True
|
110 |
+
|
111 |
+
# optimize_expanded_pts = False
|
112 |
+
# optimize_expanded_pts = True
|
113 |
+
|
114 |
+
no_friction_constraint = False
|
115 |
+
|
116 |
+
optimize_glb_transformations = True
|
117 |
+
sim_model_path = "DiffHand/assets/hand_sphere_only_hand_testt.xml"
|
118 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
119 |
+
mano_mult_const_after_cent = 1.0
|
120 |
+
sim_num_steps = 1000000
|
121 |
+
|
122 |
+
bending_net_type = "active_force_field_v18"
|
123 |
+
|
124 |
+
|
125 |
+
### try to train the residual friction ? ###
|
126 |
+
train_residual_friction = True
|
127 |
+
optimize_rules = True
|
128 |
+
### cube ###
|
129 |
+
load_optimized_init_actions = ""
|
130 |
+
|
131 |
+
optimize_rules = False
|
132 |
+
|
133 |
+
|
134 |
+
## optimize rules ## penetration proj k to robot ##
|
135 |
+
optimize_rules = True
|
136 |
+
penetration_proj_k_to_robot = 4000000.0
|
137 |
+
use_optimizable_params = True
|
138 |
+
|
139 |
+
penetration_determining = "ball_primitives" # uing ball primitives
|
140 |
+
optimize_rules = True #
|
141 |
+
penetration_proj_k_to_robot = 4000000.0 #
|
142 |
+
use_optimizable_params = True
|
143 |
+
train_with_forces_to_active = False
|
144 |
+
|
145 |
+
# penetration_determining = "ball_primitives"
|
146 |
+
### obj sdf and normals for colllision eteftion and responses ##
|
147 |
+
### grab train seq 54; cylinder ###
|
148 |
+
penetration_determining = "sdf_of_canon"
|
149 |
+
optimize_rules = True
|
150 |
+
train_with_forces_to_active = False
|
151 |
+
|
152 |
+
### grab train seq 1 ###
|
153 |
+
penetration_determining = "sdf_of_canon"
|
154 |
+
train_with_forces_to_active = False
|
155 |
+
|
156 |
+
### grab train seq 224 ###
|
157 |
+
penetration_determining = "sdf_of_canon"
|
158 |
+
train_with_forces_to_active = False
|
159 |
+
loss_scale_coef = 1000.0
|
160 |
+
penetration_proj_k_to_robot_friction = 40000000.0
|
161 |
+
penetration_proj_k_to_robot_friction = 100000000.0
|
162 |
+
use_same_contact_spring_k = False
|
163 |
+
sim_model_path = "DiffHand/assets/hand_sphere_only_hand_testt.xml"
|
164 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
165 |
+
|
166 |
+
|
167 |
+
penetration_determining = "sdf_of_canon"
|
168 |
+
optimize_rules = True
|
169 |
+
# optimize_rules = True
|
170 |
+
|
171 |
+
optimize_rules = False
|
172 |
+
|
173 |
+
optimize_rules = True
|
174 |
+
|
175 |
+
|
176 |
+
optimize_rules = False
|
177 |
+
|
178 |
+
optim_sim_model_params_from_mano = True
|
179 |
+
optimize_rules = True
|
180 |
+
optim_sim_model_params_from_mano = False
|
181 |
+
optimize_rules = False
|
182 |
+
|
183 |
+
penetration_proj_k_to_robot_friction = 100000000.0
|
184 |
+
penetration_proj_k_to_robot = 40000000.0
|
185 |
+
|
186 |
+
|
187 |
+
penetrating_depth_penalty = 1
|
188 |
+
|
189 |
+
minn_dist_threshold_robot_to_obj = 0.0
|
190 |
+
|
191 |
+
|
192 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
193 |
+
|
194 |
+
optim_sim_model_params_from_mano = True
|
195 |
+
optimize_rules = True
|
196 |
+
optim_sim_model_params_from_mano = False
|
197 |
+
optimize_rules = False
|
198 |
+
optim_sim_model_params_from_mano = False
|
199 |
+
optimize_rules = False
|
200 |
+
|
201 |
+
load_optimized_init_transformations = ""
|
202 |
+
optim_sim_model_params_from_mano = True
|
203 |
+
optimize_rules = True
|
204 |
+
minn_dist_threshold_robot_to_obj = 0.0
|
205 |
+
|
206 |
+
|
207 |
+
optim_sim_model_params_from_mano = False
|
208 |
+
|
209 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
210 |
+
|
211 |
+
|
212 |
+
### kinematics confgs ###
|
213 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
214 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
215 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
216 |
+
# ckpt_fn = ""
|
217 |
+
load_optimized_init_transformations = ""
|
218 |
+
optim_sim_model_params_from_mano = True
|
219 |
+
optimize_rules = True
|
220 |
+
minn_dist_threshold_robot_to_obj = 0.0
|
221 |
+
|
222 |
+
optim_sim_model_params_from_mano = False
|
223 |
+
|
224 |
+
optimize_rules = True
|
225 |
+
|
226 |
+
ckpt_fn = "ckpts/grab/102/retargeted_shadow.pth"
|
227 |
+
load_optimized_init_transformations = "ckpts/grab/102/retargeted_shadow.pth"
|
228 |
+
|
229 |
+
|
230 |
+
optimize_rules = True
|
231 |
+
|
232 |
+
## opt roboto ##
|
233 |
+
opt_robo_glb_trans = True
|
234 |
+
opt_robo_glb_rot = False
|
235 |
+
opt_robo_states = True
|
236 |
+
|
237 |
+
|
238 |
+
load_redmax_robot_actions_fn = "ckpts/grab/102/diffhand_act.npy"
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
ckpt_fn = ""
|
243 |
+
|
244 |
+
use_multi_stages = False
|
245 |
+
train_with_forces_to_active = True
|
246 |
+
|
247 |
+
|
248 |
+
# optimize_rules = False
|
249 |
+
loss_scale_coef = 1.0 ## loss scale coef ## loss scale coef ####
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
use_opt_rigid_translations=True
|
254 |
+
|
255 |
+
train_def = True
|
256 |
+
|
257 |
+
# optimizable_rigid_translations = False #
|
258 |
+
optimizable_rigid_translations=True
|
259 |
+
|
260 |
+
nerf {
|
261 |
+
D = 8,
|
262 |
+
d_in = 4,
|
263 |
+
d_in_view = 3,
|
264 |
+
W = 256,
|
265 |
+
multires = 10,
|
266 |
+
multires_view = 4,
|
267 |
+
output_ch = 4,
|
268 |
+
skips=[4],
|
269 |
+
use_viewdirs=True
|
270 |
+
}
|
271 |
+
|
272 |
+
sdf_network {
|
273 |
+
d_out = 257,
|
274 |
+
d_in = 3,
|
275 |
+
d_hidden = 256,
|
276 |
+
n_layers = 8,
|
277 |
+
skip_in = [4],
|
278 |
+
multires = 6,
|
279 |
+
bias = 0.5,
|
280 |
+
scale = 1.0,
|
281 |
+
geometric_init = True,
|
282 |
+
weight_norm = True,
|
283 |
+
}
|
284 |
+
|
285 |
+
variance_network {
|
286 |
+
init_val = 0.3
|
287 |
+
}
|
288 |
+
|
289 |
+
rendering_network {
|
290 |
+
d_feature = 256,
|
291 |
+
mode = idr,
|
292 |
+
d_in = 9,
|
293 |
+
d_out = 3,
|
294 |
+
d_hidden = 256,
|
295 |
+
n_layers = 4,
|
296 |
+
weight_norm = True,
|
297 |
+
multires_view = 4,
|
298 |
+
squeeze_out = True,
|
299 |
+
}
|
300 |
+
|
301 |
+
neus_renderer {
|
302 |
+
n_samples = 64,
|
303 |
+
n_importance = 64,
|
304 |
+
n_outside = 0,
|
305 |
+
up_sample_steps = 4 ,
|
306 |
+
perturb = 1.0,
|
307 |
+
}
|
308 |
+
|
309 |
+
bending_network {
|
310 |
+
multires = 6,
|
311 |
+
bending_latent_size = 32,
|
312 |
+
d_in = 3,
|
313 |
+
rigidity_hidden_dimensions = 64,
|
314 |
+
rigidity_network_depth = 5,
|
315 |
+
use_rigidity_network = False,
|
316 |
+
bending_n_timesteps = 10,
|
317 |
+
}
|
318 |
+
}
|
confs_new/dyn_grab_pointset_mano.conf
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
|
7 |
+
# tag = "train_retargeted_shadow_hand_seq_102_mano_sparse_retargeting_"
|
8 |
+
tag = "train_dyn_mano_acts_"
|
9 |
+
|
10 |
+
recording = [
|
11 |
+
./,
|
12 |
+
./models
|
13 |
+
]
|
14 |
+
}
|
15 |
+
|
16 |
+
dataset {
|
17 |
+
data_dir = public_data/CASE_NAME/
|
18 |
+
render_cameras_name = cameras_sphere.npz
|
19 |
+
object_cameras_name = cameras_sphere.npz
|
20 |
+
obj_idx = 102
|
21 |
+
}
|
22 |
+
|
23 |
+
train {
|
24 |
+
learning_rate = 5e-4
|
25 |
+
learning_rate_alpha = 0.05
|
26 |
+
end_iter = 300000
|
27 |
+
|
28 |
+
batch_size = 1024
|
29 |
+
validate_resolution_level = 4
|
30 |
+
warm_up_end = 5000
|
31 |
+
anneal_end = 0
|
32 |
+
use_white_bkgd = False
|
33 |
+
|
34 |
+
# save_freq = 10000
|
35 |
+
save_freq = 10000
|
36 |
+
val_freq = 20
|
37 |
+
val_mesh_freq = 20
|
38 |
+
report_freq = 10
|
39 |
+
igr_weight = 0.1
|
40 |
+
mask_weight = 0.1
|
41 |
+
}
|
42 |
+
|
43 |
+
model {
|
44 |
+
|
45 |
+
optimize_dyn_actions = True
|
46 |
+
|
47 |
+
|
48 |
+
optimize_robot = True
|
49 |
+
|
50 |
+
use_penalty_based_friction = True
|
51 |
+
|
52 |
+
use_split_params = False
|
53 |
+
|
54 |
+
use_sqr_spring_stiffness = True
|
55 |
+
|
56 |
+
use_pre_proj_frictions = True
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
use_sqrt_dist = True
|
61 |
+
contact_maintaining_dist_thres = 0.2
|
62 |
+
|
63 |
+
robot_actions_diff_coef = 0.001
|
64 |
+
|
65 |
+
|
66 |
+
use_sdf_as_contact_dist = True
|
67 |
+
|
68 |
+
|
69 |
+
#
|
70 |
+
use_contact_dist_as_sdf = False
|
71 |
+
|
72 |
+
use_glb_proj_delta = True
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
# penetration_proj_k_to_robot = 30
|
77 |
+
penetrating_depth_penalty = 1.0
|
78 |
+
train_states = True
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
minn_dist_threshold = 0.000
|
83 |
+
obj_mass = 30.0
|
84 |
+
|
85 |
+
|
86 |
+
use_LBFGS = True
|
87 |
+
use_LBFGS = False
|
88 |
+
|
89 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
90 |
+
|
91 |
+
extract_delta_mesh = False
|
92 |
+
freeze_weights = True
|
93 |
+
gt_act_xs_def = False
|
94 |
+
use_bending_network = True
|
95 |
+
### for ts = 3 ###
|
96 |
+
# use_delta_bending = False
|
97 |
+
### for ts = 3 ###
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
103 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
104 |
+
|
105 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
106 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
107 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
108 |
+
|
109 |
+
bending_net_type = "active_force_field_v18"
|
110 |
+
sim_num_steps = 1000000
|
111 |
+
n_timesteps = 60
|
112 |
+
optim_sim_model_params_from_mano = False
|
113 |
+
penetration_determining = "sdf_of_canon"
|
114 |
+
train_with_forces_to_active = False
|
115 |
+
loss_scale_coef = 1000.0
|
116 |
+
use_same_contact_spring_k = False
|
117 |
+
use_optimizable_params = True #
|
118 |
+
train_residual_friction = True
|
119 |
+
mano_mult_const_after_cent = 1.0
|
120 |
+
optimize_glb_transformations = True
|
121 |
+
no_friction_constraint = False
|
122 |
+
optimize_active_object = True
|
123 |
+
loss_tangential_diff_coef = 0
|
124 |
+
optimize_with_intermediates = True
|
125 |
+
using_delta_glb_trans = False
|
126 |
+
train_multi_seqs = False
|
127 |
+
use_split_network = True
|
128 |
+
use_delta_bending = True
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
##### contact spring model settings ####
|
134 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
135 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
136 |
+
penetration_proj_k_to_robot = 4000000.0
|
137 |
+
##### contact spring model settings ####
|
138 |
+
|
139 |
+
|
140 |
+
###### ######
|
141 |
+
# drive_pointset = "states"
|
142 |
+
fix_obj = True # to track the hand only
|
143 |
+
optimize_rules = False
|
144 |
+
train_pointset_acts_via_deltas = False
|
145 |
+
load_optimized_init_actions = ""
|
146 |
+
load_optimized_init_transformations = ""
|
147 |
+
ckpt_fn = ""
|
148 |
+
retar_only_glb = True
|
149 |
+
# use_multi_stages = True
|
150 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
151 |
+
|
152 |
+
use_opt_rigid_translations=True
|
153 |
+
|
154 |
+
train_def = True
|
155 |
+
optimizable_rigid_translations=True
|
156 |
+
|
157 |
+
nerf {
|
158 |
+
D = 8,
|
159 |
+
d_in = 4,
|
160 |
+
d_in_view = 3,
|
161 |
+
W = 256,
|
162 |
+
multires = 10,
|
163 |
+
multires_view = 4,
|
164 |
+
output_ch = 4,
|
165 |
+
skips=[4],
|
166 |
+
use_viewdirs=True
|
167 |
+
}
|
168 |
+
|
169 |
+
sdf_network {
|
170 |
+
d_out = 257,
|
171 |
+
d_in = 3,
|
172 |
+
d_hidden = 256,
|
173 |
+
n_layers = 8,
|
174 |
+
skip_in = [4],
|
175 |
+
multires = 6,
|
176 |
+
bias = 0.5,
|
177 |
+
scale = 1.0,
|
178 |
+
geometric_init = True,
|
179 |
+
weight_norm = True,
|
180 |
+
}
|
181 |
+
|
182 |
+
variance_network {
|
183 |
+
init_val = 0.3
|
184 |
+
}
|
185 |
+
|
186 |
+
rendering_network {
|
187 |
+
d_feature = 256,
|
188 |
+
mode = idr,
|
189 |
+
d_in = 9,
|
190 |
+
d_out = 3,
|
191 |
+
d_hidden = 256,
|
192 |
+
n_layers = 4,
|
193 |
+
weight_norm = True,
|
194 |
+
multires_view = 4,
|
195 |
+
squeeze_out = True,
|
196 |
+
}
|
197 |
+
|
198 |
+
neus_renderer {
|
199 |
+
n_samples = 64,
|
200 |
+
n_importance = 64,
|
201 |
+
n_outside = 0,
|
202 |
+
up_sample_steps = 4 ,
|
203 |
+
perturb = 1.0,
|
204 |
+
}
|
205 |
+
|
206 |
+
bending_network {
|
207 |
+
multires = 6,
|
208 |
+
bending_latent_size = 32,
|
209 |
+
d_in = 3,
|
210 |
+
rigidity_hidden_dimensions = 64,
|
211 |
+
rigidity_network_depth = 5,
|
212 |
+
use_rigidity_network = False,
|
213 |
+
bending_n_timesteps = 10,
|
214 |
+
}
|
215 |
+
}
|
confs_new/dyn_grab_pointset_mano_dyn.conf
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
|
7 |
+
# tag = "train_retargeted_shadow_hand_seq_102_mano_sparse_retargeting_"
|
8 |
+
# tag = "train_dyn_mano_acts_"
|
9 |
+
tag = "train_dyn_mano_acts_wreact_optps_"
|
10 |
+
|
11 |
+
recording = [
|
12 |
+
./,
|
13 |
+
./models
|
14 |
+
]
|
15 |
+
}
|
16 |
+
|
17 |
+
dataset {
|
18 |
+
data_dir = public_data/CASE_NAME/
|
19 |
+
render_cameras_name = cameras_sphere.npz
|
20 |
+
object_cameras_name = cameras_sphere.npz
|
21 |
+
obj_idx = 102
|
22 |
+
}
|
23 |
+
|
24 |
+
train {
|
25 |
+
learning_rate = 5e-4
|
26 |
+
learning_rate_alpha = 0.05
|
27 |
+
end_iter = 300000
|
28 |
+
|
29 |
+
batch_size = 1024
|
30 |
+
validate_resolution_level = 4
|
31 |
+
warm_up_end = 5000
|
32 |
+
anneal_end = 0
|
33 |
+
use_white_bkgd = False
|
34 |
+
|
35 |
+
# save_freq = 10000
|
36 |
+
save_freq = 10000
|
37 |
+
val_freq = 20
|
38 |
+
val_mesh_freq = 20
|
39 |
+
report_freq = 10
|
40 |
+
igr_weight = 0.1
|
41 |
+
mask_weight = 0.1
|
42 |
+
}
|
43 |
+
|
44 |
+
model {
|
45 |
+
|
46 |
+
optimize_dyn_actions = True
|
47 |
+
|
48 |
+
|
49 |
+
optimize_robot = True
|
50 |
+
|
51 |
+
use_penalty_based_friction = True
|
52 |
+
|
53 |
+
use_split_params = False
|
54 |
+
|
55 |
+
use_sqr_spring_stiffness = True
|
56 |
+
|
57 |
+
use_pre_proj_frictions = True
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
use_sqrt_dist = True
|
62 |
+
contact_maintaining_dist_thres = 0.2
|
63 |
+
|
64 |
+
robot_actions_diff_coef = 0.001
|
65 |
+
|
66 |
+
|
67 |
+
use_sdf_as_contact_dist = True
|
68 |
+
|
69 |
+
|
70 |
+
#
|
71 |
+
use_contact_dist_as_sdf = False
|
72 |
+
|
73 |
+
use_glb_proj_delta = True
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
# penetration_proj_k_to_robot = 30
|
78 |
+
penetrating_depth_penalty = 1.0
|
79 |
+
train_states = True
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
minn_dist_threshold = 0.000
|
84 |
+
obj_mass = 30.0
|
85 |
+
|
86 |
+
|
87 |
+
use_LBFGS = True
|
88 |
+
use_LBFGS = False
|
89 |
+
|
90 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
91 |
+
|
92 |
+
extract_delta_mesh = False
|
93 |
+
freeze_weights = True
|
94 |
+
gt_act_xs_def = False
|
95 |
+
use_bending_network = True
|
96 |
+
### for ts = 3 ###
|
97 |
+
# use_delta_bending = False
|
98 |
+
### for ts = 3 ###
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
104 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
105 |
+
|
106 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
107 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
108 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
109 |
+
|
110 |
+
bending_net_type = "active_force_field_v18"
|
111 |
+
sim_num_steps = 1000000
|
112 |
+
n_timesteps = 60
|
113 |
+
optim_sim_model_params_from_mano = False
|
114 |
+
penetration_determining = "sdf_of_canon"
|
115 |
+
train_with_forces_to_active = False
|
116 |
+
loss_scale_coef = 1000.0
|
117 |
+
use_same_contact_spring_k = False
|
118 |
+
use_optimizable_params = True #
|
119 |
+
train_residual_friction = True
|
120 |
+
mano_mult_const_after_cent = 1.0
|
121 |
+
optimize_glb_transformations = True
|
122 |
+
no_friction_constraint = False
|
123 |
+
optimize_active_object = True
|
124 |
+
loss_tangential_diff_coef = 0
|
125 |
+
optimize_with_intermediates = True
|
126 |
+
using_delta_glb_trans = False
|
127 |
+
train_multi_seqs = False
|
128 |
+
use_split_network = True
|
129 |
+
use_delta_bending = True
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
##### contact spring model settings ####
|
135 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
136 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
137 |
+
penetration_proj_k_to_robot = 4000000.0
|
138 |
+
##### contact spring model settings ####
|
139 |
+
|
140 |
+
|
141 |
+
###### Stage 1: optimize for the parametes ######
|
142 |
+
# drive_pointset = "states"
|
143 |
+
fix_obj = False
|
144 |
+
optimize_rules = True
|
145 |
+
train_pointset_acts_via_deltas = False
|
146 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
|
147 |
+
load_optimized_init_transformations = ""
|
148 |
+
ckpt_fn = "ckpts/grab/102/dyn_mano_arti.pth"
|
149 |
+
# retar_only_glb = True
|
150 |
+
# use_multi_stages = True
|
151 |
+
###### Stage 1: optimize for the parametes ######
|
152 |
+
|
153 |
+
use_opt_rigid_translations=True
|
154 |
+
|
155 |
+
train_def = True
|
156 |
+
optimizable_rigid_translations=True
|
157 |
+
|
158 |
+
nerf {
|
159 |
+
D = 8,
|
160 |
+
d_in = 4,
|
161 |
+
d_in_view = 3,
|
162 |
+
W = 256,
|
163 |
+
multires = 10,
|
164 |
+
multires_view = 4,
|
165 |
+
output_ch = 4,
|
166 |
+
skips=[4],
|
167 |
+
use_viewdirs=True
|
168 |
+
}
|
169 |
+
|
170 |
+
sdf_network {
|
171 |
+
d_out = 257,
|
172 |
+
d_in = 3,
|
173 |
+
d_hidden = 256,
|
174 |
+
n_layers = 8,
|
175 |
+
skip_in = [4],
|
176 |
+
multires = 6,
|
177 |
+
bias = 0.5,
|
178 |
+
scale = 1.0,
|
179 |
+
geometric_init = True,
|
180 |
+
weight_norm = True,
|
181 |
+
}
|
182 |
+
|
183 |
+
variance_network {
|
184 |
+
init_val = 0.3
|
185 |
+
}
|
186 |
+
|
187 |
+
rendering_network {
|
188 |
+
d_feature = 256,
|
189 |
+
mode = idr,
|
190 |
+
d_in = 9,
|
191 |
+
d_out = 3,
|
192 |
+
d_hidden = 256,
|
193 |
+
n_layers = 4,
|
194 |
+
weight_norm = True,
|
195 |
+
multires_view = 4,
|
196 |
+
squeeze_out = True,
|
197 |
+
}
|
198 |
+
|
199 |
+
neus_renderer {
|
200 |
+
n_samples = 64,
|
201 |
+
n_importance = 64,
|
202 |
+
n_outside = 0,
|
203 |
+
up_sample_steps = 4 ,
|
204 |
+
perturb = 1.0,
|
205 |
+
}
|
206 |
+
|
207 |
+
bending_network {
|
208 |
+
multires = 6,
|
209 |
+
bending_latent_size = 32,
|
210 |
+
d_in = 3,
|
211 |
+
rigidity_hidden_dimensions = 64,
|
212 |
+
rigidity_network_depth = 5,
|
213 |
+
use_rigidity_network = False,
|
214 |
+
bending_n_timesteps = 10,
|
215 |
+
}
|
216 |
+
}
|
217 |
+
|
218 |
+
|
confs_new/dyn_grab_pointset_mano_dyn_optacts.conf
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
|
7 |
+
# tag = "train_retargeted_shadow_hand_seq_102_mano_sparse_retargeting_"
|
8 |
+
# tag = "train_dyn_mano_acts_"
|
9 |
+
tag = "train_dyn_mano_acts_wreact_optps_optacts_"
|
10 |
+
|
11 |
+
recording = [
|
12 |
+
./,
|
13 |
+
./models
|
14 |
+
]
|
15 |
+
}
|
16 |
+
|
17 |
+
dataset {
|
18 |
+
data_dir = public_data/CASE_NAME/
|
19 |
+
render_cameras_name = cameras_sphere.npz
|
20 |
+
object_cameras_name = cameras_sphere.npz
|
21 |
+
obj_idx = 102
|
22 |
+
}
|
23 |
+
|
24 |
+
train {
|
25 |
+
learning_rate = 5e-4
|
26 |
+
learning_rate_alpha = 0.05
|
27 |
+
end_iter = 300000
|
28 |
+
|
29 |
+
batch_size = 1024
|
30 |
+
validate_resolution_level = 4
|
31 |
+
warm_up_end = 5000
|
32 |
+
anneal_end = 0
|
33 |
+
use_white_bkgd = False
|
34 |
+
|
35 |
+
# save_freq = 10000
|
36 |
+
save_freq = 10000
|
37 |
+
val_freq = 20
|
38 |
+
val_mesh_freq = 20
|
39 |
+
report_freq = 10
|
40 |
+
igr_weight = 0.1
|
41 |
+
mask_weight = 0.1
|
42 |
+
}
|
43 |
+
|
44 |
+
model {
|
45 |
+
|
46 |
+
optimize_dyn_actions = True
|
47 |
+
|
48 |
+
|
49 |
+
optimize_robot = True
|
50 |
+
|
51 |
+
use_penalty_based_friction = True
|
52 |
+
|
53 |
+
use_split_params = False
|
54 |
+
|
55 |
+
use_sqr_spring_stiffness = True
|
56 |
+
|
57 |
+
use_pre_proj_frictions = True
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
use_sqrt_dist = True
|
62 |
+
contact_maintaining_dist_thres = 0.2
|
63 |
+
|
64 |
+
robot_actions_diff_coef = 0.001
|
65 |
+
|
66 |
+
|
67 |
+
use_sdf_as_contact_dist = True
|
68 |
+
|
69 |
+
|
70 |
+
#
|
71 |
+
use_contact_dist_as_sdf = False
|
72 |
+
|
73 |
+
use_glb_proj_delta = True
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
# penetration_proj_k_to_robot = 30
|
78 |
+
penetrating_depth_penalty = 1.0
|
79 |
+
train_states = True
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
minn_dist_threshold = 0.000
|
84 |
+
obj_mass = 30.0
|
85 |
+
|
86 |
+
|
87 |
+
use_LBFGS = True
|
88 |
+
use_LBFGS = False
|
89 |
+
|
90 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
91 |
+
|
92 |
+
extract_delta_mesh = False
|
93 |
+
freeze_weights = True
|
94 |
+
gt_act_xs_def = False
|
95 |
+
use_bending_network = True
|
96 |
+
### for ts = 3 ###
|
97 |
+
# use_delta_bending = False
|
98 |
+
### for ts = 3 ###
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
104 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
105 |
+
|
106 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
107 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
108 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
109 |
+
|
110 |
+
bending_net_type = "active_force_field_v18"
|
111 |
+
sim_num_steps = 1000000
|
112 |
+
n_timesteps = 60
|
113 |
+
optim_sim_model_params_from_mano = False
|
114 |
+
penetration_determining = "sdf_of_canon"
|
115 |
+
train_with_forces_to_active = False
|
116 |
+
loss_scale_coef = 1000.0
|
117 |
+
use_same_contact_spring_k = False
|
118 |
+
use_optimizable_params = True #
|
119 |
+
train_residual_friction = True
|
120 |
+
mano_mult_const_after_cent = 1.0
|
121 |
+
optimize_glb_transformations = True
|
122 |
+
no_friction_constraint = False
|
123 |
+
optimize_active_object = True
|
124 |
+
loss_tangential_diff_coef = 0
|
125 |
+
optimize_with_intermediates = True
|
126 |
+
using_delta_glb_trans = False
|
127 |
+
train_multi_seqs = False
|
128 |
+
use_split_network = True
|
129 |
+
use_delta_bending = True
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
##### contact spring model settings ####
|
135 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
136 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
137 |
+
penetration_proj_k_to_robot = 4000000.0
|
138 |
+
##### contact spring model settings ####
|
139 |
+
|
140 |
+
|
141 |
+
###### Stage 1: optimize for the parametes ######
|
142 |
+
# drive_pointset = "states"
|
143 |
+
fix_obj = False
|
144 |
+
optimize_rules = False
|
145 |
+
train_pointset_acts_via_deltas = False
|
146 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
|
147 |
+
load_optimized_init_transformations = ""
|
148 |
+
ckpt_fn = "ckpts/grab/102/dyn_mano_arti.pth"
|
149 |
+
# retar_only_glb = True
|
150 |
+
# use_multi_stages = True
|
151 |
+
###### Stage 1: optimize for the parametes ######
|
152 |
+
|
153 |
+
use_opt_rigid_translations=True
|
154 |
+
|
155 |
+
train_def = True
|
156 |
+
optimizable_rigid_translations=True
|
157 |
+
|
158 |
+
nerf {
|
159 |
+
D = 8,
|
160 |
+
d_in = 4,
|
161 |
+
d_in_view = 3,
|
162 |
+
W = 256,
|
163 |
+
multires = 10,
|
164 |
+
multires_view = 4,
|
165 |
+
output_ch = 4,
|
166 |
+
skips=[4],
|
167 |
+
use_viewdirs=True
|
168 |
+
}
|
169 |
+
|
170 |
+
sdf_network {
|
171 |
+
d_out = 257,
|
172 |
+
d_in = 3,
|
173 |
+
d_hidden = 256,
|
174 |
+
n_layers = 8,
|
175 |
+
skip_in = [4],
|
176 |
+
multires = 6,
|
177 |
+
bias = 0.5,
|
178 |
+
scale = 1.0,
|
179 |
+
geometric_init = True,
|
180 |
+
weight_norm = True,
|
181 |
+
}
|
182 |
+
|
183 |
+
variance_network {
|
184 |
+
init_val = 0.3
|
185 |
+
}
|
186 |
+
|
187 |
+
rendering_network {
|
188 |
+
d_feature = 256,
|
189 |
+
mode = idr,
|
190 |
+
d_in = 9,
|
191 |
+
d_out = 3,
|
192 |
+
d_hidden = 256,
|
193 |
+
n_layers = 4,
|
194 |
+
weight_norm = True,
|
195 |
+
multires_view = 4,
|
196 |
+
squeeze_out = True,
|
197 |
+
}
|
198 |
+
|
199 |
+
neus_renderer {
|
200 |
+
n_samples = 64,
|
201 |
+
n_importance = 64,
|
202 |
+
n_outside = 0,
|
203 |
+
up_sample_steps = 4 ,
|
204 |
+
perturb = 1.0,
|
205 |
+
}
|
206 |
+
|
207 |
+
bending_network {
|
208 |
+
multires = 6,
|
209 |
+
bending_latent_size = 32,
|
210 |
+
d_in = 3,
|
211 |
+
rigidity_hidden_dimensions = 64,
|
212 |
+
rigidity_network_depth = 5,
|
213 |
+
use_rigidity_network = False,
|
214 |
+
bending_n_timesteps = 10,
|
215 |
+
}
|
216 |
+
}
|
217 |
+
|
218 |
+
|
confs_new/dyn_grab_pointset_points_dyn.conf
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_"
|
7 |
+
|
8 |
+
recording = [
|
9 |
+
./,
|
10 |
+
./models
|
11 |
+
]
|
12 |
+
}
|
13 |
+
|
14 |
+
dataset {
|
15 |
+
data_dir = public_data/CASE_NAME/
|
16 |
+
render_cameras_name = cameras_sphere.npz
|
17 |
+
object_cameras_name = cameras_sphere.npz
|
18 |
+
obj_idx = 102
|
19 |
+
}
|
20 |
+
|
21 |
+
train {
|
22 |
+
learning_rate = 5e-4
|
23 |
+
learning_rate_alpha = 0.05
|
24 |
+
end_iter = 300000
|
25 |
+
|
26 |
+
batch_size = 1024 # 64
|
27 |
+
validate_resolution_level = 4
|
28 |
+
warm_up_end = 5000
|
29 |
+
anneal_end = 0
|
30 |
+
use_white_bkgd = False
|
31 |
+
|
32 |
+
# save_freq = 10000
|
33 |
+
save_freq = 10000
|
34 |
+
val_freq = 20 # 2500
|
35 |
+
val_mesh_freq = 20 # 5000
|
36 |
+
report_freq = 10
|
37 |
+
### igr weight ###
|
38 |
+
igr_weight = 0.1
|
39 |
+
mask_weight = 0.1
|
40 |
+
}
|
41 |
+
|
42 |
+
model {
|
43 |
+
|
44 |
+
optimize_dyn_actions = True
|
45 |
+
|
46 |
+
|
47 |
+
optimize_robot = True
|
48 |
+
|
49 |
+
use_penalty_based_friction = True
|
50 |
+
|
51 |
+
use_split_params = False
|
52 |
+
|
53 |
+
use_sqr_spring_stiffness = True
|
54 |
+
|
55 |
+
use_pre_proj_frictions = True
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
use_sqrt_dist = True
|
60 |
+
contact_maintaining_dist_thres = 0.2
|
61 |
+
|
62 |
+
robot_actions_diff_coef = 0.001
|
63 |
+
|
64 |
+
|
65 |
+
use_sdf_as_contact_dist = True
|
66 |
+
|
67 |
+
|
68 |
+
#
|
69 |
+
use_contact_dist_as_sdf = False
|
70 |
+
|
71 |
+
use_glb_proj_delta = True
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
# penetration_proj_k_to_robot = 30
|
76 |
+
penetrating_depth_penalty = 1.0
|
77 |
+
train_states = True
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
minn_dist_threshold = 0.000
|
82 |
+
obj_mass = 30.0
|
83 |
+
|
84 |
+
|
85 |
+
use_LBFGS = True
|
86 |
+
use_LBFGS = False
|
87 |
+
|
88 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
89 |
+
|
90 |
+
extract_delta_mesh = False
|
91 |
+
freeze_weights = True
|
92 |
+
gt_act_xs_def = False
|
93 |
+
use_bending_network = True
|
94 |
+
### for ts = 3 ###
|
95 |
+
# use_delta_bending = False
|
96 |
+
### for ts = 3 ###
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
102 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
103 |
+
|
104 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
105 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
106 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
107 |
+
|
108 |
+
bending_net_type = "active_force_field_v18"
|
109 |
+
sim_num_steps = 1000000
|
110 |
+
n_timesteps = 60
|
111 |
+
optim_sim_model_params_from_mano = False
|
112 |
+
penetration_determining = "sdf_of_canon"
|
113 |
+
train_with_forces_to_active = False
|
114 |
+
loss_scale_coef = 1000.0
|
115 |
+
use_same_contact_spring_k = False
|
116 |
+
use_optimizable_params = True #
|
117 |
+
train_residual_friction = True
|
118 |
+
mano_mult_const_after_cent = 1.0
|
119 |
+
optimize_glb_transformations = True
|
120 |
+
no_friction_constraint = False
|
121 |
+
optimize_active_object = True
|
122 |
+
loss_tangential_diff_coef = 0
|
123 |
+
optimize_with_intermediates = True
|
124 |
+
using_delta_glb_trans = False
|
125 |
+
train_multi_seqs = False
|
126 |
+
use_split_network = True
|
127 |
+
use_delta_bending = True
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
###### threshold, ks settings 1, optimize acts ######
|
138 |
+
# drive_pointset = "actions"
|
139 |
+
# fix_obj = True
|
140 |
+
# optimize_rules = False
|
141 |
+
# train_pointset_acts_via_deltas = True
|
142 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
|
143 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
|
144 |
+
###### threshold, ks settings 1, optimize acts ######
|
145 |
+
|
146 |
+
|
147 |
+
##### contact spring model settings ####
|
148 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
149 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
150 |
+
penetration_proj_k_to_robot = 4000000.0
|
151 |
+
##### contact spring model settings ####
|
152 |
+
|
153 |
+
|
154 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
155 |
+
drive_pointset = "states"
|
156 |
+
fix_obj = True
|
157 |
+
optimize_rules = False
|
158 |
+
train_pointset_acts_via_deltas = False
|
159 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
|
160 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
161 |
+
|
162 |
+
|
163 |
+
###### Stage 2: threshold, ks settings 1, optimize acts ######
|
164 |
+
drive_pointset = "actions"
|
165 |
+
fix_obj = True
|
166 |
+
optimize_rules = False
|
167 |
+
train_pointset_acts_via_deltas = True
|
168 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
|
169 |
+
###### Stage 2: threshold, ks settings 1, optimize acts ######
|
170 |
+
|
171 |
+
|
172 |
+
###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
173 |
+
drive_pointset = "actions"
|
174 |
+
fix_obj = False
|
175 |
+
optimize_rules = True
|
176 |
+
train_pointset_acts_via_deltas = True
|
177 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
|
178 |
+
##### model parameters optimized from the MANO hand trajectory #####
|
179 |
+
ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
|
180 |
+
###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
181 |
+
|
182 |
+
|
183 |
+
###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
184 |
+
drive_pointset = "actions"
|
185 |
+
fix_obj = False
|
186 |
+
optimize_rules = False
|
187 |
+
train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
|
188 |
+
##### model parameters optimized from the MANO hand expanded set trajectory #####
|
189 |
+
ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
|
190 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
|
191 |
+
###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
192 |
+
|
193 |
+
|
194 |
+
use_opt_rigid_translations=True
|
195 |
+
|
196 |
+
train_def = True
|
197 |
+
optimizable_rigid_translations=True
|
198 |
+
|
199 |
+
nerf {
|
200 |
+
D = 8,
|
201 |
+
d_in = 4,
|
202 |
+
d_in_view = 3,
|
203 |
+
W = 256,
|
204 |
+
multires = 10,
|
205 |
+
multires_view = 4,
|
206 |
+
output_ch = 4,
|
207 |
+
skips=[4],
|
208 |
+
use_viewdirs=True
|
209 |
+
}
|
210 |
+
|
211 |
+
sdf_network {
|
212 |
+
d_out = 257,
|
213 |
+
d_in = 3,
|
214 |
+
d_hidden = 256,
|
215 |
+
n_layers = 8,
|
216 |
+
skip_in = [4],
|
217 |
+
multires = 6,
|
218 |
+
bias = 0.5,
|
219 |
+
scale = 1.0,
|
220 |
+
geometric_init = True,
|
221 |
+
weight_norm = True,
|
222 |
+
}
|
223 |
+
|
224 |
+
variance_network {
|
225 |
+
init_val = 0.3
|
226 |
+
}
|
227 |
+
|
228 |
+
rendering_network {
|
229 |
+
d_feature = 256,
|
230 |
+
mode = idr,
|
231 |
+
d_in = 9,
|
232 |
+
d_out = 3,
|
233 |
+
d_hidden = 256,
|
234 |
+
n_layers = 4,
|
235 |
+
weight_norm = True,
|
236 |
+
multires_view = 4,
|
237 |
+
squeeze_out = True,
|
238 |
+
}
|
239 |
+
|
240 |
+
neus_renderer {
|
241 |
+
n_samples = 64,
|
242 |
+
n_importance = 64,
|
243 |
+
n_outside = 0,
|
244 |
+
up_sample_steps = 4 ,
|
245 |
+
perturb = 1.0,
|
246 |
+
}
|
247 |
+
|
248 |
+
bending_network {
|
249 |
+
multires = 6,
|
250 |
+
bending_latent_size = 32,
|
251 |
+
d_in = 3,
|
252 |
+
rigidity_hidden_dimensions = 64,
|
253 |
+
rigidity_network_depth = 5,
|
254 |
+
use_rigidity_network = False,
|
255 |
+
bending_n_timesteps = 10,
|
256 |
+
}
|
257 |
+
}
|
confs_new/dyn_grab_pointset_points_dyn_retar.conf
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_"
|
7 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_retar_to_shadow_"
|
8 |
+
|
9 |
+
recording = [
|
10 |
+
./,
|
11 |
+
./models
|
12 |
+
]
|
13 |
+
}
|
14 |
+
|
15 |
+
dataset {
|
16 |
+
data_dir = public_data/CASE_NAME/
|
17 |
+
render_cameras_name = cameras_sphere.npz
|
18 |
+
object_cameras_name = cameras_sphere.npz
|
19 |
+
obj_idx = 102
|
20 |
+
}
|
21 |
+
|
22 |
+
train {
|
23 |
+
learning_rate = 5e-4
|
24 |
+
learning_rate_alpha = 0.05
|
25 |
+
end_iter = 300000
|
26 |
+
|
27 |
+
batch_size = 1024 # 64
|
28 |
+
validate_resolution_level = 4
|
29 |
+
warm_up_end = 5000
|
30 |
+
anneal_end = 0
|
31 |
+
use_white_bkgd = False
|
32 |
+
|
33 |
+
# save_freq = 10000
|
34 |
+
save_freq = 10000
|
35 |
+
val_freq = 20 # 2500
|
36 |
+
val_mesh_freq = 20 # 5000
|
37 |
+
report_freq = 10
|
38 |
+
### igr weight ###
|
39 |
+
igr_weight = 0.1
|
40 |
+
mask_weight = 0.1
|
41 |
+
}
|
42 |
+
|
43 |
+
model {
|
44 |
+
|
45 |
+
optimize_dyn_actions = True
|
46 |
+
|
47 |
+
|
48 |
+
optimize_robot = True
|
49 |
+
|
50 |
+
use_penalty_based_friction = True
|
51 |
+
|
52 |
+
use_split_params = False
|
53 |
+
|
54 |
+
use_sqr_spring_stiffness = True
|
55 |
+
|
56 |
+
use_pre_proj_frictions = True
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
use_sqrt_dist = True
|
61 |
+
contact_maintaining_dist_thres = 0.2
|
62 |
+
|
63 |
+
robot_actions_diff_coef = 0.001
|
64 |
+
|
65 |
+
|
66 |
+
use_sdf_as_contact_dist = True
|
67 |
+
|
68 |
+
|
69 |
+
#
|
70 |
+
use_contact_dist_as_sdf = False
|
71 |
+
|
72 |
+
use_glb_proj_delta = True
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
# penetration_proj_k_to_robot = 30
|
77 |
+
penetrating_depth_penalty = 1.0
|
78 |
+
train_states = True
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
minn_dist_threshold = 0.000
|
83 |
+
obj_mass = 30.0
|
84 |
+
|
85 |
+
|
86 |
+
use_LBFGS = True
|
87 |
+
use_LBFGS = False
|
88 |
+
|
89 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
90 |
+
|
91 |
+
extract_delta_mesh = False
|
92 |
+
freeze_weights = True
|
93 |
+
gt_act_xs_def = False
|
94 |
+
use_bending_network = True
|
95 |
+
### for ts = 3 ###
|
96 |
+
# use_delta_bending = False
|
97 |
+
### for ts = 3 ###
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
103 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
104 |
+
|
105 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
106 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
107 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
108 |
+
|
109 |
+
bending_net_type = "active_force_field_v18"
|
110 |
+
sim_num_steps = 1000000
|
111 |
+
n_timesteps = 60
|
112 |
+
optim_sim_model_params_from_mano = False
|
113 |
+
penetration_determining = "sdf_of_canon"
|
114 |
+
train_with_forces_to_active = False
|
115 |
+
loss_scale_coef = 1000.0
|
116 |
+
use_same_contact_spring_k = False
|
117 |
+
use_optimizable_params = True #
|
118 |
+
train_residual_friction = True
|
119 |
+
mano_mult_const_after_cent = 1.0
|
120 |
+
optimize_glb_transformations = True
|
121 |
+
no_friction_constraint = False
|
122 |
+
optimize_active_object = True
|
123 |
+
loss_tangential_diff_coef = 0
|
124 |
+
optimize_with_intermediates = True
|
125 |
+
using_delta_glb_trans = False
|
126 |
+
train_multi_seqs = False
|
127 |
+
use_split_network = True
|
128 |
+
use_delta_bending = True
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
###### threshold, ks settings 1, optimize acts ######
|
139 |
+
# drive_pointset = "actions"
|
140 |
+
# fix_obj = True
|
141 |
+
# optimize_rules = False
|
142 |
+
# train_pointset_acts_via_deltas = True
|
143 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
|
144 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
|
145 |
+
###### threshold, ks settings 1, optimize acts ######
|
146 |
+
|
147 |
+
|
148 |
+
##### contact spring model settings ####
|
149 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
150 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
151 |
+
penetration_proj_k_to_robot = 4000000.0
|
152 |
+
##### contact spring model settings ####
|
153 |
+
|
154 |
+
|
155 |
+
# ###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
156 |
+
# drive_pointset = "states"
|
157 |
+
# fix_obj = True
|
158 |
+
# optimize_rules = False
|
159 |
+
# train_pointset_acts_via_deltas = False
|
160 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
|
161 |
+
# ###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
162 |
+
|
163 |
+
|
164 |
+
# ###### Stage 2: threshold, ks settings 1, optimize acts ######
|
165 |
+
# drive_pointset = "actions"
|
166 |
+
# fix_obj = True
|
167 |
+
# optimize_rules = False
|
168 |
+
# train_pointset_acts_via_deltas = True
|
169 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
|
170 |
+
# ###### Stage 2: threshold, ks settings 1, optimize acts ######
|
171 |
+
|
172 |
+
|
173 |
+
# ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
174 |
+
# drive_pointset = "actions"
|
175 |
+
# fix_obj = False
|
176 |
+
# optimize_rules = True
|
177 |
+
# train_pointset_acts_via_deltas = True
|
178 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
|
179 |
+
# ##### model parameters optimized from the MANO hand trajectory #####
|
180 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
|
181 |
+
# ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
182 |
+
|
183 |
+
|
184 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
185 |
+
# drive_pointset = "actions"
|
186 |
+
# fix_obj = False
|
187 |
+
# optimize_rules = False
|
188 |
+
# train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
|
189 |
+
# ##### model parameters optimized from the MANO hand expanded set trajectory #####
|
190 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
|
191 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
|
192 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
###### Retargeting Stage 1 ######
|
197 |
+
drive_pointset = "actions"
|
198 |
+
fix_obj = False
|
199 |
+
optimize_rules = False
|
200 |
+
train_pointset_acts_via_deltas = True
|
201 |
+
##### model parameters optimized from the MANO hand expanded set trajectory #####
|
202 |
+
ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
|
203 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
|
204 |
+
load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti.pth"
|
205 |
+
finger_cd_loss = 1.0
|
206 |
+
optimize_pointset_motion_only = True
|
207 |
+
###### Retargeting Stage 1 ######
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
use_opt_rigid_translations=True
|
212 |
+
|
213 |
+
train_def = True
|
214 |
+
optimizable_rigid_translations=True
|
215 |
+
|
216 |
+
nerf {
|
217 |
+
D = 8,
|
218 |
+
d_in = 4,
|
219 |
+
d_in_view = 3,
|
220 |
+
W = 256,
|
221 |
+
multires = 10,
|
222 |
+
multires_view = 4,
|
223 |
+
output_ch = 4,
|
224 |
+
skips=[4],
|
225 |
+
use_viewdirs=True
|
226 |
+
}
|
227 |
+
|
228 |
+
sdf_network {
|
229 |
+
d_out = 257,
|
230 |
+
d_in = 3,
|
231 |
+
d_hidden = 256,
|
232 |
+
n_layers = 8,
|
233 |
+
skip_in = [4],
|
234 |
+
multires = 6,
|
235 |
+
bias = 0.5,
|
236 |
+
scale = 1.0,
|
237 |
+
geometric_init = True,
|
238 |
+
weight_norm = True,
|
239 |
+
}
|
240 |
+
|
241 |
+
variance_network {
|
242 |
+
init_val = 0.3
|
243 |
+
}
|
244 |
+
|
245 |
+
rendering_network {
|
246 |
+
d_feature = 256,
|
247 |
+
mode = idr,
|
248 |
+
d_in = 9,
|
249 |
+
d_out = 3,
|
250 |
+
d_hidden = 256,
|
251 |
+
n_layers = 4,
|
252 |
+
weight_norm = True,
|
253 |
+
multires_view = 4,
|
254 |
+
squeeze_out = True,
|
255 |
+
}
|
256 |
+
|
257 |
+
neus_renderer {
|
258 |
+
n_samples = 64,
|
259 |
+
n_importance = 64,
|
260 |
+
n_outside = 0,
|
261 |
+
up_sample_steps = 4 ,
|
262 |
+
perturb = 1.0,
|
263 |
+
}
|
264 |
+
|
265 |
+
bending_network {
|
266 |
+
multires = 6,
|
267 |
+
bending_latent_size = 32,
|
268 |
+
d_in = 3,
|
269 |
+
rigidity_hidden_dimensions = 64,
|
270 |
+
rigidity_network_depth = 5,
|
271 |
+
use_rigidity_network = False,
|
272 |
+
bending_n_timesteps = 10,
|
273 |
+
}
|
274 |
+
}
|
confs_new/dyn_grab_pointset_points_dyn_retar_pts.conf
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_"
|
7 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_retar_to_shadow_pointset_"
|
8 |
+
|
9 |
+
recording = [
|
10 |
+
./,
|
11 |
+
./models
|
12 |
+
]
|
13 |
+
}
|
14 |
+
|
15 |
+
dataset {
|
16 |
+
data_dir = public_data/CASE_NAME/
|
17 |
+
render_cameras_name = cameras_sphere.npz
|
18 |
+
object_cameras_name = cameras_sphere.npz
|
19 |
+
obj_idx = 102
|
20 |
+
}
|
21 |
+
|
22 |
+
train {
|
23 |
+
learning_rate = 5e-4
|
24 |
+
learning_rate_alpha = 0.05
|
25 |
+
end_iter = 300000
|
26 |
+
|
27 |
+
batch_size = 1024 # 64
|
28 |
+
validate_resolution_level = 4
|
29 |
+
warm_up_end = 5000
|
30 |
+
anneal_end = 0
|
31 |
+
use_white_bkgd = False
|
32 |
+
|
33 |
+
# save_freq = 10000
|
34 |
+
save_freq = 10000
|
35 |
+
val_freq = 20 # 2500
|
36 |
+
val_mesh_freq = 20 # 5000
|
37 |
+
report_freq = 10
|
38 |
+
### igr weight ###
|
39 |
+
igr_weight = 0.1
|
40 |
+
mask_weight = 0.1
|
41 |
+
}
|
42 |
+
|
43 |
+
model {
|
44 |
+
|
45 |
+
optimize_dyn_actions = True
|
46 |
+
|
47 |
+
|
48 |
+
optimize_robot = True
|
49 |
+
|
50 |
+
use_penalty_based_friction = True
|
51 |
+
|
52 |
+
use_split_params = False
|
53 |
+
|
54 |
+
use_sqr_spring_stiffness = True
|
55 |
+
|
56 |
+
use_pre_proj_frictions = True
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
use_sqrt_dist = True
|
61 |
+
contact_maintaining_dist_thres = 0.2
|
62 |
+
|
63 |
+
robot_actions_diff_coef = 0.001
|
64 |
+
|
65 |
+
|
66 |
+
use_sdf_as_contact_dist = True
|
67 |
+
|
68 |
+
|
69 |
+
#
|
70 |
+
use_contact_dist_as_sdf = False
|
71 |
+
|
72 |
+
use_glb_proj_delta = True
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
# penetration_proj_k_to_robot = 30
|
77 |
+
penetrating_depth_penalty = 1.0
|
78 |
+
train_states = True
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
minn_dist_threshold = 0.000
|
83 |
+
obj_mass = 30.0
|
84 |
+
|
85 |
+
|
86 |
+
use_LBFGS = True
|
87 |
+
use_LBFGS = False
|
88 |
+
|
89 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
90 |
+
|
91 |
+
extract_delta_mesh = False
|
92 |
+
freeze_weights = True
|
93 |
+
gt_act_xs_def = False
|
94 |
+
use_bending_network = True
|
95 |
+
### for ts = 3 ###
|
96 |
+
# use_delta_bending = False
|
97 |
+
### for ts = 3 ###
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
103 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
104 |
+
|
105 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
106 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
107 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
108 |
+
|
109 |
+
bending_net_type = "active_force_field_v18"
|
110 |
+
sim_num_steps = 1000000
|
111 |
+
n_timesteps = 60
|
112 |
+
optim_sim_model_params_from_mano = False
|
113 |
+
penetration_determining = "sdf_of_canon"
|
114 |
+
train_with_forces_to_active = False
|
115 |
+
loss_scale_coef = 1000.0
|
116 |
+
use_same_contact_spring_k = False
|
117 |
+
use_optimizable_params = True #
|
118 |
+
train_residual_friction = True
|
119 |
+
mano_mult_const_after_cent = 1.0
|
120 |
+
optimize_glb_transformations = True
|
121 |
+
no_friction_constraint = False
|
122 |
+
optimize_active_object = True
|
123 |
+
loss_tangential_diff_coef = 0
|
124 |
+
optimize_with_intermediates = True
|
125 |
+
using_delta_glb_trans = False
|
126 |
+
train_multi_seqs = False
|
127 |
+
use_split_network = True
|
128 |
+
use_delta_bending = True
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
###### threshold, ks settings 1, optimize acts ######
|
139 |
+
# drive_pointset = "actions"
|
140 |
+
# fix_obj = True
|
141 |
+
# optimize_rules = False
|
142 |
+
# train_pointset_acts_via_deltas = True
|
143 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
|
144 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
|
145 |
+
###### threshold, ks settings 1, optimize acts ######
|
146 |
+
|
147 |
+
|
148 |
+
##### contact spring model settings ####
|
149 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
150 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
151 |
+
penetration_proj_k_to_robot = 4000000.0
|
152 |
+
##### contact spring model settings ####
|
153 |
+
|
154 |
+
|
155 |
+
# ###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
156 |
+
# drive_pointset = "states"
|
157 |
+
# fix_obj = True
|
158 |
+
# optimize_rules = False
|
159 |
+
# train_pointset_acts_via_deltas = False
|
160 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
|
161 |
+
# ###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
162 |
+
|
163 |
+
|
164 |
+
# ###### Stage 2: threshold, ks settings 1, optimize acts ######
|
165 |
+
# drive_pointset = "actions"
|
166 |
+
# fix_obj = True
|
167 |
+
# optimize_rules = False
|
168 |
+
# train_pointset_acts_via_deltas = True
|
169 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
|
170 |
+
# ###### Stage 2: threshold, ks settings 1, optimize acts ######
|
171 |
+
|
172 |
+
|
173 |
+
# ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
174 |
+
# drive_pointset = "actions"
|
175 |
+
# fix_obj = False
|
176 |
+
# optimize_rules = True
|
177 |
+
# train_pointset_acts_via_deltas = True
|
178 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
|
179 |
+
# ##### model parameters optimized from the MANO hand trajectory #####
|
180 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
|
181 |
+
# ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
182 |
+
|
183 |
+
|
184 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
185 |
+
# drive_pointset = "actions"
|
186 |
+
# fix_obj = False
|
187 |
+
# optimize_rules = False
|
188 |
+
# train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
|
189 |
+
# ##### model parameters optimized from the MANO hand expanded set trajectory #####
|
190 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
|
191 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
|
192 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
###### Retargeting Stage 1 ######
|
197 |
+
drive_pointset = "actions"
|
198 |
+
fix_obj = False
|
199 |
+
optimize_rules = False
|
200 |
+
train_pointset_acts_via_deltas = True
|
201 |
+
##### model parameters optimized from the MANO hand expanded set trajectory #####
|
202 |
+
ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
|
203 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
|
204 |
+
load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti_retar.pth"
|
205 |
+
finger_cd_loss = 1.0
|
206 |
+
optimize_pointset_motion_only = True
|
207 |
+
###### Retargeting Stage 1 ######
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
###### Retargeting Stage 2 ######
|
212 |
+
load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti_retar_optimized_arti.pth"
|
213 |
+
###### Retargeting Stage 2 ######
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
use_opt_rigid_translations=True
|
219 |
+
|
220 |
+
train_def = True
|
221 |
+
optimizable_rigid_translations=True
|
222 |
+
|
223 |
+
nerf {
|
224 |
+
D = 8,
|
225 |
+
d_in = 4,
|
226 |
+
d_in_view = 3,
|
227 |
+
W = 256,
|
228 |
+
multires = 10,
|
229 |
+
multires_view = 4,
|
230 |
+
output_ch = 4,
|
231 |
+
skips=[4],
|
232 |
+
use_viewdirs=True
|
233 |
+
}
|
234 |
+
|
235 |
+
sdf_network {
|
236 |
+
d_out = 257,
|
237 |
+
d_in = 3,
|
238 |
+
d_hidden = 256,
|
239 |
+
n_layers = 8,
|
240 |
+
skip_in = [4],
|
241 |
+
multires = 6,
|
242 |
+
bias = 0.5,
|
243 |
+
scale = 1.0,
|
244 |
+
geometric_init = True,
|
245 |
+
weight_norm = True,
|
246 |
+
}
|
247 |
+
|
248 |
+
variance_network {
|
249 |
+
init_val = 0.3
|
250 |
+
}
|
251 |
+
|
252 |
+
rendering_network {
|
253 |
+
d_feature = 256,
|
254 |
+
mode = idr,
|
255 |
+
d_in = 9,
|
256 |
+
d_out = 3,
|
257 |
+
d_hidden = 256,
|
258 |
+
n_layers = 4,
|
259 |
+
weight_norm = True,
|
260 |
+
multires_view = 4,
|
261 |
+
squeeze_out = True,
|
262 |
+
}
|
263 |
+
|
264 |
+
neus_renderer {
|
265 |
+
n_samples = 64,
|
266 |
+
n_importance = 64,
|
267 |
+
n_outside = 0,
|
268 |
+
up_sample_steps = 4 ,
|
269 |
+
perturb = 1.0,
|
270 |
+
}
|
271 |
+
|
272 |
+
bending_network {
|
273 |
+
multires = 6,
|
274 |
+
bending_latent_size = 32,
|
275 |
+
d_in = 3,
|
276 |
+
rigidity_hidden_dimensions = 64,
|
277 |
+
rigidity_network_depth = 5,
|
278 |
+
use_rigidity_network = False,
|
279 |
+
bending_n_timesteps = 10,
|
280 |
+
}
|
281 |
+
}
|
confs_new/dyn_grab_pointset_points_dyn_retar_pts_opts.conf
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_"
|
7 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_retar_to_shadow_pointset_"
|
8 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_retar_to_shadow_pointset_optrules_"
|
9 |
+
|
10 |
+
recording = [
|
11 |
+
./,
|
12 |
+
./models
|
13 |
+
]
|
14 |
+
}
|
15 |
+
|
16 |
+
dataset {
|
17 |
+
data_dir = public_data/CASE_NAME/
|
18 |
+
render_cameras_name = cameras_sphere.npz
|
19 |
+
object_cameras_name = cameras_sphere.npz
|
20 |
+
obj_idx = 102
|
21 |
+
}
|
22 |
+
|
23 |
+
train {
|
24 |
+
learning_rate = 5e-4
|
25 |
+
learning_rate_alpha = 0.05
|
26 |
+
end_iter = 300000
|
27 |
+
|
28 |
+
batch_size = 1024 # 64
|
29 |
+
validate_resolution_level = 4
|
30 |
+
warm_up_end = 5000
|
31 |
+
anneal_end = 0
|
32 |
+
use_white_bkgd = False
|
33 |
+
|
34 |
+
save_freq = 10000
|
35 |
+
val_freq = 20
|
36 |
+
val_mesh_freq = 20
|
37 |
+
report_freq = 10
|
38 |
+
igr_weight = 0.1
|
39 |
+
mask_weight = 0.1
|
40 |
+
}
|
41 |
+
|
42 |
+
model {
|
43 |
+
|
44 |
+
optimize_dyn_actions = True
|
45 |
+
|
46 |
+
|
47 |
+
optimize_robot = True
|
48 |
+
|
49 |
+
use_penalty_based_friction = True
|
50 |
+
|
51 |
+
use_split_params = False
|
52 |
+
|
53 |
+
use_sqr_spring_stiffness = True
|
54 |
+
|
55 |
+
use_pre_proj_frictions = True
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
use_sqrt_dist = True
|
60 |
+
contact_maintaining_dist_thres = 0.2
|
61 |
+
|
62 |
+
robot_actions_diff_coef = 0.001
|
63 |
+
|
64 |
+
|
65 |
+
use_sdf_as_contact_dist = True
|
66 |
+
|
67 |
+
|
68 |
+
#
|
69 |
+
use_contact_dist_as_sdf = False
|
70 |
+
|
71 |
+
use_glb_proj_delta = True
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
# penetration_proj_k_to_robot = 30
|
76 |
+
penetrating_depth_penalty = 1.0
|
77 |
+
train_states = True
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
minn_dist_threshold = 0.000
|
82 |
+
obj_mass = 30.0
|
83 |
+
|
84 |
+
|
85 |
+
use_LBFGS = True
|
86 |
+
use_LBFGS = False
|
87 |
+
|
88 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
89 |
+
|
90 |
+
extract_delta_mesh = False
|
91 |
+
freeze_weights = True
|
92 |
+
gt_act_xs_def = False
|
93 |
+
use_bending_network = True
|
94 |
+
### for ts = 3 ###
|
95 |
+
# use_delta_bending = False
|
96 |
+
### for ts = 3 ###
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
102 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
103 |
+
|
104 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
105 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
106 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
107 |
+
|
108 |
+
bending_net_type = "active_force_field_v18"
|
109 |
+
sim_num_steps = 1000000
|
110 |
+
n_timesteps = 60
|
111 |
+
optim_sim_model_params_from_mano = False
|
112 |
+
penetration_determining = "sdf_of_canon"
|
113 |
+
train_with_forces_to_active = False
|
114 |
+
loss_scale_coef = 1000.0
|
115 |
+
use_same_contact_spring_k = False
|
116 |
+
use_optimizable_params = True #
|
117 |
+
train_residual_friction = True
|
118 |
+
mano_mult_const_after_cent = 1.0
|
119 |
+
optimize_glb_transformations = True
|
120 |
+
no_friction_constraint = False
|
121 |
+
optimize_active_object = True
|
122 |
+
loss_tangential_diff_coef = 0
|
123 |
+
optimize_with_intermediates = True
|
124 |
+
using_delta_glb_trans = False
|
125 |
+
train_multi_seqs = False
|
126 |
+
use_split_network = True
|
127 |
+
use_delta_bending = True
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
###### threshold, ks settings 1, optimize acts ######
|
138 |
+
# drive_pointset = "actions"
|
139 |
+
# fix_obj = True
|
140 |
+
# optimize_rules = False
|
141 |
+
# train_pointset_acts_via_deltas = True
|
142 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
|
143 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
|
144 |
+
###### threshold, ks settings 1, optimize acts ######
|
145 |
+
|
146 |
+
|
147 |
+
##### contact spring model settings ####
|
148 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
149 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
150 |
+
penetration_proj_k_to_robot = 4000000.0
|
151 |
+
##### contact spring model settings ####
|
152 |
+
|
153 |
+
|
154 |
+
# ###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
155 |
+
# drive_pointset = "states"
|
156 |
+
# fix_obj = True
|
157 |
+
# optimize_rules = False
|
158 |
+
# train_pointset_acts_via_deltas = False
|
159 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
|
160 |
+
# ###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
161 |
+
|
162 |
+
|
163 |
+
# ###### Stage 2: threshold, ks settings 1, optimize acts ######
|
164 |
+
# drive_pointset = "actions"
|
165 |
+
# fix_obj = True
|
166 |
+
# optimize_rules = False
|
167 |
+
# train_pointset_acts_via_deltas = True
|
168 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
|
169 |
+
# ###### Stage 2: threshold, ks settings 1, optimize acts ######
|
170 |
+
|
171 |
+
|
172 |
+
# ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
173 |
+
# drive_pointset = "actions"
|
174 |
+
# fix_obj = False
|
175 |
+
# optimize_rules = True
|
176 |
+
# train_pointset_acts_via_deltas = True
|
177 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
|
178 |
+
# ##### model parameters optimized from the MANO hand trajectory #####
|
179 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
|
180 |
+
# ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
181 |
+
|
182 |
+
|
183 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
184 |
+
# drive_pointset = "actions"
|
185 |
+
# fix_obj = False
|
186 |
+
# optimize_rules = False
|
187 |
+
# train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
|
188 |
+
# ##### model parameters optimized from the MANO hand expanded set trajectory #####
|
189 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
|
190 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
|
191 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
###### Retargeting Stage 1 ######
|
196 |
+
drive_pointset = "actions"
|
197 |
+
fix_obj = False
|
198 |
+
optimize_rules = False
|
199 |
+
train_pointset_acts_via_deltas = True
|
200 |
+
##### model parameters optimized from the MANO hand expanded set trajectory #####
|
201 |
+
ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
|
202 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps_optimized_acts.pth"
|
203 |
+
load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti_retar.pth"
|
204 |
+
finger_cd_loss = 1.0
|
205 |
+
optimize_pointset_motion_only = True
|
206 |
+
###### Retargeting Stage 1 ######
|
207 |
+
|
208 |
+
|
209 |
+
###### Retargeting Stage 2 ######
|
210 |
+
load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti_retar_optimized_arti.pth"
|
211 |
+
###### Retargeting Stage 2 ######
|
212 |
+
|
213 |
+
|
214 |
+
###### Retargeting Stage 3 ######
|
215 |
+
load_optimized_init_transformations = "ckpts/grab/102/dyn_mano_shadow_arti_retar_optimized_arti_optimized_pts.pth"
|
216 |
+
optimize_anchored_pts = False
|
217 |
+
optimize_rules = True
|
218 |
+
optimize_pointset_motion_only = False
|
219 |
+
###### Retargeting Stage 3 ######
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
use_opt_rigid_translations=True
|
225 |
+
|
226 |
+
train_def = True
|
227 |
+
optimizable_rigid_translations=True
|
228 |
+
|
229 |
+
nerf {
|
230 |
+
D = 8,
|
231 |
+
d_in = 4,
|
232 |
+
d_in_view = 3,
|
233 |
+
W = 256,
|
234 |
+
multires = 10,
|
235 |
+
multires_view = 4,
|
236 |
+
output_ch = 4,
|
237 |
+
skips=[4],
|
238 |
+
use_viewdirs=True
|
239 |
+
}
|
240 |
+
|
241 |
+
sdf_network {
|
242 |
+
d_out = 257,
|
243 |
+
d_in = 3,
|
244 |
+
d_hidden = 256,
|
245 |
+
n_layers = 8,
|
246 |
+
skip_in = [4],
|
247 |
+
multires = 6,
|
248 |
+
bias = 0.5,
|
249 |
+
scale = 1.0,
|
250 |
+
geometric_init = True,
|
251 |
+
weight_norm = True,
|
252 |
+
}
|
253 |
+
|
254 |
+
variance_network {
|
255 |
+
init_val = 0.3
|
256 |
+
}
|
257 |
+
|
258 |
+
rendering_network {
|
259 |
+
d_feature = 256,
|
260 |
+
mode = idr,
|
261 |
+
d_in = 9,
|
262 |
+
d_out = 3,
|
263 |
+
d_hidden = 256,
|
264 |
+
n_layers = 4,
|
265 |
+
weight_norm = True,
|
266 |
+
multires_view = 4,
|
267 |
+
squeeze_out = True,
|
268 |
+
}
|
269 |
+
|
270 |
+
neus_renderer {
|
271 |
+
n_samples = 64,
|
272 |
+
n_importance = 64,
|
273 |
+
n_outside = 0,
|
274 |
+
up_sample_steps = 4 ,
|
275 |
+
perturb = 1.0,
|
276 |
+
}
|
277 |
+
|
278 |
+
bending_network {
|
279 |
+
multires = 6,
|
280 |
+
bending_latent_size = 32,
|
281 |
+
d_in = 3,
|
282 |
+
rigidity_hidden_dimensions = 64,
|
283 |
+
rigidity_network_depth = 5,
|
284 |
+
use_rigidity_network = False,
|
285 |
+
bending_n_timesteps = 10,
|
286 |
+
}
|
287 |
+
}
|
confs_new/dyn_grab_pointset_points_dyn_s1.conf
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_"
|
7 |
+
|
8 |
+
recording = [
|
9 |
+
./,
|
10 |
+
./models
|
11 |
+
]
|
12 |
+
}
|
13 |
+
|
14 |
+
dataset {
|
15 |
+
data_dir = public_data/CASE_NAME/
|
16 |
+
render_cameras_name = cameras_sphere.npz
|
17 |
+
object_cameras_name = cameras_sphere.npz
|
18 |
+
obj_idx = 102
|
19 |
+
}
|
20 |
+
|
21 |
+
train {
|
22 |
+
learning_rate = 5e-4
|
23 |
+
learning_rate_alpha = 0.05
|
24 |
+
end_iter = 300000
|
25 |
+
|
26 |
+
batch_size = 1024
|
27 |
+
validate_resolution_level = 4
|
28 |
+
warm_up_end = 5000
|
29 |
+
anneal_end = 0
|
30 |
+
use_white_bkgd = False
|
31 |
+
|
32 |
+
# save_freq = 10000
|
33 |
+
save_freq = 10000
|
34 |
+
val_freq = 20
|
35 |
+
val_mesh_freq = 20
|
36 |
+
report_freq = 10
|
37 |
+
igr_weight = 0.1
|
38 |
+
mask_weight = 0.1
|
39 |
+
}
|
40 |
+
|
41 |
+
model {
|
42 |
+
|
43 |
+
optimize_dyn_actions = True
|
44 |
+
|
45 |
+
|
46 |
+
optimize_robot = True
|
47 |
+
|
48 |
+
use_penalty_based_friction = True
|
49 |
+
|
50 |
+
use_split_params = False
|
51 |
+
|
52 |
+
use_sqr_spring_stiffness = True
|
53 |
+
|
54 |
+
use_pre_proj_frictions = True
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
use_sqrt_dist = True
|
59 |
+
contact_maintaining_dist_thres = 0.2
|
60 |
+
|
61 |
+
robot_actions_diff_coef = 0.001
|
62 |
+
|
63 |
+
|
64 |
+
use_sdf_as_contact_dist = True
|
65 |
+
|
66 |
+
|
67 |
+
#
|
68 |
+
use_contact_dist_as_sdf = False
|
69 |
+
|
70 |
+
use_glb_proj_delta = True
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
# penetration_proj_k_to_robot = 30
|
75 |
+
penetrating_depth_penalty = 1.0
|
76 |
+
train_states = True
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
minn_dist_threshold = 0.000
|
81 |
+
obj_mass = 30.0
|
82 |
+
|
83 |
+
|
84 |
+
use_LBFGS = True
|
85 |
+
use_LBFGS = False
|
86 |
+
|
87 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
88 |
+
|
89 |
+
extract_delta_mesh = False
|
90 |
+
freeze_weights = True
|
91 |
+
gt_act_xs_def = False
|
92 |
+
use_bending_network = True
|
93 |
+
### for ts = 3 ###
|
94 |
+
# use_delta_bending = False
|
95 |
+
### for ts = 3 ###
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
101 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
102 |
+
|
103 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
104 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
105 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
106 |
+
|
107 |
+
bending_net_type = "active_force_field_v18"
|
108 |
+
sim_num_steps = 1000000
|
109 |
+
n_timesteps = 60
|
110 |
+
optim_sim_model_params_from_mano = False
|
111 |
+
penetration_determining = "sdf_of_canon"
|
112 |
+
train_with_forces_to_active = False
|
113 |
+
loss_scale_coef = 1000.0
|
114 |
+
use_same_contact_spring_k = False
|
115 |
+
use_optimizable_params = True #
|
116 |
+
train_residual_friction = True
|
117 |
+
mano_mult_const_after_cent = 1.0
|
118 |
+
optimize_glb_transformations = True
|
119 |
+
no_friction_constraint = False
|
120 |
+
optimize_active_object = True
|
121 |
+
loss_tangential_diff_coef = 0
|
122 |
+
optimize_with_intermediates = True
|
123 |
+
using_delta_glb_trans = False
|
124 |
+
train_multi_seqs = False
|
125 |
+
use_split_network = True
|
126 |
+
use_delta_bending = True
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
###### threshold, ks settings 1, optimize acts ######
|
137 |
+
# drive_pointset = "actions"
|
138 |
+
# fix_obj = True
|
139 |
+
# optimize_rules = False
|
140 |
+
# train_pointset_acts_via_deltas = True
|
141 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
|
142 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
|
143 |
+
###### threshold, ks settings 1, optimize acts ######
|
144 |
+
|
145 |
+
|
146 |
+
##### contact spring model settings ####
|
147 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
148 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
149 |
+
penetration_proj_k_to_robot = 4000000.0
|
150 |
+
##### contact spring model settings ####
|
151 |
+
|
152 |
+
|
153 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
154 |
+
drive_pointset = "states"
|
155 |
+
fix_obj = True
|
156 |
+
optimize_rules = False
|
157 |
+
train_pointset_acts_via_deltas = False
|
158 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
|
159 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
160 |
+
|
161 |
+
|
162 |
+
# ###### Stage 2: threshold, ks settings 1, optimize acts ######
|
163 |
+
# drive_pointset = "actions"
|
164 |
+
# fix_obj = True
|
165 |
+
# optimize_rules = False
|
166 |
+
# train_pointset_acts_via_deltas = True
|
167 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
|
168 |
+
# ###### Stage 2: threshold, ks settings 1, optimize acts ######
|
169 |
+
|
170 |
+
|
171 |
+
# ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
172 |
+
# drive_pointset = "actions"
|
173 |
+
# fix_obj = False
|
174 |
+
# optimize_rules = True
|
175 |
+
# train_pointset_acts_via_deltas = True
|
176 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
|
177 |
+
# ##### model parameters optimized from the MANO hand trajectory #####
|
178 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
|
179 |
+
# ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
180 |
+
|
181 |
+
|
182 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
183 |
+
# drive_pointset = "actions"
|
184 |
+
# fix_obj = False
|
185 |
+
# optimize_rules = False
|
186 |
+
# train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
|
187 |
+
# ##### model parameters optimized from the MANO hand expanded set trajectory #####
|
188 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
|
189 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
|
190 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
191 |
+
|
192 |
+
|
193 |
+
use_opt_rigid_translations=True
|
194 |
+
|
195 |
+
train_def = True
|
196 |
+
optimizable_rigid_translations=True
|
197 |
+
|
198 |
+
nerf {
|
199 |
+
D = 8,
|
200 |
+
d_in = 4,
|
201 |
+
d_in_view = 3,
|
202 |
+
W = 256,
|
203 |
+
multires = 10,
|
204 |
+
multires_view = 4,
|
205 |
+
output_ch = 4,
|
206 |
+
skips=[4],
|
207 |
+
use_viewdirs=True
|
208 |
+
}
|
209 |
+
|
210 |
+
sdf_network {
|
211 |
+
d_out = 257,
|
212 |
+
d_in = 3,
|
213 |
+
d_hidden = 256,
|
214 |
+
n_layers = 8,
|
215 |
+
skip_in = [4],
|
216 |
+
multires = 6,
|
217 |
+
bias = 0.5,
|
218 |
+
scale = 1.0,
|
219 |
+
geometric_init = True,
|
220 |
+
weight_norm = True,
|
221 |
+
}
|
222 |
+
|
223 |
+
variance_network {
|
224 |
+
init_val = 0.3
|
225 |
+
}
|
226 |
+
|
227 |
+
rendering_network {
|
228 |
+
d_feature = 256,
|
229 |
+
mode = idr,
|
230 |
+
d_in = 9,
|
231 |
+
d_out = 3,
|
232 |
+
d_hidden = 256,
|
233 |
+
n_layers = 4,
|
234 |
+
weight_norm = True,
|
235 |
+
multires_view = 4,
|
236 |
+
squeeze_out = True,
|
237 |
+
}
|
238 |
+
|
239 |
+
neus_renderer {
|
240 |
+
n_samples = 64,
|
241 |
+
n_importance = 64,
|
242 |
+
n_outside = 0,
|
243 |
+
up_sample_steps = 4 ,
|
244 |
+
perturb = 1.0,
|
245 |
+
}
|
246 |
+
|
247 |
+
bending_network {
|
248 |
+
multires = 6,
|
249 |
+
bending_latent_size = 32,
|
250 |
+
d_in = 3,
|
251 |
+
rigidity_hidden_dimensions = 64,
|
252 |
+
rigidity_network_depth = 5,
|
253 |
+
use_rigidity_network = False,
|
254 |
+
bending_n_timesteps = 10,
|
255 |
+
}
|
256 |
+
}
|
confs_new/dyn_grab_pointset_points_dyn_s2.conf
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_"
|
7 |
+
|
8 |
+
recording = [
|
9 |
+
./,
|
10 |
+
./models
|
11 |
+
]
|
12 |
+
}
|
13 |
+
|
14 |
+
dataset {
|
15 |
+
data_dir = public_data/CASE_NAME/
|
16 |
+
render_cameras_name = cameras_sphere.npz
|
17 |
+
object_cameras_name = cameras_sphere.npz
|
18 |
+
obj_idx = 102
|
19 |
+
}
|
20 |
+
|
21 |
+
train {
|
22 |
+
learning_rate = 5e-4
|
23 |
+
learning_rate_alpha = 0.05
|
24 |
+
end_iter = 300000
|
25 |
+
|
26 |
+
batch_size = 1024 # 64
|
27 |
+
validate_resolution_level = 4
|
28 |
+
warm_up_end = 5000
|
29 |
+
anneal_end = 0
|
30 |
+
use_white_bkgd = False
|
31 |
+
|
32 |
+
# save_freq = 10000
|
33 |
+
save_freq = 10000
|
34 |
+
val_freq = 20 # 2500
|
35 |
+
val_mesh_freq = 20 # 5000
|
36 |
+
report_freq = 10
|
37 |
+
### igr weight ###
|
38 |
+
igr_weight = 0.1
|
39 |
+
mask_weight = 0.1
|
40 |
+
}
|
41 |
+
|
42 |
+
model {
|
43 |
+
|
44 |
+
optimize_dyn_actions = True
|
45 |
+
|
46 |
+
|
47 |
+
optimize_robot = True
|
48 |
+
|
49 |
+
use_penalty_based_friction = True
|
50 |
+
|
51 |
+
use_split_params = False
|
52 |
+
|
53 |
+
use_sqr_spring_stiffness = True
|
54 |
+
|
55 |
+
use_pre_proj_frictions = True
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
use_sqrt_dist = True
|
60 |
+
contact_maintaining_dist_thres = 0.2
|
61 |
+
|
62 |
+
robot_actions_diff_coef = 0.001
|
63 |
+
|
64 |
+
|
65 |
+
use_sdf_as_contact_dist = True
|
66 |
+
|
67 |
+
|
68 |
+
#
|
69 |
+
use_contact_dist_as_sdf = False
|
70 |
+
|
71 |
+
use_glb_proj_delta = True
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
# penetration_proj_k_to_robot = 30
|
76 |
+
penetrating_depth_penalty = 1.0
|
77 |
+
train_states = True
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
minn_dist_threshold = 0.000
|
82 |
+
obj_mass = 30.0
|
83 |
+
|
84 |
+
|
85 |
+
use_LBFGS = True
|
86 |
+
use_LBFGS = False
|
87 |
+
|
88 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
89 |
+
|
90 |
+
extract_delta_mesh = False
|
91 |
+
freeze_weights = True
|
92 |
+
gt_act_xs_def = False
|
93 |
+
use_bending_network = True
|
94 |
+
### for ts = 3 ###
|
95 |
+
# use_delta_bending = False
|
96 |
+
### for ts = 3 ###
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
102 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
103 |
+
|
104 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
105 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
106 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
107 |
+
|
108 |
+
bending_net_type = "active_force_field_v18"
|
109 |
+
sim_num_steps = 1000000
|
110 |
+
n_timesteps = 60
|
111 |
+
optim_sim_model_params_from_mano = False
|
112 |
+
penetration_determining = "sdf_of_canon"
|
113 |
+
train_with_forces_to_active = False
|
114 |
+
loss_scale_coef = 1000.0
|
115 |
+
use_same_contact_spring_k = False
|
116 |
+
use_optimizable_params = True #
|
117 |
+
train_residual_friction = True
|
118 |
+
mano_mult_const_after_cent = 1.0
|
119 |
+
optimize_glb_transformations = True
|
120 |
+
no_friction_constraint = False
|
121 |
+
optimize_active_object = True
|
122 |
+
loss_tangential_diff_coef = 0
|
123 |
+
optimize_with_intermediates = True
|
124 |
+
using_delta_glb_trans = False
|
125 |
+
train_multi_seqs = False
|
126 |
+
use_split_network = True
|
127 |
+
use_delta_bending = True
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
###### threshold, ks settings 1, optimize acts ######
|
138 |
+
# drive_pointset = "actions"
|
139 |
+
# fix_obj = True
|
140 |
+
# optimize_rules = False
|
141 |
+
# train_pointset_acts_via_deltas = True
|
142 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
|
143 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
|
144 |
+
###### threshold, ks settings 1, optimize acts ######
|
145 |
+
|
146 |
+
|
147 |
+
##### contact spring model settings ####
|
148 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
149 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
150 |
+
penetration_proj_k_to_robot = 4000000.0
|
151 |
+
##### contact spring model settings ####
|
152 |
+
|
153 |
+
|
154 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
155 |
+
drive_pointset = "states"
|
156 |
+
fix_obj = True
|
157 |
+
optimize_rules = False
|
158 |
+
train_pointset_acts_via_deltas = False
|
159 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
|
160 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
161 |
+
|
162 |
+
|
163 |
+
###### Stage 2: threshold, ks settings 1, optimize acts ######
|
164 |
+
drive_pointset = "actions"
|
165 |
+
fix_obj = True
|
166 |
+
optimize_rules = False
|
167 |
+
train_pointset_acts_via_deltas = True
|
168 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt" ## pre-optimized ckpts
|
169 |
+
load_optimized_init_actions = "exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_/checkpoints/ckpt_004000.pth"
|
170 |
+
###### Stage 2: threshold, ks settings 1, optimize acts ######
|
171 |
+
|
172 |
+
|
173 |
+
# ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
174 |
+
# drive_pointset = "actions"
|
175 |
+
# fix_obj = False
|
176 |
+
# optimize_rules = True
|
177 |
+
# train_pointset_acts_via_deltas = True
|
178 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
|
179 |
+
# ##### model parameters optimized from the MANO hand trajectory #####
|
180 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
|
181 |
+
# ###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
182 |
+
|
183 |
+
|
184 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
185 |
+
# drive_pointset = "actions"
|
186 |
+
# fix_obj = False
|
187 |
+
# optimize_rules = False
|
188 |
+
# train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
|
189 |
+
# ##### model parameters optimized from the MANO hand expanded set trajectory #####
|
190 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
|
191 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
|
192 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
193 |
+
|
194 |
+
|
195 |
+
use_opt_rigid_translations=True
|
196 |
+
|
197 |
+
train_def = True
|
198 |
+
optimizable_rigid_translations=True
|
199 |
+
|
200 |
+
nerf {
|
201 |
+
D = 8,
|
202 |
+
d_in = 4,
|
203 |
+
d_in_view = 3,
|
204 |
+
W = 256,
|
205 |
+
multires = 10,
|
206 |
+
multires_view = 4,
|
207 |
+
output_ch = 4,
|
208 |
+
skips=[4],
|
209 |
+
use_viewdirs=True
|
210 |
+
}
|
211 |
+
|
212 |
+
sdf_network {
|
213 |
+
d_out = 257,
|
214 |
+
d_in = 3,
|
215 |
+
d_hidden = 256,
|
216 |
+
n_layers = 8,
|
217 |
+
skip_in = [4],
|
218 |
+
multires = 6,
|
219 |
+
bias = 0.5,
|
220 |
+
scale = 1.0,
|
221 |
+
geometric_init = True,
|
222 |
+
weight_norm = True,
|
223 |
+
}
|
224 |
+
|
225 |
+
variance_network {
|
226 |
+
init_val = 0.3
|
227 |
+
}
|
228 |
+
|
229 |
+
rendering_network {
|
230 |
+
d_feature = 256,
|
231 |
+
mode = idr,
|
232 |
+
d_in = 9,
|
233 |
+
d_out = 3,
|
234 |
+
d_hidden = 256,
|
235 |
+
n_layers = 4,
|
236 |
+
weight_norm = True,
|
237 |
+
multires_view = 4,
|
238 |
+
squeeze_out = True,
|
239 |
+
}
|
240 |
+
|
241 |
+
neus_renderer {
|
242 |
+
n_samples = 64,
|
243 |
+
n_importance = 64,
|
244 |
+
n_outside = 0,
|
245 |
+
up_sample_steps = 4 ,
|
246 |
+
perturb = 1.0,
|
247 |
+
}
|
248 |
+
|
249 |
+
bending_network {
|
250 |
+
multires = 6,
|
251 |
+
bending_latent_size = 32,
|
252 |
+
d_in = 3,
|
253 |
+
rigidity_hidden_dimensions = 64,
|
254 |
+
rigidity_network_depth = 5,
|
255 |
+
use_rigidity_network = False,
|
256 |
+
bending_n_timesteps = 10,
|
257 |
+
}
|
258 |
+
}
|
confs_new/dyn_grab_pointset_points_dyn_s3.conf
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_optsysps_"
|
7 |
+
|
8 |
+
recording = [
|
9 |
+
./,
|
10 |
+
./models
|
11 |
+
]
|
12 |
+
}
|
13 |
+
|
14 |
+
dataset {
|
15 |
+
data_dir = public_data/CASE_NAME/
|
16 |
+
render_cameras_name = cameras_sphere.npz
|
17 |
+
object_cameras_name = cameras_sphere.npz
|
18 |
+
obj_idx = 102
|
19 |
+
}
|
20 |
+
|
21 |
+
train {
|
22 |
+
learning_rate = 5e-4
|
23 |
+
learning_rate_alpha = 0.05
|
24 |
+
end_iter = 300000
|
25 |
+
|
26 |
+
batch_size = 1024 # 64
|
27 |
+
validate_resolution_level = 4
|
28 |
+
warm_up_end = 5000
|
29 |
+
anneal_end = 0
|
30 |
+
use_white_bkgd = False
|
31 |
+
|
32 |
+
# save_freq = 10000
|
33 |
+
save_freq = 10000
|
34 |
+
val_freq = 20 # 2500
|
35 |
+
val_mesh_freq = 20 # 5000
|
36 |
+
report_freq = 10
|
37 |
+
### igr weight ###
|
38 |
+
igr_weight = 0.1
|
39 |
+
mask_weight = 0.1
|
40 |
+
}
|
41 |
+
|
42 |
+
model {
|
43 |
+
|
44 |
+
optimize_dyn_actions = True
|
45 |
+
|
46 |
+
|
47 |
+
optimize_robot = True
|
48 |
+
|
49 |
+
use_penalty_based_friction = True
|
50 |
+
|
51 |
+
use_split_params = False
|
52 |
+
|
53 |
+
use_sqr_spring_stiffness = True
|
54 |
+
|
55 |
+
use_pre_proj_frictions = True
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
use_sqrt_dist = True
|
60 |
+
contact_maintaining_dist_thres = 0.2
|
61 |
+
|
62 |
+
robot_actions_diff_coef = 0.001
|
63 |
+
|
64 |
+
|
65 |
+
use_sdf_as_contact_dist = True
|
66 |
+
|
67 |
+
|
68 |
+
#
|
69 |
+
use_contact_dist_as_sdf = False
|
70 |
+
|
71 |
+
use_glb_proj_delta = True
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
# penetration_proj_k_to_robot = 30
|
76 |
+
penetrating_depth_penalty = 1.0
|
77 |
+
train_states = True
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
minn_dist_threshold = 0.000
|
82 |
+
obj_mass = 30.0
|
83 |
+
|
84 |
+
|
85 |
+
use_LBFGS = True
|
86 |
+
use_LBFGS = False
|
87 |
+
|
88 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
89 |
+
|
90 |
+
extract_delta_mesh = False
|
91 |
+
freeze_weights = True
|
92 |
+
gt_act_xs_def = False
|
93 |
+
use_bending_network = True
|
94 |
+
### for ts = 3 ###
|
95 |
+
# use_delta_bending = False
|
96 |
+
### for ts = 3 ###
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
102 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
103 |
+
|
104 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
105 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
106 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
107 |
+
|
108 |
+
bending_net_type = "active_force_field_v18"
|
109 |
+
sim_num_steps = 1000000
|
110 |
+
n_timesteps = 60
|
111 |
+
optim_sim_model_params_from_mano = False
|
112 |
+
penetration_determining = "sdf_of_canon"
|
113 |
+
train_with_forces_to_active = False
|
114 |
+
loss_scale_coef = 1000.0
|
115 |
+
use_same_contact_spring_k = False
|
116 |
+
use_optimizable_params = True #
|
117 |
+
train_residual_friction = True
|
118 |
+
mano_mult_const_after_cent = 1.0
|
119 |
+
optimize_glb_transformations = True
|
120 |
+
no_friction_constraint = False
|
121 |
+
optimize_active_object = True
|
122 |
+
loss_tangential_diff_coef = 0
|
123 |
+
optimize_with_intermediates = True
|
124 |
+
using_delta_glb_trans = False
|
125 |
+
train_multi_seqs = False
|
126 |
+
use_split_network = True
|
127 |
+
use_delta_bending = True
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
###### threshold, ks settings 1, optimize acts ######
|
138 |
+
# drive_pointset = "actions"
|
139 |
+
# fix_obj = True
|
140 |
+
# optimize_rules = False
|
141 |
+
# train_pointset_acts_via_deltas = True
|
142 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
|
143 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
|
144 |
+
###### threshold, ks settings 1, optimize acts ######
|
145 |
+
|
146 |
+
|
147 |
+
##### contact spring model settings ####
|
148 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
149 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
150 |
+
penetration_proj_k_to_robot = 4000000.0
|
151 |
+
##### contact spring model settings ####
|
152 |
+
|
153 |
+
|
154 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
155 |
+
drive_pointset = "states"
|
156 |
+
fix_obj = True
|
157 |
+
optimize_rules = False
|
158 |
+
train_pointset_acts_via_deltas = False
|
159 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
|
160 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
161 |
+
|
162 |
+
|
163 |
+
###### Stage 2: threshold, ks settings 1, optimize acts ######
|
164 |
+
drive_pointset = "actions"
|
165 |
+
fix_obj = True
|
166 |
+
optimize_rules = False
|
167 |
+
train_pointset_acts_via_deltas = True
|
168 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
|
169 |
+
###### Stage 2: threshold, ks settings 1, optimize acts ######
|
170 |
+
|
171 |
+
|
172 |
+
###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
173 |
+
drive_pointset = "actions"
|
174 |
+
fix_obj = False
|
175 |
+
optimize_rules = True
|
176 |
+
train_pointset_acts_via_deltas = True
|
177 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
|
178 |
+
load_optimized_init_actions = "exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_/checkpoints/ckpt_004000.pth"
|
179 |
+
##### model parameters optimized from the MANO hand trajectory #####
|
180 |
+
ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
|
181 |
+
ckpt_fn = "exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_/checkpoints/ckpt_004000.pth"
|
182 |
+
###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
183 |
+
|
184 |
+
|
185 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
186 |
+
# drive_pointset = "actions"
|
187 |
+
# fix_obj = False
|
188 |
+
# optimize_rules = False
|
189 |
+
# train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
|
190 |
+
# ##### model parameters optimized from the MANO hand expanded set trajectory #####
|
191 |
+
# ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
|
192 |
+
# load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
|
193 |
+
# ###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
194 |
+
|
195 |
+
|
196 |
+
use_opt_rigid_translations=True
|
197 |
+
|
198 |
+
train_def = True
|
199 |
+
optimizable_rigid_translations=True
|
200 |
+
|
201 |
+
nerf {
|
202 |
+
D = 8,
|
203 |
+
d_in = 4,
|
204 |
+
d_in_view = 3,
|
205 |
+
W = 256,
|
206 |
+
multires = 10,
|
207 |
+
multires_view = 4,
|
208 |
+
output_ch = 4,
|
209 |
+
skips=[4],
|
210 |
+
use_viewdirs=True
|
211 |
+
}
|
212 |
+
|
213 |
+
sdf_network {
|
214 |
+
d_out = 257,
|
215 |
+
d_in = 3,
|
216 |
+
d_hidden = 256,
|
217 |
+
n_layers = 8,
|
218 |
+
skip_in = [4],
|
219 |
+
multires = 6,
|
220 |
+
bias = 0.5,
|
221 |
+
scale = 1.0,
|
222 |
+
geometric_init = True,
|
223 |
+
weight_norm = True,
|
224 |
+
}
|
225 |
+
|
226 |
+
variance_network {
|
227 |
+
init_val = 0.3
|
228 |
+
}
|
229 |
+
|
230 |
+
rendering_network {
|
231 |
+
d_feature = 256,
|
232 |
+
mode = idr,
|
233 |
+
d_in = 9,
|
234 |
+
d_out = 3,
|
235 |
+
d_hidden = 256,
|
236 |
+
n_layers = 4,
|
237 |
+
weight_norm = True,
|
238 |
+
multires_view = 4,
|
239 |
+
squeeze_out = True,
|
240 |
+
}
|
241 |
+
|
242 |
+
neus_renderer {
|
243 |
+
n_samples = 64,
|
244 |
+
n_importance = 64,
|
245 |
+
n_outside = 0,
|
246 |
+
up_sample_steps = 4 ,
|
247 |
+
perturb = 1.0,
|
248 |
+
}
|
249 |
+
|
250 |
+
bending_network {
|
251 |
+
multires = 6,
|
252 |
+
bending_latent_size = 32,
|
253 |
+
d_in = 3,
|
254 |
+
rigidity_hidden_dimensions = 64,
|
255 |
+
rigidity_network_depth = 5,
|
256 |
+
use_rigidity_network = False,
|
257 |
+
bending_n_timesteps = 10,
|
258 |
+
}
|
259 |
+
}
|
confs_new/dyn_grab_pointset_points_dyn_s4.conf
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_optsysps_optacts_"
|
7 |
+
|
8 |
+
recording = [
|
9 |
+
./,
|
10 |
+
./models
|
11 |
+
]
|
12 |
+
}
|
13 |
+
|
14 |
+
dataset {
|
15 |
+
data_dir = public_data/CASE_NAME/
|
16 |
+
render_cameras_name = cameras_sphere.npz
|
17 |
+
object_cameras_name = cameras_sphere.npz
|
18 |
+
obj_idx = 102
|
19 |
+
}
|
20 |
+
|
21 |
+
train {
|
22 |
+
learning_rate = 5e-4
|
23 |
+
learning_rate_alpha = 0.05
|
24 |
+
end_iter = 300000
|
25 |
+
|
26 |
+
batch_size = 1024 # 64
|
27 |
+
validate_resolution_level = 4
|
28 |
+
warm_up_end = 5000
|
29 |
+
anneal_end = 0
|
30 |
+
use_white_bkgd = False
|
31 |
+
|
32 |
+
# save_freq = 10000
|
33 |
+
save_freq = 10000
|
34 |
+
val_freq = 20 # 2500
|
35 |
+
val_mesh_freq = 20 # 5000
|
36 |
+
report_freq = 10
|
37 |
+
### igr weight ###
|
38 |
+
igr_weight = 0.1
|
39 |
+
mask_weight = 0.1
|
40 |
+
}
|
41 |
+
|
42 |
+
model {
|
43 |
+
|
44 |
+
optimize_dyn_actions = True
|
45 |
+
|
46 |
+
|
47 |
+
optimize_robot = True
|
48 |
+
|
49 |
+
use_penalty_based_friction = True
|
50 |
+
|
51 |
+
use_split_params = False
|
52 |
+
|
53 |
+
use_sqr_spring_stiffness = True
|
54 |
+
|
55 |
+
use_pre_proj_frictions = True
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
use_sqrt_dist = True
|
60 |
+
contact_maintaining_dist_thres = 0.2
|
61 |
+
|
62 |
+
robot_actions_diff_coef = 0.001
|
63 |
+
|
64 |
+
|
65 |
+
use_sdf_as_contact_dist = True
|
66 |
+
|
67 |
+
|
68 |
+
#
|
69 |
+
use_contact_dist_as_sdf = False
|
70 |
+
|
71 |
+
use_glb_proj_delta = True
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
# penetration_proj_k_to_robot = 30
|
76 |
+
penetrating_depth_penalty = 1.0
|
77 |
+
train_states = True
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
minn_dist_threshold = 0.000
|
82 |
+
obj_mass = 30.0
|
83 |
+
|
84 |
+
|
85 |
+
use_LBFGS = True
|
86 |
+
use_LBFGS = False
|
87 |
+
|
88 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
89 |
+
|
90 |
+
extract_delta_mesh = False
|
91 |
+
freeze_weights = True
|
92 |
+
gt_act_xs_def = False
|
93 |
+
use_bending_network = True
|
94 |
+
### for ts = 3 ###
|
95 |
+
# use_delta_bending = False
|
96 |
+
### for ts = 3 ###
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
102 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
103 |
+
|
104 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
105 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
106 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
107 |
+
|
108 |
+
bending_net_type = "active_force_field_v18"
|
109 |
+
sim_num_steps = 1000000
|
110 |
+
n_timesteps = 60
|
111 |
+
optim_sim_model_params_from_mano = False
|
112 |
+
penetration_determining = "sdf_of_canon"
|
113 |
+
train_with_forces_to_active = False
|
114 |
+
loss_scale_coef = 1000.0
|
115 |
+
use_same_contact_spring_k = False
|
116 |
+
use_optimizable_params = True #
|
117 |
+
train_residual_friction = True
|
118 |
+
mano_mult_const_after_cent = 1.0
|
119 |
+
optimize_glb_transformations = True
|
120 |
+
no_friction_constraint = False
|
121 |
+
optimize_active_object = True
|
122 |
+
loss_tangential_diff_coef = 0
|
123 |
+
optimize_with_intermediates = True
|
124 |
+
using_delta_glb_trans = False
|
125 |
+
train_multi_seqs = False
|
126 |
+
use_split_network = True
|
127 |
+
use_delta_bending = True
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
###### threshold, ks settings 1, optimize acts ######
|
138 |
+
# drive_pointset = "actions"
|
139 |
+
# fix_obj = True
|
140 |
+
# optimize_rules = False
|
141 |
+
# train_pointset_acts_via_deltas = True
|
142 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d4_/checkpoints/ckpt_002000.pth"
|
143 |
+
# load_optimized_init_actions = "/data/xueyi/NeuS/exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_dyn_mano_hand_seq_102_mouse_optdynactions_points_optrobo_offsetdriven_optrules_multk100_wfixobj_optdelta_radius0d2_/checkpoints/ckpt_008000.pth"
|
144 |
+
###### threshold, ks settings 1, optimize acts ######
|
145 |
+
|
146 |
+
|
147 |
+
##### contact spring model settings ####
|
148 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
149 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
150 |
+
penetration_proj_k_to_robot = 4000000.0
|
151 |
+
##### contact spring model settings ####
|
152 |
+
|
153 |
+
|
154 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
155 |
+
drive_pointset = "states"
|
156 |
+
fix_obj = True
|
157 |
+
optimize_rules = False
|
158 |
+
train_pointset_acts_via_deltas = False
|
159 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_arti.pth"
|
160 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
161 |
+
|
162 |
+
|
163 |
+
###### Stage 2: threshold, ks settings 1, optimize acts ######
|
164 |
+
drive_pointset = "actions"
|
165 |
+
fix_obj = True
|
166 |
+
optimize_rules = False
|
167 |
+
train_pointset_acts_via_deltas = True
|
168 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_states.pt"
|
169 |
+
###### Stage 2: threshold, ks settings 1, optimize acts ######
|
170 |
+
|
171 |
+
|
172 |
+
###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
173 |
+
drive_pointset = "actions"
|
174 |
+
fix_obj = False
|
175 |
+
optimize_rules = True
|
176 |
+
train_pointset_acts_via_deltas = True
|
177 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_acts.pt"
|
178 |
+
##### model parameters optimized from the MANO hand trajectory #####
|
179 |
+
ckpt_fn = "ckpts/grab/102/dyn_mano_opts.pt"
|
180 |
+
###### Stage 3: threshold, ks settings 1, optimize params from acts ######
|
181 |
+
|
182 |
+
|
183 |
+
###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
184 |
+
drive_pointset = "actions"
|
185 |
+
fix_obj = False
|
186 |
+
optimize_rules = False
|
187 |
+
train_pointset_acts_via_deltas = True ## pointset acts via deltas ###
|
188 |
+
##### model parameters optimized from the MANO hand expanded set trajectory #####
|
189 |
+
ckpt_fn = "ckpts/grab/102/dyn_mano_pointset_optimized_acts_optimized_ps.pth"
|
190 |
+
load_optimized_init_actions = "ckpts/grab/102/dyn_mano_pointset_optimized_acts.pth"
|
191 |
+
ckpt_fn = "exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_optsysps_/checkpoints/ckpt_044000.pth"
|
192 |
+
load_optimized_init_actions = "exp/hand_test_routine_2_light_color_wtime_active_passive/wmask_reverse_value_totviews_tag_train_retargeted_shadow_hand_seq_102_mano_pointset_acts_optstates_optacts_optsysps_/checkpoints/ckpt_044000.pth"
|
193 |
+
###### Stage 4: threshold, ks settings 1, optimize acts from optimized params ######
|
194 |
+
|
195 |
+
|
196 |
+
use_opt_rigid_translations=True
|
197 |
+
|
198 |
+
train_def = True
|
199 |
+
optimizable_rigid_translations=True
|
200 |
+
|
201 |
+
nerf {
|
202 |
+
D = 8,
|
203 |
+
d_in = 4,
|
204 |
+
d_in_view = 3,
|
205 |
+
W = 256,
|
206 |
+
multires = 10,
|
207 |
+
multires_view = 4,
|
208 |
+
output_ch = 4,
|
209 |
+
skips=[4],
|
210 |
+
use_viewdirs=True
|
211 |
+
}
|
212 |
+
|
213 |
+
sdf_network {
|
214 |
+
d_out = 257,
|
215 |
+
d_in = 3,
|
216 |
+
d_hidden = 256,
|
217 |
+
n_layers = 8,
|
218 |
+
skip_in = [4],
|
219 |
+
multires = 6,
|
220 |
+
bias = 0.5,
|
221 |
+
scale = 1.0,
|
222 |
+
geometric_init = True,
|
223 |
+
weight_norm = True,
|
224 |
+
}
|
225 |
+
|
226 |
+
variance_network {
|
227 |
+
init_val = 0.3
|
228 |
+
}
|
229 |
+
|
230 |
+
rendering_network {
|
231 |
+
d_feature = 256,
|
232 |
+
mode = idr,
|
233 |
+
d_in = 9,
|
234 |
+
d_out = 3,
|
235 |
+
d_hidden = 256,
|
236 |
+
n_layers = 4,
|
237 |
+
weight_norm = True,
|
238 |
+
multires_view = 4,
|
239 |
+
squeeze_out = True,
|
240 |
+
}
|
241 |
+
|
242 |
+
neus_renderer {
|
243 |
+
n_samples = 64,
|
244 |
+
n_importance = 64,
|
245 |
+
n_outside = 0,
|
246 |
+
up_sample_steps = 4 ,
|
247 |
+
perturb = 1.0,
|
248 |
+
}
|
249 |
+
|
250 |
+
bending_network {
|
251 |
+
multires = 6,
|
252 |
+
bending_latent_size = 32,
|
253 |
+
d_in = 3,
|
254 |
+
rigidity_hidden_dimensions = 64,
|
255 |
+
rigidity_network_depth = 5,
|
256 |
+
use_rigidity_network = False,
|
257 |
+
bending_n_timesteps = 10,
|
258 |
+
}
|
259 |
+
}
|
confs_new/dyn_grab_sparse_retar.conf
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
general {
|
2 |
+
|
3 |
+
|
4 |
+
base_exp_dir = exp/CASE_NAME/wmask
|
5 |
+
|
6 |
+
|
7 |
+
tag = "train_retargeted_shadow_hand_seq_102_mano_sparse_retargeting_"
|
8 |
+
|
9 |
+
recording = [
|
10 |
+
./,
|
11 |
+
./models
|
12 |
+
]
|
13 |
+
}
|
14 |
+
|
15 |
+
dataset {
|
16 |
+
data_dir = public_data/CASE_NAME/
|
17 |
+
render_cameras_name = cameras_sphere.npz
|
18 |
+
object_cameras_name = cameras_sphere.npz
|
19 |
+
obj_idx = 102
|
20 |
+
}
|
21 |
+
|
22 |
+
train {
|
23 |
+
learning_rate = 5e-4
|
24 |
+
learning_rate_alpha = 0.05
|
25 |
+
end_iter = 300000
|
26 |
+
|
27 |
+
batch_size = 1024
|
28 |
+
validate_resolution_level = 4
|
29 |
+
warm_up_end = 5000
|
30 |
+
anneal_end = 0
|
31 |
+
use_white_bkgd = False
|
32 |
+
|
33 |
+
# save_freq = 10000
|
34 |
+
save_freq = 10000
|
35 |
+
val_freq = 20
|
36 |
+
val_mesh_freq = 20
|
37 |
+
report_freq = 10
|
38 |
+
igr_weight = 0.1
|
39 |
+
mask_weight = 0.1
|
40 |
+
}
|
41 |
+
|
42 |
+
model {
|
43 |
+
|
44 |
+
optimize_dyn_actions = True
|
45 |
+
|
46 |
+
|
47 |
+
optimize_robot = True
|
48 |
+
|
49 |
+
use_penalty_based_friction = True
|
50 |
+
|
51 |
+
use_split_params = False
|
52 |
+
|
53 |
+
use_sqr_spring_stiffness = True
|
54 |
+
|
55 |
+
use_pre_proj_frictions = True
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
use_sqrt_dist = True
|
60 |
+
contact_maintaining_dist_thres = 0.2
|
61 |
+
|
62 |
+
robot_actions_diff_coef = 0.001
|
63 |
+
|
64 |
+
|
65 |
+
use_sdf_as_contact_dist = True
|
66 |
+
|
67 |
+
|
68 |
+
#
|
69 |
+
use_contact_dist_as_sdf = False
|
70 |
+
|
71 |
+
use_glb_proj_delta = True
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
# penetration_proj_k_to_robot = 30
|
76 |
+
penetrating_depth_penalty = 1.0
|
77 |
+
train_states = True
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
minn_dist_threshold = 0.000
|
82 |
+
obj_mass = 30.0
|
83 |
+
|
84 |
+
|
85 |
+
use_LBFGS = True
|
86 |
+
use_LBFGS = False
|
87 |
+
|
88 |
+
use_mano_hand_for_test = False # use the dynamic mano model here #
|
89 |
+
|
90 |
+
extract_delta_mesh = False
|
91 |
+
freeze_weights = True
|
92 |
+
gt_act_xs_def = False
|
93 |
+
use_bending_network = True
|
94 |
+
### for ts = 3 ###
|
95 |
+
# use_delta_bending = False
|
96 |
+
### for ts = 3 ###
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
sim_model_path = "rsc/shadow_hand_description/shadowhand_new.urdf"
|
102 |
+
mano_sim_model_path = "rsc/mano/mano_mean_wcollision_scaled_scaled_0_9507_nroot.urdf"
|
103 |
+
|
104 |
+
obj_sdf_fn = "data/grab/102/102_obj.npy"
|
105 |
+
kinematic_mano_gt_sv_fn = "data/grab/102/102_sv_dict.npy"
|
106 |
+
scaled_obj_mesh_fn = "data/grab/102/102_obj.obj"
|
107 |
+
|
108 |
+
bending_net_type = "active_force_field_v18"
|
109 |
+
sim_num_steps = 1000000
|
110 |
+
n_timesteps = 60
|
111 |
+
optim_sim_model_params_from_mano = False
|
112 |
+
penetration_determining = "sdf_of_canon"
|
113 |
+
train_with_forces_to_active = False
|
114 |
+
loss_scale_coef = 1000.0
|
115 |
+
use_same_contact_spring_k = False
|
116 |
+
use_optimizable_params = True #
|
117 |
+
train_residual_friction = True
|
118 |
+
mano_mult_const_after_cent = 1.0
|
119 |
+
optimize_glb_transformations = True
|
120 |
+
no_friction_constraint = False
|
121 |
+
optimize_active_object = True
|
122 |
+
loss_tangential_diff_coef = 0
|
123 |
+
optimize_with_intermediates = True
|
124 |
+
using_delta_glb_trans = False
|
125 |
+
train_multi_seqs = False
|
126 |
+
use_split_network = True
|
127 |
+
use_delta_bending = True
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
##### contact spring model settings ####
|
133 |
+
minn_dist_threshold_robot_to_obj = 0.1
|
134 |
+
penetration_proj_k_to_robot_friction = 10000000.0
|
135 |
+
penetration_proj_k_to_robot = 4000000.0
|
136 |
+
##### contact spring model settings ####
|
137 |
+
|
138 |
+
|
139 |
+
###### ######
|
140 |
+
drive_pointset = "states"
|
141 |
+
fix_obj = True
|
142 |
+
optimize_rules = False
|
143 |
+
train_pointset_acts_via_deltas = False
|
144 |
+
load_optimized_init_actions = ""
|
145 |
+
load_optimized_init_transformations = ""
|
146 |
+
ckpt_fn = ""
|
147 |
+
retar_only_glb = True
|
148 |
+
# use_multi_stages = True
|
149 |
+
###### Stage 1: threshold, ks settings 1, optimize offsets ######
|
150 |
+
|
151 |
+
use_opt_rigid_translations=True
|
152 |
+
|
153 |
+
train_def = True
|
154 |
+
optimizable_rigid_translations=True
|
155 |
+
|
156 |
+
nerf {
|
157 |
+
D = 8,
|
158 |
+
d_in = 4,
|
159 |
+
d_in_view = 3,
|
160 |
+
W = 256,
|
161 |
+
multires = 10,
|
162 |
+
multires_view = 4,
|
163 |
+
output_ch = 4,
|
164 |
+
skips=[4],
|
165 |
+
use_viewdirs=True
|
166 |
+
}
|
167 |
+
|
168 |
+
sdf_network {
|
169 |
+
d_out = 257,
|
170 |
+
d_in = 3,
|
171 |
+
d_hidden = 256,
|
172 |
+
n_layers = 8,
|
173 |
+
skip_in = [4],
|
174 |
+
multires = 6,
|
175 |
+
bias = 0.5,
|
176 |
+
scale = 1.0,
|
177 |
+
geometric_init = True,
|
178 |
+
weight_norm = True,
|
179 |
+
}
|
180 |
+
|
181 |
+
variance_network {
|
182 |
+
init_val = 0.3
|
183 |
+
}
|
184 |
+
|
185 |
+
rendering_network {
|
186 |
+
d_feature = 256,
|
187 |
+
mode = idr,
|
188 |
+
d_in = 9,
|
189 |
+
d_out = 3,
|
190 |
+
d_hidden = 256,
|
191 |
+
n_layers = 4,
|
192 |
+
weight_norm = True,
|
193 |
+
multires_view = 4,
|
194 |
+
squeeze_out = True,
|
195 |
+
}
|
196 |
+
|
197 |
+
neus_renderer {
|
198 |
+
n_samples = 64,
|
199 |
+
n_importance = 64,
|
200 |
+
n_outside = 0,
|
201 |
+
up_sample_steps = 4 ,
|
202 |
+
perturb = 1.0,
|
203 |
+
}
|
204 |
+
|
205 |
+
bending_network {
|
206 |
+
multires = 6,
|
207 |
+
bending_latent_size = 32,
|
208 |
+
d_in = 3,
|
209 |
+
rigidity_hidden_dimensions = 64,
|
210 |
+
rigidity_network_depth = 5,
|
211 |
+
use_rigidity_network = False,
|
212 |
+
bending_n_timesteps = 10,
|
213 |
+
}
|
214 |
+
}
|
exp_runner_stage_1.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/data_utils_torch.py
ADDED
@@ -0,0 +1,1547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Mesh data utilities."""
|
2 |
+
from re import I
|
3 |
+
# import matplotlib.pyplot as plt
|
4 |
+
# from mpl_toolkits import mplot3d # pylint: disable=unused-import
|
5 |
+
# from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
6 |
+
import networkx as nx
|
7 |
+
import numpy as np
|
8 |
+
import six
|
9 |
+
import os
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import models.fields as fields
|
13 |
+
# from options.options import opt
|
14 |
+
# from polygen_torch.
|
15 |
+
# from utils.constants import MASK_GRID_VALIE
|
16 |
+
|
17 |
+
try:
|
18 |
+
from torch_cluster import fps
|
19 |
+
except:
|
20 |
+
pass
|
21 |
+
|
22 |
+
MAX_RANGE = 0.1
|
23 |
+
MIN_RANGE = -0.1
|
24 |
+
|
25 |
+
# import open3d as o3d
|
26 |
+
|
27 |
+
def calculate_correspondences(last_mesh, bending_network, tot_timesteps, delta_bending):
|
28 |
+
# iterate all the timesteps and get the bended
|
29 |
+
timestep_to_pts = {
|
30 |
+
tot_timesteps - 1: last_mesh.detach().cpu().numpy()
|
31 |
+
}
|
32 |
+
if delta_bending:
|
33 |
+
|
34 |
+
for i_ts in range(tot_timesteps - 1, -1, -1):
|
35 |
+
if isinstance(bending_network, list):
|
36 |
+
tot_offsets = []
|
37 |
+
for i_bending, cur_bending_net in enumerate(bending_network):
|
38 |
+
cur_def_pts = cur_bending_net(last_mesh, i_ts)
|
39 |
+
tot_offsets.append(cur_def_pts - last_mesh)
|
40 |
+
tot_offsets = torch.stack(tot_offsets, dim=0)
|
41 |
+
tot_offsets = torch.sum(tot_offsets, dim=0)
|
42 |
+
last_mesh = last_mesh + tot_offsets
|
43 |
+
else:
|
44 |
+
last_mesh = bending_network(last_mesh, i_ts)
|
45 |
+
timestep_to_pts[i_ts - 1] = last_mesh.detach().cpu().numpy()
|
46 |
+
elif isinstance(bending_network, fields.BendingNetworkRigidTrans):
|
47 |
+
for i_ts in range(tot_timesteps - 1, -1, -1):
|
48 |
+
last_mesh = bending_network.forward_delta(last_mesh, i_ts)
|
49 |
+
timestep_to_pts[i_ts - 1] = last_mesh.detach().cpu().numpy()
|
50 |
+
else:
|
51 |
+
raise ValueError(f"the function is designed for delta bending")
|
52 |
+
return timestep_to_pts
|
53 |
+
# pass
|
54 |
+
|
55 |
+
def joint_infos_to_numpy(joint_infos):
|
56 |
+
joint_infos_np = []
|
57 |
+
for part_joint_info in joint_infos:
|
58 |
+
for zz in ["dir", "center"]:
|
59 |
+
# if isinstance(part_joint_info["axis"][zz], np.array):
|
60 |
+
part_joint_info["axis"][zz] = part_joint_info["axis"][zz].detach().numpy()
|
61 |
+
joint_infos_np.append(part_joint_info)
|
62 |
+
return joint_infos_np
|
63 |
+
|
64 |
+
|
65 |
+
def normalie_pc_bbox_batched(pc, rt_stats=False):
|
66 |
+
pc_min = torch.min(pc, dim=1, keepdim=True)[0]
|
67 |
+
pc_max = torch.max(pc, dim=1, keepdim=True)[0]
|
68 |
+
pc_center = 0.5 * (pc_min + pc_max)
|
69 |
+
|
70 |
+
pc = pc - pc_center
|
71 |
+
extents = pc_max - pc_min
|
72 |
+
scale = torch.sqrt(torch.sum(extents ** 2, dim=-1, keepdim=True))
|
73 |
+
|
74 |
+
pc = pc / torch.clamp(scale, min=1e-6)
|
75 |
+
if rt_stats:
|
76 |
+
return pc, pc_center, scale
|
77 |
+
else:
|
78 |
+
return pc
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
def scale_vertices_to_target_scale(vertices, target_scale):
|
83 |
+
# vertices: bsz x N x 3;
|
84 |
+
# target_scale: bsz x 1
|
85 |
+
# normalized_vertices = normalize_vertices_scale_torch(vertices)
|
86 |
+
normalized_vertices = normalie_pc_bbox_batched(vertices)
|
87 |
+
normalized_vertices = normalized_vertices * target_scale.unsqueeze(1)
|
88 |
+
return normalized_vertices
|
89 |
+
|
90 |
+
def compute_normals_o3d(verts, faces): #### assume no batching... ####
|
91 |
+
mesh = o3d.geometry.TriangleMesh()
|
92 |
+
# o3d_mesh_b.vertices = verts_b
|
93 |
+
# o3d_mesh_b.triangles = np.array(faces_b, dtype=np.long)
|
94 |
+
mesh.vertices = verts.detach().cpu().numpy()
|
95 |
+
mesh.triangles = faces.detach().cpu().numpy().astype(np.long)
|
96 |
+
verts_normals = mesh.compute_vertex_normals(normalized=True)
|
97 |
+
verts_normals = torch.from_numpy(verts_normals, dtype=torch.float32).cuda()
|
98 |
+
return verts_normals
|
99 |
+
|
100 |
+
|
101 |
+
def get_vals_via_nearest_neighbours(pts_src, pts_tar, val_tar):
|
102 |
+
### n_src x 3 ---> n_src x n_tar
|
103 |
+
dist_src_tar = torch.sum((pts_src.unsqueeze(-2) - pts_tar.unsqueeze(0)) ** 2, dim=-1)
|
104 |
+
minn_dists_src_tar, minn_dists_tar_idxes = torch.min(dist_src_tar, dim=-1) ### n_src
|
105 |
+
selected_src_val = batched_index_select(values=val_tar, indices=minn_dists_tar_idxes, dim=0) ### n_src x dim
|
106 |
+
return selected_src_val
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
## sample conected componetn start from selected_verts
|
111 |
+
def sample_bfs_component(selected_vert, faces, max_num_grids):
|
112 |
+
vert_idx_to_adj_verts = {}
|
113 |
+
for i_f, cur_f in enumerate(faces):
|
114 |
+
# for i0, v0 in enumerate(cur_f):
|
115 |
+
for i0 in range(len(cur_f)):
|
116 |
+
v0 = cur_f[i0] - 1
|
117 |
+
i1 = (i0 + 1) % len(cur_f)
|
118 |
+
v1 = cur_f[i1] - 1
|
119 |
+
if v0 not in vert_idx_to_adj_verts:
|
120 |
+
vert_idx_to_adj_verts[v0] = {v1: 1}
|
121 |
+
else:
|
122 |
+
vert_idx_to_adj_verts[v0][v1] = 1
|
123 |
+
if v1 not in vert_idx_to_adj_verts:
|
124 |
+
vert_idx_to_adj_verts[v1] = {v0: 1}
|
125 |
+
else:
|
126 |
+
vert_idx_to_adj_verts[v1][v0] = 1
|
127 |
+
vert_idx_to_visited = {} # whether visisted here #
|
128 |
+
vis_que = [selected_vert]
|
129 |
+
vert_idx_to_visited[selected_vert] = 1
|
130 |
+
visited = [selected_vert]
|
131 |
+
while len(vis_que) > 0 and len(visited) < max_num_grids:
|
132 |
+
cur_frnt_vert = vis_que[0]
|
133 |
+
vis_que.pop(0)
|
134 |
+
if cur_frnt_vert in vert_idx_to_adj_verts:
|
135 |
+
cur_frnt_vert_adjs = vert_idx_to_adj_verts[cur_frnt_vert]
|
136 |
+
for adj_vert in cur_frnt_vert_adjs:
|
137 |
+
if adj_vert not in vert_idx_to_visited:
|
138 |
+
vert_idx_to_visited[adj_vert] = 1
|
139 |
+
vis_que.append(adj_vert)
|
140 |
+
visited.append(adj_vert)
|
141 |
+
if len(visited) >= max_num_grids:
|
142 |
+
visited = visited[: max_num_grids - 1]
|
143 |
+
return visited
|
144 |
+
|
145 |
+
def select_faces_via_verts(selected_verts, faces):
|
146 |
+
if not isinstance(selected_verts, list):
|
147 |
+
selected_verts = selected_verts.tolist()
|
148 |
+
# selected_verts_dict = {ii: 1 for ii in selected_verts}
|
149 |
+
old_idx_to_new_idx = {v + 1: ii + 1 for ii, v in enumerate(selected_verts)} ####### v + 1: ii + 1 --> for the selected_verts
|
150 |
+
new_faces = []
|
151 |
+
for i_f, cur_f in enumerate(faces):
|
152 |
+
cur_new_f = []
|
153 |
+
valid = True
|
154 |
+
for cur_v in cur_f:
|
155 |
+
if cur_v in old_idx_to_new_idx:
|
156 |
+
cur_new_f.append(old_idx_to_new_idx[cur_v])
|
157 |
+
else:
|
158 |
+
valid = False
|
159 |
+
break
|
160 |
+
if valid:
|
161 |
+
new_faces.append(cur_new_f)
|
162 |
+
return new_faces
|
163 |
+
|
164 |
+
|
165 |
+
def convert_grid_content_to_grid_pts(content_value, grid_size):
|
166 |
+
flat_grid = torch.zeros([grid_size ** 3], dtype=torch.long)
|
167 |
+
cur_idx = flat_grid.size(0) - 1
|
168 |
+
while content_value > 0:
|
169 |
+
flat_grid[cur_idx] = content_value % grid_size
|
170 |
+
content_value = content_value // grid_size
|
171 |
+
cur_idx -= 1
|
172 |
+
grid_pts = flat_grid.contiguous().view(grid_size, grid_size, grid_size).contiguous()
|
173 |
+
return grid_pts
|
174 |
+
|
175 |
+
# 0.2
|
176 |
+
def warp_coord(sampled_gradients, val, reflect=False): # val from [0.0, 1.0] # from the 0.0
|
177 |
+
# assume single value as inputs
|
178 |
+
grad_values = sampled_gradients.tolist()
|
179 |
+
# mid_val
|
180 |
+
mid_val = grad_values[0] * 0.2 + grad_values[1] * 0.2 + grad_values[2] * 0.1
|
181 |
+
if reflect:
|
182 |
+
grad_values[3] = grad_values[1]
|
183 |
+
grad_values[4] = grad_values[0]
|
184 |
+
|
185 |
+
# if not reflect:
|
186 |
+
accum_val = 0.0
|
187 |
+
for i_val in range(len(grad_values)):
|
188 |
+
if val > 0.2 * (i_val + 1) and i_val < 4: # if i_val == 4, directly use the reamining length * corresponding gradient value
|
189 |
+
accum_val += grad_values[i_val] * 0.2
|
190 |
+
else:
|
191 |
+
accum_val += grad_values[i_val] * (val - 0.2 * i_val)
|
192 |
+
break
|
193 |
+
return accum_val # modified value
|
194 |
+
|
195 |
+
def random_shift(vertices, shift_factor=0.25):
|
196 |
+
"""Apply random shift to vertices."""
|
197 |
+
# max_shift_pos = tf.cast(255 - tf.reduce_max(vertices, axis=0), tf.float32)
|
198 |
+
|
199 |
+
# max_shift_pos = tf.maximum(max_shift_pos, 1e-9)
|
200 |
+
|
201 |
+
# max_shift_neg = tf.cast(tf.reduce_min(vertices, axis=0), tf.float32)
|
202 |
+
# max_shift_neg = tf.maximum(max_shift_neg, 1e-9)
|
203 |
+
|
204 |
+
# shift = tfd.TruncatedNormal(
|
205 |
+
# tf.zeros([1, 3]), shift_factor*255, -max_shift_neg,
|
206 |
+
# max_shift_pos).sample()
|
207 |
+
# shift = tf.cast(shift, tf.int32)
|
208 |
+
# vertices += shift
|
209 |
+
|
210 |
+
minn_tensor = torch.tensor([1e-9], dtype=torch.float32)
|
211 |
+
|
212 |
+
max_shift_pos = (255 - torch.max(vertices, dim=0)[0]).float()
|
213 |
+
max_shift_pos = torch.maximum(max_shift_pos, minn_tensor)
|
214 |
+
max_shift_neg = (torch.min(vertices, dim=0)[0]).float()
|
215 |
+
max_shift_neg = torch.maximum(max_shift_neg, minn_tensor)
|
216 |
+
|
217 |
+
shift = torch.zeros((1, 3), dtype=torch.float32)
|
218 |
+
# torch.nn.init.trunc_normal_(shift, 0., shift_factor * 255., -max_shift_neg, max_shift_pos)
|
219 |
+
for i_s in range(shift.size(-1)):
|
220 |
+
cur_axis_max_shift_neg = max_shift_neg[i_s].item()
|
221 |
+
cur_axis_max_shift_pos = max_shift_pos[i_s].item()
|
222 |
+
cur_axis_shift = torch.zeros((1,), dtype=torch.float32)
|
223 |
+
|
224 |
+
torch.nn.init.trunc_normal_(cur_axis_shift, 0., shift_factor * 255., -cur_axis_max_shift_neg, cur_axis_max_shift_pos)
|
225 |
+
shift[:, i_s] = cur_axis_shift.item()
|
226 |
+
|
227 |
+
shift = shift.long()
|
228 |
+
vertices += shift
|
229 |
+
|
230 |
+
return vertices
|
231 |
+
|
232 |
+
def safe_transpose(tsr, dima, dimb):
|
233 |
+
tsr = tsr.contiguous().transpose(dima, dimb).contiguous()
|
234 |
+
return tsr
|
235 |
+
|
236 |
+
def batched_index_select(values, indices, dim = 1):
|
237 |
+
value_dims = values.shape[(dim + 1):]
|
238 |
+
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
|
239 |
+
indices = indices[(..., *((None,) * len(value_dims)))]
|
240 |
+
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
|
241 |
+
value_expand_len = len(indices_shape) - (dim + 1)
|
242 |
+
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
|
243 |
+
|
244 |
+
value_expand_shape = [-1] * len(values.shape)
|
245 |
+
expand_slice = slice(dim, (dim + value_expand_len))
|
246 |
+
value_expand_shape[expand_slice] = indices.shape[expand_slice]
|
247 |
+
values = values.expand(*value_expand_shape)
|
248 |
+
|
249 |
+
dim += value_expand_len
|
250 |
+
return values.gather(dim, indices)
|
251 |
+
|
252 |
+
def read_obj_file_ours(obj_fn, sub_one=False):
|
253 |
+
vertices = []
|
254 |
+
faces = []
|
255 |
+
with open(obj_fn, "r") as rf:
|
256 |
+
for line in rf:
|
257 |
+
items = line.strip().split(" ")
|
258 |
+
if items[0] == 'v':
|
259 |
+
cur_verts = items[1:]
|
260 |
+
cur_verts = [float(vv) for vv in cur_verts]
|
261 |
+
vertices.append(cur_verts)
|
262 |
+
elif items[0] == 'f':
|
263 |
+
cur_faces = items[1:] # faces
|
264 |
+
cur_face_idxes = []
|
265 |
+
for cur_f in cur_faces:
|
266 |
+
try:
|
267 |
+
cur_f_idx = int(cur_f.split("/")[0])
|
268 |
+
except:
|
269 |
+
cur_f_idx = int(cur_f.split("//")[0])
|
270 |
+
cur_face_idxes.append(cur_f_idx if not sub_one else cur_f_idx - 1)
|
271 |
+
faces.append(cur_face_idxes)
|
272 |
+
rf.close()
|
273 |
+
vertices = np.array(vertices, dtype=np.float)
|
274 |
+
return vertices, faces
|
275 |
+
|
276 |
+
def read_obj_file(obj_file):
|
277 |
+
"""Read vertices and faces from already opened file."""
|
278 |
+
vertex_list = []
|
279 |
+
flat_vertices_list = []
|
280 |
+
flat_vertices_indices = {}
|
281 |
+
flat_triangles = []
|
282 |
+
|
283 |
+
for line in obj_file:
|
284 |
+
tokens = line.split()
|
285 |
+
if not tokens:
|
286 |
+
continue
|
287 |
+
line_type = tokens[0]
|
288 |
+
# We skip lines not starting with v or f.
|
289 |
+
if line_type == 'v': #
|
290 |
+
vertex_list.append([float(x) for x in tokens[1:]])
|
291 |
+
elif line_type == 'f':
|
292 |
+
triangle = []
|
293 |
+
for i in range(len(tokens) - 1):
|
294 |
+
vertex_name = tokens[i + 1]
|
295 |
+
if vertex_name in flat_vertices_indices: # triangles
|
296 |
+
triangle.append(flat_vertices_indices[vertex_name])
|
297 |
+
continue
|
298 |
+
flat_vertex = []
|
299 |
+
for index in six.ensure_str(vertex_name).split('/'):
|
300 |
+
if not index:
|
301 |
+
continue
|
302 |
+
# obj triangle indices are 1 indexed, so subtract 1 here.
|
303 |
+
flat_vertex += vertex_list[int(index) - 1]
|
304 |
+
flat_vertex_index = len(flat_vertices_list)
|
305 |
+
flat_vertices_list.append(flat_vertex)
|
306 |
+
# flat_vertex_index
|
307 |
+
flat_vertices_indices[vertex_name] = flat_vertex_index
|
308 |
+
triangle.append(flat_vertex_index)
|
309 |
+
flat_triangles.append(triangle)
|
310 |
+
|
311 |
+
return np.array(flat_vertices_list, dtype=np.float32), flat_triangles
|
312 |
+
|
313 |
+
|
314 |
+
def batched_index_select(values, indices, dim = 1):
|
315 |
+
value_dims = values.shape[(dim + 1):]
|
316 |
+
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
|
317 |
+
indices = indices[(..., *((None,) * len(value_dims)))]
|
318 |
+
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
|
319 |
+
value_expand_len = len(indices_shape) - (dim + 1)
|
320 |
+
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
|
321 |
+
|
322 |
+
value_expand_shape = [-1] * len(values.shape)
|
323 |
+
expand_slice = slice(dim, (dim + value_expand_len))
|
324 |
+
value_expand_shape[expand_slice] = indices.shape[expand_slice]
|
325 |
+
values = values.expand(*value_expand_shape)
|
326 |
+
|
327 |
+
dim += value_expand_len
|
328 |
+
return values.gather(dim, indices)
|
329 |
+
|
330 |
+
|
331 |
+
def safe_transpose(x, dim1, dim2):
|
332 |
+
x = x.contiguous().transpose(dim1, dim2).contiguous()
|
333 |
+
return x
|
334 |
+
|
335 |
+
def merge_meshes(vertices_list, faces_list):
|
336 |
+
tot_vertices = []
|
337 |
+
tot_faces = []
|
338 |
+
nn_verts = 0
|
339 |
+
for cur_vertices, cur_faces in zip(vertices_list, faces_list):
|
340 |
+
tot_vertices.append(cur_vertices)
|
341 |
+
new_cur_faces = []
|
342 |
+
for cur_face_idx in cur_faces:
|
343 |
+
new_cur_face_idx = [vert_idx + nn_verts for vert_idx in cur_face_idx]
|
344 |
+
new_cur_faces.append(new_cur_face_idx)
|
345 |
+
nn_verts += cur_vertices.shape[0]
|
346 |
+
tot_faces += new_cur_faces # get total-faces
|
347 |
+
tot_vertices = np.concatenate(tot_vertices, axis=0)
|
348 |
+
return tot_vertices, tot_faces
|
349 |
+
|
350 |
+
|
351 |
+
def read_obj(obj_path):
|
352 |
+
"""Open .obj file from the path provided and read vertices and faces."""
|
353 |
+
|
354 |
+
with open(obj_path) as obj_file:
|
355 |
+
return read_obj_file_ours(obj_path, sub_one=True)
|
356 |
+
# return read_obj_file(obj_file)
|
357 |
+
|
358 |
+
|
359 |
+
|
360 |
+
|
361 |
+
def center_vertices(vertices):
|
362 |
+
"""Translate the vertices so that bounding box is centered at zero."""
|
363 |
+
vert_min = vertices.min(axis=0)
|
364 |
+
vert_max = vertices.max(axis=0)
|
365 |
+
vert_center = 0.5 * (vert_min + vert_max)
|
366 |
+
return vertices - vert_center
|
367 |
+
|
368 |
+
|
369 |
+
def normalize_vertices_scale(vertices):
|
370 |
+
"""Scale the vertices so that the long diagonal of the bounding box is one."""
|
371 |
+
vert_min = vertices.min(axis=0)
|
372 |
+
vert_max = vertices.max(axis=0)
|
373 |
+
extents = vert_max - vert_min
|
374 |
+
scale = np.sqrt(np.sum(extents**2)) # normalize the diagonal line to 1.
|
375 |
+
return vertices / scale
|
376 |
+
|
377 |
+
|
378 |
+
def get_vertices_center(vertices):
|
379 |
+
vert_min = vertices.min(axis=0)
|
380 |
+
vert_max = vertices.max(axis=0)
|
381 |
+
vert_center = 0.5 * (vert_min + vert_max)
|
382 |
+
return vert_center
|
383 |
+
|
384 |
+
def get_batched_vertices_center(vertices):
|
385 |
+
vert_min = vertices.min(axis=1)
|
386 |
+
vert_max = vertices.max(axis=1)
|
387 |
+
vert_center = 0.5 * (vert_min + vert_max)
|
388 |
+
return vert_center
|
389 |
+
|
390 |
+
def get_vertices_scale(vertices):
|
391 |
+
vert_min = vertices.min(axis=0)
|
392 |
+
vert_max = vertices.max(axis=0)
|
393 |
+
extents = vert_max - vert_min
|
394 |
+
scale = np.sqrt(np.sum(extents**2))
|
395 |
+
return scale
|
396 |
+
|
397 |
+
def quantize_verts(verts, n_bits=8, min_range=None, max_range=None):
|
398 |
+
"""Convert vertices in [-1., 1.] to discrete values in [0, n_bits**2 - 1]."""
|
399 |
+
min_range = -0.5 if min_range is None else min_range
|
400 |
+
max_range = 0.5 if max_range is None else max_range
|
401 |
+
range_quantize = 2**n_bits - 1
|
402 |
+
verts_quantize = (verts - min_range) * range_quantize / (
|
403 |
+
max_range - min_range)
|
404 |
+
return verts_quantize.astype('int32')
|
405 |
+
|
406 |
+
def quantize_verts_torch(verts, n_bits=8, min_range=None, max_range=None):
|
407 |
+
min_range = -0.5 if min_range is None else min_range
|
408 |
+
max_range = 0.5 if max_range is None else max_range
|
409 |
+
range_quantize = 2**n_bits - 1
|
410 |
+
verts_quantize = (verts - min_range) * range_quantize / (
|
411 |
+
max_range - min_range)
|
412 |
+
return verts_quantize.long()
|
413 |
+
|
414 |
+
def dequantize_verts(verts, n_bits=8, add_noise=False, min_range=None, max_range=None):
|
415 |
+
"""Convert quantized vertices to floats."""
|
416 |
+
min_range = -0.5 if min_range is None else min_range
|
417 |
+
max_range = 0.5 if max_range is None else max_range
|
418 |
+
range_quantize = 2**n_bits - 1
|
419 |
+
verts = verts.astype('float32')
|
420 |
+
verts = verts * (max_range - min_range) / range_quantize + min_range
|
421 |
+
if add_noise:
|
422 |
+
verts += np.random.uniform(size=verts.shape) * (1 / range_quantize)
|
423 |
+
return verts
|
424 |
+
|
425 |
+
def dequantize_verts_torch(verts, n_bits=8, add_noise=False, min_range=None, max_range=None):
|
426 |
+
min_range = -0.5 if min_range is None else min_range
|
427 |
+
max_range = 0.5 if max_range is None else max_range
|
428 |
+
range_quantize = 2**n_bits - 1
|
429 |
+
verts = verts.float()
|
430 |
+
verts = verts * (max_range - min_range) / range_quantize + min_range
|
431 |
+
# if add_noise:
|
432 |
+
# verts += np.random.uniform(size=verts.shape) * (1 / range_quantize)
|
433 |
+
return verts
|
434 |
+
|
435 |
+
|
436 |
+
### dump vertices and faces to the obj file
|
437 |
+
def write_obj(vertices, faces, file_path, transpose=True, scale=1.):
|
438 |
+
"""Write vertices and faces to obj."""
|
439 |
+
if transpose:
|
440 |
+
vertices = vertices[:, [1, 2, 0]]
|
441 |
+
vertices *= scale
|
442 |
+
if faces is not None:
|
443 |
+
if min(min(faces)) == 0:
|
444 |
+
f_add = 1
|
445 |
+
else:
|
446 |
+
f_add = 0
|
447 |
+
else:
|
448 |
+
faces = []
|
449 |
+
with open(file_path, 'w') as f:
|
450 |
+
for v in vertices:
|
451 |
+
f.write('v {} {} {}\n'.format(v[0], v[1], v[2]))
|
452 |
+
for face in faces:
|
453 |
+
line = 'f'
|
454 |
+
for i in face:
|
455 |
+
line += ' {}'.format(i + f_add)
|
456 |
+
line += '\n'
|
457 |
+
f.write(line)
|
458 |
+
|
459 |
+
|
460 |
+
def face_to_cycles(face):
|
461 |
+
"""Find cycles in face."""
|
462 |
+
g = nx.Graph()
|
463 |
+
for v in range(len(face) - 1):
|
464 |
+
g.add_edge(face[v], face[v + 1])
|
465 |
+
g.add_edge(face[-1], face[0])
|
466 |
+
return list(nx.cycle_basis(g))
|
467 |
+
|
468 |
+
|
469 |
+
def flatten_faces(faces):
|
470 |
+
"""Converts from list of faces to flat face array with stopping indices."""
|
471 |
+
if not faces:
|
472 |
+
return np.array([0])
|
473 |
+
else:
|
474 |
+
l = [f + [-1] for f in faces[:-1]]
|
475 |
+
l += [faces[-1] + [-2]]
|
476 |
+
return np.array([item for sublist in l for item in sublist]) + 2 # pylint: disable=g-complex-comprehension
|
477 |
+
|
478 |
+
|
479 |
+
def unflatten_faces(flat_faces):
|
480 |
+
"""Converts from flat face sequence to a list of separate faces."""
|
481 |
+
def group(seq):
|
482 |
+
g = []
|
483 |
+
for el in seq:
|
484 |
+
if el == 0 or el == -1:
|
485 |
+
yield g
|
486 |
+
g = []
|
487 |
+
else:
|
488 |
+
g.append(el - 1)
|
489 |
+
yield g
|
490 |
+
outputs = list(group(flat_faces - 1))[:-1]
|
491 |
+
# Remove empty faces
|
492 |
+
return [o for o in outputs if len(o) > 2]
|
493 |
+
|
494 |
+
|
495 |
+
|
496 |
+
def quantize_process_mesh(vertices, faces, tris=None, quantization_bits=8, remove_du=True):
|
497 |
+
"""Quantize vertices, remove resulting duplicates and reindex faces."""
|
498 |
+
vertices = quantize_verts(vertices, quantization_bits)
|
499 |
+
vertices, inv = np.unique(vertices, axis=0, return_inverse=True) # return inverse and unique the vertices
|
500 |
+
|
501 |
+
#
|
502 |
+
if opt.dataset.sort_dist:
|
503 |
+
if opt.model.debug:
|
504 |
+
print("sorting via dist...")
|
505 |
+
vertices_max = np.max(vertices, axis=0)
|
506 |
+
vertices_min = np.min(vertices, axis=0)
|
507 |
+
dist_vertices = np.minimum(np.abs(vertices - np.array([[vertices_min[0], vertices_min[1], 0]])), np.abs(vertices - np.array([[vertices_max[0], vertices_max[1], 0]])))
|
508 |
+
dist_vertices = np.concatenate([dist_vertices[:, 0:1] + dist_vertices[:, 1:2], dist_vertices[:, 2:]], axis=-1)
|
509 |
+
sort_inds = np.lexsort(dist_vertices.T)
|
510 |
+
else:
|
511 |
+
# Sort vertices by z then y then x.
|
512 |
+
sort_inds = np.lexsort(vertices.T) # sorted indices...
|
513 |
+
vertices = vertices[sort_inds]
|
514 |
+
|
515 |
+
# Re-index faces and tris to re-ordered vertices.
|
516 |
+
faces = [np.argsort(sort_inds)[inv[f]] for f in faces]
|
517 |
+
if tris is not None:
|
518 |
+
tris = np.array([np.argsort(sort_inds)[inv[t]] for t in tris])
|
519 |
+
|
520 |
+
# Merging duplicate vertices and re-indexing the faces causes some faces to
|
521 |
+
# contain loops (e.g [2, 3, 5, 2, 4]). Split these faces into distinct
|
522 |
+
# sub-faces.
|
523 |
+
sub_faces = []
|
524 |
+
for f in faces:
|
525 |
+
cliques = face_to_cycles(f)
|
526 |
+
for c in cliques:
|
527 |
+
c_length = len(c)
|
528 |
+
# Only append faces with more than two verts.
|
529 |
+
if c_length > 2:
|
530 |
+
d = np.argmin(c)
|
531 |
+
# Cyclically permute faces just that first index is the smallest.
|
532 |
+
sub_faces.append([c[(d + i) % c_length] for i in range(c_length)])
|
533 |
+
faces = sub_faces
|
534 |
+
if tris is not None:
|
535 |
+
tris = np.array([v for v in tris if len(set(v)) == len(v)])
|
536 |
+
|
537 |
+
# Sort faces by lowest vertex indices. If two faces have the same lowest
|
538 |
+
# index then sort by next lowest and so on.
|
539 |
+
faces.sort(key=lambda f: tuple(sorted(f)))
|
540 |
+
if tris is not None:
|
541 |
+
tris = tris.tolist()
|
542 |
+
tris.sort(key=lambda f: tuple(sorted(f)))
|
543 |
+
tris = np.array(tris)
|
544 |
+
|
545 |
+
# After removing degenerate faces some vertices are now unreferenced. # Vertices
|
546 |
+
# Remove these. # Vertices
|
547 |
+
num_verts = vertices.shape[0]
|
548 |
+
# print(f"remove_du: {remove_du}")
|
549 |
+
if remove_du: ##### num_verts
|
550 |
+
print("Removing du..")
|
551 |
+
try:
|
552 |
+
vert_connected = np.equal(
|
553 |
+
np.arange(num_verts)[:, None], np.hstack(faces)[None]).any(axis=-1)
|
554 |
+
vertices = vertices[vert_connected]
|
555 |
+
|
556 |
+
|
557 |
+
# Re-index faces and tris to re-ordered vertices.
|
558 |
+
vert_indices = (
|
559 |
+
np.arange(num_verts) - np.cumsum(1 - vert_connected.astype('int')))
|
560 |
+
faces = [vert_indices[f].tolist() for f in faces]
|
561 |
+
except:
|
562 |
+
pass
|
563 |
+
if tris is not None:
|
564 |
+
tris = np.array([vert_indices[t].tolist() for t in tris])
|
565 |
+
|
566 |
+
return vertices, faces, tris
|
567 |
+
|
568 |
+
|
569 |
+
def process_mesh(vertices, faces, quantization_bits=8, recenter_mesh=True, remove_du=True):
|
570 |
+
"""Process mesh vertices and faces."""
|
571 |
+
|
572 |
+
|
573 |
+
|
574 |
+
# Transpose so that z-axis is vertical.
|
575 |
+
vertices = vertices[:, [2, 0, 1]]
|
576 |
+
|
577 |
+
if recenter_mesh:
|
578 |
+
# Translate the vertices so that bounding box is centered at zero.
|
579 |
+
vertices = center_vertices(vertices)
|
580 |
+
|
581 |
+
# Scale the vertices so that the long diagonal of the bounding box is equal
|
582 |
+
# to one.
|
583 |
+
vertices = normalize_vertices_scale(vertices)
|
584 |
+
|
585 |
+
# Quantize and sort vertices, remove resulting duplicates, sort and reindex
|
586 |
+
# faces.
|
587 |
+
vertices, faces, _ = quantize_process_mesh(
|
588 |
+
vertices, faces, quantization_bits=quantization_bits, remove_du=remove_du) ##### quantize_process_mesh
|
589 |
+
|
590 |
+
# unflatten_faces = np.array(faces, dtype=np.long) ### start from zero
|
591 |
+
|
592 |
+
# Flatten faces and add 'new face' = 1 and 'stop' = 0 tokens.
|
593 |
+
faces = flatten_faces(faces)
|
594 |
+
|
595 |
+
# Discard degenerate meshes without faces.
|
596 |
+
return {
|
597 |
+
'vertices': vertices,
|
598 |
+
'faces': faces,
|
599 |
+
}
|
600 |
+
|
601 |
+
|
602 |
+
def process_mesh_ours(vertices, faces, quantization_bits=8, recenter_mesh=True, remove_du=True):
|
603 |
+
"""Process mesh vertices and faces."""
|
604 |
+
# Transpose so that z-axis is vertical.
|
605 |
+
vertices = vertices[:, [2, 0, 1]]
|
606 |
+
|
607 |
+
if recenter_mesh:
|
608 |
+
# Translate the vertices so that bounding box is centered at zero.
|
609 |
+
vertices = center_vertices(vertices)
|
610 |
+
|
611 |
+
# Scale the vertices so that the long diagonal of the bounding box is equal
|
612 |
+
# to one.
|
613 |
+
vertices = normalize_vertices_scale(vertices)
|
614 |
+
|
615 |
+
# Quantize and sort vertices, remove resulting duplicates, sort and reindex
|
616 |
+
# faces.
|
617 |
+
quant_vertices, faces, _ = quantize_process_mesh(
|
618 |
+
vertices, faces, quantization_bits=quantization_bits, remove_du=remove_du) ##### quantize_process_mesh
|
619 |
+
vertices = dequantize_verts(quant_vertices) #### dequantize vertices ####
|
620 |
+
### vertices: nn_verts x 3
|
621 |
+
# try:
|
622 |
+
# # print("faces", faces)
|
623 |
+
# unflatten_faces = np.array(faces, dtype=np.long)
|
624 |
+
# except:
|
625 |
+
# print("faces", faces)
|
626 |
+
# raise ValueError("Something bad happened when processing meshes...")
|
627 |
+
|
628 |
+
# Flatten faces and add 'new face' = 1 and 'stop' = 0 tokens.
|
629 |
+
|
630 |
+
faces = flatten_faces(faces)
|
631 |
+
|
632 |
+
# Discard degenerate meshes without faces.
|
633 |
+
return {
|
634 |
+
'vertices': quant_vertices,
|
635 |
+
'faces': faces,
|
636 |
+
# 'unflatten_faces': unflatten_faces,
|
637 |
+
'dequant_vertices': vertices,
|
638 |
+
'class_label': 0
|
639 |
+
}
|
640 |
+
|
641 |
+
def read_mesh_from_obj_file(fn, quantization_bits=8, recenter_mesh=True, remove_du=True):
|
642 |
+
vertices, faces = read_obj(fn)
|
643 |
+
# print(vertices.shape)
|
644 |
+
mesh_dict = process_mesh_ours(vertices, faces, quantization_bits=quantization_bits, recenter_mesh=recenter_mesh, remove_du=remove_du)
|
645 |
+
return mesh_dict
|
646 |
+
|
647 |
+
def process_mesh_list(vertices, faces, quantization_bits=8, recenter_mesh=True):
|
648 |
+
"""Process mesh vertices and faces."""
|
649 |
+
|
650 |
+
vertices = [cur_vert[:, [2, 0, 1]] for cur_vert in vertices]
|
651 |
+
|
652 |
+
tot_vertices = np.concatenate(vertices, axis=0) # center and scale of those vertices
|
653 |
+
vert_center = get_vertices_center(tot_vertices)
|
654 |
+
vert_scale = get_vertices_scale(tot_vertices)
|
655 |
+
|
656 |
+
processed_vertices, processed_faces = [], []
|
657 |
+
|
658 |
+
for cur_verts, cur_faces in zip(vertices, faces):
|
659 |
+
# print(f"current vertices: {cur_verts.shape}, faces: {len(cur_faces)}")
|
660 |
+
normalized_verts = (cur_verts - vert_center) / vert_scale
|
661 |
+
cur_processed_verts, cur_processed_faces, _ = quantize_process_mesh(
|
662 |
+
normalized_verts, cur_faces, quantization_bits=quantization_bits
|
663 |
+
)
|
664 |
+
processed_vertices.append(cur_processed_verts)
|
665 |
+
processed_faces.append(cur_processed_faces)
|
666 |
+
vertices, faces = merge_meshes(processed_vertices, processed_faces)
|
667 |
+
faces = flatten_faces(faces=faces)
|
668 |
+
|
669 |
+
|
670 |
+
# Discard degenerate meshes without faces.
|
671 |
+
return {
|
672 |
+
'vertices': vertices,
|
673 |
+
|
674 |
+
'faces': faces,
|
675 |
+
|
676 |
+
}
|
677 |
+
|
678 |
+
|
679 |
+
def plot_sampled_meshes(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
|
680 |
+
|
681 |
+
if not os.path.exists(sv_mesh_folder):
|
682 |
+
os.mkdir(sv_mesh_folder)
|
683 |
+
|
684 |
+
part_vertex_samples = [v_sample['left'], v_sample['rgt']]
|
685 |
+
part_face_samples = [f_sample['left'], f_sample['rgt']]
|
686 |
+
|
687 |
+
|
688 |
+
tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
|
689 |
+
tot_n_part = 2
|
690 |
+
|
691 |
+
if predict_joint:
|
692 |
+
pred_dir = v_sample['joint_dir']
|
693 |
+
pred_pvp = v_sample['joint_pvp']
|
694 |
+
print("pred_dir", pred_dir.shape, pred_dir)
|
695 |
+
print("pred_pvp", pred_pvp.shape, pred_pvp)
|
696 |
+
else:
|
697 |
+
pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
|
698 |
+
|
699 |
+
|
700 |
+
|
701 |
+
|
702 |
+
tot_mesh_list = []
|
703 |
+
for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
|
704 |
+
mesh_list = []
|
705 |
+
for i_n in range(tot_n_samples):
|
706 |
+
mesh_list.append(
|
707 |
+
{
|
708 |
+
'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
|
709 |
+
'faces': unflatten_faces(
|
710 |
+
cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
|
711 |
+
}
|
712 |
+
)
|
713 |
+
tot_mesh_list.append(mesh_list)
|
714 |
+
# and write this obj file?
|
715 |
+
# write_obj(vertices, faces, file_path, transpose=True, scale=1.):
|
716 |
+
# write mesh objs
|
717 |
+
for i_n in range(tot_n_samples):
|
718 |
+
cur_mesh = mesh_list[i_n]
|
719 |
+
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
|
720 |
+
cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
|
721 |
+
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
|
722 |
+
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
|
723 |
+
|
724 |
+
|
725 |
+
|
726 |
+
###### plot mesh (predicted) ######
|
727 |
+
tot_samples_mesh_dict = []
|
728 |
+
for i_s in range(tot_n_samples):
|
729 |
+
cur_s_tot_vertices = []
|
730 |
+
cur_s_tot_faces = []
|
731 |
+
cur_s_n_vertices = 0
|
732 |
+
|
733 |
+
for i_p in range(tot_n_part):
|
734 |
+
cur_s_cur_part_mesh_dict = tot_mesh_list[i_p][i_s]
|
735 |
+
cur_s_cur_part_vertices, cur_s_cur_part_faces = cur_s_cur_part_mesh_dict['vertices'], \
|
736 |
+
cur_s_cur_part_mesh_dict['faces']
|
737 |
+
cur_s_cur_part_new_faces = []
|
738 |
+
for cur_s_cur_part_cur_face in cur_s_cur_part_faces:
|
739 |
+
cur_s_cur_part_cur_new_face = [fid + cur_s_n_vertices for fid in cur_s_cur_part_cur_face]
|
740 |
+
cur_s_cur_part_new_faces.append(cur_s_cur_part_cur_new_face)
|
741 |
+
cur_s_n_vertices += cur_s_cur_part_vertices.shape[0]
|
742 |
+
cur_s_tot_vertices.append(cur_s_cur_part_vertices)
|
743 |
+
cur_s_tot_faces += cur_s_cur_part_new_faces
|
744 |
+
|
745 |
+
cur_s_tot_vertices = np.concatenate(cur_s_tot_vertices, axis=0)
|
746 |
+
cur_s_mesh_dict = {
|
747 |
+
'vertices': cur_s_tot_vertices, 'faces': cur_s_tot_faces
|
748 |
+
}
|
749 |
+
tot_samples_mesh_dict.append(cur_s_mesh_dict)
|
750 |
+
|
751 |
+
for i_s in range(tot_n_samples):
|
752 |
+
cur_mesh = tot_samples_mesh_dict[i_s]
|
753 |
+
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
|
754 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_s}.obj")
|
755 |
+
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
|
756 |
+
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
|
757 |
+
###### plot mesh (predicted) ######
|
758 |
+
|
759 |
+
|
760 |
+
###### plot mesh (translated) ######
|
761 |
+
tot_samples_mesh_dict = []
|
762 |
+
for i_s in range(tot_n_samples):
|
763 |
+
cur_s_tot_vertices = []
|
764 |
+
cur_s_tot_faces = []
|
765 |
+
cur_s_n_vertices = 0
|
766 |
+
cur_s_pred_pvp = pred_pvp[i_s]
|
767 |
+
|
768 |
+
for i_p in range(tot_n_part):
|
769 |
+
cur_s_cur_part_mesh_dict = tot_mesh_list[i_p][i_s]
|
770 |
+
cur_s_cur_part_vertices, cur_s_cur_part_faces = cur_s_cur_part_mesh_dict['vertices'], \
|
771 |
+
cur_s_cur_part_mesh_dict['faces']
|
772 |
+
cur_s_cur_part_new_faces = []
|
773 |
+
for cur_s_cur_part_cur_face in cur_s_cur_part_faces:
|
774 |
+
cur_s_cur_part_cur_new_face = [fid + cur_s_n_vertices for fid in cur_s_cur_part_cur_face]
|
775 |
+
cur_s_cur_part_new_faces.append(cur_s_cur_part_cur_new_face)
|
776 |
+
cur_s_n_vertices += cur_s_cur_part_vertices.shape[0]
|
777 |
+
|
778 |
+
if i_p == 1:
|
779 |
+
# min_rngs = cur_s_cur_part_vertices.min(1)
|
780 |
+
# max_rngs = cur_s_cur_part_vertices.max(1)
|
781 |
+
min_rngs = cur_s_cur_part_vertices.min(0)
|
782 |
+
max_rngs = cur_s_cur_part_vertices.max(0)
|
783 |
+
# shifted; cur_s_pred_pvp
|
784 |
+
# shifted = np.array([0., cur_s_pred_pvp[1] - max_rngs[1], cur_s_pred_pvp[2] - min_rngs[2]], dtype=np.float)
|
785 |
+
# shifted = np.reshape(shifted, [1, 3]) #
|
786 |
+
cur_s_pred_pvp = np.array([0., max_rngs[1], min_rngs[2]], dtype=np.float32)
|
787 |
+
pvp_sample_pred_err = np.sum((cur_s_pred_pvp - pred_pvp[i_s]) ** 2)
|
788 |
+
# print prediction err, pred pvp and real pvp
|
789 |
+
# print("cur_s, sample_pred_pvp_err:", pvp_sample_pred_err.item(), ";real val:", cur_s_pred_pvp, "; pred_val:", pred_pvp[i_s])
|
790 |
+
pred_pvp[i_s] = cur_s_pred_pvp
|
791 |
+
shifted = np.zeros((1, 3), dtype=np.float32)
|
792 |
+
cur_s_cur_part_vertices = cur_s_cur_part_vertices + shifted # shift vertices... # min_rngs
|
793 |
+
# shifted
|
794 |
+
cur_s_tot_vertices.append(cur_s_cur_part_vertices)
|
795 |
+
cur_s_tot_faces += cur_s_cur_part_new_faces
|
796 |
+
|
797 |
+
cur_s_tot_vertices = np.concatenate(cur_s_tot_vertices, axis=0)
|
798 |
+
cur_s_mesh_dict = {
|
799 |
+
'vertices': cur_s_tot_vertices, 'faces': cur_s_tot_faces
|
800 |
+
}
|
801 |
+
tot_samples_mesh_dict.append(cur_s_mesh_dict)
|
802 |
+
|
803 |
+
for i_s in range(tot_n_samples):
|
804 |
+
cur_mesh = tot_samples_mesh_dict[i_s]
|
805 |
+
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
|
806 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_s}_shifted.obj")
|
807 |
+
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
|
808 |
+
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
|
809 |
+
###### plot mesh (translated) ######
|
810 |
+
|
811 |
+
|
812 |
+
|
813 |
+
###### plot mesh (rotated) ######
|
814 |
+
if predict_joint:
|
815 |
+
from revolute_transform import revoluteTransform
|
816 |
+
tot_samples_mesh_dict = []
|
817 |
+
for i_s in range(tot_n_samples):
|
818 |
+
cur_s_tot_vertices = []
|
819 |
+
cur_s_tot_faces = []
|
820 |
+
cur_s_n_vertices = 0
|
821 |
+
|
822 |
+
# cur_s_pred_dir = pred_dir[i_s]
|
823 |
+
cur_s_pred_pvp = pred_pvp[i_s]
|
824 |
+
print("current pred dir:", cur_s_pred_dir, "; current pred pvp:", cur_s_pred_pvp)
|
825 |
+
cur_s_pred_dir = np.array([1.0, 0.0, 0.0], dtype=np.float)
|
826 |
+
# cur_s_pred_pvp = cur_s_pred_pvp[[1, 2, 0]]
|
827 |
+
|
828 |
+
for i_p in range(tot_n_part):
|
829 |
+
cur_s_cur_part_mesh_dict = tot_mesh_list[i_p][i_s]
|
830 |
+
cur_s_cur_part_vertices, cur_s_cur_part_faces = cur_s_cur_part_mesh_dict['vertices'], \
|
831 |
+
cur_s_cur_part_mesh_dict['faces']
|
832 |
+
|
833 |
+
if i_p == 1:
|
834 |
+
cur_s_cur_part_vertices, _ = revoluteTransform(cur_s_cur_part_vertices, cur_s_pred_pvp, cur_s_pred_dir, 0.5 * np.pi) # reverse revolute vertices of the upper piece
|
835 |
+
cur_s_cur_part_vertices = cur_s_cur_part_vertices[:, :3] #
|
836 |
+
cur_s_cur_part_new_faces = []
|
837 |
+
for cur_s_cur_part_cur_face in cur_s_cur_part_faces:
|
838 |
+
cur_s_cur_part_cur_new_face = [fid + cur_s_n_vertices for fid in cur_s_cur_part_cur_face]
|
839 |
+
cur_s_cur_part_new_faces.append(cur_s_cur_part_cur_new_face)
|
840 |
+
cur_s_n_vertices += cur_s_cur_part_vertices.shape[0]
|
841 |
+
cur_s_tot_vertices.append(cur_s_cur_part_vertices)
|
842 |
+
# print(f"i_s: {i_s}, i_p: {i_p}, n_vertices: {cur_s_cur_part_vertices.shape[0]}")
|
843 |
+
cur_s_tot_faces += cur_s_cur_part_new_faces
|
844 |
+
|
845 |
+
cur_s_tot_vertices = np.concatenate(cur_s_tot_vertices, axis=0)
|
846 |
+
# print(f"i_s: {i_s}, n_cur_s_tot_vertices: {cur_s_tot_vertices.shape[0]}")
|
847 |
+
cur_s_mesh_dict = {
|
848 |
+
'vertices': cur_s_tot_vertices, 'faces': cur_s_tot_faces
|
849 |
+
}
|
850 |
+
tot_samples_mesh_dict.append(cur_s_mesh_dict)
|
851 |
+
# plot_meshes(tot_samples_mesh_dict, ax_lims=0.5, mesh_sv_fn=f"./figs/training_step_{n}_part_{tot_n_part}_rot.png") # plot the mesh;
|
852 |
+
for i_s in range(tot_n_samples):
|
853 |
+
cur_mesh = tot_samples_mesh_dict[i_s]
|
854 |
+
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
|
855 |
+
# rotated mesh fn
|
856 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_s}_rot.obj")
|
857 |
+
# write object in the file...
|
858 |
+
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
|
859 |
+
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=False, scale=1.)
|
860 |
+
|
861 |
+
|
862 |
+
|
863 |
+
|
864 |
+
def sample_pts_from_mesh(vertices, faces, npoints=512, minus_one=True):
|
865 |
+
return vertices
|
866 |
+
sampled_pcts = []
|
867 |
+
pts_to_seg_idx = []
|
868 |
+
# triangles and pts
|
869 |
+
minus_val = 0 if not minus_one else 1
|
870 |
+
for i in range(len(faces)): #
|
871 |
+
cur_face = faces[i]
|
872 |
+
n_tris = len(cur_face) - 2
|
873 |
+
v_as, v_bs, v_cs = [cur_face[0] for _ in range(n_tris)], cur_face[1: len(cur_face) - 1], cur_face[2: len(cur_face)]
|
874 |
+
for v_a, v_b, v_c in zip(v_as, v_bs, v_cs):
|
875 |
+
|
876 |
+
v_a, v_b, v_c = vertices[v_a - minus_val], vertices[v_b - minus_val], vertices[v_c - minus_val]
|
877 |
+
ab, ac = v_b - v_a, v_c - v_a
|
878 |
+
cos_ab_ac = (np.sum(ab * ac) / np.clip(np.sqrt(np.sum(ab ** 2)) * np.sqrt(np.sum(ac ** 2)), a_min=1e-9,
|
879 |
+
a_max=9999999.)).item()
|
880 |
+
sin_ab_ac = math.sqrt(min(max(0., 1. - cos_ab_ac ** 2), 1.))
|
881 |
+
cur_area = 0.5 * sin_ab_ac * np.sqrt(np.sum(ab ** 2)).item() * np.sqrt(np.sum(ac ** 2)).item()
|
882 |
+
|
883 |
+
cur_sampled_pts = int(cur_area * npoints)
|
884 |
+
cur_sampled_pts = 1 if cur_sampled_pts == 0 else cur_sampled_pts
|
885 |
+
# if cur_sampled_pts == 0:
|
886 |
+
|
887 |
+
tmp_x, tmp_y = np.random.uniform(0, 1., (cur_sampled_pts,)).tolist(), np.random.uniform(0., 1., (
|
888 |
+
cur_sampled_pts,)).tolist()
|
889 |
+
|
890 |
+
for xx, yy in zip(tmp_x, tmp_y):
|
891 |
+
sqrt_xx, sqrt_yy = math.sqrt(xx), math.sqrt(yy)
|
892 |
+
aa = 1. - sqrt_xx
|
893 |
+
bb = sqrt_xx * (1. - yy)
|
894 |
+
cc = yy * sqrt_xx
|
895 |
+
cur_pos = v_a * aa + v_b * bb + v_c * cc
|
896 |
+
sampled_pcts.append(cur_pos)
|
897 |
+
# pts_to_seg_idx.append(cur_tri_seg)
|
898 |
+
|
899 |
+
# seg_idx_to_sampled_pts = {}
|
900 |
+
sampled_pcts = np.array(sampled_pcts, dtype=np.float)
|
901 |
+
|
902 |
+
return sampled_pcts
|
903 |
+
|
904 |
+
|
905 |
+
def fps_fr_numpy(np_pts, n_sampling=4096):
|
906 |
+
pts = torch.from_numpy(np_pts).float().cuda()
|
907 |
+
pts_fps_idx = farthest_point_sampling(pts.unsqueeze(0), n_sampling=n_sampling) # farthes points sampling ##
|
908 |
+
pts = pts[pts_fps_idx].cpu()
|
909 |
+
return pts
|
910 |
+
|
911 |
+
|
912 |
+
def farthest_point_sampling(pos: torch.FloatTensor, n_sampling: int):
|
913 |
+
bz, N = pos.size(0), pos.size(1)
|
914 |
+
feat_dim = pos.size(-1)
|
915 |
+
device = pos.device
|
916 |
+
sampling_ratio = float(n_sampling / N)
|
917 |
+
pos_float = pos.float()
|
918 |
+
|
919 |
+
batch = torch.arange(bz, dtype=torch.long).view(bz, 1).to(device)
|
920 |
+
mult_one = torch.ones((N,), dtype=torch.long).view(1, N).to(device)
|
921 |
+
|
922 |
+
batch = batch * mult_one
|
923 |
+
batch = batch.view(-1)
|
924 |
+
pos_float = pos_float.contiguous().view(-1, feat_dim).contiguous() # (bz x N, 3)
|
925 |
+
# sampling_ratio = torch.tensor([sampling_ratio for _ in range(bz)], dtype=torch.float).to(device)
|
926 |
+
# batch = torch.zeros((N, ), dtype=torch.long, device=device)
|
927 |
+
sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False)
|
928 |
+
# shape of sampled_idx?
|
929 |
+
return sampled_idx
|
930 |
+
|
931 |
+
|
932 |
+
def plot_sampled_meshes_single_part(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
|
933 |
+
|
934 |
+
if not os.path.exists(sv_mesh_folder):
|
935 |
+
os.mkdir(sv_mesh_folder)
|
936 |
+
|
937 |
+
part_vertex_samples = [v_sample]
|
938 |
+
part_face_samples = [f_sample]
|
939 |
+
|
940 |
+
|
941 |
+
tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
|
942 |
+
tot_n_part = 2
|
943 |
+
|
944 |
+
# not predict joints here
|
945 |
+
# if predict_joint:
|
946 |
+
# pred_dir = v_sample['joint_dir']
|
947 |
+
# pred_pvp = v_sample['joint_pvp']
|
948 |
+
# print("pred_dir", pred_dir.shape, pred_dir)
|
949 |
+
# print("pred_pvp", pred_pvp.shape, pred_pvp)
|
950 |
+
# else:
|
951 |
+
# pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
|
952 |
+
|
953 |
+
|
954 |
+
tot_mesh_list = []
|
955 |
+
for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
|
956 |
+
mesh_list = []
|
957 |
+
for i_n in range(tot_n_samples):
|
958 |
+
mesh_list.append(
|
959 |
+
{
|
960 |
+
'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
|
961 |
+
'faces': unflatten_faces(
|
962 |
+
cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
|
963 |
+
}
|
964 |
+
)
|
965 |
+
tot_mesh_list.append(mesh_list)
|
966 |
+
|
967 |
+
for i_n in range(tot_n_samples):
|
968 |
+
cur_mesh = mesh_list[i_n]
|
969 |
+
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
|
970 |
+
# cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
|
971 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
|
972 |
+
print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
|
973 |
+
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
|
974 |
+
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
|
975 |
+
|
976 |
+
|
977 |
+
def plot_sampled_meshes(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
|
978 |
+
|
979 |
+
if not os.path.exists(sv_mesh_folder):
|
980 |
+
os.mkdir(sv_mesh_folder)
|
981 |
+
|
982 |
+
part_vertex_samples = [v_sample]
|
983 |
+
part_face_samples = [f_sample]
|
984 |
+
|
985 |
+
|
986 |
+
tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
|
987 |
+
# tot_n_part = 2
|
988 |
+
|
989 |
+
# not predict joints here
|
990 |
+
# if predict_joint:
|
991 |
+
# pred_dir = v_sample['joint_dir']
|
992 |
+
# pred_pvp = v_sample['joint_pvp']
|
993 |
+
# print("pred_dir", pred_dir.shape, pred_dir)
|
994 |
+
# print("pred_pvp", pred_pvp.shape, pred_pvp)
|
995 |
+
# else:
|
996 |
+
# pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
|
997 |
+
|
998 |
+
|
999 |
+
tot_mesh_list = []
|
1000 |
+
for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
|
1001 |
+
mesh_list = []
|
1002 |
+
for i_n in range(tot_n_samples):
|
1003 |
+
mesh_list.append(
|
1004 |
+
{
|
1005 |
+
'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
|
1006 |
+
'faces': unflatten_faces(
|
1007 |
+
cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
|
1008 |
+
}
|
1009 |
+
)
|
1010 |
+
tot_mesh_list.append(mesh_list)
|
1011 |
+
|
1012 |
+
for i_n in range(tot_n_samples):
|
1013 |
+
cur_mesh = mesh_list[i_n]
|
1014 |
+
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
|
1015 |
+
# cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
|
1016 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
|
1017 |
+
print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
|
1018 |
+
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
|
1019 |
+
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1020 |
+
|
1021 |
+
|
1022 |
+
def filter_masked_vertices(vertices, mask_indicator):
|
1023 |
+
# vertices: n_verts x 3
|
1024 |
+
mask_indicator = np.reshape(mask_indicator, (vertices.shape[0], 3))
|
1025 |
+
tot_vertices = []
|
1026 |
+
for i_v in range(vertices.shape[0]):
|
1027 |
+
cur_vert = vertices[i_v]
|
1028 |
+
cur_vert_indicator = mask_indicator[i_v][0].item()
|
1029 |
+
if cur_vert_indicator < 0.5:
|
1030 |
+
tot_vertices.append(cur_vert)
|
1031 |
+
tot_vertices = np.array(tot_vertices)
|
1032 |
+
return tot_vertices
|
1033 |
+
|
1034 |
+
|
1035 |
+
def plot_sampled_ar_subd_meshes(v_sample, f_sample, sv_mesh_folder, cur_step=0, ):
|
1036 |
+
if not os.path.exists(sv_mesh_folder): ### vertices_mask
|
1037 |
+
os.mkdir(sv_mesh_folder)
|
1038 |
+
### v_sample: bsz x nn_verts x 3
|
1039 |
+
vertices_mask = v_sample['vertices_mask']
|
1040 |
+
vertices = v_sample['vertices']
|
1041 |
+
faces = f_sample['faces']
|
1042 |
+
num_face_indices = f_sample['num_face_indices'] #### num_faces_indices
|
1043 |
+
bsz = vertices.shape[0]
|
1044 |
+
|
1045 |
+
for i_bsz in range(bsz):
|
1046 |
+
cur_vertices = vertices[i_bsz]
|
1047 |
+
cur_vertices_mask = vertices_mask[i_bsz]
|
1048 |
+
cur_faces = faces[i_bsz]
|
1049 |
+
cur_num_face_indices = num_face_indices[i_bsz]
|
1050 |
+
cur_nn_verts = cur_vertices_mask.sum(-1).item()
|
1051 |
+
cur_nn_verts = int(cur_nn_verts)
|
1052 |
+
cur_vertices = cur_vertices[:cur_nn_verts]
|
1053 |
+
cur_faces = unflatten_faces(
|
1054 |
+
cur_faces[:int(cur_num_face_indices)])
|
1055 |
+
|
1056 |
+
cur_num_faces = len(cur_faces)
|
1057 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_inst_{i_bsz}.obj")
|
1058 |
+
# cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}_context.obj")
|
1059 |
+
print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_nn_verts}, nn_faces: {cur_num_faces}")
|
1060 |
+
if cur_nn_verts > 0 and cur_num_faces > 0:
|
1061 |
+
write_obj(cur_vertices, cur_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1062 |
+
|
1063 |
+
|
1064 |
+
|
1065 |
+
def plot_sampled_meshes_single_part_for_pretraining(v_sample, f_sample, context, sv_mesh_folder, cur_step=0, predict_joint=True, with_context=True):
|
1066 |
+
|
1067 |
+
if not os.path.exists(sv_mesh_folder):
|
1068 |
+
os.mkdir(sv_mesh_folder)
|
1069 |
+
|
1070 |
+
part_vertex_samples = [v_sample]
|
1071 |
+
part_face_samples = [f_sample]
|
1072 |
+
|
1073 |
+
context_vertices = [context['vertices']]
|
1074 |
+
context_faces = [context['faces']]
|
1075 |
+
context_vertices_mask = [context['vertices_mask']]
|
1076 |
+
context_faces_mask = [context['faces_mask']]
|
1077 |
+
|
1078 |
+
|
1079 |
+
tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
|
1080 |
+
tot_n_part = 2
|
1081 |
+
|
1082 |
+
# not predict joints here
|
1083 |
+
# if predict_joint:
|
1084 |
+
# pred_dir = v_sample['joint_dir']
|
1085 |
+
# pred_pvp = v_sample['joint_pvp']
|
1086 |
+
# print("pred_dir", pred_dir.shape, pred_dir)
|
1087 |
+
# print("pred_pvp", pred_pvp.shape, pred_pvp)
|
1088 |
+
# else:
|
1089 |
+
# pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
|
1090 |
+
|
1091 |
+
#
|
1092 |
+
|
1093 |
+
|
1094 |
+
tot_mesh_list = []
|
1095 |
+
for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
|
1096 |
+
mesh_list = []
|
1097 |
+
context_mesh_list = []
|
1098 |
+
for i_n in range(tot_n_samples):
|
1099 |
+
mesh_list.append(
|
1100 |
+
{
|
1101 |
+
'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
|
1102 |
+
'faces': unflatten_faces(
|
1103 |
+
cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
|
1104 |
+
}
|
1105 |
+
)
|
1106 |
+
|
1107 |
+
cur_context_vertices = context_vertices[i_p][i_n]
|
1108 |
+
cur_context_faces = context_faces[i_p][i_n]
|
1109 |
+
cur_context_vertices_mask = context_vertices_mask[i_p][i_n]
|
1110 |
+
cur_context_faces_mask = context_faces_mask[i_p][i_n]
|
1111 |
+
cur_nn_vertices = np.sum(cur_context_vertices_mask).item()
|
1112 |
+
cur_nn_faces = np.sum(cur_context_faces_mask).item()
|
1113 |
+
cur_nn_vertices, cur_nn_faces = int(cur_nn_vertices), int(cur_nn_faces)
|
1114 |
+
cur_context_vertices = cur_context_vertices[:cur_nn_vertices]
|
1115 |
+
if 'mask_vertices_flat_indicator' in context:
|
1116 |
+
cur_context_vertices_mask_indicator = context['mask_vertices_flat_indicator'][i_n]
|
1117 |
+
cur_context_vertices_mask_indicator = cur_context_vertices_mask_indicator[:cur_nn_vertices * 3]
|
1118 |
+
cur_context_vertices = filter_masked_vertices(cur_context_vertices, cur_context_vertices_mask_indicator)
|
1119 |
+
cur_context_faces = cur_context_faces[:cur_nn_faces] # context faces...
|
1120 |
+
context_mesh_dict = {
|
1121 |
+
'vertices': dequantize_verts(cur_context_vertices, n_bits=8), 'faces': unflatten_faces(cur_context_faces)
|
1122 |
+
}
|
1123 |
+
context_mesh_list.append(context_mesh_dict)
|
1124 |
+
|
1125 |
+
tot_mesh_list.append(mesh_list)
|
1126 |
+
|
1127 |
+
# if with_context:
|
1128 |
+
for i_n in range(tot_n_samples):
|
1129 |
+
cur_mesh = mesh_list[i_n]
|
1130 |
+
cur_context_mesh = context_mesh_list[i_n]
|
1131 |
+
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
|
1132 |
+
cur_context_vertices, cur_context_faces = cur_context_mesh['vertices'], cur_context_mesh['faces']
|
1133 |
+
# cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
|
1134 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
|
1135 |
+
cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_part_{i_p}_ins_{i_n}_context.obj")
|
1136 |
+
print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
|
1137 |
+
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
|
1138 |
+
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1139 |
+
if cur_context_vertices.shape[0] > 0 and len(cur_context_faces) > 0:
|
1140 |
+
write_obj(cur_context_vertices, cur_context_faces, cur_context_mesh_sv_fn, transpose=True, scale=1.)
|
1141 |
+
|
1142 |
+
|
1143 |
+
def plot_grids_for_pretraining_ml(v_sample, sv_mesh_folder="", cur_step=0, context=None):
|
1144 |
+
|
1145 |
+
if not os.path.exists(sv_mesh_folder):
|
1146 |
+
os.mkdir(sv_mesh_folder)
|
1147 |
+
|
1148 |
+
mesh_list = []
|
1149 |
+
context_mesh_list = []
|
1150 |
+
tot_n_samples = v_sample['vertices'].shape[0]
|
1151 |
+
|
1152 |
+
for i_n in range(tot_n_samples):
|
1153 |
+
mesh_list.append(
|
1154 |
+
{
|
1155 |
+
'vertices': v_sample['vertices'][i_n][:v_sample['num_vertices'][i_n]],
|
1156 |
+
'faces': []
|
1157 |
+
}
|
1158 |
+
)
|
1159 |
+
|
1160 |
+
cur_context_vertices = context['vertices'][i_n]
|
1161 |
+
cur_context_vertices_mask = context['vertices_mask'][i_n]
|
1162 |
+
cur_nn_vertices = np.sum(cur_context_vertices_mask).item()
|
1163 |
+
cur_nn_vertices = int(cur_nn_vertices)
|
1164 |
+
cur_context_vertices = cur_context_vertices[:cur_nn_vertices]
|
1165 |
+
if 'mask_vertices_flat_indicator' in context:
|
1166 |
+
cur_context_vertices_mask_indicator = context['mask_vertices_flat_indicator'][i_n]
|
1167 |
+
cur_context_vertices_mask_indicator = cur_context_vertices_mask_indicator[:cur_nn_vertices * 3]
|
1168 |
+
cur_context_vertices = filter_masked_vertices(cur_context_vertices, cur_context_vertices_mask_indicator)
|
1169 |
+
context_mesh_dict = {
|
1170 |
+
'vertices': dequantize_verts(cur_context_vertices, n_bits=8), 'faces': []
|
1171 |
+
}
|
1172 |
+
context_mesh_list.append(context_mesh_dict)
|
1173 |
+
|
1174 |
+
# tot_mesh_list.append(mesh_list)
|
1175 |
+
|
1176 |
+
# if with_context:
|
1177 |
+
for i_n in range(tot_n_samples):
|
1178 |
+
cur_mesh = mesh_list[i_n]
|
1179 |
+
cur_context_mesh = context_mesh_list[i_n]
|
1180 |
+
cur_mesh_vertices = cur_mesh['vertices']
|
1181 |
+
cur_context_vertices = cur_context_mesh['vertices']
|
1182 |
+
# cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
|
1183 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_n}.obj")
|
1184 |
+
cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_n}_context.obj")
|
1185 |
+
# print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
|
1186 |
+
print(f"saving the sample to {cur_mesh_sv_fn}, context sample to {cur_context_mesh_sv_fn}")
|
1187 |
+
if cur_mesh_vertices.shape[0] > 0:
|
1188 |
+
write_obj(cur_mesh_vertices, None, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1189 |
+
if cur_context_vertices.shape[0] > 0:
|
1190 |
+
write_obj(cur_context_vertices, None, cur_context_mesh_sv_fn, transpose=True, scale=1.)
|
1191 |
+
|
1192 |
+
|
1193 |
+
def get_grid_content_from_grids(grid_xyzs, grid_values, grid_size=2):
|
1194 |
+
cur_bsz_grid_xyzs = grid_xyzs # grid_length x 3 # grids pts for a sinlge batch
|
1195 |
+
cur_bsz_grid_values = grid_values # grid_length x gs x gs x gs
|
1196 |
+
pts = []
|
1197 |
+
for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
|
1198 |
+
cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
|
1199 |
+
if cur_grid_xyz[0] == -1 or cur_grid_xyz[1] == -1 or cur_grid_xyz[2] == -1:
|
1200 |
+
break
|
1201 |
+
if len(cur_bsz_grid_values.shape) > 1:
|
1202 |
+
cur_grid_values = cur_bsz_grid_values[i_grid]
|
1203 |
+
else:
|
1204 |
+
cur_grid_content = cur_bsz_grid_values[i_grid].item()
|
1205 |
+
if cur_grid_content >= MASK_GRID_VALIE:
|
1206 |
+
continue
|
1207 |
+
inde = 2
|
1208 |
+
cur_grid_values = []
|
1209 |
+
for i_s in range(grid_size ** 3):
|
1210 |
+
cur_mod_value = cur_grid_content % inde
|
1211 |
+
cur_grid_content = cur_grid_content // inde
|
1212 |
+
cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
|
1213 |
+
cur_grid_values = np.array(cur_grid_values, dtype=np.long)
|
1214 |
+
cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
|
1215 |
+
# if words
|
1216 |
+
# flip words
|
1217 |
+
for i_x in range(cur_grid_values.shape[0]):
|
1218 |
+
for i_y in range(cur_grid_values.shape[1]):
|
1219 |
+
for i_z in range(cur_grid_values.shape[2]):
|
1220 |
+
cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
|
1221 |
+
if cur_grid_xyz_value > 0.5:
|
1222 |
+
cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
|
1223 |
+
pts.append([cur_x, cur_y, cur_z])
|
1224 |
+
return pts
|
1225 |
+
|
1226 |
+
def plot_grids_for_pretraining(v_sample, sv_mesh_folder="", cur_step=0, context=None, sv_mesh_fn=None):
|
1227 |
+
|
1228 |
+
##### plot grids
|
1229 |
+
if not os.path.exists(sv_mesh_folder):
|
1230 |
+
os.mkdir(sv_mesh_folder)
|
1231 |
+
|
1232 |
+
# part_vertex_samples = [v_sample] # vertex samples
|
1233 |
+
# part_face_samples = [f_sample] # face samples
|
1234 |
+
|
1235 |
+
grid_xyzs = v_sample['grid_xyzs']
|
1236 |
+
grid_values = v_sample['grid_values']
|
1237 |
+
|
1238 |
+
bsz = grid_xyzs.shape[0]
|
1239 |
+
grid_size = opt.vertex_model.grid_size
|
1240 |
+
|
1241 |
+
|
1242 |
+
for i_bsz in range(bsz):
|
1243 |
+
cur_bsz_grid_xyzs = grid_xyzs[i_bsz] # grid_length x 3
|
1244 |
+
cur_bsz_grid_values = grid_values[i_bsz] # grid_length x gs x gs x gs
|
1245 |
+
pts = []
|
1246 |
+
for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
|
1247 |
+
cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
|
1248 |
+
if cur_grid_xyz[0] == -1 or cur_grid_xyz[1] == -1 or cur_grid_xyz[2] == -1:
|
1249 |
+
break
|
1250 |
+
if len(cur_bsz_grid_values.shape) > 1:
|
1251 |
+
cur_grid_values = cur_bsz_grid_values[i_grid]
|
1252 |
+
else:
|
1253 |
+
cur_grid_content = cur_bsz_grid_values[i_grid].item()
|
1254 |
+
if cur_grid_content >= MASK_GRID_VALIE:
|
1255 |
+
continue
|
1256 |
+
inde = 2
|
1257 |
+
cur_grid_values = []
|
1258 |
+
for i_s in range(grid_size ** 3):
|
1259 |
+
cur_mod_value = cur_grid_content % inde
|
1260 |
+
cur_grid_content = cur_grid_content // inde
|
1261 |
+
cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
|
1262 |
+
cur_grid_values = np.array(cur_grid_values, dtype=np.long)
|
1263 |
+
cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
|
1264 |
+
# if
|
1265 |
+
for i_x in range(cur_grid_values.shape[0]):
|
1266 |
+
for i_y in range(cur_grid_values.shape[1]):
|
1267 |
+
for i_z in range(cur_grid_values.shape[2]):
|
1268 |
+
cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
|
1269 |
+
if cur_grid_xyz_value > 0.5:
|
1270 |
+
cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
|
1271 |
+
pts.append([cur_x, cur_y, cur_z])
|
1272 |
+
|
1273 |
+
|
1274 |
+
if len(pts) == 0:
|
1275 |
+
print("zzz, len(pts) == 0")
|
1276 |
+
continue
|
1277 |
+
pts = np.array(pts, dtype=np.float32)
|
1278 |
+
# pts = center_vertices(pts)
|
1279 |
+
# pts = normalize_vertices_scale(pts)
|
1280 |
+
pts = pts[:, [2, 1, 0]]
|
1281 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}.obj" if sv_mesh_fn is None else sv_mesh_fn)
|
1282 |
+
|
1283 |
+
print(f"Saving obj to {cur_mesh_sv_fn}")
|
1284 |
+
write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1285 |
+
|
1286 |
+
|
1287 |
+
def plot_grids_for_pretraining_obj_corpus(v_sample, sv_mesh_folder="", cur_step=0, context=None, sv_mesh_fn=None):
|
1288 |
+
|
1289 |
+
##### plot grids
|
1290 |
+
if not os.path.exists(sv_mesh_folder):
|
1291 |
+
os.mkdir(sv_mesh_folder)
|
1292 |
+
|
1293 |
+
# part_vertex_samples = [v_sample] # vertex samples
|
1294 |
+
# part_face_samples = [f_sample] # face samples
|
1295 |
+
|
1296 |
+
grid_xyzs = v_sample['grid_xyzs']
|
1297 |
+
grid_values = v_sample['grid_values']
|
1298 |
+
|
1299 |
+
bsz = grid_xyzs.shape[0]
|
1300 |
+
grid_size = opt.vertex_model.grid_size
|
1301 |
+
|
1302 |
+
|
1303 |
+
for i_bsz in range(bsz):
|
1304 |
+
cur_bsz_grid_xyzs = grid_xyzs[i_bsz] # grid_length x 3
|
1305 |
+
cur_bsz_grid_values = grid_values[i_bsz] # grid_length x gs x gs x gs
|
1306 |
+
part_pts = []
|
1307 |
+
pts = []
|
1308 |
+
for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
|
1309 |
+
cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
|
1310 |
+
##### grid_xyz; grid_
|
1311 |
+
if cur_grid_xyz[0] == -1 and cur_grid_xyz[1] == -1 and cur_grid_xyz[2] == -1:
|
1312 |
+
part_pts.append(pts)
|
1313 |
+
pts = []
|
1314 |
+
continue
|
1315 |
+
##### cur_grid_xyz... #####
|
1316 |
+
elif not (cur_grid_xyz[0] >= 0 and cur_grid_xyz[1] >= 0 and cur_grid_xyz[2] >= 0):
|
1317 |
+
continue
|
1318 |
+
if len(cur_bsz_grid_values.shape) > 1:
|
1319 |
+
cur_grid_values = cur_bsz_grid_values[i_grid]
|
1320 |
+
else:
|
1321 |
+
cur_grid_content = cur_bsz_grid_values[i_grid].item()
|
1322 |
+
if cur_grid_content >= MASK_GRID_VALIE: # mask grid value
|
1323 |
+
continue
|
1324 |
+
inde = 2
|
1325 |
+
cur_grid_values = []
|
1326 |
+
for i_s in range(grid_size ** 3):
|
1327 |
+
cur_mod_value = cur_grid_content % inde
|
1328 |
+
cur_grid_content = cur_grid_content // inde
|
1329 |
+
cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
|
1330 |
+
cur_grid_values = np.array(cur_grid_values, dtype=np.long)
|
1331 |
+
cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
|
1332 |
+
# if
|
1333 |
+
for i_x in range(cur_grid_values.shape[0]):
|
1334 |
+
for i_y in range(cur_grid_values.shape[1]):
|
1335 |
+
for i_z in range(cur_grid_values.shape[2]):
|
1336 |
+
cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
|
1337 |
+
##### gird-xyz-values #####
|
1338 |
+
if cur_grid_xyz_value > 0.5: # cur_grid_xyz_value
|
1339 |
+
cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
|
1340 |
+
pts.append([cur_x, cur_y, cur_z])
|
1341 |
+
|
1342 |
+
if len(pts) > 0:
|
1343 |
+
part_pts.append(pts)
|
1344 |
+
pts = []
|
1345 |
+
tot_nn_pts = sum([len(aa) for aa in part_pts])
|
1346 |
+
if tot_nn_pts == 0:
|
1347 |
+
print("zzz, tot_nn_pts == 0")
|
1348 |
+
continue
|
1349 |
+
|
1350 |
+
for i_p, pts in enumerate(part_pts):
|
1351 |
+
if len(pts) == 0:
|
1352 |
+
continue
|
1353 |
+
pts = np.array(pts, dtype=np.float32)
|
1354 |
+
pts = center_vertices(pts)
|
1355 |
+
# pts = normalize_vertices_scale(pts)
|
1356 |
+
pts = pts[:, [2, 1, 0]]
|
1357 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}_ip_{i_p}.obj" if sv_mesh_fn is None else sv_mesh_fn)
|
1358 |
+
|
1359 |
+
print(f"Saving {i_p}-th part obj to {cur_mesh_sv_fn}")
|
1360 |
+
write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1361 |
+
|
1362 |
+
|
1363 |
+
|
1364 |
+
def plot_grids_for_pretraining_obj_part(v_sample, sv_mesh_folder="", cur_step=0, context=None, sv_mesh_fn=None):
|
1365 |
+
|
1366 |
+
##### plot grids
|
1367 |
+
if not os.path.exists(sv_mesh_folder):
|
1368 |
+
os.mkdir(sv_mesh_folder)
|
1369 |
+
|
1370 |
+
# part_vertex_samples = [v_sample] # vertex samples
|
1371 |
+
# part_face_samples = [f_sample] # face samples
|
1372 |
+
|
1373 |
+
grid_xyzs = v_sample['grid_xyzs']
|
1374 |
+
grid_values = v_sample['grid_values']
|
1375 |
+
|
1376 |
+
bsz = grid_xyzs.shape[0]
|
1377 |
+
grid_size = opt.vertex_model.grid_size
|
1378 |
+
|
1379 |
+
|
1380 |
+
for i_bsz in range(bsz):
|
1381 |
+
cur_bsz_grid_xyzs = grid_xyzs[i_bsz] # grid_length x 3
|
1382 |
+
cur_bsz_grid_values = grid_values[i_bsz] # grid_length x gs x gs x gs
|
1383 |
+
part_pts = []
|
1384 |
+
pts = []
|
1385 |
+
for i_grid in range(cur_bsz_grid_xyzs.shape[0]): # cur_bsz_grid_xyzs
|
1386 |
+
cur_grid_xyz = cur_bsz_grid_xyzs[i_grid].tolist()
|
1387 |
+
##### grid_xyz; grid_
|
1388 |
+
if cur_grid_xyz[0] == -1 and cur_grid_xyz[1] == -1 and cur_grid_xyz[2] == -1:
|
1389 |
+
part_pts.append(pts)
|
1390 |
+
pts = []
|
1391 |
+
break
|
1392 |
+
elif cur_grid_xyz[0] == -1 and cur_grid_xyz[1] == -1 and cur_grid_xyz[2] == 0:
|
1393 |
+
part_pts.append(pts)
|
1394 |
+
pts = []
|
1395 |
+
continue
|
1396 |
+
##### cur_grid_xyz... #####
|
1397 |
+
elif not (cur_grid_xyz[0] >= 0 and cur_grid_xyz[1] >= 0 and cur_grid_xyz[2] >= 0):
|
1398 |
+
continue
|
1399 |
+
if len(cur_bsz_grid_values.shape) > 1:
|
1400 |
+
cur_grid_values = cur_bsz_grid_values[i_grid]
|
1401 |
+
else:
|
1402 |
+
cur_grid_content = cur_bsz_grid_values[i_grid].item()
|
1403 |
+
if cur_grid_content >= MASK_GRID_VALIE: # invalid jor dummy content value s
|
1404 |
+
continue
|
1405 |
+
inde = 2
|
1406 |
+
cur_grid_values = []
|
1407 |
+
for i_s in range(grid_size ** 3):
|
1408 |
+
cur_mod_value = cur_grid_content % inde
|
1409 |
+
cur_grid_content = cur_grid_content // inde
|
1410 |
+
cur_grid_values = [cur_mod_value] + cur_grid_values # higher values should be put to the front of the list
|
1411 |
+
cur_grid_values = np.array(cur_grid_values, dtype=np.long)
|
1412 |
+
cur_grid_values = np.reshape(cur_grid_values, (grid_size, grid_size, grid_size))
|
1413 |
+
# if
|
1414 |
+
for i_x in range(cur_grid_values.shape[0]):
|
1415 |
+
for i_y in range(cur_grid_values.shape[1]):
|
1416 |
+
for i_z in range(cur_grid_values.shape[2]):
|
1417 |
+
cur_grid_xyz_value = int(cur_grid_values[i_x, i_y, i_z].item())
|
1418 |
+
##### gird-xyz-values #####
|
1419 |
+
if cur_grid_xyz_value > 0.5: # cur_grid_xyz_value
|
1420 |
+
cur_x, cur_y, cur_z = cur_grid_xyz[0] * grid_size + i_x, cur_grid_xyz[1] * grid_size + i_y, cur_grid_xyz[2] * grid_size + i_z
|
1421 |
+
pts.append([cur_x, cur_y, cur_z])
|
1422 |
+
|
1423 |
+
if len(pts) > 0:
|
1424 |
+
part_pts.append(pts)
|
1425 |
+
pts = []
|
1426 |
+
tot_nn_pts = sum([len(aa) for aa in part_pts])
|
1427 |
+
if tot_nn_pts == 0:
|
1428 |
+
print("zzz, tot_nn_pts == 0")
|
1429 |
+
continue
|
1430 |
+
|
1431 |
+
for i_p, pts in enumerate(part_pts):
|
1432 |
+
if len(pts) == 0:
|
1433 |
+
continue
|
1434 |
+
pts = np.array(pts, dtype=np.float32)
|
1435 |
+
pts = center_vertices(pts)
|
1436 |
+
# pts = normalize_vertices_scale(pts)
|
1437 |
+
pts = pts[:, [2, 1, 0]]
|
1438 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}_ip_{i_p}.obj" if sv_mesh_fn is None else sv_mesh_fn)
|
1439 |
+
|
1440 |
+
print(f"Saving {i_p}-th part obj to {cur_mesh_sv_fn}")
|
1441 |
+
write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1442 |
+
|
1443 |
+
|
1444 |
+
def plot_grids_for_pretraining_ml(v_sample, sv_mesh_folder="", cur_step=0, context=None):
|
1445 |
+
|
1446 |
+
if not os.path.exists(sv_mesh_folder):
|
1447 |
+
os.mkdir(sv_mesh_folder)
|
1448 |
+
|
1449 |
+
# part_vertex_samples = [v_sample] # vertex samples
|
1450 |
+
# part_face_samples = [f_sample] # face samples
|
1451 |
+
|
1452 |
+
grid_xyzs = v_sample['grid_xyzs']
|
1453 |
+
grid_values = v_sample['grid_values']
|
1454 |
+
|
1455 |
+
context_grid_xyzs = context['grid_xyzs'] - 1
|
1456 |
+
# context_grid_values = context['grid_content']
|
1457 |
+
context_grid_values = context['mask_grid_content']
|
1458 |
+
|
1459 |
+
bsz = grid_xyzs.shape[0]
|
1460 |
+
grid_size = opt.vertex_model.grid_size
|
1461 |
+
|
1462 |
+
|
1463 |
+
for i_bsz in range(bsz):
|
1464 |
+
cur_bsz_grid_pts = get_grid_content_from_grids(grid_xyzs[i_bsz], grid_values[i_bsz], grid_size=grid_size)
|
1465 |
+
cur_context_grid_pts = get_grid_content_from_grids(context_grid_xyzs[i_bsz], context_grid_values[i_bsz], grid_size=grid_size)
|
1466 |
+
|
1467 |
+
if len(cur_bsz_grid_pts) > 0 and len(cur_context_grid_pts) > 0:
|
1468 |
+
cur_bsz_grid_pts = np.array(cur_bsz_grid_pts, dtype=np.float32)
|
1469 |
+
cur_bsz_grid_pts = center_vertices(cur_bsz_grid_pts)
|
1470 |
+
cur_bsz_grid_pts = normalize_vertices_scale(cur_bsz_grid_pts)
|
1471 |
+
cur_bsz_grid_pts = cur_bsz_grid_pts[:, [2, 1, 0]]
|
1472 |
+
#### plot current mesh / sampled points ####
|
1473 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}.obj")
|
1474 |
+
print(f"Saving predicted grid points to {cur_mesh_sv_fn}")
|
1475 |
+
write_obj(cur_bsz_grid_pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1476 |
+
|
1477 |
+
cur_context_grid_pts = np.array(cur_context_grid_pts, dtype=np.float32)
|
1478 |
+
cur_context_grid_pts = center_vertices(cur_context_grid_pts)
|
1479 |
+
cur_context_grid_pts = normalize_vertices_scale(cur_context_grid_pts)
|
1480 |
+
cur_context_grid_pts = cur_context_grid_pts[:, [2, 1, 0]]
|
1481 |
+
#### plot current mesh / sampled points ####
|
1482 |
+
cur_context_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}_context.obj")
|
1483 |
+
print(f"Saving context grid points to {cur_context_mesh_sv_fn}")
|
1484 |
+
write_obj(cur_context_grid_pts, None, cur_context_mesh_sv_fn, transpose=True, scale=1.)
|
1485 |
+
|
1486 |
+
# print(f"Saving obj to {cur_mesh_sv_fn}")
|
1487 |
+
# write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1488 |
+
|
1489 |
+
# if len(cur_bsz_grid_pts) == 0:
|
1490 |
+
# print("zzz, len(pts) == 0")
|
1491 |
+
# continue
|
1492 |
+
# pts = np.array(pts, dtype=np.float32)
|
1493 |
+
# pts = center_vertices(pts)
|
1494 |
+
# pts = normalize_vertices_scale(pts)
|
1495 |
+
# pts = pts[:, [2, 1, 0]]
|
1496 |
+
# cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"training_step_{cur_step}_ins_{i_bsz}.obj")
|
1497 |
+
|
1498 |
+
# print(f"Saving obj to {cur_mesh_sv_fn}")
|
1499 |
+
# write_obj(pts, None, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1500 |
+
|
1501 |
+
|
1502 |
+
|
1503 |
+
def plot_sampled_meshes_single_part_for_sampling(v_sample, f_sample, sv_mesh_folder, cur_step=0, predict_joint=True,):
|
1504 |
+
|
1505 |
+
if not os.path.exists(sv_mesh_folder):
|
1506 |
+
os.mkdir(sv_mesh_folder)
|
1507 |
+
|
1508 |
+
part_vertex_samples = [v_sample]
|
1509 |
+
part_face_samples = [f_sample]
|
1510 |
+
|
1511 |
+
|
1512 |
+
tot_n_samples = part_vertex_samples[0]['vertices'].shape[0]
|
1513 |
+
tot_n_part = 2
|
1514 |
+
|
1515 |
+
# not predict joints here
|
1516 |
+
# if predict_joint:
|
1517 |
+
# pred_dir = v_sample['joint_dir']
|
1518 |
+
# pred_pvp = v_sample['joint_pvp']
|
1519 |
+
# print("pred_dir", pred_dir.shape, pred_dir)
|
1520 |
+
# print("pred_pvp", pred_pvp.shape, pred_pvp)
|
1521 |
+
# else:
|
1522 |
+
# pred_pvp = np.zeros(shape=[tot_n_samples, 3], dtype=np.float32)
|
1523 |
+
|
1524 |
+
|
1525 |
+
tot_mesh_list = []
|
1526 |
+
for i_p, (cur_part_v_samples_np, cur_part_f_samples_np) in enumerate(zip(part_vertex_samples, part_face_samples)):
|
1527 |
+
mesh_list = []
|
1528 |
+
for i_n in range(tot_n_samples):
|
1529 |
+
mesh_list.append(
|
1530 |
+
{
|
1531 |
+
'vertices': cur_part_v_samples_np['vertices'][i_n][:cur_part_v_samples_np['num_vertices'][i_n]],
|
1532 |
+
'faces': unflatten_faces(
|
1533 |
+
cur_part_f_samples_np['faces'][i_n][:cur_part_f_samples_np['num_face_indices'][i_n]])
|
1534 |
+
}
|
1535 |
+
)
|
1536 |
+
tot_mesh_list.append(mesh_list)
|
1537 |
+
|
1538 |
+
for i_n in range(tot_n_samples):
|
1539 |
+
cur_mesh = mesh_list[i_n]
|
1540 |
+
cur_mesh_vertices, cur_mesh_faces = cur_mesh['vertices'], cur_mesh['faces']
|
1541 |
+
# cur_mesh_sv_fn = os.path.join("./meshes", f"training_step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
|
1542 |
+
cur_mesh_sv_fn = os.path.join(sv_mesh_folder, f"step_{cur_step}_part_{i_p}_ins_{i_n}.obj")
|
1543 |
+
print(f"saving to {cur_mesh_sv_fn}, nn_verts: {cur_mesh_vertices.shape[0]}, nn_faces: {len(cur_mesh_faces)}")
|
1544 |
+
if cur_mesh_vertices.shape[0] > 0 and len(cur_mesh_faces) > 0:
|
1545 |
+
write_obj(cur_mesh_vertices, cur_mesh_faces, cur_mesh_sv_fn, transpose=True, scale=1.)
|
1546 |
+
|
1547 |
+
|
models/dataset.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import cv2 as cv
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
from glob import glob
|
7 |
+
from icecream import ic
|
8 |
+
from scipy.spatial.transform import Rotation as Rot
|
9 |
+
from scipy.spatial.transform import Slerp
|
10 |
+
|
11 |
+
|
12 |
+
# This function is borrowed from IDR: https://github.com/lioryariv/idr
|
13 |
+
def load_K_Rt_from_P(filename, P=None):
|
14 |
+
if P is None:
|
15 |
+
lines = open(filename).read().splitlines()
|
16 |
+
if len(lines) == 4:
|
17 |
+
lines = lines[1:]
|
18 |
+
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
|
19 |
+
P = np.asarray(lines).astype(np.float32).squeeze()
|
20 |
+
|
21 |
+
out = cv.decomposeProjectionMatrix(P)
|
22 |
+
K = out[0]
|
23 |
+
R = out[1]
|
24 |
+
t = out[2]
|
25 |
+
|
26 |
+
K = K / K[2, 2]
|
27 |
+
intrinsics = np.eye(4)
|
28 |
+
intrinsics[:3, :3] = K
|
29 |
+
|
30 |
+
pose = np.eye(4, dtype=np.float32)
|
31 |
+
pose[:3, :3] = R.transpose()
|
32 |
+
pose[:3, 3] = (t[:3] / t[3])[:, 0]
|
33 |
+
|
34 |
+
return intrinsics, pose
|
35 |
+
|
36 |
+
def filter_iamges_via_pixel_values(data_dir):
|
37 |
+
images_lis = sorted(glob(os.path.join(data_dir, 'image/*.png'))) ## images lis ##
|
38 |
+
n_images = len(images_lis)
|
39 |
+
images_np = np.stack([cv.imread(im_name) for im_name in images_lis]) / 255.0
|
40 |
+
print(f"images_np: {images_np.shape}")
|
41 |
+
# nn_frames x res x res x 3 #
|
42 |
+
images_np = 1. - images_np
|
43 |
+
has_density_values = (np.sum(images_np, axis=-1) > 0.7).astype(np.float32)
|
44 |
+
has_density_values = np.sum(np.sum(has_density_values, axis=-1), axis=-1)
|
45 |
+
tot_res_nns = float(images_np.shape[1] * images_np.shape[2])
|
46 |
+
has_density_ratio = has_density_values / tot_res_nns ### has density ratio and ratio #
|
47 |
+
print(f"has_density_values: {has_density_values.shape}")
|
48 |
+
paried_has_density_ratio_list = [(i_fr, has_density_ratio[i_fr].item()) for i_fr in range(has_density_ratio.shape[0])]
|
49 |
+
paried_has_density_ratio_list = sorted(paried_has_density_ratio_list, key=lambda ii: ii[1], reverse=True)
|
50 |
+
mid_rnk_value = len(paried_has_density_ratio_list) // 4
|
51 |
+
print(f"mid value of the density ratio")
|
52 |
+
print(paried_has_density_ratio_list[mid_rnk_value])
|
53 |
+
iamge_idx = paried_has_density_ratio_list[mid_rnk_value][0]
|
54 |
+
print(f"iamge idx: {images_lis[iamge_idx]}")
|
55 |
+
print(paried_has_density_ratio_list[:mid_rnk_value])
|
56 |
+
tot_selected_img_idx_list = [ii[0] for ii in paried_has_density_ratio_list[:mid_rnk_value]]
|
57 |
+
tot_selected_img_idx_list =sorted(tot_selected_img_idx_list)
|
58 |
+
print(len(tot_selected_img_idx_list))
|
59 |
+
# print(tot_selected_img_idx_list[54])
|
60 |
+
print(tot_selected_img_idx_list)
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
class Dataset:
|
65 |
+
def __init__(self, conf):
|
66 |
+
super(Dataset, self).__init__()
|
67 |
+
print('Load data: Begin')
|
68 |
+
self.device = torch.device('cuda')
|
69 |
+
self.conf = conf
|
70 |
+
|
71 |
+
self.selected_img_idxes_list = [0, 1, 5, 6, 7, 8, 9, 13, 14, 15, 35, 36, 42, 43, 44, 48, 49, 50, 51, 55, 56, 57, 61, 62, 63, 69, 84, 90, 91, 92, 96, 97]
|
72 |
+
# self.selected_img_idxes_list = [0, 1, 5, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 28, 29, 35, 36, 37, 40, 41, 70, 71, 79, 82, 83, 84, 85, 92, 93, 96, 97, 98, 99, 105, 106, 107, 110, 111, 112, 113, 118, 119, 120, 121, 124, 125, 133, 134, 135, 139, 174, 175, 176, 177, 180, 188, 189, 190, 191, 194, 195]
|
73 |
+
|
74 |
+
self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 36, 40, 41, 70, 71, 78, 82, 83, 84, 85, 90, 91, 92, 93, 96, 97]
|
75 |
+
|
76 |
+
self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 36, 40, 41, 70, 71, 78, 82, 83, 84, 85, 90, 91, 92, 93, 96, 97, 98, 99, 104, 105, 106, 107, 110, 111, 112, 113, 118, 119, 120, 121, 124, 125, 134, 135, 139, 174, 175, 176, 177, 180, 181, 182, 183, 188, 189, 190, 191, 194, 195]
|
77 |
+
|
78 |
+
self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 20, 21, 22, 23, 26, 27, 70, 78, 83, 84, 85, 91, 92, 93, 96, 97, 98, 99, 105, 106, 107, 110, 111, 112, 113, 119, 120, 121, 124, 125, 175, 176, 181, 182, 188, 189, 190, 191, 194, 195]
|
79 |
+
# or the timestep to the dataset instance ## # selected img idxes list #
|
80 |
+
self.selected_img_idxes = np.array(self.selected_img_idxes_list).astype(np.int32)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
self.data_dir = conf.get_string('data_dir')
|
87 |
+
self.render_cameras_name = conf.get_string('render_cameras_name')
|
88 |
+
self.object_cameras_name = conf.get_string('object_cameras_name')
|
89 |
+
|
90 |
+
## camera outside sphere ##
|
91 |
+
self.camera_outside_sphere = conf.get_bool('camera_outside_sphere', default=True)
|
92 |
+
self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1)
|
93 |
+
|
94 |
+
camera_dict = np.load(os.path.join(self.data_dir, self.render_cameras_name))
|
95 |
+
# camera_dict = np.load("/home/xueyi/diffsim/NeuS/public_data/dtu_scan24/cameras_sphere.npz")
|
96 |
+
self.camera_dict = camera_dict # rendr camera dict #
|
97 |
+
# render camera dict # # number of pixels in the views -> very thin geometry is not useful
|
98 |
+
self.images_lis = sorted(glob(os.path.join(self.data_dir, 'image/*.png')))
|
99 |
+
|
100 |
+
# iamges_lis # and the images_lis and the images_lis #
|
101 |
+
# self.images_lis = self.images_lis[:1] # totoal views and poses of the camera; # and select cameras for rendering #
|
102 |
+
|
103 |
+
self.n_images = len(self.images_lis)
|
104 |
+
self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 256.0
|
105 |
+
|
106 |
+
|
107 |
+
self.selected_img_idxes_list = list(range(self.images_np.shape[0]))
|
108 |
+
self.selected_img_idxes = np.array(self.selected_img_idxes_list).astype(np.int32)
|
109 |
+
|
110 |
+
self.images_np = self.images_np[self.selected_img_idxes] ## get selected iamges_np #
|
111 |
+
|
112 |
+
### if we deal with the backgound carefully ### ### get
|
113 |
+
self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 255.0
|
114 |
+
self.images_np = self.images_np[self.selected_img_idxes]
|
115 |
+
self.images_np = 1. - self.images_np ###
|
116 |
+
|
117 |
+
|
118 |
+
self.masks_lis = sorted(glob(os.path.join(self.data_dir, 'mask/*.png')))
|
119 |
+
|
120 |
+
# self.masks_lis = self.masks_lis[:1]
|
121 |
+
|
122 |
+
try:
|
123 |
+
self.masks_np = np.stack([cv.imread(im_name) for im_name in self.masks_lis]) / 256.0
|
124 |
+
self.masks_np = self.masks_np[self.selected_img_idxes]
|
125 |
+
except:
|
126 |
+
self.masks_np = self.images_np.copy()
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
# world_mat is a projection matrix from world to image
|
134 |
+
self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
|
135 |
+
|
136 |
+
self.scale_mats_np = []
|
137 |
+
|
138 |
+
# scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.
|
139 |
+
self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
|
140 |
+
|
141 |
+
self.intrinsics_all = []
|
142 |
+
self.pose_all = []
|
143 |
+
|
144 |
+
# for idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, self.world_mats_np)):
|
145 |
+
for idx in self.selected_img_idxes_list:
|
146 |
+
scale_mat = self.scale_mats_np[idx]
|
147 |
+
world_mat = self.world_mats_np[idx]
|
148 |
+
|
149 |
+
if "hand" in self.data_dir:
|
150 |
+
intrinsics = np.eye(4)
|
151 |
+
fov = 512. / 2. # * 2
|
152 |
+
res = 512.
|
153 |
+
intrinsics[:3, :3] = np.array([
|
154 |
+
[fov, 0, 0.5* res], # res #
|
155 |
+
[0, fov, 0.5* res], # res #
|
156 |
+
[0, 0, 1]
|
157 |
+
], dtype=np.float32)
|
158 |
+
pose = camera_dict['camera_mat_%d' % idx].astype(np.float32)
|
159 |
+
else:
|
160 |
+
P = world_mat @ scale_mat
|
161 |
+
P = P[:3, :4]
|
162 |
+
intrinsics, pose = load_K_Rt_from_P(None, P)
|
163 |
+
|
164 |
+
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
|
165 |
+
self.pose_all.append(torch.from_numpy(pose).float())
|
166 |
+
|
167 |
+
### images, masks,
|
168 |
+
self.images = torch.from_numpy(self.images_np.astype(np.float32)).cpu() # [n_images, H, W, 3] #
|
169 |
+
self.masks = torch.from_numpy(self.masks_np.astype(np.float32)).cpu() # [n_images, H, W, 3] #
|
170 |
+
self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4] # optimize sdf field # rigid model hand
|
171 |
+
self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4]
|
172 |
+
self.focal = self.intrinsics_all[0][0, 0]
|
173 |
+
self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4]
|
174 |
+
self.H, self.W = self.images.shape[1], self.images.shape[2]
|
175 |
+
self.image_pixels = self.H * self.W
|
176 |
+
|
177 |
+
object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0])
|
178 |
+
object_bbox_max = np.array([ 1.01, 1.01, 1.01, 1.0])
|
179 |
+
# Object scale mat: region of interest to **extract mesh**
|
180 |
+
object_scale_mat = np.load(os.path.join(self.data_dir, self.object_cameras_name))['scale_mat_0']
|
181 |
+
object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
|
182 |
+
object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
|
183 |
+
self.object_bbox_min = object_bbox_min[:3, 0]
|
184 |
+
self.object_bbox_max = object_bbox_max[:3, 0]
|
185 |
+
|
186 |
+
self.n_images = self.images.size(0)
|
187 |
+
|
188 |
+
print('Load data: End')
|
189 |
+
|
190 |
+
def get_rays(H, W, K, c2w, inverse_y, flip_x, flip_y, mode='center'):
|
191 |
+
i, j = torch.meshgrid( # meshgrid #
|
192 |
+
torch.linspace(0, W-1, W, device=c2w.device),
|
193 |
+
torch.linspace(0, H-1, H, device=c2w.device))
|
194 |
+
i = i.t().float()
|
195 |
+
j = j.t().float()
|
196 |
+
if mode == 'lefttop':
|
197 |
+
pass
|
198 |
+
elif mode == 'center':
|
199 |
+
i, j = i+0.5, j+0.5
|
200 |
+
elif mode == 'random':
|
201 |
+
i = i+torch.rand_like(i)
|
202 |
+
j = j+torch.rand_like(j)
|
203 |
+
else:
|
204 |
+
raise NotImplementedError
|
205 |
+
|
206 |
+
if flip_x:
|
207 |
+
i = i.flip((1,))
|
208 |
+
if flip_y:
|
209 |
+
j = j.flip((0,))
|
210 |
+
if inverse_y:
|
211 |
+
dirs = torch.stack([(i-K[0][2])/K[0][0], (j-K[1][2])/K[1][1], torch.ones_like(i)], -1)
|
212 |
+
else:
|
213 |
+
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
|
214 |
+
# Rotate ray directions from camera frame to the world frame
|
215 |
+
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
|
216 |
+
# Translate camera frame's origin to the world frame. It is the origin of all rays.
|
217 |
+
rays_o = c2w[:3,3].expand(rays_d.shape)
|
218 |
+
return rays_o, rays_d
|
219 |
+
|
220 |
+
def gen_rays_at(self, img_idx, resolution_level=1):
|
221 |
+
"""
|
222 |
+
Generate rays at world space from one camera.
|
223 |
+
"""
|
224 |
+
l = resolution_level
|
225 |
+
tx = torch.linspace(0, self.W - 1, self.W // l)
|
226 |
+
ty = torch.linspace(0, self.H - 1, self.H // l)
|
227 |
+
pixels_x, pixels_y = torch.meshgrid(tx, ty)
|
228 |
+
|
229 |
+
##### previous method #####
|
230 |
+
# p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
231 |
+
# # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
232 |
+
# p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
|
233 |
+
# rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
234 |
+
# rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
|
235 |
+
# rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3
|
236 |
+
##### previous method #####
|
237 |
+
|
238 |
+
fov = 512.; res = 512.
|
239 |
+
K = np.array([
|
240 |
+
[fov, 0, 0.5* res],
|
241 |
+
[0, fov, 0.5* res],
|
242 |
+
[0, 0, 1]
|
243 |
+
], dtype=np.float32)
|
244 |
+
K = torch.from_numpy(K).float().cuda()
|
245 |
+
|
246 |
+
|
247 |
+
# ### `center` mode ### #
|
248 |
+
c2w = self.pose_all[img_idx]
|
249 |
+
pixels_x, pixels_y = pixels_x+0.5, pixels_y+0.5
|
250 |
+
|
251 |
+
dirs = torch.stack([(pixels_x-K[0][2])/K[0][0], -(pixels_y-K[1][2])/K[1][1], -torch.ones_like(pixels_x)], -1)
|
252 |
+
rays_v = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
|
253 |
+
rays_o = c2w[:3,3].expand(rays_v.shape)
|
254 |
+
# dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
|
255 |
+
|
256 |
+
# p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
257 |
+
# # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
258 |
+
# p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
|
259 |
+
# rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
260 |
+
# rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
|
261 |
+
# rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3
|
262 |
+
return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
|
263 |
+
|
264 |
+
def gen_random_rays_at(self, img_idx, batch_size):
|
265 |
+
"""
|
266 |
+
Generate random rays at world space from one camera.
|
267 |
+
"""
|
268 |
+
pixels_x = torch.randint(low=0, high=self.W, size=[batch_size])
|
269 |
+
pixels_y = torch.randint(low=0, high=self.H, size=[batch_size])
|
270 |
+
color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3
|
271 |
+
|
272 |
+
mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3
|
273 |
+
|
274 |
+
|
275 |
+
##### previous method #####
|
276 |
+
# p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
|
277 |
+
# # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
|
278 |
+
# p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3
|
279 |
+
# rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3
|
280 |
+
# rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]).squeeze() # batch_size, 3
|
281 |
+
# rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_v.shape) # batch_size, 3
|
282 |
+
##### previous method #####
|
283 |
+
|
284 |
+
fov = 512.; res = 512.
|
285 |
+
K = np.array([
|
286 |
+
[fov, 0, 0.5* res],
|
287 |
+
[0, fov, 0.5* res],
|
288 |
+
[0, 0, 1]
|
289 |
+
], dtype=np.float32)
|
290 |
+
K = torch.from_numpy(K).float().cuda()
|
291 |
+
|
292 |
+
|
293 |
+
# ### `center` mode ### #
|
294 |
+
c2w = self.pose_all[img_idx]
|
295 |
+
pixels_x, pixels_y = pixels_x+0.5, pixels_y+0.5
|
296 |
+
|
297 |
+
dirs = torch.stack([(pixels_x-K[0][2])/K[0][0], -(pixels_y-K[1][2])/K[1][1], -torch.ones_like(pixels_x)], -1)
|
298 |
+
rays_v = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
|
299 |
+
rays_o = c2w[:3,3].expand(rays_v.shape)
|
300 |
+
|
301 |
+
|
302 |
+
return torch.cat([rays_o.cpu(), rays_v.cpu(), color, mask[:, :1]], dim=-1).cuda() # batch_size, 10
|
303 |
+
|
304 |
+
def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
|
305 |
+
"""
|
306 |
+
Interpolate pose between two cameras.
|
307 |
+
"""
|
308 |
+
l = resolution_level
|
309 |
+
tx = torch.linspace(0, self.W - 1, self.W // l)
|
310 |
+
ty = torch.linspace(0, self.H - 1, self.H // l)
|
311 |
+
pixels_x, pixels_y = torch.meshgrid(tx, ty)
|
312 |
+
p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
313 |
+
p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
|
314 |
+
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
315 |
+
trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
|
316 |
+
pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
|
317 |
+
pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
|
318 |
+
pose_0 = np.linalg.inv(pose_0)
|
319 |
+
pose_1 = np.linalg.inv(pose_1)
|
320 |
+
rot_0 = pose_0[:3, :3]
|
321 |
+
rot_1 = pose_1[:3, :3]
|
322 |
+
rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
|
323 |
+
key_times = [0, 1]
|
324 |
+
slerp = Slerp(key_times, rots)
|
325 |
+
rot = slerp(ratio)
|
326 |
+
pose = np.diag([1.0, 1.0, 1.0, 1.0])
|
327 |
+
pose = pose.astype(np.float32)
|
328 |
+
pose[:3, :3] = rot.as_matrix()
|
329 |
+
pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
|
330 |
+
pose = np.linalg.inv(pose)
|
331 |
+
rot = torch.from_numpy(pose[:3, :3]).cuda()
|
332 |
+
trans = torch.from_numpy(pose[:3, 3]).cuda()
|
333 |
+
rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
|
334 |
+
rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
|
335 |
+
return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
|
336 |
+
|
337 |
+
def near_far_from_sphere(self, rays_o, rays_d):
|
338 |
+
a = torch.sum(rays_d**2, dim=-1, keepdim=True)
|
339 |
+
b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
|
340 |
+
mid = 0.5 * (-b) / a
|
341 |
+
near = mid - 1.0
|
342 |
+
far = mid + 1.0
|
343 |
+
return near, far
|
344 |
+
|
345 |
+
## iamge_at ##
|
346 |
+
def image_at(self, idx, resolution_level):
|
347 |
+
if self.selected_img_idxes_list is not None:
|
348 |
+
img = cv.imread(self.images_lis[self.selected_img_idxes_list[idx]])
|
349 |
+
else:
|
350 |
+
img = cv.imread(self.images_lis[idx])
|
351 |
+
return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255)
|
352 |
+
|
353 |
+
|
354 |
+
if __name__=='__main__':
|
355 |
+
data_dir = "/data/datasets/genn/diffsim/diffredmax/save_res/goal_optimize_model_hand_sphere_test_obj_type_active_nfr_10_view_divide_0.5_n_views_7_three_planes_False_recon_dvgo_new_Nposes_7_routine_2"
|
356 |
+
data_dir = "/data/datasets/genn/diffsim/neus/public_data/hand_test"
|
357 |
+
data_dir = "/data2/datasets/diffsim/neus/public_data/hand_test_routine_2"
|
358 |
+
data_dir = "/data2/datasets/diffsim/neus/public_data/hand_test_routine_2_light_color"
|
359 |
+
filter_iamges_via_pixel_values(data_dir=data_dir)
|
models/dataset_wtime.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import cv2 as cv
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
from glob import glob
|
7 |
+
from icecream import ic
|
8 |
+
from scipy.spatial.transform import Rotation as Rot
|
9 |
+
from scipy.spatial.transform import Slerp
|
10 |
+
|
11 |
+
|
12 |
+
# This function is borrowed from IDR: https://github.com/lioryariv/idr
|
13 |
+
def load_K_Rt_from_P(filename, P=None):
|
14 |
+
if P is None:
|
15 |
+
lines = open(filename).read().splitlines()
|
16 |
+
if len(lines) == 4:
|
17 |
+
lines = lines[1:]
|
18 |
+
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
|
19 |
+
P = np.asarray(lines).astype(np.float32).squeeze()
|
20 |
+
|
21 |
+
out = cv.decomposeProjectionMatrix(P)
|
22 |
+
K = out[0]
|
23 |
+
R = out[1]
|
24 |
+
t = out[2]
|
25 |
+
|
26 |
+
K = K / K[2, 2]
|
27 |
+
intrinsics = np.eye(4)
|
28 |
+
intrinsics[:3, :3] = K
|
29 |
+
|
30 |
+
pose = np.eye(4, dtype=np.float32)
|
31 |
+
pose[:3, :3] = R.transpose()
|
32 |
+
pose[:3, 3] = (t[:3] / t[3])[:, 0]
|
33 |
+
|
34 |
+
return intrinsics, pose
|
35 |
+
|
36 |
+
def filter_iamges_via_pixel_values(data_dir):
|
37 |
+
images_lis = sorted(glob(os.path.join(data_dir, 'image/*.png'))) ## images lis ##
|
38 |
+
n_images = len(images_lis)
|
39 |
+
images_np = np.stack([cv.imread(im_name) for im_name in images_lis]) / 255.0
|
40 |
+
print(f"images_np: {images_np.shape}")
|
41 |
+
# nn_frames x res x res x 3 #
|
42 |
+
images_np = 1. - images_np
|
43 |
+
has_density_values = (np.sum(images_np, axis=-1) > 0.7).astype(np.float32)
|
44 |
+
has_density_values = np.sum(np.sum(has_density_values, axis=-1), axis=-1)
|
45 |
+
tot_res_nns = float(images_np.shape[1] * images_np.shape[2])
|
46 |
+
has_density_ratio = has_density_values / tot_res_nns ### has density ratio and ratio #
|
47 |
+
print(f"has_density_values: {has_density_values.shape}")
|
48 |
+
paried_has_density_ratio_list = [(i_fr, has_density_ratio[i_fr].item()) for i_fr in range(has_density_ratio.shape[0])]
|
49 |
+
paried_has_density_ratio_list = sorted(paried_has_density_ratio_list, key=lambda ii: ii[1], reverse=True)
|
50 |
+
mid_rnk_value = len(paried_has_density_ratio_list) // 4
|
51 |
+
print(f"mid value of the density ratio")
|
52 |
+
print(paried_has_density_ratio_list[mid_rnk_value])
|
53 |
+
iamge_idx = paried_has_density_ratio_list[mid_rnk_value][0]
|
54 |
+
print(f"iamge idx: {images_lis[iamge_idx]}")
|
55 |
+
print(paried_has_density_ratio_list[:mid_rnk_value])
|
56 |
+
tot_selected_img_idx_list = [ii[0] for ii in paried_has_density_ratio_list[:mid_rnk_value]]
|
57 |
+
tot_selected_img_idx_list =sorted(tot_selected_img_idx_list)
|
58 |
+
print(len(tot_selected_img_idx_list))
|
59 |
+
# print(tot_selected_img_idx_list[54])
|
60 |
+
print(tot_selected_img_idx_list)
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
class Dataset:
|
65 |
+
def __init__(self, conf, time_idx, mode='train'):
|
66 |
+
super(Dataset, self).__init__()
|
67 |
+
print('Load data: Begin')
|
68 |
+
self.device = torch.device('cuda')
|
69 |
+
self.conf = conf
|
70 |
+
|
71 |
+
self.selected_img_idxes_list = [0, 1, 5, 6, 7, 8, 9, 13, 14, 15, 35, 36, 42, 43, 44, 48, 49, 50, 51, 55, 56, 57, 61, 62, 63, 69, 84, 90, 91, 92, 96, 97]
|
72 |
+
# self.selected_img_idxes_list = [0, 1, 5, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 28, 29, 35, 36, 37, 40, 41, 70, 71, 79, 82, 83, 84, 85, 92, 93, 96, 97, 98, 99, 105, 106, 107, 110, 111, 112, 113, 118, 119, 120, 121, 124, 125, 133, 134, 135, 139, 174, 175, 176, 177, 180, 188, 189, 190, 191, 194, 195]
|
73 |
+
|
74 |
+
self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 36, 40, 41, 70, 71, 78, 82, 83, 84, 85, 90, 91, 92, 93, 96, 97]
|
75 |
+
|
76 |
+
self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 15, 20, 21, 22, 23, 26, 27, 36, 40, 41, 70, 71, 78, 82, 83, 84, 85, 90, 91, 92, 93, 96, 97, 98, 99, 104, 105, 106, 107, 110, 111, 112, 113, 118, 119, 120, 121, 124, 125, 134, 135, 139, 174, 175, 176, 177, 180, 181, 182, 183, 188, 189, 190, 191, 194, 195]
|
77 |
+
# selected img idxes list #
|
78 |
+
self.selected_img_idxes_list = [0, 1, 6, 7, 8, 9, 12, 13, 14, 20, 21, 22, 23, 26, 27, 70, 78, 83, 84, 85, 91, 92, 93, 96, 97, 98, 99, 105, 106, 107, 110, 111, 112, 113, 119, 120, 121, 124, 125, 175, 176, 181, 182, 188, 189, 190, 191, 194, 195]
|
79 |
+
# or the timestep to the dataset instance ## # selected img idxes list #
|
80 |
+
self.selected_img_idxes = np.array(self.selected_img_idxes_list).astype(np.int32)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
self.data_dir = conf.get_string('data_dir')
|
87 |
+
|
88 |
+
self.data_dir = os.path.join(self.data_dir, f"{time_idx}") # the time_idx #
|
89 |
+
|
90 |
+
self.render_cameras_name = conf.get_string('render_cameras_name')
|
91 |
+
self.object_cameras_name = conf.get_string('object_cameras_name')
|
92 |
+
|
93 |
+
## camera outside sphere ##
|
94 |
+
self.camera_outside_sphere = conf.get_bool('camera_outside_sphere', default=True)
|
95 |
+
self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1)
|
96 |
+
|
97 |
+
camera_dict = np.load(os.path.join(self.data_dir, self.render_cameras_name))
|
98 |
+
# camera_dict = np.load("/home/xueyi/diffsim/NeuS/public_data/dtu_scan24/cameras_sphere.npz")
|
99 |
+
self.camera_dict = camera_dict # rendr camera dict #
|
100 |
+
# render camera dict # # number of pixels in the views -> very thin geometry is not useful
|
101 |
+
self.images_lis = sorted(glob(os.path.join(self.data_dir, 'image/*.png')))
|
102 |
+
|
103 |
+
# iamges_lis # and the images_lis and the images_lis #
|
104 |
+
# self.images_lis = self.images_lis[:1] # totoal views and poses of the camera; # and select cameras for rendering #
|
105 |
+
|
106 |
+
self.n_images = len(self.images_lis)
|
107 |
+
|
108 |
+
if mode == 'train_from_model_rules':
|
109 |
+
self.images_np = cv.imread(self.images_lis[0]) / 256.0
|
110 |
+
print(self.images_np.shape)
|
111 |
+
self.images_np = np.reshape(self.images_np, (1, self.images_np.shape[0], self.images_np.shape[1], self.images_np.shape[2]))
|
112 |
+
self.images_np = [self.images_np for _ in range(len(self.images_lis))]
|
113 |
+
self.images_np = np.concatenate(self.images_np, axis=0)
|
114 |
+
else:
|
115 |
+
presaved_imags_npy_fn = os.path.join(self.data_dir, "processed_images.npy")
|
116 |
+
if not os.path.exists(presaved_imags_npy_fn):
|
117 |
+
self.images_np = []
|
118 |
+
for i_im_idx, im_name in enumerate(self.images_lis):
|
119 |
+
print(f"loading {i_im_idx} / {len(self.images_lis)}")
|
120 |
+
cur_im = cv.imread(im_name) # for im_name in self.images_lis
|
121 |
+
self.images_np.append(cur_im)
|
122 |
+
self.images_np = np.stack(self.images_np) / 256.0
|
123 |
+
np.save(presaved_imags_npy_fn, self.images_np)
|
124 |
+
else:
|
125 |
+
print(f"Loading from {presaved_imags_npy_fn}")
|
126 |
+
self.images_np = np.load(presaved_imags_npy_fn, allow_pickle=True)
|
127 |
+
|
128 |
+
# self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 256.0
|
129 |
+
|
130 |
+
|
131 |
+
# self.selected_img_idxes_list = list(range(self.images_np.shape[0]))
|
132 |
+
# self.selected_img_idxes = np.array(self.selected_img_idxes_list).astype(np.int32)
|
133 |
+
|
134 |
+
# get
|
135 |
+
self.images_np = self.images_np[self.selected_img_idxes] ## get selected iamges_np #
|
136 |
+
|
137 |
+
### if we deal with the backgound carefully ### ### get
|
138 |
+
# self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 255.0
|
139 |
+
# self.images_np = self.images_np[self.selected_img_idxes]
|
140 |
+
self.images_np = 1. - self.images_np ###
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
self.masks_lis = sorted(glob(os.path.join(self.data_dir, 'mask/*.png')))
|
145 |
+
|
146 |
+
if mode == 'train_from_model_rules':
|
147 |
+
self.masks_np = cv.imread(self.masks_lis[0]) / 256.0
|
148 |
+
print("masks shape:", self.masks_np.shape)
|
149 |
+
self.masks_np = np.reshape(self.masks_np, (1, self.masks_np.shape[0], self.masks_np.shape[1], self.masks_np.shape[2])) # .repeat(len(self.masks_lis), 1, 1)
|
150 |
+
self.masks_np = [self.masks_np for _ in range(len(self.masks_lis))]
|
151 |
+
self.masks_np = np.concatenate(self.masks_np, axis=0)
|
152 |
+
else:
|
153 |
+
presaved_masks_npy_fn = os.path.join(self.data_dir, "processed_masks.npy")
|
154 |
+
# self.masks_lis = self.masks_lis[:1]
|
155 |
+
|
156 |
+
if not os.path.exists(presaved_masks_npy_fn):
|
157 |
+
try:
|
158 |
+
self.masks_np = np.stack([cv.imread(im_name) for im_name in self.masks_lis]) / 256.0
|
159 |
+
self.masks_np = self.masks_np[self.selected_img_idxes]
|
160 |
+
except:
|
161 |
+
self.masks_np = self.images_np.copy()
|
162 |
+
np.save(presaved_masks_npy_fn, self.masks_np)
|
163 |
+
else:
|
164 |
+
print(f"Loading from {presaved_masks_npy_fn}")
|
165 |
+
self.masks_np = np.load(presaved_masks_npy_fn, allow_pickle=True)
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
# world_mat is a projection matrix from world to image
|
173 |
+
self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
|
174 |
+
|
175 |
+
self.scale_mats_np = []
|
176 |
+
|
177 |
+
# scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.
|
178 |
+
self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
|
179 |
+
|
180 |
+
self.intrinsics_all = []
|
181 |
+
self.pose_all = []
|
182 |
+
|
183 |
+
# for idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, self.world_mats_np)):
|
184 |
+
for idx in self.selected_img_idxes_list:
|
185 |
+
scale_mat = self.scale_mats_np[idx]
|
186 |
+
world_mat = self.world_mats_np[idx]
|
187 |
+
|
188 |
+
if "hand" in self.data_dir:
|
189 |
+
intrinsics = np.eye(4)
|
190 |
+
fov = 512. / 2. # * 2
|
191 |
+
res = 512.
|
192 |
+
intrinsics[:3, :3] = np.array([
|
193 |
+
[fov, 0, 0.5* res], # res #
|
194 |
+
[0, fov, 0.5* res], # res #
|
195 |
+
[0, 0, 1]
|
196 |
+
], dtype=np.float32)
|
197 |
+
pose = camera_dict['camera_mat_%d' % idx].astype(np.float32)
|
198 |
+
else:
|
199 |
+
P = world_mat @ scale_mat
|
200 |
+
P = P[:3, :4]
|
201 |
+
intrinsics, pose = load_K_Rt_from_P(None, P)
|
202 |
+
|
203 |
+
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
|
204 |
+
self.pose_all.append(torch.from_numpy(pose).float())
|
205 |
+
|
206 |
+
### images, masks,
|
207 |
+
self.images = torch.from_numpy(self.images_np.astype(np.float32)).cpu() # [n_images, H, W, 3] #
|
208 |
+
self.masks = torch.from_numpy(self.masks_np.astype(np.float32)).cpu() # [n_images, H, W, 3] #
|
209 |
+
self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4] # optimize sdf field # rigid model hand
|
210 |
+
self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4]
|
211 |
+
self.focal = self.intrinsics_all[0][0, 0]
|
212 |
+
self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4]
|
213 |
+
self.H, self.W = self.images.shape[1], self.images.shape[2]
|
214 |
+
self.image_pixels = self.H * self.W
|
215 |
+
|
216 |
+
object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0])
|
217 |
+
object_bbox_max = np.array([ 1.01, 1.01, 1.01, 1.0])
|
218 |
+
# Object scale mat: region of interest to **extract mesh**
|
219 |
+
object_scale_mat = np.load(os.path.join(self.data_dir, self.object_cameras_name))['scale_mat_0']
|
220 |
+
object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
|
221 |
+
object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
|
222 |
+
self.object_bbox_min = object_bbox_min[:3, 0]
|
223 |
+
self.object_bbox_max = object_bbox_max[:3, 0]
|
224 |
+
|
225 |
+
self.n_images = self.images.size(0)
|
226 |
+
|
227 |
+
print('Load data: End')
|
228 |
+
|
229 |
+
def get_rays(H, W, K, c2w, inverse_y, flip_x, flip_y, mode='center'):
|
230 |
+
i, j = torch.meshgrid( # meshgrid #
|
231 |
+
torch.linspace(0, W-1, W, device=c2w.device),
|
232 |
+
torch.linspace(0, H-1, H, device=c2w.device))
|
233 |
+
i = i.t().float()
|
234 |
+
j = j.t().float()
|
235 |
+
if mode == 'lefttop':
|
236 |
+
pass
|
237 |
+
elif mode == 'center':
|
238 |
+
i, j = i+0.5, j+0.5
|
239 |
+
elif mode == 'random':
|
240 |
+
i = i+torch.rand_like(i)
|
241 |
+
j = j+torch.rand_like(j)
|
242 |
+
else:
|
243 |
+
raise NotImplementedError
|
244 |
+
|
245 |
+
if flip_x:
|
246 |
+
i = i.flip((1,))
|
247 |
+
if flip_y:
|
248 |
+
j = j.flip((0,))
|
249 |
+
if inverse_y:
|
250 |
+
dirs = torch.stack([(i-K[0][2])/K[0][0], (j-K[1][2])/K[1][1], torch.ones_like(i)], -1)
|
251 |
+
else:
|
252 |
+
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
|
253 |
+
# Rotate ray directions from camera frame to the world frame
|
254 |
+
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
|
255 |
+
# Translate camera frame's origin to the world frame. It is the origin of all rays.
|
256 |
+
rays_o = c2w[:3,3].expand(rays_d.shape)
|
257 |
+
return rays_o, rays_d
|
258 |
+
|
259 |
+
def gen_rays_at(self, img_idx, resolution_level=1):
|
260 |
+
"""
|
261 |
+
Generate rays at world space from one camera.
|
262 |
+
"""
|
263 |
+
l = resolution_level
|
264 |
+
tx = torch.linspace(0, self.W - 1, self.W // l)
|
265 |
+
ty = torch.linspace(0, self.H - 1, self.H // l)
|
266 |
+
pixels_x, pixels_y = torch.meshgrid(tx, ty)
|
267 |
+
|
268 |
+
##### previous method #####
|
269 |
+
# p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
270 |
+
# # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
271 |
+
# p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
|
272 |
+
# rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
273 |
+
# rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
|
274 |
+
# rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3
|
275 |
+
##### previous method #####
|
276 |
+
|
277 |
+
fov = 512.; res = 512.
|
278 |
+
K = np.array([
|
279 |
+
[fov, 0, 0.5* res],
|
280 |
+
[0, fov, 0.5* res],
|
281 |
+
[0, 0, 1]
|
282 |
+
], dtype=np.float32)
|
283 |
+
K = torch.from_numpy(K).float().cuda()
|
284 |
+
|
285 |
+
|
286 |
+
# ### `center` mode ### #
|
287 |
+
c2w = self.pose_all[img_idx]
|
288 |
+
pixels_x, pixels_y = pixels_x+0.5, pixels_y+0.5
|
289 |
+
|
290 |
+
dirs = torch.stack([(pixels_x-K[0][2])/K[0][0], -(pixels_y-K[1][2])/K[1][1], -torch.ones_like(pixels_x)], -1)
|
291 |
+
rays_v = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
|
292 |
+
rays_o = c2w[:3,3].expand(rays_v.shape)
|
293 |
+
# dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
|
294 |
+
|
295 |
+
# p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
296 |
+
# # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
297 |
+
# p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
|
298 |
+
# rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
299 |
+
# rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
|
300 |
+
# rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3
|
301 |
+
return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
|
302 |
+
|
303 |
+
def gen_random_rays_at(self, img_idx, batch_size):
|
304 |
+
"""
|
305 |
+
Generate random rays at world space from one camera.
|
306 |
+
"""
|
307 |
+
img_idx = img_idx.cpu()
|
308 |
+
pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]).cpu()
|
309 |
+
pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]).cpu()
|
310 |
+
|
311 |
+
# print(self.images.device, img_idx.device, pixels_y.device)
|
312 |
+
color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3
|
313 |
+
|
314 |
+
mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3
|
315 |
+
|
316 |
+
|
317 |
+
##### previous method #####
|
318 |
+
# p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
|
319 |
+
# # p = torch.stack([pixels_x, pixels_y, -1. * torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
|
320 |
+
# p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3
|
321 |
+
# rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3
|
322 |
+
# rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]).squeeze() # batch_size, 3
|
323 |
+
# rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_v.shape) # batch_size, 3
|
324 |
+
##### previous method #####
|
325 |
+
|
326 |
+
fov = 512.; res = 512.
|
327 |
+
K = np.array([
|
328 |
+
[fov, 0, 0.5* res],
|
329 |
+
[0, fov, 0.5* res],
|
330 |
+
[0, 0, 1]
|
331 |
+
], dtype=np.float32)
|
332 |
+
K = torch.from_numpy(K).float().cuda()
|
333 |
+
|
334 |
+
|
335 |
+
# ### `center` mode ### #
|
336 |
+
c2w = self.pose_all[img_idx]
|
337 |
+
|
338 |
+
pixels_x = pixels_x.cuda()
|
339 |
+
pixels_y = pixels_y.cuda()
|
340 |
+
pixels_x, pixels_y = pixels_x+0.5, pixels_y+0.5
|
341 |
+
|
342 |
+
dirs = torch.stack([(pixels_x-K[0][2])/K[0][0], -(pixels_y-K[1][2])/K[1][1], -torch.ones_like(pixels_x)], -1)
|
343 |
+
rays_v = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
|
344 |
+
rays_o = c2w[:3,3].expand(rays_v.shape)
|
345 |
+
|
346 |
+
|
347 |
+
return torch.cat([rays_o.cpu(), rays_v.cpu(), color, mask[:, :1]], dim=-1).cuda() # batch_size, 10
|
348 |
+
|
349 |
+
def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
|
350 |
+
"""
|
351 |
+
Interpolate pose between two cameras.
|
352 |
+
"""
|
353 |
+
l = resolution_level
|
354 |
+
tx = torch.linspace(0, self.W - 1, self.W // l)
|
355 |
+
ty = torch.linspace(0, self.H - 1, self.H // l)
|
356 |
+
pixels_x, pixels_y = torch.meshgrid(tx, ty)
|
357 |
+
p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
|
358 |
+
p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
|
359 |
+
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
|
360 |
+
trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
|
361 |
+
pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
|
362 |
+
pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
|
363 |
+
pose_0 = np.linalg.inv(pose_0)
|
364 |
+
pose_1 = np.linalg.inv(pose_1)
|
365 |
+
rot_0 = pose_0[:3, :3]
|
366 |
+
rot_1 = pose_1[:3, :3]
|
367 |
+
rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
|
368 |
+
key_times = [0, 1]
|
369 |
+
slerp = Slerp(key_times, rots)
|
370 |
+
rot = slerp(ratio)
|
371 |
+
pose = np.diag([1.0, 1.0, 1.0, 1.0])
|
372 |
+
pose = pose.astype(np.float32)
|
373 |
+
pose[:3, :3] = rot.as_matrix()
|
374 |
+
pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
|
375 |
+
pose = np.linalg.inv(pose)
|
376 |
+
rot = torch.from_numpy(pose[:3, :3]).cuda()
|
377 |
+
trans = torch.from_numpy(pose[:3, 3]).cuda()
|
378 |
+
rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
|
379 |
+
rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
|
380 |
+
return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
|
381 |
+
|
382 |
+
def near_far_from_sphere(self, rays_o, rays_d):
|
383 |
+
a = torch.sum(rays_d**2, dim=-1, keepdim=True)
|
384 |
+
b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
|
385 |
+
mid = 0.5 * (-b) / a
|
386 |
+
near = mid - 1.0
|
387 |
+
far = mid + 1.0
|
388 |
+
return near, far
|
389 |
+
|
390 |
+
def image_at(self, idx, resolution_level):
|
391 |
+
if self.selected_img_idxes_list is not None:
|
392 |
+
img = cv.imread(self.images_lis[self.selected_img_idxes_list[idx]])
|
393 |
+
else:
|
394 |
+
img = cv.imread(self.images_lis[idx])
|
395 |
+
return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255)
|
396 |
+
|
397 |
+
|
398 |
+
if __name__=='__main__':
|
399 |
+
data_dir = "/data/datasets/genn/diffsim/diffredmax/save_res/goal_optimize_model_hand_sphere_test_obj_type_active_nfr_10_view_divide_0.5_n_views_7_three_planes_False_recon_dvgo_new_Nposes_7_routine_2"
|
400 |
+
data_dir = "/data/datasets/genn/diffsim/neus/public_data/hand_test"
|
401 |
+
data_dir = "/data2/datasets/diffsim/neus/public_data/hand_test_routine_2"
|
402 |
+
data_dir = "/data2/datasets/diffsim/neus/public_data/hand_test_routine_2_light_color"
|
403 |
+
filter_iamges_via_pixel_values(data_dir=data_dir)
|
models/dyn_model_act.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/dyn_model_act_v2.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/dyn_model_act_v2_deformable.py
ADDED
@@ -0,0 +1,1582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
# import torch
|
4 |
+
# from ..utils import Timer
|
5 |
+
import numpy as np
|
6 |
+
# import torch.nn.functional as F
|
7 |
+
import os
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
from xml.etree.ElementTree import ElementTree
|
12 |
+
|
13 |
+
import trimesh
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
# import List
|
17 |
+
# class link; joint; body
|
18 |
+
###
|
19 |
+
|
20 |
+
from scipy.spatial.transform import Rotation as R
|
21 |
+
from torch.distributions.uniform import Uniform
|
22 |
+
|
23 |
+
# deformable articulated objects with the articulated models #
|
24 |
+
|
25 |
+
DAMPING = 1.0
|
26 |
+
DAMPING = 0.3
|
27 |
+
|
28 |
+
def plane_rotation_matrix_from_angle_xz(angle):
|
29 |
+
sin_ = torch.sin(angle)
|
30 |
+
cos_ = torch.cos(angle)
|
31 |
+
zero_padding = torch.zeros_like(cos_)
|
32 |
+
one_padding = torch.ones_like(cos_)
|
33 |
+
col_a = torch.stack(
|
34 |
+
[cos_, zero_padding, sin_], dim=0
|
35 |
+
)
|
36 |
+
col_b = torch.stack(
|
37 |
+
[zero_padding, one_padding, zero_padding], dim=0
|
38 |
+
)
|
39 |
+
col_c = torch.stack(
|
40 |
+
[-1. * sin_, zero_padding, cos_], dim=0
|
41 |
+
)
|
42 |
+
rot_mtx = torch.stack(
|
43 |
+
[col_a, col_b, col_c], dim=-1
|
44 |
+
)
|
45 |
+
return rot_mtx
|
46 |
+
|
47 |
+
def plane_rotation_matrix_from_angle(angle):
|
48 |
+
## angle of
|
49 |
+
sin_ = torch.sin(angle)
|
50 |
+
cos_ = torch.cos(angle)
|
51 |
+
col_a = torch.stack(
|
52 |
+
[cos_, sin_], dim=0 ### col of the rotation matrix
|
53 |
+
)
|
54 |
+
col_b = torch.stack(
|
55 |
+
[-1. * sin_, cos_], dim=0 ## cols of the rotation matrix
|
56 |
+
)
|
57 |
+
rot_mtx = torch.stack(
|
58 |
+
[col_a, col_b], dim=-1 ### rotation matrix
|
59 |
+
)
|
60 |
+
return rot_mtx
|
61 |
+
|
62 |
+
def rotation_matrix_from_axis_angle(axis, angle): # rotation_matrix_from_axis_angle ->
|
63 |
+
# sin_ = np.sin(angle) # ti.math.sin(angle)
|
64 |
+
# cos_ = np.cos(angle) # ti.math.cos(angle)
|
65 |
+
sin_ = torch.sin(angle) # ti.math.sin(angle)
|
66 |
+
cos_ = torch.cos(angle) # ti.math.cos(angle)
|
67 |
+
u_x, u_y, u_z = axis[0], axis[1], axis[2]
|
68 |
+
u_xx = u_x * u_x
|
69 |
+
u_yy = u_y * u_y
|
70 |
+
u_zz = u_z * u_z
|
71 |
+
u_xy = u_x * u_y
|
72 |
+
u_xz = u_x * u_z
|
73 |
+
u_yz = u_y * u_z
|
74 |
+
|
75 |
+
row_a = torch.stack(
|
76 |
+
[cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dim=0
|
77 |
+
)
|
78 |
+
# print(f"row_a: {row_a.size()}")
|
79 |
+
row_b = torch.stack(
|
80 |
+
[u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dim=0
|
81 |
+
)
|
82 |
+
# print(f"row_b: {row_b.size()}")
|
83 |
+
row_c = torch.stack(
|
84 |
+
[u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dim=0
|
85 |
+
)
|
86 |
+
# print(f"row_c: {row_c.size()}")
|
87 |
+
|
88 |
+
### rot_mtx for the rot_mtx ###
|
89 |
+
rot_mtx = torch.stack(
|
90 |
+
[row_a, row_b, row_c], dim=-1 ### rot_matrix of he matrix ##
|
91 |
+
)
|
92 |
+
|
93 |
+
return rot_mtx
|
94 |
+
|
95 |
+
|
96 |
+
def update_quaternion(delta_angle, prev_quat):
|
97 |
+
s1 = 0
|
98 |
+
s2 = prev_quat[0]
|
99 |
+
v2 = prev_quat[1:]
|
100 |
+
v1 = delta_angle / 2
|
101 |
+
new_v = s1 * v2 + s2 * v1 + torch.cross(v1, v2)
|
102 |
+
new_s = s1 * s2 - torch.sum(v1 * v2)
|
103 |
+
new_quat = torch.cat([new_s.unsqueeze(0), new_v], dim=0)
|
104 |
+
return new_quat
|
105 |
+
|
106 |
+
|
107 |
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
108 |
+
"""
|
109 |
+
Convert rotations given as quaternions to rotation matrices.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
quaternions: quaternions with real part first,
|
113 |
+
as tensor of shape (..., 4).
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
117 |
+
"""
|
118 |
+
r, i, j, k = torch.unbind(quaternions, -1) # -1 for the quaternion matrix #
|
119 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
120 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
121 |
+
|
122 |
+
o = torch.stack(
|
123 |
+
(
|
124 |
+
1 - two_s * (j * j + k * k),
|
125 |
+
two_s * (i * j - k * r),
|
126 |
+
two_s * (i * k + j * r),
|
127 |
+
two_s * (i * j + k * r),
|
128 |
+
1 - two_s * (i * i + k * k),
|
129 |
+
two_s * (j * k - i * r),
|
130 |
+
two_s * (i * k - j * r),
|
131 |
+
two_s * (j * k + i * r),
|
132 |
+
1 - two_s * (i * i + j * j),
|
133 |
+
),
|
134 |
+
-1,
|
135 |
+
)
|
136 |
+
|
137 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
class Inertial:
|
143 |
+
def __init__(self, origin_rpy, origin_xyz, mass, inertia) -> None:
|
144 |
+
self.origin_rpy = origin_rpy
|
145 |
+
self.origin_xyz = origin_xyz
|
146 |
+
self.mass = mass
|
147 |
+
self.inertia = inertia
|
148 |
+
if torch.sum(self.inertia).item() < 1e-4:
|
149 |
+
self.inertia = self.inertia + torch.eye(3, dtype=torch.float32).cuda()
|
150 |
+
pass
|
151 |
+
|
152 |
+
class Visual:
|
153 |
+
def __init__(self, visual_xyz, visual_rpy, geometry_mesh_fn, geometry_mesh_scale) -> None:
|
154 |
+
# self.visual_origin = visual_origin
|
155 |
+
self.visual_xyz = visual_xyz
|
156 |
+
self.visual_rpy = visual_rpy
|
157 |
+
self.mesh_nm = geometry_mesh_fn.split("/")[-1].split(".")[0]
|
158 |
+
mesh_root = "/home/xueyi/diffsim/NeuS/rsc/mano" ## mano models of the mesh root ##
|
159 |
+
if not os.path.exists(mesh_root):
|
160 |
+
mesh_root = "/data/xueyi/diffsim/NeuS/rsc/mano"
|
161 |
+
self.mesh_root = mesh_root
|
162 |
+
self.geometry_mesh_fn = os.path.join(mesh_root, geometry_mesh_fn)
|
163 |
+
self.geometry_mesh_scale = geometry_mesh_scale
|
164 |
+
# tranformed by xyz #
|
165 |
+
self.vertices, self.faces = self.load_geoemtry_mesh()
|
166 |
+
self.cur_expanded_visual_pts = None
|
167 |
+
pass
|
168 |
+
|
169 |
+
def load_geoemtry_mesh(self, ):
|
170 |
+
# mesh_root =
|
171 |
+
mesh = trimesh.load_mesh(self.geometry_mesh_fn)
|
172 |
+
vertices = mesh.vertices
|
173 |
+
faces = mesh.faces
|
174 |
+
|
175 |
+
vertices = torch.from_numpy(vertices).float().cuda()
|
176 |
+
faces =torch.from_numpy(faces).long().cuda()
|
177 |
+
|
178 |
+
vertices = vertices * self.geometry_mesh_scale.unsqueeze(0) + self.visual_xyz.unsqueeze(0)
|
179 |
+
|
180 |
+
return vertices, faces
|
181 |
+
|
182 |
+
# init_visual_meshes = get_init_visual_meshes(self, parent_rot, parent_trans, init_visual_meshes)
|
183 |
+
def get_init_visual_meshes(self, parent_rot, parent_trans, init_visual_meshes):
|
184 |
+
# cur_vertices = torch.matmul(parent_rot, self.vertices.transpose(1, 0)).contiguous().transpose(1, 0).contiguous() + parent_trans.unsqueeze(0)
|
185 |
+
cur_vertices = self.vertices
|
186 |
+
# print(f"adding mesh loaded from {self.geometry_mesh_fn}")
|
187 |
+
init_visual_meshes['vertices'].append(cur_vertices) # cur vertices # trans #
|
188 |
+
init_visual_meshes['faces'].append(self.faces)
|
189 |
+
return init_visual_meshes
|
190 |
+
|
191 |
+
def expand_visual_pts(self, ):
|
192 |
+
expand_factor = 0.2
|
193 |
+
nn_expand_pts = 20
|
194 |
+
|
195 |
+
expand_factor = 0.4
|
196 |
+
nn_expand_pts = 40 ### number of the expanded points ### ## points ##
|
197 |
+
expand_save_fn = f"{self.mesh_nm}_expanded_pts_factor_{expand_factor}_nnexp_{nn_expand_pts}.npy"
|
198 |
+
expand_save_fn = os.path.join(self.mesh_root, expand_save_fn)
|
199 |
+
|
200 |
+
if not os.path.exists(expand_save_fn):
|
201 |
+
cur_expanded_visual_pts = []
|
202 |
+
if self.cur_expanded_visual_pts is None:
|
203 |
+
cur_src_pts = self.vertices
|
204 |
+
else:
|
205 |
+
cur_src_pts = self.cur_expanded_visual_pts
|
206 |
+
maxx_verts, _ = torch.max(cur_src_pts, dim=0)
|
207 |
+
minn_verts, _ = torch.min(cur_src_pts, dim=0)
|
208 |
+
extent_verts = maxx_verts - minn_verts ## (3,)-dim vecotr
|
209 |
+
norm_extent_verts = torch.norm(extent_verts, dim=-1).item() ## (1,)-dim vector
|
210 |
+
expand_r = norm_extent_verts * expand_factor
|
211 |
+
# nn_expand_pts = 5 # expand the vertices to 5 times of the original vertices
|
212 |
+
for i_pts in range(self.vertices.size(0)):
|
213 |
+
cur_pts = cur_src_pts[i_pts]
|
214 |
+
# sample from the circile with cur_pts as thejcenter and the radius as expand_r
|
215 |
+
# (-r, r) # sample the offset vector in the size of (nn_expand_pts, 3)
|
216 |
+
offset_dist = Uniform(-1. * expand_r, expand_r)
|
217 |
+
offset_vec = offset_dist.sample((nn_expand_pts, 3)).cuda()
|
218 |
+
cur_expanded_pts = cur_pts + offset_vec
|
219 |
+
cur_expanded_visual_pts.append(cur_expanded_pts)
|
220 |
+
cur_expanded_visual_pts = torch.cat(cur_expanded_visual_pts, dim=0)
|
221 |
+
np.save(expand_save_fn, cur_expanded_visual_pts.detach().cpu().numpy())
|
222 |
+
else:
|
223 |
+
print(f"Loading visual pts from {expand_save_fn}") # load from the fn #
|
224 |
+
cur_expanded_visual_pts = np.load(expand_save_fn, allow_pickle=True)
|
225 |
+
cur_expanded_visual_pts = torch.from_numpy(cur_expanded_visual_pts).float().cuda()
|
226 |
+
self.cur_expanded_visual_pts = cur_expanded_visual_pts # expanded visual pts #
|
227 |
+
return self.cur_expanded_visual_pts
|
228 |
+
# cur_pts #
|
229 |
+
# use r as the search direction # # expande save fn #
|
230 |
+
def get_transformed_visual_pts(self, visual_pts_list):
|
231 |
+
visual_pts_list.append(self.cur_expanded_visual_pts) #
|
232 |
+
return visual_pts_list
|
233 |
+
|
234 |
+
|
235 |
+
|
236 |
+
## link urdf ## expand the visual pts to form the expanded visual grids pts #
|
237 |
+
# use get_name_to_visual_pts_faces to get the transformed visual pts and faces #
|
238 |
+
## Link_urdf ##
|
239 |
+
class Link_urdf: # get_transformed_visual_pts #
|
240 |
+
def __init__(self, name, inertial: Inertial, visual: Visual=None) -> None:
|
241 |
+
|
242 |
+
self.name = name
|
243 |
+
self.inertial = inertial
|
244 |
+
self.visual = visual # vsiual meshes #
|
245 |
+
|
246 |
+
# self.joint = joint
|
247 |
+
# self.body = body
|
248 |
+
# self.children = children
|
249 |
+
# self.name = name
|
250 |
+
|
251 |
+
self.link_idx = ...
|
252 |
+
|
253 |
+
# self.args = args
|
254 |
+
|
255 |
+
self.joint = None # joint name to struct
|
256 |
+
# self.join
|
257 |
+
self.children = ...
|
258 |
+
self.children = {} # joint name to child sruct
|
259 |
+
|
260 |
+
### dyn_model_act ###
|
261 |
+
# parent_rot_mtx, parent_trans_vec #
|
262 |
+
# parent_rot_mtx, parent_trans_vec # # link urdf #
|
263 |
+
# self.parent_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(), requires_grad=True)
|
264 |
+
# self.parent_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(), requires_grad=True)
|
265 |
+
# self.curr_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(), requires_grad=True)
|
266 |
+
# self.curr_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(), requires_grad=True)
|
267 |
+
# #
|
268 |
+
# self.tot_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(), requires_grad=True)
|
269 |
+
# self.tot_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(), requires_grad=True)
|
270 |
+
|
271 |
+
# expand visual pts #
|
272 |
+
def expand_visual_pts(self, expanded_visual_pts, link_name_to_visited, link_name_to_link_struct):
|
273 |
+
link_name_to_visited[self.name] = 1
|
274 |
+
if self.visual is not None:
|
275 |
+
cur_expanded_visual_pts = self.visual.expand_visual_pts()
|
276 |
+
expanded_visual_pts.append(cur_expanded_visual_pts)
|
277 |
+
|
278 |
+
for cur_link in self.children:
|
279 |
+
cur_link_struct = link_name_to_link_struct[self.children[cur_link]]
|
280 |
+
cur_link_name = cur_link_struct.name
|
281 |
+
if cur_link_name in link_name_to_visited:
|
282 |
+
continue
|
283 |
+
expanded_visual_pts = cur_link_struct.expand_visual_pts(expanded_visual_pts, link_name_to_visited, link_name_to_link_struct)
|
284 |
+
return expanded_visual_pts
|
285 |
+
|
286 |
+
def get_transformed_visual_pts(self, visual_pts_list, link_name_to_visited, link_name_to_link_struct):
|
287 |
+
link_name_to_visited[self.name] = 1
|
288 |
+
|
289 |
+
if self.joint is not None:
|
290 |
+
for cur_joint_name in self.joint:
|
291 |
+
cur_joint = self.joint[cur_joint_name]
|
292 |
+
cur_child_name = self.children[cur_joint_name]
|
293 |
+
cur_child = link_name_to_link_struct[cur_child_name] # parent and the child_visual, cur_child.visual #
|
294 |
+
# parent #
|
295 |
+
# print(f"joint: {cur_joint.name}, child: {cur_child_name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
|
296 |
+
# print(f"joint: {cur_joint.name}, child: {cur_child_name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
|
297 |
+
# joint_origin_xyz = cur_joint.origin_xyz ## transformed_visual_pts ####
|
298 |
+
if cur_child_name in link_name_to_visited:
|
299 |
+
continue
|
300 |
+
# cur_child_visual_pts = {'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []}
|
301 |
+
cur_child_visual_pts_list = []
|
302 |
+
cur_child_visual_pts_list = cur_child.get_transformed_visual_pts(cur_child_visual_pts_list, link_name_to_visited, link_name_to_link_struct)
|
303 |
+
|
304 |
+
if len(cur_child_visual_pts_list) > 0:
|
305 |
+
cur_child_visual_pts = torch.cat(cur_child_visual_pts_list, dim=0)
|
306 |
+
# cur_child_verts, cur_child_faces = cur_child_visual_pts['vertices'], cur_child_visual_pts['faces']
|
307 |
+
# cur_child_link_idxes = cur_child_visual_pts['link_idxes']
|
308 |
+
# cur_transformed_joint_pos = cur_child_visual_pts['transformed_joint_pos']
|
309 |
+
# joint_link_idxes = cur_child_visual_pts['joint_link_idxes']
|
310 |
+
# if len(cur_child_verts) > 0:
|
311 |
+
# cur_child_verts, cur_child_faces = merge_meshes(cur_child_verts, cur_child_faces)
|
312 |
+
cur_child_visual_pts = cur_child_visual_pts + cur_joint.origin_xyz.unsqueeze(0)
|
313 |
+
cur_joint_rot, cur_joint_trans = cur_joint.compute_transformation_from_current_state()
|
314 |
+
cur_child_visual_pts = torch.matmul(cur_joint_rot, cur_child_visual_pts.transpose(1, 0).contiguous()).transpose(1, 0).contiguous() + cur_joint_trans.unsqueeze(0)
|
315 |
+
|
316 |
+
# if len(cur_transformed_joint_pos) > 0:
|
317 |
+
# cur_transformed_joint_pos = torch.cat(cur_transformed_joint_pos, dim=0)
|
318 |
+
# cur_transformed_joint_pos = cur_transformed_joint_pos + cur_joint.origin_xyz.unsqueeze(0)
|
319 |
+
# cur_transformed_joint_pos = torch.matmul(cur_joint_rot, cur_transformed_joint_pos.transpose(1, 0).contiguous()).transpose(1, 0).contiguous() + cur_joint_trans.unsqueeze(0)
|
320 |
+
# cur_joint_pos = cur_joint_trans.unsqueeze(0).clone()
|
321 |
+
# cur_transformed_joint_pos = torch.cat(
|
322 |
+
# [cur_transformed_joint_pos, cur_joint_pos], dim=0 ##### joint poses #####
|
323 |
+
# )
|
324 |
+
# else:
|
325 |
+
# cur_transformed_joint_pos = cur_joint_trans.unsqueeze(0).clone()
|
326 |
+
|
327 |
+
# if len(joint_link_idxes) > 0:
|
328 |
+
# joint_link_idxes = torch.cat(joint_link_idxes, dim=-1) ### joint_link idxes ###
|
329 |
+
# cur_joint_idx = cur_child.link_idx
|
330 |
+
# joint_link_idxes = torch.cat(
|
331 |
+
# [joint_link_idxes, torch.tensor([cur_joint_idx], dtype=torch.long).cuda()], dim=-1
|
332 |
+
# )
|
333 |
+
# else:
|
334 |
+
# joint_link_idxes = torch.tensor([cur_child.link_idx], dtype=torch.long).cuda().view(1,)
|
335 |
+
|
336 |
+
|
337 |
+
visual_pts_list.append(cur_child_visual_pts)
|
338 |
+
# cur_child_verts = cur_child_verts + # transformed joint pos #
|
339 |
+
# cur_child_link_idxes = torch.cat(cur_child_link_idxes, dim=-1)
|
340 |
+
# # joint_link_idxes = torch.cat(joint_link_idxes, dim=-1)
|
341 |
+
# init_visual_meshes['vertices'].append(cur_child_verts)
|
342 |
+
# init_visual_meshes['faces'].append(cur_child_faces)
|
343 |
+
# init_visual_meshes['link_idxes'].append(cur_child_link_idxes)
|
344 |
+
# init_visual_meshes['transformed_joint_pos'].append(cur_transformed_joint_pos)
|
345 |
+
# init_visual_meshes['joint_link_idxes'].append(joint_link_idxes)
|
346 |
+
|
347 |
+
|
348 |
+
if self.visual is not None:
|
349 |
+
# get_transformed_visual_pts #
|
350 |
+
visual_pts_list = self.visual.get_transformed_visual_pts(visual_pts_list)
|
351 |
+
|
352 |
+
# for cur_link in self.children:
|
353 |
+
# cur_link_name = cur_link.name
|
354 |
+
# if cur_link_name in link_name_to_visited: # link name to visited #
|
355 |
+
# continue
|
356 |
+
# visual_pts_list = cur_link.get_transformed_visual_pts(visual_pts_list, link_name_to_visited, link_name_to_link_struct)
|
357 |
+
return visual_pts_list
|
358 |
+
|
359 |
+
# use both the articulated motion and the frre form
|
360 |
+
def set_initial_state(self, states, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct):
|
361 |
+
|
362 |
+
link_name_to_visited[self.name] = 1
|
363 |
+
|
364 |
+
if self.joint is not None:
|
365 |
+
for cur_joint_name in self.joint:
|
366 |
+
cur_joint = self.joint[cur_joint_name]
|
367 |
+
cur_joint_name = cur_joint.name
|
368 |
+
cur_child = self.children[cur_joint_name]
|
369 |
+
cur_child_struct = link_name_to_link_struct[cur_child]
|
370 |
+
cur_child_name = cur_child_struct.name
|
371 |
+
|
372 |
+
if cur_child_name in link_name_to_visited:
|
373 |
+
continue
|
374 |
+
if cur_joint.type in ['revolute']:
|
375 |
+
cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name] # action joint name to joint idx #
|
376 |
+
# cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name] #
|
377 |
+
# cur_joint = self.joint[cur_joint_name]
|
378 |
+
cur_state = states[cur_joint_idx] ### joint state ###
|
379 |
+
cur_joint.set_initial_state(cur_state)
|
380 |
+
cur_child_struct.set_initial_state(states, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct)
|
381 |
+
|
382 |
+
|
383 |
+
|
384 |
+
def get_init_visual_meshes(self, parent_rot, parent_trans, init_visual_meshes, link_name_to_link_struct, link_name_to_visited):
|
385 |
+
link_name_to_visited[self.name] = 1
|
386 |
+
|
387 |
+
# 'transformed_joint_pos': [], 'link_idxes': []
|
388 |
+
if self.joint is not None:
|
389 |
+
# for i_ch, (cur_joint, cur_child) in enumerate(zip(self.joint, self.children)):
|
390 |
+
# print(f"joint: {cur_joint.name}, child: {cur_child.name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
|
391 |
+
# joint_origin_xyz = cur_joint.origin_xyz
|
392 |
+
# init_visual_meshes = cur_child.get_init_visual_meshes(parent_rot, parent_trans + joint_origin_xyz, init_visual_meshes)
|
393 |
+
# print(f"name: {self.name}, keys: {self.joint.keys()}")
|
394 |
+
for cur_joint_name in self.joint: #
|
395 |
+
cur_joint = self.joint[cur_joint_name]
|
396 |
+
cur_child_name = self.children[cur_joint_name]
|
397 |
+
cur_child = link_name_to_link_struct[cur_child_name]
|
398 |
+
# print(f"joint: {cur_joint.name}, child: {cur_child_name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
|
399 |
+
# print(f"joint: {cur_joint.name}, child: {cur_child_name}, parent: {self.name}, child_visual: {cur_child.visual is not None}")
|
400 |
+
joint_origin_xyz = cur_joint.origin_xyz
|
401 |
+
if cur_child_name in link_name_to_visited:
|
402 |
+
continue
|
403 |
+
cur_child_visual_pts = {'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []}
|
404 |
+
cur_child_visual_pts = cur_child.get_init_visual_meshes(parent_rot, parent_trans + joint_origin_xyz, cur_child_visual_pts, link_name_to_link_struct, link_name_to_visited)
|
405 |
+
cur_child_verts, cur_child_faces = cur_child_visual_pts['vertices'], cur_child_visual_pts['faces']
|
406 |
+
cur_child_link_idxes = cur_child_visual_pts['link_idxes']
|
407 |
+
cur_transformed_joint_pos = cur_child_visual_pts['transformed_joint_pos']
|
408 |
+
joint_link_idxes = cur_child_visual_pts['joint_link_idxes']
|
409 |
+
if len(cur_child_verts) > 0:
|
410 |
+
cur_child_verts, cur_child_faces = merge_meshes(cur_child_verts, cur_child_faces)
|
411 |
+
cur_child_verts = cur_child_verts + cur_joint.origin_xyz.unsqueeze(0)
|
412 |
+
cur_joint_rot, cur_joint_trans = cur_joint.compute_transformation_from_current_state()
|
413 |
+
cur_child_verts = torch.matmul(cur_joint_rot, cur_child_verts.transpose(1, 0).contiguous()).transpose(1, 0).contiguous() + cur_joint_trans.unsqueeze(0)
|
414 |
+
|
415 |
+
if len(cur_transformed_joint_pos) > 0:
|
416 |
+
cur_transformed_joint_pos = torch.cat(cur_transformed_joint_pos, dim=0)
|
417 |
+
cur_transformed_joint_pos = cur_transformed_joint_pos + cur_joint.origin_xyz.unsqueeze(0)
|
418 |
+
cur_transformed_joint_pos = torch.matmul(cur_joint_rot, cur_transformed_joint_pos.transpose(1, 0).contiguous()).transpose(1, 0).contiguous() + cur_joint_trans.unsqueeze(0)
|
419 |
+
cur_joint_pos = cur_joint_trans.unsqueeze(0).clone()
|
420 |
+
cur_transformed_joint_pos = torch.cat(
|
421 |
+
[cur_transformed_joint_pos, cur_joint_pos], dim=0 ##### joint poses #####
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
cur_transformed_joint_pos = cur_joint_trans.unsqueeze(0).clone()
|
425 |
+
|
426 |
+
if len(joint_link_idxes) > 0:
|
427 |
+
joint_link_idxes = torch.cat(joint_link_idxes, dim=-1) ### joint_link idxes ###
|
428 |
+
cur_joint_idx = cur_child.link_idx
|
429 |
+
joint_link_idxes = torch.cat(
|
430 |
+
[joint_link_idxes, torch.tensor([cur_joint_idx], dtype=torch.long).cuda()], dim=-1
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
joint_link_idxes = torch.tensor([cur_child.link_idx], dtype=torch.long).cuda().view(1,)
|
434 |
+
|
435 |
+
|
436 |
+
|
437 |
+
# cur_child_verts = cur_child_verts + # transformed joint pos #
|
438 |
+
cur_child_link_idxes = torch.cat(cur_child_link_idxes, dim=-1)
|
439 |
+
# joint_link_idxes = torch.cat(joint_link_idxes, dim=-1)
|
440 |
+
init_visual_meshes['vertices'].append(cur_child_verts)
|
441 |
+
init_visual_meshes['faces'].append(cur_child_faces)
|
442 |
+
init_visual_meshes['link_idxes'].append(cur_child_link_idxes)
|
443 |
+
init_visual_meshes['transformed_joint_pos'].append(cur_transformed_joint_pos)
|
444 |
+
init_visual_meshes['joint_link_idxes'].append(joint_link_idxes)
|
445 |
+
|
446 |
+
|
447 |
+
# joint_origin_xyz = self.joint.origin_xyz
|
448 |
+
else:
|
449 |
+
joint_origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32).cuda()
|
450 |
+
# self.parent_rot_mtx = parent_rot
|
451 |
+
# self.parent_trans_vec = parent_trans + joint_origin_xyz
|
452 |
+
|
453 |
+
|
454 |
+
if self.visual is not None:
|
455 |
+
init_visual_meshes = self.visual.get_init_visual_meshes(parent_rot, parent_trans, init_visual_meshes)
|
456 |
+
cur_visual_mesh_pts_nn = self.visual.vertices.size(0)
|
457 |
+
cur_link_idxes = torch.zeros((cur_visual_mesh_pts_nn, ), dtype=torch.long).cuda()+ self.link_idx
|
458 |
+
init_visual_meshes['link_idxes'].append(cur_link_idxes)
|
459 |
+
|
460 |
+
# for cur_link in self.children: #
|
461 |
+
# init_visual_meshes = cur_link.get_init_visual_meshes(self.parent_rot_mtx, self.parent_trans_vec, init_visual_meshes)
|
462 |
+
return init_visual_meshes ## init visual meshes ##
|
463 |
+
|
464 |
+
# calculate inerti
|
465 |
+
def calculate_inertia(self, link_name_to_visited, link_name_to_link_struct):
|
466 |
+
link_name_to_visited[self.name] = 1
|
467 |
+
self.cur_inertia = torch.zeros((3, 3), dtype=torch.float32).cuda()
|
468 |
+
|
469 |
+
if self.joint is not None:
|
470 |
+
for joint_nm in self.joint:
|
471 |
+
cur_joint = self.joint[joint_nm]
|
472 |
+
cur_child = self.children[joint_nm]
|
473 |
+
cur_child_struct = link_name_to_link_struct[cur_child]
|
474 |
+
cur_child_name = cur_child_struct.name
|
475 |
+
if cur_child_name in link_name_to_visited:
|
476 |
+
continue
|
477 |
+
joint_rot, joint_trans = cur_joint.compute_transformation_from_current_state(n_grad=True)
|
478 |
+
# cur_parent_rot = torch.matmul(parent_rot, joint_rot) #
|
479 |
+
# cur_parent_trans = torch.matmul(parent_rot, joint_trans.unsqueeze(-1)).squeeze(-1) + parent_trans #
|
480 |
+
child_inertia = cur_child_struct.calculate_inertia(link_name_to_visited, link_name_to_link_struct)
|
481 |
+
child_inertia = torch.matmul(
|
482 |
+
joint_rot.detach(), torch.matmul(child_inertia, joint_rot.detach().transpose(1, 0).contiguous())
|
483 |
+
).detach()
|
484 |
+
self.cur_inertia += child_inertia
|
485 |
+
# if self.visual is not None:
|
486 |
+
# self.cur_inertia += self.visual.inertia
|
487 |
+
self.cur_inertia += self.inertial.inertia.detach()
|
488 |
+
return self.cur_inertia
|
489 |
+
|
490 |
+
|
491 |
+
def set_delta_state_and_update(self, states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, link_name_to_link_struct):
|
492 |
+
|
493 |
+
link_name_to_visited[self.name] = 1
|
494 |
+
|
495 |
+
if self.joint is not None:
|
496 |
+
for cur_joint_name in self.joint:
|
497 |
+
|
498 |
+
cur_joint = self.joint[cur_joint_name] # joint model
|
499 |
+
|
500 |
+
cur_child = self.children[cur_joint_name] # child model #
|
501 |
+
|
502 |
+
cur_child_struct = link_name_to_link_struct[cur_child]
|
503 |
+
|
504 |
+
cur_child_name = cur_child_struct.name
|
505 |
+
|
506 |
+
if cur_child_name in link_name_to_visited:
|
507 |
+
continue
|
508 |
+
|
509 |
+
## cur child inertia ##
|
510 |
+
# cur_child_inertia = cur_child_struct.cur_inertia
|
511 |
+
|
512 |
+
|
513 |
+
if cur_joint.type in ['revolute']:
|
514 |
+
cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name]
|
515 |
+
cur_state = states[cur_joint_idx]
|
516 |
+
### get the child struct ###
|
517 |
+
# set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
|
518 |
+
# set actions and update states #
|
519 |
+
cur_joint.set_delta_state_and_update(cur_state, cur_timestep)
|
520 |
+
|
521 |
+
cur_child_struct.set_delta_state_and_update(states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, link_name_to_link_struct)
|
522 |
+
|
523 |
+
|
524 |
+
|
525 |
+
# the joint #
|
526 |
+
# set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
|
527 |
+
def set_actions_and_update_states(self, actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct):
|
528 |
+
|
529 |
+
link_name_to_visited[self.name] = 1
|
530 |
+
|
531 |
+
# the current joint of the
|
532 |
+
if self.joint is not None:
|
533 |
+
for cur_joint_name in self.joint:
|
534 |
+
|
535 |
+
cur_joint = self.joint[cur_joint_name] # joint model
|
536 |
+
|
537 |
+
cur_child = self.children[cur_joint_name] # child model #
|
538 |
+
|
539 |
+
cur_child_struct = link_name_to_link_struct[cur_child]
|
540 |
+
|
541 |
+
cur_child_name = cur_child_struct.name
|
542 |
+
|
543 |
+
if cur_child_name in link_name_to_visited:
|
544 |
+
continue
|
545 |
+
|
546 |
+
cur_child_inertia = cur_child_struct.cur_inertia
|
547 |
+
|
548 |
+
|
549 |
+
if cur_joint.type in ['revolute']:
|
550 |
+
cur_joint_idx = action_joint_name_to_joint_idx[cur_joint_name]
|
551 |
+
cur_action = actions[cur_joint_idx]
|
552 |
+
### get the child struct ###
|
553 |
+
# set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
|
554 |
+
# set actions and update states #
|
555 |
+
cur_joint.set_actions_and_update_states(cur_action, cur_timestep, time_cons, cur_child_inertia.detach())
|
556 |
+
|
557 |
+
cur_child_struct.set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, link_name_to_link_struct)
|
558 |
+
|
559 |
+
|
560 |
+
def set_init_states_target_value(self, init_states):
|
561 |
+
if self.joint.type == 'revolute':
|
562 |
+
self.joint_angle = init_states[self.joint.joint_idx]
|
563 |
+
joint_axis = self.joint.axis
|
564 |
+
self.rot_vec = self.joint_angle * joint_axis
|
565 |
+
self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32).cuda()
|
566 |
+
self.joint.state = self.joint.state + update_quaternion(self.rot_vec, self.joint.state)
|
567 |
+
self.joint.timestep_to_states[0] = self.joint.state.detach()
|
568 |
+
self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
|
569 |
+
for cur_link in self.children:
|
570 |
+
cur_link.set_init_states_target_value(init_states)
|
571 |
+
|
572 |
+
# should forward for one single step -> use the action #
|
573 |
+
def set_init_states(self, ):
|
574 |
+
self.joint.state = torch.tensor([1, 0, 0, 0], dtype=torch.float32).cuda()
|
575 |
+
self.joint.timestep_to_states[0] = self.joint.state.detach()
|
576 |
+
self.joint.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
|
577 |
+
for cur_link in self.children:
|
578 |
+
cur_link.set_init_states()
|
579 |
+
|
580 |
+
|
581 |
+
def get_visual_pts(self, visual_pts_list):
|
582 |
+
visual_pts_list = self.body.get_visual_pts(visual_pts_list)
|
583 |
+
for cur_link in self.children:
|
584 |
+
visual_pts_list = cur_link.get_visual_pts(visual_pts_list)
|
585 |
+
visual_pts_list = torch.cat(visual_pts_list, dim=0)
|
586 |
+
return visual_pts_list
|
587 |
+
|
588 |
+
def get_visual_faces_list(self, visual_faces_list):
|
589 |
+
visual_faces_list = self.body.get_visual_faces_list(visual_faces_list)
|
590 |
+
for cur_link in self.children:
|
591 |
+
visual_faces_list = cur_link.get_visual_faces_list(visual_faces_list)
|
592 |
+
return visual_faces_list
|
593 |
+
# pass
|
594 |
+
|
595 |
+
|
596 |
+
def set_state(self, name_to_state):
|
597 |
+
self.joint.set_state(name_to_state=name_to_state)
|
598 |
+
for child_link in self.children:
|
599 |
+
child_link.set_state(name_to_state)
|
600 |
+
|
601 |
+
def set_state_via_vec(self, state_vec):
|
602 |
+
self.joint.set_state_via_vec(state_vec)
|
603 |
+
for child_link in self.children:
|
604 |
+
child_link.set_state_via_vec(state_vec)
|
605 |
+
|
606 |
+
|
607 |
+
|
608 |
+
|
609 |
+
class Joint_Limit:
|
610 |
+
def __init__(self, effort, lower, upper, velocity) -> None:
|
611 |
+
self.effort = effort
|
612 |
+
self.lower = lower
|
613 |
+
self.velocity = velocity
|
614 |
+
self.upper = upper
|
615 |
+
pass
|
616 |
+
|
617 |
+
# Joint_urdf(name, joint_type, parent_link, child_link, origin_xyz, axis_xyz, limit: Joint_Limit)
|
618 |
+
class Joint_urdf: #
|
619 |
+
|
620 |
+
def __init__(self, name, joint_type, parent_link, child_link, origin_xyz, axis_xyz, limit: Joint_Limit) -> None:
|
621 |
+
self.name = name
|
622 |
+
self.type = joint_type
|
623 |
+
self.parent_link = parent_link
|
624 |
+
self.child_link = child_link
|
625 |
+
self.origin_xyz = origin_xyz
|
626 |
+
self.axis_xyz = axis_xyz
|
627 |
+
self.limit = limit
|
628 |
+
|
629 |
+
# joint angle; joint state #
|
630 |
+
self.timestep_to_vels = {}
|
631 |
+
self.timestep_to_states = {}
|
632 |
+
|
633 |
+
self.init_pos = self.origin_xyz.clone()
|
634 |
+
|
635 |
+
#### only for the current state #### # joint urdf #
|
636 |
+
self.state = nn.Parameter(
|
637 |
+
torch.tensor([1., 0., 0., 0.], dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True
|
638 |
+
)
|
639 |
+
self.action = nn.Parameter(
|
640 |
+
torch.zeros((1,), dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True
|
641 |
+
)
|
642 |
+
# self.rot_mtx = np.eye(3, dtypes=np.float32)
|
643 |
+
# self.trans_vec = np.zeros((3,), dtype=np.float32) ## rot m
|
644 |
+
self.rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True)
|
645 |
+
self.trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32, requires_grad=True).cuda(), requires_grad=True)
|
646 |
+
|
647 |
+
def set_initial_state(self, state):
|
648 |
+
# joint angle as the state value #
|
649 |
+
self.timestep_to_vels[0] = torch.zeros((3,), dtype=torch.float32).cuda().detach() ## velocity ##
|
650 |
+
delta_rot_vec = self.axis_xyz * state
|
651 |
+
# self.timestep_to_states[0] = state.detach()
|
652 |
+
cur_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
|
653 |
+
init_state = cur_state + update_quaternion(delta_rot_vec, cur_state)
|
654 |
+
self.timestep_to_states[0] = init_state.detach()
|
655 |
+
self.state = init_state
|
656 |
+
|
657 |
+
def set_delta_state_and_update(self, state, cur_timestep):
|
658 |
+
self.timestep_to_vels[cur_timestep] = torch.zeros((3,), dtype=torch.float32).cuda().detach()
|
659 |
+
delta_rot_vec = self.axis_xyz * state
|
660 |
+
if cur_timestep == 0:
|
661 |
+
prev_state = torch.tensor([1., 0., 0., 0.], dtype=torch.float32).cuda()
|
662 |
+
else:
|
663 |
+
prev_state = self.timestep_to_states[cur_timestep - 1].detach()
|
664 |
+
cur_state = prev_state + update_quaternion(delta_rot_vec, prev_state)
|
665 |
+
self.timestep_to_states[cur_timestep] = cur_state.detach()
|
666 |
+
self.state = cur_state
|
667 |
+
|
668 |
+
|
669 |
+
|
670 |
+
def compute_transformation_from_current_state(self, n_grad=False):
|
671 |
+
# together with the parent rot mtx and the parent trans vec #
|
672 |
+
# cur_joint_state = self.state
|
673 |
+
if self.type == "revolute":
|
674 |
+
# rot_mtx = rotation_matrix_from_axis_angle(self.axis, cur_joint_state)
|
675 |
+
# trans_vec = self.pos - np.matmul(rot_mtx, self.pos.reshape(3, 1)).reshape(3)
|
676 |
+
if n_grad:
|
677 |
+
rot_mtx = quaternion_to_matrix(self.state.detach())
|
678 |
+
else:
|
679 |
+
rot_mtx = quaternion_to_matrix(self.state)
|
680 |
+
# trans_vec = self.pos - torch.matmul(rot_mtx, self.pos.view(3, 1)).view(3).contiguous()
|
681 |
+
trans_vec = self.origin_xyz - torch.matmul(rot_mtx, self.origin_xyz.view(3, 1)).view(3).contiguous()
|
682 |
+
self.rot_mtx = rot_mtx
|
683 |
+
self.trans_vec = trans_vec
|
684 |
+
elif self.type == "fixed":
|
685 |
+
rot_mtx = torch.eye(3, dtype=torch.float32).cuda()
|
686 |
+
trans_vec = torch.zeros((3,), dtype=torch.float32).cuda()
|
687 |
+
# trans_vec = self.origin_xyz
|
688 |
+
self.rot_mtx = rot_mtx
|
689 |
+
self.trans_vec = trans_vec #
|
690 |
+
else:
|
691 |
+
pass
|
692 |
+
return self.rot_mtx, self.trans_vec
|
693 |
+
|
694 |
+
|
695 |
+
# set actions # set actions and udpate states #
|
696 |
+
def set_actions_and_update_states(self, action, cur_timestep, time_cons, cur_inertia):
|
697 |
+
|
698 |
+
# timestep_to_vels, timestep_to_states, state #
|
699 |
+
if self.type in ['revolute']:
|
700 |
+
|
701 |
+
self.action = action
|
702 |
+
#
|
703 |
+
# visual_pts and visual_pts_mass #
|
704 |
+
# cur_joint_pos = self.joint.pos #
|
705 |
+
# TODO: check whether the following is correct #
|
706 |
+
torque = self.action * self.axis_xyz
|
707 |
+
|
708 |
+
# # Compute inertia matrix #
|
709 |
+
# inertial = torch.zeros((3, 3), dtype=torch.float32).cuda()
|
710 |
+
# for i_pts in range(self.visual_pts.size(0)):
|
711 |
+
# cur_pts = self.visual_pts[i_pts]
|
712 |
+
# cur_pts_mass = self.visual_pts_mass[i_pts]
|
713 |
+
# cur_r = cur_pts - cur_joint_pos # r_i
|
714 |
+
# # cur_vert = init_passive_mesh[i_v]
|
715 |
+
# # cur_r = cur_vert - init_passive_mesh_center
|
716 |
+
# dot_r_r = torch.sum(cur_r * cur_r)
|
717 |
+
# cur_eye_mtx = torch.eye(3, dtype=torch.float32).cuda()
|
718 |
+
# r_mult_rT = torch.matmul(cur_r.unsqueeze(-1), cur_r.unsqueeze(0))
|
719 |
+
# inertial += (dot_r_r * cur_eye_mtx - r_mult_rT) * cur_pts_mass
|
720 |
+
# m = torch.sum(self.visual_pts_mass)
|
721 |
+
# # Use torque to update angular velocity -> state #
|
722 |
+
# inertia_inv = torch.linalg.inv(inertial)
|
723 |
+
|
724 |
+
# axis-angle of
|
725 |
+
# inertia_inv = self.cur_inertia_inv
|
726 |
+
# print(f"updating actions and states for the joint {self.name} with type {self.type}")
|
727 |
+
inertia_inv = torch.linalg.inv(cur_inertia).detach()
|
728 |
+
|
729 |
+
delta_omega = torch.matmul(inertia_inv, torque.unsqueeze(-1)).squeeze(-1)
|
730 |
+
|
731 |
+
# delta_omega = torque / 400 #
|
732 |
+
|
733 |
+
# timestep_to_vels, timestep_to_states, state #
|
734 |
+
|
735 |
+
# TODO: dt should be an optimizable constant? should it be the same value as that optimized for the passive object? #
|
736 |
+
delta_angular_vel = delta_omega * time_cons # * self.args.dt
|
737 |
+
delta_angular_vel = delta_angular_vel.squeeze(0)
|
738 |
+
if cur_timestep > 0: ## cur_timestep - 1 ##
|
739 |
+
prev_angular_vel = self.timestep_to_vels[cur_timestep - 1].detach()
|
740 |
+
cur_angular_vel = prev_angular_vel + delta_angular_vel * DAMPING
|
741 |
+
else:
|
742 |
+
cur_angular_vel = delta_angular_vel
|
743 |
+
|
744 |
+
self.timestep_to_vels[cur_timestep] = cur_angular_vel.detach()
|
745 |
+
|
746 |
+
cur_delta_quat = cur_angular_vel * time_cons # * self.args.dt
|
747 |
+
cur_delta_quat = cur_delta_quat.squeeze(0)
|
748 |
+
cur_state = self.timestep_to_states[cur_timestep].detach() # quaternion #
|
749 |
+
# print(f"cur_delta_quat: {cur_delta_quat.size()}, cur_state: {cur_state.size()}")
|
750 |
+
nex_state = cur_state + update_quaternion(cur_delta_quat, cur_state)
|
751 |
+
self.timestep_to_states[cur_timestep + 1] = nex_state.detach()
|
752 |
+
self.state = nex_state # set the joint state #
|
753 |
+
|
754 |
+
|
755 |
+
# get_transformed_visual_pts() --- transformed_visual_pts ##
|
756 |
+
# use the transformed visual # the articulated motion field #
|
757 |
+
# then we should add the free motion field here # # add the free motion field # # hwo to use that? #
|
758 |
+
# another rules for optimizing articulation motion field #
|
759 |
+
# -> the articulated model predicted transformations #
|
760 |
+
# -> the free motion field -> the motion field predicted by the network for each timestep -> an implicit motion field #
|
761 |
+
|
762 |
+
class Robot_urdf:
|
763 |
+
def __init__(self, links, link_name_to_link_idxes, link_name_to_link_struct, joint_name_to_joint_idx, actions_joint_name_to_joint_idx) -> None:
|
764 |
+
self.links = links
|
765 |
+
self.link_name_to_link_idxes = link_name_to_link_idxes
|
766 |
+
self.link_name_to_link_struct = link_name_to_link_struct
|
767 |
+
# joint_name_to_joint_idx, actions_joint_name_to_joint_idx
|
768 |
+
self.joint_name_to_joint_idx = joint_name_to_joint_idx
|
769 |
+
self.actions_joint_name_to_joint_idx = actions_joint_name_to_joint_idx
|
770 |
+
|
771 |
+
#
|
772 |
+
# particles
|
773 |
+
# sample particles
|
774 |
+
# how to sample particles?
|
775 |
+
# how to expand the particles? # -> you can use weights in the model dict #
|
776 |
+
# from grids and jample from grids #
|
777 |
+
# link idx to the
|
778 |
+
# robot #
|
779 |
+
# init vertices, init faces #
|
780 |
+
# expande the aprticles #
|
781 |
+
# expanede particles #
|
782 |
+
# use particles to conduct the simulation #
|
783 |
+
|
784 |
+
|
785 |
+
|
786 |
+
self.init_vertices, self.init_faces = self.get_init_visual_pts()
|
787 |
+
|
788 |
+
|
789 |
+
init_visual_pts_sv_fn = "robot_expanded_visual_pts.npy"
|
790 |
+
np.save(init_visual_pts_sv_fn, self.init_vertices.detach().cpu().numpy())
|
791 |
+
|
792 |
+
joint_name_to_joint_idx_sv_fn = "mano_joint_name_to_joint_idx.npy"
|
793 |
+
np.save(joint_name_to_joint_idx_sv_fn, self.joint_name_to_joint_idx)
|
794 |
+
|
795 |
+
actions_joint_name_to_joint_idx_sv_fn = "mano_actions_joint_name_to_joint_idx.npy"
|
796 |
+
np.save(actions_joint_name_to_joint_idx_sv_fn, self.actions_joint_name_to_joint_idx)
|
797 |
+
|
798 |
+
tot_joints = len(self.joint_name_to_joint_idx)
|
799 |
+
tot_actions_joints = len(self.actions_joint_name_to_joint_idx)
|
800 |
+
|
801 |
+
print(f"tot_joints: {tot_joints}, tot_actions_joints: {tot_actions_joints}")
|
802 |
+
|
803 |
+
pass
|
804 |
+
|
805 |
+
|
806 |
+
|
807 |
+
|
808 |
+
def expand_visual_pts(self, ):
|
809 |
+
link_name_to_visited = {}
|
810 |
+
# transform the visual pts #
|
811 |
+
# action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
|
812 |
+
|
813 |
+
palm_idx = self.link_name_to_link_idxes["palm"]
|
814 |
+
palm_link = self.links[palm_idx]
|
815 |
+
expanded_visual_pts = []
|
816 |
+
# expanded the visual pts # # transformed viusal pts # or the translations of the visual pts #
|
817 |
+
expanded_visual_pts = palm_link.expand_visual_pts(expanded_visual_pts, link_name_to_visited, self.link_name_to_link_struct)
|
818 |
+
expanded_visual_pts = torch.cat(expanded_visual_pts, dim=0)
|
819 |
+
# pass
|
820 |
+
return expanded_visual_pts
|
821 |
+
|
822 |
+
|
823 |
+
# get_transformed_visual_pts() # get_transformed_visual_pts of the visual pts ### get_transformed_visual_pts ## get_transformed_visual_pts ### #
|
824 |
+
def get_transformed_visual_pts(self, ):
|
825 |
+
init_visual_pts = []
|
826 |
+
link_name_to_visited = {}
|
827 |
+
|
828 |
+
palm_idx = self.link_name_to_link_idxes["palm"]
|
829 |
+
palm_link = self.links[palm_idx]
|
830 |
+
|
831 |
+
### init_visual_pts # from the pal mink to get the total transformed visual pts ##
|
832 |
+
init_visual_pts = palm_link.get_transformed_visual_pts(init_visual_pts, link_name_to_visited, self.link_name_to_link_struct)
|
833 |
+
|
834 |
+
init_visual_pts = torch.cat(init_visual_pts, dim=0) ## get the inita visual pts from the palm link ###
|
835 |
+
return init_visual_pts
|
836 |
+
|
837 |
+
|
838 |
+
### samping issue? --- TODO` `
|
839 |
+
def get_init_visual_pts(self, ):
|
840 |
+
init_visual_meshes = {
|
841 |
+
'vertices': [], 'faces': [], 'link_idxes': [], 'transformed_joint_pos': [], 'link_idxes': [], 'transformed_joint_pos': [], 'joint_link_idxes': []
|
842 |
+
}
|
843 |
+
init_parent_rot = torch.eye(3, dtype=torch.float32).cuda()
|
844 |
+
init_parent_trans = torch.zeros((3,), dtype=torch.float32).cuda()
|
845 |
+
|
846 |
+
palm_idx = self.link_name_to_link_idxes["palm"]
|
847 |
+
palm_link = self.links[palm_idx]
|
848 |
+
|
849 |
+
link_name_to_visited = {}
|
850 |
+
|
851 |
+
init_visual_meshes = palm_link.get_init_visual_meshes(init_parent_rot, init_parent_trans, init_visual_meshes, self.link_name_to_link_struct, link_name_to_visited)
|
852 |
+
|
853 |
+
self.link_idxes = torch.cat(init_visual_meshes['link_idxes'], dim=-1)
|
854 |
+
self.transformed_joint_pos = torch.cat(init_visual_meshes['transformed_joint_pos'], dim=0)
|
855 |
+
self.joint_link_idxes = torch.cat(init_visual_meshes['joint_link_idxes'], dim=-1) ###
|
856 |
+
|
857 |
+
|
858 |
+
|
859 |
+
# for cur_link in self.links:
|
860 |
+
# init_visual_meshes = cur_link.get_init_visual_meshes(init_parent_rot, init_parent_trans, init_visual_meshes, self.link_name_to_link_struct, link_name_to_visited)
|
861 |
+
|
862 |
+
init_vertices, init_faces = merge_meshes(init_visual_meshes['vertices'], init_visual_meshes['faces'])
|
863 |
+
return init_vertices, init_faces
|
864 |
+
|
865 |
+
|
866 |
+
def set_delta_state_and_update(self, states, cur_timestep):
|
867 |
+
link_name_to_visited = {}
|
868 |
+
|
869 |
+
action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
|
870 |
+
|
871 |
+
palm_idx = self.link_name_to_link_idxes["palm"]
|
872 |
+
palm_link = self.links[palm_idx]
|
873 |
+
|
874 |
+
link_name_to_visited = {}
|
875 |
+
|
876 |
+
palm_link.set_delta_state_and_update(states, cur_timestep, link_name_to_visited, action_joint_name_to_joint_idx, self.link_name_to_link_struct)
|
877 |
+
|
878 |
+
# cur_joint.set_actions_and_update_states(cur_action, cur_timestep, time_cons, cur_child_inertia)
|
879 |
+
def set_actions_and_update_states(self, actions, cur_timestep, time_cons,):
|
880 |
+
# actions
|
881 |
+
# self.actions_joint_name_to_joint_idx as the action joint name to joint idx
|
882 |
+
link_name_to_visited = {}
|
883 |
+
|
884 |
+
action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
|
885 |
+
|
886 |
+
palm_idx = self.link_name_to_link_idxes["palm"]
|
887 |
+
palm_link = self.links[palm_idx]
|
888 |
+
|
889 |
+
link_name_to_visited = {}
|
890 |
+
|
891 |
+
palm_link.set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
|
892 |
+
|
893 |
+
# for cur_joint in
|
894 |
+
|
895 |
+
# for cur_link in self.links:
|
896 |
+
# if cur_link.joint is not None:
|
897 |
+
# for cur_joint_nm in cur_link.joint:
|
898 |
+
# if cur_link.joint[cur_joint_nm].type in ['revolute']:
|
899 |
+
# cur_link_joint_name = cur_link.joint[cur_joint_nm].name
|
900 |
+
# cur_link_joint_idx = self.actions_joint_name_to_joint_idx[cur_link_joint_name]
|
901 |
+
|
902 |
+
|
903 |
+
# for cur_link in self.links:
|
904 |
+
# cur_link.set_actions_and_update_states(actions, cur_timestep, time_cons, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
|
905 |
+
|
906 |
+
### TODO: add the contact torque when calculating the nextstep states ###
|
907 |
+
### TODO: not an accurate implementation since differen joints should be jconsidered for one single link ###
|
908 |
+
### TODO: the articulated force modle is not so easy as this one .... ###
|
909 |
+
def set_contact_forces(self, hard_selected_forces, hard_selected_manipulating_points, hard_selected_sampled_input_pts_idxes):
|
910 |
+
# transformed_joint_pos, joint_link_idxes, link_idxes #
|
911 |
+
selected_pts_link_idxes = self.link_idxes[hard_selected_sampled_input_pts_idxes]
|
912 |
+
# use the selected link idxes #
|
913 |
+
# selected pts idxes #
|
914 |
+
|
915 |
+
# self.joint_link_idxes, transformed_joint_pos #
|
916 |
+
self.link_idx_to_transformed_joint_pos = {}
|
917 |
+
for i_link in range(self.transformed_joint_pos.size(0)):
|
918 |
+
cur_link_idx = self.link_idxes[i_link].item()
|
919 |
+
cur_link_pos = self.transformed_joint_pos[i_link]
|
920 |
+
# if cur_link_idx not in self.link_idx_to_transformed_joint_pos:
|
921 |
+
self.link_idx_to_transformed_joint_pos[cur_link_idx] = cur_link_pos
|
922 |
+
# self.link_idx_to_transformed_joint_pos[cur_link_idx].append(cur_link_pos)
|
923 |
+
|
924 |
+
# from the
|
925 |
+
self.link_idx_to_contact_forces = {}
|
926 |
+
for i_c_pts in range(hard_selected_forces.size(0)):
|
927 |
+
cur_contact_force = hard_selected_forces[i_c_pts] ##
|
928 |
+
cur_link_idx = selected_pts_link_idxes[i_c_pts].item()
|
929 |
+
cur_link_pos = self.link_idx_to_transformed_joint_pos[cur_link_idx]
|
930 |
+
cur_link_action_pos = hard_selected_manipulating_points[i_c_pts]
|
931 |
+
# (action_pos - link_pos) x (-contact_force) #
|
932 |
+
cur_contact_torque = torch.cross(
|
933 |
+
cur_link_action_pos - cur_link_pos, -cur_contact_force
|
934 |
+
)
|
935 |
+
if cur_link_idx not in self.link_idx_to_contact_forces:
|
936 |
+
self.link_idx_to_contact_forces[cur_link_idx] = [cur_contact_torque]
|
937 |
+
else:
|
938 |
+
self.link_idx_to_contact_forces[cur_link_idx].append(cur_contact_torque)
|
939 |
+
for link_idx in self.link_idx_to_contact_forces:
|
940 |
+
self.link_idx_to_contact_forces[link_idx] = torch.stack(self.link_idx_to_contact_forces[link_idx], dim=0)
|
941 |
+
self.link_idx_to_contact_forces[link_idx] = torch.sum(self.link_idx_to_contact_forces[link_idx] , dim=0)
|
942 |
+
for link_idx, link_struct in enumerate(self.links):
|
943 |
+
if link_idx in self.link_idx_to_contact_forces:
|
944 |
+
cur_link_contact_force = self.link_idx_to_contact_forces[link_idx]
|
945 |
+
link_struct.contact_torque = cur_link_contact_force
|
946 |
+
else:
|
947 |
+
link_struct.contact_torque = None
|
948 |
+
|
949 |
+
|
950 |
+
# def se ### from the optimizable initial states ###
|
951 |
+
def set_initial_state(self, states):
|
952 |
+
action_joint_name_to_joint_idx = self.actions_joint_name_to_joint_idx
|
953 |
+
link_name_to_visited = {}
|
954 |
+
|
955 |
+
palm_idx = self.link_name_to_link_idxes["palm"]
|
956 |
+
palm_link = self.links[palm_idx]
|
957 |
+
|
958 |
+
link_name_to_visited = {}
|
959 |
+
|
960 |
+
palm_link.set_initial_state(states, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
|
961 |
+
|
962 |
+
# for cur_link in self.links:
|
963 |
+
# cur_link.set_initial_state(states, action_joint_name_to_joint_idx, link_name_to_visited, self.link_name_to_link_struct)
|
964 |
+
|
965 |
+
### after each timestep -> re-calculate the inertial matrix using the current simulated states and the set the new actiosn and forward the simulation #
|
966 |
+
def calculate_inertia(self):
|
967 |
+
link_name_to_visited = {}
|
968 |
+
|
969 |
+
palm_idx = self.link_name_to_link_idxes["palm"]
|
970 |
+
palm_link = self.links[palm_idx]
|
971 |
+
|
972 |
+
link_name_to_visited = {}
|
973 |
+
|
974 |
+
palm_link.calculate_inertia(link_name_to_visited, self.link_name_to_link_struct)
|
975 |
+
|
976 |
+
# for cur_link in self.links:
|
977 |
+
# cur_link.calculate_inertia(link_name_to_visited, self.link_name_to_link_struct)
|
978 |
+
|
979 |
+
|
980 |
+
|
981 |
+
|
982 |
+
|
983 |
+
def parse_nparray_from_string(strr, args=None):
|
984 |
+
vals = strr.split(" ")
|
985 |
+
vals = [float(val) for val in vals]
|
986 |
+
vals = np.array(vals, dtype=np.float32)
|
987 |
+
vals = torch.from_numpy(vals).float()
|
988 |
+
## vals ##
|
989 |
+
vals = nn.Parameter(vals.cuda(), requires_grad=True)
|
990 |
+
|
991 |
+
return vals
|
992 |
+
|
993 |
+
|
994 |
+
### parse link data ###
|
995 |
+
def parse_link_data(link, args):
|
996 |
+
|
997 |
+
link_name = link.attrib["name"]
|
998 |
+
# print(f"parsing link: {link_name}") ## joints body meshes #
|
999 |
+
|
1000 |
+
joint = link.find("./joint")
|
1001 |
+
|
1002 |
+
joint_name = joint.attrib["name"]
|
1003 |
+
joint_type = joint.attrib["type"]
|
1004 |
+
if joint_type in ["revolute"]: ## a general xml parser here?
|
1005 |
+
axis = joint.attrib["axis"]
|
1006 |
+
axis = parse_nparray_from_string(axis, args=args)
|
1007 |
+
else:
|
1008 |
+
axis = None
|
1009 |
+
pos = joint.attrib["pos"] #
|
1010 |
+
pos = parse_nparray_from_string(pos, args=args)
|
1011 |
+
quat = joint.attrib["quat"]
|
1012 |
+
quat = parse_nparray_from_string(quat, args=args)
|
1013 |
+
|
1014 |
+
try:
|
1015 |
+
frame = joint.attrib["frame"]
|
1016 |
+
except:
|
1017 |
+
frame = "WORLD"
|
1018 |
+
|
1019 |
+
if joint_type not in ["fixed"]:
|
1020 |
+
damping = joint.attrib["damping"]
|
1021 |
+
damping = float(damping)
|
1022 |
+
else:
|
1023 |
+
damping = 0.0
|
1024 |
+
|
1025 |
+
cur_joint = Joint(joint_name, joint_type, axis, pos, quat, frame, damping, args=args)
|
1026 |
+
|
1027 |
+
body = link.find("./body")
|
1028 |
+
body_name = body.attrib["name"]
|
1029 |
+
body_type = body.attrib["type"]
|
1030 |
+
if body_type == "mesh":
|
1031 |
+
filename = body.attrib["filename"]
|
1032 |
+
else:
|
1033 |
+
filename = ""
|
1034 |
+
|
1035 |
+
if body_type == "sphere":
|
1036 |
+
radius = body.attrib["radius"]
|
1037 |
+
radius = float(radius)
|
1038 |
+
else:
|
1039 |
+
radius = 0.
|
1040 |
+
|
1041 |
+
pos = body.attrib["pos"]
|
1042 |
+
pos = parse_nparray_from_string(pos, args=args)
|
1043 |
+
quat = body.attrib["quat"]
|
1044 |
+
quat = joint.attrib["quat"]
|
1045 |
+
try:
|
1046 |
+
transform_type = body.attrib["transform_type"]
|
1047 |
+
except:
|
1048 |
+
transform_type = "OBJ_TO_WORLD"
|
1049 |
+
density = body.attrib["density"]
|
1050 |
+
density = float(density)
|
1051 |
+
mu = body.attrib["mu"]
|
1052 |
+
mu = float(mu)
|
1053 |
+
try: ## rgba ##
|
1054 |
+
rgba = body.attrib["rgba"]
|
1055 |
+
rgba = parse_nparray_from_string(rgba, args=args)
|
1056 |
+
except:
|
1057 |
+
rgba = np.zeros((4,), dtype=np.float32)
|
1058 |
+
|
1059 |
+
cur_body = Body(body_name, body_type, filename, pos, quat, transform_type, density, mu, rgba, radius, args=args)
|
1060 |
+
|
1061 |
+
children_link = []
|
1062 |
+
links = link.findall("./link")
|
1063 |
+
for child_link in links: #
|
1064 |
+
cur_child_link = parse_link_data(child_link, args=args)
|
1065 |
+
children_link.append(cur_child_link)
|
1066 |
+
|
1067 |
+
link_name = link.attrib["name"]
|
1068 |
+
link_obj = Link(link_name, joint=cur_joint, body=cur_body, children=children_link, args=args)
|
1069 |
+
return link_obj
|
1070 |
+
|
1071 |
+
|
1072 |
+
### parse link data ###
|
1073 |
+
def parse_link_data_urdf(link):
|
1074 |
+
|
1075 |
+
link_name = link.attrib["name"]
|
1076 |
+
# print(f"parsing link: {link_name}") ## joints body meshes #
|
1077 |
+
|
1078 |
+
inertial = link.find("./inertial")
|
1079 |
+
|
1080 |
+
origin = inertial.find("./origin")
|
1081 |
+
inertial_pos = origin.attrib["xyz"]
|
1082 |
+
inertial_pos = parse_nparray_from_string(inertial_pos)
|
1083 |
+
|
1084 |
+
inertial_rpy = origin.attrib["rpy"]
|
1085 |
+
inertial_rpy = parse_nparray_from_string(inertial_rpy)
|
1086 |
+
|
1087 |
+
inertial_mass = inertial.find("./mass")
|
1088 |
+
inertial_mass = inertial_mass.attrib["value"]
|
1089 |
+
|
1090 |
+
inertial_inertia = inertial.find("./inertia")
|
1091 |
+
inertial_ixx = inertial_inertia.attrib["ixx"]
|
1092 |
+
inertial_ixx = float(inertial_ixx)
|
1093 |
+
inertial_ixy = inertial_inertia.attrib["ixy"]
|
1094 |
+
inertial_ixy = float(inertial_ixy)
|
1095 |
+
inertial_ixz = inertial_inertia.attrib["ixz"]
|
1096 |
+
inertial_ixz = float(inertial_ixz)
|
1097 |
+
inertial_iyy = inertial_inertia.attrib["iyy"]
|
1098 |
+
inertial_iyy = float(inertial_iyy)
|
1099 |
+
inertial_iyz = inertial_inertia.attrib["iyz"]
|
1100 |
+
inertial_iyz = float(inertial_iyz)
|
1101 |
+
inertial_izz = inertial_inertia.attrib["izz"]
|
1102 |
+
inertial_izz = float(inertial_izz)
|
1103 |
+
|
1104 |
+
inertial_inertia_mtx = torch.zeros((3, 3), dtype=torch.float32).cuda()
|
1105 |
+
inertial_inertia_mtx[0, 0] = inertial_ixx
|
1106 |
+
inertial_inertia_mtx[0, 1] = inertial_ixy
|
1107 |
+
inertial_inertia_mtx[0, 2] = inertial_ixz
|
1108 |
+
inertial_inertia_mtx[1, 0] = inertial_ixy
|
1109 |
+
inertial_inertia_mtx[1, 1] = inertial_iyy
|
1110 |
+
inertial_inertia_mtx[1, 2] = inertial_iyz
|
1111 |
+
inertial_inertia_mtx[2, 0] = inertial_ixz
|
1112 |
+
inertial_inertia_mtx[2, 1] = inertial_iyz
|
1113 |
+
inertial_inertia_mtx[2, 2] = inertial_izz
|
1114 |
+
|
1115 |
+
# [xx, xy, xz] #
|
1116 |
+
# [0, yy, yz] #
|
1117 |
+
# [0, 0, zz] #
|
1118 |
+
|
1119 |
+
# a strange inertia value ... #
|
1120 |
+
# TODO: how to compute the inertia matrix? #
|
1121 |
+
|
1122 |
+
visual = link.find("./visual")
|
1123 |
+
|
1124 |
+
if visual is not None:
|
1125 |
+
origin = visual.find("./origin")
|
1126 |
+
visual_pos = origin.attrib["xyz"]
|
1127 |
+
visual_pos = parse_nparray_from_string(visual_pos)
|
1128 |
+
visual_rpy = origin.attrib["rpy"]
|
1129 |
+
visual_rpy = parse_nparray_from_string(visual_rpy)
|
1130 |
+
geometry = visual.find("./geometry")
|
1131 |
+
geometry_mesh = geometry.find("./mesh")
|
1132 |
+
mesh_fn = geometry_mesh.attrib["filename"]
|
1133 |
+
mesh_scale = geometry_mesh.attrib["scale"]
|
1134 |
+
|
1135 |
+
mesh_scale = parse_nparray_from_string(mesh_scale)
|
1136 |
+
mesh_fn = str(mesh_fn)
|
1137 |
+
|
1138 |
+
|
1139 |
+
link_struct = Link_urdf(name=link_name, inertial=Inertial(origin_rpy=inertial_rpy, origin_xyz=inertial_pos, mass=inertial_mass, inertia=inertial_inertia_mtx), visual=Visual(visual_rpy=visual_rpy, visual_xyz=visual_pos, geometry_mesh_fn=mesh_fn, geometry_mesh_scale=mesh_scale) if visual is not None else None)
|
1140 |
+
|
1141 |
+
return link_struct
|
1142 |
+
|
1143 |
+
def parse_joint_data_urdf(joint):
|
1144 |
+
joint_name = joint.attrib["name"]
|
1145 |
+
joint_type = joint.attrib["type"]
|
1146 |
+
|
1147 |
+
parent = joint.find("./parent")
|
1148 |
+
child = joint.find("./child")
|
1149 |
+
parent_name = parent.attrib["link"]
|
1150 |
+
child_name = child.attrib["link"]
|
1151 |
+
|
1152 |
+
joint_origin = joint.find("./origin")
|
1153 |
+
# if joint_origin.
|
1154 |
+
try:
|
1155 |
+
origin_xyz = joint_origin.attrib["xyz"]
|
1156 |
+
origin_xyz = parse_nparray_from_string(origin_xyz)
|
1157 |
+
except:
|
1158 |
+
origin_xyz = torch.tensor([0., 0., 0.], dtype=torch.float32).cuda()
|
1159 |
+
|
1160 |
+
joint_axis = joint.find("./axis")
|
1161 |
+
if joint_axis is not None:
|
1162 |
+
joint_axis = joint_axis.attrib["xyz"]
|
1163 |
+
joint_axis = parse_nparray_from_string(joint_axis)
|
1164 |
+
else:
|
1165 |
+
joint_axis = torch.tensor([1, 0., 0.], dtype=torch.float32).cuda()
|
1166 |
+
|
1167 |
+
joint_limit = joint.find("./limit")
|
1168 |
+
if joint_limit is not None:
|
1169 |
+
joint_lower = joint_limit.attrib["lower"]
|
1170 |
+
joint_lower = float(joint_lower)
|
1171 |
+
joint_upper = joint_limit.attrib["upper"]
|
1172 |
+
joint_upper = float(joint_upper)
|
1173 |
+
joint_effort = joint_limit.attrib["effort"]
|
1174 |
+
joint_effort = float(joint_effort)
|
1175 |
+
joint_velocity = joint_limit.attrib["velocity"]
|
1176 |
+
joint_velocity = float(joint_velocity)
|
1177 |
+
else:
|
1178 |
+
joint_lower = -0.5000
|
1179 |
+
joint_upper = 1.57
|
1180 |
+
joint_effort = 1000
|
1181 |
+
joint_velocity = 0.5
|
1182 |
+
|
1183 |
+
# cosntruct the joint data #
|
1184 |
+
joint_limit = Joint_Limit(effort=joint_effort, lower=joint_lower, upper=joint_upper, velocity=joint_velocity)
|
1185 |
+
cur_joint_struct = Joint_urdf(joint_name, joint_type, parent_name, child_name, origin_xyz, joint_axis, joint_limit)
|
1186 |
+
return cur_joint_struct
|
1187 |
+
|
1188 |
+
|
1189 |
+
|
1190 |
+
def parse_data_from_urdf(xml_fn):
|
1191 |
+
|
1192 |
+
tree = ElementTree()
|
1193 |
+
tree.parse(xml_fn)
|
1194 |
+
print(f"{xml_fn}")
|
1195 |
+
### get total robots ###
|
1196 |
+
# robots = tree.findall("link")
|
1197 |
+
cur_robot = tree
|
1198 |
+
# i_robot = 0 #
|
1199 |
+
# tot_robots = [] #
|
1200 |
+
# for cur_robot in robots: #
|
1201 |
+
# print(f"Getting robot: {i_robot}") #
|
1202 |
+
# i_robot += 1 #
|
1203 |
+
# print(f"len(robots): {len(robots)}") #
|
1204 |
+
# cur_robot = robots[0] #
|
1205 |
+
cur_links = cur_robot.findall("./link")
|
1206 |
+
# i_link = 0
|
1207 |
+
link_name_to_link_idxes = {}
|
1208 |
+
cur_robot_links = []
|
1209 |
+
link_name_to_link_struct = {}
|
1210 |
+
for i_link_idx, cur_link in enumerate(cur_links):
|
1211 |
+
cur_link_struct = parse_link_data_urdf(cur_link)
|
1212 |
+
print(f"Adding link {cur_link_struct.name}")
|
1213 |
+
cur_link_struct.link_idx = i_link_idx
|
1214 |
+
cur_robot_links.append(cur_link_struct)
|
1215 |
+
|
1216 |
+
link_name_to_link_idxes[cur_link_struct.name] = i_link_idx
|
1217 |
+
link_name_to_link_struct[cur_link_struct.name] = cur_link_struct
|
1218 |
+
# for cur_link in cur_links:
|
1219 |
+
# cur_robot_links.append(parse_link_data_urdf(cur_link, args=args))
|
1220 |
+
|
1221 |
+
print(f"link_name_to_link_struct: {len(link_name_to_link_struct)}, ")
|
1222 |
+
|
1223 |
+
tot_robot_joints = []
|
1224 |
+
|
1225 |
+
joint_name_to_joint_idx = {}
|
1226 |
+
|
1227 |
+
actions_joint_name_to_joint_idx = {}
|
1228 |
+
|
1229 |
+
cur_joints = cur_robot.findall("./joint")
|
1230 |
+
for i_joint, cur_joint in enumerate(cur_joints):
|
1231 |
+
cur_joint_struct = parse_joint_data_urdf(cur_joint)
|
1232 |
+
cur_joint_parent_link = cur_joint_struct.parent_link
|
1233 |
+
cur_joint_child_link = cur_joint_struct.child_link
|
1234 |
+
|
1235 |
+
cur_joint_idx = len(tot_robot_joints)
|
1236 |
+
cur_joint_name = cur_joint_struct.name
|
1237 |
+
|
1238 |
+
joint_name_to_joint_idx[cur_joint_name] = cur_joint_idx
|
1239 |
+
|
1240 |
+
cur_joint_type = cur_joint_struct.type
|
1241 |
+
if cur_joint_type in ['revolute']:
|
1242 |
+
actions_joint_name_to_joint_idx[cur_joint_name] = cur_joint_idx
|
1243 |
+
|
1244 |
+
|
1245 |
+
#### add the current joint to tot joints ###
|
1246 |
+
tot_robot_joints.append(cur_joint_struct)
|
1247 |
+
|
1248 |
+
parent_link_idx = link_name_to_link_idxes[cur_joint_parent_link]
|
1249 |
+
cur_parent_link_struct = cur_robot_links[parent_link_idx]
|
1250 |
+
|
1251 |
+
|
1252 |
+
child_link_idx = link_name_to_link_idxes[cur_joint_child_link]
|
1253 |
+
cur_child_link_struct = cur_robot_links[child_link_idx]
|
1254 |
+
# parent link struct #
|
1255 |
+
if link_name_to_link_struct[cur_joint_parent_link].joint is not None:
|
1256 |
+
link_name_to_link_struct[cur_joint_parent_link].joint[cur_joint_struct.name] = cur_joint_struct
|
1257 |
+
link_name_to_link_struct[cur_joint_parent_link].children[cur_joint_struct.name] = cur_child_link_struct.name
|
1258 |
+
# cur_child_link_struct
|
1259 |
+
# cur_parent_link_struct.joint.append(cur_joint_struct)
|
1260 |
+
# cur_parent_link_struct.children.append(cur_child_link_struct)
|
1261 |
+
else:
|
1262 |
+
link_name_to_link_struct[cur_joint_parent_link].joint = {
|
1263 |
+
cur_joint_struct.name: cur_joint_struct
|
1264 |
+
}
|
1265 |
+
link_name_to_link_struct[cur_joint_parent_link].children = {
|
1266 |
+
cur_joint_struct.name: cur_child_link_struct.name
|
1267 |
+
# cur_child_link_struct
|
1268 |
+
}
|
1269 |
+
# cur_parent_link_struct.joint = [cur_joint_struct]
|
1270 |
+
# cur_parent_link_struct.children.append(cur_child_link_struct)
|
1271 |
+
# pass
|
1272 |
+
|
1273 |
+
|
1274 |
+
cur_robot_obj = Robot_urdf(cur_robot_links, link_name_to_link_idxes, link_name_to_link_struct, joint_name_to_joint_idx, actions_joint_name_to_joint_idx)
|
1275 |
+
# tot_robots.append(cur_robot_obj)
|
1276 |
+
|
1277 |
+
# for the joint robots #
|
1278 |
+
# for every joint
|
1279 |
+
# tot_actuators = []
|
1280 |
+
# actuators = tree.findall("./actuator/motor")
|
1281 |
+
# joint_nm_to_joint_idx = {}
|
1282 |
+
# i_act = 0
|
1283 |
+
# for cur_act in actuators:
|
1284 |
+
# cur_act_joint_nm = cur_act.attrib["joint"]
|
1285 |
+
# joint_nm_to_joint_idx[cur_act_joint_nm] = i_act
|
1286 |
+
# i_act += 1 ### add the act ###
|
1287 |
+
|
1288 |
+
# tot_robots[0].set_joint_idx(joint_nm_to_joint_idx) ### set joint idx here ### # tot robots #
|
1289 |
+
# tot_robots[0].get_nn_pts()
|
1290 |
+
# tot_robots[1].get_nn_pts()
|
1291 |
+
|
1292 |
+
return cur_robot_obj
|
1293 |
+
|
1294 |
+
|
1295 |
+
def get_name_to_state_from_str(states_str):
|
1296 |
+
tot_states = states_str.split(" ")
|
1297 |
+
tot_states = [float(cur_state) for cur_state in tot_states]
|
1298 |
+
joint_name_to_state = {}
|
1299 |
+
for i in range(len(tot_states)):
|
1300 |
+
cur_joint_name = f"joint{i + 1}"
|
1301 |
+
cur_joint_state = tot_states[i]
|
1302 |
+
joint_name_to_state[cur_joint_name] = cur_joint_state
|
1303 |
+
return joint_name_to_state
|
1304 |
+
|
1305 |
+
|
1306 |
+
def merge_meshes(verts_list, faces_list):
|
1307 |
+
nn_verts = 0
|
1308 |
+
tot_verts_list = []
|
1309 |
+
tot_faces_list = []
|
1310 |
+
for i_vv, cur_verts in enumerate(verts_list):
|
1311 |
+
cur_verts_nn = cur_verts.size(0)
|
1312 |
+
tot_verts_list.append(cur_verts)
|
1313 |
+
tot_faces_list.append(faces_list[i_vv] + nn_verts)
|
1314 |
+
nn_verts = nn_verts + cur_verts_nn
|
1315 |
+
tot_verts_list = torch.cat(tot_verts_list, dim=0)
|
1316 |
+
tot_faces_list = torch.cat(tot_faces_list, dim=0)
|
1317 |
+
return tot_verts_list, tot_faces_list
|
1318 |
+
|
1319 |
+
|
1320 |
+
|
1321 |
+
class RobotAgent: # robot and the robot #
|
1322 |
+
def __init__(self, xml_fn) -> None:
|
1323 |
+
self.xml_fn = xml_fn
|
1324 |
+
# self.args = args
|
1325 |
+
|
1326 |
+
##
|
1327 |
+
active_robot = parse_data_from_urdf(xml_fn)
|
1328 |
+
|
1329 |
+
self.time_constant = nn.Embedding(
|
1330 |
+
num_embeddings=3, embedding_dim=1
|
1331 |
+
).cuda()
|
1332 |
+
torch.nn.init.ones_(self.time_constant.weight) #
|
1333 |
+
self.time_constant.weight.data = self.time_constant.weight.data * 0.2 ### time_constant data #
|
1334 |
+
|
1335 |
+
self.optimizable_actions = nn.Embedding(
|
1336 |
+
num_embeddings=100, embedding_dim=60,
|
1337 |
+
).cuda()
|
1338 |
+
torch.nn.init.zeros_(self.optimizable_actions.weight) #
|
1339 |
+
|
1340 |
+
self.learning_rate = 5e-4
|
1341 |
+
|
1342 |
+
self.active_robot = active_robot
|
1343 |
+
|
1344 |
+
# # get init states # #
|
1345 |
+
self.set_init_states()
|
1346 |
+
init_visual_pts = self.get_init_state_visual_pts()
|
1347 |
+
self.init_visual_pts = init_visual_pts
|
1348 |
+
|
1349 |
+
|
1350 |
+
def set_init_states_target_value(self, init_states):
|
1351 |
+
# glb_rot = torch.eye(n=3, dtype=torch.float32).cuda()
|
1352 |
+
# glb_trans = torch.zeros((3,), dtype=torch.float32).cuda() ### glb_trans #### and the rot 3##
|
1353 |
+
|
1354 |
+
# tot_init_states = {}
|
1355 |
+
# tot_init_states['glb_rot'] = glb_rot;
|
1356 |
+
# tot_init_states['glb_trans'] = glb_trans;
|
1357 |
+
# tot_init_states['links_init_states'] = init_states
|
1358 |
+
# self.active_robot.set_init_states_target_value(tot_init_states)
|
1359 |
+
# init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
|
1360 |
+
self.active_robot.set_initial_state(init_states)
|
1361 |
+
|
1362 |
+
def set_init_states(self):
|
1363 |
+
# glb_rot = torch.eye(n=3, dtype=torch.float32).cuda()
|
1364 |
+
# glb_trans = torch.zeros((3,), dtype=torch.float32).cuda() ### glb_trans #### and the rot 3##
|
1365 |
+
|
1366 |
+
# ### random rotation ###
|
1367 |
+
# # glb_rot_np = R.random().as_matrix()
|
1368 |
+
# # glb_rot = torch.from_numpy(glb_rot_np).float().cuda()
|
1369 |
+
# ### random rotation ###
|
1370 |
+
|
1371 |
+
# # glb_rot, glb_trans #
|
1372 |
+
# init_states = {} # init states #
|
1373 |
+
# init_states['glb_rot'] = glb_rot; #
|
1374 |
+
# init_states['glb_trans'] = glb_trans;
|
1375 |
+
# self.active_robot.set_init_states(init_states)
|
1376 |
+
|
1377 |
+
init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
|
1378 |
+
self.active_robot.set_initial_state(init_joint_states)
|
1379 |
+
|
1380 |
+
def get_init_state_visual_pts(self, ):
|
1381 |
+
# visual_pts_list = [] # compute the transformation via current state #
|
1382 |
+
# visual_pts_list, visual_pts_mass_list = self.active_robot.compute_transformation_via_current_state( visual_pts_list)
|
1383 |
+
cur_verts, cur_faces = self.active_robot.get_init_visual_pts()
|
1384 |
+
self.faces = cur_faces
|
1385 |
+
# init_visual_pts = visual_pts_list
|
1386 |
+
return cur_verts
|
1387 |
+
|
1388 |
+
def set_actions_and_update_states(self, actions, cur_timestep):
|
1389 |
+
#
|
1390 |
+
time_cons = self.time_constant(torch.zeros((1,), dtype=torch.long).cuda()) ### time constant of the system ##
|
1391 |
+
self.active_robot.set_actions_and_update_states(actions, cur_timestep, time_cons) ###
|
1392 |
+
pass
|
1393 |
+
|
1394 |
+
def forward_stepping_test(self, ):
|
1395 |
+
# delta_glb_rot; delta_glb_trans #
|
1396 |
+
timestep_to_visual_pts = {}
|
1397 |
+
for i_step in range(50):
|
1398 |
+
actions = {}
|
1399 |
+
actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32).cuda()
|
1400 |
+
actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32).cuda()
|
1401 |
+
actions_link_actions = torch.ones((22, ), dtype=torch.float32).cuda()
|
1402 |
+
# actions_link_actions = actions_link_actions * 0.2
|
1403 |
+
actions_link_actions = actions_link_actions * -1. #
|
1404 |
+
actions['link_actions'] = actions_link_actions
|
1405 |
+
self.set_actions_and_update_states(actions=actions, cur_timestep=i_step)
|
1406 |
+
|
1407 |
+
cur_visual_pts = robot_agent.get_init_state_visual_pts()
|
1408 |
+
cur_visual_pts = cur_visual_pts.detach().cpu().numpy()
|
1409 |
+
timestep_to_visual_pts[i_step + 1] = cur_visual_pts
|
1410 |
+
return timestep_to_visual_pts
|
1411 |
+
|
1412 |
+
def initialize_optimization(self, reference_pts_dict):
|
1413 |
+
self.n_timesteps = 50
|
1414 |
+
# self.n_timesteps = 19 # first 19-timesteps optimization #
|
1415 |
+
self.nn_tot_optimization_iters = 100
|
1416 |
+
# self.nn_tot_optimization_iters = 57
|
1417 |
+
# TODO: load reference points #
|
1418 |
+
self.ts_to_reference_pts = np.load(reference_pts_dict, allow_pickle=True).item() ####
|
1419 |
+
self.ts_to_reference_pts = {
|
1420 |
+
ts // 2 + 1: torch.from_numpy(self.ts_to_reference_pts[ts]).float().cuda() for ts in self.ts_to_reference_pts
|
1421 |
+
}
|
1422 |
+
|
1423 |
+
|
1424 |
+
def forward_stepping_optimization(self, ):
|
1425 |
+
nn_tot_optimization_iters = self.nn_tot_optimization_iters
|
1426 |
+
params_to_train = []
|
1427 |
+
params_to_train += list(self.optimizable_actions.parameters())
|
1428 |
+
self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate)
|
1429 |
+
|
1430 |
+
for i_iter in range(nn_tot_optimization_iters):
|
1431 |
+
|
1432 |
+
tot_losses = []
|
1433 |
+
ts_to_robot_points = {}
|
1434 |
+
for cur_ts in range(self.n_timesteps):
|
1435 |
+
# print(f"iter: {i_iter}, cur_ts: {cur_ts}")
|
1436 |
+
# actions = {}
|
1437 |
+
# actions['delta_glb_rot'] = torch.eye(3, dtype=torch.float32).cuda()
|
1438 |
+
# actions['delta_glb_trans'] = torch.zeros((3,), dtype=torch.float32).cuda()
|
1439 |
+
actions_link_actions = self.optimizable_actions(torch.zeros((1,), dtype=torch.long).cuda() + cur_ts).squeeze(0)
|
1440 |
+
# actions_link_actions = actions_link_actions * 0.2
|
1441 |
+
# actions_link_actions = actions_link_actions * -1. #
|
1442 |
+
# actions['link_actions'] = actions_link_actions
|
1443 |
+
# self.set_actions_and_update_states(actions=actions, cur_timestep=cur_ts) # update the interaction #
|
1444 |
+
|
1445 |
+
with torch.no_grad():
|
1446 |
+
self.active_robot.calculate_inertia()
|
1447 |
+
|
1448 |
+
self.active_robot.set_actions_and_update_states(actions_link_actions, cur_ts, 0.2)
|
1449 |
+
|
1450 |
+
cur_visual_pts, cur_faces = self.active_robot.get_init_visual_pts()
|
1451 |
+
ts_to_robot_points[cur_ts + 1] = cur_visual_pts.clone()
|
1452 |
+
|
1453 |
+
cur_reference_pts = self.ts_to_reference_pts[cur_ts + 1]
|
1454 |
+
diff = torch.sum((cur_visual_pts - cur_reference_pts) ** 2, dim=-1)
|
1455 |
+
diff = diff.mean()
|
1456 |
+
|
1457 |
+
# diff.
|
1458 |
+
self.optimizer.zero_grad()
|
1459 |
+
diff.backward(retain_graph=True)
|
1460 |
+
# diff.backward(retain_graph=False)
|
1461 |
+
self.optimizer.step()
|
1462 |
+
|
1463 |
+
tot_losses.append(diff.item())
|
1464 |
+
|
1465 |
+
|
1466 |
+
loss = sum(tot_losses) / float(len(tot_losses))
|
1467 |
+
print(f"Iter: {i_iter}, average loss: {loss}")
|
1468 |
+
# print(f"Iter: {i_iter}, average loss: {loss.item()}, start optimizing")
|
1469 |
+
# self.optimizer.zero_grad()
|
1470 |
+
# loss.backward()
|
1471 |
+
# self.optimizer.step()
|
1472 |
+
|
1473 |
+
self.ts_to_robot_points = {
|
1474 |
+
ts: ts_to_robot_points[ts].detach().cpu().numpy() for ts in ts_to_robot_points
|
1475 |
+
}
|
1476 |
+
self.ts_to_ref_points = {
|
1477 |
+
ts: self.ts_to_reference_pts[ts].detach().cpu().numpy() for ts in ts_to_robot_points
|
1478 |
+
}
|
1479 |
+
return self.ts_to_robot_points, self.ts_to_ref_points
|
1480 |
+
|
1481 |
+
|
1482 |
+
|
1483 |
+
|
1484 |
+
def rotation_matrix_from_axis_angle(axis, angle): # rotation_matrix_from_axis_angle ->
|
1485 |
+
# sin_ = np.sin(angle) # ti.math.sin(angle)
|
1486 |
+
# cos_ = np.cos(angle) # ti.math.cos(angle)
|
1487 |
+
sin_ = torch.sin(angle) # ti.math.sin(angle)
|
1488 |
+
cos_ = torch.cos(angle) # ti.math.cos(angle)
|
1489 |
+
u_x, u_y, u_z = axis[0], axis[1], axis[2]
|
1490 |
+
u_xx = u_x * u_x
|
1491 |
+
u_yy = u_y * u_y
|
1492 |
+
u_zz = u_z * u_z
|
1493 |
+
u_xy = u_x * u_y
|
1494 |
+
u_xz = u_x * u_z
|
1495 |
+
u_yz = u_y * u_z ##
|
1496 |
+
|
1497 |
+
|
1498 |
+
row_a = torch.stack(
|
1499 |
+
[cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dim=0
|
1500 |
+
)
|
1501 |
+
# print(f"row_a: {row_a.size()}")
|
1502 |
+
row_b = torch.stack(
|
1503 |
+
[u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dim=0
|
1504 |
+
)
|
1505 |
+
# print(f"row_b: {row_b.size()}")
|
1506 |
+
row_c = torch.stack(
|
1507 |
+
[u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dim=0
|
1508 |
+
)
|
1509 |
+
# print(f"row_c: {row_c.size()}")
|
1510 |
+
|
1511 |
+
### rot_mtx for the rot_mtx ###
|
1512 |
+
rot_mtx = torch.stack(
|
1513 |
+
[row_a, row_b, row_c], dim=-1 ### rot_matrix of he matrix ##
|
1514 |
+
)
|
1515 |
+
|
1516 |
+
return rot_mtx
|
1517 |
+
|
1518 |
+
|
1519 |
+
|
1520 |
+
#### Big TODO: the external contact forces from the manipulated object to the robot ####
|
1521 |
+
if __name__=='__main__':
|
1522 |
+
|
1523 |
+
urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/mano/mano_mean_nocoll_simplified.urdf"
|
1524 |
+
robot_agent = RobotAgent(urdf_fn)
|
1525 |
+
|
1526 |
+
ref_dict_npy = "reference_verts.npy"
|
1527 |
+
robot_agent.initialize_optimization(ref_dict_npy)
|
1528 |
+
ts_to_robot_points, ts_to_ref_points = robot_agent.forward_stepping_optimization()
|
1529 |
+
np.save(f"ts_to_robot_points.npy", ts_to_robot_points)
|
1530 |
+
np.save(f"ts_to_ref_points.npy", ts_to_ref_points)
|
1531 |
+
exit(0)
|
1532 |
+
|
1533 |
+
urdf_fn = "/home/xueyi/diffsim/NeuS/rsc/mano/mano_mean_nocoll_simplified.urdf"
|
1534 |
+
cur_robot = parse_data_from_urdf(urdf_fn)
|
1535 |
+
# self.init_vertices, self.init_faces
|
1536 |
+
init_vertices, init_faces = cur_robot.init_vertices, cur_robot.init_faces
|
1537 |
+
|
1538 |
+
|
1539 |
+
|
1540 |
+
init_vertices = init_vertices.detach().cpu().numpy()
|
1541 |
+
init_faces = init_faces.detach().cpu().numpy()
|
1542 |
+
|
1543 |
+
|
1544 |
+
## initial states ehre ##3
|
1545 |
+
# mesh_obj = trimesh.Trimesh(vertices=init_vertices, faces=init_faces)
|
1546 |
+
# mesh_obj.export(f"hand_urdf.ply")
|
1547 |
+
|
1548 |
+
##### Test the set initial state function #####
|
1549 |
+
init_joint_states = torch.zeros((60, ), dtype=torch.float32).cuda()
|
1550 |
+
cur_robot.set_initial_state(init_joint_states)
|
1551 |
+
##### Test the set initial state function #####
|
1552 |
+
|
1553 |
+
|
1554 |
+
|
1555 |
+
|
1556 |
+
cur_zeros_actions = torch.zeros((60, ), dtype=torch.float32).cuda()
|
1557 |
+
cur_ones_actions = torch.ones((60, ), dtype=torch.float32).cuda() # * 100
|
1558 |
+
|
1559 |
+
ts_to_mesh_verts = {}
|
1560 |
+
for i_ts in range(50):
|
1561 |
+
cur_robot.calculate_inertia()
|
1562 |
+
|
1563 |
+
cur_robot.set_actions_and_update_states(cur_ones_actions, i_ts, 0.2) ###
|
1564 |
+
|
1565 |
+
|
1566 |
+
cur_verts, cur_faces = cur_robot.get_init_visual_pts()
|
1567 |
+
cur_mesh = trimesh.Trimesh(vertices=cur_verts.detach().cpu().numpy(), faces=cur_faces.detach().cpu().numpy())
|
1568 |
+
|
1569 |
+
ts_to_mesh_verts[i_ts + i_ts] = cur_verts.detach().cpu().numpy()
|
1570 |
+
# cur_mesh.export(f"stated_mano_mesh.ply")
|
1571 |
+
# cur_mesh.export(f"zero_actioned_mano_mesh.ply")
|
1572 |
+
cur_mesh.export(f"ones_actioned_mano_mesh_ts_{i_ts}.ply")
|
1573 |
+
|
1574 |
+
np.save(f"reference_verts.npy", ts_to_mesh_verts)
|
1575 |
+
|
1576 |
+
exit(0)
|
1577 |
+
|
1578 |
+
xml_fn = "/home/xueyi/diffsim/DiffHand/assets/hand_sphere.xml"
|
1579 |
+
robot_agent = RobotAgent(xml_fn=xml_fn, args=None)
|
1580 |
+
init_visual_pts = robot_agent.init_visual_pts.detach().cpu().numpy()
|
1581 |
+
exit(0)
|
1582 |
+
|
models/dyn_model_utils.py
ADDED
@@ -0,0 +1,1369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
# import torch
|
4 |
+
# from ..utils import Timer
|
5 |
+
import numpy as np
|
6 |
+
# import torch.nn.functional as F
|
7 |
+
import os
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
from xml.etree.ElementTree import ElementTree
|
12 |
+
|
13 |
+
import trimesh
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
# import List
|
17 |
+
# class link; joint; body
|
18 |
+
###
|
19 |
+
|
20 |
+
## calculate transformation to the frame ##
|
21 |
+
|
22 |
+
## the joint idx ##
|
23 |
+
##### th_cuda_idx #####
|
24 |
+
|
25 |
+
# name to the main axis? #
|
26 |
+
|
27 |
+
# def get_body_name_to_main_axis()
|
28 |
+
# another one is just getting the joint offset positions?
|
29 |
+
# and after the revolute transformation # # all revolute joint points ####
|
30 |
+
def get_body_name_to_main_axis():
|
31 |
+
# negative y; positive x #
|
32 |
+
body_name_to_main_axis = {
|
33 |
+
"body2": -2, "body6": 1, "body10": 1, "body14": 1, "body17": 1
|
34 |
+
}
|
35 |
+
return body_name_to_main_axis ## get the body name to main axis ##
|
36 |
+
|
37 |
+
## insert one
|
38 |
+
def plane_rotation_matrix_from_angle_xz(angle):
|
39 |
+
## angle of
|
40 |
+
sin_ = torch.sin(angle)
|
41 |
+
cos_ = torch.cos(angle)
|
42 |
+
zero_padding = torch.zeros_like(cos_)
|
43 |
+
one_padding = torch.ones_like(cos_)
|
44 |
+
col_a = torch.stack(
|
45 |
+
[cos_, zero_padding, sin_], dim=0
|
46 |
+
)
|
47 |
+
col_b = torch.stack(
|
48 |
+
[zero_padding, one_padding, zero_padding], dim=0
|
49 |
+
)
|
50 |
+
col_c = torch.stack(
|
51 |
+
[-1. * sin_, zero_padding, cos_], dim=0
|
52 |
+
)
|
53 |
+
rot_mtx = torch.stack(
|
54 |
+
[col_a, col_b, col_c], dim=-1
|
55 |
+
)
|
56 |
+
# col_a = torch.stack(
|
57 |
+
# [cos_, sin_], dim=0 ### col of the rotation matrix
|
58 |
+
# )
|
59 |
+
# col_b = torch.stack(
|
60 |
+
# [-1. * sin_, cos_], dim=0 ## cols of the rotation matrix
|
61 |
+
# )
|
62 |
+
# rot_mtx = torch.stack(
|
63 |
+
# [col_a, col_b], dim=-1 ### rotation matrix
|
64 |
+
# )
|
65 |
+
return rot_mtx
|
66 |
+
|
67 |
+
def plane_rotation_matrix_from_angle(angle):
|
68 |
+
## angle of
|
69 |
+
sin_ = torch.sin(angle)
|
70 |
+
cos_ = torch.cos(angle)
|
71 |
+
col_a = torch.stack(
|
72 |
+
[cos_, sin_], dim=0 ### col of the rotation matrix
|
73 |
+
)
|
74 |
+
col_b = torch.stack(
|
75 |
+
[-1. * sin_, cos_], dim=0 ## cols of the rotation matrix
|
76 |
+
)
|
77 |
+
rot_mtx = torch.stack(
|
78 |
+
[col_a, col_b], dim=-1 ### rotation matrix
|
79 |
+
)
|
80 |
+
return rot_mtx
|
81 |
+
|
82 |
+
def rotation_matrix_from_axis_angle(axis, angle): # rotation_matrix_from_axis_angle ->
|
83 |
+
# sin_ = np.sin(angle) # ti.math.sin(angle)
|
84 |
+
# cos_ = np.cos(angle) # ti.math.cos(angle)
|
85 |
+
sin_ = torch.sin(angle) # ti.math.sin(angle)
|
86 |
+
cos_ = torch.cos(angle) # ti.math.cos(angle)
|
87 |
+
u_x, u_y, u_z = axis[0], axis[1], axis[2]
|
88 |
+
u_xx = u_x * u_x
|
89 |
+
u_yy = u_y * u_y
|
90 |
+
u_zz = u_z * u_z
|
91 |
+
u_xy = u_x * u_y
|
92 |
+
u_xz = u_x * u_z
|
93 |
+
u_yz = u_y * u_z ##
|
94 |
+
# rot_mtx = np.stack(
|
95 |
+
# [np.array([cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dtype=np.float32),
|
96 |
+
# np.array([u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dtype=np.float32),
|
97 |
+
# np.array([u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dtype=np.float32)
|
98 |
+
# ], axis=-1 ### np stack
|
99 |
+
# ) ## a single
|
100 |
+
|
101 |
+
# rot_mtx = torch.stack(
|
102 |
+
# [
|
103 |
+
# torch.tensor([cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dtype=torch.float32),
|
104 |
+
# torch.tensor([u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dtype=torch.float32),
|
105 |
+
# torch.tensor([u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dtype=torch.float32)
|
106 |
+
# ], dim=-1 ## stack those torch tensors ##
|
107 |
+
# )
|
108 |
+
|
109 |
+
row_a = torch.stack(
|
110 |
+
[cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dim=0
|
111 |
+
)
|
112 |
+
# print(f"row_a: {row_a.size()}")
|
113 |
+
row_b = torch.stack(
|
114 |
+
[u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dim=0
|
115 |
+
)
|
116 |
+
# print(f"row_b: {row_b.size()}")
|
117 |
+
row_c = torch.stack(
|
118 |
+
[u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dim=0
|
119 |
+
)
|
120 |
+
# print(f"row_c: {row_c.size()}")
|
121 |
+
|
122 |
+
### rot_mtx for the rot_mtx ###
|
123 |
+
rot_mtx = torch.stack(
|
124 |
+
[row_a, row_b, row_c], dim=-1 ### rot_matrix of he matrix ##
|
125 |
+
)
|
126 |
+
|
127 |
+
# rot_mtx = torch.stack(
|
128 |
+
# [
|
129 |
+
|
130 |
+
# torch.tensor([cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dtype=torch.float32),
|
131 |
+
# torch.tensor([u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dtype=torch.float32),
|
132 |
+
# torch.tensor([u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dtype=torch.float32)
|
133 |
+
# ], dim=-1 ## stack those torch tensors ##
|
134 |
+
# )
|
135 |
+
|
136 |
+
# rot_mtx_numpy = rot_mtx.to_numpy()
|
137 |
+
# rot_mtx_at_rot_mtx = rot_mtx @ rot_mtx.transpose()
|
138 |
+
# print(rot_mtx_at_rot_mtx)
|
139 |
+
return rot_mtx
|
140 |
+
|
141 |
+
## joint name = "joint3" ##
|
142 |
+
# <joint name="joint3" type="revolute" axis="0.000000 0.000000 -1.000000" pos="4.689700 -4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e7"/>
|
143 |
+
class Joint:
|
144 |
+
def __init__(self, name, joint_type, axis, pos, quat, frame, damping, args) -> None:
|
145 |
+
self.name = name
|
146 |
+
self.type = joint_type
|
147 |
+
self.axis = axis
|
148 |
+
self.pos = pos
|
149 |
+
self.quat = quat
|
150 |
+
self.frame = frame
|
151 |
+
self.damping = damping
|
152 |
+
|
153 |
+
self.args = args
|
154 |
+
|
155 |
+
#### TODO: the dimension of the state vector ? ####
|
156 |
+
# self.state = 0. ## parameter
|
157 |
+
self.state = nn.Parameter(
|
158 |
+
torch.zeros((1,), dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
|
159 |
+
)
|
160 |
+
# self.rot_mtx = np.eye(3, dtypes=np.float32)
|
161 |
+
# self.trans_vec = np.zeros((3,), dtype=np.float32) ## rot m
|
162 |
+
self.rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True)
|
163 |
+
self.trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True)
|
164 |
+
# self.rot_mtx = np.eye(3, dtype=np.float32)
|
165 |
+
# self.trans_vec = np.zeros((3,), dtype=np.float32)
|
166 |
+
|
167 |
+
self.axis_rot_mtx = torch.tensor(
|
168 |
+
[
|
169 |
+
[1, 0, 0], [0, -1, 0], [0, 0, -1]
|
170 |
+
], dtype=torch.float32
|
171 |
+
).cuda(self.args.th_cuda_idx)
|
172 |
+
|
173 |
+
self.joint_idx = -1
|
174 |
+
|
175 |
+
self.transformed_joint_pts = self.pos.clone()
|
176 |
+
|
177 |
+
def print_grads(self, ):
|
178 |
+
print(f"rot_mtx: {self.rot_mtx.grad}")
|
179 |
+
print(f"trans_vec: {self.trans_vec.grad}")
|
180 |
+
|
181 |
+
def clear_grads(self,):
|
182 |
+
if self.rot_mtx.grad is not None:
|
183 |
+
self.rot_mtx.grad.data = self.rot_mtx.grad.data * 0.
|
184 |
+
if self.trans_vec.grad is not None:
|
185 |
+
self.trans_vec.grad.data = self.trans_vec.grad.data * 0.
|
186 |
+
|
187 |
+
def compute_transformation(self,):
|
188 |
+
# use the state to transform them # # transform # ## transform the state ##
|
189 |
+
# use the state to transform them # # transform them for the state #
|
190 |
+
if self.type == "revolute":
|
191 |
+
# print(f"computing transformation matrices with axis: {self.axis}, state: {self.state}")
|
192 |
+
# rotation matrix from the axis angle #
|
193 |
+
rot_mtx = rotation_matrix_from_axis_angle(self.axis, self.state)
|
194 |
+
# rot_mtx(p - p_v) + p_v -> rot_mtx p - rot_mtx p_v + p_v
|
195 |
+
# trans_vec = self.pos - np.matmul(rot_mtx, self.pos.reshape(3, 1)).reshape(3)
|
196 |
+
# self.rot_mtx = np.copy(rot_mtx)
|
197 |
+
# self.trans_vec = np.copy(trans_vec)
|
198 |
+
trans_vec = self.pos - torch.matmul(rot_mtx, self.pos.view(3, 1)).view(3).contiguous()
|
199 |
+
self.rot_mtx = rot_mtx
|
200 |
+
self.trans_vec = trans_vec
|
201 |
+
else:
|
202 |
+
### TODO: implement transformations for joints in other types ###
|
203 |
+
pass
|
204 |
+
|
205 |
+
def set_state(self, name_to_state):
|
206 |
+
if self.name in name_to_state:
|
207 |
+
# self.state = name_to_state["name"]
|
208 |
+
self.state = name_to_state[self.name] ##
|
209 |
+
|
210 |
+
def set_state_via_vec(self, state_vec): ### transform points via the state vectors here ###
|
211 |
+
if self.joint_idx >= 0:
|
212 |
+
self.state = state_vec[self.joint_idx] ## give the parameter to the parameters ##
|
213 |
+
|
214 |
+
def set_joint_idx(self, joint_name_to_idx):
|
215 |
+
if self.name in joint_name_to_idx:
|
216 |
+
self.joint_idx = joint_name_to_idx[self.name]
|
217 |
+
|
218 |
+
|
219 |
+
def set_args(self, args):
|
220 |
+
self.args = args
|
221 |
+
|
222 |
+
|
223 |
+
def compute_transformation_via_state_vals(self, state_vals):
|
224 |
+
if self.joint_idx >= 0:
|
225 |
+
cur_joint_state = state_vals[self.joint_idx]
|
226 |
+
else:
|
227 |
+
cur_joint_state = self.state
|
228 |
+
# use the state to transform them # # transform # ## transform the state ##
|
229 |
+
# use the state to transform them # # transform them for the state #
|
230 |
+
if self.type == "revolute":
|
231 |
+
# print(f"computing transformation matrices with axis: {self.axis}, state: {self.state}")
|
232 |
+
# rotation matrix from the axis angle #
|
233 |
+
rot_mtx = rotation_matrix_from_axis_angle(self.axis, cur_joint_state)
|
234 |
+
# rot_mtx(p - p_v) + p_v -> rot_mtx p - rot_mtx p_v + p_v
|
235 |
+
# trans_vec = self.pos - np.matmul(rot_mtx, self.pos.reshape(3, 1)).reshape(3)
|
236 |
+
# self.rot_mtx = np.copy(rot_mtx)
|
237 |
+
# self.trans_vec = np.copy(trans_vec)
|
238 |
+
trans_vec = self.pos - torch.matmul(rot_mtx, self.pos.view(3, 1)).view(3).contiguous()
|
239 |
+
self.rot_mtx = rot_mtx
|
240 |
+
self.trans_vec = trans_vec
|
241 |
+
elif self.type == "free2d":
|
242 |
+
cur_joint_state = state_vals # still only for the current scene #
|
243 |
+
# cur_joint_state
|
244 |
+
cur_joint_rot_val = state_vals[2]
|
245 |
+
### rot_mtx ### ### rot_mtx ###
|
246 |
+
rot_mtx = plane_rotation_matrix_from_angle_xz(cur_joint_rot_val)
|
247 |
+
# rot_mtx = plane_rotation_matrix_from_angle(cur_joint_rot_val) ### 2 x 2 rot matrix #
|
248 |
+
# cur joint rot val #
|
249 |
+
# rot mtx of the rotation
|
250 |
+
# xy_val =
|
251 |
+
# axis_rot_mtx
|
252 |
+
# R_axis^T ( R R_axis (p) + trans (with the y-axis padded) )
|
253 |
+
cur_trans_vec = torch.stack(
|
254 |
+
[state_vals[0], torch.zeros_like(state_vals[0]), state_vals[1]], dim=0
|
255 |
+
)
|
256 |
+
# cur_trans_vec #
|
257 |
+
rot_mtx = torch.matmul(self.axis_rot_mtx.transpose(1, 0), torch.matmul(rot_mtx, self.axis_rot_mtx))
|
258 |
+
trans_vec = torch.matmul(self.axis_rot_mtx.transpose(1, 0), cur_trans_vec.unsqueeze(-1).contiguous()).squeeze(-1).contiguous() + self.pos
|
259 |
+
|
260 |
+
self.rot_mtx = rot_mtx
|
261 |
+
self.trans_vec = trans_vec ## rot_mtx and trans_vec #
|
262 |
+
else:
|
263 |
+
### TODO: implement transformations for joints in other types ###
|
264 |
+
pass
|
265 |
+
return self.rot_mtx, self.trans_vec
|
266 |
+
|
267 |
+
def transform_joints_via_parent_rot_trans_infos(self, parent_rot_mtx, parent_trans_vec):
|
268 |
+
#
|
269 |
+
# if self.type == "revolute" or self.type == "free2d":
|
270 |
+
transformed_joint_pts = torch.matmul(parent_rot_mtx, self.pos.view(3 ,1).contiguous()).view(3).contiguous() + parent_trans_vec
|
271 |
+
# else:
|
272 |
+
self.transformed_joint_pts = transformed_joint_pts ### get self transformed joint pts here ###
|
273 |
+
return transformed_joint_pts
|
274 |
+
# if self.joint_idx >= 0:
|
275 |
+
# cur_joint_state = state_vals[self.joint_idx]
|
276 |
+
# else:
|
277 |
+
# cur_joint_state = self.state # state #
|
278 |
+
# # use the state to transform them # # transform ### transform the state ##
|
279 |
+
# # use the state to transform them # # transform them for the state # transform for the state #
|
280 |
+
# if self.type == "revolute":
|
281 |
+
# # print(f"computing transformation matrices with axis: {self.axis}, state: {self.state}")
|
282 |
+
# # rotation matrix from the axis angle #
|
283 |
+
# rot_mtx = rotation_matrix_from_axis_angle(self.axis, cur_joint_state)
|
284 |
+
# # rot_mtx(p - p_v) + p_v -> rot_mtx p - rot_mtx p_v + p_v
|
285 |
+
# # trans_vec = self.pos - np.matmul(rot_mtx, self.pos.reshape(3, 1)).reshape(3)
|
286 |
+
# # self.rot_mtx = np.copy(rot_mtx)
|
287 |
+
# # self.trans_vec = np.copy(trans_vec)
|
288 |
+
# trans_vec = self.pos - torch.matmul(rot_mtx, self.pos.view(3, 1)).view(3).contiguous()
|
289 |
+
# self.rot_mtx = rot_mtx
|
290 |
+
# self.trans_vec = trans_vec
|
291 |
+
# elif self.type == "free2d":
|
292 |
+
# cur_joint_state = state_vals # still only for the current scene #
|
293 |
+
# # cur_joint_state
|
294 |
+
# cur_joint_rot_val = state_vals[2]
|
295 |
+
# ### rot_mtx ### ### rot_mtx ###
|
296 |
+
# rot_mtx = plane_rotation_matrix_from_angle_xz(cur_joint_rot_val)
|
297 |
+
# # rot_mtx = plane_rotation_matrix_from_angle(cur_joint_rot_val) ### 2 x 2 rot matrix #
|
298 |
+
# # cur joint rot val #
|
299 |
+
# # rot mtx of the rotation
|
300 |
+
# # xy_val =
|
301 |
+
# # axis_rot_mtx
|
302 |
+
# # R_axis^T ( R R_axis (p) + trans (with the y-axis padded) )
|
303 |
+
# cur_trans_vec = torch.stack(
|
304 |
+
# [state_vals[0], torch.zeros_like(state_vals[0]), state_vals[1]], dim=0
|
305 |
+
# )
|
306 |
+
# # cur_trans_vec #
|
307 |
+
# rot_mtx = torch.matmul(self.axis_rot_mtx.transpose(1, 0), torch.matmul(rot_mtx, self.axis_rot_mtx))
|
308 |
+
# trans_vec = torch.matmul(self.axis_rot_mtx.transpose(1, 0), cur_trans_vec.unsqueeze(-1).contiguous()).squeeze(-1).contiguous() + self.pos
|
309 |
+
|
310 |
+
# self.rot_mtx = rot_mtx
|
311 |
+
# self.trans_vec = trans_vec ## rot_mtx and trans_vec #
|
312 |
+
# else:
|
313 |
+
# ### TODO: implement transformations for joints in other types ###
|
314 |
+
# pass
|
315 |
+
# return self.rot_mtx, self.trans_vec
|
316 |
+
|
317 |
+
## fixed joint -> translation and rotation ##
|
318 |
+
## revolute joint -> can be actuated ##
|
319 |
+
## set states and compute the transfromations in a top-to-down manner ##
|
320 |
+
|
321 |
+
## trnasform the robot -> a list of qs ##
|
322 |
+
## a list of qs ##
|
323 |
+
## transform from the root of the robot; pass qs from the root to the leaf node ##
|
324 |
+
## visual meshes or visual meshes from the basic description of robots ##
|
325 |
+
## visual meshes; or visual points ##
|
326 |
+
## visual meshes -> transform them into the visual density values here ##
|
327 |
+
## visual meshes -> transform them into the ## into the visual counterparts ##
|
328 |
+
## ## visual meshes -> ## ## ##
|
329 |
+
# <body name="body0" type="mesh" filename="hand/body0.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.700000 0.700000 0.700000 1"/>
|
330 |
+
class Body:
|
331 |
+
def __init__(self, name, body_type, filename, pos, quat, transform_type, density, mu, rgba, radius, args) -> None:
|
332 |
+
self.name = name
|
333 |
+
self.body_type = body_type
|
334 |
+
### for mesh object ###
|
335 |
+
self.filename = filename
|
336 |
+
self.args = args
|
337 |
+
|
338 |
+
self.pos = pos
|
339 |
+
self.quat = quat
|
340 |
+
self.transform_type = transform_type
|
341 |
+
self.density = density
|
342 |
+
self.mu = mu
|
343 |
+
self.rgba = rgba
|
344 |
+
|
345 |
+
### for sphere object ###
|
346 |
+
self.radius = radius
|
347 |
+
## or vertices here #
|
348 |
+
## pass them to the child and treat them as the parent transformation ##
|
349 |
+
|
350 |
+
self.visual_pts_ref = None
|
351 |
+
self.visual_faces_ref = None
|
352 |
+
|
353 |
+
self.visual_pts = None ## visual pts and
|
354 |
+
|
355 |
+
self.body_name_to_main_axis = get_body_name_to_main_axis() ### get the body name to main axis here #
|
356 |
+
|
357 |
+
self.get_visual_counterparts()
|
358 |
+
|
359 |
+
|
360 |
+
def update_radius(self,):
|
361 |
+
self.radius.data = self.radius.data - self.radius.grad.data
|
362 |
+
|
363 |
+
self.radius.grad.data = self.radius.grad.data * 0.
|
364 |
+
|
365 |
+
|
366 |
+
def update_xml_file(self,):
|
367 |
+
xml_content_with_flexible_radius = f"""<redmax model="hand">
|
368 |
+
<!-- 1) change the damping value here? -->
|
369 |
+
<!-- 2) change the center of mass -->
|
370 |
+
<option integrator="BDF2" timestep="0.01" gravity="0. 0. -0.000098"/>
|
371 |
+
<ground pos="0 0 -10" normal="0 0 1"/>
|
372 |
+
<default>
|
373 |
+
<general_primitive_contact kn="1e6" kt="1e3" mu="0.8" damping="3e1" />
|
374 |
+
</default>
|
375 |
+
|
376 |
+
<robot>
|
377 |
+
<link name="link0">
|
378 |
+
<joint name="joint0" type="fixed" pos="0 0 0" quat="1 0 0 0" frame="WORLD"/>
|
379 |
+
<body name="body0" type="mesh" filename="hand/body0.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.700000 0.700000 0.700000 1"/>
|
380 |
+
<link name="link1">
|
381 |
+
<joint name="joint1" type="revolute" axis="0.000000 0.000000 -1.000000" pos="-3.300000 -5.689700 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
382 |
+
<body name="body1" type="mesh" filename="hand/body1.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.600000 0.600000 0.600000 1"/>
|
383 |
+
<link name="link2">
|
384 |
+
<joint name="joint2" type="revolute" axis="1.000000 0.000000 0.000000" pos="-3.300000 -7.680000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
385 |
+
<body name="body2" type="mesh" filename="hand/body2.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.500000 0.500000 0.500000 1"/>
|
386 |
+
</link>
|
387 |
+
</link>
|
388 |
+
<link name="link3">
|
389 |
+
<!-- revolute joint -->
|
390 |
+
<joint name="joint3" type="revolute" axis="0.000000 0.000000 -1.000000" pos="4.689700 -4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
391 |
+
<body name="body3" type="mesh" filename="hand/body3.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.600000 0.600000 0.600000 1"/>
|
392 |
+
<link name="link4">
|
393 |
+
<joint name="joint4" type="revolute" axis="0.000000 1.000000 0.000000" pos="6.680000 -4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
394 |
+
<body name="body4" type="mesh" filename="hand/body4.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.500000 0.500000 0.500000 1"/>
|
395 |
+
<link name="link5">
|
396 |
+
<joint name="joint5" type="revolute" axis="0.000000 1.000000 0.000000" pos="11.080000 -4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
397 |
+
<body name="body5" type="mesh" filename="hand/body5.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.400000 0.400000 0.400000 1"/>
|
398 |
+
<link name="link6">
|
399 |
+
<joint name="joint6" type="revolute" axis="0.000000 1.000000 0.000000" pos="15.480000 -4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
400 |
+
<body name="body6" type="mesh" filename="hand/body6.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.300000 0.300000 0.300000 1"/>
|
401 |
+
</link>
|
402 |
+
</link>
|
403 |
+
</link>
|
404 |
+
</link>
|
405 |
+
<link name="link7">
|
406 |
+
<joint name="joint7" type="revolute" axis="0.000000 0.000000 -1.000000" pos="4.689700 -1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
407 |
+
<body name="body7" type="mesh" filename="hand/body7.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.600000 0.600000 0.600000 1"/>
|
408 |
+
<link name="link8">
|
409 |
+
<joint name="joint8" type="revolute" axis="0.000000 1.000000 0.000000" pos="6.680000 -1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
410 |
+
<body name="body8" type="mesh" filename="hand/body8.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.500000 0.500000 0.500000 1"/>
|
411 |
+
<link name="link9">
|
412 |
+
<joint name="joint9" type="revolute" axis="0.000000 1.000000 0.000000" pos="11.080000 -1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
413 |
+
<body name="body9" type="mesh" filename="hand/body9.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.400000 0.400000 0.400000 1"/>
|
414 |
+
<link name="link10">
|
415 |
+
<joint name="joint10" type="revolute" axis="0.000000 1.000000 0.000000" pos="15.480000 -1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
416 |
+
<body name="body10" type="mesh" filename="hand/body10.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.300000 0.300000 0.300000 1"/>
|
417 |
+
</link>
|
418 |
+
</link>
|
419 |
+
</link>
|
420 |
+
</link>
|
421 |
+
<link name="link11">
|
422 |
+
<joint name="joint11" type="revolute" axis="0.000000 0.000000 -1.000000" pos="4.689700 1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
423 |
+
<body name="body11" type="mesh" filename="hand/body11.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.600000 0.600000 0.600000 1"/>
|
424 |
+
<link name="link12">
|
425 |
+
<joint name="joint12" type="revolute" axis="0.000000 1.000000 0.000000" pos="6.680000 1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
426 |
+
<body name="body12" type="mesh" filename="hand/body12.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.500000 0.500000 0.500000 1"/>
|
427 |
+
<link name="link13">
|
428 |
+
<joint name="joint13" type="revolute" axis="0.000000 1.000000 0.000000" pos="11.080000 1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
429 |
+
<body name="body13" type="mesh" filename="hand/body13.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.400000 0.400000 0.400000 1"/>
|
430 |
+
<link name="link14">
|
431 |
+
<joint name="joint14" type="revolute" axis="0.000000 1.000000 0.000000" pos="15.480000 1.475000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
432 |
+
<body name="body14" type="mesh" filename="hand/body14.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.300000 0.300000 0.300000 1"/>
|
433 |
+
</link>
|
434 |
+
</link>
|
435 |
+
</link>
|
436 |
+
</link>
|
437 |
+
<link name="link15">
|
438 |
+
<joint name="joint15" type="revolute" axis="0.000000 0.000000 -1.000000" pos="4.689700 4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
439 |
+
<body name="body15" type="mesh" filename="hand/body15.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.600000 0.600000 0.600000 1"/>
|
440 |
+
<link name="link16">
|
441 |
+
<joint name="joint16" type="revolute" axis="0.000000 1.000000 0.000000" pos="6.680000 4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
442 |
+
<body name="body16" type="mesh" filename="hand/body16.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.500000 0.500000 0.500000 1"/>
|
443 |
+
<link name="link17">
|
444 |
+
<joint name="joint17" type="revolute" axis="0.000000 1.000000 0.000000" pos="11.080000 4.425000 0.000000" quat="1 0 0 0" frame="WORLD" damping="1e4"/>
|
445 |
+
<body name="body17" type="mesh" filename="hand/body17.obj" pos="0 0 0" quat="1 0 0 0" transform_type="OBJ_TO_WORLD" density="1" mu="0" rgba="0.400000 0.400000 0.400000 1"/>
|
446 |
+
</link>
|
447 |
+
</link>
|
448 |
+
</link>
|
449 |
+
</link>
|
450 |
+
</robot>
|
451 |
+
|
452 |
+
<robot>
|
453 |
+
<link name="sphere">
|
454 |
+
<joint name="sphere" type="free2d" pos = "10. 0.0 3.5" quat="1 -1 0 0" format="LOCAL" damping="0"/>
|
455 |
+
<body name="sphere" type="sphere" radius="{self.radius[0].detach().cpu().item()}" pos="0 0 0" quat="1 0 0 0" density="0.5" mu="0" texture="resources/textures/sphere.jpg"/>
|
456 |
+
</link>
|
457 |
+
</robot>
|
458 |
+
|
459 |
+
<contact>
|
460 |
+
<ground_contact body="sphere" kn="1e6" kt="1e3" mu="0.8" damping="3e1"/>
|
461 |
+
<general_primitive_contact general_body="body0" primitive_body="sphere"/>
|
462 |
+
<general_primitive_contact general_body="body1" primitive_body="sphere"/>
|
463 |
+
<general_primitive_contact general_body="body2" primitive_body="sphere"/>
|
464 |
+
<general_primitive_contact general_body="body3" primitive_body="sphere"/>
|
465 |
+
<general_primitive_contact general_body="body4" primitive_body="sphere"/>
|
466 |
+
<general_primitive_contact general_body="body5" primitive_body="sphere"/>
|
467 |
+
<general_primitive_contact general_body="body6" primitive_body="sphere"/>
|
468 |
+
<general_primitive_contact general_body="body7" primitive_body="sphere"/>
|
469 |
+
<general_primitive_contact general_body="body8" primitive_body="sphere"/>
|
470 |
+
<general_primitive_contact general_body="body9" primitive_body="sphere"/>
|
471 |
+
<general_primitive_contact general_body="body10" primitive_body="sphere"/>
|
472 |
+
<general_primitive_contact general_body="body11" primitive_body="sphere"/>
|
473 |
+
<general_primitive_contact general_body="body12" primitive_body="sphere"/>
|
474 |
+
<general_primitive_contact general_body="body13" primitive_body="sphere"/>
|
475 |
+
<general_primitive_contact general_body="body14" primitive_body="sphere"/>
|
476 |
+
<general_primitive_contact general_body="body15" primitive_body="sphere"/>
|
477 |
+
<general_primitive_contact general_body="body16" primitive_body="sphere"/>
|
478 |
+
<general_primitive_contact general_body="body17" primitive_body="sphere"/>
|
479 |
+
</contact>
|
480 |
+
|
481 |
+
<actuator>
|
482 |
+
<motor joint="joint1" ctrl="force" ctrl_range="-3e5 3e5"/>
|
483 |
+
<motor joint="joint2" ctrl="force" ctrl_range="-3e5 3e5"/>
|
484 |
+
<motor joint="joint3" ctrl="force" ctrl_range="-3e5 3e5"/>
|
485 |
+
<motor joint="joint4" ctrl="force" ctrl_range="-3e5 3e5"/>
|
486 |
+
<motor joint="joint5" ctrl="force" ctrl_range="-3e5 3e5"/>
|
487 |
+
<motor joint="joint6" ctrl="force" ctrl_range="-3e5 3e5"/>
|
488 |
+
<motor joint="joint7" ctrl="force" ctrl_range="-3e5 3e5"/>
|
489 |
+
<motor joint="joint8" ctrl="force" ctrl_range="-3e5 3e5"/>
|
490 |
+
<motor joint="joint9" ctrl="force" ctrl_range="-3e5 3e5"/>
|
491 |
+
<motor joint="joint10" ctrl="force" ctrl_range="-3e5 3e5"/>
|
492 |
+
<motor joint="joint11" ctrl="force" ctrl_range="-3e5 3e5"/>
|
493 |
+
<motor joint="joint12" ctrl="force" ctrl_range="-3e5 3e5"/>
|
494 |
+
<motor joint="joint13" ctrl="force" ctrl_range="-3e5 3e5"/>
|
495 |
+
<motor joint="joint14" ctrl="force" ctrl_range="-3e5 3e5"/>
|
496 |
+
<motor joint="joint15" ctrl="force" ctrl_range="-3e5 3e5"/>
|
497 |
+
<motor joint="joint16" ctrl="force" ctrl_range="-3e5 3e5"/>
|
498 |
+
<motor joint="joint17" ctrl="force" ctrl_range="-3e5 3e5"/>
|
499 |
+
</actuator>
|
500 |
+
</redmax>
|
501 |
+
"""
|
502 |
+
xml_loading_fn = "/home/xueyi/diffsim/DiffHand/assets/hand_sphere_free_sphere_geo_test.xml"
|
503 |
+
with open(xml_loading_fn, "w") as wf:
|
504 |
+
wf.write(xml_content_with_flexible_radius)
|
505 |
+
wf.close()
|
506 |
+
|
507 |
+
### get visual pts colorrs ### ###
|
508 |
+
def get_visual_pts_colors(self, ):
|
509 |
+
tot_visual_pts_nn = self.visual_pts_ref.size(0)
|
510 |
+
# self.pts_rgba = [torch.from_numpy(self.rgba).float().cuda(self.args.th_cuda_idx) for _ in range(tot_visual_pts_nn)] # total visual pts nn
|
511 |
+
self.pts_rgba = [torch.tensor(self.rgba.data).cuda(self.args.th_cuda_idx) for _ in range(tot_visual_pts_nn)] # total visual pts nn skeletong
|
512 |
+
self.pts_rgba = torch.stack(self.pts_rgba, dim=0) #
|
513 |
+
return self.pts_rgba
|
514 |
+
|
515 |
+
def get_visual_counterparts(self,):
|
516 |
+
### TODO: implement this for visual counterparts ### mid line regression and name to body mapping relations --- for each body, how to calculate the midline and other properties?
|
517 |
+
######## get body type ########## get visual midline of the input mesh and the mesh vertices? ######## # skeleton of the hand -> 21 points ? retarget from this hand to the mano hand and use the mano hand priors?
|
518 |
+
if self.body_type == "sphere":
|
519 |
+
filename = "/home/xueyi/diffsim/DiffHand/examples/save_res/hand_sphere_demo/meshes/18.obj"
|
520 |
+
if not os.path.exists(filename):
|
521 |
+
filename = "/data/xueyi/diffsim/DiffHand/assets/18.obj"
|
522 |
+
else:
|
523 |
+
filename = self.filename
|
524 |
+
rt_asset_path = "/home/xueyi/diffsim/DiffHand/assets" ### assets folder ###
|
525 |
+
if not os.path.exists(rt_asset_path):
|
526 |
+
rt_asset_path = "/data/xueyi/diffsim/DiffHand/assets"
|
527 |
+
filename = os.path.join(rt_asset_path, filename)
|
528 |
+
body_mesh = trimesh.load(filename, process=False)
|
529 |
+
# verts = np.array(body_mesh.vertices)
|
530 |
+
# faces = np.array(body_mesh.faces, dtype=np.long)
|
531 |
+
|
532 |
+
# self.visual_pts_ref = np.copy(verts) ## verts ##
|
533 |
+
# self.visual_faces_ref = np.copy(faces) ## faces
|
534 |
+
# self.visual_pts_ref #
|
535 |
+
|
536 |
+
#### body_mesh.vertices ####
|
537 |
+
# verts = torch.tensor(body_mesh.vertices, dtype=torch.float32).cuda(self.args.th_cuda_idx)
|
538 |
+
# faces = torch.tensor(body_mesh.faces, dtype=torch.long).cuda(self.args.th_cuda_idx)
|
539 |
+
#### body_mesh.vertices ####
|
540 |
+
|
541 |
+
# self.pos = nn.Parameter(
|
542 |
+
# torch.tensor([0., 0., 0.], dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
|
543 |
+
# )
|
544 |
+
|
545 |
+
self.pos = nn.Parameter(
|
546 |
+
torch.tensor(self.pos.detach().cpu().tolist(), dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
|
547 |
+
)
|
548 |
+
|
549 |
+
### Step 1 ### -> set the pos to the correct initial pose ###
|
550 |
+
|
551 |
+
self.radius = nn.Parameter(
|
552 |
+
torch.tensor([self.args.initial_radius], dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
|
553 |
+
)
|
554 |
+
### visual pts ref ### ## body_mesh.vertices -> #
|
555 |
+
self.visual_pts_ref = torch.tensor(body_mesh.vertices, dtype=torch.float32).cuda(self.args.th_cuda_idx)
|
556 |
+
|
557 |
+
# if self.name == "sphere":
|
558 |
+
# self.visual_pts_ref = self.visual_pts_ref / 2. # the initial radius
|
559 |
+
# self.visual_pts_ref = self.visual_pts_ref * self.radius ## multiple the initla radius #
|
560 |
+
|
561 |
+
# self.visual_pts_ref = nn.Parameter(
|
562 |
+
# torch.tensor(body_mesh.vertices, dtype=torch.float32, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
|
563 |
+
# )
|
564 |
+
# self.visual_faces_ref = nn.Parameter(
|
565 |
+
# torch.tensor(body_mesh.faces, dtype=torch.long, requires_grad=True).cuda(self.args.th_cuda_idx), requires_grad=True
|
566 |
+
# )
|
567 |
+
self.visual_faces_ref = torch.tensor(body_mesh.faces, dtype=torch.long).cuda(self.args.th_cuda_idx)
|
568 |
+
|
569 |
+
# body_name_to_main_axis
|
570 |
+
# body_name_to_main_axis for the body_name_to_main_axis #
|
571 |
+
# visual_faces_ref #
|
572 |
+
# visual_pts_ref #
|
573 |
+
|
574 |
+
minn_pts, _ = torch.min(self.visual_pts_ref, dim=0) ### get the visual pts minn ###
|
575 |
+
maxx_pts, _ = torch.max(self.visual_pts_ref, dim=0) ### visual pts maxx ###
|
576 |
+
mean_pts = torch.mean(self.visual_pts_ref, dim=0) ### mean_pts of the mean_pts ###
|
577 |
+
|
578 |
+
if self.name in self.body_name_to_main_axis:
|
579 |
+
cur_main_axis = self.body_name_to_main_axis[self.name] ## get the body name ##
|
580 |
+
|
581 |
+
if cur_main_axis == -2:
|
582 |
+
main_axis_pts = minn_pts[1] # the main axis pts
|
583 |
+
full_main_axis_pts = torch.tensor([mean_pts[0], main_axis_pts, mean_pts[2]], dtype=torch.float32).cuda(self.args.th_cuda_idx)
|
584 |
+
elif cur_main_axis == 1:
|
585 |
+
main_axis_pts = maxx_pts[0] # the maxx axis pts
|
586 |
+
full_main_axis_pts = torch.tensor([main_axis_pts, mean_pts[1], mean_pts[2]], dtype=torch.float32).cuda(self.args.th_cuda_idx)
|
587 |
+
self.full_main_axis_pts_ref = full_main_axis_pts
|
588 |
+
else:
|
589 |
+
self.full_main_axis_pts_ref = mean_pts.clone() ### get the mean pts ###
|
590 |
+
# mean_pts
|
591 |
+
# main_axis_pts =
|
592 |
+
|
593 |
+
|
594 |
+
# self.visual_pts_ref = verts #
|
595 |
+
# self.visual_faces_ref = faces #
|
596 |
+
# get visual points colors # the color should be an optimizable property # # # or init visual point colors here ## # or init visual point colors here #
|
597 |
+
# simulatoable assets ## for the
|
598 |
+
|
599 |
+
def transform_visual_pts_ref(self,):
|
600 |
+
if self.name == "sphere":
|
601 |
+
visual_pts_ref = self.visual_pts_ref / 2. #
|
602 |
+
visual_pts_ref = visual_pts_ref * self.radius
|
603 |
+
else:
|
604 |
+
visual_pts_ref = self.visual_pts_ref
|
605 |
+
return visual_pts_ref
|
606 |
+
|
607 |
+
def transform_visual_pts(self, rot_mtx, trans_vec):
|
608 |
+
visual_pts_ref = self.transform_visual_pts_ref()
|
609 |
+
# rot_mtx: 3 x 3 numpy array
|
610 |
+
# trans_vec: 3 numpy array
|
611 |
+
# print(f"transforming body with rot_mtx: {rot_mtx} and trans_vec: {trans_vec}")
|
612 |
+
# self.visual_pts = np.matmul(rot_mtx, self.visual_pts_ref.T).T + trans_vec.reshape(1, 3) # reshape #
|
613 |
+
# print(f"rot_mtx: {rot_mtx}, trans_vec: {trans_vec}")
|
614 |
+
self.visual_pts = torch.matmul(rot_mtx, visual_pts_ref.transpose(1, 0)).transpose(1, 0) + trans_vec.unsqueeze(0)
|
615 |
+
|
616 |
+
# full_main_axis_pts ->
|
617 |
+
self.full_main_axis_pts = torch.matmul(rot_mtx, self.full_main_axis_pts_ref.unsqueeze(-1)).contiguous().squeeze(-1) + trans_vec
|
618 |
+
self.full_main_axis_pts = self.full_main_axis_pts.unsqueeze(0)
|
619 |
+
|
620 |
+
return self.visual_pts
|
621 |
+
|
622 |
+
def get_tot_transformed_joints(self, transformed_joints):
|
623 |
+
if self.name in self.body_name_to_main_axis:
|
624 |
+
transformed_joints.append(self.full_main_axis_pts)
|
625 |
+
return transformed_joints
|
626 |
+
|
627 |
+
def get_nn_pts(self,):
|
628 |
+
self.nn_pts = self.visual_pts_ref.size(0)
|
629 |
+
return self.nn_pts
|
630 |
+
|
631 |
+
def set_args(self, args):
|
632 |
+
self.args = args
|
633 |
+
|
634 |
+
def clear_grad(self, ):
|
635 |
+
if self.pos.grad is not None:
|
636 |
+
self.pos.grad.data = self.pos.grad.data * 0.
|
637 |
+
if self.radius.grad is not None:
|
638 |
+
self.radius.grad.data = self.radius.grad.data * 0.
|
639 |
+
|
640 |
+
|
641 |
+
# get the visual counterparts of the boyd mesh or elements #
|
642 |
+
|
643 |
+
# xyz attribute ## ## xyz attribute #
|
644 |
+
|
645 |
+
# use get_name_to_visual_pts
|
646 |
+
# use get_name_to_visual_pts_faces to get the transformed visual pts and faces #
|
647 |
+
class Link:
|
648 |
+
def __init__(self, name, joint: Joint, body: Body, children, args) -> None:
|
649 |
+
|
650 |
+
self.joint = joint
|
651 |
+
self.body = body
|
652 |
+
self.children = children
|
653 |
+
self.name = name
|
654 |
+
|
655 |
+
self.args = args
|
656 |
+
|
657 |
+
# joint # parent_rot_mtx, parent_trans_vec
|
658 |
+
self.parent_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True)
|
659 |
+
self.parent_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True)
|
660 |
+
self.curr_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True)
|
661 |
+
self.curr_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True)
|
662 |
+
#
|
663 |
+
self.tot_rot_mtx = nn.Parameter(torch.eye(n=3, dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True)
|
664 |
+
self.tot_trans_vec = nn.Parameter(torch.zeros((3,), dtype=torch.float32).cuda(self.args.th_cuda_idx), requires_grad=True) ## torch zeros #
|
665 |
+
|
666 |
+
def print_grads(self, ): ### print grads here ###
|
667 |
+
print(f"parent_rot_mtx: {self.parent_rot_mtx.grad}")
|
668 |
+
print(f"parent_trans_vec: {self.parent_trans_vec.grad}")
|
669 |
+
print(f"curr_rot_mtx: {self.curr_rot_mtx.grad}")
|
670 |
+
print(f"curr_trans_vec: {self.curr_trans_vec.grad}")
|
671 |
+
print(f"tot_rot_mtx: {self.tot_rot_mtx.grad}")
|
672 |
+
print(f"tot_trans_vec: {self.tot_trans_vec.grad}")
|
673 |
+
print(f"Joint")
|
674 |
+
self.joint.print_grads()
|
675 |
+
for cur_link in self.children:
|
676 |
+
cur_link.print_grads()
|
677 |
+
|
678 |
+
|
679 |
+
def set_state(self, name_to_state):
|
680 |
+
self.joint.set_state(name_to_state=name_to_state)
|
681 |
+
for child_link in self.children:
|
682 |
+
child_link.set_state(name_to_state)
|
683 |
+
|
684 |
+
|
685 |
+
def set_state_via_vec(self, state_vec): #
|
686 |
+
self.joint.set_state_via_vec(state_vec)
|
687 |
+
for child_link in self.children:
|
688 |
+
child_link.set_state_via_vec(state_vec)
|
689 |
+
# if self.joint_idx >= 0:
|
690 |
+
# self.state = state_vec[self.joint_idx]
|
691 |
+
|
692 |
+
##
|
693 |
+
def get_tot_transformed_joints(self, transformed_joints):
|
694 |
+
cur_joint_transformed_pts = self.joint.transformed_joint_pts.unsqueeze(0) ### 3 pts
|
695 |
+
transformed_joints.append(cur_joint_transformed_pts)
|
696 |
+
transformed_joints = self.body.get_tot_transformed_joints(transformed_joints)
|
697 |
+
# if self.joint.name
|
698 |
+
for cur_link in self.children:
|
699 |
+
transformed_joints = cur_link.get_tot_transformed_joints(transformed_joints)
|
700 |
+
return transformed_joints
|
701 |
+
|
702 |
+
def compute_transformation_via_state_vecs(self, state_vals, parent_rot_mtx, parent_trans_vec, visual_pts_list):
|
703 |
+
# state vecs and rot mtx # state vecs #####
|
704 |
+
joint_rot_mtx, joint_trans_vec = self.joint.compute_transformation_via_state_vals(state_vals=state_vals)
|
705 |
+
|
706 |
+
self.curr_rot_mtx = joint_rot_mtx
|
707 |
+
self.curr_trans_vec = joint_trans_vec
|
708 |
+
|
709 |
+
self.joint.transform_joints_via_parent_rot_trans_infos(parent_rot_mtx=parent_rot_mtx, parent_trans_vec=parent_trans_vec) ## get rot and trans mtx and vecs ###
|
710 |
+
|
711 |
+
tot_parent_rot_mtx = torch.matmul(parent_rot_mtx, joint_rot_mtx)
|
712 |
+
tot_parent_trans_vec = torch.matmul(parent_rot_mtx, joint_trans_vec.unsqueeze(-1)).view(3) + parent_trans_vec
|
713 |
+
|
714 |
+
self.tot_rot_mtx = tot_parent_rot_mtx
|
715 |
+
self.tot_trans_vec = tot_parent_trans_vec
|
716 |
+
|
717 |
+
# self.tot_rot_mtx = np.copy(tot_parent_rot_mtx)
|
718 |
+
# self.tot_trans_vec = np.copy(tot_parent_trans_vec)
|
719 |
+
|
720 |
+
### visual_pts_list for recording visual pts ###
|
721 |
+
|
722 |
+
cur_body_visual_pts = self.body.transform_visual_pts(rot_mtx=self.tot_rot_mtx, trans_vec=self.tot_trans_vec)
|
723 |
+
visual_pts_list.append(cur_body_visual_pts)
|
724 |
+
|
725 |
+
for cur_link in self.children:
|
726 |
+
# cur_link.parent_rot_mtx = np.copy(tot_parent_rot_mtx) ### set children parent rot mtx and the trans vec
|
727 |
+
# cur_link.parent_trans_vec = np.copy(tot_parent_trans_vec) ##
|
728 |
+
cur_link.parent_rot_mtx = tot_parent_rot_mtx ### set children parent rot mtx and the trans vec #
|
729 |
+
cur_link.parent_trans_vec = tot_parent_trans_vec ##
|
730 |
+
# cur_link.compute_transformation() ## compute self's transformations
|
731 |
+
cur_link.compute_transformation_via_state_vecs(state_vals, tot_parent_rot_mtx, tot_parent_trans_vec, visual_pts_list)
|
732 |
+
|
733 |
+
def get_visual_pts_rgba_values(self, pts_rgba_vals_list):
|
734 |
+
|
735 |
+
cur_body_visual_rgba_vals = self.body.get_visual_pts_colors()
|
736 |
+
pts_rgba_vals_list.append(cur_body_visual_rgba_vals)
|
737 |
+
|
738 |
+
for cur_link in self.children:
|
739 |
+
cur_link.get_visual_pts_rgba_values(pts_rgba_vals_list)
|
740 |
+
|
741 |
+
|
742 |
+
|
743 |
+
def compute_transformation(self,):
|
744 |
+
self.joint.compute_transformation()
|
745 |
+
# self.curr_rot_mtx = np.copy(self.joint.rot_mtx)
|
746 |
+
# self.curr_trans_vec = np.copy(self.joint.trans_vec)
|
747 |
+
|
748 |
+
self.curr_rot_mtx = self.joint.rot_mtx
|
749 |
+
self.curr_trans_vec = self.joint.trans_vec
|
750 |
+
# rot_p (rot_c p + trans_c) + trans_p #
|
751 |
+
# rot_p rot_c p + rot_p trans_c + trans_p #
|
752 |
+
#### matmul ####
|
753 |
+
# tot_parent_rot_mtx = np.matmul(self.parent_rot_mtx, self.curr_rot_mtx)
|
754 |
+
# tot_parent_trans_vec = np.matmul(self.parent_rot_mtx, self.curr_trans_vec.reshape(3, 1)).reshape(3) + self.parent_trans_vec
|
755 |
+
|
756 |
+
tot_parent_rot_mtx = torch.matmul(self.parent_rot_mtx, self.curr_rot_mtx)
|
757 |
+
tot_parent_trans_vec = torch.matmul(self.parent_rot_mtx, self.curr_trans_vec.unsqueeze(-1)).view(3) + self.parent_trans_vec
|
758 |
+
|
759 |
+
self.tot_rot_mtx = tot_parent_rot_mtx
|
760 |
+
self.tot_trans_vec = tot_parent_trans_vec
|
761 |
+
|
762 |
+
# self.tot_rot_mtx = np.copy(tot_parent_rot_mtx)
|
763 |
+
# self.tot_trans_vec = np.copy(tot_parent_trans_vec)
|
764 |
+
|
765 |
+
for cur_link in self.children:
|
766 |
+
# cur_link.parent_rot_mtx = np.copy(tot_parent_rot_mtx) ### set children parent rot mtx and the trans vec
|
767 |
+
# cur_link.parent_trans_vec = np.copy(tot_parent_trans_vec) ##
|
768 |
+
cur_link.parent_rot_mtx = tot_parent_rot_mtx ### set children parent rot mtx and the trans vec #
|
769 |
+
cur_link.parent_trans_vec = tot_parent_trans_vec ##
|
770 |
+
cur_link.compute_transformation() ## compute self's transformations
|
771 |
+
|
772 |
+
def get_name_to_visual_pts_faces(self, name_to_visual_pts_faces):
|
773 |
+
# transform_visual_pts # ## rot_mt
|
774 |
+
self.body.transform_visual_pts(rot_mtx=self.tot_rot_mtx, trans_vec=self.tot_trans_vec)
|
775 |
+
name_to_visual_pts_faces[self.body.name] = {"pts": self.body.visual_pts, "faces": self.body.visual_faces_ref}
|
776 |
+
for cur_link in self.children:
|
777 |
+
cur_link.get_name_to_visual_pts_faces(name_to_visual_pts_faces) ## transform the pts faces
|
778 |
+
|
779 |
+
def get_visual_pts_list(self, visual_pts_list):
|
780 |
+
# transform_visual_pts # ## rot_mt
|
781 |
+
self.body.transform_visual_pts(rot_mtx=self.tot_rot_mtx, trans_vec=self.tot_trans_vec)
|
782 |
+
visual_pts_list.append(self.body.visual_pts) # body template #
|
783 |
+
# name_to_visual_pts_faces[self.body.name] = {"pts": self.body.visual_pts, "faces": self.body.visual_faces_ref}
|
784 |
+
for cur_link in self.children:
|
785 |
+
# cur_link.get_name_to_visual_pts_faces(name_to_visual_pts_faces) ## transform the pts faces
|
786 |
+
cur_link.get_visual_pts_list(visual_pts_list)
|
787 |
+
|
788 |
+
|
789 |
+
|
790 |
+
def set_joint_idx(self, joint_name_to_idx):
|
791 |
+
self.joint.set_joint_idx(joint_name_to_idx)
|
792 |
+
for cur_link in self.children:
|
793 |
+
cur_link.set_joint_idx(joint_name_to_idx)
|
794 |
+
# if self.name in joint_name_to_idx:
|
795 |
+
# self.joint_idx = joint_name_to_idx[self.name]
|
796 |
+
|
797 |
+
def get_nn_pts(self,):
|
798 |
+
nn_pts = 0
|
799 |
+
nn_pts += self.body.get_nn_pts()
|
800 |
+
for cur_link in self.children:
|
801 |
+
nn_pts += cur_link.get_nn_pts()
|
802 |
+
self.nn_pts = nn_pts
|
803 |
+
return self.nn_pts
|
804 |
+
|
805 |
+
def clear_grads(self,):
|
806 |
+
|
807 |
+
if self.parent_rot_mtx.grad is not None:
|
808 |
+
self.parent_rot_mtx.grad.data = self.parent_rot_mtx.grad.data * 0.
|
809 |
+
if self.parent_trans_vec.grad is not None:
|
810 |
+
self.parent_trans_vec.grad.data = self.parent_trans_vec.grad.data * 0.
|
811 |
+
if self.curr_rot_mtx.grad is not None:
|
812 |
+
self.curr_rot_mtx.grad.data = self.curr_rot_mtx.grad.data * 0.
|
813 |
+
if self.curr_trans_vec.grad is not None:
|
814 |
+
self.curr_trans_vec.grad.data = self.curr_trans_vec.grad.data * 0.
|
815 |
+
if self.tot_rot_mtx.grad is not None:
|
816 |
+
self.tot_rot_mtx.grad.data = self.tot_rot_mtx.grad.data * 0.
|
817 |
+
if self.tot_trans_vec.grad is not None:
|
818 |
+
self.tot_trans_vec.grad.data = self.tot_trans_vec.grad.data * 0.
|
819 |
+
# print(f"parent_rot_mtx: {self.parent_rot_mtx.grad}")
|
820 |
+
# print(f"parent_trans_vec: {self.parent_trans_vec.grad}")
|
821 |
+
# print(f"curr_rot_mtx: {self.curr_rot_mtx.grad}")
|
822 |
+
# print(f"curr_trans_vec: {self.curr_trans_vec.grad}")
|
823 |
+
# print(f"tot_rot_mtx: {self.tot_rot_mtx.grad}")
|
824 |
+
# print(f"tot_trans_vec: {self.tot_trans_vec.grad}")
|
825 |
+
# print(f"Joint")
|
826 |
+
self.joint.clear_grads()
|
827 |
+
self.body.clear_grad()
|
828 |
+
for cur_link in self.children:
|
829 |
+
cur_link.clear_grads()
|
830 |
+
|
831 |
+
def set_args(self, args):
|
832 |
+
self.args = args
|
833 |
+
for cur_link in self.children:
|
834 |
+
cur_link.set_args(args)
|
835 |
+
|
836 |
+
|
837 |
+
|
838 |
+
|
839 |
+
class Robot: # robot and the robot #
|
840 |
+
def __init__(self, children_links, args) -> None:
|
841 |
+
self.children = children_links
|
842 |
+
self.args = args
|
843 |
+
|
844 |
+
def set_state(self, name_to_state):
|
845 |
+
for cur_link in self.children:
|
846 |
+
cur_link.set_state(name_to_state)
|
847 |
+
|
848 |
+
def compute_transformation(self,):
|
849 |
+
for cur_link in self.children:
|
850 |
+
cur_link.compute_transformation()
|
851 |
+
|
852 |
+
def get_name_to_visual_pts_faces(self, name_to_visual_pts_faces):
|
853 |
+
for cur_link in self.children:
|
854 |
+
cur_link.get_name_to_visual_pts_faces(name_to_visual_pts_faces)
|
855 |
+
|
856 |
+
def get_visual_pts_list(self, visual_pts_list):
|
857 |
+
for cur_link in self.children:
|
858 |
+
cur_link.get_visual_pts_list(visual_pts_list)
|
859 |
+
|
860 |
+
def set_joint_idx(self, joint_name_to_idx):
|
861 |
+
for cur_link in self.children:
|
862 |
+
cur_link.set_joint_idx(joint_name_to_idx) ### set joint idx ###
|
863 |
+
|
864 |
+
def set_state_via_vec(self, state_vec): ### set the state vec for the state vec ###
|
865 |
+
for cur_link in self.children: ### set the state vec for the state vec ###
|
866 |
+
cur_link.set_state_via_vec(state_vec)
|
867 |
+
# self.joint.set_state_via_vec(state_vec)
|
868 |
+
# for child_link in self.children:
|
869 |
+
# child_link.set_state_via_vec(state_vec)
|
870 |
+
|
871 |
+
# get_tot_transformed_joints
|
872 |
+
def get_tot_transformed_joints(self, transformed_joints):
|
873 |
+
for cur_link in self.children: #
|
874 |
+
transformed_joints = cur_link.get_tot_transformed_joints(transformed_joints)
|
875 |
+
return transformed_joints
|
876 |
+
|
877 |
+
def get_nn_pts(self):
|
878 |
+
nn_pts = 0
|
879 |
+
for cur_link in self.children:
|
880 |
+
nn_pts += cur_link.get_nn_pts()
|
881 |
+
self.nn_pts = nn_pts
|
882 |
+
return self.nn_pts
|
883 |
+
|
884 |
+
def set_args(self, args):
|
885 |
+
self.args = args
|
886 |
+
for cur_link in self.children: ## args ##
|
887 |
+
cur_link.set_args(args)
|
888 |
+
|
889 |
+
def print_grads(self):
|
890 |
+
for cur_link in self.children:
|
891 |
+
cur_link.print_grads()
|
892 |
+
|
893 |
+
def clear_grads(self,): ## clear grads ##
|
894 |
+
for cur_link in self.children:
|
895 |
+
cur_link.clear_grads()
|
896 |
+
|
897 |
+
def compute_transformation_via_state_vecs(self, state_vals, visual_pts_list):
|
898 |
+
# parent_rot_mtx, parent_trans_vec
|
899 |
+
for cur_link in self.children:
|
900 |
+
cur_link.compute_transformation_via_state_vecs(state_vals, cur_link.parent_rot_mtx, cur_link.parent_trans_vec, visual_pts_list)
|
901 |
+
return visual_pts_list
|
902 |
+
|
903 |
+
# get_visual_pts_rgba_values(self, pts_rgba_vals_list):
|
904 |
+
def get_visual_pts_rgba_values(self, pts_rgba_vals_list):
|
905 |
+
for cur_link in self.children:
|
906 |
+
cur_link.get_visual_pts_rgba_values(pts_rgba_vals_list)
|
907 |
+
return pts_rgba_vals_list ## compute pts rgba vals list ##
|
908 |
+
|
909 |
+
def parse_nparray_from_string(strr, args):
|
910 |
+
vals = strr.split(" ")
|
911 |
+
vals = [float(val) for val in vals]
|
912 |
+
vals = np.array(vals, dtype=np.float32)
|
913 |
+
vals = torch.from_numpy(vals).float()
|
914 |
+
## vals ##
|
915 |
+
vals = nn.Parameter(vals.cuda(args.th_cuda_idx), requires_grad=True)
|
916 |
+
|
917 |
+
return vals
|
918 |
+
|
919 |
+
|
920 |
+
### parse link data ###
|
921 |
+
def parse_link_data(link, args):
|
922 |
+
|
923 |
+
link_name = link.attrib["name"]
|
924 |
+
# print(f"parsing link: {link_name}") ## joints body meshes #
|
925 |
+
|
926 |
+
joint = link.find("./joint")
|
927 |
+
|
928 |
+
joint_name = joint.attrib["name"]
|
929 |
+
joint_type = joint.attrib["type"]
|
930 |
+
if joint_type in ["revolute"]: ## a general xml parser here?
|
931 |
+
axis = joint.attrib["axis"]
|
932 |
+
axis = parse_nparray_from_string(axis, args=args)
|
933 |
+
else:
|
934 |
+
axis = None
|
935 |
+
pos = joint.attrib["pos"] #
|
936 |
+
pos = parse_nparray_from_string(pos, args=args)
|
937 |
+
quat = joint.attrib["quat"]
|
938 |
+
quat = parse_nparray_from_string(quat, args=args)
|
939 |
+
|
940 |
+
try:
|
941 |
+
frame = joint.attrib["frame"]
|
942 |
+
except:
|
943 |
+
frame = "WORLD"
|
944 |
+
|
945 |
+
if joint_type not in ["fixed"]:
|
946 |
+
damping = joint.attrib["damping"]
|
947 |
+
damping = float(damping)
|
948 |
+
else:
|
949 |
+
damping = 0.0
|
950 |
+
|
951 |
+
cur_joint = Joint(joint_name, joint_type, axis, pos, quat, frame, damping, args=args)
|
952 |
+
|
953 |
+
body = link.find("./body")
|
954 |
+
body_name = body.attrib["name"]
|
955 |
+
body_type = body.attrib["type"]
|
956 |
+
if body_type == "mesh":
|
957 |
+
filename = body.attrib["filename"]
|
958 |
+
else:
|
959 |
+
filename = ""
|
960 |
+
|
961 |
+
if body_type == "sphere":
|
962 |
+
radius = body.attrib["radius"]
|
963 |
+
radius = float(radius)
|
964 |
+
else:
|
965 |
+
radius = 0.
|
966 |
+
|
967 |
+
pos = body.attrib["pos"]
|
968 |
+
pos = parse_nparray_from_string(pos, args=args)
|
969 |
+
quat = body.attrib["quat"]
|
970 |
+
quat = joint.attrib["quat"]
|
971 |
+
try:
|
972 |
+
transform_type = body.attrib["transform_type"]
|
973 |
+
except:
|
974 |
+
transform_type = "OBJ_TO_WORLD"
|
975 |
+
density = body.attrib["density"]
|
976 |
+
density = float(density)
|
977 |
+
mu = body.attrib["mu"]
|
978 |
+
mu = float(mu)
|
979 |
+
try: ## rgba ##
|
980 |
+
rgba = body.attrib["rgba"]
|
981 |
+
rgba = parse_nparray_from_string(rgba, args=args)
|
982 |
+
except:
|
983 |
+
rgba = np.zeros((4,), dtype=np.float32)
|
984 |
+
|
985 |
+
cur_body = Body(body_name, body_type, filename, pos, quat, transform_type, density, mu, rgba, radius, args=args)
|
986 |
+
|
987 |
+
children_link = []
|
988 |
+
links = link.findall("./link")
|
989 |
+
for child_link in links: #
|
990 |
+
cur_child_link = parse_link_data(child_link, args=args)
|
991 |
+
children_link.append(cur_child_link)
|
992 |
+
|
993 |
+
link_name = link.attrib["name"]
|
994 |
+
link_obj = Link(link_name, joint=cur_joint, body=cur_body, children=children_link, args=args)
|
995 |
+
return link_obj
|
996 |
+
|
997 |
+
|
998 |
+
|
999 |
+
|
1000 |
+
def parse_data_from_xml(xml_fn, args):
|
1001 |
+
|
1002 |
+
tree = ElementTree()
|
1003 |
+
tree.parse(xml_fn)
|
1004 |
+
|
1005 |
+
### get total robots ###
|
1006 |
+
robots = tree.findall("./robot")
|
1007 |
+
i_robot = 0
|
1008 |
+
tot_robots = []
|
1009 |
+
for cur_robot in robots:
|
1010 |
+
print(f"Getting robot: {i_robot}")
|
1011 |
+
i_robot += 1
|
1012 |
+
cur_links = cur_robot.findall("./link")
|
1013 |
+
# i_link = 0
|
1014 |
+
cur_robot_links = []
|
1015 |
+
for cur_link in cur_links: ## child of the link ##
|
1016 |
+
### a parse link util -> the child of the link is composed of (the joint; body; and children links (with children or with no child here))
|
1017 |
+
# cur_link_name = cur_link.attrib["name"]
|
1018 |
+
# print(f"Getting link: {i_link} with name: {cur_link_name}")
|
1019 |
+
# i_link += 1 ##
|
1020 |
+
cur_robot_links.append(parse_link_data(cur_link, args=args))
|
1021 |
+
cur_robot_obj = Robot(cur_robot_links, args=args)
|
1022 |
+
tot_robots.append(cur_robot_obj)
|
1023 |
+
|
1024 |
+
|
1025 |
+
tot_actuators = []
|
1026 |
+
actuators = tree.findall("./actuator/motor")
|
1027 |
+
joint_nm_to_joint_idx = {}
|
1028 |
+
i_act = 0
|
1029 |
+
for cur_act in actuators:
|
1030 |
+
cur_act_joint_nm = cur_act.attrib["joint"]
|
1031 |
+
joint_nm_to_joint_idx[cur_act_joint_nm] = i_act
|
1032 |
+
i_act += 1 ### add the act ###
|
1033 |
+
|
1034 |
+
tot_robots[0].set_joint_idx(joint_nm_to_joint_idx) ### set joint idx here ### # tot robots #
|
1035 |
+
tot_robots[0].get_nn_pts()
|
1036 |
+
tot_robots[1].get_nn_pts()
|
1037 |
+
|
1038 |
+
return tot_robots
|
1039 |
+
|
1040 |
+
def get_name_to_state_from_str(states_str):
|
1041 |
+
tot_states = states_str.split(" ")
|
1042 |
+
tot_states = [float(cur_state) for cur_state in tot_states]
|
1043 |
+
joint_name_to_state = {}
|
1044 |
+
for i in range(len(tot_states)):
|
1045 |
+
cur_joint_name = f"joint{i + 1}"
|
1046 |
+
cur_joint_state = tot_states[i]
|
1047 |
+
joint_name_to_state[cur_joint_name] = cur_joint_state
|
1048 |
+
return joint_name_to_state
|
1049 |
+
|
1050 |
+
def create_zero_states():
|
1051 |
+
nn_joints = 17
|
1052 |
+
joint_name_to_state = {}
|
1053 |
+
for i_j in range(nn_joints):
|
1054 |
+
cur_joint_name = f"joint{i_j + 1}"
|
1055 |
+
joint_name_to_state[cur_joint_name] = 0.
|
1056 |
+
return joint_name_to_state
|
1057 |
+
|
1058 |
+
# [6.96331033e-17 3.54807679e-06 1.74046190e-15 2.66367417e-05
|
1059 |
+
# 1.22444894e-05 3.38976792e-06 1.46917635e-15 2.66367383e-05
|
1060 |
+
# 1.22444882e-05 3.38976786e-06 1.97778813e-15 2.66367383e-05
|
1061 |
+
# 1.22444882e-05 3.38976786e-06 4.76033293e-16 1.26279884e-05
|
1062 |
+
# 3.51189993e-06 0.00000000e+00 4.89999978e-03 0.00000000e+00]
|
1063 |
+
|
1064 |
+
|
1065 |
+
def rotation_matrix_from_axis_angle_np(axis, angle): # rotation_matrix_from_axis_angle ->
|
1066 |
+
sin_ = np.sin(angle) # ti.math.sin(angle)
|
1067 |
+
cos_ = np.cos(angle) # ti.math.cos(angle)
|
1068 |
+
# sin_ = torch.sin(angle) # ti.math.sin(angle)
|
1069 |
+
# cos_ = torch.cos(angle) # ti.math.cos(angle)
|
1070 |
+
u_x, u_y, u_z = axis[0], axis[1], axis[2]
|
1071 |
+
u_xx = u_x * u_x
|
1072 |
+
u_yy = u_y * u_y
|
1073 |
+
u_zz = u_z * u_z
|
1074 |
+
u_xy = u_x * u_y
|
1075 |
+
u_xz = u_x * u_z
|
1076 |
+
u_yz = u_y * u_z ##
|
1077 |
+
|
1078 |
+
|
1079 |
+
row_a = np.stack(
|
1080 |
+
[cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], axis=0
|
1081 |
+
)
|
1082 |
+
# print(f"row_a: {row_a.size()}")
|
1083 |
+
row_b = np.stack(
|
1084 |
+
[u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], axis=0
|
1085 |
+
)
|
1086 |
+
# print(f"row_b: {row_b.size()}")
|
1087 |
+
row_c = np.stack(
|
1088 |
+
[u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], axis=0
|
1089 |
+
)
|
1090 |
+
# print(f"row_c: {row_c.size()}")
|
1091 |
+
|
1092 |
+
### rot_mtx for the rot_mtx ###
|
1093 |
+
rot_mtx = np.stack(
|
1094 |
+
[row_a, row_b, row_c], axis=-1 ### rot_matrix of he matrix ##
|
1095 |
+
)
|
1096 |
+
|
1097 |
+
return rot_mtx
|
1098 |
+
|
1099 |
+
|
1100 |
+
|
1101 |
+
|
1102 |
+
def rotation_matrix_from_axis_angle(axis, angle): # rotation_matrix_from_axis_angle ->
|
1103 |
+
# sin_ = np.sin(angle) # ti.math.sin(angle)
|
1104 |
+
# cos_ = np.cos(angle) # ti.math.cos(angle)
|
1105 |
+
sin_ = torch.sin(angle) # ti.math.sin(angle)
|
1106 |
+
cos_ = torch.cos(angle) # ti.math.cos(angle)
|
1107 |
+
u_x, u_y, u_z = axis[0], axis[1], axis[2]
|
1108 |
+
u_xx = u_x * u_x
|
1109 |
+
u_yy = u_y * u_y
|
1110 |
+
u_zz = u_z * u_z
|
1111 |
+
u_xy = u_x * u_y
|
1112 |
+
u_xz = u_x * u_z
|
1113 |
+
u_yz = u_y * u_z ##
|
1114 |
+
|
1115 |
+
|
1116 |
+
row_a = torch.stack(
|
1117 |
+
[cos_ + u_xx * (1 - cos_), u_xy * (1. - cos_) + u_z * sin_, u_xz * (1. - cos_) - u_y * sin_], dim=0
|
1118 |
+
)
|
1119 |
+
# print(f"row_a: {row_a.size()}")
|
1120 |
+
row_b = torch.stack(
|
1121 |
+
[u_xy * (1. - cos_) - u_z * sin_, cos_ + u_yy * (1. - cos_), u_yz * (1. - cos_) + u_x * sin_], dim=0
|
1122 |
+
)
|
1123 |
+
# print(f"row_b: {row_b.size()}")
|
1124 |
+
row_c = torch.stack(
|
1125 |
+
[u_xz * (1. - cos_) + u_y * sin_, u_yz * (1. - cos_) - u_x * sin_, cos_ + u_zz * (1. - cos_)], dim=0
|
1126 |
+
)
|
1127 |
+
# print(f"row_c: {row_c.size()}")
|
1128 |
+
|
1129 |
+
### rot_mtx for the rot_mtx ###
|
1130 |
+
rot_mtx = torch.stack(
|
1131 |
+
[row_a, row_b, row_c], dim=-1 ### rot_matrix of he matrix ##
|
1132 |
+
)
|
1133 |
+
|
1134 |
+
return rot_mtx
|
1135 |
+
|
1136 |
+
|
1137 |
+
def get_camera_to_world_poses(n=10, ):
|
1138 |
+
## sample from the upper half sphere ##
|
1139 |
+
# theta and phi for the
|
1140 |
+
theta = np.random.uniform(low=0.0, high=1.0, size=(n,)) * np.pi * 2. # xz palne #
|
1141 |
+
phi = np.random.uniform(low=-1.0, high=0.0, size=(n,)) * np.pi ## [-0.5 \pi, 0.5 \pi] ## negative pi to the original pi
|
1142 |
+
# theta = torch.from_numpy(theta).float().cuda()
|
1143 |
+
tot_c2w_matrix = []
|
1144 |
+
for i_n in range(n):
|
1145 |
+
# y_rot_vec = torch.tensor([0., 1., 0.]).float().cuda(th_cuda_idx)
|
1146 |
+
# y_rot_mtx = load_utils.rotation_matrix_from_axis_angle(rot_vec, rot_angle)
|
1147 |
+
|
1148 |
+
|
1149 |
+
z_axis_rot_axis = np.array([0, 0, 1.], dtype=np.float32)
|
1150 |
+
z_axis_rot_angle = np.pi - theta[i_n]
|
1151 |
+
z_axis_rot_matrix = rotation_matrix_from_axis_angle_np(z_axis_rot_axis, z_axis_rot_angle)
|
1152 |
+
rotated_plane_rot_axis_ori = np.array([1, -1, 0], dtype=np.float32)
|
1153 |
+
rotated_plane_rot_axis_ori = rotated_plane_rot_axis_ori / np.sqrt(np.sum(rotated_plane_rot_axis_ori ** 2))
|
1154 |
+
rotated_plane_rot_axis = np.matmul(z_axis_rot_matrix, rotated_plane_rot_axis_ori)
|
1155 |
+
|
1156 |
+
plane_rot_angle = phi[i_n]
|
1157 |
+
plane_rot_matrix = rotation_matrix_from_axis_angle_np(rotated_plane_rot_axis, plane_rot_angle)
|
1158 |
+
|
1159 |
+
c2w_matrix = np.matmul(plane_rot_matrix, z_axis_rot_matrix)
|
1160 |
+
c2w_trans_matrix = np.array(
|
1161 |
+
[np.cos(theta[i_n]) * np.sin(phi[i_n]), np.sin(theta[i_n]) * np.sin(phi[i_n]), np.cos(phi[i_n])], dtype=np.float32
|
1162 |
+
)
|
1163 |
+
c2w_matrix = np.concatenate(
|
1164 |
+
[c2w_matrix, c2w_trans_matrix.reshape(3, 1)], axis=-1
|
1165 |
+
) ##c2w matrix
|
1166 |
+
tot_c2w_matrix.append(c2w_matrix)
|
1167 |
+
tot_c2w_matrix = np.stack(tot_c2w_matrix, axis=0)
|
1168 |
+
return tot_c2w_matrix
|
1169 |
+
|
1170 |
+
|
1171 |
+
def get_camera_to_world_poses_th(n=10, th_cuda_idx=0):
|
1172 |
+
## sample from the upper half sphere ##
|
1173 |
+
# theta and phi for the
|
1174 |
+
theta = np.random.uniform(low=0.0, high=1.0, size=(n,)) * np.pi * 2. # xz palne #
|
1175 |
+
phi = np.random.uniform(low=-1.0, high=0.0, size=(n,)) * np.pi ## [-0.5 \pi, 0.5 \pi] ## negative pi to the original pi
|
1176 |
+
|
1177 |
+
# n_total = 14
|
1178 |
+
# n_xz = 14
|
1179 |
+
# n_y = 7
|
1180 |
+
# theta = [i_xz * 1.0 / float(n_xz) * np.pi * 2. for i_xz in range(n_xz)]
|
1181 |
+
# phi = [i_y * (-1.0) / float(n_y) * np.pi for i_y in range(n_y)]
|
1182 |
+
|
1183 |
+
|
1184 |
+
theta = torch.from_numpy(theta).float().cuda(th_cuda_idx)
|
1185 |
+
phi = torch.from_numpy(phi).float().cuda(th_cuda_idx)
|
1186 |
+
|
1187 |
+
tot_c2w_matrix = []
|
1188 |
+
for i_n in range(n): # if use veyr dense views like those
|
1189 |
+
y_rot_angle = theta[i_n]
|
1190 |
+
y_rot_vec = torch.tensor([0., 1., 0.]).float().cuda(th_cuda_idx)
|
1191 |
+
y_rot_mtx = rotation_matrix_from_axis_angle(y_rot_vec, y_rot_angle)
|
1192 |
+
|
1193 |
+
x_axis = torch.tensor([1., 0., 0.]).float().cuda(th_cuda_idx)
|
1194 |
+
y_rot_x_axis = torch.matmul(y_rot_mtx, x_axis.unsqueeze(-1)).squeeze(-1) ### y_rot_x_axis #
|
1195 |
+
|
1196 |
+
x_rot_angle = phi[i_n]
|
1197 |
+
x_rot_mtx = rotation_matrix_from_axis_angle(y_rot_x_axis, x_rot_angle)
|
1198 |
+
|
1199 |
+
rot_mtx = torch.matmul(x_rot_mtx, y_rot_mtx)
|
1200 |
+
xyz_offset = torch.tensor([0., 0., 1.5]).float().cuda(th_cuda_idx)
|
1201 |
+
rot_xyz_offset = torch.matmul(rot_mtx, xyz_offset.unsqueeze(-1)).squeeze(-1).contiguous() + 0.5 ### 3 for the xyz offset
|
1202 |
+
|
1203 |
+
c2w_matrix = torch.cat(
|
1204 |
+
[rot_mtx, rot_xyz_offset.unsqueeze(-1)], dim=-1
|
1205 |
+
)
|
1206 |
+
tot_c2w_matrix.append(c2w_matrix)
|
1207 |
+
|
1208 |
+
|
1209 |
+
# z_axis_rot_axis = np.array([0, 0, 1.], dtype=np.float32)
|
1210 |
+
# z_axis_rot_angle = np.pi - theta[i_n]
|
1211 |
+
# z_axis_rot_matrix = rotation_matrix_from_axis_angle_np(z_axis_rot_axis, z_axis_rot_angle)
|
1212 |
+
# rotated_plane_rot_axis_ori = np.array([1, -1, 0], dtype=np.float32)
|
1213 |
+
# rotated_plane_rot_axis_ori = rotated_plane_rot_axis_ori / np.sqrt(np.sum(rotated_plane_rot_axis_ori ** 2))
|
1214 |
+
# rotated_plane_rot_axis = np.matmul(z_axis_rot_matrix, rotated_plane_rot_axis_ori)
|
1215 |
+
|
1216 |
+
# plane_rot_angle = phi[i_n]
|
1217 |
+
# plane_rot_matrix = rotation_matrix_from_axis_angle_np(rotated_plane_rot_axis, plane_rot_angle)
|
1218 |
+
|
1219 |
+
# c2w_matrix = np.matmul(plane_rot_matrix, z_axis_rot_matrix)
|
1220 |
+
# c2w_trans_matrix = np.array(
|
1221 |
+
# [np.cos(theta[i_n]) * np.sin(phi[i_n]), np.sin(theta[i_n]) * np.sin(phi[i_n]), np.cos(phi[i_n])], dtype=np.float32
|
1222 |
+
# )
|
1223 |
+
# c2w_matrix = np.concatenate(
|
1224 |
+
# [c2w_matrix, c2w_trans_matrix.reshape(3, 1)], axis=-1
|
1225 |
+
# ) ##c2w matrix
|
1226 |
+
# tot_c2w_matrix.append(c2w_matrix)
|
1227 |
+
# tot_c2w_matrix = np.stack(tot_c2w_matrix, axis=0)
|
1228 |
+
tot_c2w_matrix = torch.stack(tot_c2w_matrix, dim=0)
|
1229 |
+
return tot_c2w_matrix
|
1230 |
+
|
1231 |
+
|
1232 |
+
def get_camera_to_world_poses_th_routine_1(n=7, th_cuda_idx=0):
|
1233 |
+
## sample from the upper half sphere ##
|
1234 |
+
# theta and phi for the
|
1235 |
+
|
1236 |
+
# theta = np.random.uniform(low=0.0, high=1.0, size=(n,)) * np.pi * 2. # xz palne #
|
1237 |
+
# phi = np.random.uniform(low=-1.0, high=0.0, size=(n,)) * np.pi ## [-0.5 \pi, 0.5 \pi] ## negative pi to the original pi
|
1238 |
+
|
1239 |
+
# n_total = 14
|
1240 |
+
n_xz = 2 * n # 14
|
1241 |
+
n_y = n # 7
|
1242 |
+
theta = [i_xz * 1.0 / float(n_xz) * np.pi * 2. for i_xz in range(n_xz)]
|
1243 |
+
phi = [i_y * (-1.0) / float(n_y) * np.pi for i_y in range(n_y)]
|
1244 |
+
|
1245 |
+
theta = torch.tensor(theta).float().cuda(th_cuda_idx)
|
1246 |
+
phi = torch.tensor(phi).float().cuda(th_cuda_idx)
|
1247 |
+
# theta = torch.from_numpy(theta).float().cuda(th_cuda_idx)
|
1248 |
+
# phi = torch.from_numpy(phi).float().cuda(th_cuda_idx)
|
1249 |
+
|
1250 |
+
tot_c2w_matrix = []
|
1251 |
+
|
1252 |
+
for i_theta in range(theta.size(0)):
|
1253 |
+
for i_phi in range(phi.size(0)):
|
1254 |
+
y_rot_angle = theta[i_theta]
|
1255 |
+
y_rot_vec = torch.tensor([0., 1., 0.]).float().cuda(th_cuda_idx)
|
1256 |
+
y_rot_mtx = rotation_matrix_from_axis_angle(y_rot_vec, y_rot_angle)
|
1257 |
+
|
1258 |
+
x_axis = torch.tensor([1., 0., 0.]).float().cuda(th_cuda_idx)
|
1259 |
+
y_rot_x_axis = torch.matmul(y_rot_mtx, x_axis.unsqueeze(-1)).squeeze(-1) ### y_rot_x_axis #
|
1260 |
+
|
1261 |
+
x_rot_angle = phi[i_phi]
|
1262 |
+
x_rot_mtx = rotation_matrix_from_axis_angle(y_rot_x_axis, x_rot_angle)
|
1263 |
+
|
1264 |
+
rot_mtx = torch.matmul(x_rot_mtx, y_rot_mtx)
|
1265 |
+
xyz_offset = torch.tensor([0., 0., 1.5]).float().cuda(th_cuda_idx)
|
1266 |
+
rot_xyz_offset = torch.matmul(rot_mtx, xyz_offset.unsqueeze(-1)).squeeze(-1).contiguous() + 0.5 ### 3 for the xyz offset
|
1267 |
+
|
1268 |
+
c2w_matrix = torch.cat(
|
1269 |
+
[rot_mtx, rot_xyz_offset.unsqueeze(-1)], dim=-1
|
1270 |
+
)
|
1271 |
+
tot_c2w_matrix.append(c2w_matrix)
|
1272 |
+
|
1273 |
+
tot_c2w_matrix = torch.stack(tot_c2w_matrix, dim=0)
|
1274 |
+
return tot_c2w_matrix
|
1275 |
+
|
1276 |
+
|
1277 |
+
def get_camera_to_world_poses_th_routine_2(n=7, th_cuda_idx=0):
|
1278 |
+
## sample from the upper half sphere ##
|
1279 |
+
# theta and phi for the
|
1280 |
+
|
1281 |
+
# theta = np.random.uniform(low=0.0, high=1.0, size=(n,)) * np.pi * 2. # xz palne #
|
1282 |
+
# phi = np.random.uniform(low=-1.0, high=0.0, size=(n,)) * np.pi ## [-0.5 \pi, 0.5 \pi] ## negative pi to the original pi
|
1283 |
+
|
1284 |
+
# n_total = 14
|
1285 |
+
n_xz = 2 * n # 14
|
1286 |
+
n_y = 2 * n # 7
|
1287 |
+
theta = [i_xz * 1.0 / float(n_xz) * np.pi * 2. for i_xz in range(n_xz)]
|
1288 |
+
# phi = [i_y * (-1.0) / float(n_y) * np.pi for i_y in range(n_y)]
|
1289 |
+
phi = [i_y * (-1.0) / float(n_y) * np.pi * 2. for i_y in range(n_y)]
|
1290 |
+
|
1291 |
+
theta = torch.tensor(theta).float().cuda(th_cuda_idx)
|
1292 |
+
phi = torch.tensor(phi).float().cuda(th_cuda_idx)
|
1293 |
+
# theta = torch.from_numpy(theta).float().cuda(th_cuda_idx)
|
1294 |
+
# phi = torch.from_numpy(phi).float().cuda(th_cuda_idx)
|
1295 |
+
|
1296 |
+
tot_c2w_matrix = []
|
1297 |
+
|
1298 |
+
for i_theta in range(theta.size(0)):
|
1299 |
+
for i_phi in range(phi.size(0)):
|
1300 |
+
y_rot_angle = theta[i_theta]
|
1301 |
+
y_rot_vec = torch.tensor([0., 1., 0.]).float().cuda(th_cuda_idx)
|
1302 |
+
y_rot_mtx = rotation_matrix_from_axis_angle(y_rot_vec, y_rot_angle)
|
1303 |
+
|
1304 |
+
x_axis = torch.tensor([1., 0., 0.]).float().cuda(th_cuda_idx)
|
1305 |
+
y_rot_x_axis = torch.matmul(y_rot_mtx, x_axis.unsqueeze(-1)).squeeze(-1) ### y_rot_x_axis #
|
1306 |
+
|
1307 |
+
x_rot_angle = phi[i_phi]
|
1308 |
+
x_rot_mtx = rotation_matrix_from_axis_angle(y_rot_x_axis, x_rot_angle)
|
1309 |
+
|
1310 |
+
rot_mtx = torch.matmul(x_rot_mtx, y_rot_mtx)
|
1311 |
+
xyz_offset = torch.tensor([0., 0., 1.5]).float().cuda(th_cuda_idx)
|
1312 |
+
rot_xyz_offset = torch.matmul(rot_mtx, xyz_offset.unsqueeze(-1)).squeeze(-1).contiguous() + 0.5 ### 3 for the xyz offset
|
1313 |
+
|
1314 |
+
c2w_matrix = torch.cat(
|
1315 |
+
[rot_mtx, rot_xyz_offset.unsqueeze(-1)], dim=-1
|
1316 |
+
)
|
1317 |
+
tot_c2w_matrix.append(c2w_matrix)
|
1318 |
+
|
1319 |
+
tot_c2w_matrix = torch.stack(tot_c2w_matrix, dim=0)
|
1320 |
+
return tot_c2w_matrix
|
1321 |
+
|
1322 |
+
|
1323 |
+
|
1324 |
+
|
1325 |
+
|
1326 |
+
|
1327 |
+
if __name__=='__main__':
|
1328 |
+
xml_fn = "/home/xueyi/diffsim/DiffHand/assets/hand_sphere.xml"
|
1329 |
+
tot_robots = parse_data_from_xml(xml_fn=xml_fn)
|
1330 |
+
# tot_robots =
|
1331 |
+
|
1332 |
+
active_optimized_states = """-0.00025872 -0.00025599 -0.00025296 -0.00022881 -0.00024449 -0.0002549 -0.00025296 -0.00022881 -0.00024449 -0.0002549 -0.00025296 -0.00022881 -0.00024449 -0.0002549 -0.00025694 -0.00024656 -0.00025556 0. 0.0049 0."""
|
1333 |
+
active_optimized_states = """-1.10617972 -1.10742263 -1.06198363 -1.03212746 -1.05429142 -1.08617289 -1.05868192 -1.01624365 -1.04478191 -1.08260959 -1.06719107 -1.04082455 -1.05995886 -1.08674006 -1.09396691 -1.08965532 -1.10036577 -10.7117466 -3.62511998 1.49450353"""
|
1334 |
+
# active_goal_optimized_states = """-1.10617972 -1.10742263 -1.0614858 -1.03189609 -1.05404354 -1.08610468 -1.05863293 -1.0174248 -1.04576456 -1.08297396 -1.06719107 -1.04082455 -1.05995886 -1.08674006 -1.09396691 -1.08965532 -1.10036577 -10.73396897 -3.68095432 1.50679285"""
|
1335 |
+
active_optimized_states = """-0.42455298 -0.42570447 -0.40567708 -0.39798589 -0.40953955 -0.42025055 -0.37910662 -0.496165 -0.37664644 -0.41942727 -0.40596508 -0.3982109 -0.40959847 -0.42024905 -0.41835001 -0.41929961 -0.42365131 -1.18756073 -2.90337822 0.4224685"""
|
1336 |
+
active_optimized_states = """-0.42442816 -0.42557961 -0.40366201 -0.3977891 -0.40947627 -0.4201424 -0.3799285 -0.3808375 -0.37953552 -0.42039598 -0.4058405 -0.39808804 -0.40947487 -0.42012458 -0.41822534 -0.41917521 -0.4235266 -0.87189658 -1.42093761 0.21977979"""
|
1337 |
+
|
1338 |
+
active_robot = tot_robots[0]
|
1339 |
+
zero_states = create_zero_states()
|
1340 |
+
active_robot.set_state(zero_states)
|
1341 |
+
active_robot.compute_transformation()
|
1342 |
+
name_to_visual_pts_surfaces = {}
|
1343 |
+
active_robot.get_name_to_visual_pts_faces(name_to_visual_pts_surfaces)
|
1344 |
+
print(len(name_to_visual_pts_surfaces))
|
1345 |
+
|
1346 |
+
sv_res_rt = "/home/xueyi/diffsim/DiffHand/examples/save_res"
|
1347 |
+
sv_res_rt = os.path.join(sv_res_rt, "load_utils_test")
|
1348 |
+
os.makedirs(sv_res_rt, exist_ok=True)
|
1349 |
+
|
1350 |
+
tmp_visual_res_sv_fn = os.path.join(sv_res_rt, f"res_with_zero_states.npy")
|
1351 |
+
np.save(tmp_visual_res_sv_fn, name_to_visual_pts_surfaces)
|
1352 |
+
print(f"tmp visual res saved to {tmp_visual_res_sv_fn}")
|
1353 |
+
|
1354 |
+
|
1355 |
+
optimized_states = get_name_to_state_from_str(active_optimized_states)
|
1356 |
+
active_robot.set_state(optimized_states)
|
1357 |
+
active_robot.compute_transformation()
|
1358 |
+
name_to_visual_pts_surfaces = {}
|
1359 |
+
active_robot.get_name_to_visual_pts_faces(name_to_visual_pts_surfaces)
|
1360 |
+
print(len(name_to_visual_pts_surfaces))
|
1361 |
+
# sv_res_rt = "/home/xueyi/diffsim/DiffHand/examples/save_res"
|
1362 |
+
# sv_res_rt = os.path.join(sv_res_rt, "load_utils_test")
|
1363 |
+
# os.makedirs(sv_res_rt, exist_ok=True)
|
1364 |
+
|
1365 |
+
# tmp_visual_res_sv_fn = os.path.join(sv_res_rt, f"res_with_optimized_states.npy")
|
1366 |
+
tmp_visual_res_sv_fn = os.path.join(sv_res_rt, f"active_ngoal_res_with_optimized_states_goal_n3.npy")
|
1367 |
+
np.save(tmp_visual_res_sv_fn, name_to_visual_pts_surfaces)
|
1368 |
+
print(f"tmp visual res with optimized states saved to {tmp_visual_res_sv_fn}")
|
1369 |
+
|
models/embedder.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
# Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
|
6 |
+
class Embedder:
|
7 |
+
def __init__(self, **kwargs):
|
8 |
+
self.kwargs = kwargs
|
9 |
+
self.create_embedding_fn()
|
10 |
+
|
11 |
+
def create_embedding_fn(self):
|
12 |
+
embed_fns = []
|
13 |
+
d = self.kwargs['input_dims']
|
14 |
+
out_dim = 0
|
15 |
+
if self.kwargs['include_input']:
|
16 |
+
embed_fns.append(lambda x: x)
|
17 |
+
out_dim += d
|
18 |
+
|
19 |
+
max_freq = self.kwargs['max_freq_log2']
|
20 |
+
N_freqs = self.kwargs['num_freqs']
|
21 |
+
|
22 |
+
if self.kwargs['log_sampling']:
|
23 |
+
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
|
24 |
+
else:
|
25 |
+
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
|
26 |
+
|
27 |
+
for freq in freq_bands:
|
28 |
+
for p_fn in self.kwargs['periodic_fns']:
|
29 |
+
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
|
30 |
+
out_dim += d
|
31 |
+
|
32 |
+
self.embed_fns = embed_fns
|
33 |
+
self.out_dim = out_dim
|
34 |
+
|
35 |
+
def embed(self, inputs):
|
36 |
+
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
37 |
+
|
38 |
+
|
39 |
+
def get_embedder(multires, input_dims=3):
|
40 |
+
embed_kwargs = {
|
41 |
+
'include_input': True,
|
42 |
+
'input_dims': input_dims,
|
43 |
+
'max_freq_log2': multires-1,
|
44 |
+
'num_freqs': multires,
|
45 |
+
'log_sampling': True,
|
46 |
+
'periodic_fns': [torch.sin, torch.cos],
|
47 |
+
}
|
48 |
+
|
49 |
+
embedder_obj = Embedder(**embed_kwargs)
|
50 |
+
def embed(x, eo=embedder_obj): return eo.embed(x)
|
51 |
+
return embed, embedder_obj.out_dim
|
models/fields.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/fields_old.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/renderer.py
ADDED
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import logging
|
6 |
+
import mcubes
|
7 |
+
from icecream import ic
|
8 |
+
import os
|
9 |
+
|
10 |
+
import trimesh
|
11 |
+
from pysdf import SDF
|
12 |
+
|
13 |
+
from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
|
14 |
+
|
15 |
+
def create_mt_variable(device):
|
16 |
+
triangle_table = torch.tensor(
|
17 |
+
[
|
18 |
+
[-1, -1, -1, -1, -1, -1],
|
19 |
+
[1, 0, 2, -1, -1, -1],
|
20 |
+
[4, 0, 3, -1, -1, -1],
|
21 |
+
[1, 4, 2, 1, 3, 4],
|
22 |
+
[3, 1, 5, -1, -1, -1],
|
23 |
+
[2, 3, 0, 2, 5, 3],
|
24 |
+
[1, 4, 0, 1, 5, 4],
|
25 |
+
[4, 2, 5, -1, -1, -1],
|
26 |
+
[4, 5, 2, -1, -1, -1],
|
27 |
+
[4, 1, 0, 4, 5, 1],
|
28 |
+
[3, 2, 0, 3, 5, 2],
|
29 |
+
[1, 3, 5, -1, -1, -1],
|
30 |
+
[4, 1, 2, 4, 3, 1],
|
31 |
+
[3, 0, 4, -1, -1, -1],
|
32 |
+
[2, 0, 1, -1, -1, -1],
|
33 |
+
[-1, -1, -1, -1, -1, -1]
|
34 |
+
], dtype=torch.long, device=device)
|
35 |
+
|
36 |
+
num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
|
37 |
+
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
|
38 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
|
39 |
+
return triangle_table, num_triangles_table, base_tet_edges, v_id
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def extract_fields_from_tets(bound_min, bound_max, resolution, query_func):
|
44 |
+
# load tet via resolution #
|
45 |
+
# scale them via bounds #
|
46 |
+
# extract the geometry #
|
47 |
+
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
|
48 |
+
device = bound_min.device
|
49 |
+
# if resolution in [64, 70, 80, 90, 100]:
|
50 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
|
51 |
+
# else:
|
52 |
+
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
53 |
+
tets = np.load(tet_fn)
|
54 |
+
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
|
55 |
+
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
|
56 |
+
# split #
|
57 |
+
# verts; verts; #
|
58 |
+
minn_verts, _ = torch.min(verts, dim=0)
|
59 |
+
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
|
60 |
+
# scale_verts = maxx_verts - minn_verts
|
61 |
+
scale_bounds = bound_max - bound_min # scale bounds #
|
62 |
+
|
63 |
+
### scale the vertices ###
|
64 |
+
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
65 |
+
|
66 |
+
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
67 |
+
|
68 |
+
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
|
69 |
+
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
|
70 |
+
|
71 |
+
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
|
72 |
+
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
|
73 |
+
|
74 |
+
sdf_values = []
|
75 |
+
N = 64
|
76 |
+
query_bundles = N ** 3 ### N^3
|
77 |
+
query_NNs = scaled_verts.size(0) // query_bundles
|
78 |
+
if query_NNs * query_bundles < scaled_verts.size(0):
|
79 |
+
query_NNs += 1
|
80 |
+
for i_query in range(query_NNs):
|
81 |
+
cur_bundle_st = i_query * query_bundles
|
82 |
+
cur_bundle_ed = (i_query + 1) * query_bundles
|
83 |
+
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
|
84 |
+
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
|
85 |
+
cur_query_vals = query_func(cur_query_pts)
|
86 |
+
sdf_values.append(cur_query_vals)
|
87 |
+
sdf_values = torch.cat(sdf_values, dim=0)
|
88 |
+
# print(f"queryed sdf values: {sdf_values.size()}") #
|
89 |
+
|
90 |
+
GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
|
91 |
+
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
|
92 |
+
|
93 |
+
# intrinsic, tet values, pts values, sdf network #
|
94 |
+
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
|
95 |
+
tet_table, num_tets_table = create_tetmesh_variables(device)
|
96 |
+
|
97 |
+
sdf_values = sdf_values.squeeze(-1) # how the rendering #
|
98 |
+
|
99 |
+
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
|
100 |
+
# print(f"scaled_verts: {scaled_verts.size()}, ")
|
101 |
+
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
102 |
+
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
103 |
+
# marching_tets_tetmesh ##
|
104 |
+
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
105 |
+
### use the GT sdf values for the marching tets ###
|
106 |
+
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
107 |
+
|
108 |
+
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
|
109 |
+
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
|
110 |
+
|
111 |
+
def extract_fields(bound_min, bound_max, resolution, query_func):
|
112 |
+
N = 64
|
113 |
+
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
|
114 |
+
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
|
115 |
+
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
|
116 |
+
|
117 |
+
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
118 |
+
with torch.no_grad():
|
119 |
+
for xi, xs in enumerate(X):
|
120 |
+
for yi, ys in enumerate(Y):
|
121 |
+
for zi, zs in enumerate(Z):
|
122 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
123 |
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
|
124 |
+
val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
|
125 |
+
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
|
126 |
+
# should save u here #
|
127 |
+
# save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
|
128 |
+
# np.save(save_u_path, u)
|
129 |
+
# print(f"u saved to {save_u_path}")
|
130 |
+
return u
|
131 |
+
|
132 |
+
|
133 |
+
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
|
134 |
+
print('threshold: {}'.format(threshold))
|
135 |
+
|
136 |
+
## using maching cubes ###
|
137 |
+
u = extract_fields(bound_min, bound_max, resolution, query_func)
|
138 |
+
vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
|
139 |
+
b_max_np = bound_max.detach().cpu().numpy()
|
140 |
+
b_min_np = bound_min.detach().cpu().numpy()
|
141 |
+
|
142 |
+
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
143 |
+
### using maching cubes ###
|
144 |
+
|
145 |
+
### using marching tets ###
|
146 |
+
# vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
|
147 |
+
# vertices = vertices.detach().cpu().numpy()
|
148 |
+
# triangles = triangles.detach().cpu().numpy()
|
149 |
+
### using marching tets ###
|
150 |
+
|
151 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
152 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
153 |
+
|
154 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
155 |
+
return vertices, triangles
|
156 |
+
|
157 |
+
def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func):
|
158 |
+
# print('threshold: {}'.format(threshold))
|
159 |
+
|
160 |
+
### using maching cubes ###
|
161 |
+
# u = extract_fields(bound_min, bound_max, resolution, query_func)
|
162 |
+
# vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
|
163 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
164 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
165 |
+
|
166 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
167 |
+
### using maching cubes ###
|
168 |
+
|
169 |
+
##
|
170 |
+
### using marching tets ### fiels from tets ##
|
171 |
+
vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
|
172 |
+
# vertices = vertices.detach().cpu().numpy()
|
173 |
+
# triangles = triangles.detach().cpu().numpy()
|
174 |
+
### using marching tets ###
|
175 |
+
|
176 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
177 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
178 |
+
#
|
179 |
+
|
180 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
181 |
+
return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
|
182 |
+
|
183 |
+
|
184 |
+
def sample_pdf(bins, weights, n_samples, det=False):
|
185 |
+
# This implementation is from NeRF
|
186 |
+
# Get pdf
|
187 |
+
weights = weights + 1e-5 # prevent nans
|
188 |
+
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
189 |
+
cdf = torch.cumsum(pdf, -1)
|
190 |
+
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
191 |
+
# Take uniform samples
|
192 |
+
if det:
|
193 |
+
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
|
194 |
+
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
195 |
+
else:
|
196 |
+
u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
|
197 |
+
|
198 |
+
# Invert CDF # invert cdf #
|
199 |
+
u = u.contiguous()
|
200 |
+
inds = torch.searchsorted(cdf, u, right=True)
|
201 |
+
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
202 |
+
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
203 |
+
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
204 |
+
|
205 |
+
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
206 |
+
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
207 |
+
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
208 |
+
|
209 |
+
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
210 |
+
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
211 |
+
t = (u - cdf_g[..., 0]) / denom
|
212 |
+
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
213 |
+
|
214 |
+
return samples
|
215 |
+
|
216 |
+
|
217 |
+
def load_GT_vertices(GT_meshes_folder):
|
218 |
+
tot_meshes_fns = os.listdir(GT_meshes_folder)
|
219 |
+
tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
|
220 |
+
tot_mesh_verts = []
|
221 |
+
tot_mesh_faces = []
|
222 |
+
n_tot_verts = 0
|
223 |
+
for fn in tot_meshes_fns:
|
224 |
+
cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
|
225 |
+
obj_mesh = trimesh.load(cur_mesh_fn, process=False)
|
226 |
+
# obj_mesh.remove_degenerate_faces(height=1e-06)
|
227 |
+
|
228 |
+
verts_obj = np.array(obj_mesh.vertices)
|
229 |
+
faces_obj = np.array(obj_mesh.faces)
|
230 |
+
|
231 |
+
tot_mesh_verts.append(verts_obj)
|
232 |
+
tot_mesh_faces.append(faces_obj + n_tot_verts)
|
233 |
+
n_tot_verts += verts_obj.shape[0]
|
234 |
+
|
235 |
+
# tot_mesh_faces.append(faces_obj)
|
236 |
+
tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
|
237 |
+
tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
|
238 |
+
return tot_mesh_verts, tot_mesh_faces
|
239 |
+
|
240 |
+
|
241 |
+
class NeuSRenderer:
|
242 |
+
def __init__(self,
|
243 |
+
nerf,
|
244 |
+
sdf_network,
|
245 |
+
deviation_network,
|
246 |
+
color_network,
|
247 |
+
n_samples,
|
248 |
+
n_importance,
|
249 |
+
n_outside,
|
250 |
+
up_sample_steps,
|
251 |
+
perturb):
|
252 |
+
self.nerf = nerf
|
253 |
+
self.sdf_network = sdf_network
|
254 |
+
self.deviation_network = deviation_network
|
255 |
+
self.color_network = color_network
|
256 |
+
self.n_samples = n_samples
|
257 |
+
self.n_importance = n_importance
|
258 |
+
self.n_outside = n_outside
|
259 |
+
self.up_sample_steps = up_sample_steps
|
260 |
+
self.perturb = perturb
|
261 |
+
|
262 |
+
GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
|
263 |
+
self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
|
264 |
+
maxx_pts = 25.
|
265 |
+
minn_pts = -15.
|
266 |
+
self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
|
267 |
+
f = SDF(self.mesh_vertices, self.mesh_faces)
|
268 |
+
self.gt_sdf = f ## a unite sphere or box
|
269 |
+
|
270 |
+
self.minn_pts = 0
|
271 |
+
self.maxx_pts = 1.
|
272 |
+
|
273 |
+
# self.minn_pts = -1.5
|
274 |
+
# self.maxx_pts = 1.5
|
275 |
+
self.bkg_pts = ... # TODO: the bkg pts
|
276 |
+
self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
|
277 |
+
self.dist_interp_thres = ... # TODO
|
278 |
+
|
279 |
+
# get the pts and render the pts #
|
280 |
+
# pts and the rendering pts #
|
281 |
+
def deform_pts(self, pts):
|
282 |
+
# pts: nn_batch x nn_samples x 3
|
283 |
+
nnb, nns = pts.size(0), pts.size(1)
|
284 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
285 |
+
dist_pts_to_bkg_pts = torch.sum(
|
286 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
287 |
+
)
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None):
|
293 |
+
"""
|
294 |
+
Render background
|
295 |
+
"""
|
296 |
+
batch_size, n_samples = z_vals.shape
|
297 |
+
|
298 |
+
# Section length
|
299 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
300 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
301 |
+
mid_z_vals = z_vals + dists * 0.5
|
302 |
+
|
303 |
+
# Section midpoints #
|
304 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
|
305 |
+
|
306 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
307 |
+
pts = pts * 2 - 1
|
308 |
+
|
309 |
+
dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
|
310 |
+
pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
|
311 |
+
|
312 |
+
dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
|
313 |
+
|
314 |
+
pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
|
315 |
+
dirs = dirs.reshape(-1, 3)
|
316 |
+
|
317 |
+
density, sampled_color = nerf(pts, dirs)
|
318 |
+
sampled_color = torch.sigmoid(sampled_color)
|
319 |
+
alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
|
320 |
+
alpha = alpha.reshape(batch_size, n_samples)
|
321 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
322 |
+
sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
|
323 |
+
color = (weights[:, :, None] * sampled_color).sum(dim=1)
|
324 |
+
if background_rgb is not None:
|
325 |
+
color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
|
326 |
+
|
327 |
+
return {
|
328 |
+
'color': color,
|
329 |
+
'sampled_color': sampled_color,
|
330 |
+
'alpha': alpha,
|
331 |
+
'weights': weights,
|
332 |
+
}
|
333 |
+
|
334 |
+
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
|
335 |
+
"""
|
336 |
+
Up sampling give a fixed inv_s
|
337 |
+
"""
|
338 |
+
batch_size, n_samples = z_vals.shape
|
339 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
|
340 |
+
|
341 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
342 |
+
pts = pts * 2 - 1
|
343 |
+
|
344 |
+
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
|
345 |
+
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
|
346 |
+
sdf = sdf.reshape(batch_size, n_samples)
|
347 |
+
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
|
348 |
+
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
|
349 |
+
mid_sdf = (prev_sdf + next_sdf) * 0.5
|
350 |
+
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
|
351 |
+
|
352 |
+
# ----------------------------------------------------------------------------------------------------------
|
353 |
+
# Use min value of [ cos, prev_cos ]
|
354 |
+
# Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
|
355 |
+
# robust when meeting situations like below:
|
356 |
+
#
|
357 |
+
# SDF
|
358 |
+
# ^
|
359 |
+
# |\ -----x----...
|
360 |
+
# | \ /
|
361 |
+
# | x x
|
362 |
+
# |---\----/-------------> 0 level
|
363 |
+
# | \ /
|
364 |
+
# | \/
|
365 |
+
# |
|
366 |
+
# ----------------------------------------------------------------------------------------------------------
|
367 |
+
prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
|
368 |
+
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
|
369 |
+
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
|
370 |
+
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
|
371 |
+
|
372 |
+
dist = (next_z_vals - prev_z_vals)
|
373 |
+
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
|
374 |
+
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
|
375 |
+
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
|
376 |
+
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
|
377 |
+
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
|
378 |
+
weights = alpha * torch.cumprod(
|
379 |
+
torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
380 |
+
|
381 |
+
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
|
382 |
+
return z_samples
|
383 |
+
|
384 |
+
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
|
385 |
+
batch_size, n_samples = z_vals.shape
|
386 |
+
_, n_importance = new_z_vals.shape
|
387 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
|
388 |
+
|
389 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
390 |
+
pts = pts * 2 - 1
|
391 |
+
|
392 |
+
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
|
393 |
+
z_vals, index = torch.sort(z_vals, dim=-1)
|
394 |
+
|
395 |
+
if not last:
|
396 |
+
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
397 |
+
sdf = torch.cat([sdf, new_sdf], dim=-1)
|
398 |
+
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
|
399 |
+
index = index.reshape(-1)
|
400 |
+
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
|
401 |
+
|
402 |
+
return z_vals, sdf
|
403 |
+
|
404 |
+
def render_core(self,
|
405 |
+
rays_o,
|
406 |
+
rays_d,
|
407 |
+
z_vals,
|
408 |
+
sample_dist,
|
409 |
+
sdf_network,
|
410 |
+
deviation_network,
|
411 |
+
color_network,
|
412 |
+
background_alpha=None,
|
413 |
+
background_sampled_color=None,
|
414 |
+
background_rgb=None,
|
415 |
+
cos_anneal_ratio=0.0):
|
416 |
+
batch_size, n_samples = z_vals.shape
|
417 |
+
|
418 |
+
# Section length
|
419 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
420 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
421 |
+
mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
|
422 |
+
|
423 |
+
# Section midpoints
|
424 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
|
425 |
+
dirs = rays_d[:, None, :].expand(pts.shape)
|
426 |
+
|
427 |
+
pts = pts.reshape(-1, 3) # pts, nn_ou
|
428 |
+
dirs = dirs.reshape(-1, 3)
|
429 |
+
|
430 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
431 |
+
|
432 |
+
# pts = pts.flip((-1,)) * 2 - 1 #
|
433 |
+
pts = pts * 2 - 1
|
434 |
+
|
435 |
+
sdf_nn_output = sdf_network(pts)
|
436 |
+
sdf = sdf_nn_output[:, :1]
|
437 |
+
feature_vector = sdf_nn_output[:, 1:]
|
438 |
+
|
439 |
+
gradients = sdf_network.gradient(pts).squeeze()
|
440 |
+
sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
|
441 |
+
|
442 |
+
# deviation network #
|
443 |
+
inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
|
444 |
+
inv_s = inv_s.expand(batch_size * n_samples, 1)
|
445 |
+
|
446 |
+
true_cos = (dirs * gradients).sum(-1, keepdim=True)
|
447 |
+
|
448 |
+
# "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
|
449 |
+
# the cos value "not dead" at the beginning training iterations, for better convergence.
|
450 |
+
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
|
451 |
+
F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
|
452 |
+
|
453 |
+
# Estimate signed distances at section points
|
454 |
+
estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
|
455 |
+
estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
|
456 |
+
|
457 |
+
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
|
458 |
+
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
|
459 |
+
|
460 |
+
p = prev_cdf - next_cdf
|
461 |
+
c = prev_cdf
|
462 |
+
|
463 |
+
alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
|
464 |
+
|
465 |
+
pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
|
466 |
+
inside_sphere = (pts_norm < 1.0).float().detach()
|
467 |
+
relax_inside_sphere = (pts_norm < 1.2).float().detach()
|
468 |
+
|
469 |
+
# Render with background
|
470 |
+
if background_alpha is not None:
|
471 |
+
alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
|
472 |
+
alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
|
473 |
+
sampled_color = sampled_color * inside_sphere[:, :, None] +\
|
474 |
+
background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
|
475 |
+
sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
|
476 |
+
|
477 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
478 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
479 |
+
|
480 |
+
color = (sampled_color * weights[:, :, None]).sum(dim=1)
|
481 |
+
if background_rgb is not None: # Fixed background, usually black
|
482 |
+
color = color + background_rgb * (1.0 - weights_sum)
|
483 |
+
|
484 |
+
# Eikonal loss
|
485 |
+
gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
|
486 |
+
dim=-1) - 1.0) ** 2
|
487 |
+
gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
|
488 |
+
|
489 |
+
return {
|
490 |
+
'color': color,
|
491 |
+
'sdf': sdf,
|
492 |
+
'dists': dists,
|
493 |
+
'gradients': gradients.reshape(batch_size, n_samples, 3),
|
494 |
+
's_val': 1.0 / inv_s,
|
495 |
+
'mid_z_vals': mid_z_vals,
|
496 |
+
'weights': weights,
|
497 |
+
'cdf': c.reshape(batch_size, n_samples),
|
498 |
+
'gradient_error': gradient_error,
|
499 |
+
'inside_sphere': inside_sphere
|
500 |
+
}
|
501 |
+
|
502 |
+
def render(self, rays_o, rays_d, near, far, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
|
503 |
+
batch_size = len(rays_o)
|
504 |
+
sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere
|
505 |
+
z_vals = torch.linspace(0.0, 1.0, self.n_samples)
|
506 |
+
z_vals = near + (far - near) * z_vals[None, :]
|
507 |
+
|
508 |
+
z_vals_outside = None
|
509 |
+
if self.n_outside > 0:
|
510 |
+
z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
|
511 |
+
|
512 |
+
n_samples = self.n_samples
|
513 |
+
perturb = self.perturb
|
514 |
+
|
515 |
+
if perturb_overwrite >= 0:
|
516 |
+
perturb = perturb_overwrite
|
517 |
+
if perturb > 0:
|
518 |
+
t_rand = (torch.rand([batch_size, 1]) - 0.5)
|
519 |
+
z_vals = z_vals + t_rand * 2.0 / self.n_samples
|
520 |
+
|
521 |
+
if self.n_outside > 0:
|
522 |
+
mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
|
523 |
+
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
|
524 |
+
lower = torch.cat([z_vals_outside[..., :1], mids], -1)
|
525 |
+
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
|
526 |
+
z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
|
527 |
+
|
528 |
+
if self.n_outside > 0:
|
529 |
+
z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
|
530 |
+
|
531 |
+
background_alpha = None
|
532 |
+
background_sampled_color = None
|
533 |
+
|
534 |
+
# Up sample
|
535 |
+
if self.n_importance > 0:
|
536 |
+
with torch.no_grad():
|
537 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
538 |
+
|
539 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
540 |
+
# sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
|
541 |
+
# gt_sdf #
|
542 |
+
|
543 |
+
#
|
544 |
+
# pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
|
545 |
+
|
546 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
547 |
+
pts = pts * 2 - 1
|
548 |
+
|
549 |
+
pts_exp = pts.reshape(-1, 3)
|
550 |
+
# minn_pts, _ = torch.min(pts_exp, dim=0)
|
551 |
+
# maxx_pts, _ = torch.max(pts_exp, dim=0)
|
552 |
+
# print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
|
553 |
+
|
554 |
+
# pts_to_near = pts - near.unsqueeze(1)
|
555 |
+
# maxx_pts = 1.5; minn_pts = -1.5
|
556 |
+
# # maxx_pts = 3; minn_pts = -3
|
557 |
+
# # maxx_pts = 1; minn_pts = -1
|
558 |
+
# pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
|
559 |
+
|
560 |
+
## render and iamges ####
|
561 |
+
if use_gt_sdf:
|
562 |
+
### use the GT sdf field ####
|
563 |
+
# print(f"Using gt sdf :")
|
564 |
+
sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
|
565 |
+
sdf = torch.from_numpy(sdf).float().cuda()
|
566 |
+
sdf = sdf.reshape(batch_size, self.n_samples)
|
567 |
+
### use the GT sdf field ####
|
568 |
+
else:
|
569 |
+
#### use the optimized sdf field ####
|
570 |
+
sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
571 |
+
#### use the optimized sdf field ####
|
572 |
+
|
573 |
+
for i in range(self.up_sample_steps):
|
574 |
+
new_z_vals = self.up_sample(rays_o,
|
575 |
+
rays_d,
|
576 |
+
z_vals,
|
577 |
+
sdf,
|
578 |
+
self.n_importance // self.up_sample_steps,
|
579 |
+
64 * 2**i)
|
580 |
+
z_vals, sdf = self.cat_z_vals(rays_o,
|
581 |
+
rays_d,
|
582 |
+
z_vals,
|
583 |
+
new_z_vals,
|
584 |
+
sdf,
|
585 |
+
last=(i + 1 == self.up_sample_steps))
|
586 |
+
|
587 |
+
n_samples = self.n_samples + self.n_importance
|
588 |
+
|
589 |
+
# Background model
|
590 |
+
if self.n_outside > 0:
|
591 |
+
z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
|
592 |
+
z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
|
593 |
+
ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf)
|
594 |
+
|
595 |
+
background_sampled_color = ret_outside['sampled_color']
|
596 |
+
background_alpha = ret_outside['alpha']
|
597 |
+
|
598 |
+
# Render core
|
599 |
+
ret_fine = self.render_core(rays_o, #
|
600 |
+
rays_d,
|
601 |
+
z_vals,
|
602 |
+
sample_dist,
|
603 |
+
self.sdf_network,
|
604 |
+
self.deviation_network,
|
605 |
+
self.color_network,
|
606 |
+
background_rgb=background_rgb,
|
607 |
+
background_alpha=background_alpha,
|
608 |
+
background_sampled_color=background_sampled_color,
|
609 |
+
cos_anneal_ratio=cos_anneal_ratio)
|
610 |
+
|
611 |
+
color_fine = ret_fine['color']
|
612 |
+
weights = ret_fine['weights']
|
613 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
614 |
+
gradients = ret_fine['gradients']
|
615 |
+
s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
|
616 |
+
|
617 |
+
return {
|
618 |
+
'color_fine': color_fine,
|
619 |
+
's_val': s_val,
|
620 |
+
'cdf_fine': ret_fine['cdf'],
|
621 |
+
'weight_sum': weights_sum,
|
622 |
+
'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
|
623 |
+
'gradients': gradients,
|
624 |
+
'weights': weights,
|
625 |
+
'gradient_error': ret_fine['gradient_error'],
|
626 |
+
'inside_sphere': ret_fine['inside_sphere']
|
627 |
+
}
|
628 |
+
|
629 |
+
def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
|
630 |
+
return extract_geometry(bound_min, # extract geometry #
|
631 |
+
bound_max,
|
632 |
+
resolution=resolution,
|
633 |
+
threshold=threshold,
|
634 |
+
query_func=lambda pts: -self.sdf_network.sdf(pts))
|
635 |
+
|
636 |
+
def extract_geometry_tets(self, bound_min, bound_max, resolution, threshold=0.0):
|
637 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
638 |
+
bound_max,
|
639 |
+
resolution=resolution,
|
640 |
+
threshold=threshold,
|
641 |
+
query_func=lambda pts: -self.sdf_network.sdf(pts))
|
models/renderer_def.py
ADDED
@@ -0,0 +1,725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import logging
|
6 |
+
import mcubes
|
7 |
+
from icecream import ic
|
8 |
+
import os
|
9 |
+
|
10 |
+
import trimesh
|
11 |
+
from pysdf import SDF
|
12 |
+
|
13 |
+
from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
|
14 |
+
|
15 |
+
def create_mt_variable(device):
|
16 |
+
triangle_table = torch.tensor(
|
17 |
+
[
|
18 |
+
[-1, -1, -1, -1, -1, -1],
|
19 |
+
[1, 0, 2, -1, -1, -1],
|
20 |
+
[4, 0, 3, -1, -1, -1],
|
21 |
+
[1, 4, 2, 1, 3, 4],
|
22 |
+
[3, 1, 5, -1, -1, -1],
|
23 |
+
[2, 3, 0, 2, 5, 3],
|
24 |
+
[1, 4, 0, 1, 5, 4],
|
25 |
+
[4, 2, 5, -1, -1, -1],
|
26 |
+
[4, 5, 2, -1, -1, -1],
|
27 |
+
[4, 1, 0, 4, 5, 1],
|
28 |
+
[3, 2, 0, 3, 5, 2],
|
29 |
+
[1, 3, 5, -1, -1, -1],
|
30 |
+
[4, 1, 2, 4, 3, 1],
|
31 |
+
[3, 0, 4, -1, -1, -1],
|
32 |
+
[2, 0, 1, -1, -1, -1],
|
33 |
+
[-1, -1, -1, -1, -1, -1]
|
34 |
+
], dtype=torch.long, device=device)
|
35 |
+
|
36 |
+
num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
|
37 |
+
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
|
38 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
|
39 |
+
return triangle_table, num_triangles_table, base_tet_edges, v_id
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=None):
|
44 |
+
# load tet via resolution #
|
45 |
+
# scale them via bounds #
|
46 |
+
# extract the geometry #
|
47 |
+
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
|
48 |
+
device = bound_min.device
|
49 |
+
# if resolution in [64, 70, 80, 90, 100]:
|
50 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
|
51 |
+
# else:
|
52 |
+
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
53 |
+
tets = np.load(tet_fn)
|
54 |
+
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
|
55 |
+
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
|
56 |
+
# split #
|
57 |
+
# verts; verts; #
|
58 |
+
minn_verts, _ = torch.min(verts, dim=0)
|
59 |
+
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
|
60 |
+
# scale_verts = maxx_verts - minn_verts
|
61 |
+
scale_bounds = bound_max - bound_min # scale bounds #
|
62 |
+
|
63 |
+
### scale the vertices ###
|
64 |
+
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
65 |
+
|
66 |
+
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
67 |
+
|
68 |
+
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
|
69 |
+
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
|
70 |
+
|
71 |
+
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
|
72 |
+
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
|
73 |
+
|
74 |
+
sdf_values = []
|
75 |
+
N = 64
|
76 |
+
query_bundles = N ** 3 ### N^3
|
77 |
+
query_NNs = scaled_verts.size(0) // query_bundles
|
78 |
+
if query_NNs * query_bundles < scaled_verts.size(0):
|
79 |
+
query_NNs += 1
|
80 |
+
for i_query in range(query_NNs):
|
81 |
+
cur_bundle_st = i_query * query_bundles
|
82 |
+
cur_bundle_ed = (i_query + 1) * query_bundles
|
83 |
+
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
|
84 |
+
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
|
85 |
+
if def_func is not None:
|
86 |
+
cur_query_pts = def_func(cur_query_pts)
|
87 |
+
cur_query_vals = query_func(cur_query_pts)
|
88 |
+
sdf_values.append(cur_query_vals)
|
89 |
+
sdf_values = torch.cat(sdf_values, dim=0)
|
90 |
+
# print(f"queryed sdf values: {sdf_values.size()}") #
|
91 |
+
|
92 |
+
GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
|
93 |
+
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
|
94 |
+
|
95 |
+
# intrinsic, tet values, pts values, sdf network #
|
96 |
+
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
|
97 |
+
tet_table, num_tets_table = create_tetmesh_variables(device)
|
98 |
+
|
99 |
+
sdf_values = sdf_values.squeeze(-1) # how the rendering #
|
100 |
+
|
101 |
+
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
|
102 |
+
# print(f"scaled_verts: {scaled_verts.size()}, ")
|
103 |
+
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
104 |
+
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
105 |
+
# marching_tets_tetmesh ##
|
106 |
+
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
107 |
+
### use the GT sdf values for the marching tets ###
|
108 |
+
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
109 |
+
|
110 |
+
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
|
111 |
+
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
|
112 |
+
|
113 |
+
def extract_fields(bound_min, bound_max, resolution, query_func):
|
114 |
+
N = 64
|
115 |
+
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
|
116 |
+
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
|
117 |
+
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
|
118 |
+
|
119 |
+
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
120 |
+
with torch.no_grad():
|
121 |
+
for xi, xs in enumerate(X):
|
122 |
+
for yi, ys in enumerate(Y):
|
123 |
+
for zi, zs in enumerate(Z):
|
124 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
125 |
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
|
126 |
+
val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
|
127 |
+
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
|
128 |
+
# should save u here #
|
129 |
+
# save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
|
130 |
+
# np.save(save_u_path, u)
|
131 |
+
# print(f"u saved to {save_u_path}")
|
132 |
+
return u
|
133 |
+
|
134 |
+
|
135 |
+
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
|
136 |
+
print('threshold: {}'.format(threshold))
|
137 |
+
|
138 |
+
## using maching cubes ###
|
139 |
+
u = extract_fields(bound_min, bound_max, resolution, query_func)
|
140 |
+
vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
|
141 |
+
b_max_np = bound_max.detach().cpu().numpy()
|
142 |
+
b_min_np = bound_min.detach().cpu().numpy()
|
143 |
+
|
144 |
+
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
145 |
+
### using maching cubes ###
|
146 |
+
|
147 |
+
### using marching tets ###
|
148 |
+
# vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
|
149 |
+
# vertices = vertices.detach().cpu().numpy()
|
150 |
+
# triangles = triangles.detach().cpu().numpy()
|
151 |
+
### using marching tets ###
|
152 |
+
|
153 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
154 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
155 |
+
|
156 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
157 |
+
return vertices, triangles
|
158 |
+
|
159 |
+
def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func, def_func=None):
|
160 |
+
# print('threshold: {}'.format(threshold))
|
161 |
+
|
162 |
+
### using maching cubes ###
|
163 |
+
# u = extract_fields(bound_min, bound_max, resolution, query_func)
|
164 |
+
# vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
|
165 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
166 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
167 |
+
|
168 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
169 |
+
### using maching cubes ###
|
170 |
+
|
171 |
+
##
|
172 |
+
### using marching tets ### fiels from tets ##
|
173 |
+
vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=def_func)
|
174 |
+
# vertices = vertices.detach().cpu().numpy()
|
175 |
+
# triangles = triangles.detach().cpu().numpy()
|
176 |
+
### using marching tets ###
|
177 |
+
|
178 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
179 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
180 |
+
#
|
181 |
+
|
182 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
183 |
+
return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
|
184 |
+
|
185 |
+
|
186 |
+
def sample_pdf(bins, weights, n_samples, det=False):
|
187 |
+
# This implementation is from NeRF
|
188 |
+
# Get pdf
|
189 |
+
weights = weights + 1e-5 # prevent nans
|
190 |
+
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
191 |
+
cdf = torch.cumsum(pdf, -1)
|
192 |
+
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
193 |
+
# Take uniform samples
|
194 |
+
if det:
|
195 |
+
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
|
196 |
+
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
197 |
+
else:
|
198 |
+
u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
|
199 |
+
|
200 |
+
# Invert CDF # invert cdf #
|
201 |
+
u = u.contiguous()
|
202 |
+
inds = torch.searchsorted(cdf, u, right=True)
|
203 |
+
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
204 |
+
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
205 |
+
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
206 |
+
|
207 |
+
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
208 |
+
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
209 |
+
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
210 |
+
|
211 |
+
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
212 |
+
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
213 |
+
t = (u - cdf_g[..., 0]) / denom
|
214 |
+
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
215 |
+
|
216 |
+
return samples
|
217 |
+
|
218 |
+
|
219 |
+
def load_GT_vertices(GT_meshes_folder):
|
220 |
+
tot_meshes_fns = os.listdir(GT_meshes_folder)
|
221 |
+
tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
|
222 |
+
tot_mesh_verts = []
|
223 |
+
tot_mesh_faces = []
|
224 |
+
n_tot_verts = 0
|
225 |
+
for fn in tot_meshes_fns:
|
226 |
+
cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
|
227 |
+
obj_mesh = trimesh.load(cur_mesh_fn, process=False)
|
228 |
+
# obj_mesh.remove_degenerate_faces(height=1e-06)
|
229 |
+
|
230 |
+
verts_obj = np.array(obj_mesh.vertices)
|
231 |
+
faces_obj = np.array(obj_mesh.faces)
|
232 |
+
|
233 |
+
tot_mesh_verts.append(verts_obj)
|
234 |
+
tot_mesh_faces.append(faces_obj + n_tot_verts)
|
235 |
+
n_tot_verts += verts_obj.shape[0]
|
236 |
+
|
237 |
+
# tot_mesh_faces.append(faces_obj)
|
238 |
+
tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
|
239 |
+
tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
|
240 |
+
return tot_mesh_verts, tot_mesh_faces
|
241 |
+
|
242 |
+
|
243 |
+
class NeuSRenderer:
|
244 |
+
def __init__(self,
|
245 |
+
nerf,
|
246 |
+
sdf_network,
|
247 |
+
deviation_network,
|
248 |
+
color_network,
|
249 |
+
n_samples,
|
250 |
+
n_importance,
|
251 |
+
n_outside,
|
252 |
+
up_sample_steps,
|
253 |
+
perturb):
|
254 |
+
self.nerf = nerf
|
255 |
+
self.sdf_network = sdf_network
|
256 |
+
self.deviation_network = deviation_network
|
257 |
+
self.color_network = color_network
|
258 |
+
self.n_samples = n_samples
|
259 |
+
self.n_importance = n_importance
|
260 |
+
self.n_outside = n_outside
|
261 |
+
self.up_sample_steps = up_sample_steps
|
262 |
+
self.perturb = perturb
|
263 |
+
|
264 |
+
GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
|
265 |
+
self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
|
266 |
+
maxx_pts = 25.
|
267 |
+
minn_pts = -15.
|
268 |
+
self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
|
269 |
+
f = SDF(self.mesh_vertices, self.mesh_faces)
|
270 |
+
self.gt_sdf = f ## a unite sphere or box
|
271 |
+
|
272 |
+
self.minn_pts = 0
|
273 |
+
self.maxx_pts = 1.
|
274 |
+
|
275 |
+
# self.minn_pts = -1.5 # gorudn-truth states with the deformation -> update the sdf value fiedl
|
276 |
+
# self.maxx_pts = 1.5
|
277 |
+
self.bkg_pts = ... # TODO: the bkg pts # bkg_pts; # bkg_pts_defs #
|
278 |
+
self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
|
279 |
+
self.dist_interp_thres = ... # TODO: set the cur_bkg_pts_defs #
|
280 |
+
|
281 |
+
self.bending_network = ... # TODO: add the bending network #
|
282 |
+
self.use_bending_network = ... # TODO: set the property #
|
283 |
+
self.use_delta_bending = ... # TODO
|
284 |
+
# use bending network #
|
285 |
+
|
286 |
+
|
287 |
+
# get the pts and render the pts #
|
288 |
+
# pts and the rendering pts #
|
289 |
+
def deform_pts(self, pts, pts_ts=0):
|
290 |
+
|
291 |
+
if self.use_bending_network:
|
292 |
+
if len(pts.size()) == 3:
|
293 |
+
nnb, nns = pts.size(0), pts.size(1)
|
294 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
295 |
+
else:
|
296 |
+
pts_exp = pts
|
297 |
+
# pts_ts #
|
298 |
+
if self.use_delta_bending:
|
299 |
+
# if pts_ts >= 5:
|
300 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
|
301 |
+
# for cur_pts_ts in range(4, -1, -1):
|
302 |
+
# # print(f"using delta bending with pts_ts: {cur_pts_ts}")
|
303 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
304 |
+
# else:
|
305 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
306 |
+
# # print(f"using delta bending with pts_ts: {cur_pts_ts}")
|
307 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
308 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
309 |
+
# print(f"using delta bending with pts_ts: {cur_pts_ts}")
|
310 |
+
pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
311 |
+
else:
|
312 |
+
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
|
313 |
+
if len(pts.size()) == 3:
|
314 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
315 |
+
else:
|
316 |
+
pts = pts_exp
|
317 |
+
return pts
|
318 |
+
|
319 |
+
# pts: nn_batch x nn_samples x 3
|
320 |
+
if len(pts.size()) == 3:
|
321 |
+
nnb, nns = pts.size(0), pts.size(1)
|
322 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
323 |
+
else:
|
324 |
+
pts_exp = pts
|
325 |
+
# print(f"prior to deforming: {pts.size()}")
|
326 |
+
|
327 |
+
dist_pts_to_bkg_pts = torch.sum(
|
328 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
329 |
+
)
|
330 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
331 |
+
dist_mask_float = dist_mask.float()
|
332 |
+
|
333 |
+
# dist_mask_float #
|
334 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
335 |
+
cur_fr_pts_def = torch.sum(
|
336 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
337 |
+
)
|
338 |
+
dist_mask_float_summ = torch.sum(
|
339 |
+
dist_mask_float, dim=1
|
340 |
+
)
|
341 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
342 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
343 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
344 |
+
if len(pts.size()) == 3:
|
345 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
346 |
+
else:
|
347 |
+
pts = pts_exp
|
348 |
+
return pts #
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None, pts_ts=0):
|
354 |
+
"""
|
355 |
+
Render background
|
356 |
+
"""
|
357 |
+
batch_size, n_samples = z_vals.shape
|
358 |
+
|
359 |
+
# Section length
|
360 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
361 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
362 |
+
mid_z_vals = z_vals + dists * 0.5
|
363 |
+
|
364 |
+
# Section midpoints #
|
365 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
|
366 |
+
|
367 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
368 |
+
pts = pts * 2 - 1
|
369 |
+
|
370 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
371 |
+
|
372 |
+
dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
|
373 |
+
pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
|
374 |
+
|
375 |
+
dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
|
376 |
+
|
377 |
+
pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
|
378 |
+
dirs = dirs.reshape(-1, 3)
|
379 |
+
|
380 |
+
density, sampled_color = nerf(pts, dirs)
|
381 |
+
sampled_color = torch.sigmoid(sampled_color)
|
382 |
+
alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
|
383 |
+
alpha = alpha.reshape(batch_size, n_samples)
|
384 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
385 |
+
sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
|
386 |
+
color = (weights[:, :, None] * sampled_color).sum(dim=1)
|
387 |
+
if background_rgb is not None:
|
388 |
+
color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
|
389 |
+
|
390 |
+
return {
|
391 |
+
'color': color,
|
392 |
+
'sampled_color': sampled_color,
|
393 |
+
'alpha': alpha,
|
394 |
+
'weights': weights,
|
395 |
+
}
|
396 |
+
|
397 |
+
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s, pts_ts=0):
|
398 |
+
"""
|
399 |
+
Up sampling give a fixed inv_s
|
400 |
+
"""
|
401 |
+
batch_size, n_samples = z_vals.shape
|
402 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
|
403 |
+
|
404 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
405 |
+
pts = pts * 2 - 1
|
406 |
+
|
407 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
408 |
+
|
409 |
+
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
|
410 |
+
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
|
411 |
+
sdf = sdf.reshape(batch_size, n_samples)
|
412 |
+
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
|
413 |
+
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
|
414 |
+
mid_sdf = (prev_sdf + next_sdf) * 0.5
|
415 |
+
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
|
416 |
+
|
417 |
+
# ----------------------------------------------------------------------------------------------------------
|
418 |
+
# Use min value of [ cos, prev_cos ]
|
419 |
+
# Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
|
420 |
+
# robust when meeting situations like below:
|
421 |
+
#
|
422 |
+
# SDF
|
423 |
+
# ^
|
424 |
+
# |\ -----x----...
|
425 |
+
# | \ /
|
426 |
+
# | x x
|
427 |
+
# |---\----/-------------> 0 level
|
428 |
+
# | \ /
|
429 |
+
# | \/
|
430 |
+
# |
|
431 |
+
# ----------------------------------------------------------------------------------------------------------
|
432 |
+
prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
|
433 |
+
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
|
434 |
+
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
|
435 |
+
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
|
436 |
+
|
437 |
+
dist = (next_z_vals - prev_z_vals)
|
438 |
+
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
|
439 |
+
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
|
440 |
+
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
|
441 |
+
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
|
442 |
+
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
|
443 |
+
weights = alpha * torch.cumprod(
|
444 |
+
torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
445 |
+
|
446 |
+
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
|
447 |
+
return z_samples
|
448 |
+
|
449 |
+
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False, pts_ts=0):
|
450 |
+
batch_size, n_samples = z_vals.shape
|
451 |
+
_, n_importance = new_z_vals.shape
|
452 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
|
453 |
+
|
454 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
455 |
+
pts = pts * 2 - 1
|
456 |
+
|
457 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
458 |
+
|
459 |
+
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
|
460 |
+
z_vals, index = torch.sort(z_vals, dim=-1)
|
461 |
+
|
462 |
+
if not last:
|
463 |
+
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
464 |
+
sdf = torch.cat([sdf, new_sdf], dim=-1)
|
465 |
+
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
|
466 |
+
index = index.reshape(-1)
|
467 |
+
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
|
468 |
+
|
469 |
+
return z_vals, sdf
|
470 |
+
|
471 |
+
def render_core(self,
|
472 |
+
rays_o,
|
473 |
+
rays_d,
|
474 |
+
z_vals,
|
475 |
+
sample_dist,
|
476 |
+
sdf_network,
|
477 |
+
deviation_network,
|
478 |
+
color_network,
|
479 |
+
background_alpha=None,
|
480 |
+
background_sampled_color=None,
|
481 |
+
background_rgb=None,
|
482 |
+
cos_anneal_ratio=0.0,
|
483 |
+
pts_ts=0):
|
484 |
+
batch_size, n_samples = z_vals.shape
|
485 |
+
|
486 |
+
# Section length
|
487 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
488 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
489 |
+
mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
|
490 |
+
|
491 |
+
# Section midpoints
|
492 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
|
493 |
+
dirs = rays_d[:, None, :].expand(pts.shape)
|
494 |
+
|
495 |
+
pts = pts.reshape(-1, 3) # pts, nn_ou
|
496 |
+
dirs = dirs.reshape(-1, 3)
|
497 |
+
|
498 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
499 |
+
|
500 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
501 |
+
pts = pts * 2 - 1
|
502 |
+
|
503 |
+
|
504 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
505 |
+
|
506 |
+
sdf_nn_output = sdf_network(pts)
|
507 |
+
sdf = sdf_nn_output[:, :1]
|
508 |
+
feature_vector = sdf_nn_output[:, 1:]
|
509 |
+
|
510 |
+
gradients = sdf_network.gradient(pts).squeeze()
|
511 |
+
sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
|
512 |
+
|
513 |
+
# deviation network #
|
514 |
+
inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
|
515 |
+
inv_s = inv_s.expand(batch_size * n_samples, 1)
|
516 |
+
|
517 |
+
true_cos = (dirs * gradients).sum(-1, keepdim=True)
|
518 |
+
|
519 |
+
# "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
|
520 |
+
# the cos value "not dead" at the beginning training iterations, for better convergence.
|
521 |
+
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
|
522 |
+
F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
|
523 |
+
|
524 |
+
# Estimate signed distances at section points
|
525 |
+
estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
|
526 |
+
estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
|
527 |
+
|
528 |
+
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
|
529 |
+
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
|
530 |
+
|
531 |
+
p = prev_cdf - next_cdf
|
532 |
+
c = prev_cdf
|
533 |
+
|
534 |
+
alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
|
535 |
+
|
536 |
+
pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
|
537 |
+
inside_sphere = (pts_norm < 1.0).float().detach()
|
538 |
+
relax_inside_sphere = (pts_norm < 1.2).float().detach()
|
539 |
+
|
540 |
+
# Render with background
|
541 |
+
if background_alpha is not None:
|
542 |
+
alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
|
543 |
+
alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
|
544 |
+
sampled_color = sampled_color * inside_sphere[:, :, None] +\
|
545 |
+
background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
|
546 |
+
sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
|
547 |
+
|
548 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
549 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
550 |
+
|
551 |
+
color = (sampled_color * weights[:, :, None]).sum(dim=1)
|
552 |
+
if background_rgb is not None: # Fixed background, usually black
|
553 |
+
color = color + background_rgb * (1.0 - weights_sum)
|
554 |
+
|
555 |
+
# Eikonal loss
|
556 |
+
gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
|
557 |
+
dim=-1) - 1.0) ** 2
|
558 |
+
gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
|
559 |
+
|
560 |
+
return {
|
561 |
+
'color': color,
|
562 |
+
'sdf': sdf,
|
563 |
+
'dists': dists,
|
564 |
+
'gradients': gradients.reshape(batch_size, n_samples, 3),
|
565 |
+
's_val': 1.0 / inv_s,
|
566 |
+
'mid_z_vals': mid_z_vals,
|
567 |
+
'weights': weights,
|
568 |
+
'cdf': c.reshape(batch_size, n_samples),
|
569 |
+
'gradient_error': gradient_error,
|
570 |
+
'inside_sphere': inside_sphere
|
571 |
+
}
|
572 |
+
|
573 |
+
def render(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
|
574 |
+
batch_size = len(rays_o)
|
575 |
+
sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere
|
576 |
+
z_vals = torch.linspace(0.0, 1.0, self.n_samples)
|
577 |
+
z_vals = near + (far - near) * z_vals[None, :]
|
578 |
+
|
579 |
+
z_vals_outside = None
|
580 |
+
if self.n_outside > 0:
|
581 |
+
z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
|
582 |
+
|
583 |
+
n_samples = self.n_samples
|
584 |
+
perturb = self.perturb
|
585 |
+
|
586 |
+
if perturb_overwrite >= 0:
|
587 |
+
perturb = perturb_overwrite
|
588 |
+
if perturb > 0:
|
589 |
+
t_rand = (torch.rand([batch_size, 1]) - 0.5)
|
590 |
+
z_vals = z_vals + t_rand * 2.0 / self.n_samples
|
591 |
+
|
592 |
+
if self.n_outside > 0:
|
593 |
+
mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
|
594 |
+
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
|
595 |
+
lower = torch.cat([z_vals_outside[..., :1], mids], -1)
|
596 |
+
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
|
597 |
+
z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
|
598 |
+
|
599 |
+
if self.n_outside > 0:
|
600 |
+
z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
|
601 |
+
|
602 |
+
background_alpha = None
|
603 |
+
background_sampled_color = None
|
604 |
+
|
605 |
+
# Up sample
|
606 |
+
if self.n_importance > 0:
|
607 |
+
with torch.no_grad():
|
608 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
609 |
+
|
610 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
611 |
+
# sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
|
612 |
+
# gt_sdf #
|
613 |
+
|
614 |
+
#
|
615 |
+
# pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
|
616 |
+
|
617 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
618 |
+
pts = pts * 2 - 1
|
619 |
+
|
620 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
621 |
+
|
622 |
+
pts_exp = pts.reshape(-1, 3)
|
623 |
+
# minn_pts, _ = torch.min(pts_exp, dim=0)
|
624 |
+
# maxx_pts, _ = torch.max(pts_exp, dim=0) # deformation field (not a rigid one) -> the meshes #
|
625 |
+
# print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
|
626 |
+
|
627 |
+
# pts_to_near = pts - near.unsqueeze(1)
|
628 |
+
# maxx_pts = 1.5; minn_pts = -1.5
|
629 |
+
# # maxx_pts = 3; minn_pts = -3
|
630 |
+
# # maxx_pts = 1; minn_pts = -1
|
631 |
+
# pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
|
632 |
+
|
633 |
+
## render and iamges ####
|
634 |
+
if use_gt_sdf:
|
635 |
+
### use the GT sdf field ####
|
636 |
+
# print(f"Using gt sdf :")
|
637 |
+
sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
|
638 |
+
sdf = torch.from_numpy(sdf).float().cuda()
|
639 |
+
sdf = sdf.reshape(batch_size, self.n_samples)
|
640 |
+
### use the GT sdf field ####
|
641 |
+
else:
|
642 |
+
#### use the optimized sdf field ####
|
643 |
+
sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
644 |
+
#### use the optimized sdf field ####
|
645 |
+
|
646 |
+
for i in range(self.up_sample_steps):
|
647 |
+
new_z_vals = self.up_sample(rays_o,
|
648 |
+
rays_d,
|
649 |
+
z_vals,
|
650 |
+
sdf,
|
651 |
+
self.n_importance // self.up_sample_steps,
|
652 |
+
64 * 2**i,
|
653 |
+
pts_ts=pts_ts)
|
654 |
+
z_vals, sdf = self.cat_z_vals(rays_o,
|
655 |
+
rays_d,
|
656 |
+
z_vals,
|
657 |
+
new_z_vals,
|
658 |
+
sdf,
|
659 |
+
last=(i + 1 == self.up_sample_steps),
|
660 |
+
pts_ts=pts_ts)
|
661 |
+
|
662 |
+
n_samples = self.n_samples + self.n_importance
|
663 |
+
|
664 |
+
# Background model
|
665 |
+
if self.n_outside > 0:
|
666 |
+
z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
|
667 |
+
z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
|
668 |
+
ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf, pts_ts=pts_ts)
|
669 |
+
|
670 |
+
background_sampled_color = ret_outside['sampled_color']
|
671 |
+
background_alpha = ret_outside['alpha']
|
672 |
+
|
673 |
+
# Render core
|
674 |
+
ret_fine = self.render_core(rays_o, #
|
675 |
+
rays_d,
|
676 |
+
z_vals,
|
677 |
+
sample_dist,
|
678 |
+
self.sdf_network,
|
679 |
+
self.deviation_network,
|
680 |
+
self.color_network,
|
681 |
+
background_rgb=background_rgb,
|
682 |
+
background_alpha=background_alpha,
|
683 |
+
background_sampled_color=background_sampled_color,
|
684 |
+
cos_anneal_ratio=cos_anneal_ratio,
|
685 |
+
pts_ts=pts_ts)
|
686 |
+
|
687 |
+
color_fine = ret_fine['color']
|
688 |
+
weights = ret_fine['weights']
|
689 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
690 |
+
gradients = ret_fine['gradients']
|
691 |
+
s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
|
692 |
+
|
693 |
+
return {
|
694 |
+
'color_fine': color_fine,
|
695 |
+
's_val': s_val,
|
696 |
+
'cdf_fine': ret_fine['cdf'],
|
697 |
+
'weight_sum': weights_sum,
|
698 |
+
'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
|
699 |
+
'gradients': gradients,
|
700 |
+
'weights': weights,
|
701 |
+
'gradient_error': ret_fine['gradient_error'],
|
702 |
+
'inside_sphere': ret_fine['inside_sphere']
|
703 |
+
}
|
704 |
+
|
705 |
+
def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
|
706 |
+
return extract_geometry(bound_min, # extract geometry #
|
707 |
+
bound_max,
|
708 |
+
resolution=resolution,
|
709 |
+
threshold=threshold,
|
710 |
+
query_func=lambda pts: -self.sdf_network.sdf(pts))
|
711 |
+
|
712 |
+
def extract_geometry_tets(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
|
713 |
+
if wdef:
|
714 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
715 |
+
bound_max,
|
716 |
+
resolution=resolution,
|
717 |
+
threshold=threshold,
|
718 |
+
query_func=lambda pts: -self.sdf_network.sdf(pts),
|
719 |
+
def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts))
|
720 |
+
else:
|
721 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
722 |
+
bound_max,
|
723 |
+
resolution=resolution,
|
724 |
+
threshold=threshold,
|
725 |
+
query_func=lambda pts: -self.sdf_network.sdf(pts))
|
models/renderer_def_multi_objs.py
ADDED
@@ -0,0 +1,1088 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import logging
|
6 |
+
import mcubes
|
7 |
+
from icecream import ic
|
8 |
+
import os
|
9 |
+
|
10 |
+
import trimesh
|
11 |
+
from pysdf import SDF
|
12 |
+
|
13 |
+
import models.fields as fields
|
14 |
+
|
15 |
+
from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
|
16 |
+
|
17 |
+
def batched_index_select(values, indices, dim = 1):
|
18 |
+
value_dims = values.shape[(dim + 1):]
|
19 |
+
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
|
20 |
+
indices = indices[(..., *((None,) * len(value_dims)))]
|
21 |
+
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
|
22 |
+
value_expand_len = len(indices_shape) - (dim + 1)
|
23 |
+
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
|
24 |
+
|
25 |
+
value_expand_shape = [-1] * len(values.shape)
|
26 |
+
expand_slice = slice(dim, (dim + value_expand_len))
|
27 |
+
value_expand_shape[expand_slice] = indices.shape[expand_slice]
|
28 |
+
values = values.expand(*value_expand_shape)
|
29 |
+
|
30 |
+
dim += value_expand_len
|
31 |
+
return values.gather(dim, indices)
|
32 |
+
|
33 |
+
|
34 |
+
def create_mt_variable(device):
|
35 |
+
triangle_table = torch.tensor(
|
36 |
+
[
|
37 |
+
[-1, -1, -1, -1, -1, -1],
|
38 |
+
[1, 0, 2, -1, -1, -1],
|
39 |
+
[4, 0, 3, -1, -1, -1],
|
40 |
+
[1, 4, 2, 1, 3, 4],
|
41 |
+
[3, 1, 5, -1, -1, -1],
|
42 |
+
[2, 3, 0, 2, 5, 3],
|
43 |
+
[1, 4, 0, 1, 5, 4],
|
44 |
+
[4, 2, 5, -1, -1, -1],
|
45 |
+
[4, 5, 2, -1, -1, -1],
|
46 |
+
[4, 1, 0, 4, 5, 1],
|
47 |
+
[3, 2, 0, 3, 5, 2],
|
48 |
+
[1, 3, 5, -1, -1, -1],
|
49 |
+
[4, 1, 2, 4, 3, 1],
|
50 |
+
[3, 0, 4, -1, -1, -1],
|
51 |
+
[2, 0, 1, -1, -1, -1],
|
52 |
+
[-1, -1, -1, -1, -1, -1]
|
53 |
+
], dtype=torch.long, device=device)
|
54 |
+
|
55 |
+
num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
|
56 |
+
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
|
57 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
|
58 |
+
return triangle_table, num_triangles_table, base_tet_edges, v_id
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
def extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=None):
|
63 |
+
# load tet via resolution #
|
64 |
+
# scale them via bounds #
|
65 |
+
# extract the geometry #
|
66 |
+
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
|
67 |
+
device = bound_min.device
|
68 |
+
# if resolution in [64, 70, 80, 90, 100]:
|
69 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
|
70 |
+
# else:
|
71 |
+
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
72 |
+
tets = np.load(tet_fn)
|
73 |
+
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
|
74 |
+
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
|
75 |
+
# split #
|
76 |
+
# verts; verts; #
|
77 |
+
minn_verts, _ = torch.min(verts, dim=0)
|
78 |
+
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
|
79 |
+
# scale_verts = maxx_verts - minn_verts
|
80 |
+
scale_bounds = bound_max - bound_min # scale bounds #
|
81 |
+
|
82 |
+
### scale the vertices ###
|
83 |
+
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
84 |
+
|
85 |
+
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
86 |
+
|
87 |
+
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
|
88 |
+
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
|
89 |
+
|
90 |
+
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
|
91 |
+
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
|
92 |
+
|
93 |
+
sdf_values = []
|
94 |
+
N = 64
|
95 |
+
query_bundles = N ** 3 ### N^3
|
96 |
+
query_NNs = scaled_verts.size(0) // query_bundles
|
97 |
+
if query_NNs * query_bundles < scaled_verts.size(0):
|
98 |
+
query_NNs += 1
|
99 |
+
for i_query in range(query_NNs):
|
100 |
+
cur_bundle_st = i_query * query_bundles
|
101 |
+
cur_bundle_ed = (i_query + 1) * query_bundles
|
102 |
+
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
|
103 |
+
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
|
104 |
+
if def_func is not None:
|
105 |
+
cur_query_pts = def_func(cur_query_pts)
|
106 |
+
cur_query_vals = query_func(cur_query_pts)
|
107 |
+
sdf_values.append(cur_query_vals)
|
108 |
+
sdf_values = torch.cat(sdf_values, dim=0)
|
109 |
+
# print(f"queryed sdf values: {sdf_values.size()}") #
|
110 |
+
|
111 |
+
GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
|
112 |
+
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
|
113 |
+
|
114 |
+
# intrinsic, tet values, pts values, sdf network #
|
115 |
+
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
|
116 |
+
tet_table, num_tets_table = create_tetmesh_variables(device)
|
117 |
+
|
118 |
+
sdf_values = sdf_values.squeeze(-1) # how the rendering #
|
119 |
+
|
120 |
+
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
|
121 |
+
# print(f"scaled_verts: {scaled_verts.size()}, ")
|
122 |
+
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
123 |
+
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
124 |
+
# marching_tets_tetmesh ##
|
125 |
+
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
126 |
+
### use the GT sdf values for the marching tets ###
|
127 |
+
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
128 |
+
|
129 |
+
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
|
130 |
+
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
|
131 |
+
|
132 |
+
def extract_fields(bound_min, bound_max, resolution, query_func):
|
133 |
+
N = 64
|
134 |
+
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
|
135 |
+
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
|
136 |
+
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
|
137 |
+
|
138 |
+
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
139 |
+
with torch.no_grad():
|
140 |
+
for xi, xs in enumerate(X):
|
141 |
+
for yi, ys in enumerate(Y):
|
142 |
+
for zi, zs in enumerate(Z):
|
143 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
144 |
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
|
145 |
+
val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
|
146 |
+
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
|
147 |
+
# should save u here #
|
148 |
+
# save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
|
149 |
+
# np.save(save_u_path, u)
|
150 |
+
# print(f"u saved to {save_u_path}")
|
151 |
+
return u
|
152 |
+
|
153 |
+
|
154 |
+
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
|
155 |
+
print('threshold: {}'.format(threshold))
|
156 |
+
|
157 |
+
## using maching cubes ###
|
158 |
+
u = extract_fields(bound_min, bound_max, resolution, query_func)
|
159 |
+
vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
|
160 |
+
b_max_np = bound_max.detach().cpu().numpy()
|
161 |
+
b_min_np = bound_min.detach().cpu().numpy()
|
162 |
+
|
163 |
+
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
164 |
+
### using maching cubes ###
|
165 |
+
|
166 |
+
### using marching tets ###
|
167 |
+
# vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
|
168 |
+
# vertices = vertices.detach().cpu().numpy()
|
169 |
+
# triangles = triangles.detach().cpu().numpy()
|
170 |
+
### using marching tets ###
|
171 |
+
|
172 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
173 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
174 |
+
|
175 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
176 |
+
return vertices, triangles
|
177 |
+
|
178 |
+
def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func, def_func=None):
|
179 |
+
# print('threshold: {}'.format(threshold))
|
180 |
+
|
181 |
+
### using maching cubes ###
|
182 |
+
# u = extract_fields(bound_min, bound_max, resolution, query_func)
|
183 |
+
# vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
|
184 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
185 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
186 |
+
|
187 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
188 |
+
### using maching cubes ###
|
189 |
+
|
190 |
+
##
|
191 |
+
### using marching tets ### fiels from tets ##
|
192 |
+
vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=def_func)
|
193 |
+
# vertices = vertices.detach().cpu().numpy()
|
194 |
+
# triangles = triangles.detach().cpu().numpy()
|
195 |
+
### using marching tets ###
|
196 |
+
|
197 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
198 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
199 |
+
#
|
200 |
+
|
201 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
202 |
+
return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
|
203 |
+
|
204 |
+
|
205 |
+
def sample_pdf(bins, weights, n_samples, det=False):
|
206 |
+
# This implementation is from NeRF
|
207 |
+
# Get pdf
|
208 |
+
weights = weights + 1e-5 # prevent nans
|
209 |
+
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
210 |
+
cdf = torch.cumsum(pdf, -1)
|
211 |
+
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
212 |
+
# Take uniform samples
|
213 |
+
if det:
|
214 |
+
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
|
215 |
+
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
216 |
+
else:
|
217 |
+
u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
|
218 |
+
|
219 |
+
# Invert CDF # invert cdf #
|
220 |
+
u = u.contiguous()
|
221 |
+
inds = torch.searchsorted(cdf, u, right=True)
|
222 |
+
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
223 |
+
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
224 |
+
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
225 |
+
|
226 |
+
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
227 |
+
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
228 |
+
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
229 |
+
|
230 |
+
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
231 |
+
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
232 |
+
t = (u - cdf_g[..., 0]) / denom
|
233 |
+
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
234 |
+
|
235 |
+
return samples
|
236 |
+
|
237 |
+
|
238 |
+
def load_GT_vertices(GT_meshes_folder):
|
239 |
+
tot_meshes_fns = os.listdir(GT_meshes_folder)
|
240 |
+
tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
|
241 |
+
tot_mesh_verts = []
|
242 |
+
tot_mesh_faces = []
|
243 |
+
n_tot_verts = 0
|
244 |
+
for fn in tot_meshes_fns:
|
245 |
+
cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
|
246 |
+
obj_mesh = trimesh.load(cur_mesh_fn, process=False)
|
247 |
+
# obj_mesh.remove_degenerate_faces(height=1e-06)
|
248 |
+
|
249 |
+
verts_obj = np.array(obj_mesh.vertices)
|
250 |
+
faces_obj = np.array(obj_mesh.faces)
|
251 |
+
|
252 |
+
tot_mesh_verts.append(verts_obj)
|
253 |
+
tot_mesh_faces.append(faces_obj + n_tot_verts)
|
254 |
+
n_tot_verts += verts_obj.shape[0]
|
255 |
+
|
256 |
+
# tot_mesh_faces.append(faces_obj)
|
257 |
+
tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
|
258 |
+
tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
|
259 |
+
return tot_mesh_verts, tot_mesh_faces
|
260 |
+
|
261 |
+
|
262 |
+
class NeuSRenderer:
|
263 |
+
def __init__(self,
|
264 |
+
nerf,
|
265 |
+
sdf_network,
|
266 |
+
deviation_network,
|
267 |
+
color_network,
|
268 |
+
n_samples,
|
269 |
+
n_importance,
|
270 |
+
n_outside,
|
271 |
+
up_sample_steps,
|
272 |
+
perturb):
|
273 |
+
self.nerf = nerf # multiple sdf networks and deviation networks and xxx #
|
274 |
+
self.sdf_network = sdf_network
|
275 |
+
self.deviation_network = deviation_network
|
276 |
+
self.color_network = color_network
|
277 |
+
self.n_samples = n_samples
|
278 |
+
self.n_importance = n_importance
|
279 |
+
self.n_outside = n_outside
|
280 |
+
self.up_sample_steps = up_sample_steps
|
281 |
+
self.perturb = perturb
|
282 |
+
|
283 |
+
GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
|
284 |
+
self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
|
285 |
+
maxx_pts = 25.
|
286 |
+
minn_pts = -15.
|
287 |
+
self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
|
288 |
+
f = SDF(self.mesh_vertices, self.mesh_faces)
|
289 |
+
self.gt_sdf = f ## a unite sphere or box
|
290 |
+
|
291 |
+
self.minn_pts = 0
|
292 |
+
self.maxx_pts = 1.
|
293 |
+
|
294 |
+
# self.minn_pts = -1.5 # gorudn-truth states with the deformation -> update the sdf value fiedl
|
295 |
+
# self.maxx_pts = 1.5 #
|
296 |
+
self.bkg_pts = ... # TODO: the bkg pts # bkg_pts; # bkg_pts_defs #
|
297 |
+
self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
|
298 |
+
self.dist_interp_thres = ... # TODO: set the cur_bkg_pts_defs #
|
299 |
+
|
300 |
+
self.bending_network = ... # TODO: add the bending network #
|
301 |
+
self.use_bending_network = ... # TODO: set the property #
|
302 |
+
self.use_delta_bending = ... # TODO
|
303 |
+
self.prev_sdf_network = ... # TODO
|
304 |
+
self.use_selector = False
|
305 |
+
# use bending network #
|
306 |
+
# two bending netwrok
|
307 |
+
# two sdf networks
|
308 |
+
|
309 |
+
|
310 |
+
# get the pts and render the pts #
|
311 |
+
# pts and the rendering pts #
|
312 |
+
def deform_pts(self, pts, pts_ts=0): # deform pts #
|
313 |
+
|
314 |
+
if self.use_bending_network:
|
315 |
+
if len(pts.size()) == 3:
|
316 |
+
nnb, nns = pts.size(0), pts.size(1)
|
317 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
318 |
+
else:
|
319 |
+
pts_exp = pts
|
320 |
+
# pts_ts #
|
321 |
+
if self.use_delta_bending:
|
322 |
+
|
323 |
+
if isinstance(self.bending_network, list):
|
324 |
+
pts_offsets = []
|
325 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
326 |
+
if isinstance(cur_bending_network, fields.BendingNetwork):
|
327 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
328 |
+
cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
|
329 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
|
330 |
+
cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
331 |
+
else:
|
332 |
+
raise ValueError('Encountered with unexpected bending network class...')
|
333 |
+
pts_offsets.append(cur_pts_exp - pts_exp)
|
334 |
+
pts_offsets = torch.stack(pts_offsets, dim=0)
|
335 |
+
pts_offsets = torch.sum(pts_offsets, dim=0)
|
336 |
+
pts_exp = pts_exp + pts_offsets
|
337 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
338 |
+
# if isinstance(self.bending_network, list): # pts ts #
|
339 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
340 |
+
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
341 |
+
# else:
|
342 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
343 |
+
else:
|
344 |
+
if isinstance(self.bending_network, list): # prev sdf network #
|
345 |
+
pts_offsets = []
|
346 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
347 |
+
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
348 |
+
pts_offsets.append(bended_pts_exp - pts_exp)
|
349 |
+
pts_offsets = torch.stack(pts_offsets, dim=0)
|
350 |
+
pts_offsets = torch.sum(pts_offsets, dim=0)
|
351 |
+
pts_exp = pts_exp + pts_offsets
|
352 |
+
else:
|
353 |
+
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
|
354 |
+
if len(pts.size()) == 3:
|
355 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
356 |
+
else:
|
357 |
+
pts = pts_exp
|
358 |
+
return pts
|
359 |
+
|
360 |
+
# pts: nn_batch x nn_samples x 3
|
361 |
+
if len(pts.size()) == 3:
|
362 |
+
nnb, nns = pts.size(0), pts.size(1)
|
363 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
364 |
+
else:
|
365 |
+
pts_exp = pts
|
366 |
+
# print(f"prior to deforming: {pts.size()}")
|
367 |
+
|
368 |
+
dist_pts_to_bkg_pts = torch.sum(
|
369 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
370 |
+
)
|
371 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
372 |
+
dist_mask_float = dist_mask.float()
|
373 |
+
|
374 |
+
# dist_mask_float #
|
375 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
376 |
+
cur_fr_pts_def = torch.sum(
|
377 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
378 |
+
)
|
379 |
+
dist_mask_float_summ = torch.sum(
|
380 |
+
dist_mask_float, dim=1
|
381 |
+
)
|
382 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
383 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
384 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
385 |
+
if len(pts.size()) == 3:
|
386 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
387 |
+
else:
|
388 |
+
pts = pts_exp
|
389 |
+
return pts #
|
390 |
+
|
391 |
+
|
392 |
+
def deform_pts_with_selector(self, pts, pts_ts=0): # deform pts #
|
393 |
+
|
394 |
+
if self.use_bending_network:
|
395 |
+
if len(pts.size()) == 3:
|
396 |
+
nnb, nns = pts.size(0), pts.size(1)
|
397 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
398 |
+
else:
|
399 |
+
pts_exp = pts
|
400 |
+
# pts_ts #
|
401 |
+
if self.use_delta_bending:
|
402 |
+
if isinstance(self.bending_network, list):
|
403 |
+
bended_pts = []
|
404 |
+
queries_sdfs_selector = []
|
405 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
406 |
+
if cur_bending_network.use_opt_rigid_translations:
|
407 |
+
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
408 |
+
else:
|
409 |
+
# bended_pts_exp = pts_exp.clone()
|
410 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
411 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
|
412 |
+
_, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
|
413 |
+
bended_pts.append(bended_pts_exp)
|
414 |
+
queries_sdfs_selector.append(cur_bended_pts_selecotr)
|
415 |
+
bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
|
416 |
+
queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
|
417 |
+
# queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
|
418 |
+
sdf_selector = queries_sdfs_selector[:, -1]
|
419 |
+
# sdf_selector = queries_sdfs_selector
|
420 |
+
# delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
|
421 |
+
bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
|
422 |
+
# print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
|
423 |
+
pts_exp = bended_pts.squeeze(1)
|
424 |
+
|
425 |
+
|
426 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
427 |
+
# if isinstance(self.bending_network, list):
|
428 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
429 |
+
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
430 |
+
# else:
|
431 |
+
|
432 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
433 |
+
else:
|
434 |
+
if isinstance(self.bending_network, list): # prev sdf network #
|
435 |
+
# pts_offsets = []
|
436 |
+
bended_pts = []
|
437 |
+
queries_sdfs_selector = []
|
438 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
439 |
+
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
440 |
+
# pts_offsets.append(bended_pts_exp - pts_exp)
|
441 |
+
_, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
|
442 |
+
bended_pts.append(bended_pts_exp)
|
443 |
+
queries_sdfs_selector.append(cur_bended_pts_selecotr)
|
444 |
+
bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
|
445 |
+
queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
|
446 |
+
# queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
|
447 |
+
sdf_selector = queries_sdfs_selector[:, -1]
|
448 |
+
# sdf_selector = queries_sdfs_selector
|
449 |
+
|
450 |
+
|
451 |
+
# delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
|
452 |
+
bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
|
453 |
+
# print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
|
454 |
+
pts_exp = bended_pts.squeeze(1)
|
455 |
+
|
456 |
+
# pts_offsets = torch.stack(pts_offsets, dim=0)
|
457 |
+
# pts_offsets = torch.sum(pts_offsets, dim=0)
|
458 |
+
# pts_exp = pts_exp + pts_offsets
|
459 |
+
else:
|
460 |
+
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
|
461 |
+
if len(pts.size()) == 3:
|
462 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
463 |
+
else:
|
464 |
+
pts = pts_exp
|
465 |
+
return pts
|
466 |
+
|
467 |
+
# pts: nn_batch x nn_samples x 3
|
468 |
+
if len(pts.size()) == 3:
|
469 |
+
nnb, nns = pts.size(0), pts.size(1)
|
470 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
471 |
+
else:
|
472 |
+
pts_exp = pts
|
473 |
+
# print(f"prior to deforming: {pts.size()}")
|
474 |
+
|
475 |
+
dist_pts_to_bkg_pts = torch.sum(
|
476 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
477 |
+
)
|
478 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
479 |
+
dist_mask_float = dist_mask.float()
|
480 |
+
|
481 |
+
# dist_mask_float #
|
482 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
483 |
+
cur_fr_pts_def = torch.sum(
|
484 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
485 |
+
)
|
486 |
+
dist_mask_float_summ = torch.sum(
|
487 |
+
dist_mask_float, dim=1
|
488 |
+
)
|
489 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
490 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
491 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
492 |
+
if len(pts.size()) == 3:
|
493 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
494 |
+
else:
|
495 |
+
pts = pts_exp
|
496 |
+
return pts #
|
497 |
+
|
498 |
+
|
499 |
+
def deform_pts_passive(self, pts, pts_ts=0):
|
500 |
+
|
501 |
+
if self.use_bending_network:
|
502 |
+
if len(pts.size()) == 3:
|
503 |
+
nnb, nns = pts.size(0), pts.size(1)
|
504 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
505 |
+
else:
|
506 |
+
pts_exp = pts
|
507 |
+
# pts_ts #
|
508 |
+
if self.use_delta_bending:
|
509 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
510 |
+
if isinstance(self.bending_network, list):
|
511 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
512 |
+
pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
513 |
+
else:
|
514 |
+
pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
515 |
+
else:
|
516 |
+
# if isinstance(self.bending_network, list):
|
517 |
+
# pts_offsets = []
|
518 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
519 |
+
# bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
520 |
+
# pts_offsets.append(bended_pts_exp - pts_exp)
|
521 |
+
# pts_offsets = torch.stack(pts_offsets, dim=0)
|
522 |
+
# pts_offsets = torch.sum(pts_offsets, dim=0)
|
523 |
+
# pts_exp = pts_exp + pts_offsets
|
524 |
+
# else:
|
525 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=pts_ts)
|
526 |
+
if len(pts.size()) == 3:
|
527 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
528 |
+
else:
|
529 |
+
pts = pts_exp
|
530 |
+
return pts
|
531 |
+
|
532 |
+
# pts: nn_batch x nn_samples x 3
|
533 |
+
if len(pts.size()) == 3:
|
534 |
+
nnb, nns = pts.size(0), pts.size(1)
|
535 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
536 |
+
else:
|
537 |
+
pts_exp = pts
|
538 |
+
# print(f"prior to deforming: {pts.size()}")
|
539 |
+
|
540 |
+
dist_pts_to_bkg_pts = torch.sum(
|
541 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
542 |
+
)
|
543 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
544 |
+
dist_mask_float = dist_mask.float()
|
545 |
+
|
546 |
+
# dist_mask_float #
|
547 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
548 |
+
cur_fr_pts_def = torch.sum(
|
549 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
550 |
+
)
|
551 |
+
dist_mask_float_summ = torch.sum(
|
552 |
+
dist_mask_float, dim=1
|
553 |
+
)
|
554 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
555 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
556 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
557 |
+
if len(pts.size()) == 3:
|
558 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
559 |
+
else:
|
560 |
+
pts = pts_exp
|
561 |
+
return pts #
|
562 |
+
|
563 |
+
|
564 |
+
def query_pts_sdf_fn_for_selector(self, pts):
|
565 |
+
# for negative
|
566 |
+
# 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field
|
567 |
+
# 2) negative in thie field and also negative in the previous field --->
|
568 |
+
# 2) for positive values of this current field --->
|
569 |
+
cur_sdf = self.sdf_network.sdf(pts)
|
570 |
+
prev_sdf = self.prev_sdf_network.sdf(pts)
|
571 |
+
neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
|
572 |
+
neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
|
573 |
+
|
574 |
+
neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
|
575 |
+
|
576 |
+
pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
|
577 |
+
pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
|
578 |
+
res_sdf = torch.zeros_like(cur_sdf)
|
579 |
+
res_sdf[neg_neg] = 1. #
|
580 |
+
res_sdf[neg_pos] = cur_sdf[neg_pos]
|
581 |
+
res_sdf[pos_neg] = cur_sdf[pos_neg]
|
582 |
+
|
583 |
+
# inside the residual mesh -> must be neg and pos
|
584 |
+
res_sdf_selector = torch.zeros_like(cur_sdf).long() #
|
585 |
+
# res_sdf_selector[neg_pos] = 1 # is the residual mesh
|
586 |
+
res_sdf_selector[neg_weq_pos] = 1
|
587 |
+
# res_sdf_selector[]
|
588 |
+
|
589 |
+
cat_cur_prev_sdf = torch.stack(
|
590 |
+
[cur_sdf, prev_sdf], dim=-1
|
591 |
+
)
|
592 |
+
minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
|
593 |
+
res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
|
594 |
+
|
595 |
+
return res_sdf, res_sdf_selector
|
596 |
+
|
597 |
+
def query_func_sdf(self, pts):
|
598 |
+
if isinstance(self.sdf_network, list):
|
599 |
+
tot_sdf_values = []
|
600 |
+
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
601 |
+
cur_sdf_values = cur_sdf_network.sdf(pts)
|
602 |
+
tot_sdf_values.append(cur_sdf_values)
|
603 |
+
tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
|
604 |
+
tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
605 |
+
sdf = tot_sdf_values
|
606 |
+
else:
|
607 |
+
sdf = self.sdf_network.sdf(pts)
|
608 |
+
return sdf
|
609 |
+
|
610 |
+
def query_func_sdf_passive(self, pts):
|
611 |
+
# if isinstance(self.sdf_network, list):
|
612 |
+
# tot_sdf_values = []
|
613 |
+
# for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
614 |
+
# cur_sdf_values = cur_sdf_network.sdf(pts)
|
615 |
+
# tot_sdf_values.append(cur_sdf_values)
|
616 |
+
# tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
|
617 |
+
# tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
618 |
+
# sdf = tot_sdf_values
|
619 |
+
# else:
|
620 |
+
sdf = self.sdf_network[-1].sdf(pts)
|
621 |
+
|
622 |
+
return sdf
|
623 |
+
|
624 |
+
|
625 |
+
def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None, pts_ts=0):
|
626 |
+
"""
|
627 |
+
Render background
|
628 |
+
"""
|
629 |
+
batch_size, n_samples = z_vals.shape
|
630 |
+
|
631 |
+
# Section length
|
632 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
633 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
634 |
+
mid_z_vals = z_vals + dists * 0.5
|
635 |
+
|
636 |
+
# Section midpoints #
|
637 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
|
638 |
+
|
639 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
640 |
+
pts = pts * 2 - 1
|
641 |
+
|
642 |
+
if self.use_selector:
|
643 |
+
pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
644 |
+
else:
|
645 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
646 |
+
|
647 |
+
dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
|
648 |
+
pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
|
649 |
+
|
650 |
+
dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
|
651 |
+
|
652 |
+
pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
|
653 |
+
dirs = dirs.reshape(-1, 3)
|
654 |
+
|
655 |
+
density, sampled_color = nerf(pts, dirs)
|
656 |
+
sampled_color = torch.sigmoid(sampled_color)
|
657 |
+
alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
|
658 |
+
alpha = alpha.reshape(batch_size, n_samples)
|
659 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
660 |
+
sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
|
661 |
+
color = (weights[:, :, None] * sampled_color).sum(dim=1)
|
662 |
+
if background_rgb is not None:
|
663 |
+
color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
|
664 |
+
|
665 |
+
return {
|
666 |
+
'color': color,
|
667 |
+
'sampled_color': sampled_color,
|
668 |
+
'alpha': alpha,
|
669 |
+
'weights': weights,
|
670 |
+
}
|
671 |
+
|
672 |
+
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s, pts_ts=0):
|
673 |
+
"""
|
674 |
+
Up sampling give a fixed inv_s
|
675 |
+
"""
|
676 |
+
batch_size, n_samples = z_vals.shape
|
677 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
|
678 |
+
|
679 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
680 |
+
pts = pts * 2 - 1
|
681 |
+
|
682 |
+
if self.use_selector:
|
683 |
+
pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
684 |
+
else:
|
685 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
686 |
+
|
687 |
+
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
|
688 |
+
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
|
689 |
+
sdf = sdf.reshape(batch_size, n_samples)
|
690 |
+
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
|
691 |
+
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
|
692 |
+
mid_sdf = (prev_sdf + next_sdf) * 0.5
|
693 |
+
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
|
694 |
+
|
695 |
+
# ----------------------------------------------------------------------------------------------------------
|
696 |
+
# Use min value of [ cos, prev_cos ]
|
697 |
+
# Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
|
698 |
+
# robust when meeting situations like below:
|
699 |
+
#
|
700 |
+
# SDF
|
701 |
+
# ^
|
702 |
+
# |\ -----x----...
|
703 |
+
# | \ /
|
704 |
+
# | x x
|
705 |
+
# |---\----/-------------> 0 level
|
706 |
+
# | \ /
|
707 |
+
# | \/
|
708 |
+
# |
|
709 |
+
# ----------------------------------------------------------------------------------------------------------
|
710 |
+
prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
|
711 |
+
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
|
712 |
+
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
|
713 |
+
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
|
714 |
+
|
715 |
+
dist = (next_z_vals - prev_z_vals)
|
716 |
+
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
|
717 |
+
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
|
718 |
+
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
|
719 |
+
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
|
720 |
+
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
|
721 |
+
weights = alpha * torch.cumprod(
|
722 |
+
torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
723 |
+
|
724 |
+
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
|
725 |
+
return z_samples
|
726 |
+
|
727 |
+
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False, pts_ts=0):
|
728 |
+
batch_size, n_samples = z_vals.shape
|
729 |
+
_, n_importance = new_z_vals.shape
|
730 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
|
731 |
+
|
732 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
733 |
+
pts = pts * 2 - 1
|
734 |
+
|
735 |
+
if self.use_selector:
|
736 |
+
pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
737 |
+
else:
|
738 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
739 |
+
|
740 |
+
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
|
741 |
+
z_vals, index = torch.sort(z_vals, dim=-1)
|
742 |
+
|
743 |
+
if not last:
|
744 |
+
if isinstance(self.sdf_network, list):
|
745 |
+
tot_new_sdf = []
|
746 |
+
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
747 |
+
cur_new_sdf = cur_sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
748 |
+
tot_new_sdf.append(cur_new_sdf)
|
749 |
+
tot_new_sdf = torch.stack(tot_new_sdf, dim=-1)
|
750 |
+
new_sdf, _ = torch.min(tot_new_sdf, dim=-1) #
|
751 |
+
else:
|
752 |
+
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
753 |
+
sdf = torch.cat([sdf, new_sdf], dim=-1)
|
754 |
+
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
|
755 |
+
index = index.reshape(-1)
|
756 |
+
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
|
757 |
+
|
758 |
+
return z_vals, sdf
|
759 |
+
|
760 |
+
|
761 |
+
|
762 |
+
def render_core(self,
|
763 |
+
rays_o,
|
764 |
+
rays_d,
|
765 |
+
z_vals,
|
766 |
+
sample_dist,
|
767 |
+
sdf_network,
|
768 |
+
deviation_network,
|
769 |
+
color_network,
|
770 |
+
background_alpha=None,
|
771 |
+
background_sampled_color=None,
|
772 |
+
background_rgb=None,
|
773 |
+
cos_anneal_ratio=0.0,
|
774 |
+
pts_ts=0):
|
775 |
+
batch_size, n_samples = z_vals.shape
|
776 |
+
|
777 |
+
# Section length
|
778 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
779 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
780 |
+
mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
|
781 |
+
|
782 |
+
# Section midpoints
|
783 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
|
784 |
+
dirs = rays_d[:, None, :].expand(pts.shape)
|
785 |
+
|
786 |
+
pts = pts.reshape(-1, 3) # pts, nn_ou
|
787 |
+
dirs = dirs.reshape(-1, 3)
|
788 |
+
|
789 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
790 |
+
|
791 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
792 |
+
pts = pts * 2 - 1
|
793 |
+
|
794 |
+
if self.use_selector:
|
795 |
+
pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
796 |
+
else:
|
797 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
798 |
+
|
799 |
+
if isinstance(sdf_network, list):
|
800 |
+
tot_sdf = []
|
801 |
+
tot_feature_vector = []
|
802 |
+
tot_obj_sel = []
|
803 |
+
tot_gradients = []
|
804 |
+
for i_obj, cur_sdf_network in enumerate(sdf_network):
|
805 |
+
cur_sdf_nn_output = cur_sdf_network(pts)
|
806 |
+
cur_sdf, cur_feature_vector = cur_sdf_nn_output[:, :1], cur_sdf_nn_output[:, 1:]
|
807 |
+
tot_sdf.append(cur_sdf)
|
808 |
+
tot_feature_vector.append(cur_feature_vector)
|
809 |
+
|
810 |
+
gradients = cur_sdf_network.gradient(pts).squeeze()
|
811 |
+
tot_gradients.append(gradients)
|
812 |
+
tot_sdf = torch.stack(tot_sdf, dim=-1)
|
813 |
+
sdf, obj_sel = torch.min(tot_sdf, dim=-1)
|
814 |
+
feature_vector = torch.stack(tot_feature_vector, dim=1)
|
815 |
+
|
816 |
+
# batched_index_select
|
817 |
+
# print(f"before sel: {feature_vector.size()}, obj_sel: {obj_sel.size()}")
|
818 |
+
feature_vector = batched_index_select(values=feature_vector, indices=obj_sel, dim=1).squeeze(1)
|
819 |
+
|
820 |
+
|
821 |
+
# feature_vector = feature_vector[obj_sel.unsqueeze(-1), :].squeeze(1)
|
822 |
+
# print(f"after sel: {feature_vector.size()}")
|
823 |
+
tot_gradients = torch.stack(tot_gradients, dim=1)
|
824 |
+
# gradients = tot_gradients[obj_sel.unsqueeze(-1)].squeeze(1)
|
825 |
+
gradients = batched_index_select(values=tot_gradients, indices=obj_sel, dim=1).squeeze(1)
|
826 |
+
# print(f"gradients: {gradients.size()}, tot_gradients: {tot_gradients.size()}")
|
827 |
+
|
828 |
+
else:
|
829 |
+
sdf_nn_output = sdf_network(pts)
|
830 |
+
sdf = sdf_nn_output[:, :1]
|
831 |
+
feature_vector = sdf_nn_output[:, 1:]
|
832 |
+
gradients = sdf_network.gradient(pts).squeeze()
|
833 |
+
|
834 |
+
sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
|
835 |
+
|
836 |
+
# deviation network #
|
837 |
+
inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
|
838 |
+
inv_s = inv_s.expand(batch_size * n_samples, 1)
|
839 |
+
|
840 |
+
true_cos = (dirs * gradients).sum(-1, keepdim=True)
|
841 |
+
|
842 |
+
# "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
|
843 |
+
# the cos value "not dead" at the beginning training iterations, for better convergence.
|
844 |
+
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
|
845 |
+
F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
|
846 |
+
|
847 |
+
# Estimate signed distances at section points
|
848 |
+
estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
|
849 |
+
estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
|
850 |
+
|
851 |
+
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
|
852 |
+
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
|
853 |
+
|
854 |
+
p = prev_cdf - next_cdf
|
855 |
+
c = prev_cdf
|
856 |
+
|
857 |
+
alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
|
858 |
+
|
859 |
+
pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
|
860 |
+
inside_sphere = (pts_norm < 1.0).float().detach()
|
861 |
+
relax_inside_sphere = (pts_norm < 1.2).float().detach()
|
862 |
+
|
863 |
+
# Render with background
|
864 |
+
if background_alpha is not None:
|
865 |
+
alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
|
866 |
+
alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
|
867 |
+
sampled_color = sampled_color * inside_sphere[:, :, None] +\
|
868 |
+
background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
|
869 |
+
sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
|
870 |
+
|
871 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
872 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
873 |
+
|
874 |
+
color = (sampled_color * weights[:, :, None]).sum(dim=1)
|
875 |
+
if background_rgb is not None: # Fixed background, usually black
|
876 |
+
color = color + background_rgb * (1.0 - weights_sum)
|
877 |
+
|
878 |
+
# Eikonal loss
|
879 |
+
gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
|
880 |
+
dim=-1) - 1.0) ** 2
|
881 |
+
gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
|
882 |
+
|
883 |
+
return {
|
884 |
+
'color': color,
|
885 |
+
'sdf': sdf,
|
886 |
+
'dists': dists,
|
887 |
+
'gradients': gradients.reshape(batch_size, n_samples, 3),
|
888 |
+
's_val': 1.0 / inv_s,
|
889 |
+
'mid_z_vals': mid_z_vals,
|
890 |
+
'weights': weights,
|
891 |
+
'cdf': c.reshape(batch_size, n_samples),
|
892 |
+
'gradient_error': gradient_error,
|
893 |
+
'inside_sphere': inside_sphere
|
894 |
+
}
|
895 |
+
|
896 |
+
def render(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
|
897 |
+
batch_size = len(rays_o)
|
898 |
+
sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere
|
899 |
+
z_vals = torch.linspace(0.0, 1.0, self.n_samples) # linspace #
|
900 |
+
z_vals = near + (far - near) * z_vals[None, :]
|
901 |
+
|
902 |
+
z_vals_outside = None
|
903 |
+
if self.n_outside > 0:
|
904 |
+
z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
|
905 |
+
|
906 |
+
n_samples = self.n_samples
|
907 |
+
perturb = self.perturb
|
908 |
+
|
909 |
+
if perturb_overwrite >= 0:
|
910 |
+
perturb = perturb_overwrite
|
911 |
+
if perturb > 0:
|
912 |
+
t_rand = (torch.rand([batch_size, 1]) - 0.5)
|
913 |
+
z_vals = z_vals + t_rand * 2.0 / self.n_samples
|
914 |
+
|
915 |
+
if self.n_outside > 0: # z values output # n_outside #
|
916 |
+
mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
|
917 |
+
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
|
918 |
+
lower = torch.cat([z_vals_outside[..., :1], mids], -1)
|
919 |
+
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
|
920 |
+
z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
|
921 |
+
|
922 |
+
if self.n_outside > 0:
|
923 |
+
z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
|
924 |
+
|
925 |
+
background_alpha = None
|
926 |
+
background_sampled_color = None
|
927 |
+
|
928 |
+
# Up sample
|
929 |
+
if self.n_importance > 0:
|
930 |
+
with torch.no_grad():
|
931 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
932 |
+
|
933 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
934 |
+
# sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
|
935 |
+
# gt_sdf #
|
936 |
+
|
937 |
+
#
|
938 |
+
# pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
|
939 |
+
|
940 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
941 |
+
pts = pts * 2 - 1
|
942 |
+
|
943 |
+
if self.use_selector:
|
944 |
+
pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
945 |
+
else:
|
946 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts) # give nthe pts
|
947 |
+
|
948 |
+
pts_exp = pts.reshape(-1, 3)
|
949 |
+
# minn_pts, _ = torch.min(pts_exp, dim=0)
|
950 |
+
# maxx_pts, _ = torch.max(pts_exp, dim=0) # deformation field (not a rigid one) -> the meshes #
|
951 |
+
# print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
|
952 |
+
|
953 |
+
# pts_to_near = pts - near.unsqueeze(1)
|
954 |
+
# maxx_pts = 1.5; minn_pts = -1.5
|
955 |
+
# # maxx_pts = 3; minn_pts = -3
|
956 |
+
# # maxx_pts = 1; minn_pts = -1
|
957 |
+
# pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
|
958 |
+
|
959 |
+
## render and iamges ####
|
960 |
+
if use_gt_sdf:
|
961 |
+
### use the GT sdf field ####
|
962 |
+
# print(f"Using gt sdf :")
|
963 |
+
sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
|
964 |
+
sdf = torch.from_numpy(sdf).float().cuda()
|
965 |
+
sdf = sdf.reshape(batch_size, self.n_samples)
|
966 |
+
### use the GT sdf field ####
|
967 |
+
else:
|
968 |
+
# pts_exp: (bsz x nn_s) x 3 -> (sdf_network) -> (bsz x nn_s)
|
969 |
+
#### use the optimized sdf field ####
|
970 |
+
|
971 |
+
# sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
972 |
+
|
973 |
+
if isinstance(self.sdf_network, list):
|
974 |
+
tot_sdf_values = []
|
975 |
+
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
976 |
+
cur_sdf_values = cur_sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
977 |
+
tot_sdf_values.append(cur_sdf_values)
|
978 |
+
tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
|
979 |
+
tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
980 |
+
sdf = tot_sdf_values
|
981 |
+
else:
|
982 |
+
sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
983 |
+
|
984 |
+
#### use the optimized sdf field ####
|
985 |
+
|
986 |
+
for i in range(self.up_sample_steps):
|
987 |
+
new_z_vals = self.up_sample(rays_o,
|
988 |
+
rays_d,
|
989 |
+
z_vals,
|
990 |
+
sdf,
|
991 |
+
self.n_importance // self.up_sample_steps,
|
992 |
+
64 * 2**i,
|
993 |
+
pts_ts=pts_ts)
|
994 |
+
z_vals, sdf = self.cat_z_vals(rays_o,
|
995 |
+
rays_d,
|
996 |
+
z_vals,
|
997 |
+
new_z_vals,
|
998 |
+
sdf,
|
999 |
+
last=(i + 1 == self.up_sample_steps),
|
1000 |
+
pts_ts=pts_ts)
|
1001 |
+
|
1002 |
+
n_samples = self.n_samples + self.n_importance
|
1003 |
+
|
1004 |
+
# Background model
|
1005 |
+
if self.n_outside > 0:
|
1006 |
+
z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
|
1007 |
+
z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
|
1008 |
+
ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf, pts_ts=pts_ts)
|
1009 |
+
|
1010 |
+
background_sampled_color = ret_outside['sampled_color']
|
1011 |
+
background_alpha = ret_outside['alpha']
|
1012 |
+
|
1013 |
+
# Render core
|
1014 |
+
ret_fine = self.render_core(rays_o, #
|
1015 |
+
rays_d,
|
1016 |
+
z_vals,
|
1017 |
+
sample_dist,
|
1018 |
+
self.sdf_network,
|
1019 |
+
self.deviation_network,
|
1020 |
+
self.color_network,
|
1021 |
+
background_rgb=background_rgb,
|
1022 |
+
background_alpha=background_alpha,
|
1023 |
+
background_sampled_color=background_sampled_color,
|
1024 |
+
cos_anneal_ratio=cos_anneal_ratio,
|
1025 |
+
pts_ts=pts_ts)
|
1026 |
+
|
1027 |
+
color_fine = ret_fine['color']
|
1028 |
+
weights = ret_fine['weights']
|
1029 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
1030 |
+
gradients = ret_fine['gradients']
|
1031 |
+
s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
|
1032 |
+
|
1033 |
+
return {
|
1034 |
+
'color_fine': color_fine,
|
1035 |
+
's_val': s_val,
|
1036 |
+
'cdf_fine': ret_fine['cdf'],
|
1037 |
+
'weight_sum': weights_sum,
|
1038 |
+
'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
|
1039 |
+
'gradients': gradients,
|
1040 |
+
'weights': weights,
|
1041 |
+
'gradient_error': ret_fine['gradient_error'],
|
1042 |
+
'inside_sphere': ret_fine['inside_sphere']
|
1043 |
+
}
|
1044 |
+
|
1045 |
+
def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
|
1046 |
+
return extract_geometry(bound_min, # extract geometry #
|
1047 |
+
bound_max,
|
1048 |
+
resolution=resolution,
|
1049 |
+
threshold=threshold,
|
1050 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts),
|
1051 |
+
query_func=lambda pts: -self.query_func_sdf(pts)
|
1052 |
+
)
|
1053 |
+
|
1054 |
+
# if self.deform_pts_with_selector:
|
1055 |
+
# pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
1056 |
+
def extract_geometry_tets(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
|
1057 |
+
if wdef:
|
1058 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1059 |
+
bound_max,
|
1060 |
+
resolution=resolution,
|
1061 |
+
threshold=threshold,
|
1062 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1063 |
+
def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts))
|
1064 |
+
else:
|
1065 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1066 |
+
bound_max,
|
1067 |
+
resolution=resolution,
|
1068 |
+
threshold=threshold,
|
1069 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts)
|
1070 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1071 |
+
)
|
1072 |
+
|
1073 |
+
def extract_geometry_tets_passive(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
|
1074 |
+
if wdef:
|
1075 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1076 |
+
bound_max,
|
1077 |
+
resolution=resolution,
|
1078 |
+
threshold=threshold,
|
1079 |
+
query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1080 |
+
def_func=lambda pts: self.deform_pts_passive(pts, pts_ts=pts_ts))
|
1081 |
+
else:
|
1082 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1083 |
+
bound_max,
|
1084 |
+
resolution=resolution,
|
1085 |
+
threshold=threshold,
|
1086 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts)
|
1087 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1088 |
+
)
|
models/renderer_def_multi_objs_compositional.py
ADDED
@@ -0,0 +1,1510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import logging
|
6 |
+
import mcubes
|
7 |
+
from icecream import ic
|
8 |
+
import os
|
9 |
+
|
10 |
+
import trimesh
|
11 |
+
from pysdf import SDF
|
12 |
+
|
13 |
+
import models.fields as fields
|
14 |
+
|
15 |
+
from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
|
16 |
+
|
17 |
+
def batched_index_select(values, indices, dim = 1):
|
18 |
+
value_dims = values.shape[(dim + 1):]
|
19 |
+
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
|
20 |
+
indices = indices[(..., *((None,) * len(value_dims)))]
|
21 |
+
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
|
22 |
+
value_expand_len = len(indices_shape) - (dim + 1)
|
23 |
+
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
|
24 |
+
|
25 |
+
value_expand_shape = [-1] * len(values.shape)
|
26 |
+
expand_slice = slice(dim, (dim + value_expand_len))
|
27 |
+
value_expand_shape[expand_slice] = indices.shape[expand_slice]
|
28 |
+
values = values.expand(*value_expand_shape)
|
29 |
+
|
30 |
+
dim += value_expand_len
|
31 |
+
return values.gather(dim, indices)
|
32 |
+
|
33 |
+
|
34 |
+
def create_mt_variable(device):
|
35 |
+
triangle_table = torch.tensor(
|
36 |
+
[
|
37 |
+
[-1, -1, -1, -1, -1, -1],
|
38 |
+
[1, 0, 2, -1, -1, -1],
|
39 |
+
[4, 0, 3, -1, -1, -1],
|
40 |
+
[1, 4, 2, 1, 3, 4],
|
41 |
+
[3, 1, 5, -1, -1, -1],
|
42 |
+
[2, 3, 0, 2, 5, 3],
|
43 |
+
[1, 4, 0, 1, 5, 4],
|
44 |
+
[4, 2, 5, -1, -1, -1],
|
45 |
+
[4, 5, 2, -1, -1, -1],
|
46 |
+
[4, 1, 0, 4, 5, 1],
|
47 |
+
[3, 2, 0, 3, 5, 2],
|
48 |
+
[1, 3, 5, -1, -1, -1],
|
49 |
+
[4, 1, 2, 4, 3, 1],
|
50 |
+
[3, 0, 4, -1, -1, -1],
|
51 |
+
[2, 0, 1, -1, -1, -1],
|
52 |
+
[-1, -1, -1, -1, -1, -1]
|
53 |
+
], dtype=torch.long, device=device)
|
54 |
+
|
55 |
+
num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
|
56 |
+
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
|
57 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
|
58 |
+
return triangle_table, num_triangles_table, base_tet_edges, v_id
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
def extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=None):
|
63 |
+
# load tet via resolution #
|
64 |
+
# scale them via bounds #
|
65 |
+
# extract the geometry #
|
66 |
+
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
|
67 |
+
device = bound_min.device
|
68 |
+
# if resolution in [64, 70, 80, 90, 100]:
|
69 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
|
70 |
+
# else:
|
71 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
72 |
+
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
73 |
+
if not os.path.exists(tet_fn):
|
74 |
+
tet_fn = f"/data/xueyi/NeuS/data/tets/{100}_compress.npz"
|
75 |
+
tets = np.load(tet_fn)
|
76 |
+
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
|
77 |
+
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
|
78 |
+
# split #
|
79 |
+
# verts; verts; #
|
80 |
+
minn_verts, _ = torch.min(verts, dim=0)
|
81 |
+
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
|
82 |
+
# scale_verts = maxx_verts - minn_verts
|
83 |
+
scale_bounds = bound_max - bound_min # scale bounds #
|
84 |
+
|
85 |
+
### scale the vertices ###
|
86 |
+
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
87 |
+
|
88 |
+
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
89 |
+
|
90 |
+
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
|
91 |
+
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
|
92 |
+
|
93 |
+
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
|
94 |
+
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
|
95 |
+
|
96 |
+
sdf_values = []
|
97 |
+
N = 64
|
98 |
+
query_bundles = N ** 3 ### N^3
|
99 |
+
query_NNs = scaled_verts.size(0) // query_bundles
|
100 |
+
if query_NNs * query_bundles < scaled_verts.size(0):
|
101 |
+
query_NNs += 1
|
102 |
+
for i_query in range(query_NNs):
|
103 |
+
cur_bundle_st = i_query * query_bundles
|
104 |
+
cur_bundle_ed = (i_query + 1) * query_bundles
|
105 |
+
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
|
106 |
+
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
|
107 |
+
if def_func is not None:
|
108 |
+
cur_query_pts = def_func(cur_query_pts)
|
109 |
+
cur_query_vals = query_func(cur_query_pts)
|
110 |
+
sdf_values.append(cur_query_vals)
|
111 |
+
sdf_values = torch.cat(sdf_values, dim=0)
|
112 |
+
# print(f"queryed sdf values: {sdf_values.size()}") #
|
113 |
+
|
114 |
+
# GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
|
115 |
+
gt_sdf_fn = "/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy"
|
116 |
+
if not os.path.exists(gt_sdf_fn):
|
117 |
+
gt_sdf_fn = "/data/xueyi/NeuS/data/100_sdf_values.npy"
|
118 |
+
GT_sdf_values = np.load(gt_sdf_fn, allow_pickle=True)
|
119 |
+
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
|
120 |
+
|
121 |
+
# intrinsic, tet values, pts values, sdf network #
|
122 |
+
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
|
123 |
+
tet_table, num_tets_table = create_tetmesh_variables(device)
|
124 |
+
|
125 |
+
sdf_values = sdf_values.squeeze(-1) # how the rendering #
|
126 |
+
|
127 |
+
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
|
128 |
+
# print(f"scaled_verts: {scaled_verts.size()}, ")
|
129 |
+
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
130 |
+
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
131 |
+
# marching_tets_tetmesh ##
|
132 |
+
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
133 |
+
### use the GT sdf values for the marching tets ###
|
134 |
+
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
135 |
+
|
136 |
+
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
|
137 |
+
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
|
138 |
+
|
139 |
+
|
140 |
+
def extract_fields_from_tets_selector(bound_min, bound_max, resolution, query_func, def_func=None):
|
141 |
+
# load tet via resolution #
|
142 |
+
# scale them via bounds #
|
143 |
+
# extract the geometry #
|
144 |
+
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
|
145 |
+
device = bound_min.device
|
146 |
+
# if resolution in [64, 70, 80, 90, 100]:
|
147 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
|
148 |
+
# else:
|
149 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
150 |
+
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
151 |
+
if not os.path.exists(tet_fn):
|
152 |
+
tet_fn = f"/data/xueyi/NeuS/data/tets/{100}_compress.npz"
|
153 |
+
tets = np.load(tet_fn)
|
154 |
+
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
|
155 |
+
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
|
156 |
+
# split #
|
157 |
+
# verts; verts; #
|
158 |
+
minn_verts, _ = torch.min(verts, dim=0)
|
159 |
+
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
|
160 |
+
# scale_verts = maxx_verts - minn_verts
|
161 |
+
scale_bounds = bound_max - bound_min # scale bounds #
|
162 |
+
|
163 |
+
### scale the vertices ###
|
164 |
+
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
165 |
+
|
166 |
+
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
167 |
+
|
168 |
+
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
|
169 |
+
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
|
170 |
+
|
171 |
+
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
|
172 |
+
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
|
173 |
+
|
174 |
+
sdf_values = []
|
175 |
+
N = 64
|
176 |
+
query_bundles = N ** 3 ### N^3
|
177 |
+
query_NNs = scaled_verts.size(0) // query_bundles
|
178 |
+
if query_NNs * query_bundles < scaled_verts.size(0):
|
179 |
+
query_NNs += 1
|
180 |
+
for i_query in range(query_NNs):
|
181 |
+
cur_bundle_st = i_query * query_bundles
|
182 |
+
cur_bundle_ed = (i_query + 1) * query_bundles
|
183 |
+
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
|
184 |
+
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
|
185 |
+
if def_func is not None:
|
186 |
+
cur_query_pts, _ = def_func(cur_query_pts)
|
187 |
+
cur_query_vals = query_func(cur_query_pts)
|
188 |
+
sdf_values.append(cur_query_vals)
|
189 |
+
sdf_values = torch.cat(sdf_values, dim=0)
|
190 |
+
# print(f"queryed sdf values: {sdf_values.size()}") #
|
191 |
+
|
192 |
+
# GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
|
193 |
+
gt_sdf_fn = "/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy"
|
194 |
+
if not os.path.exists(gt_sdf_fn):
|
195 |
+
gt_sdf_fn = "/data/xueyi/NeuS/data/100_sdf_values.npy"
|
196 |
+
GT_sdf_values = np.load(gt_sdf_fn, allow_pickle=True)
|
197 |
+
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
|
198 |
+
|
199 |
+
# intrinsic, tet values, pts values, sdf network #
|
200 |
+
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
|
201 |
+
tet_table, num_tets_table = create_tetmesh_variables(device)
|
202 |
+
|
203 |
+
sdf_values = sdf_values.squeeze(-1) # how the rendering #
|
204 |
+
|
205 |
+
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
|
206 |
+
# print(f"scaled_verts: {scaled_verts.size()}, ")
|
207 |
+
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
208 |
+
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
209 |
+
# marching_tets_tetmesh ##
|
210 |
+
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
211 |
+
### use the GT sdf values for the marching tets ###
|
212 |
+
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
213 |
+
|
214 |
+
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
|
215 |
+
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
|
216 |
+
|
217 |
+
|
218 |
+
def extract_fields(bound_min, bound_max, resolution, query_func):
|
219 |
+
N = 64
|
220 |
+
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
|
221 |
+
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
|
222 |
+
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
|
223 |
+
|
224 |
+
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
225 |
+
with torch.no_grad():
|
226 |
+
for xi, xs in enumerate(X):
|
227 |
+
for yi, ys in enumerate(Y):
|
228 |
+
for zi, zs in enumerate(Z):
|
229 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
230 |
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
|
231 |
+
val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
|
232 |
+
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
|
233 |
+
# should save u here #
|
234 |
+
# save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
|
235 |
+
# np.save(save_u_path, u) #
|
236 |
+
# print(f"u saved to {save_u_path}")
|
237 |
+
return u
|
238 |
+
|
239 |
+
|
240 |
+
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
|
241 |
+
print('threshold: {}'.format(threshold))
|
242 |
+
|
243 |
+
## using maching cubes ###
|
244 |
+
u = extract_fields(bound_min, bound_max, resolution, query_func)
|
245 |
+
vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
|
246 |
+
b_max_np = bound_max.detach().cpu().numpy()
|
247 |
+
b_min_np = bound_min.detach().cpu().numpy()
|
248 |
+
|
249 |
+
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
250 |
+
### using maching cubes ###
|
251 |
+
|
252 |
+
### using marching tets ###
|
253 |
+
# vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
|
254 |
+
# vertices = vertices.detach().cpu().numpy()
|
255 |
+
# triangles = triangles.detach().cpu().numpy()
|
256 |
+
### using marching tets ###
|
257 |
+
|
258 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
259 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
260 |
+
|
261 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
262 |
+
return vertices, triangles
|
263 |
+
|
264 |
+
def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func, def_func=None, selector=False):
|
265 |
+
# print('threshold: {}'.format(threshold))
|
266 |
+
|
267 |
+
### using maching cubes ###
|
268 |
+
# u = extract_fields(bound_min, bound_max, resolution, query_func)
|
269 |
+
# vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
|
270 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
271 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
272 |
+
|
273 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
274 |
+
### using maching cubes ###
|
275 |
+
|
276 |
+
##
|
277 |
+
### using marching tets ### fiels from tets ##
|
278 |
+
if selector:
|
279 |
+
vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets_selector(bound_min, bound_max, resolution, query_func, def_func=def_func)
|
280 |
+
else:
|
281 |
+
vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=def_func)
|
282 |
+
# vertices = vertices.detach().cpu().numpy()
|
283 |
+
# triangles = triangles.detach().cpu().numpy()
|
284 |
+
### using marching tets ###
|
285 |
+
|
286 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
287 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
288 |
+
#
|
289 |
+
|
290 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
291 |
+
return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
|
292 |
+
|
293 |
+
|
294 |
+
### sample pdfs ###
|
295 |
+
def sample_pdf(bins, weights, n_samples, det=False):
|
296 |
+
# This implementation is from NeRF
|
297 |
+
# Get pdf
|
298 |
+
weights = weights + 1e-5 # prevent nans
|
299 |
+
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
300 |
+
cdf = torch.cumsum(pdf, -1)
|
301 |
+
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
302 |
+
# Take uniform samples
|
303 |
+
if det:
|
304 |
+
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
|
305 |
+
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
306 |
+
else:
|
307 |
+
u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
|
308 |
+
|
309 |
+
# Invert CDF # invert cdf #
|
310 |
+
u = u.contiguous()
|
311 |
+
inds = torch.searchsorted(cdf, u, right=True)
|
312 |
+
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
313 |
+
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
314 |
+
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
315 |
+
|
316 |
+
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
317 |
+
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
318 |
+
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
319 |
+
|
320 |
+
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
321 |
+
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
322 |
+
t = (u - cdf_g[..., 0]) / denom
|
323 |
+
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
324 |
+
|
325 |
+
return samples
|
326 |
+
|
327 |
+
|
328 |
+
def load_GT_vertices(GT_meshes_folder):
|
329 |
+
tot_meshes_fns = os.listdir(GT_meshes_folder)
|
330 |
+
tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
|
331 |
+
tot_mesh_verts = []
|
332 |
+
tot_mesh_faces = []
|
333 |
+
n_tot_verts = 0
|
334 |
+
for fn in tot_meshes_fns:
|
335 |
+
cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
|
336 |
+
obj_mesh = trimesh.load(cur_mesh_fn, process=False)
|
337 |
+
# obj_mesh.remove_degenerate_faces(height=1e-06)
|
338 |
+
|
339 |
+
verts_obj = np.array(obj_mesh.vertices)
|
340 |
+
faces_obj = np.array(obj_mesh.faces)
|
341 |
+
|
342 |
+
tot_mesh_verts.append(verts_obj)
|
343 |
+
tot_mesh_faces.append(faces_obj + n_tot_verts)
|
344 |
+
n_tot_verts += verts_obj.shape[0]
|
345 |
+
|
346 |
+
# tot_mesh_faces.append(faces_obj)
|
347 |
+
tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
|
348 |
+
tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
|
349 |
+
return tot_mesh_verts, tot_mesh_faces
|
350 |
+
|
351 |
+
|
352 |
+
class NeuSRenderer:
|
353 |
+
def __init__(self,
|
354 |
+
nerf,
|
355 |
+
sdf_network,
|
356 |
+
deviation_network,
|
357 |
+
color_network,
|
358 |
+
n_samples,
|
359 |
+
n_importance,
|
360 |
+
n_outside,
|
361 |
+
up_sample_steps,
|
362 |
+
perturb):
|
363 |
+
self.nerf = nerf #
|
364 |
+
self.sdf_network = sdf_network
|
365 |
+
self.deviation_network = deviation_network
|
366 |
+
self.color_network = color_network
|
367 |
+
self.n_samples = n_samples
|
368 |
+
self.n_importance = n_importance
|
369 |
+
self.n_outside = n_outside
|
370 |
+
self.up_sample_steps = up_sample_steps
|
371 |
+
self.perturb = perturb
|
372 |
+
|
373 |
+
GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
|
374 |
+
if not os.path.exists(GT_meshes_folder):
|
375 |
+
GT_meshes_folder = "/data/xueyi/diffsim/DiffHand/assets/hand"
|
376 |
+
self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
|
377 |
+
maxx_pts = 25.
|
378 |
+
minn_pts = -15.
|
379 |
+
self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
|
380 |
+
f = SDF(self.mesh_vertices, self.mesh_faces)
|
381 |
+
self.gt_sdf = f ## a unite sphere or box
|
382 |
+
|
383 |
+
self.minn_pts = 0
|
384 |
+
self.maxx_pts = 1.
|
385 |
+
|
386 |
+
# self.minn_pts = -1.5 #
|
387 |
+
# self.maxx_pts = 1.5 #
|
388 |
+
self.bkg_pts = ... # TODO
|
389 |
+
self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
|
390 |
+
self.dist_interp_thres = ... # TODO: set the cur_bkg_pts_defs #
|
391 |
+
|
392 |
+
self.bending_network = ... # TODO: add the bending network #
|
393 |
+
self.use_bending_network = ... # TODO: set the property #
|
394 |
+
self.use_delta_bending = ... # TODO
|
395 |
+
self.prev_sdf_network = ... # TODO
|
396 |
+
self.use_selector = False
|
397 |
+
self.timestep_to_passive_mesh = ... # TODO
|
398 |
+
self.timestep_to_active_mesh = ... # TODO
|
399 |
+
|
400 |
+
|
401 |
+
|
402 |
+
def deform_pts(self, pts, pts_ts=0, update_tot_def=True): # deform pts #
|
403 |
+
|
404 |
+
if self.use_bending_network:
|
405 |
+
if len(pts.size()) == 3:
|
406 |
+
nnb, nns = pts.size(0), pts.size(1)
|
407 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
408 |
+
else:
|
409 |
+
pts_exp = pts
|
410 |
+
# pts_ts #
|
411 |
+
if self.use_delta_bending:
|
412 |
+
|
413 |
+
if isinstance(self.bending_network, list):
|
414 |
+
pts_offsets = []
|
415 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
416 |
+
if isinstance(cur_bending_network, fields.BendingNetwork):
|
417 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
418 |
+
cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
|
419 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
|
420 |
+
cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
421 |
+
else:
|
422 |
+
raise ValueError('Encountered with unexpected bending network class...')
|
423 |
+
pts_offsets.append(cur_pts_exp - pts_exp)
|
424 |
+
pts_offsets = torch.stack(pts_offsets, dim=0)
|
425 |
+
pts_offsets = torch.sum(pts_offsets, dim=0)
|
426 |
+
pts_exp = pts_exp + pts_offsets
|
427 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
428 |
+
# if isinstance(self.bending_network, list): # pts ts #
|
429 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
430 |
+
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
431 |
+
# else:
|
432 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
433 |
+
else:
|
434 |
+
if isinstance(self.bending_network, list): # prev sdf network #
|
435 |
+
pts_offsets = []
|
436 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
437 |
+
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
438 |
+
pts_offsets.append(bended_pts_exp - pts_exp)
|
439 |
+
pts_offsets = torch.stack(pts_offsets, dim=0)
|
440 |
+
pts_offsets = torch.sum(pts_offsets, dim=0)
|
441 |
+
pts_exp = pts_exp + pts_offsets
|
442 |
+
else:
|
443 |
+
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
|
444 |
+
if len(pts.size()) == 3:
|
445 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
446 |
+
else:
|
447 |
+
pts = pts_exp
|
448 |
+
return pts
|
449 |
+
|
450 |
+
# pts: nn_batch x nn_samples x 3
|
451 |
+
if len(pts.size()) == 3:
|
452 |
+
nnb, nns = pts.size(0), pts.size(1)
|
453 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
454 |
+
else:
|
455 |
+
pts_exp = pts
|
456 |
+
# print(f"prior to deforming: {pts.size()}")
|
457 |
+
|
458 |
+
dist_pts_to_bkg_pts = torch.sum(
|
459 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
460 |
+
)
|
461 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
462 |
+
dist_mask_float = dist_mask.float()
|
463 |
+
|
464 |
+
# dist_mask_float #
|
465 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
466 |
+
cur_fr_pts_def = torch.sum(
|
467 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
468 |
+
)
|
469 |
+
dist_mask_float_summ = torch.sum(
|
470 |
+
dist_mask_float, dim=1
|
471 |
+
)
|
472 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
473 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
474 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
475 |
+
if len(pts.size()) == 3:
|
476 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
477 |
+
else:
|
478 |
+
pts = pts_exp
|
479 |
+
return pts #
|
480 |
+
|
481 |
+
|
482 |
+
def deform_pts_with_selector(self, pts, pts_ts=0, update_tot_def=True): # deform pts #
|
483 |
+
|
484 |
+
if self.use_bending_network:
|
485 |
+
if len(pts.size()) == 3:
|
486 |
+
nnb, nns = pts.size(0), pts.size(1)
|
487 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
488 |
+
else:
|
489 |
+
pts_exp = pts
|
490 |
+
# pts_ts #
|
491 |
+
if self.use_delta_bending:
|
492 |
+
if isinstance(self.bending_network, list):
|
493 |
+
bended_pts = []
|
494 |
+
queries_sdfs_selector = []
|
495 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
496 |
+
if cur_bending_network.use_opt_rigid_translations:
|
497 |
+
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
498 |
+
else:
|
499 |
+
# bended_pts_exp = pts_exp.clone()
|
500 |
+
if i_obj == 1 and pts_ts == 0:
|
501 |
+
bended_pts_exp = pts_exp
|
502 |
+
elif i_obj == 1:
|
503 |
+
for cur_pts_ts in range(pts_ts, 0, -1): ### before 0 ###
|
504 |
+
if isinstance(cur_bending_network, fields.BendingNetwork):
|
505 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
|
506 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkForceForward) or isinstance(cur_bending_network, fields.BendingNetworkRigidTransForward):
|
507 |
+
# input_pts, input_pts_ts, timestep_to_passive_mesh, act_sdf_net=None, details=None, special_loss_return=False
|
508 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh)
|
509 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkForceFieldForward):
|
510 |
+
# input_pts, input_pts_ts, timestep_to_passive_mesh, passive_sdf_net, details=None, special_loss_return=False
|
511 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1])
|
512 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForward):
|
513 |
+
# active_bending_net, active_sdf_net,
|
514 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
515 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV2):
|
516 |
+
# active_bending_net, active_sdf_net,
|
517 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
518 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV3):
|
519 |
+
# active_bending_net, active_sdf_net,
|
520 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
521 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV4):
|
522 |
+
# active_bending_net, active_sdf_net,
|
523 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
524 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV5):
|
525 |
+
# active_bending_net, active_sdf_net,
|
526 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
527 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV6):
|
528 |
+
# active_bending_net, active_sdf_net,
|
529 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
530 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV7):
|
531 |
+
# active_bending_net, active_sdf_net,
|
532 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
533 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV8):
|
534 |
+
# active_bending_net, active_sdf_net,
|
535 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0], update_tot_def=update_tot_def)
|
536 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkActiveForceFieldForwardV9):
|
537 |
+
# active_bending_net, active_sdf_net,
|
538 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0], update_tot_def=update_tot_def)
|
539 |
+
else:
|
540 |
+
raise ValueError(f"Unrecognized bending network type: {type(cur_bending_network)}")
|
541 |
+
else:
|
542 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
543 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
|
544 |
+
# _, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
|
545 |
+
_, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector_ndelta(bended_pts_exp, i_net=i_obj)
|
546 |
+
bended_pts.append(bended_pts_exp)
|
547 |
+
queries_sdfs_selector.append(cur_bended_pts_selecotr)
|
548 |
+
bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts # # bended_pts #
|
549 |
+
queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
|
550 |
+
# queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
|
551 |
+
#### get the final sdf_selector from queries_sdfs_selector ####
|
552 |
+
# sdf_selector = queries_sdfs_selector[:, -1]
|
553 |
+
# neg_neg = ((queries_sdfs_selector[:, 0] == 0).float() + (queries_sdfs_selector[:, -1] == 1).float()) > 1.5 #### both inside of the object
|
554 |
+
sdf_selector = 1 - queries_sdfs_selector[:, 0]
|
555 |
+
# neg_neg
|
556 |
+
# sdf_selector = queries_sdfs_selector
|
557 |
+
# delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
|
558 |
+
bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
|
559 |
+
# print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
|
560 |
+
# pts_exp = bended_pts.squeeze(1)
|
561 |
+
pts_exp = bended_pts
|
562 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
563 |
+
# if isinstance(self.bending_network, list):
|
564 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
565 |
+
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
566 |
+
# else:
|
567 |
+
|
568 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
569 |
+
else:
|
570 |
+
if isinstance(self.bending_network, list): # prev sdf network #
|
571 |
+
# pts_offsets = []
|
572 |
+
bended_pts = []
|
573 |
+
queries_sdfs_selector = []
|
574 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
575 |
+
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
576 |
+
# pts_offsets.append(bended_pts_exp - pts_exp)
|
577 |
+
_, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
|
578 |
+
bended_pts.append(bended_pts_exp)
|
579 |
+
queries_sdfs_selector.append(cur_bended_pts_selecotr)
|
580 |
+
bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
|
581 |
+
queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
|
582 |
+
# queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
|
583 |
+
sdf_selector = queries_sdfs_selector[:, -1]
|
584 |
+
# sdf_selector = queries_sdfs_selector
|
585 |
+
|
586 |
+
|
587 |
+
# delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
|
588 |
+
bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
|
589 |
+
# print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
|
590 |
+
pts_exp = bended_pts.squeeze(1)
|
591 |
+
|
592 |
+
# pts_offsets = torch.stack(pts_offsets, dim=0)
|
593 |
+
# pts_offsets = torch.sum(pts_offsets, dim=0)
|
594 |
+
# pts_exp = pts_exp + pts_offsets
|
595 |
+
else:
|
596 |
+
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
|
597 |
+
if len(pts.size()) == 3:
|
598 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
599 |
+
else:
|
600 |
+
pts = pts_exp
|
601 |
+
return pts, sdf_selector
|
602 |
+
else:
|
603 |
+
# pts: nn_batch x nn_samples x 3
|
604 |
+
if len(pts.size()) == 3:
|
605 |
+
nnb, nns = pts.size(0), pts.size(1)
|
606 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
607 |
+
else:
|
608 |
+
pts_exp = pts
|
609 |
+
# print(f"prior to deforming: {pts.size()}")
|
610 |
+
|
611 |
+
dist_pts_to_bkg_pts = torch.sum(
|
612 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
613 |
+
)
|
614 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
615 |
+
dist_mask_float = dist_mask.float()
|
616 |
+
|
617 |
+
# dist_mask_float #
|
618 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
619 |
+
cur_fr_pts_def = torch.sum(
|
620 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
621 |
+
)
|
622 |
+
dist_mask_float_summ = torch.sum(
|
623 |
+
dist_mask_float, dim=1
|
624 |
+
)
|
625 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
626 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
627 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
628 |
+
if len(pts.size()) == 3:
|
629 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
630 |
+
else:
|
631 |
+
pts = pts_exp
|
632 |
+
return pts #
|
633 |
+
|
634 |
+
|
635 |
+
def deform_pts_passive(self, pts, pts_ts=0):
|
636 |
+
|
637 |
+
if self.use_bending_network:
|
638 |
+
if len(pts.size()) == 3:
|
639 |
+
nnb, nns = pts.size(0), pts.size(1)
|
640 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
641 |
+
else:
|
642 |
+
pts_exp = pts
|
643 |
+
# pts_ts #
|
644 |
+
if self.use_delta_bending:
|
645 |
+
if pts_ts > 0:
|
646 |
+
for cur_pts_ts in range(pts_ts, 0, -1):
|
647 |
+
# if isinstance(self.bending_network, list):
|
648 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
649 |
+
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
650 |
+
# else:
|
651 |
+
if isinstance(self.bending_network[-1], fields.BendingNetwork):
|
652 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts)
|
653 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkForceForward) or isinstance(self.bending_network[-1], fields.BendingNetworkRigidTransForward):
|
654 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh)
|
655 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkForceFieldForward):
|
656 |
+
# input_pts, input_pts_ts, timestep_to_passive_mesh, passive_sdf_net, details=None, special_loss_return=False
|
657 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1])
|
658 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForward):
|
659 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
660 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV2):
|
661 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
662 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV3):
|
663 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
664 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV4):
|
665 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
666 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV5):
|
667 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
668 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV6):
|
669 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
670 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV7):
|
671 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
672 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV8):
|
673 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
674 |
+
elif isinstance(self.bending_network[-1], fields.BendingNetworkActiveForceFieldForwardV9):
|
675 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts, timestep_to_active_mesh=self.timestep_to_active_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, passive_sdf_net=self.sdf_network[1], active_bending_net=self.bending_network[0], active_sdf_net=self.sdf_network[0])
|
676 |
+
else:
|
677 |
+
raise ValueError(f"Unrecognized bending network type: {type(self.bending_network[-1])}")
|
678 |
+
# pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=cur_pts_ts)
|
679 |
+
else:
|
680 |
+
# if isinstance(self.bending_network, list):
|
681 |
+
# pts_offsets = []
|
682 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
683 |
+
# bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
684 |
+
# pts_offsets.append(bended_pts_exp - pts_exp)
|
685 |
+
# pts_offsets = torch.stack(pts_offsets, dim=0)
|
686 |
+
# pts_offsets = torch.sum(pts_offsets, dim=0)
|
687 |
+
# pts_exp = pts_exp + pts_offsets
|
688 |
+
# else:
|
689 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=pts_ts)
|
690 |
+
if len(pts.size()) == 3:
|
691 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
692 |
+
else:
|
693 |
+
pts = pts_exp
|
694 |
+
return pts
|
695 |
+
|
696 |
+
# pts: nn_batch x nn_samples x 3
|
697 |
+
if len(pts.size()) == 3:
|
698 |
+
nnb, nns = pts.size(0), pts.size(1)
|
699 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
700 |
+
else:
|
701 |
+
pts_exp = pts
|
702 |
+
# print(f"prior to deforming: {pts.size()}")
|
703 |
+
|
704 |
+
dist_pts_to_bkg_pts = torch.sum(
|
705 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
706 |
+
)
|
707 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
708 |
+
dist_mask_float = dist_mask.float()
|
709 |
+
|
710 |
+
# dist_mask_float #
|
711 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
712 |
+
cur_fr_pts_def = torch.sum(
|
713 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
714 |
+
)
|
715 |
+
dist_mask_float_summ = torch.sum(
|
716 |
+
dist_mask_float, dim=1
|
717 |
+
)
|
718 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
719 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
720 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
721 |
+
if len(pts.size()) == 3:
|
722 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
723 |
+
else:
|
724 |
+
pts = pts_exp
|
725 |
+
return pts #
|
726 |
+
|
727 |
+
# delta mesh as passive mesh #
|
728 |
+
def query_pts_sdf_fn_for_selector(self, pts):
|
729 |
+
# for negative
|
730 |
+
# 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field
|
731 |
+
# 2) negative in thie field and also negative in the previous field --->
|
732 |
+
# 2) for positive values of this current field --->
|
733 |
+
cur_sdf = self.sdf_network.sdf(pts)
|
734 |
+
prev_sdf = self.prev_sdf_network.sdf(pts)
|
735 |
+
neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
|
736 |
+
neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
|
737 |
+
|
738 |
+
neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
|
739 |
+
|
740 |
+
pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
|
741 |
+
pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
|
742 |
+
res_sdf = torch.zeros_like(cur_sdf)
|
743 |
+
res_sdf[neg_neg] = 1. #
|
744 |
+
res_sdf[neg_pos] = cur_sdf[neg_pos]
|
745 |
+
res_sdf[pos_neg] = cur_sdf[pos_neg]
|
746 |
+
|
747 |
+
# inside the residual mesh -> must be neg and pos
|
748 |
+
res_sdf_selector = torch.zeros_like(cur_sdf).long() #
|
749 |
+
# res_sdf_selector[neg_pos] = 1 # is the residual mesh
|
750 |
+
res_sdf_selector[neg_weq_pos] = 1
|
751 |
+
# res_sdf_selector[]
|
752 |
+
|
753 |
+
cat_cur_prev_sdf = torch.stack(
|
754 |
+
[cur_sdf, prev_sdf], dim=-1
|
755 |
+
)
|
756 |
+
minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
|
757 |
+
res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
|
758 |
+
|
759 |
+
return res_sdf, res_sdf_selector
|
760 |
+
|
761 |
+
def query_pts_sdf_fn_for_selector_ndelta(self, pts, i_net):
|
762 |
+
# for negative
|
763 |
+
# 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field
|
764 |
+
# 2) negative in thie field and also negative in the previous field --->
|
765 |
+
# 2) for positive values of this current field --->
|
766 |
+
passive_sdf = self.sdf_network[i_net].sdf(pts).squeeze(-1)
|
767 |
+
passive_sdf_selector = torch.zeros_like(passive_sdf).long()
|
768 |
+
passive_sdf_selector[passive_sdf <= 0.] = 1.
|
769 |
+
return passive_sdf, passive_sdf_selector
|
770 |
+
|
771 |
+
cur_sdf = self.sdf_network.sdf(pts)
|
772 |
+
prev_sdf = self.prev_sdf_network.sdf(pts)
|
773 |
+
neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
|
774 |
+
neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
|
775 |
+
|
776 |
+
neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
|
777 |
+
|
778 |
+
pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
|
779 |
+
pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
|
780 |
+
res_sdf = torch.zeros_like(cur_sdf)
|
781 |
+
res_sdf[neg_neg] = 1. #
|
782 |
+
res_sdf[neg_pos] = cur_sdf[neg_pos]
|
783 |
+
res_sdf[pos_neg] = cur_sdf[pos_neg]
|
784 |
+
|
785 |
+
# inside the residual mesh -> must be neg and pos
|
786 |
+
res_sdf_selector = torch.zeros_like(cur_sdf).long() #
|
787 |
+
# res_sdf_selector[neg_pos] = 1 # is the residual mesh
|
788 |
+
res_sdf_selector[neg_weq_pos] = 1
|
789 |
+
# res_sdf_selector[]
|
790 |
+
|
791 |
+
cat_cur_prev_sdf = torch.stack(
|
792 |
+
[cur_sdf, prev_sdf], dim=-1
|
793 |
+
)
|
794 |
+
minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
|
795 |
+
res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
|
796 |
+
|
797 |
+
return res_sdf, res_sdf_selector
|
798 |
+
|
799 |
+
|
800 |
+
def query_func_sdf(self, pts):
|
801 |
+
if isinstance(self.sdf_network, list):
|
802 |
+
tot_sdf_values = []
|
803 |
+
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
804 |
+
cur_sdf_values = cur_sdf_network.sdf(pts)
|
805 |
+
tot_sdf_values.append(cur_sdf_values)
|
806 |
+
tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
|
807 |
+
tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
808 |
+
sdf = tot_sdf_values
|
809 |
+
else:
|
810 |
+
sdf = self.sdf_network.sdf(pts)
|
811 |
+
return sdf
|
812 |
+
|
813 |
+
def query_func_sdf_passive(self, pts):
|
814 |
+
# if isinstance(self.sdf_network, list):
|
815 |
+
# tot_sdf_values = []
|
816 |
+
# for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
817 |
+
# cur_sdf_values = cur_sdf_network.sdf(pts)
|
818 |
+
# tot_sdf_values.append(cur_sdf_values)
|
819 |
+
# tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
|
820 |
+
# tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
821 |
+
# sdf = tot_sdf_values
|
822 |
+
# else:
|
823 |
+
sdf = self.sdf_network[-1].sdf(pts)
|
824 |
+
|
825 |
+
return sdf
|
826 |
+
|
827 |
+
|
828 |
+
def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None, pts_ts=0):
|
829 |
+
"""
|
830 |
+
Render background
|
831 |
+
"""
|
832 |
+
batch_size, n_samples = z_vals.shape
|
833 |
+
|
834 |
+
# Section length
|
835 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
836 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
837 |
+
mid_z_vals = z_vals + dists * 0.5
|
838 |
+
|
839 |
+
# Section midpoints #
|
840 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
|
841 |
+
|
842 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
843 |
+
pts = pts * 2 - 1
|
844 |
+
|
845 |
+
if self.use_selector:
|
846 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
847 |
+
else:
|
848 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
849 |
+
|
850 |
+
dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
|
851 |
+
pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
|
852 |
+
|
853 |
+
dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
|
854 |
+
|
855 |
+
pts = pts.reshape(-1, 3 + int(self.n_outside > 0)) ### deformed_pts ###
|
856 |
+
dirs = dirs.reshape(-1, 3)
|
857 |
+
|
858 |
+
if self.use_selector:
|
859 |
+
tot_density, tot_sampled_color = [], []
|
860 |
+
for i_nerf, cur_nerf in enumerate(nerf):
|
861 |
+
cur_density, cur_sampled_color = cur_nerf(pts, dirs)
|
862 |
+
tot_density.append(cur_density)
|
863 |
+
tot_sampled_color.append(cur_sampled_color)
|
864 |
+
tot_density = torch.stack(tot_density, dim=1)
|
865 |
+
tot_sampled_color = torch.stack(tot_sampled_color, dim=1) ### sampled colors
|
866 |
+
# print(f"tot_density: {tot_density.size()}, tot_sampled_color: {tot_sampled_color.size()}, sdf_selector: {sdf_selector.size()}")
|
867 |
+
density = batched_index_select(values=tot_density, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
|
868 |
+
sampled_color = batched_index_select(values=tot_sampled_color, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
|
869 |
+
else:
|
870 |
+
density, sampled_color = nerf(pts, dirs)
|
871 |
+
sampled_color = torch.sigmoid(sampled_color)
|
872 |
+
alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
|
873 |
+
alpha = alpha.reshape(batch_size, n_samples)
|
874 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
875 |
+
sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
|
876 |
+
color = (weights[:, :, None] * sampled_color).sum(dim=1)
|
877 |
+
if background_rgb is not None:
|
878 |
+
color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
|
879 |
+
|
880 |
+
return {
|
881 |
+
'color': color,
|
882 |
+
'sampled_color': sampled_color,
|
883 |
+
'alpha': alpha,
|
884 |
+
'weights': weights,
|
885 |
+
}
|
886 |
+
|
887 |
+
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s, pts_ts=0):
|
888 |
+
"""
|
889 |
+
Up sampling give a fixed inv_s
|
890 |
+
"""
|
891 |
+
batch_size, n_samples = z_vals.shape
|
892 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
|
893 |
+
|
894 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
895 |
+
pts = pts * 2 - 1
|
896 |
+
|
897 |
+
if self.use_selector:
|
898 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
899 |
+
else:
|
900 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
901 |
+
|
902 |
+
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
|
903 |
+
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
|
904 |
+
sdf = sdf.reshape(batch_size, n_samples)
|
905 |
+
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
|
906 |
+
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
|
907 |
+
mid_sdf = (prev_sdf + next_sdf) * 0.5
|
908 |
+
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
|
909 |
+
|
910 |
+
# ----------------------------------------------------------------------------------------------------------
|
911 |
+
# Use min value of [ cos, prev_cos ]
|
912 |
+
# Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
|
913 |
+
# robust when meeting situations like below:
|
914 |
+
#
|
915 |
+
# SDF
|
916 |
+
# ^
|
917 |
+
# |\ -----x----...
|
918 |
+
# | \ /
|
919 |
+
# | x x
|
920 |
+
# |---\----/-------------> 0 level
|
921 |
+
# | \ /
|
922 |
+
# | \/
|
923 |
+
# |
|
924 |
+
# ----------------------------------------------------------------------------------------------------------
|
925 |
+
prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
|
926 |
+
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
|
927 |
+
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
|
928 |
+
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
|
929 |
+
|
930 |
+
dist = (next_z_vals - prev_z_vals)
|
931 |
+
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
|
932 |
+
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
|
933 |
+
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
|
934 |
+
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
|
935 |
+
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
|
936 |
+
weights = alpha * torch.cumprod(
|
937 |
+
torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
938 |
+
|
939 |
+
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
|
940 |
+
return z_samples
|
941 |
+
|
942 |
+
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False, pts_ts=0):
|
943 |
+
batch_size, n_samples = z_vals.shape
|
944 |
+
_, n_importance = new_z_vals.shape
|
945 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
|
946 |
+
|
947 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
948 |
+
pts = pts * 2 - 1
|
949 |
+
|
950 |
+
if self.use_selector:
|
951 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
952 |
+
else:
|
953 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
954 |
+
|
955 |
+
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
|
956 |
+
z_vals, index = torch.sort(z_vals, dim=-1)
|
957 |
+
|
958 |
+
if not last:
|
959 |
+
if isinstance(self.sdf_network, list):
|
960 |
+
tot_new_sdf = []
|
961 |
+
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
962 |
+
cur_new_sdf = cur_sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
963 |
+
tot_new_sdf.append(cur_new_sdf)
|
964 |
+
tot_new_sdf = torch.stack(tot_new_sdf, dim=-1)
|
965 |
+
new_sdf, _ = torch.min(tot_new_sdf, dim=-1) #
|
966 |
+
else:
|
967 |
+
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
968 |
+
sdf = torch.cat([sdf, new_sdf], dim=-1)
|
969 |
+
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
|
970 |
+
index = index.reshape(-1)
|
971 |
+
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
|
972 |
+
|
973 |
+
return z_vals, sdf
|
974 |
+
|
975 |
+
|
976 |
+
|
977 |
+
def render_core(self,
|
978 |
+
rays_o,
|
979 |
+
rays_d,
|
980 |
+
z_vals,
|
981 |
+
sample_dist,
|
982 |
+
sdf_network,
|
983 |
+
deviation_network,
|
984 |
+
color_network,
|
985 |
+
background_alpha=None,
|
986 |
+
background_sampled_color=None,
|
987 |
+
background_rgb=None,
|
988 |
+
cos_anneal_ratio=0.0,
|
989 |
+
pts_ts=0):
|
990 |
+
batch_size, n_samples = z_vals.shape
|
991 |
+
|
992 |
+
# Section length
|
993 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
994 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
995 |
+
mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
|
996 |
+
|
997 |
+
# Section midpoints
|
998 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
|
999 |
+
dirs = rays_d[:, None, :].expand(pts.shape)
|
1000 |
+
|
1001 |
+
pts = pts.reshape(-1, 3) # pts, nn_ou
|
1002 |
+
dirs = dirs.reshape(-1, 3)
|
1003 |
+
|
1004 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
1005 |
+
|
1006 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
1007 |
+
pts = pts * 2 - 1
|
1008 |
+
|
1009 |
+
if self.use_selector:
|
1010 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
1011 |
+
else:
|
1012 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
1013 |
+
|
1014 |
+
if isinstance(sdf_network, list):
|
1015 |
+
tot_sdf = []
|
1016 |
+
tot_feature_vector = []
|
1017 |
+
tot_obj_sel = []
|
1018 |
+
tot_gradients = []
|
1019 |
+
for i_obj, cur_sdf_network in enumerate(sdf_network):
|
1020 |
+
cur_sdf_nn_output = cur_sdf_network(pts)
|
1021 |
+
cur_sdf, cur_feature_vector = cur_sdf_nn_output[:, :1], cur_sdf_nn_output[:, 1:]
|
1022 |
+
tot_sdf.append(cur_sdf)
|
1023 |
+
tot_feature_vector.append(cur_feature_vector)
|
1024 |
+
|
1025 |
+
gradients = cur_sdf_network.gradient(pts).squeeze()
|
1026 |
+
tot_gradients.append(gradients)
|
1027 |
+
tot_sdf = torch.stack(tot_sdf, dim=-1)
|
1028 |
+
|
1029 |
+
#
|
1030 |
+
if self.use_selector:
|
1031 |
+
sdf = batched_index_select(tot_sdf, sdf_selector.unsqueeze(1).unsqueeze(1), dim=2).squeeze(-1)
|
1032 |
+
obj_sel = sdf_selector.unsqueeze(1)
|
1033 |
+
else:
|
1034 |
+
sdf, obj_sel = torch.min(tot_sdf, dim=-1)
|
1035 |
+
feature_vector = torch.stack(tot_feature_vector, dim=1)
|
1036 |
+
|
1037 |
+
# batched_index_select
|
1038 |
+
# print(f"before sel: {feature_vector.size()}, obj_sel: {obj_sel.size()}")
|
1039 |
+
feature_vector = batched_index_select(values=feature_vector, indices=obj_sel, dim=1).squeeze(1)
|
1040 |
+
|
1041 |
+
# feature_vector = feature_vector[obj_sel.unsqueeze(-1), :].squeeze(1)
|
1042 |
+
# print(f"after sel: {feature_vector.size()}")
|
1043 |
+
tot_gradients = torch.stack(tot_gradients, dim=1)
|
1044 |
+
# gradients = tot_gradients[obj_sel.unsqueeze(-1)].squeeze(1)
|
1045 |
+
gradients = batched_index_select(values=tot_gradients, indices=obj_sel, dim=1).squeeze(1)
|
1046 |
+
# print(f"gradients: {gradients.size()}, tot_gradients: {tot_gradients.size()}")
|
1047 |
+
|
1048 |
+
else:
|
1049 |
+
sdf_nn_output = sdf_network(pts)
|
1050 |
+
sdf = sdf_nn_output[:, :1]
|
1051 |
+
feature_vector = sdf_nn_output[:, 1:]
|
1052 |
+
gradients = sdf_network.gradient(pts).squeeze()
|
1053 |
+
|
1054 |
+
if self.use_selector:
|
1055 |
+
tot_sampled_color = []
|
1056 |
+
for i_color_net, cur_color_network in enumerate(color_network):
|
1057 |
+
cur_sampled_color = cur_color_network(pts, gradients, dirs, feature_vector) # .reshape(batch_size, n_samples, 3)
|
1058 |
+
tot_sampled_color.append(cur_sampled_color)
|
1059 |
+
# print(f"tot_density: {tot_density.size()}, tot_sampled_color: {tot_sampled_color.size()}, sdf_selector: {sdf_selector.size()}")
|
1060 |
+
tot_sampled_color = torch.stack(tot_sampled_color, dim=1)
|
1061 |
+
sampled_color = batched_index_select(values=tot_sampled_color, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1).reshape(batch_size, n_samples, 3)
|
1062 |
+
else:
|
1063 |
+
sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
|
1064 |
+
|
1065 |
+
if self.use_selector and isinstance(deviation_network, list):
|
1066 |
+
tot_inv_s = []
|
1067 |
+
for i_dev_net, cur_deviation_network in enumerate(deviation_network):
|
1068 |
+
cur_inv_s = cur_deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6)
|
1069 |
+
tot_inv_s.append(cur_inv_s)
|
1070 |
+
tot_inv_s = torch.stack(tot_inv_s, dim=1)
|
1071 |
+
inv_s = batched_index_select(values=tot_inv_s, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
|
1072 |
+
# inv_s =
|
1073 |
+
else:
|
1074 |
+
# deviation network #
|
1075 |
+
inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
|
1076 |
+
inv_s = inv_s.expand(batch_size * n_samples, 1)
|
1077 |
+
|
1078 |
+
true_cos = (dirs * gradients).sum(-1, keepdim=True)
|
1079 |
+
|
1080 |
+
# "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
|
1081 |
+
# the cos value "not dead" at the beginning training iterations, for better convergence.
|
1082 |
+
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
|
1083 |
+
F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
|
1084 |
+
|
1085 |
+
# Estimate signed distances at section points
|
1086 |
+
estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
|
1087 |
+
estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
|
1088 |
+
|
1089 |
+
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
|
1090 |
+
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
|
1091 |
+
|
1092 |
+
p = prev_cdf - next_cdf
|
1093 |
+
c = prev_cdf
|
1094 |
+
|
1095 |
+
alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
|
1096 |
+
|
1097 |
+
pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
|
1098 |
+
inside_sphere = (pts_norm < 1.0).float().detach()
|
1099 |
+
relax_inside_sphere = (pts_norm < 1.2).float().detach()
|
1100 |
+
|
1101 |
+
# Render with background
|
1102 |
+
if background_alpha is not None:
|
1103 |
+
alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
|
1104 |
+
alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
|
1105 |
+
sampled_color = sampled_color * inside_sphere[:, :, None] +\
|
1106 |
+
background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
|
1107 |
+
sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
|
1108 |
+
|
1109 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
1110 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
1111 |
+
|
1112 |
+
color = (sampled_color * weights[:, :, None]).sum(dim=1)
|
1113 |
+
if background_rgb is not None: # Fixed background, usually black
|
1114 |
+
color = color + background_rgb * (1.0 - weights_sum)
|
1115 |
+
|
1116 |
+
# Eikonal loss
|
1117 |
+
gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
|
1118 |
+
dim=-1) - 1.0) ** 2
|
1119 |
+
gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
|
1120 |
+
|
1121 |
+
return {
|
1122 |
+
'color': color,
|
1123 |
+
'sdf': sdf,
|
1124 |
+
'dists': dists,
|
1125 |
+
'gradients': gradients.reshape(batch_size, n_samples, 3),
|
1126 |
+
's_val': 1.0 / inv_s,
|
1127 |
+
'mid_z_vals': mid_z_vals,
|
1128 |
+
'weights': weights,
|
1129 |
+
'cdf': c.reshape(batch_size, n_samples),
|
1130 |
+
'gradient_error': gradient_error,
|
1131 |
+
'inside_sphere': inside_sphere
|
1132 |
+
}
|
1133 |
+
|
1134 |
+
def per_sdf_query(self, pts):
|
1135 |
+
tot_sdfs = []
|
1136 |
+
for i_sdf_net, cur_sdf_network in enumerate(self.sdf_network):
|
1137 |
+
cur_sdf_value = cur_sdf_network.sdf(pts).squeeze(-1)
|
1138 |
+
tot_sdfs.append(cur_sdf_value)
|
1139 |
+
tot_sdfs = torch.stack(tot_sdfs, dim=1)
|
1140 |
+
return tot_sdfs
|
1141 |
+
|
1142 |
+
|
1143 |
+
def render(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
|
1144 |
+
batch_size = len(rays_o)
|
1145 |
+
sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere
|
1146 |
+
z_vals = torch.linspace(0.0, 1.0, self.n_samples) # linspace #
|
1147 |
+
z_vals = near + (far - near) * z_vals[None, :]
|
1148 |
+
|
1149 |
+
z_vals_outside = None
|
1150 |
+
if self.n_outside > 0:
|
1151 |
+
z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
|
1152 |
+
|
1153 |
+
n_samples = self.n_samples
|
1154 |
+
perturb = self.perturb
|
1155 |
+
|
1156 |
+
if perturb_overwrite >= 0:
|
1157 |
+
perturb = perturb_overwrite
|
1158 |
+
if perturb > 0:
|
1159 |
+
t_rand = (torch.rand([batch_size, 1]) - 0.5)
|
1160 |
+
z_vals = z_vals + t_rand * 2.0 / self.n_samples
|
1161 |
+
|
1162 |
+
if self.n_outside > 0: # z values output # n_outside #
|
1163 |
+
mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
|
1164 |
+
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
|
1165 |
+
lower = torch.cat([z_vals_outside[..., :1], mids], -1)
|
1166 |
+
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
|
1167 |
+
z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
|
1168 |
+
|
1169 |
+
if self.n_outside > 0:
|
1170 |
+
z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
|
1171 |
+
|
1172 |
+
background_alpha = None
|
1173 |
+
background_sampled_color = None
|
1174 |
+
|
1175 |
+
# Up sample
|
1176 |
+
if self.n_importance > 0:
|
1177 |
+
with torch.no_grad():
|
1178 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
1179 |
+
|
1180 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
1181 |
+
# sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
|
1182 |
+
# gt_sdf #
|
1183 |
+
|
1184 |
+
#
|
1185 |
+
# pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
|
1186 |
+
|
1187 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
1188 |
+
pts = pts * 2 - 1
|
1189 |
+
|
1190 |
+
if self.use_selector:
|
1191 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
1192 |
+
else:
|
1193 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts) # give nthe pts
|
1194 |
+
|
1195 |
+
pts_exp = pts.reshape(-1, 3)
|
1196 |
+
# minn_pts, _ = torch.min(pts_exp, dim=0)
|
1197 |
+
# maxx_pts, _ = torch.max(pts_exp, dim=0) # deformation field (not a rigid one) -> the meshes #
|
1198 |
+
# print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
|
1199 |
+
|
1200 |
+
# pts_to_near = pts - near.unsqueeze(1)
|
1201 |
+
# maxx_pts = 1.5; minn_pts = -1.5
|
1202 |
+
# # maxx_pts = 3; minn_pts = -3
|
1203 |
+
# # maxx_pts = 1; minn_pts = -1
|
1204 |
+
# pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
|
1205 |
+
|
1206 |
+
## render and iamges ####
|
1207 |
+
# if use_gt_sdf:
|
1208 |
+
# ### use the GT sdf field ####
|
1209 |
+
# # print(f"Using gt sdf :")
|
1210 |
+
# sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
|
1211 |
+
# sdf = torch.from_numpy(sdf).float().cuda()
|
1212 |
+
# sdf = sdf.reshape(batch_size, self.n_samples)
|
1213 |
+
# ### use the GT sdf field ####
|
1214 |
+
# else:
|
1215 |
+
# # pts_exp: (bsz x nn_s) x 3 -> (sdf_network) -> (bsz x nn_s)
|
1216 |
+
# #### use the optimized sdf field ####
|
1217 |
+
|
1218 |
+
# # sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
1219 |
+
|
1220 |
+
if isinstance(self.sdf_network, list):
|
1221 |
+
if self.use_selector:
|
1222 |
+
tot_sdf_values = []
|
1223 |
+
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
1224 |
+
cur_sdf_values = cur_sdf_network.sdf(pts_exp).squeeze(-1)
|
1225 |
+
tot_sdf_values.append(cur_sdf_values)
|
1226 |
+
tot_sdf_values = torch.stack(tot_sdf_values, dim=1)
|
1227 |
+
tot_sdf_values = batched_index_select(tot_sdf_values, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1)
|
1228 |
+
sdf = tot_sdf_values.reshape(batch_size, self.n_samples)
|
1229 |
+
else:
|
1230 |
+
# tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
1231 |
+
tot_sdf_values = []
|
1232 |
+
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
1233 |
+
cur_sdf_values = cur_sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
1234 |
+
tot_sdf_values.append(cur_sdf_values)
|
1235 |
+
tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
|
1236 |
+
tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
1237 |
+
sdf = tot_sdf_values
|
1238 |
+
else:
|
1239 |
+
sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
1240 |
+
|
1241 |
+
#### use the optimized sdf field ####
|
1242 |
+
|
1243 |
+
for i in range(self.up_sample_steps):
|
1244 |
+
new_z_vals = self.up_sample(rays_o,
|
1245 |
+
rays_d,
|
1246 |
+
z_vals,
|
1247 |
+
sdf,
|
1248 |
+
self.n_importance // self.up_sample_steps,
|
1249 |
+
64 * 2**i,
|
1250 |
+
pts_ts=pts_ts)
|
1251 |
+
z_vals, sdf = self.cat_z_vals(rays_o,
|
1252 |
+
rays_d,
|
1253 |
+
z_vals,
|
1254 |
+
new_z_vals,
|
1255 |
+
sdf,
|
1256 |
+
last=(i + 1 == self.up_sample_steps),
|
1257 |
+
pts_ts=pts_ts)
|
1258 |
+
|
1259 |
+
n_samples = self.n_samples + self.n_importance
|
1260 |
+
|
1261 |
+
# Background model
|
1262 |
+
if self.n_outside > 0:
|
1263 |
+
z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
|
1264 |
+
z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
|
1265 |
+
ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf, pts_ts=pts_ts)
|
1266 |
+
|
1267 |
+
background_sampled_color = ret_outside['sampled_color']
|
1268 |
+
background_alpha = ret_outside['alpha']
|
1269 |
+
|
1270 |
+
tot_sdfs = self.per_sdf_query(pts_exp)
|
1271 |
+
|
1272 |
+
# Render core
|
1273 |
+
ret_fine = self.render_core(rays_o, #
|
1274 |
+
rays_d,
|
1275 |
+
z_vals,
|
1276 |
+
sample_dist,
|
1277 |
+
self.sdf_network,
|
1278 |
+
self.deviation_network,
|
1279 |
+
self.color_network,
|
1280 |
+
background_rgb=background_rgb,
|
1281 |
+
background_alpha=background_alpha,
|
1282 |
+
background_sampled_color=background_sampled_color,
|
1283 |
+
cos_anneal_ratio=cos_anneal_ratio,
|
1284 |
+
pts_ts=pts_ts)
|
1285 |
+
|
1286 |
+
color_fine = ret_fine['color']
|
1287 |
+
weights = ret_fine['weights']
|
1288 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
1289 |
+
gradients = ret_fine['gradients']
|
1290 |
+
s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
|
1291 |
+
|
1292 |
+
return {
|
1293 |
+
'color_fine': color_fine,
|
1294 |
+
's_val': s_val,
|
1295 |
+
'cdf_fine': ret_fine['cdf'],
|
1296 |
+
'weight_sum': weights_sum,
|
1297 |
+
'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
|
1298 |
+
'gradients': gradients,
|
1299 |
+
'weights': weights,
|
1300 |
+
'gradient_error': ret_fine['gradient_error'],
|
1301 |
+
'inside_sphere': ret_fine['inside_sphere'],
|
1302 |
+
'tot_sdfs': tot_sdfs,
|
1303 |
+
}
|
1304 |
+
|
1305 |
+
|
1306 |
+
|
1307 |
+
def render_def(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False, update_tot_def=True):
|
1308 |
+
batch_size = len(rays_o)
|
1309 |
+
# sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere
|
1310 |
+
z_vals = torch.linspace(0.0, 1.0, self.n_samples)
|
1311 |
+
z_vals = near + (far - near) * z_vals[None, :]
|
1312 |
+
|
1313 |
+
z_vals_outside = None
|
1314 |
+
if self.n_outside > 0:
|
1315 |
+
z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
|
1316 |
+
|
1317 |
+
n_samples = self.n_samples
|
1318 |
+
perturb = self.perturb
|
1319 |
+
|
1320 |
+
if perturb_overwrite >= 0:
|
1321 |
+
perturb = perturb_overwrite
|
1322 |
+
if perturb > 0:
|
1323 |
+
t_rand = (torch.rand([batch_size, 1]) - 0.5)
|
1324 |
+
z_vals = z_vals + t_rand * 2.0 / self.n_samples
|
1325 |
+
|
1326 |
+
if self.n_outside > 0: # z values output # n_outside #
|
1327 |
+
mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
|
1328 |
+
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
|
1329 |
+
lower = torch.cat([z_vals_outside[..., :1], mids], -1)
|
1330 |
+
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
|
1331 |
+
z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
|
1332 |
+
|
1333 |
+
if self.n_outside > 0:
|
1334 |
+
z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
|
1335 |
+
|
1336 |
+
background_alpha = None
|
1337 |
+
background_sampled_color = None
|
1338 |
+
|
1339 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
1340 |
+
|
1341 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
1342 |
+
# sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
|
1343 |
+
# gt_sdf #
|
1344 |
+
|
1345 |
+
#
|
1346 |
+
# pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
|
1347 |
+
|
1348 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
1349 |
+
pts = pts * 2 - 1
|
1350 |
+
|
1351 |
+
if self.use_selector:
|
1352 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts, update_tot_def=update_tot_def)
|
1353 |
+
else:
|
1354 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts, update_tot_def=update_tot_def) # give nthe pts
|
1355 |
+
|
1356 |
+
return {
|
1357 |
+
'defed_pts': pts
|
1358 |
+
}
|
1359 |
+
#
|
1360 |
+
|
1361 |
+
|
1362 |
+
def extract_fields_from_tets_selector_self(self, bound_min, bound_max, resolution, i_ts, passive=False):
|
1363 |
+
# load tet via resolution #
|
1364 |
+
# scale them via bounds #
|
1365 |
+
# extract the geometry #
|
1366 |
+
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
|
1367 |
+
device = bound_min.device
|
1368 |
+
# if resolution in [64, 70, 80, 90, 100]:
|
1369 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
|
1370 |
+
# else:
|
1371 |
+
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
1372 |
+
if not os.path.exists(tet_fn):
|
1373 |
+
tet_fn = f"/data/xueyi/NeuS/data/tets/{100}_compress.npz"
|
1374 |
+
tets = np.load(tet_fn)
|
1375 |
+
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
|
1376 |
+
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
|
1377 |
+
# split #
|
1378 |
+
# verts; verts; #
|
1379 |
+
minn_verts, _ = torch.min(verts, dim=0)
|
1380 |
+
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
|
1381 |
+
# scale_verts = maxx_verts - minn_verts
|
1382 |
+
scale_bounds = bound_max - bound_min # scale bounds #
|
1383 |
+
|
1384 |
+
### scale the vertices ###
|
1385 |
+
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
1386 |
+
|
1387 |
+
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
1388 |
+
|
1389 |
+
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
|
1390 |
+
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
|
1391 |
+
|
1392 |
+
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
|
1393 |
+
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
|
1394 |
+
|
1395 |
+
sdf_values = []
|
1396 |
+
N = 64
|
1397 |
+
query_bundles = N ** 3 ### N^3
|
1398 |
+
query_NNs = scaled_verts.size(0) // query_bundles
|
1399 |
+
if query_NNs * query_bundles < scaled_verts.size(0):
|
1400 |
+
query_NNs += 1
|
1401 |
+
for i_query in range(query_NNs):
|
1402 |
+
cur_bundle_st = i_query * query_bundles
|
1403 |
+
cur_bundle_ed = (i_query + 1) * query_bundles
|
1404 |
+
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
|
1405 |
+
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
|
1406 |
+
# if def_func is not None:
|
1407 |
+
cur_query_pts, sdf_selector = self.deform_pts_with_selector(cur_query_pts, pts_ts=i_ts)
|
1408 |
+
# cur_query_pts, _
|
1409 |
+
# cur_query_vals = query_func(cur_query_pts)
|
1410 |
+
|
1411 |
+
|
1412 |
+
|
1413 |
+
if passive:
|
1414 |
+
cur_query_vals = self.sdf_network[1].sdf(cur_query_pts) # .squeeze(-1)
|
1415 |
+
else:
|
1416 |
+
tot_sdf_values = []
|
1417 |
+
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
1418 |
+
cur_sdf_values = cur_sdf_network.sdf(cur_query_pts).squeeze(-1)
|
1419 |
+
tot_sdf_values.append(cur_sdf_values)
|
1420 |
+
tot_sdf_values = torch.stack(tot_sdf_values, dim=1)
|
1421 |
+
tot_sdf_values = batched_index_select(tot_sdf_values, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1)
|
1422 |
+
cur_query_vals = tot_sdf_values.unsqueeze(1)
|
1423 |
+
# sdf = tot_sdf_values.reshape(batch_size, self.n_samples)
|
1424 |
+
# for i_obj,
|
1425 |
+
sdf_values.append(cur_query_vals)
|
1426 |
+
sdf_values = torch.cat(sdf_values, dim=0)
|
1427 |
+
# print(f"queryed sdf values: {sdf_values.size()}") #
|
1428 |
+
|
1429 |
+
gt_sdf_fn = "/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy"
|
1430 |
+
if not os.path.exists(gt_sdf_fn):
|
1431 |
+
gt_sdf_fn = "/data/xueyi/NeuS/data/100_sdf_values.npy"
|
1432 |
+
GT_sdf_values = np.load(gt_sdf_fn, allow_pickle=True)
|
1433 |
+
|
1434 |
+
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
|
1435 |
+
|
1436 |
+
# intrinsic, tet values, pts values, sdf network #
|
1437 |
+
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
|
1438 |
+
tet_table, num_tets_table = create_tetmesh_variables(device)
|
1439 |
+
|
1440 |
+
sdf_values = sdf_values.squeeze(-1) # how the rendering #
|
1441 |
+
|
1442 |
+
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
|
1443 |
+
# print(f"scaled_verts: {scaled_verts.size()}, ")
|
1444 |
+
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
1445 |
+
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
1446 |
+
# marching_tets_tetmesh ##
|
1447 |
+
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
1448 |
+
### use the GT sdf values for the marching tets ###
|
1449 |
+
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
1450 |
+
|
1451 |
+
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
|
1452 |
+
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
|
1453 |
+
|
1454 |
+
|
1455 |
+
def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
|
1456 |
+
return extract_geometry(bound_min, # extract geometry #
|
1457 |
+
bound_max,
|
1458 |
+
resolution=resolution,
|
1459 |
+
threshold=threshold,
|
1460 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts),
|
1461 |
+
query_func=lambda pts: -self.query_func_sdf(pts)
|
1462 |
+
)
|
1463 |
+
|
1464 |
+
# if self.deform_pts_with_selector:
|
1465 |
+
# pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
1466 |
+
def extract_geometry_tets(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
|
1467 |
+
if wdef:
|
1468 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1469 |
+
bound_max,
|
1470 |
+
resolution=resolution,
|
1471 |
+
threshold=threshold,
|
1472 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1473 |
+
def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts),
|
1474 |
+
selector=True)
|
1475 |
+
else:
|
1476 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1477 |
+
bound_max,
|
1478 |
+
resolution=resolution,
|
1479 |
+
threshold=threshold,
|
1480 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts)
|
1481 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1482 |
+
selector=True
|
1483 |
+
)
|
1484 |
+
|
1485 |
+
def extract_geometry_tets_passive(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
|
1486 |
+
if wdef:
|
1487 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1488 |
+
bound_max,
|
1489 |
+
resolution=resolution,
|
1490 |
+
threshold=threshold,
|
1491 |
+
query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1492 |
+
def_func=lambda pts: self.deform_pts_passive(pts, pts_ts=pts_ts),
|
1493 |
+
selector=False
|
1494 |
+
)
|
1495 |
+
# return extract_geometry_tets(bound_min, # extract geometry #
|
1496 |
+
# bound_max,
|
1497 |
+
# resolution=resolution,
|
1498 |
+
# threshold=threshold,
|
1499 |
+
# query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1500 |
+
# def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts),
|
1501 |
+
# selector=True)
|
1502 |
+
else:
|
1503 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1504 |
+
bound_max,
|
1505 |
+
resolution=resolution,
|
1506 |
+
threshold=threshold,
|
1507 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts)
|
1508 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1509 |
+
selector=False
|
1510 |
+
)
|
models/renderer_def_multi_objs_rigidtrans_forward.py
ADDED
@@ -0,0 +1,1603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import logging
|
6 |
+
import mcubes
|
7 |
+
from icecream import ic
|
8 |
+
import os
|
9 |
+
|
10 |
+
import trimesh
|
11 |
+
from pysdf import SDF
|
12 |
+
|
13 |
+
import models.fields as fields
|
14 |
+
|
15 |
+
from uni_rep.rep_3d.dmtet import marching_tets_tetmesh, create_tetmesh_variables
|
16 |
+
|
17 |
+
def batched_index_select(values, indices, dim = 1):
|
18 |
+
value_dims = values.shape[(dim + 1):]
|
19 |
+
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
|
20 |
+
indices = indices[(..., *((None,) * len(value_dims)))]
|
21 |
+
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
|
22 |
+
value_expand_len = len(indices_shape) - (dim + 1)
|
23 |
+
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
|
24 |
+
|
25 |
+
value_expand_shape = [-1] * len(values.shape)
|
26 |
+
expand_slice = slice(dim, (dim + value_expand_len))
|
27 |
+
value_expand_shape[expand_slice] = indices.shape[expand_slice]
|
28 |
+
values = values.expand(*value_expand_shape)
|
29 |
+
|
30 |
+
dim += value_expand_len
|
31 |
+
return values.gather(dim, indices)
|
32 |
+
|
33 |
+
|
34 |
+
def create_mt_variable(device):
|
35 |
+
triangle_table = torch.tensor(
|
36 |
+
[
|
37 |
+
[-1, -1, -1, -1, -1, -1],
|
38 |
+
[1, 0, 2, -1, -1, -1],
|
39 |
+
[4, 0, 3, -1, -1, -1],
|
40 |
+
[1, 4, 2, 1, 3, 4],
|
41 |
+
[3, 1, 5, -1, -1, -1],
|
42 |
+
[2, 3, 0, 2, 5, 3],
|
43 |
+
[1, 4, 0, 1, 5, 4],
|
44 |
+
[4, 2, 5, -1, -1, -1],
|
45 |
+
[4, 5, 2, -1, -1, -1],
|
46 |
+
[4, 1, 0, 4, 5, 1],
|
47 |
+
[3, 2, 0, 3, 5, 2],
|
48 |
+
[1, 3, 5, -1, -1, -1],
|
49 |
+
[4, 1, 2, 4, 3, 1],
|
50 |
+
[3, 0, 4, -1, -1, -1],
|
51 |
+
[2, 0, 1, -1, -1, -1],
|
52 |
+
[-1, -1, -1, -1, -1, -1]
|
53 |
+
], dtype=torch.long, device=device)
|
54 |
+
|
55 |
+
num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device)
|
56 |
+
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device)
|
57 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device))
|
58 |
+
return triangle_table, num_triangles_table, base_tet_edges, v_id
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
def extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=None):
|
63 |
+
# load tet via resolution #
|
64 |
+
# scale them via bounds #
|
65 |
+
# extract the geometry #
|
66 |
+
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
|
67 |
+
device = bound_min.device
|
68 |
+
# if resolution in [64, 70, 80, 90, 100]:
|
69 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
|
70 |
+
# else:
|
71 |
+
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
72 |
+
tets = np.load(tet_fn)
|
73 |
+
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
|
74 |
+
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
|
75 |
+
# split #
|
76 |
+
# verts; verts; #
|
77 |
+
minn_verts, _ = torch.min(verts, dim=0)
|
78 |
+
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
|
79 |
+
# scale_verts = maxx_verts - minn_verts
|
80 |
+
scale_bounds = bound_max - bound_min # scale bounds #
|
81 |
+
|
82 |
+
### scale the vertices ###
|
83 |
+
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
84 |
+
|
85 |
+
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
86 |
+
|
87 |
+
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
|
88 |
+
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
|
89 |
+
|
90 |
+
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
|
91 |
+
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
|
92 |
+
|
93 |
+
sdf_values = []
|
94 |
+
N = 64
|
95 |
+
query_bundles = N ** 3 ### N^3
|
96 |
+
query_NNs = scaled_verts.size(0) // query_bundles
|
97 |
+
if query_NNs * query_bundles < scaled_verts.size(0):
|
98 |
+
query_NNs += 1
|
99 |
+
for i_query in range(query_NNs):
|
100 |
+
cur_bundle_st = i_query * query_bundles
|
101 |
+
cur_bundle_ed = (i_query + 1) * query_bundles
|
102 |
+
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
|
103 |
+
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
|
104 |
+
if def_func is not None:
|
105 |
+
# cur_query_pts = def_func(cur_query_pts)
|
106 |
+
cur_query_pts, _ = def_func(cur_query_pts)
|
107 |
+
cur_query_vals = query_func(cur_query_pts)
|
108 |
+
sdf_values.append(cur_query_vals)
|
109 |
+
sdf_values = torch.cat(sdf_values, dim=0)
|
110 |
+
# print(f"queryed sdf values: {sdf_values.size()}") #
|
111 |
+
|
112 |
+
GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
|
113 |
+
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
|
114 |
+
|
115 |
+
# intrinsic, tet values, pts values, sdf network #
|
116 |
+
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
|
117 |
+
tet_table, num_tets_table = create_tetmesh_variables(device)
|
118 |
+
|
119 |
+
sdf_values = sdf_values.squeeze(-1) # how the rendering #
|
120 |
+
|
121 |
+
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
|
122 |
+
# print(f"scaled_verts: {scaled_verts.size()}, ")
|
123 |
+
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
124 |
+
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
125 |
+
# marching_tets_tetmesh ##
|
126 |
+
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
127 |
+
### use the GT sdf values for the marching tets ###
|
128 |
+
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
129 |
+
|
130 |
+
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
|
131 |
+
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
|
132 |
+
|
133 |
+
|
134 |
+
def extract_fields_from_tets_selector(bound_min, bound_max, resolution, query_func, def_func=None):
|
135 |
+
# load tet via resolution #
|
136 |
+
# scale them via bounds #
|
137 |
+
# extract the geometry #
|
138 |
+
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
|
139 |
+
device = bound_min.device
|
140 |
+
# if resolution in [64, 70, 80, 90, 100]:
|
141 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
|
142 |
+
# else:
|
143 |
+
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
144 |
+
tets = np.load(tet_fn)
|
145 |
+
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
|
146 |
+
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
|
147 |
+
# split #
|
148 |
+
# verts; verts; #
|
149 |
+
minn_verts, _ = torch.min(verts, dim=0)
|
150 |
+
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
|
151 |
+
# scale_verts = maxx_verts - minn_verts
|
152 |
+
scale_bounds = bound_max - bound_min # scale bounds #
|
153 |
+
|
154 |
+
### scale the vertices ###
|
155 |
+
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
156 |
+
|
157 |
+
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
158 |
+
|
159 |
+
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
|
160 |
+
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
|
161 |
+
|
162 |
+
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
|
163 |
+
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
|
164 |
+
|
165 |
+
sdf_values = []
|
166 |
+
N = 64
|
167 |
+
query_bundles = N ** 3 ### N^3
|
168 |
+
query_NNs = scaled_verts.size(0) // query_bundles
|
169 |
+
if query_NNs * query_bundles < scaled_verts.size(0):
|
170 |
+
query_NNs += 1
|
171 |
+
for i_query in range(query_NNs):
|
172 |
+
cur_bundle_st = i_query * query_bundles
|
173 |
+
cur_bundle_ed = (i_query + 1) * query_bundles
|
174 |
+
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
|
175 |
+
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
|
176 |
+
if def_func is not None:
|
177 |
+
# cur_query_pts = def_func(cur_query_pts)
|
178 |
+
cur_query_pts, _ = def_func(cur_query_pts)
|
179 |
+
cur_query_vals = query_func(cur_query_pts)
|
180 |
+
sdf_values.append(cur_query_vals)
|
181 |
+
sdf_values = torch.cat(sdf_values, dim=0)
|
182 |
+
# print(f"queryed sdf values: {sdf_values.size()}") #
|
183 |
+
|
184 |
+
GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
|
185 |
+
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
|
186 |
+
|
187 |
+
# intrinsic, tet values, pts values, sdf network #
|
188 |
+
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
|
189 |
+
tet_table, num_tets_table = create_tetmesh_variables(device)
|
190 |
+
|
191 |
+
sdf_values = sdf_values.squeeze(-1) # how the rendering #
|
192 |
+
|
193 |
+
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
|
194 |
+
# print(f"scaled_verts: {scaled_verts.size()}, ")
|
195 |
+
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
196 |
+
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
197 |
+
# marching_tets_tetmesh ##
|
198 |
+
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
199 |
+
### use the GT sdf values for the marching tets ###
|
200 |
+
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
201 |
+
|
202 |
+
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
|
203 |
+
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
|
204 |
+
|
205 |
+
|
206 |
+
def extract_fields(bound_min, bound_max, resolution, query_func):
|
207 |
+
N = 64
|
208 |
+
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
|
209 |
+
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
|
210 |
+
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
|
211 |
+
|
212 |
+
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
|
213 |
+
with torch.no_grad():
|
214 |
+
for xi, xs in enumerate(X):
|
215 |
+
for yi, ys in enumerate(Y):
|
216 |
+
for zi, zs in enumerate(Z):
|
217 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs)
|
218 |
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
|
219 |
+
val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
|
220 |
+
u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
|
221 |
+
# should save u here #
|
222 |
+
# save_u_path = os.path.join("/data2/datasets/diffsim/neus/exp/hand_test/womask_sphere_reverse_value/other_saved", "sdf_values.npy")
|
223 |
+
# np.save(save_u_path, u)
|
224 |
+
# print(f"u saved to {save_u_path}")
|
225 |
+
return u
|
226 |
+
|
227 |
+
|
228 |
+
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
|
229 |
+
print('threshold: {}'.format(threshold))
|
230 |
+
|
231 |
+
## using maching cubes ###
|
232 |
+
u = extract_fields(bound_min, bound_max, resolution, query_func)
|
233 |
+
vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
|
234 |
+
b_max_np = bound_max.detach().cpu().numpy()
|
235 |
+
b_min_np = bound_min.detach().cpu().numpy()
|
236 |
+
|
237 |
+
vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
238 |
+
### using maching cubes ###
|
239 |
+
|
240 |
+
### using marching tets ###
|
241 |
+
# vertices, triangles = extract_fields_from_tets(bound_min, bound_max, resolution, query_func)
|
242 |
+
# vertices = vertices.detach().cpu().numpy()
|
243 |
+
# triangles = triangles.detach().cpu().numpy()
|
244 |
+
### using marching tets ###
|
245 |
+
|
246 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
247 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
248 |
+
|
249 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
250 |
+
return vertices, triangles
|
251 |
+
|
252 |
+
def extract_geometry_tets(bound_min, bound_max, resolution, threshold, query_func, def_func=None):
|
253 |
+
# print('threshold: {}'.format(threshold))
|
254 |
+
|
255 |
+
### using maching cubes ###
|
256 |
+
# u = extract_fields(bound_min, bound_max, resolution, query_func)
|
257 |
+
# vertices, triangles = mcubes.marching_cubes(u, threshold) # grid sdf and marching cubes #
|
258 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
259 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
260 |
+
|
261 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
262 |
+
### using maching cubes ###
|
263 |
+
|
264 |
+
##
|
265 |
+
### using marching tets ### fiels from tets ##
|
266 |
+
vertices, triangles, tet_sdf_values, GT_verts, GT_faces = extract_fields_from_tets(bound_min, bound_max, resolution, query_func, def_func=def_func)
|
267 |
+
# vertices = vertices.detach().cpu().numpy()
|
268 |
+
# triangles = triangles.detach().cpu().numpy()
|
269 |
+
### using marching tets ###
|
270 |
+
|
271 |
+
# b_max_np = bound_max.detach().cpu().numpy()
|
272 |
+
# b_min_np = bound_min.detach().cpu().numpy()
|
273 |
+
#
|
274 |
+
|
275 |
+
# vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
|
276 |
+
return vertices, triangles, tet_sdf_values, GT_verts, GT_faces
|
277 |
+
|
278 |
+
|
279 |
+
def sample_pdf(bins, weights, n_samples, det=False):
|
280 |
+
# This implementation is from NeRF
|
281 |
+
# Get pdf
|
282 |
+
weights = weights + 1e-5 # prevent nans
|
283 |
+
pdf = weights / torch.sum(weights, -1, keepdim=True)
|
284 |
+
cdf = torch.cumsum(pdf, -1)
|
285 |
+
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
286 |
+
# Take uniform samples
|
287 |
+
if det:
|
288 |
+
u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
|
289 |
+
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
|
290 |
+
else:
|
291 |
+
u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
|
292 |
+
|
293 |
+
# Invert CDF # invert cdf #
|
294 |
+
u = u.contiguous()
|
295 |
+
inds = torch.searchsorted(cdf, u, right=True)
|
296 |
+
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
297 |
+
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
298 |
+
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
299 |
+
|
300 |
+
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
301 |
+
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
302 |
+
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
303 |
+
|
304 |
+
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
305 |
+
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
306 |
+
t = (u - cdf_g[..., 0]) / denom
|
307 |
+
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
308 |
+
|
309 |
+
return samples
|
310 |
+
|
311 |
+
|
312 |
+
def load_GT_vertices(GT_meshes_folder):
|
313 |
+
tot_meshes_fns = os.listdir(GT_meshes_folder)
|
314 |
+
tot_meshes_fns = [fn for fn in tot_meshes_fns if fn.endswith(".obj")]
|
315 |
+
tot_mesh_verts = []
|
316 |
+
tot_mesh_faces = []
|
317 |
+
n_tot_verts = 0
|
318 |
+
for fn in tot_meshes_fns:
|
319 |
+
cur_mesh_fn = os.path.join(GT_meshes_folder, fn)
|
320 |
+
obj_mesh = trimesh.load(cur_mesh_fn, process=False)
|
321 |
+
# obj_mesh.remove_degenerate_faces(height=1e-06)
|
322 |
+
|
323 |
+
verts_obj = np.array(obj_mesh.vertices)
|
324 |
+
faces_obj = np.array(obj_mesh.faces)
|
325 |
+
|
326 |
+
tot_mesh_verts.append(verts_obj)
|
327 |
+
tot_mesh_faces.append(faces_obj + n_tot_verts)
|
328 |
+
n_tot_verts += verts_obj.shape[0]
|
329 |
+
|
330 |
+
# tot_mesh_faces.append(faces_obj)
|
331 |
+
tot_mesh_verts = np.concatenate(tot_mesh_verts, axis=0)
|
332 |
+
tot_mesh_faces = np.concatenate(tot_mesh_faces, axis=0)
|
333 |
+
return tot_mesh_verts, tot_mesh_faces
|
334 |
+
|
335 |
+
|
336 |
+
class NeuSRenderer:
|
337 |
+
def __init__(self,
|
338 |
+
nerf,
|
339 |
+
sdf_network,
|
340 |
+
deviation_network,
|
341 |
+
color_network,
|
342 |
+
n_samples,
|
343 |
+
n_importance,
|
344 |
+
n_outside,
|
345 |
+
up_sample_steps,
|
346 |
+
perturb):
|
347 |
+
self.nerf = nerf # multiple sdf networks and deviation networks and xxx #
|
348 |
+
self.sdf_network = sdf_network
|
349 |
+
self.deviation_network = deviation_network
|
350 |
+
self.color_network = color_network
|
351 |
+
self.n_samples = n_samples
|
352 |
+
self.n_importance = n_importance
|
353 |
+
self.n_outside = n_outside
|
354 |
+
self.up_sample_steps = up_sample_steps
|
355 |
+
self.perturb = perturb
|
356 |
+
|
357 |
+
GT_meshes_folder = "/home/xueyi/diffsim/DiffHand/assets/hand"
|
358 |
+
self.mesh_vertices, self.mesh_faces = load_GT_vertices(GT_meshes_folder=GT_meshes_folder)
|
359 |
+
maxx_pts = 25.
|
360 |
+
minn_pts = -15.
|
361 |
+
self.mesh_vertices = (self.mesh_vertices - minn_pts) / (maxx_pts - minn_pts)
|
362 |
+
f = SDF(self.mesh_vertices, self.mesh_faces)
|
363 |
+
self.gt_sdf = f ## a unite sphere or box
|
364 |
+
|
365 |
+
self.minn_pts = 0
|
366 |
+
self.maxx_pts = 1.
|
367 |
+
|
368 |
+
# self.minn_pts = -1.5 # gorudn-truth states with the deformation -> update the sdf value fiedl
|
369 |
+
# self.maxx_pts = 1.5 #
|
370 |
+
self.bkg_pts = ... # TODO: the bkg pts # bkg_pts; # bkg_pts_defs #
|
371 |
+
self.cur_fr_bkg_pts_defs = ... # TODO: set the cur_bkg_pts_defs for each frame #
|
372 |
+
self.dist_interp_thres = ... # TODO: set the cur_bkg_pts_defs #
|
373 |
+
|
374 |
+
self.bending_network = ... # TODO: add the bending network #
|
375 |
+
self.use_bending_network = ... # TODO: set the property #
|
376 |
+
self.use_delta_bending = ... # TODO
|
377 |
+
self.prev_sdf_network = ... # TODO
|
378 |
+
self.use_selector = False
|
379 |
+
# self.bending_network_rigidtrans_forward = ... ## TODO: set the rigidjtrans forward ###
|
380 |
+
# timestep_to_mesh, timestep_to_passive_mesh, bending_net, bending_net_passive, act_sdf_net, details=None, special_loss_return=False
|
381 |
+
self.timestep_to_mesh = ... ## TODO
|
382 |
+
self.timestep_to_passive_mesh = ... ### TODO
|
383 |
+
self.bending_net = ... ## TODO
|
384 |
+
self.bending_net_passive = ... ### TODO
|
385 |
+
# self.act_sdf_net = ... ### TODO
|
386 |
+
self.bending_net_kinematic = ... ### TODO
|
387 |
+
self.time_to_act_joints = ... # ## TODO
|
388 |
+
# use bending network #
|
389 |
+
# two bending netwrok
|
390 |
+
# two sdf networks # deform pts kinematic #
|
391 |
+
|
392 |
+
def deform_pts_kinematic(self, pts, pts_ts=0): # deform pts #
|
393 |
+
|
394 |
+
if self.use_bending_network:
|
395 |
+
if len(pts.size()) == 3:
|
396 |
+
nnb, nns = pts.size(0), pts.size(1)
|
397 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
398 |
+
else:
|
399 |
+
pts_exp = pts
|
400 |
+
# pts_ts #
|
401 |
+
if self.use_delta_bending:
|
402 |
+
|
403 |
+
if isinstance(self.bending_net_kinematic, list):
|
404 |
+
pts_offsets = [] #
|
405 |
+
for i_obj, cur_bending_network in enumerate(self.bending_net_kinematic):
|
406 |
+
if isinstance(cur_bending_network, fields.BendingNetwork):
|
407 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
408 |
+
cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
|
409 |
+
# elif isinstance(cur_bending_network, fields.BendingNetworkForward):
|
410 |
+
# cur_pts_exp = cur_bending_network(input_pts=pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
|
411 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
|
412 |
+
cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
413 |
+
else:
|
414 |
+
raise ValueError('Encountered with unexpected bending network class...')
|
415 |
+
pts_offsets.append(cur_pts_exp - pts_exp)
|
416 |
+
pts_offsets = torch.stack(pts_offsets, dim=0)
|
417 |
+
pts_offsets = torch.sum(pts_offsets, dim=0)
|
418 |
+
pts_exp = pts_exp + pts_offsets
|
419 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
420 |
+
# if isinstance(self.bending_network, list): # pts ts #
|
421 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
422 |
+
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
423 |
+
# else:
|
424 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
425 |
+
else:
|
426 |
+
if isinstance(self.bending_net_kinematic, list): # prev sdf network #
|
427 |
+
pts_offsets = []
|
428 |
+
for i_obj, cur_bending_network in enumerate(self.bending_net_kinematic):
|
429 |
+
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
430 |
+
pts_offsets.append(bended_pts_exp - pts_exp)
|
431 |
+
pts_offsets = torch.stack(pts_offsets, dim=0)
|
432 |
+
pts_offsets = torch.sum(pts_offsets, dim=0)
|
433 |
+
pts_exp = pts_exp + pts_offsets
|
434 |
+
else:
|
435 |
+
pts_exp = self.bending_net_kinematic(pts_exp, input_pts_ts=pts_ts)
|
436 |
+
if len(pts.size()) == 3:
|
437 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
438 |
+
else:
|
439 |
+
pts = pts_exp
|
440 |
+
return pts
|
441 |
+
|
442 |
+
# pts: nn_batch x nn_samples x 3
|
443 |
+
if len(pts.size()) == 3:
|
444 |
+
nnb, nns = pts.size(0), pts.size(1)
|
445 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
446 |
+
else:
|
447 |
+
pts_exp = pts
|
448 |
+
# print(f"prior to deforming: {pts.size()}")
|
449 |
+
|
450 |
+
dist_pts_to_bkg_pts = torch.sum(
|
451 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
452 |
+
)
|
453 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
454 |
+
dist_mask_float = dist_mask.float()
|
455 |
+
|
456 |
+
# dist_mask_float #
|
457 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
458 |
+
cur_fr_pts_def = torch.sum(
|
459 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
460 |
+
)
|
461 |
+
dist_mask_float_summ = torch.sum(
|
462 |
+
dist_mask_float, dim=1
|
463 |
+
)
|
464 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
465 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
466 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
467 |
+
if len(pts.size()) == 3:
|
468 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
469 |
+
else:
|
470 |
+
pts = pts_exp
|
471 |
+
return pts #
|
472 |
+
|
473 |
+
|
474 |
+
def deform_pts_kinematic_active(self, pts, pts_ts=0): # deform pts #
|
475 |
+
|
476 |
+
if self.use_bending_network:
|
477 |
+
if len(pts.size()) == 3:
|
478 |
+
nnb, nns = pts.size(0), pts.size(1)
|
479 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
480 |
+
else:
|
481 |
+
pts_exp = pts
|
482 |
+
# pts_ts #
|
483 |
+
if self.use_delta_bending:
|
484 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
485 |
+
pts_exp = self.bending_net(pts_exp, input_pts_ts=cur_pts_ts)
|
486 |
+
# if isinstance(self.bending_net_kinematic, list):
|
487 |
+
# pts_offsets = [] #
|
488 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_net_kinematic):
|
489 |
+
# if isinstance(cur_bending_network, fields.BendingNetwork):
|
490 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
491 |
+
# cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
|
492 |
+
# # elif isinstance(cur_bending_network, fields.BendingNetworkForward):
|
493 |
+
# # cur_pts_exp = cur_bending_network(input_pts=pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
|
494 |
+
# elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
|
495 |
+
# cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
496 |
+
# else:
|
497 |
+
# raise ValueError('Encountered with unexpected bending network class...')
|
498 |
+
# pts_offsets.append(cur_pts_exp - pts_exp)
|
499 |
+
# pts_offsets = torch.stack(pts_offsets, dim=0)
|
500 |
+
# pts_offsets = torch.sum(pts_offsets, dim=0)
|
501 |
+
# pts_exp = pts_exp + pts_offsets
|
502 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
503 |
+
# if isinstance(self.bending_network, list): # pts ts #
|
504 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
505 |
+
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
506 |
+
# else:
|
507 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
508 |
+
else:
|
509 |
+
pts_exp = self.bending_net(pts_exp, input_pts_ts=pts_ts)
|
510 |
+
# if isinstance(self.bending_net_kinematic, list): # prev sdf network #
|
511 |
+
# pts_offsets = []
|
512 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_net_kinematic):
|
513 |
+
# bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
514 |
+
# pts_offsets.append(bended_pts_exp - pts_exp)
|
515 |
+
# pts_offsets = torch.stack(pts_offsets, dim=0)
|
516 |
+
# pts_offsets = torch.sum(pts_offsets, dim=0)
|
517 |
+
# pts_exp = pts_exp + pts_offsets
|
518 |
+
# else:
|
519 |
+
# pts_exp = self.bending_net_kinematic(pts_exp, input_pts_ts=pts_ts)
|
520 |
+
if len(pts.size()) == 3:
|
521 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
522 |
+
else:
|
523 |
+
pts = pts_exp
|
524 |
+
return pts
|
525 |
+
|
526 |
+
# pts: nn_batch x nn_samples x 3
|
527 |
+
if len(pts.size()) == 3:
|
528 |
+
nnb, nns = pts.size(0), pts.size(1)
|
529 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
530 |
+
else:
|
531 |
+
pts_exp = pts
|
532 |
+
# print(f"prior to deforming: {pts.size()}")
|
533 |
+
|
534 |
+
dist_pts_to_bkg_pts = torch.sum(
|
535 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
536 |
+
)
|
537 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
538 |
+
dist_mask_float = dist_mask.float()
|
539 |
+
|
540 |
+
# dist_mask_float #
|
541 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
542 |
+
cur_fr_pts_def = torch.sum(
|
543 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
544 |
+
)
|
545 |
+
dist_mask_float_summ = torch.sum(
|
546 |
+
dist_mask_float, dim=1
|
547 |
+
)
|
548 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
549 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
550 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
551 |
+
if len(pts.size()) == 3:
|
552 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
553 |
+
else:
|
554 |
+
pts = pts_exp
|
555 |
+
return pts #
|
556 |
+
|
557 |
+
# get the pts and render the pts #
|
558 |
+
# pts and the rendering pts #
|
559 |
+
def deform_pts(self, pts, pts_ts=0): # deform pts #
|
560 |
+
|
561 |
+
if self.use_bending_network:
|
562 |
+
if len(pts.size()) == 3:
|
563 |
+
nnb, nns = pts.size(0), pts.size(1)
|
564 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
565 |
+
else:
|
566 |
+
pts_exp = pts
|
567 |
+
# pts_ts #
|
568 |
+
if self.use_delta_bending:
|
569 |
+
|
570 |
+
if isinstance(self.bending_network, list):
|
571 |
+
pts_offsets = []
|
572 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
573 |
+
if isinstance(cur_bending_network, fields.BendingNetwork):
|
574 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
575 |
+
cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
|
576 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkForward):
|
577 |
+
for cur_pts_ts in range(pts_ts-1, -1, -1):
|
578 |
+
cur_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
|
579 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkForwardJointDyn):
|
580 |
+
# for cur_pts_ts in range(pts_ts-1, -1, -1):
|
581 |
+
for cur_pts_ts in range(pts_ts, 0, -1):
|
582 |
+
cur_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, timestep_to_joints_pts=self.time_to_act_joints, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
|
583 |
+
# elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
|
584 |
+
# cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
585 |
+
else:
|
586 |
+
raise ValueError('Encountered with unexpected bending network class...')
|
587 |
+
pts_offsets.append(cur_pts_exp - pts_exp)
|
588 |
+
pts_offsets = torch.stack(pts_offsets, dim=0)
|
589 |
+
pts_offsets = torch.sum(pts_offsets, dim=0)
|
590 |
+
pts_exp = pts_exp + pts_offsets
|
591 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
592 |
+
# if isinstance(self.bending_network, list): # pts ts #
|
593 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
594 |
+
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
595 |
+
# else:
|
596 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
597 |
+
else:
|
598 |
+
if isinstance(self.bending_network, list): # prev sdf network #
|
599 |
+
pts_offsets = []
|
600 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
601 |
+
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
602 |
+
pts_offsets.append(bended_pts_exp - pts_exp)
|
603 |
+
pts_offsets = torch.stack(pts_offsets, dim=0)
|
604 |
+
pts_offsets = torch.sum(pts_offsets, dim=0)
|
605 |
+
pts_exp = pts_exp + pts_offsets
|
606 |
+
else:
|
607 |
+
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
|
608 |
+
if len(pts.size()) == 3:
|
609 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
610 |
+
else:
|
611 |
+
pts = pts_exp
|
612 |
+
return pts
|
613 |
+
|
614 |
+
# pts: nn_batch x nn_samples x 3
|
615 |
+
if len(pts.size()) == 3:
|
616 |
+
nnb, nns = pts.size(0), pts.size(1)
|
617 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
618 |
+
else:
|
619 |
+
pts_exp = pts
|
620 |
+
# print(f"prior to deforming: {pts.size()}")
|
621 |
+
|
622 |
+
dist_pts_to_bkg_pts = torch.sum(
|
623 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
624 |
+
)
|
625 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
626 |
+
dist_mask_float = dist_mask.float()
|
627 |
+
|
628 |
+
# dist_mask_float #
|
629 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
630 |
+
cur_fr_pts_def = torch.sum(
|
631 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
632 |
+
)
|
633 |
+
dist_mask_float_summ = torch.sum(
|
634 |
+
dist_mask_float, dim=1
|
635 |
+
)
|
636 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
637 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
638 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
639 |
+
if len(pts.size()) == 3:
|
640 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
641 |
+
else:
|
642 |
+
pts = pts_exp
|
643 |
+
return pts #
|
644 |
+
|
645 |
+
|
646 |
+
def deform_pts_with_selector(self, pts, pts_ts=0): # deform pts #
|
647 |
+
|
648 |
+
if self.use_bending_network:
|
649 |
+
if len(pts.size()) == 3:
|
650 |
+
nnb, nns = pts.size(0), pts.size(1)
|
651 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
652 |
+
else:
|
653 |
+
pts_exp = pts
|
654 |
+
# pts_ts #
|
655 |
+
if self.use_delta_bending:
|
656 |
+
if isinstance(self.bending_network, list):
|
657 |
+
bended_pts = []
|
658 |
+
queries_sdfs_selector = []
|
659 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
660 |
+
# if cur_bending_network.use_opt_rigid_translations:
|
661 |
+
# bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
662 |
+
# else:
|
663 |
+
# # bended_pts_exp = pts_exp.clone()
|
664 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
665 |
+
# bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
|
666 |
+
|
667 |
+
if isinstance(cur_bending_network, fields.BendingNetwork):
|
668 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
669 |
+
bended_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts)
|
670 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkForward):
|
671 |
+
for cur_pts_ts in range(pts_ts-1, -1, -1):
|
672 |
+
bended_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
|
673 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkForwardJointDyn):
|
674 |
+
# for cur_pts_ts in range(pts_ts-1, -1, -1):
|
675 |
+
for cur_pts_ts in range(pts_ts, 0, -1):
|
676 |
+
bended_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else bended_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, timestep_to_joints_pts=self.time_to_act_joints, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
|
677 |
+
|
678 |
+
if pts_ts == 0:
|
679 |
+
bended_pts_exp = pts_exp.clone()
|
680 |
+
_, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
|
681 |
+
bended_pts.append(bended_pts_exp)
|
682 |
+
queries_sdfs_selector.append(cur_bended_pts_selecotr)
|
683 |
+
bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
|
684 |
+
queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
|
685 |
+
# queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
|
686 |
+
sdf_selector = queries_sdfs_selector[:, -1]
|
687 |
+
# sdf_selector = queries_sdfs_selector
|
688 |
+
# delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
|
689 |
+
# print(f"bended_pts: {bended_pts.size()}, sdf_selector: {sdf_selector.size()}, maxx_sdf_selector: {torch.max(sdf_selector)}, minn_sdf_selector: {torch.min(sdf_selector)}")
|
690 |
+
bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
|
691 |
+
# print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
|
692 |
+
# pts_exp = bended_pts.squeeze(1)
|
693 |
+
pts_exp = bended_pts
|
694 |
+
|
695 |
+
|
696 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
697 |
+
# if isinstance(self.bending_network, list):
|
698 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
699 |
+
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
700 |
+
# else:
|
701 |
+
|
702 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
703 |
+
else:
|
704 |
+
if isinstance(self.bending_network, list): # prev sdf network #
|
705 |
+
# pts_offsets = []
|
706 |
+
bended_pts = []
|
707 |
+
queries_sdfs_selector = []
|
708 |
+
for i_obj, cur_bending_network in enumerate(self.bending_network):
|
709 |
+
bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
710 |
+
# pts_offsets.append(bended_pts_exp - pts_exp)
|
711 |
+
_, cur_bended_pts_selecotr = self.query_pts_sdf_fn_for_selector(bended_pts_exp)
|
712 |
+
bended_pts.append(bended_pts_exp)
|
713 |
+
queries_sdfs_selector.append(cur_bended_pts_selecotr)
|
714 |
+
bended_pts = torch.stack(bended_pts, dim=1) # nn_pts x 2 x 3 for bended pts #
|
715 |
+
queries_sdfs_selector = torch.stack(queries_sdfs_selector, dim=1) # nn_pts x 2
|
716 |
+
# queries_sdfs_selector = (queries_sdfs_selector.sum(dim=1) > 0.5).float().long()
|
717 |
+
sdf_selector = queries_sdfs_selector[:, -1]
|
718 |
+
# sdf_selector = queries_sdfs_selector
|
719 |
+
|
720 |
+
|
721 |
+
# delta_sdf, sdf_selector = self.query_pts_sdf_fn_for_selector(pts_exp)
|
722 |
+
bended_pts = batched_index_select(values=bended_pts, indices=sdf_selector.unsqueeze(1), dim=1).squeeze(1) # nn_pts x 3 #
|
723 |
+
# print(f"bended_pts: {bended_pts.size()}, pts_exp: {pts_exp.size()}")
|
724 |
+
pts_exp = bended_pts.squeeze(1)
|
725 |
+
|
726 |
+
# pts_offsets = torch.stack(pts_offsets, dim=0)
|
727 |
+
# pts_offsets = torch.sum(pts_offsets, dim=0)
|
728 |
+
# pts_exp = pts_exp + pts_offsets
|
729 |
+
else:
|
730 |
+
pts_exp = self.bending_network(pts_exp, input_pts_ts=pts_ts)
|
731 |
+
if len(pts.size()) == 3:
|
732 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
733 |
+
else:
|
734 |
+
pts = pts_exp
|
735 |
+
return pts, sdf_selector
|
736 |
+
|
737 |
+
# pts: nn_batch x nn_samples x 3
|
738 |
+
if len(pts.size()) == 3:
|
739 |
+
nnb, nns = pts.size(0), pts.size(1)
|
740 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
741 |
+
else:
|
742 |
+
pts_exp = pts
|
743 |
+
# print(f"prior to deforming: {pts.size()}")
|
744 |
+
|
745 |
+
dist_pts_to_bkg_pts = torch.sum(
|
746 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
747 |
+
)
|
748 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
749 |
+
dist_mask_float = dist_mask.float()
|
750 |
+
|
751 |
+
# dist_mask_float #
|
752 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
753 |
+
cur_fr_pts_def = torch.sum(
|
754 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
755 |
+
)
|
756 |
+
dist_mask_float_summ = torch.sum(
|
757 |
+
dist_mask_float, dim=1
|
758 |
+
)
|
759 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
760 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
761 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
762 |
+
if len(pts.size()) == 3:
|
763 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
764 |
+
else:
|
765 |
+
pts = pts_exp
|
766 |
+
return pts #
|
767 |
+
|
768 |
+
|
769 |
+
def deform_pts_passive(self, pts, pts_ts=0):
|
770 |
+
|
771 |
+
if self.use_bending_network:
|
772 |
+
if len(pts.size()) == 3:
|
773 |
+
nnb, nns = pts.size(0), pts.size(1)
|
774 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
775 |
+
else:
|
776 |
+
pts_exp = pts
|
777 |
+
# pts_ts #
|
778 |
+
|
779 |
+
|
780 |
+
|
781 |
+
if self.use_delta_bending:
|
782 |
+
cur_bending_network = self.bending_network[-1]
|
783 |
+
if isinstance(cur_bending_network, fields.BendingNetwork):
|
784 |
+
for cur_pts_ts in range(pts_ts, -1, -1):
|
785 |
+
cur_pts_exp = cur_bending_network(pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts)
|
786 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkForward):
|
787 |
+
for cur_pts_ts in range(pts_ts-1, -1, -1):
|
788 |
+
cur_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
|
789 |
+
elif isinstance(cur_bending_network, fields.BendingNetworkForwardJointDyn):
|
790 |
+
# for cur_pts_ts in range(pts_ts-1, -1, -1):
|
791 |
+
for cur_pts_ts in range(pts_ts, 0, -1):
|
792 |
+
cur_pts_exp = cur_bending_network(input_pts=pts_exp if cur_pts_ts == pts_ts else cur_pts_exp, input_pts_ts=cur_pts_ts, timestep_to_mesh=self.timestep_to_mesh, timestep_to_passive_mesh=self.timestep_to_passive_mesh, timestep_to_joints_pts=self.time_to_act_joints, bending_net=self.bending_net, bending_net_passive=self.bending_net_passive, act_sdf_net=self.prev_sdf_network, details=None, special_loss_return=False)
|
793 |
+
# elif isinstance(cur_bending_network, fields.BendingNetworkRigidTrans):
|
794 |
+
# cur_pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
795 |
+
else:
|
796 |
+
raise ValueError('Encountered with unexpected bending network class...')
|
797 |
+
# for cur_pts_ts in range(pts_ts, -1, -1):
|
798 |
+
# if isinstance(self.bending_network, list):
|
799 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
800 |
+
# pts_exp = cur_bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
801 |
+
# else:
|
802 |
+
# pts_exp = self.bending_network(pts_exp, input_pts_ts=cur_pts_ts)
|
803 |
+
# time_to_offset = {
|
804 |
+
# 0:[-0.03338804, 0.07566567, 0.0958022 ],
|
805 |
+
# 1:[-0.05909395, 0.05454276, 0.09974975],
|
806 |
+
# 2: [-0.07214502, -0.00118192, 0.09003166],
|
807 |
+
# 3: [-0.10040219, -0.01334709, 0.08493543],
|
808 |
+
# 4: [-0.10047092, -0.01264334, 0.05320398],
|
809 |
+
# 5: [-0.09152254, 0.00722668, 0.0101514 ],
|
810 |
+
# }
|
811 |
+
|
812 |
+
if pts_ts > 0:
|
813 |
+
pts_exp = cur_pts_exp
|
814 |
+
else:
|
815 |
+
pts_exp = pts_exp
|
816 |
+
else:
|
817 |
+
# if isinstance(self.bending_network, list):
|
818 |
+
# pts_offsets = []
|
819 |
+
# for i_obj, cur_bending_network in enumerate(self.bending_network):
|
820 |
+
# bended_pts_exp = cur_bending_network(pts_exp, input_pts_ts=pts_ts)
|
821 |
+
# pts_offsets.append(bended_pts_exp - pts_exp)
|
822 |
+
# pts_offsets = torch.stack(pts_offsets, dim=0)
|
823 |
+
# pts_offsets = torch.sum(pts_offsets, dim=0)
|
824 |
+
# pts_exp = pts_exp + pts_offsets
|
825 |
+
# else:
|
826 |
+
pts_exp = self.bending_network[-1](pts_exp, input_pts_ts=pts_ts)
|
827 |
+
if len(pts.size()) == 3:
|
828 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
829 |
+
else:
|
830 |
+
pts = pts_exp
|
831 |
+
return pts
|
832 |
+
|
833 |
+
# pts: nn_batch x nn_samples x 3
|
834 |
+
if len(pts.size()) == 3:
|
835 |
+
nnb, nns = pts.size(0), pts.size(1)
|
836 |
+
pts_exp = pts.contiguous().view(nnb * nns, -1).contiguous()
|
837 |
+
else:
|
838 |
+
pts_exp = pts
|
839 |
+
# print(f"prior to deforming: {pts.size()}")
|
840 |
+
|
841 |
+
dist_pts_to_bkg_pts = torch.sum(
|
842 |
+
(pts_exp.unsqueeze(1) - self.bkg_pts.unsqueeze(0)) ** 2, dim=-1 ## nn_pts_exp x nn_bkg_pts
|
843 |
+
)
|
844 |
+
dist_mask = dist_pts_to_bkg_pts <= self.dist_interp_thres #
|
845 |
+
dist_mask_float = dist_mask.float()
|
846 |
+
|
847 |
+
# dist_mask_float #
|
848 |
+
cur_fr_bkg_def_exp = self.cur_fr_bkg_pts_defs.unsqueeze(0).repeat(pts_exp.size(0), 1, 1).contiguous()
|
849 |
+
cur_fr_pts_def = torch.sum(
|
850 |
+
cur_fr_bkg_def_exp * dist_mask_float.unsqueeze(-1), dim=1
|
851 |
+
)
|
852 |
+
dist_mask_float_summ = torch.sum(
|
853 |
+
dist_mask_float, dim=1
|
854 |
+
)
|
855 |
+
dist_mask_float_summ = torch.clamp(dist_mask_float_summ, min=1)
|
856 |
+
cur_fr_pts_def = cur_fr_pts_def / dist_mask_float_summ.unsqueeze(-1) # bkg pts deformation #
|
857 |
+
pts_exp = pts_exp - cur_fr_pts_def
|
858 |
+
if len(pts.size()) == 3:
|
859 |
+
pts = pts_exp.contiguous().view(nnb, nns, -1).contiguous()
|
860 |
+
else:
|
861 |
+
pts = pts_exp
|
862 |
+
return pts #
|
863 |
+
|
864 |
+
|
865 |
+
def query_pts_sdf_fn_for_selector(self, pts):
|
866 |
+
# for negative
|
867 |
+
# 1) inside the current mesh but outside the previous mesh ---> negative sdf for this field but positive for another field
|
868 |
+
# 2) negative in thie field and also negative in the previous field --->
|
869 |
+
# 2) for positive values of this current field --->
|
870 |
+
# maxx_pts, _ = torch.max(pts, dim=0)
|
871 |
+
# minn_pts, _ = torch.min(pts, dim=0)
|
872 |
+
|
873 |
+
cur_sdf = self.sdf_network.sdf(pts).squeeze(-1)
|
874 |
+
prev_sdf = self.prev_sdf_network.sdf(pts).squeeze(-1)
|
875 |
+
neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
|
876 |
+
neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
|
877 |
+
|
878 |
+
neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
|
879 |
+
|
880 |
+
pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
|
881 |
+
pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
|
882 |
+
res_sdf = torch.zeros_like(cur_sdf)
|
883 |
+
|
884 |
+
# print(f"res_sdf: {res_sdf.size()}, neg_neg: {neg_neg.size()}, pts: {pts.size()}, maxx_pts: {maxx_pts}, minn_pts: {minn_pts}")
|
885 |
+
# if torch.sum(neg_neg.float()).item() > 0.:
|
886 |
+
res_sdf[neg_neg] = 1. #
|
887 |
+
# if torch.sum(neg_pos.float()).item() > 0.:
|
888 |
+
res_sdf[neg_pos] = cur_sdf[neg_pos]
|
889 |
+
# if torch.sum(pos_neg.float()).item() > 0.:
|
890 |
+
res_sdf[pos_neg] = cur_sdf[pos_neg]
|
891 |
+
|
892 |
+
# inside the residual mesh -> must be neg and pos
|
893 |
+
res_sdf_selector = torch.zeros_like(cur_sdf).long() #
|
894 |
+
# res_sdf_selector[neg_pos] = 1 # is the residual mesh
|
895 |
+
# if torch.sum(neg_weq_pos.float()).item() > 0.:
|
896 |
+
res_sdf_selector[neg_weq_pos] = 1
|
897 |
+
# res_sdf_selector[]
|
898 |
+
|
899 |
+
cat_cur_prev_sdf = torch.stack(
|
900 |
+
[cur_sdf, prev_sdf], dim=-1
|
901 |
+
)
|
902 |
+
minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
|
903 |
+
|
904 |
+
if torch.sum(pos_pos.float()).item() > 0.:
|
905 |
+
res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
|
906 |
+
|
907 |
+
return res_sdf, res_sdf_selector
|
908 |
+
|
909 |
+
def query_func_sdf(self, pts):
|
910 |
+
# if isinstance(self.sdf_network, list):
|
911 |
+
# tot_sdf_values = []
|
912 |
+
# for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
913 |
+
# cur_sdf_values = cur_sdf_network.sdf(pts)
|
914 |
+
# tot_sdf_values.append(cur_sdf_values)
|
915 |
+
# tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
|
916 |
+
# tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
917 |
+
# sdf = tot_sdf_values
|
918 |
+
# else:
|
919 |
+
# sdf = self.sdf_network.sdf(pts)
|
920 |
+
|
921 |
+
|
922 |
+
cur_sdf = self.sdf_network.sdf(pts)
|
923 |
+
prev_sdf = self.prev_sdf_network.sdf(pts)
|
924 |
+
neg_neg = ((cur_sdf < 0.).float() + (prev_sdf < 0.).float()) > 1.5
|
925 |
+
neg_pos = ((cur_sdf < 0.).float() + (prev_sdf >= 0.).float()) > 1.5
|
926 |
+
|
927 |
+
neg_weq_pos = ((cur_sdf <= 0.).float() + (prev_sdf > 0.).float()) > 1.5
|
928 |
+
|
929 |
+
pos_neg = ((cur_sdf >= 0.).float() + (prev_sdf < 0.).float()) > 1.5
|
930 |
+
pos_pos = ((cur_sdf >= 0.).float() + (prev_sdf >= 0.).float()) > 1.5
|
931 |
+
res_sdf = torch.zeros_like(cur_sdf)
|
932 |
+
|
933 |
+
# print(f"res_sdf: {res_sdf.size()}, neg_neg: {neg_neg.size()}, pts: {pts.size()}, maxx_pts: {maxx_pts}, minn_pts: {minn_pts}")
|
934 |
+
# if torch.sum(neg_neg.float()).item() > 0.:
|
935 |
+
res_sdf[neg_neg] = prev_sdf[neg_neg]
|
936 |
+
# if torch.sum(neg_pos.float()).item() > 0.:
|
937 |
+
res_sdf[neg_pos] = cur_sdf[neg_pos]
|
938 |
+
# if torch.sum(pos_neg.float()).item() > 0.:
|
939 |
+
res_sdf[pos_neg] = cur_sdf[pos_neg]
|
940 |
+
|
941 |
+
# inside the residual mesh -> must be neg and pos
|
942 |
+
# res_sdf_selector = torch.zeros_like(cur_sdf).long() #
|
943 |
+
# res_sdf_selector[neg_pos] = 1 # is the residual mesh
|
944 |
+
# if torch.sum(neg_weq_pos.float()).item() > 0.:
|
945 |
+
# res_sdf_selector[neg_weq_pos] = 1
|
946 |
+
# res_sdf_selector[]
|
947 |
+
|
948 |
+
cat_cur_prev_sdf = torch.stack(
|
949 |
+
[cur_sdf, prev_sdf], dim=-1
|
950 |
+
)
|
951 |
+
minn_cur_prev_sdf, _ = torch.min(cat_cur_prev_sdf, dim=-1)
|
952 |
+
|
953 |
+
# if torch.sum(pos_pos.float()).item() > 0.:
|
954 |
+
res_sdf[pos_pos] = minn_cur_prev_sdf[pos_pos]
|
955 |
+
|
956 |
+
return res_sdf
|
957 |
+
|
958 |
+
def query_func_active(self, pts):
|
959 |
+
# if isinstance(self.sdf_network, list):
|
960 |
+
# tot_sdf_values = []
|
961 |
+
# for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
962 |
+
# cur_sdf_values = cur_sdf_network.sdf(pts)
|
963 |
+
# tot_sdf_values.append(cur_sdf_values)
|
964 |
+
# tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
|
965 |
+
# tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
966 |
+
# sdf = tot_sdf_values
|
967 |
+
# else:
|
968 |
+
sdf = self.prev_sdf_network.sdf(pts)
|
969 |
+
return sdf
|
970 |
+
|
971 |
+
def query_func_sdf_passive(self, pts):
|
972 |
+
# if isinstance(self.sdf_network, list):
|
973 |
+
# tot_sdf_values = []
|
974 |
+
# for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
975 |
+
# cur_sdf_values = cur_sdf_network.sdf(pts)
|
976 |
+
# tot_sdf_values.append(cur_sdf_values)
|
977 |
+
# tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
|
978 |
+
# tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
979 |
+
# sdf = tot_sdf_values
|
980 |
+
# else:
|
981 |
+
sdf = self.sdf_network[-1].sdf(pts)
|
982 |
+
|
983 |
+
return sdf
|
984 |
+
|
985 |
+
|
986 |
+
def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None, pts_ts=0):
|
987 |
+
"""
|
988 |
+
Render background
|
989 |
+
"""
|
990 |
+
batch_size, n_samples = z_vals.shape
|
991 |
+
|
992 |
+
# Section length
|
993 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
994 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
995 |
+
mid_z_vals = z_vals + dists * 0.5
|
996 |
+
|
997 |
+
# Section midpoints #
|
998 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 #
|
999 |
+
|
1000 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
1001 |
+
pts = pts * 2 - 1
|
1002 |
+
|
1003 |
+
if self.use_selector:
|
1004 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
1005 |
+
else:
|
1006 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
1007 |
+
|
1008 |
+
dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
|
1009 |
+
pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 #
|
1010 |
+
|
1011 |
+
dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
|
1012 |
+
|
1013 |
+
pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
|
1014 |
+
dirs = dirs.reshape(-1, 3)
|
1015 |
+
|
1016 |
+
density, sampled_color = nerf(pts, dirs)
|
1017 |
+
sampled_color = torch.sigmoid(sampled_color)
|
1018 |
+
alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
|
1019 |
+
alpha = alpha.reshape(batch_size, n_samples)
|
1020 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
1021 |
+
sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
|
1022 |
+
color = (weights[:, :, None] * sampled_color).sum(dim=1)
|
1023 |
+
if background_rgb is not None:
|
1024 |
+
color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
|
1025 |
+
|
1026 |
+
return {
|
1027 |
+
'color': color,
|
1028 |
+
'sampled_color': sampled_color,
|
1029 |
+
'alpha': alpha,
|
1030 |
+
'weights': weights,
|
1031 |
+
}
|
1032 |
+
|
1033 |
+
def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s, pts_ts=0):
|
1034 |
+
"""
|
1035 |
+
Up sampling give a fixed inv_s
|
1036 |
+
"""
|
1037 |
+
batch_size, n_samples = z_vals.shape
|
1038 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
|
1039 |
+
|
1040 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
1041 |
+
pts = pts * 2 - 1
|
1042 |
+
|
1043 |
+
if self.use_selector:
|
1044 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
1045 |
+
else:
|
1046 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
1047 |
+
|
1048 |
+
radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
|
1049 |
+
inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
|
1050 |
+
sdf = sdf.reshape(batch_size, n_samples)
|
1051 |
+
prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
|
1052 |
+
prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
|
1053 |
+
mid_sdf = (prev_sdf + next_sdf) * 0.5
|
1054 |
+
cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
|
1055 |
+
|
1056 |
+
# ----------------------------------------------------------------------------------------------------------
|
1057 |
+
# Use min value of [ cos, prev_cos ]
|
1058 |
+
# Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
|
1059 |
+
# robust when meeting situations like below:
|
1060 |
+
#
|
1061 |
+
# SDF
|
1062 |
+
# ^
|
1063 |
+
# |\ -----x----...
|
1064 |
+
# | \ /
|
1065 |
+
# | x x
|
1066 |
+
# |---\----/-------------> 0 level
|
1067 |
+
# | \ /
|
1068 |
+
# | \/
|
1069 |
+
# |
|
1070 |
+
# ----------------------------------------------------------------------------------------------------------
|
1071 |
+
prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
|
1072 |
+
cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
|
1073 |
+
cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
|
1074 |
+
cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
|
1075 |
+
|
1076 |
+
dist = (next_z_vals - prev_z_vals)
|
1077 |
+
prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
|
1078 |
+
next_esti_sdf = mid_sdf + cos_val * dist * 0.5
|
1079 |
+
prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
|
1080 |
+
next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
|
1081 |
+
alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
|
1082 |
+
weights = alpha * torch.cumprod(
|
1083 |
+
torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
1084 |
+
|
1085 |
+
z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
|
1086 |
+
return z_samples
|
1087 |
+
|
1088 |
+
def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False, pts_ts=0):
|
1089 |
+
batch_size, n_samples = z_vals.shape
|
1090 |
+
_, n_importance = new_z_vals.shape
|
1091 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
|
1092 |
+
|
1093 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
1094 |
+
pts = pts * 2 - 1
|
1095 |
+
|
1096 |
+
if self.use_selector:
|
1097 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
1098 |
+
else:
|
1099 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
1100 |
+
|
1101 |
+
z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
|
1102 |
+
z_vals, index = torch.sort(z_vals, dim=-1)
|
1103 |
+
|
1104 |
+
if not last:
|
1105 |
+
if isinstance(self.sdf_network, list):
|
1106 |
+
tot_new_sdf = []
|
1107 |
+
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
1108 |
+
cur_new_sdf = cur_sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
1109 |
+
tot_new_sdf.append(cur_new_sdf)
|
1110 |
+
tot_new_sdf = torch.stack(tot_new_sdf, dim=-1)
|
1111 |
+
new_sdf, _ = torch.min(tot_new_sdf, dim=-1) #
|
1112 |
+
else:
|
1113 |
+
if self.use_selector:
|
1114 |
+
new_sdf_cur = self.sdf_network.sdf(pts.reshape(-1, 3)) # .reshape(batch_size, n_importance)
|
1115 |
+
new_sdf_prev = self.prev_sdf_network.sdf(pts.reshape(-1, 3)) # .reshape(batch_size, n_importance)
|
1116 |
+
new_sdf = torch.stack([new_sdf_prev, new_sdf_cur], dim=1)
|
1117 |
+
new_sdf = batched_index_select(new_sdf, sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
|
1118 |
+
new_sdf = new_sdf.reshape(batch_size, n_importance)
|
1119 |
+
else:
|
1120 |
+
new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
1121 |
+
sdf = torch.cat([sdf, new_sdf], dim=-1)
|
1122 |
+
xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
|
1123 |
+
index = index.reshape(-1)
|
1124 |
+
sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
|
1125 |
+
|
1126 |
+
return z_vals, sdf
|
1127 |
+
|
1128 |
+
|
1129 |
+
|
1130 |
+
def render_core(self,
|
1131 |
+
rays_o,
|
1132 |
+
rays_d,
|
1133 |
+
z_vals,
|
1134 |
+
sample_dist,
|
1135 |
+
sdf_network,
|
1136 |
+
deviation_network,
|
1137 |
+
color_network,
|
1138 |
+
background_alpha=None,
|
1139 |
+
background_sampled_color=None,
|
1140 |
+
background_rgb=None,
|
1141 |
+
cos_anneal_ratio=0.0,
|
1142 |
+
pts_ts=0):
|
1143 |
+
batch_size, n_samples = z_vals.shape
|
1144 |
+
|
1145 |
+
# Section length
|
1146 |
+
dists = z_vals[..., 1:] - z_vals[..., :-1]
|
1147 |
+
dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
|
1148 |
+
mid_z_vals = z_vals + dists * 0.5 # z_vals and dists * 0.5 #
|
1149 |
+
|
1150 |
+
# Section midpoints
|
1151 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
|
1152 |
+
dirs = rays_d[:, None, :].expand(pts.shape)
|
1153 |
+
|
1154 |
+
pts = pts.reshape(-1, 3) # pts, nn_ou
|
1155 |
+
dirs = dirs.reshape(-1, 3)
|
1156 |
+
|
1157 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
1158 |
+
|
1159 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
1160 |
+
pts = pts * 2 - 1
|
1161 |
+
|
1162 |
+
if self.use_selector:
|
1163 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
1164 |
+
else:
|
1165 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts)
|
1166 |
+
|
1167 |
+
if isinstance(sdf_network, list):
|
1168 |
+
tot_sdf = []
|
1169 |
+
tot_feature_vector = []
|
1170 |
+
tot_obj_sel = []
|
1171 |
+
tot_gradients = []
|
1172 |
+
for i_obj, cur_sdf_network in enumerate(sdf_network):
|
1173 |
+
cur_sdf_nn_output = cur_sdf_network(pts)
|
1174 |
+
cur_sdf, cur_feature_vector = cur_sdf_nn_output[:, :1], cur_sdf_nn_output[:, 1:]
|
1175 |
+
tot_sdf.append(cur_sdf)
|
1176 |
+
tot_feature_vector.append(cur_feature_vector)
|
1177 |
+
|
1178 |
+
gradients = cur_sdf_network.gradient(pts).squeeze()
|
1179 |
+
tot_gradients.append(gradients)
|
1180 |
+
tot_sdf = torch.stack(tot_sdf, dim=-1)
|
1181 |
+
sdf, obj_sel = torch.min(tot_sdf, dim=-1)
|
1182 |
+
feature_vector = torch.stack(tot_feature_vector, dim=1)
|
1183 |
+
|
1184 |
+
# batched_index_select
|
1185 |
+
# print(f"before sel: {feature_vector.size()}, obj_sel: {obj_sel.size()}")
|
1186 |
+
feature_vector = batched_index_select(values=feature_vector, indices=obj_sel, dim=1).squeeze(1)
|
1187 |
+
|
1188 |
+
|
1189 |
+
# feature_vector = feature_vector[obj_sel.unsqueeze(-1), :].squeeze(1)
|
1190 |
+
# print(f"after sel: {feature_vector.size()}")
|
1191 |
+
tot_gradients = torch.stack(tot_gradients, dim=1)
|
1192 |
+
# gradients = tot_gradients[obj_sel.unsqueeze(-1)].squeeze(1)
|
1193 |
+
gradients = batched_index_select(values=tot_gradients, indices=obj_sel, dim=1).squeeze(1)
|
1194 |
+
# print(f"gradients: {gradients.size()}, tot_gradients: {tot_gradients.size()}")
|
1195 |
+
|
1196 |
+
else:
|
1197 |
+
# sdf_nn_output = sdf_network(pts)
|
1198 |
+
# sdf = sdf_nn_output[:, :1]
|
1199 |
+
# feature_vector = sdf_nn_output[:, 1:]
|
1200 |
+
# gradients = sdf_network.gradient(pts).squeeze()
|
1201 |
+
|
1202 |
+
if self.use_selector:
|
1203 |
+
prev_sdf_nn_output = self.prev_sdf_network(pts)
|
1204 |
+
prev_gradients = self.prev_sdf_network.gradient(pts).squeeze()
|
1205 |
+
cur_sdf_nn_output = self.sdf_network(pts)
|
1206 |
+
cur_gradients = self.sdf_network.gradient(pts).squeeze()
|
1207 |
+
|
1208 |
+
sdf_nn_output = torch.stack([prev_sdf_nn_output, cur_sdf_nn_output], dim=1)
|
1209 |
+
sdf_nn_output = batched_index_select(sdf_nn_output, sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
|
1210 |
+
|
1211 |
+
sdf = sdf_nn_output[:, :1]
|
1212 |
+
feature_vector = sdf_nn_output[:, 1:]
|
1213 |
+
|
1214 |
+
gradients = torch.stack([prev_gradients, cur_gradients], dim=1)
|
1215 |
+
gradients = batched_index_select(gradients, sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
|
1216 |
+
else:
|
1217 |
+
sdf_nn_output = sdf_network(pts)
|
1218 |
+
sdf = sdf_nn_output[:, :1]
|
1219 |
+
feature_vector = sdf_nn_output[:, 1:]
|
1220 |
+
gradients = sdf_network.gradient(pts).squeeze()
|
1221 |
+
# new_sdf_cur = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
1222 |
+
# new_sdf_prev = self.prev_sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
|
1223 |
+
# new_sdf = torch.stack([new_sdf_prev, new_sdf_cur], dim=1)
|
1224 |
+
# new_sdf = batched_index_select(new_sdf, sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
|
1225 |
+
|
1226 |
+
sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
|
1227 |
+
|
1228 |
+
# deviation network #
|
1229 |
+
inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
|
1230 |
+
inv_s = inv_s.expand(batch_size * n_samples, 1)
|
1231 |
+
|
1232 |
+
true_cos = (dirs * gradients).sum(-1, keepdim=True)
|
1233 |
+
|
1234 |
+
# "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
|
1235 |
+
# the cos value "not dead" at the beginning training iterations, for better convergence.
|
1236 |
+
iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
|
1237 |
+
F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
|
1238 |
+
|
1239 |
+
# Estimate signed distances at section points
|
1240 |
+
estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
|
1241 |
+
estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
|
1242 |
+
|
1243 |
+
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
|
1244 |
+
next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
|
1245 |
+
|
1246 |
+
p = prev_cdf - next_cdf
|
1247 |
+
c = prev_cdf
|
1248 |
+
|
1249 |
+
alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
|
1250 |
+
|
1251 |
+
pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
|
1252 |
+
inside_sphere = (pts_norm < 1.0).float().detach()
|
1253 |
+
relax_inside_sphere = (pts_norm < 1.2).float().detach()
|
1254 |
+
|
1255 |
+
# Render with background
|
1256 |
+
if background_alpha is not None:
|
1257 |
+
alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
|
1258 |
+
alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
|
1259 |
+
sampled_color = sampled_color * inside_sphere[:, :, None] +\
|
1260 |
+
background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
|
1261 |
+
sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
|
1262 |
+
|
1263 |
+
weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
|
1264 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
1265 |
+
|
1266 |
+
color = (sampled_color * weights[:, :, None]).sum(dim=1)
|
1267 |
+
if background_rgb is not None: # Fixed background, usually black
|
1268 |
+
color = color + background_rgb * (1.0 - weights_sum)
|
1269 |
+
|
1270 |
+
# Eikonal loss
|
1271 |
+
gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
|
1272 |
+
dim=-1) - 1.0) ** 2
|
1273 |
+
gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
|
1274 |
+
|
1275 |
+
return {
|
1276 |
+
'color': color,
|
1277 |
+
'sdf': sdf,
|
1278 |
+
'dists': dists,
|
1279 |
+
'gradients': gradients.reshape(batch_size, n_samples, 3),
|
1280 |
+
's_val': 1.0 / inv_s,
|
1281 |
+
'mid_z_vals': mid_z_vals,
|
1282 |
+
'weights': weights,
|
1283 |
+
'cdf': c.reshape(batch_size, n_samples),
|
1284 |
+
'gradient_error': gradient_error,
|
1285 |
+
'inside_sphere': inside_sphere
|
1286 |
+
}
|
1287 |
+
|
1288 |
+
def render(self, rays_o, rays_d, near, far, pts_ts=0, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0, use_gt_sdf=False):
|
1289 |
+
batch_size = len(rays_o)
|
1290 |
+
sample_dist = 2.0 / self.n_samples # in a unit sphere # # Assuming the region of interest is a unit sphere
|
1291 |
+
z_vals = torch.linspace(0.0, 1.0, self.n_samples) # linspace #
|
1292 |
+
z_vals = near + (far - near) * z_vals[None, :]
|
1293 |
+
|
1294 |
+
z_vals_outside = None
|
1295 |
+
if self.n_outside > 0:
|
1296 |
+
z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
|
1297 |
+
|
1298 |
+
n_samples = self.n_samples
|
1299 |
+
perturb = self.perturb
|
1300 |
+
|
1301 |
+
if perturb_overwrite >= 0:
|
1302 |
+
perturb = perturb_overwrite
|
1303 |
+
if perturb > 0:
|
1304 |
+
t_rand = (torch.rand([batch_size, 1]) - 0.5)
|
1305 |
+
z_vals = z_vals + t_rand * 2.0 / self.n_samples
|
1306 |
+
|
1307 |
+
if self.n_outside > 0: # z values output # n_outside #
|
1308 |
+
mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
|
1309 |
+
upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
|
1310 |
+
lower = torch.cat([z_vals_outside[..., :1], mids], -1)
|
1311 |
+
t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
|
1312 |
+
z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
|
1313 |
+
|
1314 |
+
if self.n_outside > 0:
|
1315 |
+
z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
|
1316 |
+
|
1317 |
+
background_alpha = None
|
1318 |
+
background_sampled_color = None
|
1319 |
+
|
1320 |
+
# Up sample
|
1321 |
+
if self.n_importance > 0:
|
1322 |
+
with torch.no_grad():
|
1323 |
+
pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
|
1324 |
+
|
1325 |
+
pts = (pts - self.minn_pts) / (self.maxx_pts - self.minn_pts)
|
1326 |
+
# sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
|
1327 |
+
# gt_sdf #
|
1328 |
+
|
1329 |
+
#
|
1330 |
+
# pts = ((pts - xyz_min) / (xyz_max - xyz_min)).flip((-1,)) * 2 - 1
|
1331 |
+
|
1332 |
+
# pts = pts.flip((-1,)) * 2 - 1
|
1333 |
+
pts = pts * 2 - 1
|
1334 |
+
|
1335 |
+
if self.use_selector:
|
1336 |
+
pts, sdf_selector = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
1337 |
+
else:
|
1338 |
+
pts = self.deform_pts(pts=pts, pts_ts=pts_ts) # give nthe pts
|
1339 |
+
|
1340 |
+
pts_exp = pts.reshape(-1, 3)
|
1341 |
+
# minn_pts, _ = torch.min(pts_exp, dim=0)
|
1342 |
+
# maxx_pts, _ = torch.max(pts_exp, dim=0) # deformation field (not a rigid one) -> the meshes #
|
1343 |
+
# print(f"minn_pts: {minn_pts}, maxx_pts: {maxx_pts}")
|
1344 |
+
|
1345 |
+
# pts_to_near = pts - near.unsqueeze(1)
|
1346 |
+
# maxx_pts = 1.5; minn_pts = -1.5
|
1347 |
+
# # maxx_pts = 3; minn_pts = -3
|
1348 |
+
# # maxx_pts = 1; minn_pts = -1
|
1349 |
+
# pts_exp = (pts_exp - minn_pts) / (maxx_pts - minn_pts)
|
1350 |
+
|
1351 |
+
## render and iamges ####
|
1352 |
+
if use_gt_sdf:
|
1353 |
+
### use the GT sdf field ####
|
1354 |
+
# print(f"Using gt sdf :")
|
1355 |
+
sdf = self.gt_sdf(pts_exp.reshape(-1, 3).detach().cpu().numpy())
|
1356 |
+
sdf = torch.from_numpy(sdf).float().cuda()
|
1357 |
+
sdf = sdf.reshape(batch_size, self.n_samples)
|
1358 |
+
### use the GT sdf field ####
|
1359 |
+
else:
|
1360 |
+
# pts_exp: (bsz x nn_s) x 3 -> (sdf_network) -> (bsz x nn_s)
|
1361 |
+
#### use the optimized sdf field ####
|
1362 |
+
|
1363 |
+
# sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
1364 |
+
|
1365 |
+
if isinstance(self.sdf_network, list):
|
1366 |
+
tot_sdf_values = []
|
1367 |
+
for i_obj, cur_sdf_network in enumerate(self.sdf_network):
|
1368 |
+
cur_sdf_values = cur_sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
1369 |
+
tot_sdf_values.append(cur_sdf_values)
|
1370 |
+
tot_sdf_values = torch.stack(tot_sdf_values, dim=-1)
|
1371 |
+
tot_sdf_values, _ = torch.min(tot_sdf_values, dim=-1) # totsdf values #
|
1372 |
+
sdf = tot_sdf_values
|
1373 |
+
else:
|
1374 |
+
# sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
1375 |
+
|
1376 |
+
if self.use_selector:
|
1377 |
+
prev_sdf = self.prev_sdf_network.sdf(pts_exp) # .reshape(batch_size, self.n_samples)
|
1378 |
+
cur_sdf = self.sdf_network.sdf(pts_exp) # .reshape(batch_size, self.n_samples)
|
1379 |
+
sdf = torch.stack([prev_sdf, cur_sdf], dim=1)
|
1380 |
+
sdf = batched_index_select(sdf, indices=sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
|
1381 |
+
sdf = sdf.reshape(batch_size, self.n_samples)
|
1382 |
+
else:
|
1383 |
+
sdf = self.sdf_network.sdf(pts_exp).reshape(batch_size, self.n_samples)
|
1384 |
+
|
1385 |
+
#### use the optimized sdf field ####
|
1386 |
+
|
1387 |
+
for i in range(self.up_sample_steps):
|
1388 |
+
new_z_vals = self.up_sample(rays_o,
|
1389 |
+
rays_d,
|
1390 |
+
z_vals,
|
1391 |
+
sdf,
|
1392 |
+
self.n_importance // self.up_sample_steps,
|
1393 |
+
64 * 2**i,
|
1394 |
+
pts_ts=pts_ts)
|
1395 |
+
z_vals, sdf = self.cat_z_vals(rays_o,
|
1396 |
+
rays_d,
|
1397 |
+
z_vals,
|
1398 |
+
new_z_vals,
|
1399 |
+
sdf,
|
1400 |
+
last=(i + 1 == self.up_sample_steps),
|
1401 |
+
pts_ts=pts_ts)
|
1402 |
+
|
1403 |
+
n_samples = self.n_samples + self.n_importance
|
1404 |
+
|
1405 |
+
# Background model
|
1406 |
+
if self.n_outside > 0:
|
1407 |
+
z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
|
1408 |
+
z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
|
1409 |
+
ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf, pts_ts=pts_ts)
|
1410 |
+
|
1411 |
+
background_sampled_color = ret_outside['sampled_color']
|
1412 |
+
background_alpha = ret_outside['alpha']
|
1413 |
+
|
1414 |
+
# Render core
|
1415 |
+
ret_fine = self.render_core(rays_o, #
|
1416 |
+
rays_d,
|
1417 |
+
z_vals,
|
1418 |
+
sample_dist,
|
1419 |
+
self.sdf_network,
|
1420 |
+
self.deviation_network,
|
1421 |
+
self.color_network,
|
1422 |
+
background_rgb=background_rgb,
|
1423 |
+
background_alpha=background_alpha,
|
1424 |
+
background_sampled_color=background_sampled_color,
|
1425 |
+
cos_anneal_ratio=cos_anneal_ratio,
|
1426 |
+
pts_ts=pts_ts)
|
1427 |
+
|
1428 |
+
color_fine = ret_fine['color']
|
1429 |
+
weights = ret_fine['weights']
|
1430 |
+
weights_sum = weights.sum(dim=-1, keepdim=True)
|
1431 |
+
gradients = ret_fine['gradients']
|
1432 |
+
s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
|
1433 |
+
|
1434 |
+
return {
|
1435 |
+
'color_fine': color_fine,
|
1436 |
+
's_val': s_val,
|
1437 |
+
'cdf_fine': ret_fine['cdf'],
|
1438 |
+
'weight_sum': weights_sum,
|
1439 |
+
'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
|
1440 |
+
'gradients': gradients,
|
1441 |
+
'weights': weights,
|
1442 |
+
'gradient_error': ret_fine['gradient_error'],
|
1443 |
+
'inside_sphere': ret_fine['inside_sphere']
|
1444 |
+
}
|
1445 |
+
|
1446 |
+
# def
|
1447 |
+
def extract_fields_from_tets_with_selector(self, bound_min, bound_max, resolution, pts_ts ):
|
1448 |
+
# load tet via resolution #
|
1449 |
+
# scale them via bounds #
|
1450 |
+
# extract the geometry #
|
1451 |
+
# /home/xueyi/gen/DeepMetaHandles/data/tets/100_compress.npz # strange #
|
1452 |
+
device = bound_min.device
|
1453 |
+
# if resolution in [64, 70, 80, 90, 100]:
|
1454 |
+
# tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{resolution}_compress.npz"
|
1455 |
+
# else:
|
1456 |
+
tet_fn = f"/home/xueyi/gen/DeepMetaHandles/data/tets/{100}_compress.npz"
|
1457 |
+
tets = np.load(tet_fn)
|
1458 |
+
verts = torch.from_numpy(tets['vertices']).float().to(device) # verts positions
|
1459 |
+
indices = torch.from_numpy(tets['tets']).long().to(device) # .to(self.device)
|
1460 |
+
# split #
|
1461 |
+
# verts; verts; #
|
1462 |
+
minn_verts, _ = torch.min(verts, dim=0)
|
1463 |
+
maxx_verts, _ = torch.max(verts, dim=0) # (3, ) # exporting the
|
1464 |
+
# scale_verts = maxx_verts - minn_verts
|
1465 |
+
scale_bounds = bound_max - bound_min # scale bounds #
|
1466 |
+
|
1467 |
+
### scale the vertices ###
|
1468 |
+
scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
1469 |
+
|
1470 |
+
# scaled_verts = (verts - minn_verts.unsqueeze(0)) / (maxx_verts - minn_verts).unsqueeze(0) ### the maxx and minn verts scales ###
|
1471 |
+
|
1472 |
+
scaled_verts = scaled_verts * 2. - 1. # init the sdf filed viathe tet mesh vertices and the sdf values ##
|
1473 |
+
# scaled_verts = (scaled_verts * scale_bounds.unsqueeze(0)) + bound_min.unsqueeze(0) ## the scaled verts ###
|
1474 |
+
|
1475 |
+
# scaled_verts = scaled_verts - scale_bounds.unsqueeze(0) / 2. #
|
1476 |
+
# scaled_verts = scaled_verts - bound_min.unsqueeze(0) - scale_bounds.unsqueeze(0) / 2.
|
1477 |
+
|
1478 |
+
sdf_values = []
|
1479 |
+
N = 64
|
1480 |
+
query_bundles = N ** 3 ### N^3
|
1481 |
+
query_NNs = scaled_verts.size(0) // query_bundles
|
1482 |
+
if query_NNs * query_bundles < scaled_verts.size(0):
|
1483 |
+
query_NNs += 1
|
1484 |
+
for i_query in range(query_NNs):
|
1485 |
+
cur_bundle_st = i_query * query_bundles
|
1486 |
+
cur_bundle_ed = (i_query + 1) * query_bundles
|
1487 |
+
cur_bundle_ed = min(cur_bundle_ed, scaled_verts.size(0))
|
1488 |
+
cur_query_pts = scaled_verts[cur_bundle_st: cur_bundle_ed]
|
1489 |
+
# if def_func is not None:
|
1490 |
+
cur_query_pts, sdf_selector = self.deform_pts_with_selector(pts=cur_query_pts, pts_ts=pts_ts)
|
1491 |
+
|
1492 |
+
prev_query_vals = -self.prev_sdf_network.sdf(cur_query_pts)
|
1493 |
+
cur_query_vals = -self.sdf_network.sdf(cur_query_pts)
|
1494 |
+
cur_query_vals = torch.stack([prev_query_vals, cur_query_vals], dim=1)
|
1495 |
+
cur_query_vals = batched_index_select(cur_query_vals, sdf_selector.unsqueeze(-1), dim=1).squeeze(1)
|
1496 |
+
|
1497 |
+
# cur_query_vals = query_func(cur_query_pts)
|
1498 |
+
sdf_values.append(cur_query_vals)
|
1499 |
+
sdf_values = torch.cat(sdf_values, dim=0)
|
1500 |
+
# print(f"queryed sdf values: {sdf_values.size()}") #
|
1501 |
+
|
1502 |
+
GT_sdf_values = np.load("/home/xueyi/diffsim/DiffHand/assets/hand/100_sdf_values.npy", allow_pickle=True)
|
1503 |
+
GT_sdf_values = torch.from_numpy(GT_sdf_values).float().to(device)
|
1504 |
+
|
1505 |
+
# intrinsic, tet values, pts values, sdf network #
|
1506 |
+
triangle_table, num_triangles_table, base_tet_edges, v_id = create_mt_variable(device)
|
1507 |
+
tet_table, num_tets_table = create_tetmesh_variables(device)
|
1508 |
+
|
1509 |
+
sdf_values = sdf_values.squeeze(-1) # how the rendering #
|
1510 |
+
|
1511 |
+
# print(f"GT_sdf_values: {GT_sdf_values.size()}, sdf_values: {sdf_values.size()}, scaled_verts: {scaled_verts.size()}")
|
1512 |
+
# print(f"scaled_verts: {scaled_verts.size()}, ")
|
1513 |
+
# pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id,
|
1514 |
+
# return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None):
|
1515 |
+
# marching_tets_tetmesh ##
|
1516 |
+
verts, faces, tet_verts, tets = marching_tets_tetmesh(scaled_verts, sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
1517 |
+
### use the GT sdf values for the marching tets ###
|
1518 |
+
GT_verts, GT_faces, GT_tet_verts, GT_tets = marching_tets_tetmesh(scaled_verts, GT_sdf_values, indices, triangle_table, num_triangles_table, base_tet_edges, v_id, return_tet_mesh=True, ori_v=scaled_verts, num_tets_table=num_tets_table, tet_table=tet_table)
|
1519 |
+
|
1520 |
+
# print(f"After tet marching with verts: {verts.size()}, faces: {faces.size()}")
|
1521 |
+
return verts, faces, sdf_values, GT_verts, GT_faces # verts, faces #
|
1522 |
+
|
1523 |
+
|
1524 |
+
def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
|
1525 |
+
return extract_geometry(bound_min, # extract geometry #
|
1526 |
+
bound_max,
|
1527 |
+
resolution=resolution,
|
1528 |
+
threshold=threshold,
|
1529 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts),
|
1530 |
+
query_func=lambda pts: -self.query_func_sdf(pts)
|
1531 |
+
)
|
1532 |
+
|
1533 |
+
# if self.deform_pts_with_selector:
|
1534 |
+
# pts = self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts)
|
1535 |
+
def extract_geometry_tets(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
|
1536 |
+
if wdef:
|
1537 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1538 |
+
bound_max,
|
1539 |
+
resolution=resolution,
|
1540 |
+
threshold=threshold,
|
1541 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1542 |
+
def_func=lambda pts: self.deform_pts(pts, pts_ts=pts_ts) if not self.use_selector else self.deform_pts_with_selector(pts=pts, pts_ts=pts_ts))
|
1543 |
+
else:
|
1544 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1545 |
+
bound_max,
|
1546 |
+
resolution=resolution,
|
1547 |
+
threshold=threshold,
|
1548 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts)
|
1549 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1550 |
+
)
|
1551 |
+
|
1552 |
+
|
1553 |
+
def extract_geometry_tets_kinematic(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
|
1554 |
+
if wdef:
|
1555 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1556 |
+
bound_max,
|
1557 |
+
resolution=resolution,
|
1558 |
+
threshold=threshold,
|
1559 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1560 |
+
def_func=lambda pts: self.deform_pts_kinematic(pts, pts_ts=pts_ts))
|
1561 |
+
else:
|
1562 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1563 |
+
bound_max,
|
1564 |
+
resolution=resolution,
|
1565 |
+
threshold=threshold,
|
1566 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts)
|
1567 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1568 |
+
)
|
1569 |
+
|
1570 |
+
|
1571 |
+
def extract_geometry_tets_active(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
|
1572 |
+
if wdef:
|
1573 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1574 |
+
bound_max,
|
1575 |
+
resolution=resolution,
|
1576 |
+
threshold=threshold,
|
1577 |
+
query_func=lambda pts: -self.query_func_active(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1578 |
+
def_func=lambda pts: self.deform_pts_kinematic_active(pts, pts_ts=pts_ts))
|
1579 |
+
else:
|
1580 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1581 |
+
bound_max,
|
1582 |
+
resolution=resolution,
|
1583 |
+
threshold=threshold,
|
1584 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts)
|
1585 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1586 |
+
)
|
1587 |
+
|
1588 |
+
def extract_geometry_tets_passive(self, bound_min, bound_max, resolution, pts_ts=0, threshold=0.0, wdef=False):
|
1589 |
+
if wdef:
|
1590 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1591 |
+
bound_max,
|
1592 |
+
resolution=resolution,
|
1593 |
+
threshold=threshold,
|
1594 |
+
query_func=lambda pts: -self.query_func_sdf_passive(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1595 |
+
def_func=lambda pts: self.deform_pts_passive(pts, pts_ts=pts_ts))
|
1596 |
+
else:
|
1597 |
+
return extract_geometry_tets(bound_min, # extract geometry #
|
1598 |
+
bound_max,
|
1599 |
+
resolution=resolution,
|
1600 |
+
threshold=threshold,
|
1601 |
+
# query_func=lambda pts: -self.sdf_network.sdf(pts)
|
1602 |
+
query_func=lambda pts: -self.query_func_sdf(pts), # lambda pts: -self.sdf_network.sdf(pts),
|
1603 |
+
)
|
models/test.js
ADDED
File without changes
|
pre-requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
pip==23.3.2
|
2 |
+
torch==2.2.0
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-f https://download.pytorch.org/whl/cpu/torch_stable.html
|
2 |
+
-f https://data.pyg.org/whl/torch-2.2.0%2Bcpu.html
|
3 |
+
# pip==20.2.4
|
4 |
+
torch==2.2.0
|
5 |
+
# torchvision==0.13.1
|
6 |
+
# torchaudio==0.12.1
|
7 |
+
scipy
|
8 |
+
trimesh
|
9 |
+
icecream
|
10 |
+
tqdm
|
11 |
+
pyhocon
|
12 |
+
open3d
|
13 |
+
tensorboard
|
14 |
+
|
15 |
+
# blobfile==2.0.1
|
16 |
+
# manopth @ git+https://github.com/hassony2/manopth.git
|
17 |
+
# numpy==1.23.1
|
18 |
+
# psutil==5.9.2
|
19 |
+
# scikit-learn
|
20 |
+
# scipy==1.9.3
|
21 |
+
# tensorboard
|
22 |
+
# tensorboardx
|
23 |
+
# tqdm
|
24 |
+
# trimesh
|
25 |
+
# clip
|
26 |
+
# chumpy
|
27 |
+
# opencv-python
|
scripts_demo/train_grab_pointset_points_dyn_s1.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
export trainer=exp_runner_stage_1.py
|
12 |
+
|
13 |
+
|
14 |
+
export mode="train_point_set"
|
15 |
+
|
16 |
+
export conf=dyn_grab_pointset_points_dyn_s1.conf
|
17 |
+
|
18 |
+
export conf_root="./confs_new"
|
19 |
+
|
20 |
+
|
21 |
+
export data_path="./data/102_grab_all_data.npy"
|
22 |
+
# bash scripts_new/train_grab_pointset_points_dyn_s1.sh
|
23 |
+
|
24 |
+
export cuda_ids="0"
|
25 |
+
|
26 |
+
|
27 |
+
CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case} --data_path=${data_path}
|
28 |
+
|
scripts_new/train_grab_mano.sh
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
export conf=wmask_refine_passive_rigidtrans_forward.conf
|
5 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
export trainer=exp_runner_stage_1.py
|
10 |
+
|
11 |
+
export mode="train_dyn_mano_model"
|
12 |
+
|
13 |
+
|
14 |
+
export conf=dyn_grab_pointset_mano.conf
|
15 |
+
|
16 |
+
export conf_root="./confs_new"
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
# bash scripts_new/train_grab_mano.sh
|
21 |
+
|
22 |
+
export cuda_ids="0"
|
23 |
+
#
|
24 |
+
|
25 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
26 |
+
|
scripts_new/train_grab_mano_wreact.sh
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
export trainer=exp_runner_stage_1.py
|
11 |
+
|
12 |
+
export mode="train_dyn_mano_model_wreact" ## wreact ## wreact ##
|
13 |
+
|
14 |
+
export conf=dyn_grab_pointset_mano_dyn.conf
|
15 |
+
|
16 |
+
|
17 |
+
export conf_root="./confs_new"
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
export cuda_ids="0"
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
26 |
+
|
scripts_new/train_grab_mano_wreact_optacts.sh
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
export trainer=exp_runner_stage_1.py
|
11 |
+
|
12 |
+
export mode="train_dyn_mano_model_wreact" ## wreact ## wreact ##
|
13 |
+
|
14 |
+
export conf=dyn_grab_pointset_mano_dyn.conf
|
15 |
+
|
16 |
+
|
17 |
+
export conf_root="./confs_new"
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
export cuda_ids="0"
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
26 |
+
|
scripts_new/train_grab_pointset.sh
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
export cuda_ids="4"
|
5 |
+
|
6 |
+
# export cuda_ids="5"
|
7 |
+
|
8 |
+
|
9 |
+
export trainer=exp_runner_arti_forward.py
|
10 |
+
export conf=wmask_refine_passive_rigidtrans_forward.conf
|
11 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
12 |
+
|
13 |
+
|
14 |
+
# export trainer=exp_runner_arti_multi_objs_compositional.py
|
15 |
+
export trainer=exp_runner_arti_multi_objs_compositional_ks.py
|
16 |
+
# export trainer=exp_runner_arti_multi_objs_dyn.py
|
17 |
+
export trainer=exp_runner_arti_multi_objs_pointset.py
|
18 |
+
# /home/xueyi/diffsim/NeuS/confs/wmask_refine_passive_compositional.conf
|
19 |
+
export conf=wmask_refine_passive_compositional.conf
|
20 |
+
export conf=dyn_arctic_ks.conf
|
21 |
+
|
22 |
+
export conf=dyn_arctic_ks_robohand.conf
|
23 |
+
# /data/xueyi/diffsim/NeuS/confs/dyn_arctic_ks_robohand_from_mano_model_rules.conf
|
24 |
+
export conf=dyn_arctic_ks_robohand_from_mano_model_rules.conf
|
25 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules.conf
|
26 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions.conf
|
27 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2.conf
|
28 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand.conf
|
29 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v2.conf
|
30 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v4.conf
|
31 |
+
|
32 |
+
# /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
|
33 |
+
export conf=dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
|
34 |
+
# /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
|
35 |
+
# export conf=dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
|
36 |
+
|
37 |
+
export conf=dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
|
38 |
+
# export conf=dyn_arctic_ks_robohand_from_mano_model_rules_arti.conf
|
39 |
+
# a very stiff system for #
|
40 |
+
export conf=dyn_grab_pointset_mano.conf
|
41 |
+
|
42 |
+
export mode="train_from_model_rules"
|
43 |
+
export mode="train_from_model_rules"
|
44 |
+
|
45 |
+
export mode="train_sdf_from_model_rules"
|
46 |
+
export mode="train_actions_from_model_rules"
|
47 |
+
# export mode="train_actions_from_sim_rules"
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
export mode="train_def"
|
52 |
+
# export mode="train_actions_from_model_rules"
|
53 |
+
export mode="train_mano_actions_from_model_rules"
|
54 |
+
|
55 |
+
# virtual force # # ## virtual forces ##
|
56 |
+
# /data/xueyi/diffsim/NeuS/confs/dyn_arctic_ks_robohand_from_mano_model_rules.conf ##
|
57 |
+
export mode="train_actions_from_model_rules"
|
58 |
+
|
59 |
+
export mode="train_actions_from_mano_model_rules"
|
60 |
+
|
61 |
+
export mode="train_real_robot_actions_from_mano_model_rules"
|
62 |
+
export mode="train_real_robot_actions_from_mano_model_rules_diffhand"
|
63 |
+
|
64 |
+
export mode="train_real_robot_actions_from_mano_model_rules_diffhand_fortest"
|
65 |
+
export mode="train_real_robot_actions_from_mano_model_rules_manohand_fortest"
|
66 |
+
|
67 |
+
|
68 |
+
export mode="train_real_robot_actions_from_mano_model_rules_diffhand_fortest"
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
export mode="train_real_robot_actions_from_mano_model_rules_manohand_fortest_states"
|
73 |
+
|
74 |
+
###### diffsim ######
|
75 |
+
# /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v3.conf
|
76 |
+
|
77 |
+
export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_world"
|
78 |
+
export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_rl"
|
79 |
+
|
80 |
+
export mode="train_dyn_mano_model_states"
|
81 |
+
|
82 |
+
|
83 |
+
## train dyn mano model states wreact ##
|
84 |
+
export mode="train_dyn_mano_model_states_wreact"
|
85 |
+
export conf=dyn_grab_pointset_mano_dyn.conf
|
86 |
+
## train dyn mano model states wreact ##
|
87 |
+
|
88 |
+
|
89 |
+
export conf_root="./confs_new"
|
90 |
+
|
91 |
+
|
92 |
+
#
|
93 |
+
# bash scripts_new/train_grab_pointset_dyn.sh
|
94 |
+
|
95 |
+
export cuda_ids="2"
|
96 |
+
#
|
97 |
+
|
98 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
99 |
+
|
scripts_new/train_grab_pointset_points_dyn.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
export trainer=exp_runner_stage_1.py
|
12 |
+
|
13 |
+
|
14 |
+
export mode="train_point_set"
|
15 |
+
|
16 |
+
export conf=dyn_grab_pointset_points_dyn.conf
|
17 |
+
|
18 |
+
export conf_root="./confs_new"
|
19 |
+
|
20 |
+
|
21 |
+
# bash scripts_new/train_grab_pointset_points_dyn.sh
|
22 |
+
|
23 |
+
export cuda_ids="0"
|
24 |
+
|
25 |
+
|
26 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
27 |
+
|
scripts_new/train_grab_pointset_points_dyn_retar.sh
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
export trainer=exp_runner_stage_1.py
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
### stage 1 -> tracking MANO expanded set using Shadow's object mesh points ###
|
16 |
+
export mode="train_point_set_retar"
|
17 |
+
export conf=dyn_grab_pointset_points_dyn_retar.conf
|
18 |
+
|
19 |
+
# ### stage 2 -> tracking MANO expanded set using Shadow's expanded points ###
|
20 |
+
# export mode="train_expanded_set_motions_retar_pts"
|
21 |
+
# export conf=dyn_grab_pointset_points_dyn_retar_pts.conf
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
export conf_root="./confs_new"
|
27 |
+
|
28 |
+
|
29 |
+
# bash scripts_new/train_grab_pointset_points_dyn_retar.sh
|
30 |
+
|
31 |
+
export cuda_ids="0"
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
36 |
+
|
scripts_new/train_grab_pointset_points_dyn_retar_pts.sh
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
export trainer=exp_runner_stage_1.py
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# ### stage 1 -> tracking MANO expanded set using Shadow's object mesh points ###
|
16 |
+
# export mode="train_expanded_set_motions_retar"
|
17 |
+
# export conf=dyn_grab_pointset_points_dyn_retar.conf
|
18 |
+
|
19 |
+
### stage 2 -> tracking MANO expanded set using Shadow's expanded points ###
|
20 |
+
export mode="train_point_set_retar_pts"
|
21 |
+
export conf=dyn_grab_pointset_points_dyn_retar_pts.conf
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
export conf_root="./confs_new"
|
27 |
+
|
28 |
+
|
29 |
+
# bash scripts_new/train_grab_pointset_points_dyn_retar.sh
|
30 |
+
|
31 |
+
export cuda_ids="0"
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
36 |
+
|
scripts_new/train_grab_pointset_points_dyn_s1.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
export trainer=exp_runner_stage_1.py
|
12 |
+
|
13 |
+
|
14 |
+
export mode="train_point_set"
|
15 |
+
|
16 |
+
export conf=dyn_grab_pointset_points_dyn_s1.conf
|
17 |
+
|
18 |
+
export conf_root="./confs_new"
|
19 |
+
|
20 |
+
|
21 |
+
# bash scripts_new/train_grab_pointset_points_dyn_s1.sh
|
22 |
+
|
23 |
+
export cuda_ids="0"
|
24 |
+
|
25 |
+
|
26 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
27 |
+
|
scripts_new/train_grab_pointset_points_dyn_s2.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
export trainer=exp_runner_stage_1.py
|
12 |
+
|
13 |
+
|
14 |
+
export mode="train_point_set"
|
15 |
+
|
16 |
+
export conf=dyn_grab_pointset_points_dyn_s2.conf
|
17 |
+
|
18 |
+
export conf_root="./confs_new"
|
19 |
+
|
20 |
+
|
21 |
+
# bash scripts_new/train_grab_pointset_points_dyn_s2.sh
|
22 |
+
|
23 |
+
export cuda_ids="0"
|
24 |
+
|
25 |
+
|
26 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
27 |
+
|
scripts_new/train_grab_pointset_points_dyn_s3.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
export trainer=exp_runner_stage_1.py
|
12 |
+
|
13 |
+
|
14 |
+
export mode="train_point_set"
|
15 |
+
|
16 |
+
export conf=dyn_grab_pointset_points_dyn_s3.conf
|
17 |
+
|
18 |
+
export conf_root="./confs_new"
|
19 |
+
|
20 |
+
|
21 |
+
# bash scripts_new/train_grab_pointset_points_dyn_s3.sh
|
22 |
+
|
23 |
+
export cuda_ids="0"
|
24 |
+
|
25 |
+
|
26 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
27 |
+
|
scripts_new/train_grab_pointset_points_dyn_s4.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
export trainer=exp_runner_stage_1.py
|
12 |
+
|
13 |
+
|
14 |
+
export mode="train_point_set"
|
15 |
+
|
16 |
+
export conf=dyn_grab_pointset_points_dyn_s4.conf
|
17 |
+
|
18 |
+
export conf_root="./confs_new"
|
19 |
+
|
20 |
+
|
21 |
+
# bash scripts_new/train_grab_pointset_points_dyn_s4.sh
|
22 |
+
|
23 |
+
export cuda_ids="0"
|
24 |
+
|
25 |
+
|
26 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
27 |
+
|
scripts_new/train_grab_shadow_multistages.sh
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
export cuda_ids="3"
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
export trainer=exp_runner_arti_forward.py
|
9 |
+
export conf=wmask_refine_passive_rigidtrans_forward.conf
|
10 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
11 |
+
|
12 |
+
|
13 |
+
# export trainer=exp_runner_arti_multi_objs_compositional.py
|
14 |
+
export trainer=exp_runner_arti_multi_objs_compositional_ks.py
|
15 |
+
export trainer=exp_runner_arti_multi_objs_arti_dyn.py
|
16 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand.conf
|
17 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v2.conf
|
18 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v4.conf
|
19 |
+
|
20 |
+
# /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
|
21 |
+
export conf=dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
|
22 |
+
|
23 |
+
# /home/xueyi/diffsim/NeuS/confs/dyn_grab_mano_model_states.conf
|
24 |
+
export conf=dyn_grab_mano_model_states.conf
|
25 |
+
|
26 |
+
export conf=dyn_grab_shadow_model_states.conf
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
export mode="train_def"
|
32 |
+
# export mode="train_actions_from_model_rules"
|
33 |
+
export mode="train_mano_actions_from_model_rules"
|
34 |
+
|
35 |
+
# virtual force #
|
36 |
+
# /data/xueyi/diffsim/NeuS/confs/dyn_arctic_ks_robohand_from_mano_model_rules.conf ##
|
37 |
+
|
38 |
+
export mode="train_real_robot_actions_from_mano_model_rules_diffhand_fortest"
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
export mode="train_real_robot_actions_from_mano_model_rules_manohand_fortest_states"
|
43 |
+
|
44 |
+
|
45 |
+
### using the diffsim ###
|
46 |
+
|
47 |
+
###### diffsim ######
|
48 |
+
# /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v3.conf
|
49 |
+
|
50 |
+
|
51 |
+
export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_world"
|
52 |
+
export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_rl"
|
53 |
+
|
54 |
+
export mode="train_dyn_mano_model_states"
|
55 |
+
|
56 |
+
#
|
57 |
+
export mode="train_real_robot_actions_from_mano_model_rules_v5_shadowhand_fortest_states_grab"
|
58 |
+
## dyn grab shadow
|
59 |
+
|
60 |
+
|
61 |
+
### optimize for the manipulatable hand actions ###
|
62 |
+
export mode="train_real_robot_actions_from_mano_model_rules_v5_shadowhand_fortest_states_grab_redmax_acts"
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
# ## acts ##
|
67 |
+
# export conf=dyn_grab_shadow_model_states_224.conf
|
68 |
+
# export conf=dyn_grab_shadow_model_states_89.conf
|
69 |
+
# export conf=dyn_grab_shadow_model_states_102.conf
|
70 |
+
# export conf=dyn_grab_shadow_model_states_7.conf
|
71 |
+
# export conf=dyn_grab_shadow_model_states_47.conf
|
72 |
+
# export conf=dyn_grab_shadow_model_states_67.conf
|
73 |
+
# export conf=dyn_grab_shadow_model_states_76.conf
|
74 |
+
# export conf=dyn_grab_shadow_model_states_85.conf
|
75 |
+
# # export conf=dyn_grab_shadow_model_states_91.conf
|
76 |
+
# # export conf=dyn_grab_shadow_model_states_167.conf
|
77 |
+
# export conf=dyn_grab_shadow_model_states_107.conf
|
78 |
+
# export conf=dyn_grab_shadow_model_states_306.conf
|
79 |
+
# export conf=dyn_grab_shadow_model_states_313.conf
|
80 |
+
# export conf=dyn_grab_shadow_model_states_322.conf
|
81 |
+
|
82 |
+
|
83 |
+
# /home/xueyi/diffsim/NeuS/confs_new/dyn_grab_arti_shadow_multi_stages.conf
|
84 |
+
export conf=dyn_grab_arti_shadow_multi_stages.conf
|
85 |
+
# export conf=dyn_grab_shadow_model_states_398.conf
|
86 |
+
# export conf=dyn_grab_shadow_model_states_363.conf
|
87 |
+
# export conf=dyn_grab_shadow_model_states_358.conf
|
88 |
+
|
89 |
+
# bash scripts_new/train_grab_shadow_multistages.sh
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
export conf_root="./confs_new"
|
96 |
+
|
97 |
+
|
98 |
+
export cuda_ids="4"
|
99 |
+
|
100 |
+
|
101 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
102 |
+
|
scripts_new/train_grab_shadow_singlestage.sh
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
export PYTHONPATH=.
|
3 |
+
|
4 |
+
export cuda_ids="3"
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
export trainer=exp_runner_arti_forward.py
|
9 |
+
export conf=wmask_refine_passive_rigidtrans_forward.conf
|
10 |
+
export data_case=hand_test_routine_2_light_color_wtime_active_passive
|
11 |
+
|
12 |
+
|
13 |
+
# export trainer=exp_runner_arti_multi_objs_compositional.py
|
14 |
+
export trainer=exp_runner_arti_multi_objs_compositional_ks.py
|
15 |
+
export trainer=exp_runner_arti_multi_objs_arti_dyn.py
|
16 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand.conf
|
17 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v2.conf
|
18 |
+
export conf=dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v4.conf
|
19 |
+
|
20 |
+
# /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
|
21 |
+
export conf=dyn_arctic_robohand_from_mano_model_train_mano_dyn_model_states.conf
|
22 |
+
|
23 |
+
# /home/xueyi/diffsim/NeuS/confs/dyn_grab_mano_model_states.conf
|
24 |
+
export conf=dyn_grab_mano_model_states.conf
|
25 |
+
|
26 |
+
export conf=dyn_grab_shadow_model_states.conf
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
export mode="train_def"
|
32 |
+
# export mode="train_actions_from_model_rules"
|
33 |
+
export mode="train_mano_actions_from_model_rules"
|
34 |
+
|
35 |
+
# virtual force #
|
36 |
+
# /data/xueyi/diffsim/NeuS/confs/dyn_arctic_ks_robohand_from_mano_model_rules.conf ##
|
37 |
+
|
38 |
+
export mode="train_real_robot_actions_from_mano_model_rules_diffhand_fortest"
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
export mode="train_real_robot_actions_from_mano_model_rules_manohand_fortest_states"
|
43 |
+
|
44 |
+
|
45 |
+
###### diffsim ######
|
46 |
+
# /home/xueyi/diffsim/NeuS/confs/dyn_arctic_robohand_from_mano_model_rules_actions_f2_diffhand_v3.conf
|
47 |
+
|
48 |
+
|
49 |
+
export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_world"
|
50 |
+
export mode="train_real_robot_actions_from_mano_model_rules_v5_manohand_fortest_states_res_rl"
|
51 |
+
|
52 |
+
export mode="train_dyn_mano_model_states"
|
53 |
+
|
54 |
+
#
|
55 |
+
export mode="train_real_robot_actions_from_mano_model_rules_v5_shadowhand_fortest_states_grab"
|
56 |
+
## dyn grab shadow
|
57 |
+
|
58 |
+
### optimze for the redmax hand actions from joint states ###
|
59 |
+
export mode="train_real_robot_actions_from_mano_model_rules_shadowhand"
|
60 |
+
|
61 |
+
### optimize for the manipulatable hand actions ###
|
62 |
+
export mode="train_real_robot_actions_from_mano_model_rules_v5_shadowhand_fortest_states_grab_redmax_acts"
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
# ## acts ##
|
67 |
+
# export conf=dyn_grab_shadow_model_states_224.conf
|
68 |
+
# export conf=dyn_grab_shadow_model_states_89.conf
|
69 |
+
# export conf=dyn_grab_shadow_model_states_102.conf
|
70 |
+
# export conf=dyn_grab_shadow_model_states_7.conf
|
71 |
+
# export conf=dyn_grab_shadow_model_states_47.conf
|
72 |
+
# export conf=dyn_grab_shadow_model_states_67.conf
|
73 |
+
# export conf=dyn_grab_shadow_model_states_76.conf
|
74 |
+
# export conf=dyn_grab_shadow_model_states_85.conf
|
75 |
+
# # export conf=dyn_grab_shadow_model_states_91.conf
|
76 |
+
# # export conf=dyn_grab_shadow_model_states_167.conf
|
77 |
+
# export conf=dyn_grab_shadow_model_states_107.conf
|
78 |
+
# export conf=dyn_grab_shadow_model_states_306.conf
|
79 |
+
# export conf=dyn_grab_shadow_model_states_313.conf
|
80 |
+
# export conf=dyn_grab_shadow_model_states_322.conf
|
81 |
+
|
82 |
+
# /home/xueyi/diffsim/NeuS/confs_new/dyn_grab_arti_shadow_multi_stages.conf
|
83 |
+
# /home/xueyi/diffsim/NeuS/confs_new/dyn_grab_arti_shadow_single_stage.conf
|
84 |
+
export conf=dyn_grab_arti_shadow_single_stage.conf
|
85 |
+
# export conf=dyn_grab_shadow_model_states_398.conf
|
86 |
+
# export conf=dyn_grab_shadow_model_states_363.conf
|
87 |
+
# export conf=dyn_grab_shadow_model_states_358.conf
|
88 |
+
|
89 |
+
# bash scripts_new/train_grab_shadow_singlestage.sh
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
export conf_root="./confs_new"
|
95 |
+
|
96 |
+
|
97 |
+
export cuda_ids="3"
|
98 |
+
|
99 |
+
|
100 |
+
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python CUDA_VISIBLE_DEVICES=${cuda_ids} python ${trainer} --mode ${mode} --conf ${conf_root}/${conf} --case ${data_case}
|
101 |
+
|