Spaces:
Sleeping
Sleeping
kyleleey
commited on
Commit
·
98a77e0
1
Parent(s):
9df3c71
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +184 -0
- ckpts/configs.yml +354 -0
- ckpts/iter0800000.pth +3 -0
- video3d/__init__.py +6 -0
- video3d/cages/cages.py +218 -0
- video3d/cub_dataloaders.py +404 -0
- video3d/cub_dataloaders_ddp.py +434 -0
- video3d/dataloaders.py +375 -0
- video3d/dataloaders_ddp.py +1210 -0
- video3d/diffusion/sd.py +252 -0
- video3d/diffusion/sd_utils.py +123 -0
- video3d/diffusion/vsd.py +323 -0
- video3d/discriminator_architecture.py +83 -0
- video3d/flow/__init__.py +0 -0
- video3d/flow/flow.py +51 -0
- video3d/flow/utils.py +23 -0
- video3d/geometry/dlmesh.py +85 -0
- video3d/geometry/dmtet.py +361 -0
- video3d/model.py +1526 -0
- video3d/model_ddp.py +0 -0
- video3d/networks.py +1724 -0
- video3d/render/light.py +191 -0
- video3d/render/material.py +282 -0
- video3d/render/mesh.py +377 -0
- video3d/render/mlptexture.py +122 -0
- video3d/render/obj.py +288 -0
- video3d/render/regularizer.py +93 -0
- video3d/render/render.py +369 -0
- video3d/render/renderutils/__init__.py +11 -0
- video3d/render/renderutils/bsdf.py +151 -0
- video3d/render/renderutils/c_src/bsdf.cu +710 -0
- video3d/render/renderutils/c_src/bsdf.h +84 -0
- video3d/render/renderutils/c_src/common.cpp +74 -0
- video3d/render/renderutils/c_src/common.h +41 -0
- video3d/render/renderutils/c_src/cubemap.cu +350 -0
- video3d/render/renderutils/c_src/cubemap.h +38 -0
- video3d/render/renderutils/c_src/loss.cu +210 -0
- video3d/render/renderutils/c_src/loss.h +38 -0
- video3d/render/renderutils/c_src/mesh.cu +94 -0
- video3d/render/renderutils/c_src/mesh.h +23 -0
- video3d/render/renderutils/c_src/normal.cu +182 -0
- video3d/render/renderutils/c_src/normal.h +27 -0
- video3d/render/renderutils/c_src/tensor.h +92 -0
- video3d/render/renderutils/c_src/torch_bindings.cpp +1062 -0
- video3d/render/renderutils/c_src/vec3f.h +109 -0
- video3d/render/renderutils/c_src/vec4f.h +25 -0
- video3d/render/renderutils/loss.py +41 -0
- video3d/render/renderutils/ops.py +554 -0
- video3d/render/renderutils/tests/test_bsdf.py +296 -0
- video3d/render/renderutils/tests/test_cubemap.py +47 -0
.gitignore
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
data
|
3 |
+
data/*/
|
4 |
+
data/*/*
|
5 |
+
!data/preprocessing/
|
6 |
+
pretrained/*/
|
7 |
+
results
|
8 |
+
neural_renderer
|
9 |
+
*.zip
|
10 |
+
unchanged/
|
11 |
+
cvpr23_results/
|
12 |
+
# slurm.bash
|
13 |
+
results
|
14 |
+
results/*/
|
15 |
+
results/*
|
16 |
+
results/*/*
|
17 |
+
results/dor_checkpoints/*
|
18 |
+
results/dor_checkpoints/*/*
|
19 |
+
results/dor_checkpoints/*/*/*
|
20 |
+
|
21 |
+
|
22 |
+
.vscode
|
23 |
+
.vscode/
|
24 |
+
|
25 |
+
dor_bash_files/
|
26 |
+
zzli_bash_files/
|
27 |
+
ray_bash_files/
|
28 |
+
|
29 |
+
config/dor_exp/
|
30 |
+
config/zzli_exp/
|
31 |
+
config/ray_exp/
|
32 |
+
|
33 |
+
wandb
|
34 |
+
wandb/*/
|
35 |
+
wandb/*/*
|
36 |
+
wandb/*/*/*
|
37 |
+
canon/out/*
|
38 |
+
canon/out/
|
39 |
+
# Byte-compiled / optimized / DLL files
|
40 |
+
__pycache__/
|
41 |
+
*.py[cod]
|
42 |
+
*$py.class
|
43 |
+
|
44 |
+
# C extensions
|
45 |
+
*.so
|
46 |
+
|
47 |
+
# Distribution / packaging
|
48 |
+
.Python
|
49 |
+
build/
|
50 |
+
develop-eggs/
|
51 |
+
dist/
|
52 |
+
downloads/
|
53 |
+
eggs/
|
54 |
+
.eggs/
|
55 |
+
lib/
|
56 |
+
lib64/
|
57 |
+
parts/
|
58 |
+
sdist/
|
59 |
+
var/
|
60 |
+
wheels/
|
61 |
+
pip-wheel-metadata/
|
62 |
+
share/python-wheels/
|
63 |
+
*.egg-info/
|
64 |
+
.installed.cfg
|
65 |
+
*.egg
|
66 |
+
MANIFEST
|
67 |
+
|
68 |
+
# PyInstaller
|
69 |
+
# Usually these files are written by a python script from a template
|
70 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
71 |
+
*.manifest
|
72 |
+
*.spec
|
73 |
+
|
74 |
+
# Installer logs
|
75 |
+
pip-log.txt
|
76 |
+
pip-delete-this-directory.txt
|
77 |
+
|
78 |
+
# Unit test / coverage reports
|
79 |
+
htmlcov/
|
80 |
+
.tox/
|
81 |
+
.nox/
|
82 |
+
.coverage
|
83 |
+
.coverage.*
|
84 |
+
.cache
|
85 |
+
nosetests.xml
|
86 |
+
coverage.xml
|
87 |
+
*.cover
|
88 |
+
*.py,cover
|
89 |
+
.hypothesis/
|
90 |
+
.pytest_cache/
|
91 |
+
|
92 |
+
# Translations
|
93 |
+
*.mo
|
94 |
+
*.pot
|
95 |
+
|
96 |
+
# Django stuff:
|
97 |
+
*.log
|
98 |
+
local_settings.py
|
99 |
+
db.sqlite3
|
100 |
+
db.sqlite3-journal
|
101 |
+
|
102 |
+
# Flask stuff:
|
103 |
+
instance/
|
104 |
+
.webassets-cache
|
105 |
+
|
106 |
+
# Scrapy stuff:
|
107 |
+
.scrapy
|
108 |
+
|
109 |
+
# Sphinx documentation
|
110 |
+
docs/_build/
|
111 |
+
|
112 |
+
# PyBuilder
|
113 |
+
target/
|
114 |
+
|
115 |
+
# Jupyter Notebook
|
116 |
+
.ipynb_checkpoints
|
117 |
+
|
118 |
+
# IPython
|
119 |
+
profile_default/
|
120 |
+
ipython_config.py
|
121 |
+
|
122 |
+
# pyenv
|
123 |
+
.python-version
|
124 |
+
|
125 |
+
# pipenv
|
126 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
127 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
128 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
129 |
+
# install all needed dependencies.
|
130 |
+
#Pipfile.lock
|
131 |
+
|
132 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
133 |
+
__pypackages__/
|
134 |
+
|
135 |
+
# Celery stuff
|
136 |
+
celerybeat-schedule
|
137 |
+
celerybeat.pid
|
138 |
+
|
139 |
+
# SageMath parsed files
|
140 |
+
*.sage.py
|
141 |
+
|
142 |
+
# Environments
|
143 |
+
.env
|
144 |
+
.venv
|
145 |
+
env/
|
146 |
+
venv/
|
147 |
+
ENV/
|
148 |
+
env.bak/
|
149 |
+
venv.bak/
|
150 |
+
|
151 |
+
# Spyder project settings
|
152 |
+
.spyderproject
|
153 |
+
.spyproject
|
154 |
+
|
155 |
+
# Rope project settings
|
156 |
+
.ropeproject
|
157 |
+
|
158 |
+
# mkdocs documentation
|
159 |
+
/site
|
160 |
+
|
161 |
+
# mypy
|
162 |
+
.mypy_cache/
|
163 |
+
.dmypy.json
|
164 |
+
dmypy.json
|
165 |
+
|
166 |
+
# Pyre type checker
|
167 |
+
.pyre/
|
168 |
+
/.idea
|
169 |
+
|
170 |
+
# dependencies
|
171 |
+
# nvdiffrast/
|
172 |
+
data/preprocessing/videos/RAFT/
|
173 |
+
preprocessing_data/RAFT/
|
174 |
+
preprocessing_data/RAFT/*
|
175 |
+
preprocessing_data/preprocessing/videos/RAFT/
|
176 |
+
# debug
|
177 |
+
|
178 |
+
|
179 |
+
DINO_v2_check/out_dor
|
180 |
+
DINO_v2_check/out_dor/*
|
181 |
+
|
182 |
+
eval/*/
|
183 |
+
scripts/vis/
|
184 |
+
eval/
|
ckpts/configs.yml
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
amb_diff_max:
|
2 |
+
- 1.0
|
3 |
+
- 1.0
|
4 |
+
amb_diff_min:
|
5 |
+
- 0.0
|
6 |
+
- 0.5
|
7 |
+
arti_reg_loss_epochs:
|
8 |
+
- 8
|
9 |
+
- 276
|
10 |
+
arti_reg_loss_weight: 0.2
|
11 |
+
articulation_arch: attention
|
12 |
+
articulation_epochs:
|
13 |
+
- 2
|
14 |
+
- 276
|
15 |
+
articulation_feature_mode: sample+global
|
16 |
+
articulation_multiplier: 0.1
|
17 |
+
attach_legs_to_body_epochs:
|
18 |
+
- 8
|
19 |
+
- 276
|
20 |
+
avg_seqshape_epochs:
|
21 |
+
- 0
|
22 |
+
- 0
|
23 |
+
avg_texture_epochs:
|
24 |
+
- 0
|
25 |
+
- 0
|
26 |
+
background_mode: none
|
27 |
+
backward_prior: true
|
28 |
+
bank_mean_dist_loss_weight: 0.0
|
29 |
+
batch_size: 6
|
30 |
+
best_pose_start_iter: 10000
|
31 |
+
blur_mask: false
|
32 |
+
body_bone_idx_preset:
|
33 |
+
0:
|
34 |
+
- 0
|
35 |
+
- 0
|
36 |
+
- 0
|
37 |
+
- 0
|
38 |
+
500000:
|
39 |
+
- 0
|
40 |
+
- 0
|
41 |
+
- 0
|
42 |
+
- 0
|
43 |
+
body_bones_type: z_minmax_y+
|
44 |
+
body_rotate_reg_mode: all-bones
|
45 |
+
bone_y_thresh: 0.4
|
46 |
+
bsdf: diffuse
|
47 |
+
cam_pos_z_offset: 10
|
48 |
+
checkpoint_dir: /viscam/u/zzli/workspace/4DAnimalKingdom_dev/results/paper_exp/same_dino_1109/mb_all_data_1k_artiID_r500k
|
49 |
+
clip_tex: false
|
50 |
+
clip_tex_loss_weight: 0.0
|
51 |
+
combine_dataset: true
|
52 |
+
config: config/zzli_exp/same_dino_1109/mb_data1k_artiID_r500k.yml
|
53 |
+
constrain_legs: false
|
54 |
+
crop_fov_approx: 25
|
55 |
+
data_loader_mode: n_frame
|
56 |
+
dataset: video
|
57 |
+
debug_seq: false
|
58 |
+
deform_epochs:
|
59 |
+
- 0
|
60 |
+
- 276
|
61 |
+
deformation_reg_loss_weight: 10.0
|
62 |
+
device: cuda:0
|
63 |
+
diffusion_albedo_ratio: 0.2
|
64 |
+
diffusion_angle_front: 60
|
65 |
+
diffusion_angle_overhead: 30
|
66 |
+
diffusion_append_prompt_directions: true
|
67 |
+
diffusion_guidance_scale: 100
|
68 |
+
diffusion_light_ambient: 0.5
|
69 |
+
diffusion_light_diffuse: 0.8
|
70 |
+
diffusion_loss_weight: 0.0001
|
71 |
+
diffusion_max_step: 0.6
|
72 |
+
diffusion_num_random_cameras: 1
|
73 |
+
diffusion_phi_offset: 180
|
74 |
+
diffusion_precision: float16
|
75 |
+
diffusion_prompt: an elephant
|
76 |
+
diffusion_radius_range:
|
77 |
+
- 9
|
78 |
+
- 11
|
79 |
+
diffusion_random_light: true
|
80 |
+
diffusion_resolution: 256
|
81 |
+
diffusion_shading_ratio: 0.4
|
82 |
+
diffusion_theta_range:
|
83 |
+
- 0
|
84 |
+
- 100
|
85 |
+
diffusion_uniform_sphere_rate: 1
|
86 |
+
dim_of_classes: 128
|
87 |
+
dino_feat_im_loss_weight:
|
88 |
+
0: 10.0
|
89 |
+
300000: 1.0
|
90 |
+
dino_feature_dim: 16
|
91 |
+
dino_feature_input: false
|
92 |
+
dino_feature_recon_dim: 16
|
93 |
+
dino_max: 1.0
|
94 |
+
dino_min: 0.0
|
95 |
+
disable_fewshot: false
|
96 |
+
disc_gt: false
|
97 |
+
disc_iv: true
|
98 |
+
disc_iv_label: Real
|
99 |
+
disc_reg_mul: 10.0
|
100 |
+
discriminator_loss_weight: 1.0
|
101 |
+
dmtet_grid: 256
|
102 |
+
dmtet_grid_smaller: 256
|
103 |
+
dmtet_grid_smaller_epoch: 1
|
104 |
+
embed_concat_pts: true
|
105 |
+
embedder_freq_arti: 8
|
106 |
+
embedder_freq_deform: 10
|
107 |
+
embedder_freq_dino: 8
|
108 |
+
embedder_freq_shape: 8
|
109 |
+
embedder_freq_tex: 10
|
110 |
+
enable_articulation: true
|
111 |
+
enable_articulation_bone_threshold: true
|
112 |
+
enable_articulation_idadd: true
|
113 |
+
enable_deform: true
|
114 |
+
enable_disc: true
|
115 |
+
enable_encoder: true
|
116 |
+
enable_lighting: true
|
117 |
+
enable_mask_distribution: true
|
118 |
+
enable_memory_bank: true
|
119 |
+
enable_pose: true
|
120 |
+
enable_prior: true
|
121 |
+
enable_sds: false
|
122 |
+
encoder_arch: vit
|
123 |
+
encoder_frozen: true
|
124 |
+
encoder_pretrained: true
|
125 |
+
enhance_back_view: true
|
126 |
+
enhance_back_view_path: /viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data
|
127 |
+
extra_renders:
|
128 |
+
instance:
|
129 |
+
- geo_normal
|
130 |
+
- diffuse
|
131 |
+
- gray
|
132 |
+
faces_per_pixel: 10
|
133 |
+
few_shot_category_num: -1
|
134 |
+
few_shot_class_vector_init: copy
|
135 |
+
few_shot_data_dir:
|
136 |
+
- /viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all
|
137 |
+
- /viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered
|
138 |
+
few_shot_iteration_save: true
|
139 |
+
few_shot_iteration_save_freq: 2000
|
140 |
+
few_shot_lr: 0.0001
|
141 |
+
few_shot_optimize: exp
|
142 |
+
few_shot_optimize_bank: all
|
143 |
+
few_shot_original_classes_num: 7
|
144 |
+
few_shot_resume: true
|
145 |
+
few_shot_test_category_names:
|
146 |
+
- caracal
|
147 |
+
- impala
|
148 |
+
- ox
|
149 |
+
- squirrel
|
150 |
+
- wolf
|
151 |
+
few_shot_test_category_num: 5
|
152 |
+
few_shot_val_image_num: 5
|
153 |
+
fix_viz_batch: false
|
154 |
+
flow_loss_epochs:
|
155 |
+
- 0
|
156 |
+
- 0
|
157 |
+
flow_loss_weight: 0.0
|
158 |
+
forbid_leg_rotate: true
|
159 |
+
fov_w: 60
|
160 |
+
full_size_h: 1080
|
161 |
+
full_size_w: 1920
|
162 |
+
gamma: 1e-6
|
163 |
+
gan_tex: false
|
164 |
+
grid_scale: 7
|
165 |
+
hidden_size: 256
|
166 |
+
in_image_size: 256
|
167 |
+
init_sdf: ellipsoid
|
168 |
+
is_dry_run: false
|
169 |
+
iter_arti_reg_loss_start: 60000
|
170 |
+
iter_articulation_start: 20000
|
171 |
+
iter_attach_leg_to_body_start: 60000
|
172 |
+
iter_deformation_start: 500000
|
173 |
+
iter_leg_rotation_start: 300000
|
174 |
+
iter_nozeroy_start: 20000
|
175 |
+
jitter_grid: 0.05
|
176 |
+
kd_max:
|
177 |
+
- 1.0
|
178 |
+
- 1.0
|
179 |
+
- 1.0
|
180 |
+
- 1.0
|
181 |
+
kd_min:
|
182 |
+
- 0.0
|
183 |
+
- 0.0
|
184 |
+
- 0.0
|
185 |
+
- 0.0
|
186 |
+
keep_num_checkpoint: 1
|
187 |
+
ks_max:
|
188 |
+
- 0.0
|
189 |
+
- 0.0
|
190 |
+
- 0.0
|
191 |
+
ks_min:
|
192 |
+
- 0.0
|
193 |
+
- 0.0
|
194 |
+
- 0.0
|
195 |
+
latent_dim: 256
|
196 |
+
load_dino_cluster: false
|
197 |
+
load_dino_feature: true
|
198 |
+
log_freq_images: 501
|
199 |
+
log_freq_losses: 50
|
200 |
+
log_train_images: true
|
201 |
+
logit_loss_dino_feat_im_loss_multiplier:
|
202 |
+
0: 50.0
|
203 |
+
300000: 500.0
|
204 |
+
logit_loss_weight: 1.0
|
205 |
+
lookat_init:
|
206 |
+
- 0.0
|
207 |
+
- 0.0
|
208 |
+
- 0.0
|
209 |
+
lookat_zeroy: true
|
210 |
+
lr: 6.0e-05
|
211 |
+
mask_disc_loss_feat_condition: true
|
212 |
+
mask_disc_loss_weight: 0.1
|
213 |
+
mask_discriminator_iter:
|
214 |
+
- 80000
|
215 |
+
- 300000
|
216 |
+
mask_distribution_loss_freq: 1
|
217 |
+
mask_distribution_loss_weight: 0.0
|
218 |
+
mask_distribution_path: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/mask_distribution
|
219 |
+
max_arti_angle: 60
|
220 |
+
max_trans_xy_range_ratio: 0.5
|
221 |
+
max_trans_z_range_ratio: 0.5
|
222 |
+
memory_bank_init: copy
|
223 |
+
memory_bank_size: 60
|
224 |
+
memory_bank_topk: 10
|
225 |
+
memory_encoder: DINO
|
226 |
+
memory_retrieve: cos-linear
|
227 |
+
mesh_edge_length_loss_weight: 0.0
|
228 |
+
mesh_normal_consistency_loss_weight: 0.0
|
229 |
+
min_seq_len: 1
|
230 |
+
nrm_max:
|
231 |
+
- 1.0
|
232 |
+
- 1.0
|
233 |
+
- 1.0
|
234 |
+
nrm_min:
|
235 |
+
- -1.0
|
236 |
+
- -1.0
|
237 |
+
- 0.0
|
238 |
+
num_body_bones: 8
|
239 |
+
num_epochs: 1375
|
240 |
+
num_iterations: 10000000
|
241 |
+
num_layers_arti: 4
|
242 |
+
num_layers_deform: 5
|
243 |
+
num_layers_dino: 5
|
244 |
+
num_layers_light: 5
|
245 |
+
num_layers_tex: 8
|
246 |
+
num_leg_bones: 3
|
247 |
+
num_legs: 4
|
248 |
+
num_sample_frames: 1
|
249 |
+
num_workers: 8
|
250 |
+
out_image_size: 256
|
251 |
+
perturb_articulation_epochs:
|
252 |
+
- 0
|
253 |
+
- 0
|
254 |
+
perturb_normal: false
|
255 |
+
perturb_sdf: false
|
256 |
+
pose_arch: encoder_dino_patch_key
|
257 |
+
pose_entropy_loss_weight: 0.0
|
258 |
+
pose_epochs:
|
259 |
+
- 0
|
260 |
+
- 0
|
261 |
+
pose_xflip_recon_epochs:
|
262 |
+
- 0
|
263 |
+
- 0
|
264 |
+
pose_xflip_reg_loss_weight: 0.0
|
265 |
+
prior_condition_choice: mod
|
266 |
+
prior_lr: 0.0006
|
267 |
+
prior_sdf_mode: mlp
|
268 |
+
pyplot_metrics: false
|
269 |
+
random_flip_train: true
|
270 |
+
random_mask_law: random_azimuth
|
271 |
+
random_sample_train_frames: false
|
272 |
+
random_sample_val_frames: true
|
273 |
+
rank: 0
|
274 |
+
reg_body_rotate_mult: 0.1
|
275 |
+
render_dino_mode: feature_mlp
|
276 |
+
renderer_spp: 4
|
277 |
+
resume: true
|
278 |
+
resume_prior_optim: true
|
279 |
+
rgb_loss_weight: 1.0
|
280 |
+
rgb_suffix: .png
|
281 |
+
root_dir: /viscam/u/zzli
|
282 |
+
rot_all_quad_epochs:
|
283 |
+
- 0
|
284 |
+
- 276
|
285 |
+
rot_rand_quad_epochs:
|
286 |
+
- 0
|
287 |
+
- 0
|
288 |
+
rot_rep: quadlookat
|
289 |
+
rot_temp_scalar: 1.0
|
290 |
+
run_few_shot: true
|
291 |
+
run_train: true
|
292 |
+
save_checkpoint_freq: 1
|
293 |
+
save_result_freq: 501
|
294 |
+
sdf_bce_reg_loss_min_weight: 0
|
295 |
+
sdf_bce_reg_loss_weight: 0
|
296 |
+
sdf_gradient_reg_loss_min_weight: 0.1
|
297 |
+
sdf_gradient_reg_loss_weight: 0.1
|
298 |
+
sdf_inflate_reg_loss_epochs:
|
299 |
+
- 0
|
300 |
+
- 0
|
301 |
+
sdf_reg_decay_start_iter: 10000
|
302 |
+
seed: 0
|
303 |
+
seqshape_epochs:
|
304 |
+
- 0
|
305 |
+
- 0
|
306 |
+
shuffle_train_seqs: true
|
307 |
+
sigma: 1e-6
|
308 |
+
silhouette_dt_loss_weight: 0.0
|
309 |
+
silhouette_inv_dt_loss_weight: 50.0
|
310 |
+
silhouette_loss_weight: 5.0
|
311 |
+
skinning_temperature: 0.05
|
312 |
+
skip_beginning: 0
|
313 |
+
skip_end: 0
|
314 |
+
small_leg_angle: true
|
315 |
+
smooth_deformation_loss_weight: 10.0
|
316 |
+
static_root_bones: false
|
317 |
+
sym_deform: true
|
318 |
+
sym_dino: false
|
319 |
+
sym_prior_shape: true
|
320 |
+
sym_texture: true
|
321 |
+
temp_clip_high: 10.0
|
322 |
+
temp_clip_low: 1.0
|
323 |
+
tex_im_size: 256
|
324 |
+
texture_epochs:
|
325 |
+
- 0
|
326 |
+
- 276
|
327 |
+
texture_mode: mlp
|
328 |
+
train_data_dir:
|
329 |
+
bear: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/bear_comb_dinov2_new/train
|
330 |
+
cow: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/cow_comb_dinov2_new/train
|
331 |
+
elephant: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/elephant_comb_dinov2_new/train
|
332 |
+
giraffe: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/giraffe_comb_dinov2_new/train
|
333 |
+
horse: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/horse_comb_dinov2_new/train
|
334 |
+
sheep: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/sheep_comb_dinov2_new/train
|
335 |
+
zebra: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/zebra_comb_dinov2_new/train
|
336 |
+
train_with_cub: false
|
337 |
+
use_logger: true
|
338 |
+
use_scheduler: false
|
339 |
+
use_wandb: false
|
340 |
+
val_data_dir:
|
341 |
+
bear: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/bear_comb_dinov2_new/val
|
342 |
+
cow: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/cow_comb_dinov2_new/val
|
343 |
+
elephant: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/elephant_comb_dinov2_new/val
|
344 |
+
giraffe: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/giraffe_comb_dinov2_new/val
|
345 |
+
horse: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/horse_comb_dinov2_new/val
|
346 |
+
sheep: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/sheep_comb_dinov2_new/val
|
347 |
+
zebra: /viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new/zebra_comb_dinov2_new/val
|
348 |
+
visualize_validation: true
|
349 |
+
vit_final_layer_type: conv
|
350 |
+
which_vit: dino_vits8
|
351 |
+
world_size: 1
|
352 |
+
zflip_epochs:
|
353 |
+
- 0
|
354 |
+
- 0
|
ckpts/iter0800000.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c7b090f1ff3e76e2ba608a25a2bd79af2892d6bb307132c9d038082395c1d57
|
3 |
+
size 306560367
|
video3d/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils.misc import setup_runtime
|
2 |
+
from .trainer import Trainer
|
3 |
+
from .trainer_ddp import TrainerDDP
|
4 |
+
from .model import Unsup3D
|
5 |
+
from .model_ddp import Unsup3DDDP
|
6 |
+
from .trainer_few_shot import Fewshot_Trainer
|
video3d/cages/cages.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cages code used from https://github.com/yifita/deep_cage
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import trimesh
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def deform_with_MVC(cage, cage_deformed, cage_face, query, verbose=False):
|
9 |
+
"""
|
10 |
+
cage (B,C,3)
|
11 |
+
cage_deformed (B,C,3)
|
12 |
+
cage_face (B,F,3) int64
|
13 |
+
query (B,Q,3)
|
14 |
+
"""
|
15 |
+
weights, weights_unnormed = mean_value_coordinates_3D(query, cage, cage_face, verbose=True)
|
16 |
+
# weights = weights.detach()
|
17 |
+
deformed = torch.sum(weights.unsqueeze(-1)*cage_deformed.unsqueeze(1), dim=2)
|
18 |
+
if verbose:
|
19 |
+
return deformed, weights, weights_unnormed
|
20 |
+
return deformed
|
21 |
+
|
22 |
+
|
23 |
+
def loadInitCage(template):
|
24 |
+
init_cage_V, init_cage_F = read_trimesh(template)
|
25 |
+
init_cage_V = torch.from_numpy(init_cage_V[:,:3].astype(np.float32)).unsqueeze(0)*2.0
|
26 |
+
init_cage_F = torch.from_numpy(init_cage_F[:,:3].astype(np.int64)).unsqueeze(0)
|
27 |
+
return init_cage_V, init_cage_F
|
28 |
+
|
29 |
+
|
30 |
+
def read_trimesh(path):
|
31 |
+
mesh = trimesh.load(path)
|
32 |
+
return mesh.vertices, mesh.faces
|
33 |
+
|
34 |
+
|
35 |
+
# util functions from pytorch_points
|
36 |
+
PI = 3.1415927
|
37 |
+
|
38 |
+
def normalize_to_box(input):
|
39 |
+
"""
|
40 |
+
normalize point cloud to unit bounding box
|
41 |
+
center = (max - min)/2
|
42 |
+
scale = max(abs(x))
|
43 |
+
input: pc [N, P, dim] or [P, dim]
|
44 |
+
output: pc, centroid, furthest_distance
|
45 |
+
"""
|
46 |
+
if len(input.shape) == 2:
|
47 |
+
axis = 0
|
48 |
+
P = input.shape[0]
|
49 |
+
D = input.shape[1]
|
50 |
+
elif len(input.shape) == 3:
|
51 |
+
axis = 1
|
52 |
+
P = input.shape[1]
|
53 |
+
D = input.shape[2]
|
54 |
+
if isinstance(input, np.ndarray):
|
55 |
+
maxP = np.amax(input, axis=axis, keepdims=True)
|
56 |
+
minP = np.amin(input, axis=axis, keepdims=True)
|
57 |
+
centroid = (maxP+minP)/2
|
58 |
+
input = input - centroid
|
59 |
+
furthest_distance = np.amax(np.abs(input), axis=(axis, -1), keepdims=True)
|
60 |
+
input = input / furthest_distance
|
61 |
+
elif isinstance(input, torch.Tensor):
|
62 |
+
maxP = torch.max(input, dim=axis, keepdim=True)[0]
|
63 |
+
minP = torch.min(input, dim=axis, keepdim=True)[0]
|
64 |
+
centroid = (maxP+minP)/2
|
65 |
+
input = input - centroid
|
66 |
+
in_shape = list(input.shape[:axis])+[P*D]
|
67 |
+
furthest_distance = torch.max(torch.abs(input).view(in_shape), dim=axis, keepdim=True)[0]
|
68 |
+
furthest_distance = furthest_distance.unsqueeze(-1)
|
69 |
+
input = input / furthest_distance
|
70 |
+
|
71 |
+
return input, centroid, furthest_distance
|
72 |
+
|
73 |
+
def normalize(tensor, dim=-1):
|
74 |
+
"""normalize tensor in specified dimension"""
|
75 |
+
return torch.nn.functional.normalize(tensor, p=2, dim=dim, eps=1e-12, out=None)
|
76 |
+
|
77 |
+
|
78 |
+
def check_values(tensor):
|
79 |
+
"""return true if tensor doesn't contain NaN or Inf"""
|
80 |
+
return not (torch.any(torch.isnan(tensor)).item() or torch.any(torch.isinf(tensor)).item())
|
81 |
+
|
82 |
+
|
83 |
+
class ScatterAdd(torch.autograd.Function):
|
84 |
+
@staticmethod
|
85 |
+
def forward(ctx, src, idx, dim, out_size, fill=0.0):
|
86 |
+
out = torch.full(out_size, fill, device=src.device, dtype=src.dtype)
|
87 |
+
ctx.save_for_backward(idx)
|
88 |
+
out.scatter_add_(dim, idx, src)
|
89 |
+
ctx.mark_non_differentiable(idx)
|
90 |
+
ctx.dim = dim
|
91 |
+
return out
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def backward(ctx, ograd):
|
95 |
+
idx, = ctx.saved_tensors
|
96 |
+
grad = torch.gather(ograd, ctx.dim, idx)
|
97 |
+
return grad, None, None, None, None
|
98 |
+
|
99 |
+
|
100 |
+
_scatter_add = ScatterAdd.apply
|
101 |
+
|
102 |
+
|
103 |
+
def scatter_add(src, idx, dim, out_size=None, fill=0.0):
|
104 |
+
if out_size is None:
|
105 |
+
out_size = list(src.size())
|
106 |
+
dim_size = idx.max().item()+1
|
107 |
+
out_size[dim] = dim_size
|
108 |
+
return _scatter_add(src, idx, dim, out_size, fill)
|
109 |
+
|
110 |
+
|
111 |
+
def mean_value_coordinates_3D(query, vertices, faces, verbose=False):
|
112 |
+
"""
|
113 |
+
Tao Ju et.al. MVC for 3D triangle meshes
|
114 |
+
params:
|
115 |
+
query (B,P,3)
|
116 |
+
vertices (B,N,3)
|
117 |
+
faces (B,F,3)
|
118 |
+
return:
|
119 |
+
wj (B,P,N)
|
120 |
+
"""
|
121 |
+
B, F, _ = faces.shape
|
122 |
+
_, P, _ = query.shape
|
123 |
+
_, N, _ = vertices.shape
|
124 |
+
# u_i = p_i - x (B,P,N,3)
|
125 |
+
uj = vertices.unsqueeze(1) - query.unsqueeze(2)
|
126 |
+
# \|u_i\| (B,P,N,1)
|
127 |
+
dj = torch.norm(uj, dim=-1, p=2, keepdim=True)
|
128 |
+
uj = normalize(uj, dim=-1)
|
129 |
+
# gather triangle B,P,F,3,3
|
130 |
+
ui = torch.gather(uj.unsqueeze(2).expand(-1,-1,F,-1,-1),
|
131 |
+
3,
|
132 |
+
faces.unsqueeze(1).unsqueeze(-1).expand(-1,P,-1,-1,3))
|
133 |
+
# li = \|u_{i+1}-u_{i-1}\| (B,P,F,3)
|
134 |
+
li = torch.norm(ui[:,:,:,[1, 2, 0],:] - ui[:, :, :,[2, 0, 1],:], dim=-1, p=2)
|
135 |
+
eps = 2e-5
|
136 |
+
li = torch.where(li>=2, li-(li.detach()-(2-eps)), li)
|
137 |
+
li = torch.where(li<=-2, li-(li.detach()+(2-eps)), li)
|
138 |
+
# asin(x) is inf at +/-1
|
139 |
+
# θi = 2arcsin[li/2] (B,P,F,3)
|
140 |
+
theta_i = 2*torch.asin(li/2)
|
141 |
+
assert(check_values(theta_i))
|
142 |
+
# B,P,F,1
|
143 |
+
h = torch.sum(theta_i, dim=-1, keepdim=True)/2
|
144 |
+
# wi← sin[θi]d{i−1}d{i+1}
|
145 |
+
# (B,P,F,3) ci ← (2sin[h]sin[h−θi])/(sin[θ_{i+1}]sin[θ_{i−1}])−1
|
146 |
+
ci = 2*torch.sin(h)*torch.sin(h-theta_i)/(torch.sin(theta_i[:,:,:,[1, 2, 0]])*torch.sin(theta_i[:,:,:,[2, 0, 1]]))-1
|
147 |
+
|
148 |
+
# NOTE: because of floating point ci can be slightly larger than 1, causing problem with sqrt(1-ci^2)
|
149 |
+
# NOTE: sqrt(x)' is nan for x=0, hence use eps
|
150 |
+
eps = 1e-5
|
151 |
+
ci = torch.where(ci>=1, ci-(ci.detach()-(1-eps)), ci)
|
152 |
+
ci = torch.where(ci<=-1, ci-(ci.detach()+(1-eps)), ci)
|
153 |
+
# si← sign[det[u1,u2,u3]]sqrt(1-ci^2)
|
154 |
+
# (B,P,F)*(B,P,F,3)
|
155 |
+
|
156 |
+
si = torch.sign(torch.det(ui)).unsqueeze(-1)*torch.sqrt(1-ci**2) # sqrt gradient nan for 0
|
157 |
+
assert(check_values(si))
|
158 |
+
# (B,P,F,3)
|
159 |
+
di = torch.gather(dj.unsqueeze(2).squeeze(-1).expand(-1,-1,F,-1), 3,
|
160 |
+
faces.unsqueeze(1).expand(-1,P,-1,-1))
|
161 |
+
assert(check_values(di))
|
162 |
+
# if si.requires_grad:
|
163 |
+
# vertices.register_hook(save_grad("mvc/dv"))
|
164 |
+
# li.register_hook(save_grad("mvc/dli"))
|
165 |
+
# theta_i.register_hook(save_grad("mvc/dtheta"))
|
166 |
+
# ci.register_hook(save_grad("mvc/dci"))
|
167 |
+
# si.register_hook(save_grad("mvc/dsi"))
|
168 |
+
# di.register_hook(save_grad("mvc/ddi"))
|
169 |
+
|
170 |
+
# wi← (θi −c[i+1]θ[i−1] −c[i−1]θ[i+1])/(disin[θi+1]s[i−1])
|
171 |
+
# B,P,F,3
|
172 |
+
# CHECK is there a 2* in the denominator
|
173 |
+
wi = (theta_i-ci[:,:,:,[1,2,0]]*theta_i[:,:,:,[2,0,1]]-ci[:,:,:,[2,0,1]]*theta_i[:,:,:,[1,2,0]])/(di*torch.sin(theta_i[:,:,:,[1,2,0]])*si[:,:,:,[2,0,1]])
|
174 |
+
# if ∃i,|si| ≤ ε, set wi to 0. coplaner with T but outside
|
175 |
+
# ignore coplaner outside triangle
|
176 |
+
# alternative check
|
177 |
+
# (B,F,3,3)
|
178 |
+
# triangle_points = torch.gather(vertices.unsqueeze(1).expand(-1,F,-1,-1), 2, faces.unsqueeze(-1).expand(-1,-1,-1,3))
|
179 |
+
# # (B,P,F,3), (B,1,F,3) -> (B,P,F,1)
|
180 |
+
# determinant = dot_product(triangle_points[:,:,:,0].unsqueeze(1)-query.unsqueeze(2),
|
181 |
+
# torch.cross(triangle_points[:,:,:,1]-triangle_points[:,:,:,0],
|
182 |
+
# triangle_points[:,:,:,2]-triangle_points[:,:,:,0], dim=-1).unsqueeze(1), dim=-1, keepdim=True).detach()
|
183 |
+
# # (B,P,F,1)
|
184 |
+
# sqrdist = determinant*determinant / (4 * sqrNorm(torch.cross(triangle_points[:,:,:,1]-triangle_points[:,:,:,0], triangle_points[:,:,:,2]-triangle_points[:,:,:,0], dim=-1), keepdim=True))
|
185 |
+
|
186 |
+
wi = torch.where(torch.any(torch.abs(si) <= 1e-5, keepdim=True, dim=-1), torch.zeros_like(wi), wi)
|
187 |
+
# wi = torch.where(sqrdist <= 1e-5, torch.zeros_like(wi), wi)
|
188 |
+
|
189 |
+
# if π −h < ε, x lies on t, use 2D barycentric coordinates
|
190 |
+
# inside triangle
|
191 |
+
inside_triangle = (PI-h).squeeze(-1)<1e-4
|
192 |
+
# set all F for this P to zero
|
193 |
+
wi = torch.where(torch.any(inside_triangle, dim=-1, keepdim=True).unsqueeze(-1), torch.zeros_like(wi), wi)
|
194 |
+
# CHECK is it di https://www.cse.wustl.edu/~taoju/research/meanvalue.pdf or li http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.516.1856&rep=rep1&type=pdf
|
195 |
+
wi = torch.where(inside_triangle.unsqueeze(-1).expand(-1,-1,-1,wi.shape[-1]), torch.sin(theta_i)*di[:,:,:,[2,0,1]]*di[:,:,:,[1,2,0]], wi)
|
196 |
+
|
197 |
+
# sum over all faces face -> vertex (B,P,F*3) -> (B,P,N)
|
198 |
+
wj = scatter_add(wi.reshape(B,P,-1).contiguous(), faces.unsqueeze(1).expand(-1,P,-1,-1).reshape(B,P,-1), 2, out_size=(B,P,N))
|
199 |
+
|
200 |
+
# close to vertex (B,P,N)
|
201 |
+
close_to_point = dj.squeeze(-1) < 1e-8
|
202 |
+
# set all F for this P to zero
|
203 |
+
wj = torch.where(torch.any(close_to_point, dim=-1, keepdim=True), torch.zeros_like(wj), wj)
|
204 |
+
wj = torch.where(close_to_point, torch.ones_like(wj), wj)
|
205 |
+
|
206 |
+
# (B,P,1)
|
207 |
+
sumWj = torch.sum(wj, dim=-1, keepdim=True)
|
208 |
+
sumWj = torch.where(sumWj==0, torch.ones_like(sumWj), sumWj)
|
209 |
+
|
210 |
+
wj_normalised = wj / sumWj
|
211 |
+
# if wj.requires_grad:
|
212 |
+
# saved_variables["mvc/wi"] = wi
|
213 |
+
# wi.register_hook(save_grad("mvc/dwi"))
|
214 |
+
# wj.register_hook(save_grad("mvc/dwj"))
|
215 |
+
if verbose:
|
216 |
+
return wj_normalised, wi
|
217 |
+
else:
|
218 |
+
return wj_normalised
|
video3d/cub_dataloaders.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import scipy.io as sio
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from types import SimpleNamespace
|
9 |
+
|
10 |
+
|
11 |
+
def get_cub_loader(data_dir, split='test', is_validation=False, batch_size=256, num_workers=4, image_size=256):
|
12 |
+
opts = SimpleNamespace()
|
13 |
+
opts.data_dir = data_dir
|
14 |
+
opts.padding_frac = 0.05
|
15 |
+
opts.jitter_frac = 0.05
|
16 |
+
opts.input_size = image_size
|
17 |
+
opts.split = split
|
18 |
+
|
19 |
+
dataset = CUBDataset(opts)
|
20 |
+
loader = torch.utils.data.DataLoader(
|
21 |
+
dataset,
|
22 |
+
batch_size=batch_size,
|
23 |
+
shuffle=not is_validation,
|
24 |
+
num_workers=num_workers,
|
25 |
+
pin_memory=True
|
26 |
+
)
|
27 |
+
return loader
|
28 |
+
|
29 |
+
|
30 |
+
class CUBDataset(Dataset):
|
31 |
+
def __init__(self, opts):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.opts = opts
|
35 |
+
self.img_size = opts.input_size
|
36 |
+
self.jitter_frac = opts.jitter_frac
|
37 |
+
self.padding_frac = opts.padding_frac
|
38 |
+
self.split = opts.split
|
39 |
+
self.data_dir = opts.data_dir
|
40 |
+
self.data_cache_dir = osp.join(self.data_dir, 'cachedir/cub')
|
41 |
+
self.img_dir = osp.join(self.data_dir, 'images')
|
42 |
+
|
43 |
+
self.anno_path = osp.join(self.data_cache_dir, 'data', '%s_cub_cleaned.mat' % self.split)
|
44 |
+
self.anno_sfm_path = osp.join(self.data_cache_dir, 'sfm', 'anno_%s.mat' % self.split)
|
45 |
+
|
46 |
+
if not osp.exists(self.anno_path):
|
47 |
+
print('%s doesnt exist!' % self.anno_path)
|
48 |
+
import pdb; pdb.set_trace()
|
49 |
+
|
50 |
+
# Load the annotation file.
|
51 |
+
print('loading %s' % self.anno_path)
|
52 |
+
self.anno = sio.loadmat(
|
53 |
+
self.anno_path, struct_as_record=False, squeeze_me=True)['images']
|
54 |
+
self.anno_sfm = sio.loadmat(
|
55 |
+
self.anno_sfm_path, struct_as_record=False, squeeze_me=True)['sfm_anno']
|
56 |
+
|
57 |
+
self.kp_perm = np.array([1, 2, 3, 4, 5, 6, 11, 12, 13, 10, 7, 8, 9, 14, 15]) - 1;
|
58 |
+
|
59 |
+
self.num_imgs = len(self.anno)
|
60 |
+
print('%d images' % self.num_imgs)
|
61 |
+
|
62 |
+
def forward_img(self, index):
|
63 |
+
data = self.anno[index]
|
64 |
+
data_sfm = self.anno_sfm[0]
|
65 |
+
|
66 |
+
# sfm_pose = (sfm_c, sfm_t, sfm_r)
|
67 |
+
sfm_pose = [np.copy(data_sfm.scale), np.copy(data_sfm.trans), np.copy(data_sfm.rot)]
|
68 |
+
|
69 |
+
sfm_rot = np.pad(sfm_pose[2], (0,1), 'constant')
|
70 |
+
sfm_rot[3, 3] = 1
|
71 |
+
sfm_pose[2] = quaternion_from_matrix(sfm_rot, isprecise=True)
|
72 |
+
|
73 |
+
img_path = osp.join(self.img_dir, str(data.rel_path))
|
74 |
+
#img_path = img_path.replace("JPEG", "jpg")
|
75 |
+
img = np.array(Image.open(img_path))
|
76 |
+
|
77 |
+
# Some are grayscale:
|
78 |
+
if len(img.shape) == 2:
|
79 |
+
img = np.repeat(np.expand_dims(img, 2), 3, axis=2)
|
80 |
+
mask = data.mask
|
81 |
+
mask = np.expand_dims(mask, 2)
|
82 |
+
h,w,_ = mask.shape
|
83 |
+
|
84 |
+
# Adjust to 0 indexing
|
85 |
+
bbox = np.array(
|
86 |
+
[data.bbox.x1, data.bbox.y1, data.bbox.x2, data.bbox.y2],
|
87 |
+
float) - 1
|
88 |
+
|
89 |
+
parts = data.parts.T.astype(float)
|
90 |
+
kp = np.copy(parts)
|
91 |
+
vis = kp[:, 2] > 0
|
92 |
+
kp[vis, :2] -= 1
|
93 |
+
|
94 |
+
# Peturb bbox
|
95 |
+
if self.split == 'train':
|
96 |
+
bbox = peturb_bbox(
|
97 |
+
bbox, pf=self.padding_frac, jf=self.jitter_frac)
|
98 |
+
else:
|
99 |
+
bbox = peturb_bbox(
|
100 |
+
bbox, pf=self.padding_frac, jf=0)
|
101 |
+
bbox = square_bbox(bbox)
|
102 |
+
|
103 |
+
# crop image around bbox, translate kps
|
104 |
+
img, mask, kp, sfm_pose = self.crop_image(img, mask, bbox, kp, vis, sfm_pose)
|
105 |
+
|
106 |
+
# scale image, and mask. And scale kps.
|
107 |
+
img, mask, kp, sfm_pose = self.scale_image(img, mask, kp, vis, sfm_pose)
|
108 |
+
|
109 |
+
# Mirror image on random.
|
110 |
+
if self.split == 'train':
|
111 |
+
img, mask, kp, sfm_pose = self.mirror_image(img, mask, kp, sfm_pose)
|
112 |
+
|
113 |
+
# Normalize kp to be [-1, 1]
|
114 |
+
img_h, img_w = img.shape[:2]
|
115 |
+
kp_norm, sfm_pose = self.normalize_kp(kp, sfm_pose, img_h, img_w)
|
116 |
+
|
117 |
+
# img = Image.fromarray(np.asarray(img, np.uint8))
|
118 |
+
mask = np.asarray(mask, np.float32)
|
119 |
+
return img, kp_norm, mask, sfm_pose, img_path
|
120 |
+
|
121 |
+
def normalize_kp(self, kp, sfm_pose, img_h, img_w):
|
122 |
+
vis = kp[:, 2, None] > 0
|
123 |
+
new_kp = np.stack([2 * (kp[:, 0] / img_w) - 1,
|
124 |
+
2 * (kp[:, 1] / img_h) - 1,
|
125 |
+
kp[:, 2]]).T
|
126 |
+
sfm_pose[0] *= (1.0/img_w + 1.0/img_h)
|
127 |
+
sfm_pose[1][0] = 2.0 * (sfm_pose[1][0] / img_w) - 1
|
128 |
+
sfm_pose[1][1] = 2.0 * (sfm_pose[1][1] / img_h) - 1
|
129 |
+
new_kp = vis * new_kp
|
130 |
+
|
131 |
+
return new_kp, sfm_pose
|
132 |
+
|
133 |
+
def crop_image(self, img, mask, bbox, kp, vis, sfm_pose):
|
134 |
+
# crop image and mask and translate kps
|
135 |
+
img = crop(img, bbox, bgval=1)
|
136 |
+
mask = crop(mask, bbox, bgval=0)
|
137 |
+
kp[vis, 0] -= bbox[0]
|
138 |
+
kp[vis, 1] -= bbox[1]
|
139 |
+
sfm_pose[1][0] -= bbox[0]
|
140 |
+
sfm_pose[1][1] -= bbox[1]
|
141 |
+
return img, mask, kp, sfm_pose
|
142 |
+
|
143 |
+
def scale_image(self, img, mask, kp, vis, sfm_pose):
|
144 |
+
# Scale image so largest bbox size is img_size
|
145 |
+
bwidth = np.shape(img)[0]
|
146 |
+
bheight = np.shape(img)[1]
|
147 |
+
scale = self.img_size / float(max(bwidth, bheight))
|
148 |
+
img_scale, _ = resize_img(img, scale)
|
149 |
+
# if img_scale.shape[0] != self.img_size:
|
150 |
+
# print('bad!')
|
151 |
+
# import ipdb; ipdb.set_trace()
|
152 |
+
# mask_scale, _ = resize_img(mask, scale)
|
153 |
+
# mask_scale, _ = resize_img(mask, scale, interpolation=cv2.INTER_NEAREST)
|
154 |
+
mask_scale, _ = resize_img(mask, scale)
|
155 |
+
kp[vis, :2] *= scale
|
156 |
+
sfm_pose[0] *= scale
|
157 |
+
sfm_pose[1] *= scale
|
158 |
+
|
159 |
+
return img_scale, mask_scale, kp, sfm_pose
|
160 |
+
|
161 |
+
def mirror_image(self, img, mask, kp, sfm_pose):
|
162 |
+
kp_perm = self.kp_perm
|
163 |
+
if np.random.rand(1) > 0.5:
|
164 |
+
# Need copy bc torch collate doesnt like neg strides
|
165 |
+
img_flip = img[:, ::-1, :].copy()
|
166 |
+
mask_flip = mask[:, ::-1].copy()
|
167 |
+
|
168 |
+
# Flip kps.
|
169 |
+
new_x = img.shape[1] - kp[:, 0] - 1
|
170 |
+
kp_flip = np.hstack((new_x[:, None], kp[:, 1:]))
|
171 |
+
kp_flip = kp_flip[kp_perm, :]
|
172 |
+
# Flip sfm_pose Rot.
|
173 |
+
R = quaternion_matrix(sfm_pose[2])
|
174 |
+
flip_R = np.diag([-1, 1, 1, 1]).dot(R.dot(np.diag([-1, 1, 1, 1])))
|
175 |
+
sfm_pose[2] = quaternion_from_matrix(flip_R, isprecise=True)
|
176 |
+
# Flip tx
|
177 |
+
tx = img.shape[1] - sfm_pose[1][0] - 1
|
178 |
+
sfm_pose[1][0] = tx
|
179 |
+
return img_flip, mask_flip, kp_flip, sfm_pose
|
180 |
+
else:
|
181 |
+
return img, mask, kp, sfm_pose
|
182 |
+
|
183 |
+
def __len__(self):
|
184 |
+
return self.num_imgs
|
185 |
+
|
186 |
+
def __getitem__(self, index):
|
187 |
+
img, kp, mask, sfm_pose, img_path = self.forward_img(index)
|
188 |
+
sfm_pose[0].shape = 1
|
189 |
+
mask = np.expand_dims(mask, 2)
|
190 |
+
|
191 |
+
images = torch.FloatTensor(img /255.).permute(2,0,1).unsqueeze(0)
|
192 |
+
masks = torch.FloatTensor(mask).permute(2,0,1).repeat(1,3,1,1)
|
193 |
+
mask_dt = compute_distance_transform(masks)
|
194 |
+
# flows = torch.zeros(1,2, self.img_size, self.img_size)
|
195 |
+
flows = torch.zeros(1)
|
196 |
+
bboxs = torch.FloatTensor([0, 0, 0, self.img_size, self.img_size, 1, 1, 0]).unsqueeze(0) # frame_id, crop_x0, crop_y0, crop_w, crop_h, resize_sx, resize_sy, sharpness
|
197 |
+
bg_image = images[0]
|
198 |
+
seq_idx = torch.LongTensor([index])
|
199 |
+
frame_idx = torch.LongTensor([0])
|
200 |
+
return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx
|
201 |
+
|
202 |
+
|
203 |
+
def compute_distance_transform(mask):
|
204 |
+
mask_dt = []
|
205 |
+
for m in mask:
|
206 |
+
dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
|
207 |
+
inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
|
208 |
+
mask_dt += [torch.stack([dt, inv_dt], 0)]
|
209 |
+
return torch.stack(mask_dt, 0) # Bx2xHxW
|
210 |
+
|
211 |
+
|
212 |
+
def resize_img(img, scale_factor):
|
213 |
+
new_size = (np.round(np.array(img.shape[:2]) * scale_factor)).astype(int)
|
214 |
+
new_img = cv2.resize(img, (new_size[1], new_size[0]))
|
215 |
+
# This is scale factor of [height, width] i.e. [y, x]
|
216 |
+
actual_factor = [new_size[0] / float(img.shape[0]),
|
217 |
+
new_size[1] / float(img.shape[1])]
|
218 |
+
return new_img, actual_factor
|
219 |
+
|
220 |
+
|
221 |
+
def peturb_bbox(bbox, pf=0, jf=0):
|
222 |
+
'''
|
223 |
+
Jitters and pads the input bbox.
|
224 |
+
Args:
|
225 |
+
bbox: Zero-indexed tight bbox.
|
226 |
+
pf: padding fraction.
|
227 |
+
jf: jittering fraction.
|
228 |
+
Returns:
|
229 |
+
pet_bbox: Jittered and padded box. Might have -ve or out-of-image coordinates
|
230 |
+
'''
|
231 |
+
pet_bbox = [coord for coord in bbox]
|
232 |
+
bwidth = bbox[2] - bbox[0] + 1
|
233 |
+
bheight = bbox[3] - bbox[1] + 1
|
234 |
+
|
235 |
+
pet_bbox[0] -= (pf*bwidth) + (1-2*np.random.random())*jf*bwidth
|
236 |
+
pet_bbox[1] -= (pf*bheight) + (1-2*np.random.random())*jf*bheight
|
237 |
+
pet_bbox[2] += (pf*bwidth) + (1-2*np.random.random())*jf*bwidth
|
238 |
+
pet_bbox[3] += (pf*bheight) + (1-2*np.random.random())*jf*bheight
|
239 |
+
|
240 |
+
return pet_bbox
|
241 |
+
|
242 |
+
|
243 |
+
def square_bbox(bbox):
|
244 |
+
'''
|
245 |
+
Converts a bbox to have a square shape by increasing size along non-max dimension.
|
246 |
+
'''
|
247 |
+
sq_bbox = [int(round(coord)) for coord in bbox]
|
248 |
+
bwidth = sq_bbox[2] - sq_bbox[0] + 1
|
249 |
+
bheight = sq_bbox[3] - sq_bbox[1] + 1
|
250 |
+
maxdim = float(max(bwidth, bheight))
|
251 |
+
|
252 |
+
dw_b_2 = int(round((maxdim-bwidth)/2.0))
|
253 |
+
dh_b_2 = int(round((maxdim-bheight)/2.0))
|
254 |
+
|
255 |
+
sq_bbox[0] -= dw_b_2
|
256 |
+
sq_bbox[1] -= dh_b_2
|
257 |
+
sq_bbox[2] = sq_bbox[0] + maxdim - 1
|
258 |
+
sq_bbox[3] = sq_bbox[1] + maxdim - 1
|
259 |
+
|
260 |
+
return sq_bbox
|
261 |
+
|
262 |
+
|
263 |
+
def crop(img, bbox, bgval=0):
|
264 |
+
'''
|
265 |
+
Crops a region from the image corresponding to the bbox.
|
266 |
+
If some regions specified go outside the image boundaries, the pixel values are set to bgval.
|
267 |
+
Args:
|
268 |
+
img: image to crop
|
269 |
+
bbox: bounding box to crop
|
270 |
+
bgval: default background for regions outside image
|
271 |
+
'''
|
272 |
+
bbox = [int(round(c)) for c in bbox]
|
273 |
+
bwidth = bbox[2] - bbox[0] + 1
|
274 |
+
bheight = bbox[3] - bbox[1] + 1
|
275 |
+
|
276 |
+
im_shape = np.shape(img)
|
277 |
+
im_h, im_w = im_shape[0], im_shape[1]
|
278 |
+
|
279 |
+
nc = 1 if len(im_shape) < 3 else im_shape[2]
|
280 |
+
|
281 |
+
img_out = np.ones((bheight, bwidth, nc))*bgval
|
282 |
+
x_min_src = max(0, bbox[0])
|
283 |
+
x_max_src = min(im_w, bbox[2]+1)
|
284 |
+
y_min_src = max(0, bbox[1])
|
285 |
+
y_max_src = min(im_h, bbox[3]+1)
|
286 |
+
|
287 |
+
x_min_trg = x_min_src - bbox[0]
|
288 |
+
x_max_trg = x_max_src - x_min_src + x_min_trg
|
289 |
+
y_min_trg = y_min_src - bbox[1]
|
290 |
+
y_max_trg = y_max_src - y_min_src + y_min_trg
|
291 |
+
|
292 |
+
img_out[y_min_trg:y_max_trg, x_min_trg:x_max_trg, :] = img[y_min_src:y_max_src, x_min_src:x_max_src, :]
|
293 |
+
return img_out
|
294 |
+
|
295 |
+
|
296 |
+
# https://github.com/akanazawa/cmr/blob/master/utils/transformations.py
|
297 |
+
import math
|
298 |
+
import numpy
|
299 |
+
_EPS = numpy.finfo(float).eps * 4.0
|
300 |
+
|
301 |
+
def quaternion_matrix(quaternion):
|
302 |
+
"""Return homogeneous rotation matrix from quaternion.
|
303 |
+
>>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
|
304 |
+
>>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
|
305 |
+
True
|
306 |
+
>>> M = quaternion_matrix([1, 0, 0, 0])
|
307 |
+
>>> numpy.allclose(M, numpy.identity(4))
|
308 |
+
True
|
309 |
+
>>> M = quaternion_matrix([0, 1, 0, 0])
|
310 |
+
>>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
|
311 |
+
True
|
312 |
+
"""
|
313 |
+
q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
|
314 |
+
n = numpy.dot(q, q)
|
315 |
+
if n < _EPS:
|
316 |
+
return numpy.identity(4)
|
317 |
+
q *= math.sqrt(2.0 / n)
|
318 |
+
q = numpy.outer(q, q)
|
319 |
+
return numpy.array([
|
320 |
+
[1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0],
|
321 |
+
[ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0],
|
322 |
+
[ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0],
|
323 |
+
[ 0.0, 0.0, 0.0, 1.0]])
|
324 |
+
|
325 |
+
def quaternion_from_matrix(matrix, isprecise=False):
|
326 |
+
"""Return quaternion from rotation matrix.
|
327 |
+
If isprecise is True, the input matrix is assumed to be a precise rotation
|
328 |
+
matrix and a faster algorithm is used.
|
329 |
+
>>> q = quaternion_from_matrix(numpy.identity(4), True)
|
330 |
+
>>> numpy.allclose(q, [1, 0, 0, 0])
|
331 |
+
True
|
332 |
+
>>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
|
333 |
+
>>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
|
334 |
+
True
|
335 |
+
>>> R = rotation_matrix(0.123, (1, 2, 3))
|
336 |
+
>>> q = quaternion_from_matrix(R, True)
|
337 |
+
>>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
|
338 |
+
True
|
339 |
+
>>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
|
340 |
+
... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
|
341 |
+
>>> q = quaternion_from_matrix(R)
|
342 |
+
>>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
|
343 |
+
True
|
344 |
+
>>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
|
345 |
+
... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
|
346 |
+
>>> q = quaternion_from_matrix(R)
|
347 |
+
>>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
|
348 |
+
True
|
349 |
+
>>> R = random_rotation_matrix()
|
350 |
+
>>> q = quaternion_from_matrix(R)
|
351 |
+
>>> is_same_transform(R, quaternion_matrix(q))
|
352 |
+
True
|
353 |
+
>>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
|
354 |
+
... quaternion_from_matrix(R, isprecise=True))
|
355 |
+
True
|
356 |
+
>>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
|
357 |
+
>>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
|
358 |
+
... quaternion_from_matrix(R, isprecise=True))
|
359 |
+
True
|
360 |
+
"""
|
361 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
|
362 |
+
if isprecise:
|
363 |
+
q = numpy.empty((4, ))
|
364 |
+
t = numpy.trace(M)
|
365 |
+
if t > M[3, 3]:
|
366 |
+
q[0] = t
|
367 |
+
q[3] = M[1, 0] - M[0, 1]
|
368 |
+
q[2] = M[0, 2] - M[2, 0]
|
369 |
+
q[1] = M[2, 1] - M[1, 2]
|
370 |
+
else:
|
371 |
+
i, j, k = 0, 1, 2
|
372 |
+
if M[1, 1] > M[0, 0]:
|
373 |
+
i, j, k = 1, 2, 0
|
374 |
+
if M[2, 2] > M[i, i]:
|
375 |
+
i, j, k = 2, 0, 1
|
376 |
+
t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
|
377 |
+
q[i] = t
|
378 |
+
q[j] = M[i, j] + M[j, i]
|
379 |
+
q[k] = M[k, i] + M[i, k]
|
380 |
+
q[3] = M[k, j] - M[j, k]
|
381 |
+
q = q[[3, 0, 1, 2]]
|
382 |
+
q *= 0.5 / math.sqrt(t * M[3, 3])
|
383 |
+
else:
|
384 |
+
m00 = M[0, 0]
|
385 |
+
m01 = M[0, 1]
|
386 |
+
m02 = M[0, 2]
|
387 |
+
m10 = M[1, 0]
|
388 |
+
m11 = M[1, 1]
|
389 |
+
m12 = M[1, 2]
|
390 |
+
m20 = M[2, 0]
|
391 |
+
m21 = M[2, 1]
|
392 |
+
m22 = M[2, 2]
|
393 |
+
# symmetric matrix K
|
394 |
+
K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0],
|
395 |
+
[m01+m10, m11-m00-m22, 0.0, 0.0],
|
396 |
+
[m02+m20, m12+m21, m22-m00-m11, 0.0],
|
397 |
+
[m21-m12, m02-m20, m10-m01, m00+m11+m22]])
|
398 |
+
K /= 3.0
|
399 |
+
# quaternion is eigenvector of K that corresponds to largest eigenvalue
|
400 |
+
w, V = numpy.linalg.eigh(K)
|
401 |
+
q = V[[3, 0, 1, 2], numpy.argmax(w)]
|
402 |
+
if q[0] < 0.0:
|
403 |
+
numpy.negative(q, q)
|
404 |
+
return q
|
video3d/cub_dataloaders_ddp.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import scipy.io as sio
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from types import SimpleNamespace
|
9 |
+
|
10 |
+
|
11 |
+
def get_cub_loader(data_dir, split='test', is_validation=False, batch_size=256, num_workers=4, image_size=256):
|
12 |
+
opts = SimpleNamespace()
|
13 |
+
opts.data_dir = data_dir
|
14 |
+
opts.padding_frac = 0.05
|
15 |
+
opts.jitter_frac = 0.05
|
16 |
+
opts.input_size = image_size
|
17 |
+
opts.split = split
|
18 |
+
|
19 |
+
dataset = CUBDataset(opts)
|
20 |
+
loader = torch.utils.data.DataLoader(
|
21 |
+
dataset,
|
22 |
+
batch_size=batch_size,
|
23 |
+
shuffle=not is_validation,
|
24 |
+
num_workers=num_workers,
|
25 |
+
pin_memory=True
|
26 |
+
)
|
27 |
+
return loader
|
28 |
+
|
29 |
+
|
30 |
+
def get_cub_loader_ddp(data_dir, world_size, rank, split='test', is_validation=False, batch_size=256, num_workers=4, image_size=256):
|
31 |
+
opts = SimpleNamespace()
|
32 |
+
opts.data_dir = data_dir
|
33 |
+
opts.padding_frac = 0.05
|
34 |
+
opts.jitter_frac = 0.05
|
35 |
+
opts.input_size = image_size
|
36 |
+
opts.split = split
|
37 |
+
|
38 |
+
dataset = CUBDataset(opts)
|
39 |
+
|
40 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
41 |
+
dataset,
|
42 |
+
num_replicas=world_size,
|
43 |
+
rank=rank,
|
44 |
+
)
|
45 |
+
|
46 |
+
loader = torch.utils.data.DataLoader(
|
47 |
+
dataset,
|
48 |
+
sampler=sampler,
|
49 |
+
batch_size=batch_size,
|
50 |
+
shuffle=not is_validation,
|
51 |
+
drop_last=True,
|
52 |
+
num_workers=num_workers,
|
53 |
+
pin_memory=True
|
54 |
+
)
|
55 |
+
return loader
|
56 |
+
|
57 |
+
|
58 |
+
class CUBDataset(Dataset):
|
59 |
+
def __init__(self, opts):
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
self.opts = opts
|
63 |
+
self.img_size = opts.input_size
|
64 |
+
self.jitter_frac = opts.jitter_frac
|
65 |
+
self.padding_frac = opts.padding_frac
|
66 |
+
self.split = opts.split
|
67 |
+
self.data_dir = opts.data_dir
|
68 |
+
self.data_cache_dir = osp.join(self.data_dir, 'cachedir/cub')
|
69 |
+
self.img_dir = osp.join(self.data_dir, 'images')
|
70 |
+
|
71 |
+
self.anno_path = osp.join(self.data_cache_dir, 'data', '%s_cub_cleaned.mat' % self.split)
|
72 |
+
self.anno_sfm_path = osp.join(self.data_cache_dir, 'sfm', 'anno_%s.mat' % self.split)
|
73 |
+
|
74 |
+
if not osp.exists(self.anno_path):
|
75 |
+
print('%s doesnt exist!' % self.anno_path)
|
76 |
+
import pdb; pdb.set_trace()
|
77 |
+
|
78 |
+
# Load the annotation file.
|
79 |
+
print('loading %s' % self.anno_path)
|
80 |
+
self.anno = sio.loadmat(
|
81 |
+
self.anno_path, struct_as_record=False, squeeze_me=True)['images']
|
82 |
+
self.anno_sfm = sio.loadmat(
|
83 |
+
self.anno_sfm_path, struct_as_record=False, squeeze_me=True)['sfm_anno']
|
84 |
+
|
85 |
+
self.kp_perm = np.array([1, 2, 3, 4, 5, 6, 11, 12, 13, 10, 7, 8, 9, 14, 15]) - 1;
|
86 |
+
|
87 |
+
self.num_imgs = len(self.anno)
|
88 |
+
print('%d images' % self.num_imgs)
|
89 |
+
|
90 |
+
def forward_img(self, index):
|
91 |
+
data = self.anno[index]
|
92 |
+
data_sfm = self.anno_sfm[0]
|
93 |
+
|
94 |
+
# sfm_pose = (sfm_c, sfm_t, sfm_r)
|
95 |
+
sfm_pose = [np.copy(data_sfm.scale), np.copy(data_sfm.trans), np.copy(data_sfm.rot)]
|
96 |
+
|
97 |
+
sfm_rot = np.pad(sfm_pose[2], (0,1), 'constant')
|
98 |
+
sfm_rot[3, 3] = 1
|
99 |
+
sfm_pose[2] = quaternion_from_matrix(sfm_rot, isprecise=True)
|
100 |
+
|
101 |
+
img_path = osp.join(self.img_dir, str(data.rel_path))
|
102 |
+
#img_path = img_path.replace("JPEG", "jpg")
|
103 |
+
img = np.array(Image.open(img_path))
|
104 |
+
|
105 |
+
# Some are grayscale:
|
106 |
+
if len(img.shape) == 2:
|
107 |
+
img = np.repeat(np.expand_dims(img, 2), 3, axis=2)
|
108 |
+
mask = data.mask
|
109 |
+
mask = np.expand_dims(mask, 2)
|
110 |
+
h,w,_ = mask.shape
|
111 |
+
|
112 |
+
# Adjust to 0 indexing
|
113 |
+
bbox = np.array(
|
114 |
+
[data.bbox.x1, data.bbox.y1, data.bbox.x2, data.bbox.y2],
|
115 |
+
float) - 1
|
116 |
+
|
117 |
+
parts = data.parts.T.astype(float)
|
118 |
+
kp = np.copy(parts)
|
119 |
+
vis = kp[:, 2] > 0
|
120 |
+
kp[vis, :2] -= 1
|
121 |
+
|
122 |
+
# Peturb bbox
|
123 |
+
if self.split == 'train':
|
124 |
+
bbox = peturb_bbox(
|
125 |
+
bbox, pf=self.padding_frac, jf=self.jitter_frac)
|
126 |
+
else:
|
127 |
+
bbox = peturb_bbox(
|
128 |
+
bbox, pf=self.padding_frac, jf=0)
|
129 |
+
bbox = square_bbox(bbox)
|
130 |
+
|
131 |
+
# crop image around bbox, translate kps
|
132 |
+
img, mask, kp, sfm_pose = self.crop_image(img, mask, bbox, kp, vis, sfm_pose)
|
133 |
+
|
134 |
+
# scale image, and mask. And scale kps.
|
135 |
+
img, mask, kp, sfm_pose = self.scale_image(img, mask, kp, vis, sfm_pose)
|
136 |
+
|
137 |
+
# Mirror image on random.
|
138 |
+
if self.split == 'train':
|
139 |
+
img, mask, kp, sfm_pose = self.mirror_image(img, mask, kp, sfm_pose)
|
140 |
+
|
141 |
+
# Normalize kp to be [-1, 1]
|
142 |
+
img_h, img_w = img.shape[:2]
|
143 |
+
kp_norm, sfm_pose = self.normalize_kp(kp, sfm_pose, img_h, img_w)
|
144 |
+
|
145 |
+
# img = Image.fromarray(np.asarray(img, np.uint8))
|
146 |
+
mask = np.asarray(mask, np.float32)
|
147 |
+
return img, kp_norm, mask, sfm_pose, img_path
|
148 |
+
|
149 |
+
def normalize_kp(self, kp, sfm_pose, img_h, img_w):
|
150 |
+
vis = kp[:, 2, None] > 0
|
151 |
+
new_kp = np.stack([2 * (kp[:, 0] / img_w) - 1,
|
152 |
+
2 * (kp[:, 1] / img_h) - 1,
|
153 |
+
kp[:, 2]]).T
|
154 |
+
sfm_pose[0] *= (1.0/img_w + 1.0/img_h)
|
155 |
+
sfm_pose[1][0] = 2.0 * (sfm_pose[1][0] / img_w) - 1
|
156 |
+
sfm_pose[1][1] = 2.0 * (sfm_pose[1][1] / img_h) - 1
|
157 |
+
new_kp = vis * new_kp
|
158 |
+
|
159 |
+
return new_kp, sfm_pose
|
160 |
+
|
161 |
+
def crop_image(self, img, mask, bbox, kp, vis, sfm_pose):
|
162 |
+
# crop image and mask and translate kps
|
163 |
+
img = crop(img, bbox, bgval=1)
|
164 |
+
mask = crop(mask, bbox, bgval=0)
|
165 |
+
kp[vis, 0] -= bbox[0]
|
166 |
+
kp[vis, 1] -= bbox[1]
|
167 |
+
sfm_pose[1][0] -= bbox[0]
|
168 |
+
sfm_pose[1][1] -= bbox[1]
|
169 |
+
return img, mask, kp, sfm_pose
|
170 |
+
|
171 |
+
def scale_image(self, img, mask, kp, vis, sfm_pose):
|
172 |
+
# Scale image so largest bbox size is img_size
|
173 |
+
bwidth = np.shape(img)[0]
|
174 |
+
bheight = np.shape(img)[1]
|
175 |
+
scale = self.img_size / float(max(bwidth, bheight))
|
176 |
+
img_scale, _ = resize_img(img, scale)
|
177 |
+
# if img_scale.shape[0] != self.img_size:
|
178 |
+
# print('bad!')
|
179 |
+
# import ipdb; ipdb.set_trace()
|
180 |
+
# mask_scale, _ = resize_img(mask, scale)
|
181 |
+
# mask_scale, _ = resize_img(mask, scale, interpolation=cv2.INTER_NEAREST)
|
182 |
+
mask_scale, _ = resize_img(mask, scale)
|
183 |
+
kp[vis, :2] *= scale
|
184 |
+
sfm_pose[0] *= scale
|
185 |
+
sfm_pose[1] *= scale
|
186 |
+
|
187 |
+
return img_scale, mask_scale, kp, sfm_pose
|
188 |
+
|
189 |
+
def mirror_image(self, img, mask, kp, sfm_pose):
|
190 |
+
kp_perm = self.kp_perm
|
191 |
+
if np.random.rand(1) > 0.5:
|
192 |
+
# Need copy bc torch collate doesnt like neg strides
|
193 |
+
img_flip = img[:, ::-1, :].copy()
|
194 |
+
mask_flip = mask[:, ::-1].copy()
|
195 |
+
|
196 |
+
# Flip kps.
|
197 |
+
new_x = img.shape[1] - kp[:, 0] - 1
|
198 |
+
kp_flip = np.hstack((new_x[:, None], kp[:, 1:]))
|
199 |
+
kp_flip = kp_flip[kp_perm, :]
|
200 |
+
# Flip sfm_pose Rot.
|
201 |
+
R = quaternion_matrix(sfm_pose[2])
|
202 |
+
flip_R = np.diag([-1, 1, 1, 1]).dot(R.dot(np.diag([-1, 1, 1, 1])))
|
203 |
+
sfm_pose[2] = quaternion_from_matrix(flip_R, isprecise=True)
|
204 |
+
# Flip tx
|
205 |
+
tx = img.shape[1] - sfm_pose[1][0] - 1
|
206 |
+
sfm_pose[1][0] = tx
|
207 |
+
return img_flip, mask_flip, kp_flip, sfm_pose
|
208 |
+
else:
|
209 |
+
return img, mask, kp, sfm_pose
|
210 |
+
|
211 |
+
def __len__(self):
|
212 |
+
return self.num_imgs
|
213 |
+
|
214 |
+
def __getitem__(self, index):
|
215 |
+
img, kp, mask, sfm_pose, img_path = self.forward_img(index)
|
216 |
+
sfm_pose[0].shape = 1
|
217 |
+
mask = np.expand_dims(mask, 2)
|
218 |
+
|
219 |
+
images = torch.FloatTensor(img /255.).permute(2,0,1).unsqueeze(0)
|
220 |
+
masks = torch.FloatTensor(mask).permute(2,0,1).repeat(1,3,1,1)
|
221 |
+
mask_dt = compute_distance_transform(masks)
|
222 |
+
# flows = torch.zeros(1,2, self.img_size, self.img_size)
|
223 |
+
flows = torch.zeros(1)
|
224 |
+
bboxs = torch.FloatTensor([0, 0, 0, self.img_size, self.img_size, 1, 1, 0]).unsqueeze(0) # frame_id, crop_x0, crop_y0, crop_w, crop_h, resize_sx, resize_sy, sharpness
|
225 |
+
bg_image = images[0]
|
226 |
+
seq_idx = torch.LongTensor([index])
|
227 |
+
frame_idx = torch.LongTensor([0])
|
228 |
+
return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx
|
229 |
+
|
230 |
+
|
231 |
+
def compute_distance_transform(mask):
|
232 |
+
mask_dt = []
|
233 |
+
for m in mask:
|
234 |
+
dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
|
235 |
+
inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
|
236 |
+
mask_dt += [torch.stack([dt, inv_dt], 0)]
|
237 |
+
return torch.stack(mask_dt, 0) # Bx2xHxW
|
238 |
+
|
239 |
+
|
240 |
+
def resize_img(img, scale_factor):
|
241 |
+
new_size = (np.round(np.array(img.shape[:2]) * scale_factor)).astype(int)
|
242 |
+
new_img = cv2.resize(img, (new_size[1], new_size[0]))
|
243 |
+
# This is scale factor of [height, width] i.e. [y, x]
|
244 |
+
actual_factor = [new_size[0] / float(img.shape[0]),
|
245 |
+
new_size[1] / float(img.shape[1])]
|
246 |
+
return new_img, actual_factor
|
247 |
+
|
248 |
+
|
249 |
+
def peturb_bbox(bbox, pf=0, jf=0):
|
250 |
+
'''
|
251 |
+
Jitters and pads the input bbox.
|
252 |
+
Args:
|
253 |
+
bbox: Zero-indexed tight bbox.
|
254 |
+
pf: padding fraction.
|
255 |
+
jf: jittering fraction.
|
256 |
+
Returns:
|
257 |
+
pet_bbox: Jittered and padded box. Might have -ve or out-of-image coordinates
|
258 |
+
'''
|
259 |
+
pet_bbox = [coord for coord in bbox]
|
260 |
+
bwidth = bbox[2] - bbox[0] + 1
|
261 |
+
bheight = bbox[3] - bbox[1] + 1
|
262 |
+
|
263 |
+
pet_bbox[0] -= (pf*bwidth) + (1-2*np.random.random())*jf*bwidth
|
264 |
+
pet_bbox[1] -= (pf*bheight) + (1-2*np.random.random())*jf*bheight
|
265 |
+
pet_bbox[2] += (pf*bwidth) + (1-2*np.random.random())*jf*bwidth
|
266 |
+
pet_bbox[3] += (pf*bheight) + (1-2*np.random.random())*jf*bheight
|
267 |
+
|
268 |
+
return pet_bbox
|
269 |
+
|
270 |
+
|
271 |
+
def square_bbox(bbox):
|
272 |
+
'''
|
273 |
+
Converts a bbox to have a square shape by increasing size along non-max dimension.
|
274 |
+
'''
|
275 |
+
sq_bbox = [int(round(coord)) for coord in bbox]
|
276 |
+
bwidth = sq_bbox[2] - sq_bbox[0] + 1
|
277 |
+
bheight = sq_bbox[3] - sq_bbox[1] + 1
|
278 |
+
maxdim = float(max(bwidth, bheight))
|
279 |
+
|
280 |
+
dw_b_2 = int(round((maxdim-bwidth)/2.0))
|
281 |
+
dh_b_2 = int(round((maxdim-bheight)/2.0))
|
282 |
+
|
283 |
+
sq_bbox[0] -= dw_b_2
|
284 |
+
sq_bbox[1] -= dh_b_2
|
285 |
+
sq_bbox[2] = sq_bbox[0] + maxdim - 1
|
286 |
+
sq_bbox[3] = sq_bbox[1] + maxdim - 1
|
287 |
+
|
288 |
+
return sq_bbox
|
289 |
+
|
290 |
+
|
291 |
+
def crop(img, bbox, bgval=0):
|
292 |
+
'''
|
293 |
+
Crops a region from the image corresponding to the bbox.
|
294 |
+
If some regions specified go outside the image boundaries, the pixel values are set to bgval.
|
295 |
+
Args:
|
296 |
+
img: image to crop
|
297 |
+
bbox: bounding box to crop
|
298 |
+
bgval: default background for regions outside image
|
299 |
+
'''
|
300 |
+
bbox = [int(round(c)) for c in bbox]
|
301 |
+
bwidth = bbox[2] - bbox[0] + 1
|
302 |
+
bheight = bbox[3] - bbox[1] + 1
|
303 |
+
|
304 |
+
im_shape = np.shape(img)
|
305 |
+
im_h, im_w = im_shape[0], im_shape[1]
|
306 |
+
|
307 |
+
nc = 1 if len(im_shape) < 3 else im_shape[2]
|
308 |
+
|
309 |
+
img_out = np.ones((bheight, bwidth, nc))*bgval
|
310 |
+
x_min_src = max(0, bbox[0])
|
311 |
+
x_max_src = min(im_w, bbox[2]+1)
|
312 |
+
y_min_src = max(0, bbox[1])
|
313 |
+
y_max_src = min(im_h, bbox[3]+1)
|
314 |
+
|
315 |
+
x_min_trg = x_min_src - bbox[0]
|
316 |
+
x_max_trg = x_max_src - x_min_src + x_min_trg
|
317 |
+
y_min_trg = y_min_src - bbox[1]
|
318 |
+
y_max_trg = y_max_src - y_min_src + y_min_trg
|
319 |
+
|
320 |
+
img_out[y_min_trg:y_max_trg, x_min_trg:x_max_trg, :] = img[y_min_src:y_max_src, x_min_src:x_max_src, :]
|
321 |
+
return img_out
|
322 |
+
|
323 |
+
|
324 |
+
# https://github.com/akanazawa/cmr/blob/master/utils/transformations.py
|
325 |
+
import math
|
326 |
+
import numpy
|
327 |
+
_EPS = numpy.finfo(float).eps * 4.0
|
328 |
+
|
329 |
+
|
330 |
+
def quaternion_matrix(quaternion):
|
331 |
+
"""Return homogeneous rotation matrix from quaternion.
|
332 |
+
>>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
|
333 |
+
>>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
|
334 |
+
True
|
335 |
+
>>> M = quaternion_matrix([1, 0, 0, 0])
|
336 |
+
>>> numpy.allclose(M, numpy.identity(4))
|
337 |
+
True
|
338 |
+
>>> M = quaternion_matrix([0, 1, 0, 0])
|
339 |
+
>>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
|
340 |
+
True
|
341 |
+
"""
|
342 |
+
q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
|
343 |
+
n = numpy.dot(q, q)
|
344 |
+
if n < _EPS:
|
345 |
+
return numpy.identity(4)
|
346 |
+
q *= math.sqrt(2.0 / n)
|
347 |
+
q = numpy.outer(q, q)
|
348 |
+
return numpy.array([
|
349 |
+
[1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0],
|
350 |
+
[ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0],
|
351 |
+
[ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0],
|
352 |
+
[ 0.0, 0.0, 0.0, 1.0]])
|
353 |
+
|
354 |
+
|
355 |
+
def quaternion_from_matrix(matrix, isprecise=False):
|
356 |
+
"""Return quaternion from rotation matrix.
|
357 |
+
If isprecise is True, the input matrix is assumed to be a precise rotation
|
358 |
+
matrix and a faster algorithm is used.
|
359 |
+
>>> q = quaternion_from_matrix(numpy.identity(4), True)
|
360 |
+
>>> numpy.allclose(q, [1, 0, 0, 0])
|
361 |
+
True
|
362 |
+
>>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
|
363 |
+
>>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
|
364 |
+
True
|
365 |
+
>>> R = rotation_matrix(0.123, (1, 2, 3))
|
366 |
+
>>> q = quaternion_from_matrix(R, True)
|
367 |
+
>>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
|
368 |
+
True
|
369 |
+
>>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
|
370 |
+
... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
|
371 |
+
>>> q = quaternion_from_matrix(R)
|
372 |
+
>>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
|
373 |
+
True
|
374 |
+
>>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
|
375 |
+
... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
|
376 |
+
>>> q = quaternion_from_matrix(R)
|
377 |
+
>>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
|
378 |
+
True
|
379 |
+
>>> R = random_rotation_matrix()
|
380 |
+
>>> q = quaternion_from_matrix(R)
|
381 |
+
>>> is_same_transform(R, quaternion_matrix(q))
|
382 |
+
True
|
383 |
+
>>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
|
384 |
+
... quaternion_from_matrix(R, isprecise=True))
|
385 |
+
True
|
386 |
+
>>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
|
387 |
+
>>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
|
388 |
+
... quaternion_from_matrix(R, isprecise=True))
|
389 |
+
True
|
390 |
+
"""
|
391 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
|
392 |
+
if isprecise:
|
393 |
+
q = numpy.empty((4, ))
|
394 |
+
t = numpy.trace(M)
|
395 |
+
if t > M[3, 3]:
|
396 |
+
q[0] = t
|
397 |
+
q[3] = M[1, 0] - M[0, 1]
|
398 |
+
q[2] = M[0, 2] - M[2, 0]
|
399 |
+
q[1] = M[2, 1] - M[1, 2]
|
400 |
+
else:
|
401 |
+
i, j, k = 0, 1, 2
|
402 |
+
if M[1, 1] > M[0, 0]:
|
403 |
+
i, j, k = 1, 2, 0
|
404 |
+
if M[2, 2] > M[i, i]:
|
405 |
+
i, j, k = 2, 0, 1
|
406 |
+
t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
|
407 |
+
q[i] = t
|
408 |
+
q[j] = M[i, j] + M[j, i]
|
409 |
+
q[k] = M[k, i] + M[i, k]
|
410 |
+
q[3] = M[k, j] - M[j, k]
|
411 |
+
q = q[[3, 0, 1, 2]]
|
412 |
+
q *= 0.5 / math.sqrt(t * M[3, 3])
|
413 |
+
else:
|
414 |
+
m00 = M[0, 0]
|
415 |
+
m01 = M[0, 1]
|
416 |
+
m02 = M[0, 2]
|
417 |
+
m10 = M[1, 0]
|
418 |
+
m11 = M[1, 1]
|
419 |
+
m12 = M[1, 2]
|
420 |
+
m20 = M[2, 0]
|
421 |
+
m21 = M[2, 1]
|
422 |
+
m22 = M[2, 2]
|
423 |
+
# symmetric matrix K
|
424 |
+
K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0],
|
425 |
+
[m01+m10, m11-m00-m22, 0.0, 0.0],
|
426 |
+
[m02+m20, m12+m21, m22-m00-m11, 0.0],
|
427 |
+
[m21-m12, m02-m20, m10-m01, m00+m11+m22]])
|
428 |
+
K /= 3.0
|
429 |
+
# quaternion is eigenvector of K that corresponds to largest eigenvalue
|
430 |
+
w, V = numpy.linalg.eigh(K)
|
431 |
+
q = V[[3, 0, 1, 2], numpy.argmax(w)]
|
432 |
+
if q[0] < 0.0:
|
433 |
+
numpy.negative(q, q)
|
434 |
+
return q
|
video3d/dataloaders.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import torchvision.datasets.folder
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
|
14 |
+
def compute_distance_transform(mask):
|
15 |
+
mask_dt = []
|
16 |
+
for m in mask:
|
17 |
+
dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
|
18 |
+
inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
|
19 |
+
mask_dt += [torch.stack([dt, inv_dt], 0)]
|
20 |
+
return torch.stack(mask_dt, 0) # Bx2xHxW
|
21 |
+
|
22 |
+
|
23 |
+
def crop_image(image, boxs, size):
|
24 |
+
crops = []
|
25 |
+
for box in boxs:
|
26 |
+
crop_x0, crop_y0, crop_w, crop_h = box
|
27 |
+
crop = transforms.functional.resized_crop(image, crop_y0, crop_x0, crop_h, crop_w, size)
|
28 |
+
crop = transforms.functional.to_tensor(crop)
|
29 |
+
crops += [crop]
|
30 |
+
return torch.stack(crops, 0)
|
31 |
+
|
32 |
+
|
33 |
+
def box_loader(fpath):
|
34 |
+
box = np.loadtxt(fpath, 'str')
|
35 |
+
box[0] = box[0].split('_')[0]
|
36 |
+
return box.astype(np.float32)
|
37 |
+
|
38 |
+
|
39 |
+
def read_feat_from_img(path, n_channels):
|
40 |
+
feat = np.array(Image.open(path))
|
41 |
+
return dencode_feat_from_img(feat, n_channels)
|
42 |
+
|
43 |
+
|
44 |
+
def dencode_feat_from_img(img, n_channels):
|
45 |
+
n_addon_channels = int(np.ceil(n_channels / 3) * 3) - n_channels
|
46 |
+
n_tiles = int((n_channels + n_addon_channels) / 3)
|
47 |
+
feat = rearrange(img, 'h (t w) c -> h w (t c)', t=n_tiles, c=3)
|
48 |
+
feat = feat[:, :, :-n_addon_channels]
|
49 |
+
feat = feat.astype('float32') / 255
|
50 |
+
return feat.transpose(2, 0, 1)
|
51 |
+
|
52 |
+
|
53 |
+
def dino_loader(fpath, n_channels):
|
54 |
+
dino_map = read_feat_from_img(fpath, n_channels)
|
55 |
+
return dino_map
|
56 |
+
|
57 |
+
|
58 |
+
def get_valid_mask(boxs, image_size):
|
59 |
+
valid_masks = []
|
60 |
+
for box in boxs:
|
61 |
+
crop_x0, crop_y0, crop_w, crop_h, full_w, full_h = box[1:7].int().numpy()
|
62 |
+
# Discard a small margin near the boundary.
|
63 |
+
margin_w = int(crop_w * 0.02)
|
64 |
+
margin_h = int(crop_h * 0.02)
|
65 |
+
mask_full = torch.ones(full_h-margin_h*2, full_w-margin_w*2)
|
66 |
+
mask_full_pad = torch.nn.functional.pad(mask_full, (crop_w+margin_w, crop_w+margin_w, crop_h+margin_h, crop_h+margin_h), mode='constant', value=0.0)
|
67 |
+
mask_full_crop = mask_full_pad[crop_y0+crop_h:crop_y0+crop_h*2, crop_x0+crop_w:crop_x0+crop_w*2]
|
68 |
+
mask_crop = torch.nn.functional.interpolate(mask_full_crop[None, None, :, :], image_size, mode='nearest')[0,0]
|
69 |
+
valid_masks += [mask_crop]
|
70 |
+
return torch.stack(valid_masks, 0) # NxHxW
|
71 |
+
|
72 |
+
|
73 |
+
def horizontal_flip_box(box):
|
74 |
+
frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness, label = box.unbind(1)
|
75 |
+
box[:,1] = full_w - crop_x0 - crop_w # x0
|
76 |
+
return box
|
77 |
+
|
78 |
+
|
79 |
+
def horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features=None, dino_clusters=None):
|
80 |
+
images = images.flip(3) # NxCxHxW
|
81 |
+
masks = masks.flip(3) # NxCxHxW
|
82 |
+
mask_dt = mask_dt.flip(3) # NxCxHxW
|
83 |
+
mask_valid = mask_valid.flip(2) # NxHxW
|
84 |
+
if flows.dim() > 1:
|
85 |
+
flows = flows.flip(3) # (N-1)x(x,y)xHxW
|
86 |
+
flows[:,0] *= -1 # invert delta x
|
87 |
+
bboxs = horizontal_flip_box(bboxs) # NxK
|
88 |
+
bg_images = bg_images.flip(3) # NxCxHxW
|
89 |
+
if dino_features.dim() > 1:
|
90 |
+
dino_features = dino_features.flip(3)
|
91 |
+
if dino_clusters.dim() > 1:
|
92 |
+
dino_clusters = dino_clusters.flip(3)
|
93 |
+
return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters
|
94 |
+
|
95 |
+
|
96 |
+
class BaseSequenceDataset(Dataset):
|
97 |
+
def __init__(self, root, skip_beginning=4, skip_end=4, min_seq_len=10, debug_seq=False):
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
self.skip_beginning = skip_beginning
|
101 |
+
self.skip_end = skip_end
|
102 |
+
self.min_seq_len = min_seq_len
|
103 |
+
# self.pattern = "{:07d}_{}"
|
104 |
+
self.sequences = self._make_sequences(root)
|
105 |
+
|
106 |
+
if debug_seq:
|
107 |
+
# self.sequences = [self.sequences[0][20:160]] * 100
|
108 |
+
seq_len = 0
|
109 |
+
while seq_len < min_seq_len:
|
110 |
+
i = np.random.randint(len(self.sequences))
|
111 |
+
rand_seq = self.sequences[i]
|
112 |
+
seq_len = len(rand_seq)
|
113 |
+
self.sequences = [rand_seq]
|
114 |
+
|
115 |
+
self.samples = []
|
116 |
+
|
117 |
+
def _make_sequences(self, path):
|
118 |
+
result = []
|
119 |
+
for d in sorted(os.scandir(path), key=lambda e: e.name):
|
120 |
+
if d.is_dir():
|
121 |
+
files = self._parse_folder(d)
|
122 |
+
if len(files) >= self.min_seq_len:
|
123 |
+
result.append(files)
|
124 |
+
return result
|
125 |
+
|
126 |
+
def _parse_folder(self, path):
|
127 |
+
result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0])))
|
128 |
+
result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
|
129 |
+
|
130 |
+
if len(result) <= self.skip_beginning + self.skip_end:
|
131 |
+
return []
|
132 |
+
if self.skip_end == 0:
|
133 |
+
return result[self.skip_beginning:]
|
134 |
+
return result[self.skip_beginning:-self.skip_end]
|
135 |
+
|
136 |
+
def _load_ids(self, path_patterns, loaders, transform=None):
|
137 |
+
result = []
|
138 |
+
for loader in loaders:
|
139 |
+
for p in path_patterns:
|
140 |
+
x = loader[1](p.format(loader[0]), *loader[2:])
|
141 |
+
if transform:
|
142 |
+
x = transform(x)
|
143 |
+
result.append(x)
|
144 |
+
return tuple(result)
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
return len(self.samples)
|
148 |
+
|
149 |
+
def __getitem__(self, index):
|
150 |
+
raise NotImplemented("This is a base class and should not be used directly")
|
151 |
+
|
152 |
+
|
153 |
+
class NFrameSequenceDataset(BaseSequenceDataset):
|
154 |
+
def __init__(self, root, cat_name=None, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, in_image_size=256, out_image_size=256, debug_seq=False, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, **kwargs):
|
155 |
+
self.cat_name = cat_name
|
156 |
+
self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)]
|
157 |
+
self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)]
|
158 |
+
self.bbox_loaders = [("box.txt", box_loader)]
|
159 |
+
super().__init__(root, skip_beginning, skip_end, min_seq_len, debug_seq)
|
160 |
+
if num_sample_frames > 1:
|
161 |
+
self.flow_loaders = [("flow.png", cv2.imread, cv2.IMREAD_UNCHANGED)]
|
162 |
+
else:
|
163 |
+
self.flow_loaders = None
|
164 |
+
|
165 |
+
self.num_sample_frames = num_sample_frames
|
166 |
+
self.random_sample = random_sample
|
167 |
+
if self.random_sample:
|
168 |
+
if shuffle:
|
169 |
+
random.shuffle(self.sequences)
|
170 |
+
self.samples = self.sequences
|
171 |
+
else:
|
172 |
+
for i, s in enumerate(self.sequences):
|
173 |
+
stride = 1 if dense_sample else self.num_sample_frames
|
174 |
+
self.samples += [(i, k) for k in range(0, len(s), stride)]
|
175 |
+
if shuffle:
|
176 |
+
random.shuffle(self.samples)
|
177 |
+
|
178 |
+
self.in_image_size = in_image_size
|
179 |
+
self.out_image_size = out_image_size
|
180 |
+
self.load_background = load_background
|
181 |
+
self.color_jitter = color_jitter
|
182 |
+
self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()])
|
183 |
+
self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
|
184 |
+
if self.flow_loaders is not None:
|
185 |
+
self.flow_transform = lambda x: (torch.FloatTensor(x.astype(np.float32)).flip(2)[:,:,:2] / 65535. ) *2 -1
|
186 |
+
self.random_flip = random_flip
|
187 |
+
self.load_dino_feature = load_dino_feature
|
188 |
+
if load_dino_feature:
|
189 |
+
self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)]
|
190 |
+
self.load_dino_cluster = load_dino_cluster
|
191 |
+
if load_dino_cluster:
|
192 |
+
self.dino_cluster_loaders = [("clusters.png", torchvision.datasets.folder.default_loader)]
|
193 |
+
|
194 |
+
def __getitem__(self, index):
|
195 |
+
if self.random_sample:
|
196 |
+
seq_idx = index % len(self.sequences)
|
197 |
+
seq = self.sequences[seq_idx]
|
198 |
+
if len(seq) < self.num_sample_frames:
|
199 |
+
start_frame_idx = 0
|
200 |
+
else:
|
201 |
+
start_frame_idx = np.random.randint(len(seq)-self.num_sample_frames+1)
|
202 |
+
paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames]
|
203 |
+
else:
|
204 |
+
seq_idx, start_frame_idx = self.samples[index % len(self.samples)]
|
205 |
+
seq = self.sequences[seq_idx]
|
206 |
+
# Handle edge case: when only last frame is left, sample last two frames, except if the sequence only has one frame
|
207 |
+
if len(seq) <= start_frame_idx +1:
|
208 |
+
start_frame_idx = max(0, start_frame_idx-1)
|
209 |
+
paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames]
|
210 |
+
|
211 |
+
masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images
|
212 |
+
mask_dt = compute_distance_transform(masks)
|
213 |
+
jitter = False
|
214 |
+
if self.color_jitter is not None:
|
215 |
+
prob, b, h = self.color_jitter
|
216 |
+
if np.random.rand() < prob:
|
217 |
+
jitter = True
|
218 |
+
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
219 |
+
image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()])
|
220 |
+
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
221 |
+
image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()])
|
222 |
+
if jitter:
|
223 |
+
images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images
|
224 |
+
images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images
|
225 |
+
images = images_fg * masks + images_bg * (1-masks)
|
226 |
+
else:
|
227 |
+
images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images
|
228 |
+
if len(paths) > 1:
|
229 |
+
flows = torch.stack(self._load_ids(paths[:-1], self.flow_loaders, transform=self.flow_transform), 0).permute(0,3,1,2) # load flow for first image, (N-1)x(x,y)xHxW, -1~1
|
230 |
+
flows = torch.nn.functional.interpolate(flows, size=self.out_image_size, mode="bilinear")
|
231 |
+
else:
|
232 |
+
flows = torch.zeros(1)
|
233 |
+
bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images
|
234 |
+
mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image
|
235 |
+
if self.load_background:
|
236 |
+
bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg'))
|
237 |
+
if jitter:
|
238 |
+
bg_image = color_jitter_tsf_bg(bg_image)
|
239 |
+
bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size))
|
240 |
+
else:
|
241 |
+
bg_images = torch.zeros_like(images)
|
242 |
+
if self.load_dino_feature:
|
243 |
+
dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224
|
244 |
+
else:
|
245 |
+
dino_features = torch.zeros(1)
|
246 |
+
if self.load_dino_cluster:
|
247 |
+
dino_clusters = torch.stack(self._load_ids(paths, self.dino_cluster_loaders, transform=transforms.ToTensor()), 0) # BxFx3x55x55
|
248 |
+
else:
|
249 |
+
dino_clusters = torch.zeros(1)
|
250 |
+
seq_idx = torch.LongTensor([seq_idx])
|
251 |
+
frame_idx = torch.arange(start_frame_idx, start_frame_idx+len(paths)).long()
|
252 |
+
|
253 |
+
if self.random_flip and np.random.rand() < 0.5:
|
254 |
+
images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters)
|
255 |
+
|
256 |
+
## pad shorter sequence
|
257 |
+
if len(paths) < self.num_sample_frames:
|
258 |
+
num_pad = self.num_sample_frames - len(paths)
|
259 |
+
images = torch.cat([images[:1]] *num_pad + [images], 0)
|
260 |
+
masks = torch.cat([masks[:1]] *num_pad + [masks], 0)
|
261 |
+
mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0)
|
262 |
+
mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0)
|
263 |
+
if flows.dim() > 1:
|
264 |
+
flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0)
|
265 |
+
bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0)
|
266 |
+
bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0)
|
267 |
+
if dino_features.dim() > 1:
|
268 |
+
dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0)
|
269 |
+
if dino_clusters.dim() > 1:
|
270 |
+
dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0)
|
271 |
+
frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0)
|
272 |
+
|
273 |
+
return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name
|
274 |
+
|
275 |
+
|
276 |
+
def get_sequence_loader(data_dir, **kwargs):
|
277 |
+
if isinstance(data_dir, dict):
|
278 |
+
loaders = []
|
279 |
+
for k, v in data_dir.items():
|
280 |
+
dataset= NFrameSequenceDataset(v, cat_name=k, **kwargs)
|
281 |
+
loader = torch.utils.data.DataLoader(dataset, batch_size=kwargs['batch_size'], shuffle=kwargs['shuffle'], num_workers=kwargs['num_workers'], pin_memory=True)
|
282 |
+
loaders += [loader]
|
283 |
+
return loaders
|
284 |
+
else:
|
285 |
+
return [get_sequence_loader_single(data_dir, **kwargs)]
|
286 |
+
|
287 |
+
|
288 |
+
def get_sequence_loader_single(data_dir, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64):
|
289 |
+
if mode == 'n_frame':
|
290 |
+
dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim)
|
291 |
+
else:
|
292 |
+
raise NotImplementedError
|
293 |
+
loader = torch.utils.data.DataLoader(
|
294 |
+
dataset,
|
295 |
+
batch_size=batch_size,
|
296 |
+
shuffle=not is_validation,
|
297 |
+
num_workers=num_workers,
|
298 |
+
pin_memory=True
|
299 |
+
)
|
300 |
+
return loader
|
301 |
+
|
302 |
+
|
303 |
+
class ImageDataset(Dataset):
|
304 |
+
def __init__(self, root, is_validation=False, image_size=256, color_jitter=None):
|
305 |
+
super().__init__()
|
306 |
+
self.image_loader = ("rgb.jpg", torchvision.datasets.folder.default_loader)
|
307 |
+
self.mask_loader = ("mask.png", torchvision.datasets.folder.default_loader)
|
308 |
+
self.bbox_loader = ("box.txt", np.loadtxt, 'str')
|
309 |
+
self.samples = self._parse_folder(root)
|
310 |
+
self.image_size = image_size
|
311 |
+
self.color_jitter = color_jitter
|
312 |
+
self.image_transform = transforms.Compose([transforms.Resize(self.image_size), transforms.ToTensor()])
|
313 |
+
self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
|
314 |
+
|
315 |
+
def _parse_folder(self, path):
|
316 |
+
result = sorted(glob(os.path.join(path, '**/*'+self.image_loader[0]), recursive=True))
|
317 |
+
result = [p.replace(self.image_loader[0], '{}') for p in result]
|
318 |
+
return result
|
319 |
+
|
320 |
+
def _load_ids(self, path, loader, transform=None):
|
321 |
+
x = loader[1](path.format(loader[0]), *loader[2:])
|
322 |
+
if transform:
|
323 |
+
x = transform(x)
|
324 |
+
return x
|
325 |
+
|
326 |
+
def __len__(self):
|
327 |
+
return len(self.samples)
|
328 |
+
|
329 |
+
def __getitem__(self, index):
|
330 |
+
path = self.samples[index % len(self.samples)]
|
331 |
+
masks = self._load_ids(path, self.mask_loader, transform=self.mask_transform).unsqueeze(0)
|
332 |
+
mask_dt = compute_distance_transform(masks)
|
333 |
+
jitter = False
|
334 |
+
if self.color_jitter is not None:
|
335 |
+
prob, b, h = self.color_jitter
|
336 |
+
if np.random.rand() < prob:
|
337 |
+
jitter = True
|
338 |
+
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
339 |
+
image_transform_fg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_fg, transforms.ToTensor()])
|
340 |
+
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
341 |
+
image_transform_bg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_bg, transforms.ToTensor()])
|
342 |
+
if jitter:
|
343 |
+
images_fg = self._load_ids(path, self.image_loader, transform=image_transform_fg).unsqueeze(0)
|
344 |
+
images_bg = self._load_ids(path, self.image_loader, transform=image_transform_bg).unsqueeze(0)
|
345 |
+
images = images_fg * masks + images_bg * (1-masks)
|
346 |
+
else:
|
347 |
+
images = self._load_ids(path, self.image_loader, transform=self.image_transform).unsqueeze(0)
|
348 |
+
flows = torch.zeros(1)
|
349 |
+
bboxs = self._load_ids(path, self.bbox_loader, transform=None)
|
350 |
+
bboxs[0] = '0'
|
351 |
+
bboxs = torch.FloatTensor(bboxs.astype('float')).unsqueeze(0)
|
352 |
+
bg_fpath = os.path.join(os.path.dirname(path), 'background_frame.jpg')
|
353 |
+
if os.path.isfile(bg_fpath):
|
354 |
+
bg_image = torchvision.datasets.folder.default_loader(bg_fpath)
|
355 |
+
if jitter:
|
356 |
+
bg_image = color_jitter_tsf_bg(bg_image)
|
357 |
+
bg_image = transforms.ToTensor()(bg_image)
|
358 |
+
else:
|
359 |
+
bg_image = images[0]
|
360 |
+
seq_idx = torch.LongTensor([index])
|
361 |
+
frame_idx = torch.LongTensor([0])
|
362 |
+
return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx
|
363 |
+
|
364 |
+
|
365 |
+
def get_image_loader(data_dir, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None):
|
366 |
+
dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter)
|
367 |
+
|
368 |
+
loader = torch.utils.data.DataLoader(
|
369 |
+
dataset,
|
370 |
+
batch_size=batch_size,
|
371 |
+
shuffle=False,
|
372 |
+
num_workers=num_workers,
|
373 |
+
pin_memory=True
|
374 |
+
)
|
375 |
+
return loader
|
video3d/dataloaders_ddp.py
ADDED
@@ -0,0 +1,1210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
import itertools
|
8 |
+
import torch
|
9 |
+
import copy
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
import torchvision.datasets.folder
|
12 |
+
import torchvision.transforms as transforms
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
|
16 |
+
def compute_distance_transform(mask):
|
17 |
+
mask_dt = []
|
18 |
+
for m in mask:
|
19 |
+
dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
|
20 |
+
inv_dt = torch.FloatTensor(cv2.distanceTransform(np.uint8(1 - m[0]), cv2.DIST_L2, cv2.DIST_MASK_PRECISE))
|
21 |
+
mask_dt += [torch.stack([dt, inv_dt], 0)]
|
22 |
+
return torch.stack(mask_dt, 0) # Bx2xHxW
|
23 |
+
|
24 |
+
|
25 |
+
def crop_image(image, boxs, size):
|
26 |
+
crops = []
|
27 |
+
for box in boxs:
|
28 |
+
crop_x0, crop_y0, crop_w, crop_h = box
|
29 |
+
crop = transforms.functional.resized_crop(image, crop_y0, crop_x0, crop_h, crop_w, size)
|
30 |
+
crop = transforms.functional.to_tensor(crop)
|
31 |
+
crops += [crop]
|
32 |
+
return torch.stack(crops, 0)
|
33 |
+
|
34 |
+
|
35 |
+
def box_loader(fpath):
|
36 |
+
box = np.loadtxt(fpath, 'str')
|
37 |
+
box[0] = box[0].split('_')[0]
|
38 |
+
return box.astype(np.float32)
|
39 |
+
|
40 |
+
|
41 |
+
def read_feat_from_img(path, n_channels):
|
42 |
+
feat = np.array(Image.open(path))
|
43 |
+
return dencode_feat_from_img(feat, n_channels)
|
44 |
+
|
45 |
+
|
46 |
+
def dencode_feat_from_img(img, n_channels):
|
47 |
+
n_addon_channels = int(np.ceil(n_channels / 3) * 3) - n_channels
|
48 |
+
n_tiles = int((n_channels + n_addon_channels) / 3)
|
49 |
+
feat = rearrange(img, 'h (t w) c -> h w (t c)', t=n_tiles, c=3)
|
50 |
+
if n_addon_channels != 0:
|
51 |
+
feat = feat[:, :, :-n_addon_channels]
|
52 |
+
feat = feat.astype('float32') / 255
|
53 |
+
return feat.transpose(2, 0, 1)
|
54 |
+
|
55 |
+
|
56 |
+
def dino_loader(fpath, n_channels):
|
57 |
+
dino_map = read_feat_from_img(fpath, n_channels)
|
58 |
+
return dino_map
|
59 |
+
|
60 |
+
|
61 |
+
def get_valid_mask(boxs, image_size):
|
62 |
+
valid_masks = []
|
63 |
+
for box in boxs:
|
64 |
+
crop_x0, crop_y0, crop_w, crop_h, full_w, full_h = box[1:7].int().numpy()
|
65 |
+
margin_w = int(crop_w * 0.02)
|
66 |
+
margin_h = int(crop_h * 0.02)
|
67 |
+
mask_full = torch.ones(full_h-margin_h*2, full_w-margin_w*2)
|
68 |
+
mask_full_pad = torch.nn.functional.pad(mask_full, (crop_w+margin_w, crop_w+margin_w, crop_h+margin_h, crop_h+margin_h), mode='constant', value=0.0)
|
69 |
+
mask_full_crop = mask_full_pad[(crop_y0+crop_h):crop_y0+(crop_h*2), (crop_x0+crop_w):crop_x0+(crop_w*2)]
|
70 |
+
mask_crop = torch.nn.functional.interpolate(mask_full_crop[None, None, :, :], image_size, mode='nearest')[0,0]
|
71 |
+
valid_masks += [mask_crop]
|
72 |
+
return torch.stack(valid_masks, 0) # NxHxW
|
73 |
+
|
74 |
+
|
75 |
+
def horizontal_flip_box(box):
|
76 |
+
frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness, label = box.unbind(1)
|
77 |
+
box[:,1] = full_w - crop_x0 - crop_w # x0
|
78 |
+
return box
|
79 |
+
|
80 |
+
|
81 |
+
def horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features=None, dino_clusters=None):
|
82 |
+
images = images.flip(3) # NxCxHxW
|
83 |
+
masks = masks.flip(3) # NxCxHxW
|
84 |
+
mask_dt = mask_dt.flip(3) # NxCxHxW
|
85 |
+
mask_valid = mask_valid.flip(2) # NxHxW
|
86 |
+
if flows.dim() > 1:
|
87 |
+
flows = flows.flip(3) # (N-1)x(x,y)xHxW
|
88 |
+
flows[:,0] *= -1 # invert delta x
|
89 |
+
bboxs = horizontal_flip_box(bboxs) # NxK
|
90 |
+
bg_images = bg_images.flip(3) # NxCxHxW
|
91 |
+
if dino_features.dim() > 1:
|
92 |
+
dino_features = dino_features.flip(3)
|
93 |
+
if dino_clusters.dim() > 1:
|
94 |
+
dino_clusters = dino_clusters.flip(3)
|
95 |
+
return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters
|
96 |
+
|
97 |
+
|
98 |
+
def none_to_nan(x):
|
99 |
+
return torch.FloatTensor([float('nan')]) if x is None else x
|
100 |
+
|
101 |
+
|
102 |
+
class BaseSequenceDataset(Dataset):
|
103 |
+
def __init__(self, root, skip_beginning=4, skip_end=4, min_seq_len=10, debug_seq=False):
|
104 |
+
super().__init__()
|
105 |
+
|
106 |
+
self.skip_beginning = skip_beginning
|
107 |
+
self.skip_end = skip_end
|
108 |
+
self.min_seq_len = min_seq_len
|
109 |
+
# self.pattern = "{:07d}_{}"
|
110 |
+
self.sequences = self._make_sequences(root)
|
111 |
+
|
112 |
+
if debug_seq:
|
113 |
+
# self.sequences = [self.sequences[0][20:160]] * 100
|
114 |
+
seq_len = 0
|
115 |
+
while seq_len < min_seq_len:
|
116 |
+
i = np.random.randint(len(self.sequences))
|
117 |
+
rand_seq = self.sequences[i]
|
118 |
+
seq_len = len(rand_seq)
|
119 |
+
self.sequences = [rand_seq]
|
120 |
+
|
121 |
+
self.samples = []
|
122 |
+
|
123 |
+
def _make_sequences(self, path):
|
124 |
+
result = []
|
125 |
+
for d in sorted(os.scandir(path), key=lambda e: e.name):
|
126 |
+
if d.is_dir():
|
127 |
+
files = self._parse_folder(d)
|
128 |
+
if len(files) >= self.min_seq_len:
|
129 |
+
result.append(files)
|
130 |
+
return result
|
131 |
+
|
132 |
+
def _parse_folder(self, path):
|
133 |
+
result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0])))
|
134 |
+
result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
|
135 |
+
|
136 |
+
if len(result) <= self.skip_beginning + self.skip_end:
|
137 |
+
return []
|
138 |
+
if self.skip_end == 0:
|
139 |
+
return result[self.skip_beginning:]
|
140 |
+
return result[self.skip_beginning:-self.skip_end]
|
141 |
+
|
142 |
+
def _load_ids(self, path_patterns, loaders, transform=None):
|
143 |
+
result = []
|
144 |
+
for loader in loaders:
|
145 |
+
for p in path_patterns:
|
146 |
+
x = loader[1](p.format(loader[0]), *loader[2:])
|
147 |
+
if transform:
|
148 |
+
x = transform(x)
|
149 |
+
result.append(x)
|
150 |
+
return tuple(result)
|
151 |
+
|
152 |
+
def __len__(self):
|
153 |
+
return len(self.samples)
|
154 |
+
|
155 |
+
def __getitem__(self, index):
|
156 |
+
raise NotImplemented("This is a base class and should not be used directly")
|
157 |
+
|
158 |
+
|
159 |
+
class NFrameSequenceDataset(BaseSequenceDataset):
|
160 |
+
def __init__(self, root, cat_name=None, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, in_image_size=256, out_image_size=256, debug_seq=False, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, flow_bool=False, **kwargs):
|
161 |
+
self.cat_name = cat_name
|
162 |
+
self.flow_bool=flow_bool
|
163 |
+
|
164 |
+
self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)]
|
165 |
+
self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)]
|
166 |
+
self.bbox_loaders = [("box.txt", box_loader)]
|
167 |
+
super().__init__(root, skip_beginning, skip_end, min_seq_len, debug_seq)
|
168 |
+
# from IPython import embed; embed()
|
169 |
+
if flow_bool and num_sample_frames > 1:
|
170 |
+
self.flow_loaders = [("flow.png", cv2.imread, cv2.IMREAD_UNCHANGED)]
|
171 |
+
else:
|
172 |
+
self.flow_loaders = None
|
173 |
+
|
174 |
+
self.num_sample_frames = num_sample_frames
|
175 |
+
self.random_sample = random_sample
|
176 |
+
if self.random_sample:
|
177 |
+
if shuffle:
|
178 |
+
random.shuffle(self.sequences)
|
179 |
+
self.samples = self.sequences
|
180 |
+
else:
|
181 |
+
|
182 |
+
for i, s in enumerate(self.sequences):
|
183 |
+
stride = 1 if dense_sample else self.num_sample_frames
|
184 |
+
self.samples += [(i, k) for k in range(0, len(s), stride)]
|
185 |
+
if shuffle:
|
186 |
+
random.shuffle(self.samples)
|
187 |
+
|
188 |
+
self.in_image_size = in_image_size
|
189 |
+
self.out_image_size = out_image_size
|
190 |
+
self.load_background = load_background
|
191 |
+
self.color_jitter = color_jitter
|
192 |
+
self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()])
|
193 |
+
self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
|
194 |
+
if self.flow_loaders is not None:
|
195 |
+
self.flow_transform = lambda x: (torch.FloatTensor(x.astype(np.float32)).flip(2)[:,:,:2] / 65535. ) *2 -1
|
196 |
+
self.random_flip = random_flip
|
197 |
+
self.load_dino_feature = load_dino_feature
|
198 |
+
if load_dino_feature:
|
199 |
+
self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)]
|
200 |
+
self.load_dino_cluster = load_dino_cluster
|
201 |
+
if load_dino_cluster:
|
202 |
+
self.dino_cluster_loaders = [("clusters.png", torchvision.datasets.folder.default_loader)]
|
203 |
+
|
204 |
+
def __getitem__(self, index):
|
205 |
+
if self.random_sample:
|
206 |
+
seq_idx = index % len(self.sequences)
|
207 |
+
seq = self.sequences[seq_idx]
|
208 |
+
if len(seq) < self.num_sample_frames:
|
209 |
+
start_frame_idx = 0
|
210 |
+
else:
|
211 |
+
start_frame_idx = np.random.randint(len(seq)-self.num_sample_frames+1)
|
212 |
+
paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames]
|
213 |
+
else:
|
214 |
+
seq_idx, start_frame_idx = self.samples[index % len(self.samples)]
|
215 |
+
seq = self.sequences[seq_idx]
|
216 |
+
# Handle edge case: when only last frame is left, sample last two frames, except if the sequence only has one frame
|
217 |
+
if len(seq) <= start_frame_idx +1:
|
218 |
+
start_frame_idx = max(0, start_frame_idx-1)
|
219 |
+
paths = seq[start_frame_idx:start_frame_idx+self.num_sample_frames]
|
220 |
+
|
221 |
+
masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images
|
222 |
+
mask_dt = compute_distance_transform(masks)
|
223 |
+
jitter = False
|
224 |
+
if self.color_jitter is not None:
|
225 |
+
prob, b, h = self.color_jitter
|
226 |
+
if np.random.rand() < prob:
|
227 |
+
jitter = True
|
228 |
+
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
229 |
+
image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()])
|
230 |
+
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
231 |
+
image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()])
|
232 |
+
if jitter:
|
233 |
+
images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images
|
234 |
+
images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images
|
235 |
+
images = images_fg * masks + images_bg * (1-masks)
|
236 |
+
else:
|
237 |
+
images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images
|
238 |
+
if self.flow_bool==True and len(paths) > 1:
|
239 |
+
flows = torch.stack(self._load_ids(paths[:-1], self.flow_loaders, transform=self.flow_transform), 0).permute(0,3,1,2) # load flow for first image, (N-1)x(x,y)xHxW, -1~1
|
240 |
+
flows = torch.nn.functional.interpolate(flows, size=self.out_image_size, mode="bilinear")
|
241 |
+
else:
|
242 |
+
flows = torch.zeros(1)
|
243 |
+
bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images
|
244 |
+
mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image
|
245 |
+
if self.load_background:
|
246 |
+
bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg'))
|
247 |
+
if jitter:
|
248 |
+
bg_image = color_jitter_tsf_bg(bg_image)
|
249 |
+
bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size))
|
250 |
+
else:
|
251 |
+
bg_images = torch.zeros_like(images)
|
252 |
+
if self.load_dino_feature:
|
253 |
+
dino_paths = [
|
254 |
+
x.replace(
|
255 |
+
"/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new",
|
256 |
+
"/viscam/projects/articulated/zzli/data_dino_5000/7_cat"
|
257 |
+
)
|
258 |
+
for x in paths
|
259 |
+
]
|
260 |
+
dino_features = torch.stack(self._load_ids(dino_paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0)
|
261 |
+
# dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224
|
262 |
+
else:
|
263 |
+
dino_features = torch.zeros(1)
|
264 |
+
if self.load_dino_cluster:
|
265 |
+
dino_clusters = torch.stack(self._load_ids(paths, self.dino_cluster_loaders, transform=transforms.ToTensor()), 0) # BxFx3x55x55
|
266 |
+
else:
|
267 |
+
dino_clusters = torch.zeros(1)
|
268 |
+
seq_idx = torch.LongTensor([seq_idx])
|
269 |
+
frame_idx = torch.arange(start_frame_idx, start_frame_idx+len(paths)).long()
|
270 |
+
|
271 |
+
if self.random_flip and np.random.rand() < 0.5:
|
272 |
+
images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters)
|
273 |
+
|
274 |
+
## pad shorter sequence
|
275 |
+
if len(paths) < self.num_sample_frames:
|
276 |
+
num_pad = self.num_sample_frames - len(paths)
|
277 |
+
images = torch.cat([images[:1]] *num_pad + [images], 0)
|
278 |
+
masks = torch.cat([masks[:1]] *num_pad + [masks], 0)
|
279 |
+
mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0)
|
280 |
+
mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0)
|
281 |
+
if flows.dim() > 1:
|
282 |
+
flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0)
|
283 |
+
bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0)
|
284 |
+
bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0)
|
285 |
+
if dino_features.dim() > 1:
|
286 |
+
dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0)
|
287 |
+
if dino_clusters.dim() > 1:
|
288 |
+
dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0)
|
289 |
+
frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0)
|
290 |
+
|
291 |
+
out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name)), )
|
292 |
+
return out
|
293 |
+
# return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name
|
294 |
+
|
295 |
+
|
296 |
+
def few_shot_box_loader(fpath):
|
297 |
+
box = np.loadtxt(fpath, 'str')
|
298 |
+
# box[0] = box[0].split('_')[0]
|
299 |
+
return box.astype(np.float32)
|
300 |
+
|
301 |
+
|
302 |
+
class FewShotImageDataset(Dataset):
|
303 |
+
def __init__(self, root, cat_name=None, cat_num=0, num_sample_frames=2, in_image_size=256, out_image_size=256, shuffle=False, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64, flow_bool=False, **kwargs):
|
304 |
+
super().__init__()
|
305 |
+
self.cat_name = cat_name
|
306 |
+
self.cat_num = cat_num # this is actually useless
|
307 |
+
self.flow_bool=flow_bool
|
308 |
+
|
309 |
+
self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)]
|
310 |
+
self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)]
|
311 |
+
self.bbox_loaders = [("box.txt", few_shot_box_loader)]
|
312 |
+
self.flow_loaders = None
|
313 |
+
|
314 |
+
# get all the valid paths, since it's just image-wise, in get_item, we will make it like a len=1 sequence
|
315 |
+
result = sorted(glob(os.path.join(root, '*'+self.image_loaders[0][0])))
|
316 |
+
result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
|
317 |
+
self.sequences = result
|
318 |
+
|
319 |
+
self.num_sample_frames = num_sample_frames
|
320 |
+
if shuffle:
|
321 |
+
random.shuffle(self.sequences)
|
322 |
+
self.samples = self.sequences
|
323 |
+
|
324 |
+
self.in_image_size = in_image_size
|
325 |
+
self.out_image_size = out_image_size
|
326 |
+
self.load_background = load_background
|
327 |
+
self.color_jitter = color_jitter
|
328 |
+
self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()])
|
329 |
+
self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
|
330 |
+
self.random_flip = random_flip
|
331 |
+
self.load_dino_feature = load_dino_feature
|
332 |
+
if load_dino_feature:
|
333 |
+
self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)]
|
334 |
+
|
335 |
+
def _load_ids(self, path_patterns, loaders, transform=None):
|
336 |
+
result = []
|
337 |
+
for loader in loaders:
|
338 |
+
for p in path_patterns:
|
339 |
+
x = loader[1](p.format(loader[0]), *loader[2:])
|
340 |
+
if transform:
|
341 |
+
x = transform(x)
|
342 |
+
result.append(x)
|
343 |
+
return tuple(result)
|
344 |
+
|
345 |
+
def __len__(self):
|
346 |
+
return len(self.samples)
|
347 |
+
|
348 |
+
def __getitem__(self, index):
|
349 |
+
paths = [self.samples[index]] # len 1 sequence
|
350 |
+
|
351 |
+
masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images
|
352 |
+
mask_dt = compute_distance_transform(masks)
|
353 |
+
jitter = False
|
354 |
+
if self.color_jitter is not None:
|
355 |
+
prob, b, h = self.color_jitter
|
356 |
+
if np.random.rand() < prob:
|
357 |
+
jitter = True
|
358 |
+
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
359 |
+
image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()])
|
360 |
+
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
361 |
+
image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()])
|
362 |
+
if jitter:
|
363 |
+
images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images
|
364 |
+
images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images
|
365 |
+
images = images_fg * masks + images_bg * (1-masks)
|
366 |
+
else:
|
367 |
+
images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images
|
368 |
+
|
369 |
+
flows = torch.zeros(1)
|
370 |
+
bboxs = torch.stack(self._load_ids(paths, self.bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images
|
371 |
+
bboxs=torch.cat([bboxs, torch.Tensor([[self.cat_num]]).float()],dim=-1) # pad a label number
|
372 |
+
|
373 |
+
mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image
|
374 |
+
if self.load_background:
|
375 |
+
bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg'))
|
376 |
+
if jitter:
|
377 |
+
bg_image = color_jitter_tsf_bg(bg_image)
|
378 |
+
bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size))
|
379 |
+
else:
|
380 |
+
bg_images = torch.zeros_like(images)
|
381 |
+
if self.load_dino_feature:
|
382 |
+
dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224
|
383 |
+
else:
|
384 |
+
dino_features = torch.zeros(1)
|
385 |
+
|
386 |
+
dino_clusters = torch.zeros(1)
|
387 |
+
|
388 |
+
# These are actually no use
|
389 |
+
seq_idx = 0
|
390 |
+
seq_idx = torch.LongTensor([seq_idx])
|
391 |
+
frame_idx = torch.arange(0, 1).long()
|
392 |
+
|
393 |
+
if self.random_flip and np.random.rand() < 0.5:
|
394 |
+
images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters)
|
395 |
+
|
396 |
+
## pad shorter sequence
|
397 |
+
if len(paths) < self.num_sample_frames:
|
398 |
+
num_pad = self.num_sample_frames - len(paths)
|
399 |
+
images = torch.cat([images[:1]] *num_pad + [images], 0)
|
400 |
+
masks = torch.cat([masks[:1]] *num_pad + [masks], 0)
|
401 |
+
mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0)
|
402 |
+
mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0)
|
403 |
+
if flows.dim() > 1:
|
404 |
+
flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0)
|
405 |
+
bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0)
|
406 |
+
bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0)
|
407 |
+
if dino_features.dim() > 1:
|
408 |
+
dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0)
|
409 |
+
if dino_clusters.dim() > 1:
|
410 |
+
dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0)
|
411 |
+
frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0)
|
412 |
+
|
413 |
+
out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name)), )
|
414 |
+
return out
|
415 |
+
# return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, self.cat_name
|
416 |
+
|
417 |
+
|
418 |
+
class Quadrupeds_Image_Dataset(Dataset):
|
419 |
+
def __init__(self, original_data_dirs, few_shot_data_dirs, original_num=7, few_shot_num=93, num_sample_frames=2,
|
420 |
+
in_image_size=256, out_image_size=256, is_validation=False, val_image_num=5, shuffle=False, color_jitter=None,
|
421 |
+
load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64,
|
422 |
+
flow_bool=False, disable_fewshot=False, dataset_split_num=-1, **kwargs):
|
423 |
+
self.original_data_dirs = original_data_dirs
|
424 |
+
self.few_shot_data_dirs = few_shot_data_dirs
|
425 |
+
self.original_num = original_num
|
426 |
+
self.few_shot_num = few_shot_num
|
427 |
+
|
428 |
+
self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)]
|
429 |
+
self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)]
|
430 |
+
self.original_bbox_loaders = [("box.txt", box_loader)]
|
431 |
+
self.few_shot_bbox_loaders = [("box.txt", few_shot_box_loader)]
|
432 |
+
|
433 |
+
assert len(self.original_data_dirs.keys()) == self.original_num
|
434 |
+
assert len(self.few_shot_data_dirs.keys()) == self.few_shot_num
|
435 |
+
self.num_sample_frames = num_sample_frames
|
436 |
+
|
437 |
+
self.batch_size = kwargs['batch_size'] # a hack way here
|
438 |
+
|
439 |
+
# for debug, just use some categories
|
440 |
+
if "override_categories" in kwargs:
|
441 |
+
self.override_categories = kwargs["override_categories"]
|
442 |
+
else:
|
443 |
+
self.override_categories = None
|
444 |
+
|
445 |
+
# original dataset
|
446 |
+
original_data_paths = {}
|
447 |
+
for k,v in self.original_data_dirs.items():
|
448 |
+
|
449 |
+
# categories override
|
450 |
+
if self.override_categories is not None:
|
451 |
+
if k not in self.override_categories:
|
452 |
+
continue
|
453 |
+
|
454 |
+
sequences = self._make_sequences(v)
|
455 |
+
samples = []
|
456 |
+
for seq in sequences:
|
457 |
+
samples += seq
|
458 |
+
if shuffle:
|
459 |
+
random.shuffle(samples)
|
460 |
+
original_data_paths.update({k: samples})
|
461 |
+
|
462 |
+
# few-shot dataset
|
463 |
+
enhance_back_view = kwargs['enhance_back_view']
|
464 |
+
if enhance_back_view:
|
465 |
+
enhance_back_view_path = kwargs['enhance_back_view_path']
|
466 |
+
|
467 |
+
few_shot_data_paths = {}
|
468 |
+
for k,v in self.few_shot_data_dirs.items():
|
469 |
+
|
470 |
+
# categories override
|
471 |
+
if self.override_categories is not None:
|
472 |
+
if k not in self.override_categories:
|
473 |
+
continue
|
474 |
+
if k.startswith('_'):
|
475 |
+
# a boundary here for dealing with when in new data, we have same categories as in 7-cat
|
476 |
+
v = v.replace(k, k[1:])
|
477 |
+
|
478 |
+
if isinstance(v, str):
|
479 |
+
result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0])))
|
480 |
+
elif isinstance(v, list):
|
481 |
+
result = []
|
482 |
+
for _v in v:
|
483 |
+
result = result + sorted(glob(os.path.join(_v, '*'+self.image_loaders[0][0])))
|
484 |
+
else:
|
485 |
+
raise NotImplementedError
|
486 |
+
|
487 |
+
# result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0])))
|
488 |
+
result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
|
489 |
+
sequences = result
|
490 |
+
|
491 |
+
# the original 7 categories are using pre-defined paths to separate train and test
|
492 |
+
# here the few-shot we use is_validation to decide if this dataset is train or test
|
493 |
+
# if use enhanced back view, we first pad the multiplied back view image paths at the front of seq
|
494 |
+
# i.e., we don't use back view images for validation
|
495 |
+
if enhance_back_view:
|
496 |
+
back_view_dir = os.path.join(enhance_back_view_path, k, 'train')
|
497 |
+
back_view_result = sorted(glob(os.path.join(back_view_dir, '*'+self.image_loaders[0][0])))
|
498 |
+
back_view_result = [p.replace(self.image_loaders[0][0], '{}') for p in back_view_result]
|
499 |
+
mul_bv_sequences = self._more_back_views(back_view_result, result)
|
500 |
+
sequences = mul_bv_sequences + sequences
|
501 |
+
|
502 |
+
if is_validation:
|
503 |
+
# sequences = sequences[-2:]
|
504 |
+
sequences = sequences[-val_image_num:]
|
505 |
+
else:
|
506 |
+
# sequences = sequences[:-2]
|
507 |
+
sequences = sequences[:-val_image_num]
|
508 |
+
|
509 |
+
if shuffle:
|
510 |
+
random.shuffle(sequences)
|
511 |
+
few_shot_data_paths.update({k: sequences})
|
512 |
+
|
513 |
+
# for visualization purpose
|
514 |
+
self.pure_ori_data_path = original_data_paths
|
515 |
+
self.pure_fs_data_path = few_shot_data_paths
|
516 |
+
|
517 |
+
self.few_shot_data_length = self._get_data_length(few_shot_data_paths) # get the original length of each few-shot category
|
518 |
+
|
519 |
+
if disable_fewshot:
|
520 |
+
few_shot_data_paths = {}
|
521 |
+
|
522 |
+
self.dataset_split_num = dataset_split_num # if -1 then pad to longest, otherwise follow this number to pad and split
|
523 |
+
if is_validation:
|
524 |
+
self.dataset_split_num = -1 # validation we don't split dataset
|
525 |
+
|
526 |
+
if self.dataset_split_num == -1:
|
527 |
+
self.all_data_paths, self.one_category_num = self._pad_paths(original_data_paths, few_shot_data_paths)
|
528 |
+
self.all_category_num = len(self.all_data_paths.keys())
|
529 |
+
self.all_category_names = list(self.all_data_paths.keys())
|
530 |
+
self.original_category_names = list(self.original_data_dirs.keys())
|
531 |
+
elif self.dataset_split_num > 0:
|
532 |
+
self.all_data_paths, self.one_category_num, self.original_category_names = self._pad_paths_withnum(original_data_paths, few_shot_data_paths, self.dataset_split_num)
|
533 |
+
self.all_category_num = len(self.all_data_paths.keys())
|
534 |
+
self.all_category_names = list(self.all_data_paths.keys())
|
535 |
+
else:
|
536 |
+
raise NotImplementedError
|
537 |
+
|
538 |
+
self.in_image_size = in_image_size
|
539 |
+
self.out_image_size = out_image_size
|
540 |
+
self.load_background = load_background
|
541 |
+
self.color_jitter = color_jitter
|
542 |
+
self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()])
|
543 |
+
self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
|
544 |
+
self.random_flip = random_flip
|
545 |
+
self.load_dino_feature = load_dino_feature
|
546 |
+
if load_dino_feature:
|
547 |
+
self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)]
|
548 |
+
|
549 |
+
def _more_back_views(self, back_view_seq, seq):
|
550 |
+
if len(back_view_seq) == 0:
|
551 |
+
# for category without back views
|
552 |
+
return []
|
553 |
+
factor = 5
|
554 |
+
# length = (len(seq) // factor) * factor
|
555 |
+
length = (len(seq) // factor) * (factor - 1)
|
556 |
+
mul_f = length // len(back_view_seq)
|
557 |
+
pad_f = length % len(back_view_seq)
|
558 |
+
new_seq = mul_f * back_view_seq + back_view_seq[:pad_f]
|
559 |
+
return new_seq
|
560 |
+
|
561 |
+
def _get_data_length(self, paths):
|
562 |
+
data_length = {}
|
563 |
+
for k,v in paths.items():
|
564 |
+
length = len(v)
|
565 |
+
data_length.update({k: length})
|
566 |
+
return data_length
|
567 |
+
|
568 |
+
def _make_sequences(self, path):
|
569 |
+
result = []
|
570 |
+
for d in sorted(os.scandir(path), key=lambda e: e.name):
|
571 |
+
if d.is_dir():
|
572 |
+
files = self._parse_folder(d)
|
573 |
+
if len(files) >= 1:
|
574 |
+
result.append(files)
|
575 |
+
return result
|
576 |
+
|
577 |
+
def _parse_folder(self, path):
|
578 |
+
result = sorted(glob(os.path.join(path, '*'+self.image_loaders[0][0])))
|
579 |
+
result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
|
580 |
+
|
581 |
+
if len(result) <= 0:
|
582 |
+
return []
|
583 |
+
return result
|
584 |
+
|
585 |
+
def _pad_paths(self, ori_paths, fs_paths):
|
586 |
+
img_nums = []
|
587 |
+
all_paths = copy.deepcopy(ori_paths)
|
588 |
+
all_paths.update(fs_paths)
|
589 |
+
for _, v in all_paths.items():
|
590 |
+
img_nums.append(len(v))
|
591 |
+
|
592 |
+
img_num = max(img_nums)
|
593 |
+
img_num = (img_num // self.batch_size) * self.batch_size
|
594 |
+
|
595 |
+
for k,v in all_paths.items():
|
596 |
+
if len(v) < img_num:
|
597 |
+
mul_time = img_num // len(v)
|
598 |
+
pad_time = img_num % len(v)
|
599 |
+
# for each v, shuffle it
|
600 |
+
shuffle_v = copy.deepcopy(v)
|
601 |
+
new_v = []
|
602 |
+
for i in range(mul_time):
|
603 |
+
new_v = new_v + shuffle_v
|
604 |
+
random.shuffle(shuffle_v)
|
605 |
+
del shuffle_v
|
606 |
+
new_v = new_v + v[0:pad_time]
|
607 |
+
# new_v = mul_time * v + v[0:pad_time]
|
608 |
+
all_paths[k] = new_v
|
609 |
+
elif len(v) > img_num:
|
610 |
+
all_paths[k] = v[:img_num]
|
611 |
+
else:
|
612 |
+
continue
|
613 |
+
|
614 |
+
return all_paths, img_num
|
615 |
+
|
616 |
+
def _pad_paths_withnum(self, ori_paths, fs_paths, split_num=1000):
|
617 |
+
img_num = (split_num // self.batch_size) * self.batch_size
|
618 |
+
all_paths = {}
|
619 |
+
orig_cat_names = []
|
620 |
+
|
621 |
+
for k, v in ori_paths.items():
|
622 |
+
total_num = ((len(v) // img_num) + 1) * img_num
|
623 |
+
pad_num = total_num - len(v)
|
624 |
+
split_num = total_num // img_num
|
625 |
+
|
626 |
+
new_v = copy.deepcopy(v)
|
627 |
+
random.shuffle(new_v)
|
628 |
+
all_v = v + new_v[:pad_num]
|
629 |
+
del new_v
|
630 |
+
|
631 |
+
for sn in range(split_num):
|
632 |
+
split_cat_name = f'{k}_' + '%03d' % sn
|
633 |
+
all_paths.update({
|
634 |
+
split_cat_name: all_v[sn*img_num: (sn+1)*img_num]
|
635 |
+
})
|
636 |
+
orig_cat_names.append(split_cat_name)
|
637 |
+
|
638 |
+
for k, v in fs_paths.items():
|
639 |
+
if len(v) < img_num:
|
640 |
+
mul_time = img_num // len(v)
|
641 |
+
pad_time = img_num % len(v)
|
642 |
+
# for each v, shuffle it
|
643 |
+
shuffle_v = copy.deepcopy(v)
|
644 |
+
new_v = []
|
645 |
+
for i in range(mul_time):
|
646 |
+
new_v = new_v + shuffle_v
|
647 |
+
random.shuffle(shuffle_v)
|
648 |
+
del shuffle_v
|
649 |
+
new_v = new_v + v[0:pad_time]
|
650 |
+
# new_v = mul_time * v + v[0:pad_time]
|
651 |
+
all_paths.update({
|
652 |
+
k: new_v
|
653 |
+
})
|
654 |
+
elif len(v) > img_num:
|
655 |
+
all_paths.update({
|
656 |
+
k: v[:img_num]
|
657 |
+
})
|
658 |
+
else:
|
659 |
+
continue
|
660 |
+
|
661 |
+
return all_paths, img_num, orig_cat_names
|
662 |
+
|
663 |
+
|
664 |
+
def _load_ids(self, path_patterns, loaders, transform=None):
|
665 |
+
result = []
|
666 |
+
for loader in loaders:
|
667 |
+
for p in path_patterns:
|
668 |
+
x = loader[1](p.format(loader[0]), *loader[2:])
|
669 |
+
if transform:
|
670 |
+
x = transform(x)
|
671 |
+
result.append(x)
|
672 |
+
return tuple(result)
|
673 |
+
|
674 |
+
def _shuffle_all(self):
|
675 |
+
for k,v in self.all_data_paths.items():
|
676 |
+
new_v = copy.deepcopy(v)
|
677 |
+
random.shuffle(new_v)
|
678 |
+
self.all_data_paths[k] = new_v
|
679 |
+
return None
|
680 |
+
|
681 |
+
def __len__(self):
|
682 |
+
return self.all_category_num * self.one_category_num
|
683 |
+
|
684 |
+
def __getitem__(self, index):
|
685 |
+
'''
|
686 |
+
This dataset must have non-shuffled index!!
|
687 |
+
'''
|
688 |
+
category_idx = (index % (self.batch_size * self.all_category_num)) // self.batch_size
|
689 |
+
path_idx = (index // (self.batch_size * self.all_category_num)) * self.batch_size + (index % (self.batch_size * self.all_category_num)) - category_idx * self.batch_size
|
690 |
+
category_name = self.all_category_names[category_idx]
|
691 |
+
paths = [self.all_data_paths[category_name][path_idx]] # len 1 sequence
|
692 |
+
|
693 |
+
if category_name in self.original_category_names:
|
694 |
+
bbox_loaders = self.original_bbox_loaders
|
695 |
+
use_original_bbox = True
|
696 |
+
else:
|
697 |
+
bbox_loaders = self.few_shot_bbox_loaders
|
698 |
+
use_original_bbox = False
|
699 |
+
|
700 |
+
masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images
|
701 |
+
mask_dt = compute_distance_transform(masks)
|
702 |
+
jitter = False
|
703 |
+
if self.color_jitter is not None:
|
704 |
+
prob, b, h = self.color_jitter
|
705 |
+
if np.random.rand() < prob:
|
706 |
+
jitter = True
|
707 |
+
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
708 |
+
image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()])
|
709 |
+
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
710 |
+
image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()])
|
711 |
+
if jitter:
|
712 |
+
images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images
|
713 |
+
images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images
|
714 |
+
images = images_fg * masks + images_bg * (1-masks)
|
715 |
+
else:
|
716 |
+
images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images
|
717 |
+
|
718 |
+
flows = torch.zeros(1)
|
719 |
+
bboxs = torch.stack(self._load_ids(paths, bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images
|
720 |
+
if not use_original_bbox:
|
721 |
+
bboxs=torch.cat([bboxs, torch.Tensor([[category_idx]]).float()],dim=-1) # pad a label number
|
722 |
+
|
723 |
+
mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image
|
724 |
+
if self.load_background:
|
725 |
+
bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg'))
|
726 |
+
if jitter:
|
727 |
+
bg_image = color_jitter_tsf_bg(bg_image)
|
728 |
+
bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size))
|
729 |
+
else:
|
730 |
+
bg_images = torch.zeros_like(images)
|
731 |
+
if self.load_dino_feature:
|
732 |
+
# print(paths)
|
733 |
+
new_dino_data_name = "data_dino_5000"
|
734 |
+
new_dino_data_path = os.path.join("/viscam/projects/articulated/dor/combine_all_data_for_ablation_magicpony", new_dino_data_name)
|
735 |
+
|
736 |
+
# TODO: use another version of DINO here by changing the path
|
737 |
+
if paths[0].startswith("/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new"):
|
738 |
+
# 7 cat data
|
739 |
+
new_dino_path = paths[0].replace(
|
740 |
+
"/viscam/projects/articulated/dor/AnimalsMotionDataset/splitted_data/Combine_data/dinov2_new",
|
741 |
+
"/viscam/projects/articulated/zzli/data_dino_5000/7_cat"
|
742 |
+
)
|
743 |
+
dino_paths = [new_dino_path]
|
744 |
+
elif paths[0].startswith("/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all"):
|
745 |
+
# 100 cat
|
746 |
+
dino_path = paths[0].replace(
|
747 |
+
"/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/few_shot_data_all",
|
748 |
+
os.path.join(new_dino_data_path, "100_cat")
|
749 |
+
)
|
750 |
+
dino_path_list = dino_path.split("/")
|
751 |
+
new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/"
|
752 |
+
new_dino_path = '/'.join(new_dino_path)
|
753 |
+
dino_paths = [new_dino_path]
|
754 |
+
|
755 |
+
elif paths[0].startswith("/viscam/projects/articulated/zzli/fs_data/data_resize_update/few_shot_data_all"):
|
756 |
+
# 100 cat
|
757 |
+
dino_path = paths[0].replace(
|
758 |
+
"/viscam/projects/articulated/zzli/fs_data/data_resize_update/few_shot_data_all",
|
759 |
+
os.path.join(new_dino_data_path, "100_cat")
|
760 |
+
)
|
761 |
+
dino_path_list = dino_path.split("/")
|
762 |
+
new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/"
|
763 |
+
new_dino_path = '/'.join(new_dino_path)
|
764 |
+
dino_paths = [new_dino_path]
|
765 |
+
|
766 |
+
elif paths[0].startswith("/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data"):
|
767 |
+
# back 100 cat
|
768 |
+
dino_path = paths[0].replace(
|
769 |
+
"/viscam/u/zzli/workspace/Animal-Data-Engine/data/data_resize_update/segmented_back_view_data",
|
770 |
+
os.path.join(new_dino_data_path, "back_100_cat")
|
771 |
+
)
|
772 |
+
dino_path_list = dino_path.split("/")
|
773 |
+
new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/"
|
774 |
+
new_dino_path = '/'.join(new_dino_path)
|
775 |
+
dino_paths = [new_dino_path]
|
776 |
+
|
777 |
+
elif paths[0].startswith("/viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered"):
|
778 |
+
# animal3d
|
779 |
+
dino_path = paths[0].replace(
|
780 |
+
"/viscam/projects/articulated/dor/Animal-Data-Engine/data/data_resize_update/train_with_classes_filtered",
|
781 |
+
os.path.join(new_dino_data_path, "animal3D")
|
782 |
+
)
|
783 |
+
dino_path_list = dino_path.split("/")
|
784 |
+
new_dino_path = dino_path_list[:-2] + dino_path_list[-1:] # remove "/train/"
|
785 |
+
new_dino_path = '/'.join(new_dino_path)
|
786 |
+
dino_paths = [new_dino_path]
|
787 |
+
else:
|
788 |
+
raise NotImplementedError
|
789 |
+
dino_features = torch.stack(self._load_ids(dino_paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0)
|
790 |
+
# dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224
|
791 |
+
else:
|
792 |
+
dino_features = torch.zeros(1)
|
793 |
+
|
794 |
+
dino_clusters = torch.zeros(1)
|
795 |
+
|
796 |
+
# These are actually no use
|
797 |
+
seq_idx = 0
|
798 |
+
seq_idx = torch.LongTensor([seq_idx])
|
799 |
+
frame_idx = torch.arange(0, 1).long()
|
800 |
+
|
801 |
+
if self.random_flip and np.random.rand() < 0.5:
|
802 |
+
images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters)
|
803 |
+
|
804 |
+
## pad shorter sequence
|
805 |
+
if len(paths) < self.num_sample_frames:
|
806 |
+
num_pad = self.num_sample_frames - len(paths)
|
807 |
+
images = torch.cat([images[:1]] *num_pad + [images], 0)
|
808 |
+
masks = torch.cat([masks[:1]] *num_pad + [masks], 0)
|
809 |
+
mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0)
|
810 |
+
mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0)
|
811 |
+
if flows.dim() > 1:
|
812 |
+
flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0)
|
813 |
+
bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0)
|
814 |
+
bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0)
|
815 |
+
if dino_features.dim() > 1:
|
816 |
+
dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0)
|
817 |
+
if dino_clusters.dim() > 1:
|
818 |
+
dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0)
|
819 |
+
frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0)
|
820 |
+
|
821 |
+
out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name)), )
|
822 |
+
return out
|
823 |
+
# return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name
|
824 |
+
|
825 |
+
def get_sequence_loader_quadrupeds(original_data_dirs, few_shot_data_dirs, original_num, few_shot_num, rank, world_size, **kwargs):
|
826 |
+
dataset = Quadrupeds_Image_Dataset(original_data_dirs, few_shot_data_dirs, original_num, few_shot_num, **kwargs)
|
827 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
828 |
+
dataset,
|
829 |
+
num_replicas=world_size,
|
830 |
+
rank=rank,
|
831 |
+
shuffle=False
|
832 |
+
)
|
833 |
+
loaders = []
|
834 |
+
loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)]
|
835 |
+
|
836 |
+
return loaders
|
837 |
+
|
838 |
+
|
839 |
+
class Quadrupeds_Image_Test_Dataset(Dataset):
|
840 |
+
def __init__(self, test_data_dirs, num_sample_frames=2, in_image_size=256, out_image_size=256, shuffle=False, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.png', load_dino_feature=False, dino_feature_dim=64, flow_bool=False, **kwargs):
|
841 |
+
self.few_shot_data_dirs = test_data_dirs
|
842 |
+
|
843 |
+
self.image_loaders = [("rgb"+rgb_suffix, torchvision.datasets.folder.default_loader)]
|
844 |
+
self.mask_loaders = [("mask.png", torchvision.datasets.folder.default_loader)]
|
845 |
+
self.original_bbox_loaders = [("box.txt", box_loader)]
|
846 |
+
self.few_shot_bbox_loaders = [("box.txt", few_shot_box_loader)]
|
847 |
+
|
848 |
+
self.num_sample_frames = num_sample_frames
|
849 |
+
|
850 |
+
self.batch_size = kwargs['batch_size'] # a hack way here
|
851 |
+
|
852 |
+
few_shot_data_paths = {}
|
853 |
+
for k,v in self.few_shot_data_dirs.items():
|
854 |
+
|
855 |
+
if k.startswith('_'):
|
856 |
+
# a boundary here for dealing with when in new data, we have same categories as in 7-cat
|
857 |
+
v = v.replace(k, k[1:])
|
858 |
+
|
859 |
+
if isinstance(v, str):
|
860 |
+
result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0])))
|
861 |
+
elif isinstance(v, list):
|
862 |
+
result = []
|
863 |
+
for _v in v:
|
864 |
+
result = result + sorted(glob(os.path.join(_v, '*'+self.image_loaders[0][0])))
|
865 |
+
else:
|
866 |
+
raise NotImplementedError
|
867 |
+
|
868 |
+
# result = sorted(glob(os.path.join(v, '*'+self.image_loaders[0][0])))
|
869 |
+
result = [p.replace(self.image_loaders[0][0], '{}') for p in result]
|
870 |
+
sequences = result
|
871 |
+
|
872 |
+
if shuffle:
|
873 |
+
random.shuffle(sequences)
|
874 |
+
few_shot_data_paths.update({k: sequences})
|
875 |
+
|
876 |
+
# for visualization purpose
|
877 |
+
self.pure_fs_data_path = few_shot_data_paths
|
878 |
+
|
879 |
+
self.all_data_paths, self.one_category_num = self._pad_paths(few_shot_data_paths)
|
880 |
+
self.all_category_num = len(self.all_data_paths.keys())
|
881 |
+
self.all_category_names = list(self.all_data_paths.keys())
|
882 |
+
|
883 |
+
self.in_image_size = in_image_size
|
884 |
+
self.out_image_size = out_image_size
|
885 |
+
self.load_background = load_background
|
886 |
+
self.color_jitter = color_jitter
|
887 |
+
self.image_transform = transforms.Compose([transforms.Resize(self.in_image_size), transforms.ToTensor()])
|
888 |
+
self.mask_transform = transforms.Compose([transforms.Resize(self.out_image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
|
889 |
+
self.random_flip = random_flip
|
890 |
+
self.load_dino_feature = load_dino_feature
|
891 |
+
if load_dino_feature:
|
892 |
+
self.dino_feature_loaders = [(f"feat{dino_feature_dim}.png", dino_loader, dino_feature_dim)]
|
893 |
+
|
894 |
+
def _pad_paths(self, fs_paths):
|
895 |
+
img_nums = []
|
896 |
+
all_paths = copy.deepcopy(fs_paths)
|
897 |
+
for _, v in all_paths.items():
|
898 |
+
img_nums.append(len(v))
|
899 |
+
|
900 |
+
img_num = max(img_nums)
|
901 |
+
img_num = (img_num // self.batch_size) * self.batch_size
|
902 |
+
|
903 |
+
for k,v in all_paths.items():
|
904 |
+
if len(v) < img_num:
|
905 |
+
mul_time = img_num // len(v)
|
906 |
+
pad_time = img_num % len(v)
|
907 |
+
# for each v, shuffle it
|
908 |
+
shuffle_v = copy.deepcopy(v)
|
909 |
+
new_v = []
|
910 |
+
for i in range(mul_time):
|
911 |
+
new_v = new_v + shuffle_v
|
912 |
+
random.shuffle(shuffle_v)
|
913 |
+
del shuffle_v
|
914 |
+
new_v = new_v + v[0:pad_time]
|
915 |
+
# new_v = mul_time * v + v[0:pad_time]
|
916 |
+
all_paths[k] = new_v
|
917 |
+
elif len(v) > img_num:
|
918 |
+
all_paths[k] = v[:img_num]
|
919 |
+
else:
|
920 |
+
continue
|
921 |
+
|
922 |
+
return all_paths, img_num
|
923 |
+
|
924 |
+
def _load_ids(self, path_patterns, loaders, transform=None):
|
925 |
+
result = []
|
926 |
+
for loader in loaders:
|
927 |
+
for p in path_patterns:
|
928 |
+
x = loader[1](p.format(loader[0]), *loader[2:])
|
929 |
+
if transform:
|
930 |
+
x = transform(x)
|
931 |
+
result.append(x)
|
932 |
+
return tuple(result)
|
933 |
+
|
934 |
+
def _shuffle_all(self):
|
935 |
+
for k,v in self.all_data_paths.items():
|
936 |
+
new_v = copy.deepcopy(v)
|
937 |
+
random.shuffle(new_v)
|
938 |
+
self.all_data_paths[k] = new_v
|
939 |
+
return None
|
940 |
+
|
941 |
+
def __len__(self):
|
942 |
+
return self.all_category_num * self.one_category_num
|
943 |
+
|
944 |
+
def __getitem__(self, index):
|
945 |
+
'''
|
946 |
+
This dataset must have non-shuffled index!!
|
947 |
+
'''
|
948 |
+
category_idx = (index % (self.batch_size * self.all_category_num)) // self.batch_size
|
949 |
+
path_idx = (index // (self.batch_size * self.all_category_num)) * self.batch_size + (index % (self.batch_size * self.all_category_num)) - category_idx * self.batch_size
|
950 |
+
category_name = self.all_category_names[category_idx]
|
951 |
+
paths = [self.all_data_paths[category_name][path_idx]] # len 1 sequence
|
952 |
+
|
953 |
+
# if category_name in self.original_category_names:
|
954 |
+
# bbox_loaders = self.original_bbox_loaders
|
955 |
+
# use_original_bbox = True
|
956 |
+
# else:
|
957 |
+
bbox_loaders = self.few_shot_bbox_loaders
|
958 |
+
use_original_bbox = False
|
959 |
+
|
960 |
+
masks = torch.stack(self._load_ids(paths, self.mask_loaders, transform=self.mask_transform), 0) # load all images
|
961 |
+
mask_dt = compute_distance_transform(masks)
|
962 |
+
jitter = False
|
963 |
+
if self.color_jitter is not None:
|
964 |
+
prob, b, h = self.color_jitter
|
965 |
+
if np.random.rand() < prob:
|
966 |
+
jitter = True
|
967 |
+
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
968 |
+
image_transform_fg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_fg, transforms.ToTensor()])
|
969 |
+
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
970 |
+
image_transform_bg = transforms.Compose([transforms.Resize(self.in_image_size), color_jitter_tsf_bg, transforms.ToTensor()])
|
971 |
+
if jitter:
|
972 |
+
images_fg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_fg), 0) # load all images
|
973 |
+
images_bg = torch.stack(self._load_ids(paths, self.image_loaders, transform=image_transform_bg), 0) # load all images
|
974 |
+
images = images_fg * masks + images_bg * (1-masks)
|
975 |
+
else:
|
976 |
+
images = torch.stack(self._load_ids(paths, self.image_loaders, transform=self.image_transform), 0) # load all images
|
977 |
+
|
978 |
+
flows = torch.zeros(1)
|
979 |
+
bboxs = torch.stack(self._load_ids(paths, bbox_loaders, transform=torch.FloatTensor), 0) # load bounding boxes for all images
|
980 |
+
if not use_original_bbox:
|
981 |
+
bboxs=torch.cat([bboxs, torch.Tensor([[category_idx]]).float()],dim=-1) # pad a label number
|
982 |
+
|
983 |
+
mask_valid = get_valid_mask(bboxs, (self.out_image_size, self.out_image_size)) # exclude pixels cropped outside the original image
|
984 |
+
if self.load_background:
|
985 |
+
bg_image = torchvision.datasets.folder.default_loader(os.path.join(os.path.dirname(paths[0]), 'background_frame.jpg'))
|
986 |
+
if jitter:
|
987 |
+
bg_image = color_jitter_tsf_bg(bg_image)
|
988 |
+
bg_images = crop_image(bg_image, bboxs[:, 1:5].int().numpy(), (self.in_image_size, self.in_image_size))
|
989 |
+
else:
|
990 |
+
bg_images = torch.zeros_like(images)
|
991 |
+
if self.load_dino_feature:
|
992 |
+
dino_features = torch.stack(self._load_ids(paths, self.dino_feature_loaders, transform=torch.FloatTensor), 0) # BxFx64x224x224
|
993 |
+
else:
|
994 |
+
dino_features = torch.zeros(1)
|
995 |
+
|
996 |
+
dino_clusters = torch.zeros(1)
|
997 |
+
|
998 |
+
# These are actually no use
|
999 |
+
seq_idx = 0
|
1000 |
+
seq_idx = torch.LongTensor([seq_idx])
|
1001 |
+
frame_idx = torch.arange(0, 1).long()
|
1002 |
+
|
1003 |
+
if self.random_flip and np.random.rand() < 0.5:
|
1004 |
+
images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters = horizontal_flip_all(images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters)
|
1005 |
+
|
1006 |
+
## pad shorter sequence
|
1007 |
+
if len(paths) < self.num_sample_frames:
|
1008 |
+
num_pad = self.num_sample_frames - len(paths)
|
1009 |
+
images = torch.cat([images[:1]] *num_pad + [images], 0)
|
1010 |
+
masks = torch.cat([masks[:1]] *num_pad + [masks], 0)
|
1011 |
+
mask_dt = torch.cat([mask_dt[:1]] *num_pad + [mask_dt], 0)
|
1012 |
+
mask_valid = torch.cat([mask_valid[:1]] *num_pad + [mask_valid], 0)
|
1013 |
+
if flows.dim() > 1:
|
1014 |
+
flows = torch.cat([flows[:1]*0] *num_pad + [flows], 0)
|
1015 |
+
bboxs = torch.cat([bboxs[:1]] * num_pad + [bboxs], 0)
|
1016 |
+
bg_images = torch.cat([bg_images[:1]] *num_pad + [bg_images], 0)
|
1017 |
+
if dino_features.dim() > 1:
|
1018 |
+
dino_features = torch.cat([dino_features[:1]] *num_pad + [dino_features], 0)
|
1019 |
+
if dino_clusters.dim() > 1:
|
1020 |
+
dino_clusters = torch.cat([dino_clusters[:1]] *num_pad + [dino_clusters], 0)
|
1021 |
+
frame_idx = torch.cat([frame_idx[:1]] *num_pad + [frame_idx], 0)
|
1022 |
+
|
1023 |
+
out = (*map(none_to_nan, (images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name)), )
|
1024 |
+
return out
|
1025 |
+
# return images, masks, mask_dt, mask_valid, flows, bboxs, bg_images, dino_features, dino_clusters, seq_idx, frame_idx, category_name
|
1026 |
+
|
1027 |
+
|
1028 |
+
|
1029 |
+
def get_test_loader_quadrupeds(test_data_dirs, rank, world_size, **kwargs):
|
1030 |
+
dataset = Quadrupeds_Image_Test_Dataset(test_data_dirs, **kwargs)
|
1031 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
1032 |
+
dataset,
|
1033 |
+
num_replicas=world_size,
|
1034 |
+
rank=rank,
|
1035 |
+
shuffle=False
|
1036 |
+
)
|
1037 |
+
loaders = []
|
1038 |
+
loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)]
|
1039 |
+
|
1040 |
+
return loaders
|
1041 |
+
|
1042 |
+
def get_sequence_loader(data_dir, **kwargs):
|
1043 |
+
if isinstance(data_dir, dict):
|
1044 |
+
loaders = []
|
1045 |
+
for k, v in data_dir.items():
|
1046 |
+
dataset= NFrameSequenceDataset(v, cat_name=k, **kwargs)
|
1047 |
+
loader = torch.utils.data.DataLoader(dataset, batch_size=kwargs['batch_size'], shuffle=kwargs['shuffle'], num_workers=kwargs['num_workers'], pin_memory=True)
|
1048 |
+
loaders += [loader]
|
1049 |
+
return loaders
|
1050 |
+
else:
|
1051 |
+
return [get_sequence_loader_single(data_dir, **kwargs)]
|
1052 |
+
|
1053 |
+
|
1054 |
+
def get_sequence_loader_single(data_dir, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64):
|
1055 |
+
if mode == 'n_frame':
|
1056 |
+
dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim)
|
1057 |
+
else:
|
1058 |
+
raise NotImplementedError
|
1059 |
+
loader = torch.utils.data.DataLoader(
|
1060 |
+
dataset,
|
1061 |
+
batch_size=batch_size,
|
1062 |
+
shuffle=not is_validation,
|
1063 |
+
num_workers=num_workers,
|
1064 |
+
pin_memory=True
|
1065 |
+
)
|
1066 |
+
return loader
|
1067 |
+
|
1068 |
+
|
1069 |
+
def get_sequence_loader_ddp(data_dir, world_size, rank, use_few_shot=False, **kwargs):
|
1070 |
+
original_classes_num = 0
|
1071 |
+
use_few_shot = use_few_shot
|
1072 |
+
if isinstance(data_dir, list) and len(data_dir) == 2 and isinstance(data_dir[-1], dict):
|
1073 |
+
# a hack way for few shot experiment
|
1074 |
+
original_classes_num = data_dir[0]
|
1075 |
+
data_dir = data_dir[-1]
|
1076 |
+
if isinstance(data_dir, dict):
|
1077 |
+
loaders = []
|
1078 |
+
cnt = original_classes_num
|
1079 |
+
for k, v in data_dir.items():
|
1080 |
+
if use_few_shot:
|
1081 |
+
dataset = FewShotImageDataset(v, cat_name=k, cat_num=cnt, **kwargs)
|
1082 |
+
cnt += 1
|
1083 |
+
else:
|
1084 |
+
dataset = NFrameSequenceDataset(v, cat_name=k, **kwargs)
|
1085 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
1086 |
+
dataset,
|
1087 |
+
num_replicas=world_size,
|
1088 |
+
rank=rank,
|
1089 |
+
)
|
1090 |
+
loaders += [torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=kwargs['batch_size'], shuffle=False, drop_last=True, num_workers=kwargs['num_workers'], pin_memory=True)]
|
1091 |
+
return loaders
|
1092 |
+
else:
|
1093 |
+
return [get_sequence_loader_single_ddp(data_dir, world_size, rank, **kwargs)]
|
1094 |
+
|
1095 |
+
|
1096 |
+
def get_sequence_loader_single_ddp(data_dir, world_size, rank, mode='all_frame', is_validation=False, batch_size=256, num_workers=4, in_image_size=256, out_image_size=256, debug_seq=False, num_sample_frames=2, skip_beginning=4, skip_end=4, min_seq_len=10, max_seq_len=256, random_sample=False, shuffle=False, dense_sample=True, color_jitter=None, load_background=False, random_flip=False, rgb_suffix='.jpg', load_dino_feature=False, load_dino_cluster=False, dino_feature_dim=64, flow_bool=False):
|
1097 |
+
if mode == 'n_frame':
|
1098 |
+
dataset = NFrameSequenceDataset(data_dir, num_sample_frames=num_sample_frames, skip_beginning=skip_beginning, skip_end=skip_end, min_seq_len=min_seq_len, in_image_size=in_image_size, out_image_size=out_image_size, debug_seq=debug_seq, random_sample=random_sample, shuffle=shuffle, dense_sample=dense_sample, color_jitter=color_jitter, load_background=load_background, random_flip=random_flip, rgb_suffix=rgb_suffix, load_dino_feature=load_dino_feature, load_dino_cluster=load_dino_cluster, dino_feature_dim=dino_feature_dim, flow_bool=flow_bool)
|
1099 |
+
else:
|
1100 |
+
raise NotImplementedError
|
1101 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
1102 |
+
dataset,
|
1103 |
+
num_replicas=world_size,
|
1104 |
+
rank=rank,
|
1105 |
+
)
|
1106 |
+
loader = torch.utils.data.DataLoader(
|
1107 |
+
dataset,
|
1108 |
+
sampler=sampler,
|
1109 |
+
batch_size=batch_size,
|
1110 |
+
shuffle=False,
|
1111 |
+
drop_last=True,
|
1112 |
+
num_workers=num_workers,
|
1113 |
+
pin_memory=True
|
1114 |
+
)
|
1115 |
+
return loader
|
1116 |
+
|
1117 |
+
|
1118 |
+
class ImageDataset(Dataset):
|
1119 |
+
def __init__(self, root, is_validation=False, image_size=256, color_jitter=None):
|
1120 |
+
super().__init__()
|
1121 |
+
self.image_loader = ("rgb.jpg", torchvision.datasets.folder.default_loader)
|
1122 |
+
self.mask_loader = ("mask.png", torchvision.datasets.folder.default_loader)
|
1123 |
+
self.bbox_loader = ("box.txt", np.loadtxt, 'str')
|
1124 |
+
self.samples = self._parse_folder(root)
|
1125 |
+
self.image_size = image_size
|
1126 |
+
self.color_jitter = color_jitter
|
1127 |
+
self.image_transform = transforms.Compose([transforms.Resize(self.image_size), transforms.ToTensor()])
|
1128 |
+
self.mask_transform = transforms.Compose([transforms.Resize(self.image_size, interpolation=Image.NEAREST), transforms.ToTensor()])
|
1129 |
+
|
1130 |
+
def _parse_folder(self, path):
|
1131 |
+
result = sorted(glob(os.path.join(path, '**/*'+self.image_loader[0]), recursive=True))
|
1132 |
+
result = [p.replace(self.image_loader[0], '{}') for p in result]
|
1133 |
+
return result
|
1134 |
+
|
1135 |
+
def _load_ids(self, path, loader, transform=None):
|
1136 |
+
x = loader[1](path.format(loader[0]), *loader[2:])
|
1137 |
+
if transform:
|
1138 |
+
x = transform(x)
|
1139 |
+
return x
|
1140 |
+
|
1141 |
+
def __len__(self):
|
1142 |
+
return len(self.samples)
|
1143 |
+
|
1144 |
+
def __getitem__(self, index):
|
1145 |
+
path = self.samples[index % len(self.samples)]
|
1146 |
+
masks = self._load_ids(path, self.mask_loader, transform=self.mask_transform).unsqueeze(0)
|
1147 |
+
mask_dt = compute_distance_transform(masks)
|
1148 |
+
jitter = False
|
1149 |
+
if self.color_jitter is not None:
|
1150 |
+
prob, b, h = self.color_jitter
|
1151 |
+
if np.random.rand() < prob:
|
1152 |
+
jitter = True
|
1153 |
+
color_jitter_tsf_fg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
1154 |
+
image_transform_fg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_fg, transforms.ToTensor()])
|
1155 |
+
color_jitter_tsf_bg = transforms.ColorJitter.get_params(brightness=(1-b, 1+b), contrast=None, saturation=None, hue=(-h, h))
|
1156 |
+
image_transform_bg = transforms.Compose([transforms.Resize(self.image_size), color_jitter_tsf_bg, transforms.ToTensor()])
|
1157 |
+
if jitter:
|
1158 |
+
images_fg = self._load_ids(path, self.image_loader, transform=image_transform_fg).unsqueeze(0)
|
1159 |
+
images_bg = self._load_ids(path, self.image_loader, transform=image_transform_bg).unsqueeze(0)
|
1160 |
+
images = images_fg * masks + images_bg * (1-masks)
|
1161 |
+
else:
|
1162 |
+
images = self._load_ids(path, self.image_loader, transform=self.image_transform).unsqueeze(0)
|
1163 |
+
flows = torch.zeros(1)
|
1164 |
+
bboxs = self._load_ids(path, self.bbox_loader, transform=None)
|
1165 |
+
bboxs[0] = '0'
|
1166 |
+
bboxs = torch.FloatTensor(bboxs.astype('float')).unsqueeze(0)
|
1167 |
+
bg_fpath = os.path.join(os.path.dirname(path), 'background_frame.jpg')
|
1168 |
+
if os.path.isfile(bg_fpath):
|
1169 |
+
bg_image = torchvision.datasets.folder.default_loader(bg_fpath)
|
1170 |
+
if jitter:
|
1171 |
+
bg_image = color_jitter_tsf_bg(bg_image)
|
1172 |
+
bg_image = transforms.ToTensor()(bg_image)
|
1173 |
+
else:
|
1174 |
+
bg_image = images[0]
|
1175 |
+
seq_idx = torch.LongTensor([index])
|
1176 |
+
frame_idx = torch.LongTensor([0])
|
1177 |
+
return images, masks, mask_dt, flows, bboxs, bg_image, seq_idx, frame_idx
|
1178 |
+
|
1179 |
+
|
1180 |
+
def get_image_loader(data_dir, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None):
|
1181 |
+
dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter)
|
1182 |
+
|
1183 |
+
loader = torch.utils.data.DataLoader(
|
1184 |
+
dataset,
|
1185 |
+
batch_size=batch_size,
|
1186 |
+
shuffle=False,
|
1187 |
+
num_workers=num_workers,
|
1188 |
+
pin_memory=True
|
1189 |
+
)
|
1190 |
+
return loader
|
1191 |
+
|
1192 |
+
|
1193 |
+
def get_image_loader_ddp(data_dir, world_size, rank, is_validation=False, batch_size=256, num_workers=4, image_size=256, color_jitter=None):
|
1194 |
+
dataset = ImageDataset(data_dir, is_validation=is_validation, image_size=image_size, color_jitter=color_jitter)
|
1195 |
+
|
1196 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
1197 |
+
dataset,
|
1198 |
+
num_replicas=world_size,
|
1199 |
+
rank=rank,
|
1200 |
+
)
|
1201 |
+
loader = torch.utils.data.DataLoader(
|
1202 |
+
dataset,
|
1203 |
+
sampler=sampler,
|
1204 |
+
batch_size=batch_size,
|
1205 |
+
shuffle=False,
|
1206 |
+
drop_last=True,
|
1207 |
+
num_workers=num_workers,
|
1208 |
+
pin_memory=True
|
1209 |
+
)
|
1210 |
+
return loader
|
video3d/diffusion/sd.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
# os.environ['HUGGINGFACE_HUB_CACHE'] = '/work/tomj/cache/huggingface_hub'
|
3 |
+
# os.environ['HF_HOME'] = '/work/tomj/cache/huggingface_hub'
|
4 |
+
os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli'
|
5 |
+
os.environ['HF_HOME'] = '/viscam/u/zzli'
|
6 |
+
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
8 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler
|
9 |
+
|
10 |
+
# Suppress partial model loading warning
|
11 |
+
logging.set_verbosity_error()
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
18 |
+
|
19 |
+
class SpecifyGradient(torch.autograd.Function):
|
20 |
+
@staticmethod
|
21 |
+
@custom_fwd
|
22 |
+
def forward(ctx, input_tensor, gt_grad):
|
23 |
+
ctx.save_for_backward(gt_grad)
|
24 |
+
return torch.zeros([1], device=input_tensor.device, dtype=input_tensor.dtype) # Dummy loss value
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
@custom_bwd
|
28 |
+
def backward(ctx, grad):
|
29 |
+
gt_grad, = ctx.saved_tensors
|
30 |
+
batch_size = len(gt_grad)
|
31 |
+
return gt_grad / batch_size, None
|
32 |
+
|
33 |
+
def seed_everything(seed):
|
34 |
+
torch.manual_seed(seed)
|
35 |
+
torch.cuda.manual_seed(seed)
|
36 |
+
|
37 |
+
|
38 |
+
class StableDiffusion(nn.Module):
|
39 |
+
def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.device = device
|
43 |
+
self.sd_version = sd_version
|
44 |
+
self.torch_dtype = torch_dtype
|
45 |
+
|
46 |
+
print(f'[INFO] loading stable diffusion...')
|
47 |
+
|
48 |
+
if hf_key is not None:
|
49 |
+
print(f'[INFO] using hugging face custom model key: {hf_key}')
|
50 |
+
model_key = hf_key
|
51 |
+
elif self.sd_version == '2.1':
|
52 |
+
model_key = "stabilityai/stable-diffusion-2-1-base"
|
53 |
+
elif self.sd_version == '2.0':
|
54 |
+
model_key = "stabilityai/stable-diffusion-2-base"
|
55 |
+
elif self.sd_version == '1.5':
|
56 |
+
model_key = "runwayml/stable-diffusion-v1-5"
|
57 |
+
else:
|
58 |
+
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
|
59 |
+
|
60 |
+
# Create model
|
61 |
+
self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device)
|
62 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
|
63 |
+
self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
|
64 |
+
self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
|
65 |
+
|
66 |
+
self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
67 |
+
# self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
68 |
+
|
69 |
+
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
70 |
+
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
71 |
+
|
72 |
+
print(f'[INFO] loaded stable diffusion!')
|
73 |
+
|
74 |
+
def get_text_embeds(self, prompt, negative_prompt):
|
75 |
+
# prompt, negative_prompt: [str]
|
76 |
+
|
77 |
+
# Tokenize text and get embeddings
|
78 |
+
text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
|
79 |
+
|
80 |
+
with torch.no_grad():
|
81 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
82 |
+
|
83 |
+
# Do the same for unconditional embeddings
|
84 |
+
uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
88 |
+
|
89 |
+
# Cat for final embeddings
|
90 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
91 |
+
return text_embeddings
|
92 |
+
|
93 |
+
def train_step(self, text_embeddings, pred_rgb,
|
94 |
+
guidance_scale=100, loss_weight=1.0, min_step_pct=0.02, max_step_pct=0.98, return_aux=False):
|
95 |
+
pred_rgb = pred_rgb.to(self.torch_dtype)
|
96 |
+
text_embeddings = text_embeddings.to(self.torch_dtype)
|
97 |
+
b = pred_rgb.shape[0]
|
98 |
+
|
99 |
+
# interp to 512x512 to be fed into vae.
|
100 |
+
|
101 |
+
# _t = time.time()
|
102 |
+
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
103 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
|
104 |
+
|
105 |
+
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
106 |
+
min_step = int(self.num_train_timesteps * min_step_pct)
|
107 |
+
max_step = int(self.num_train_timesteps * max_step_pct)
|
108 |
+
t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device)
|
109 |
+
|
110 |
+
# encode image into latents with vae, requires grad!
|
111 |
+
# _t = time.time()
|
112 |
+
latents = self.encode_imgs(pred_rgb_512)
|
113 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
|
114 |
+
|
115 |
+
# predict the noise residual with unet, NO grad!
|
116 |
+
# _t = time.time()
|
117 |
+
with torch.no_grad():
|
118 |
+
# add noise
|
119 |
+
noise = torch.randn_like(latents)
|
120 |
+
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
121 |
+
# pred noise
|
122 |
+
latent_model_input = torch.cat([latents_noisy] * 2)
|
123 |
+
t_input = torch.cat([t, t])
|
124 |
+
noise_pred = self.unet(latent_model_input, t_input, encoder_hidden_states=text_embeddings).sample
|
125 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s')
|
126 |
+
|
127 |
+
# perform guidance (high scale from paper!)
|
128 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
129 |
+
# noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
130 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
131 |
+
|
132 |
+
# w(t), sigma_t^2
|
133 |
+
w = (1 - self.alphas[t])
|
134 |
+
# w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
|
135 |
+
grad = loss_weight * w[:, None, None, None] * (noise_pred - noise)
|
136 |
+
|
137 |
+
# clip grad for stable training?
|
138 |
+
# grad = grad.clamp(-10, 10)
|
139 |
+
grad = torch.nan_to_num(grad)
|
140 |
+
|
141 |
+
# since we omitted an item in grad, we need to use the custom function to specify the gradient
|
142 |
+
# _t = time.time()
|
143 |
+
# loss = SpecifyGradient.apply(latents, grad)
|
144 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s')
|
145 |
+
|
146 |
+
targets = (latents - grad).detach()
|
147 |
+
loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
|
148 |
+
|
149 |
+
if return_aux:
|
150 |
+
aux = {'grad': grad, 't': t, 'w': w}
|
151 |
+
return loss, aux
|
152 |
+
else:
|
153 |
+
return loss
|
154 |
+
|
155 |
+
|
156 |
+
def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
157 |
+
|
158 |
+
if latents is None:
|
159 |
+
latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.config.in_channels, height // 8, width // 8), device=self.device)
|
160 |
+
|
161 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
162 |
+
|
163 |
+
with torch.autocast('cuda'):
|
164 |
+
for i, t in enumerate(self.scheduler.timesteps):
|
165 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
166 |
+
latent_model_input = torch.cat([latents] * 2)
|
167 |
+
|
168 |
+
# predict the noise residual
|
169 |
+
with torch.no_grad():
|
170 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
|
171 |
+
|
172 |
+
# perform guidance
|
173 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
174 |
+
noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
175 |
+
|
176 |
+
# compute the previous noisy sample x_t -> x_t-1
|
177 |
+
latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']
|
178 |
+
|
179 |
+
return latents
|
180 |
+
|
181 |
+
def decode_latents(self, latents):
|
182 |
+
|
183 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
184 |
+
|
185 |
+
with torch.no_grad():
|
186 |
+
imgs = self.vae.decode(latents).sample
|
187 |
+
|
188 |
+
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
189 |
+
|
190 |
+
return imgs
|
191 |
+
|
192 |
+
def encode_imgs(self, imgs):
|
193 |
+
# imgs: [B, 3, H, W]
|
194 |
+
|
195 |
+
imgs = 2 * imgs - 1
|
196 |
+
|
197 |
+
posterior = self.vae.encode(imgs).latent_dist
|
198 |
+
latents = posterior.sample() * self.vae.config.scaling_factor
|
199 |
+
|
200 |
+
return latents
|
201 |
+
|
202 |
+
def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
|
203 |
+
|
204 |
+
if isinstance(prompts, str):
|
205 |
+
prompts = [prompts]
|
206 |
+
|
207 |
+
if isinstance(negative_prompts, str):
|
208 |
+
negative_prompts = [negative_prompts]
|
209 |
+
|
210 |
+
# Prompts -> text embeds
|
211 |
+
text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768]
|
212 |
+
|
213 |
+
# Text embeds -> img latents
|
214 |
+
latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]
|
215 |
+
|
216 |
+
# Img latents -> imgs
|
217 |
+
imgs = self.decode_latents(latents) # [1, 3, 512, 512]
|
218 |
+
|
219 |
+
# Img to Numpy
|
220 |
+
imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
|
221 |
+
imgs = (imgs * 255).round().astype('uint8')
|
222 |
+
|
223 |
+
return imgs
|
224 |
+
|
225 |
+
|
226 |
+
if __name__ == '__main__':
|
227 |
+
import argparse
|
228 |
+
import matplotlib.pyplot as plt
|
229 |
+
|
230 |
+
parser = argparse.ArgumentParser()
|
231 |
+
parser.add_argument('prompt', type=str)
|
232 |
+
parser.add_argument('--negative', default='', type=str)
|
233 |
+
parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
|
234 |
+
parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
|
235 |
+
parser.add_argument('-H', type=int, default=512)
|
236 |
+
parser.add_argument('-W', type=int, default=512)
|
237 |
+
parser.add_argument('--seed', type=int, default=0)
|
238 |
+
parser.add_argument('--steps', type=int, default=50)
|
239 |
+
opt = parser.parse_args()
|
240 |
+
|
241 |
+
seed_everything(opt.seed)
|
242 |
+
|
243 |
+
device = torch.device('cuda')
|
244 |
+
|
245 |
+
sd = StableDiffusion(device, opt.sd_version, opt.hf_key)
|
246 |
+
|
247 |
+
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
|
248 |
+
|
249 |
+
# visualize image
|
250 |
+
plt.imshow(imgs[0])
|
251 |
+
plt.show()
|
252 |
+
plt.savefig(f'{opt.prompt}.png')
|
video3d/diffusion/sd_utils.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from ..render.light import DirectionalLight
|
7 |
+
|
8 |
+
def safe_normalize(x, eps=1e-20):
|
9 |
+
return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
|
10 |
+
|
11 |
+
def get_view_direction(thetas, phis, overhead, front, phi_offset=0):
|
12 |
+
# phis [B,]; thetas: [B,]
|
13 |
+
# front = 0 [360 - front / 2, front / 2)
|
14 |
+
# side (left) = 1 [front / 2, 180 - front / 2)
|
15 |
+
# back = 2 [180 - front / 2, 180 + front / 2)
|
16 |
+
# side (right) = 3 [180 + front / 2, 360 - front / 2)
|
17 |
+
# top = 4 [0, overhead]
|
18 |
+
# bottom = 5 [180-overhead, 180]
|
19 |
+
res = torch.zeros(thetas.shape[0], dtype=torch.long)
|
20 |
+
|
21 |
+
# first determine by phis
|
22 |
+
phi_offset = np.deg2rad(phi_offset)
|
23 |
+
phis = phis + phi_offset
|
24 |
+
phis = phis % (2 * np.pi)
|
25 |
+
half_front = front / 2
|
26 |
+
|
27 |
+
res[(phis >= (2*np.pi - half_front)) | (phis < half_front)] = 0
|
28 |
+
res[(phis >= half_front) & (phis < (np.pi - half_front))] = 1
|
29 |
+
res[(phis >= (np.pi - half_front)) & (phis < (np.pi + half_front))] = 2
|
30 |
+
res[(phis >= (np.pi + half_front)) & (phis < (2*np.pi - half_front))] = 3
|
31 |
+
|
32 |
+
# override by thetas
|
33 |
+
res[thetas <= overhead] = 4
|
34 |
+
res[thetas >= (np.pi - overhead)] = 5
|
35 |
+
return res
|
36 |
+
|
37 |
+
|
38 |
+
def view_direction_id_to_text(view_direction_id):
|
39 |
+
dir_texts = ['front', 'side', 'back', 'side', 'overhead', 'bottom']
|
40 |
+
return [dir_texts[i] for i in view_direction_id]
|
41 |
+
|
42 |
+
|
43 |
+
def append_text_direction(prompts, dir_texts):
|
44 |
+
return [f'{prompt}, {dir_text} view' for prompt, dir_text in zip(prompts, dir_texts)]
|
45 |
+
|
46 |
+
|
47 |
+
def rand_lights(camera_dir, fixed_ambient, fixed_diffuse):
|
48 |
+
size = camera_dir.shape[0]
|
49 |
+
device = camera_dir.device
|
50 |
+
random_fixed_dir = F.normalize(torch.randn_like(camera_dir) + camera_dir, dim=-1) # Centered around camera_dir
|
51 |
+
random_fixed_intensity = torch.tensor([fixed_ambient, fixed_diffuse], device=device)[None, :].repeat(size, 1) # ambient, diffuse
|
52 |
+
return DirectionalLight(mlp_in=1, mlp_layers=1, mlp_hidden_size=1, # Dummy values
|
53 |
+
intensity_min_max=[0.5, 1],fixed_dir=random_fixed_dir, fixed_intensity=random_fixed_intensity).to(device)
|
54 |
+
|
55 |
+
def rand_poses(size, device, radius_range=[1, 1], theta_range=[0, 120], phi_range=[0, 360], cam_z_offset=10, return_dirs=False, angle_overhead=30, angle_front=60, phi_offset=0, jitter=False, uniform_sphere_rate=0.5):
|
56 |
+
''' generate random poses from an orbit camera
|
57 |
+
Args:
|
58 |
+
size: batch size of generated poses.
|
59 |
+
device: where to allocate the output.
|
60 |
+
radius_range: [min, max]
|
61 |
+
theta_range: [min, max], should be in [0, pi]
|
62 |
+
phi_range: [min, max], should be in [0, 2 * pi]
|
63 |
+
Return:
|
64 |
+
poses: [size, 4, 4]
|
65 |
+
'''
|
66 |
+
|
67 |
+
theta_range = np.deg2rad(theta_range)
|
68 |
+
phi_range = np.deg2rad(phi_range)
|
69 |
+
angle_overhead = np.deg2rad(angle_overhead)
|
70 |
+
angle_front = np.deg2rad(angle_front)
|
71 |
+
|
72 |
+
radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0]
|
73 |
+
|
74 |
+
phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
|
75 |
+
if random.random() < uniform_sphere_rate:
|
76 |
+
# based on http://corysimon.github.io/articles/uniformdistn-on-sphere/
|
77 |
+
# acos takes in [-1, 1], first convert theta range to fit in [-1, 1]
|
78 |
+
theta_range = torch.from_numpy(np.array(theta_range)).to(device)
|
79 |
+
theta_amplitude_range = torch.cos(theta_range)
|
80 |
+
# sample uniformly in amplitude space range
|
81 |
+
thetas_amplitude = torch.rand(size, device=device) * (theta_amplitude_range[1] - theta_amplitude_range[0]) + theta_amplitude_range[0]
|
82 |
+
# convert back
|
83 |
+
thetas = torch.acos(thetas_amplitude)
|
84 |
+
else:
|
85 |
+
thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0]
|
86 |
+
|
87 |
+
centers = -torch.stack([
|
88 |
+
radius * torch.sin(thetas) * torch.sin(phis),
|
89 |
+
radius * torch.cos(thetas),
|
90 |
+
radius * torch.sin(thetas) * torch.cos(phis),
|
91 |
+
], dim=-1) # [B, 3]
|
92 |
+
|
93 |
+
targets = 0
|
94 |
+
|
95 |
+
# jitters
|
96 |
+
if jitter:
|
97 |
+
centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
|
98 |
+
targets = targets + torch.randn_like(centers) * 0.2
|
99 |
+
|
100 |
+
# lookat
|
101 |
+
forward_vector = safe_normalize(targets - centers)
|
102 |
+
up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
|
103 |
+
right_vector = safe_normalize(torch.cross(up_vector, forward_vector, dim=-1))
|
104 |
+
|
105 |
+
if jitter:
|
106 |
+
up_noise = torch.randn_like(up_vector) * 0.02
|
107 |
+
else:
|
108 |
+
up_noise = 0
|
109 |
+
|
110 |
+
up_vector = safe_normalize(torch.cross(forward_vector, right_vector, dim=-1) + up_noise)
|
111 |
+
|
112 |
+
poses = torch.stack([right_vector, up_vector, forward_vector], dim=-1)
|
113 |
+
radius = radius[..., None] - cam_z_offset
|
114 |
+
translations = torch.cat([torch.zeros_like(radius), torch.zeros_like(radius), radius], dim=-1)
|
115 |
+
poses = torch.cat([poses.view(-1, 9), translations], dim=-1)
|
116 |
+
|
117 |
+
if return_dirs:
|
118 |
+
dirs = get_view_direction(thetas, phis, angle_overhead, angle_front, phi_offset=phi_offset)
|
119 |
+
dirs = view_direction_id_to_text(dirs)
|
120 |
+
else:
|
121 |
+
dirs = None
|
122 |
+
|
123 |
+
return poses, dirs
|
video3d/diffusion/vsd.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['HUGGINGFACE_HUB_CACHE'] = '/viscam/u/zzli'
|
3 |
+
os.environ['HF_HOME'] = '/viscam/u/zzli'
|
4 |
+
|
5 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
6 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler
|
7 |
+
|
8 |
+
from diffusers.loaders import AttnProcsLayers
|
9 |
+
from diffusers.models.attention_processor import LoRAAttnProcessor
|
10 |
+
from diffusers.models.embeddings import TimestepEmbedding
|
11 |
+
from diffusers.utils.import_utils import is_xformers_available
|
12 |
+
|
13 |
+
# Suppress partial model loading warning
|
14 |
+
logging.set_verbosity_error()
|
15 |
+
|
16 |
+
import gc
|
17 |
+
import random
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import tinycudann as tcnn
|
22 |
+
from video3d.diffusion.sd import StableDiffusion
|
23 |
+
from torch.cuda.amp import custom_bwd, custom_fwd
|
24 |
+
|
25 |
+
|
26 |
+
def seed_everything(seed):
|
27 |
+
torch.manual_seed(seed)
|
28 |
+
torch.cuda.manual_seed(seed)
|
29 |
+
|
30 |
+
def cleanup():
|
31 |
+
gc.collect()
|
32 |
+
torch.cuda.empty_cache()
|
33 |
+
tcnn.free_temporary_memory()
|
34 |
+
|
35 |
+
class StableDiffusion_VSD(StableDiffusion):
|
36 |
+
def __init__(self, device, sd_version='2.1', hf_key=None, torch_dtype=torch.float32, lora_n_timestamp_samples=1):
|
37 |
+
super().__init__(device, sd_version=sd_version, hf_key=hf_key, torch_dtype=torch_dtype)
|
38 |
+
|
39 |
+
# self.device = device
|
40 |
+
# self.sd_version = sd_version
|
41 |
+
# self.torch_dtype = torch_dtype
|
42 |
+
|
43 |
+
if hf_key is not None:
|
44 |
+
print(f'[INFO] using hugging face custom model key: {hf_key}')
|
45 |
+
model_key = hf_key
|
46 |
+
elif self.sd_version == '2.1':
|
47 |
+
model_key = "stabilityai/stable-diffusion-2-1-base"
|
48 |
+
elif self.sd_version == '2.0':
|
49 |
+
model_key = "stabilityai/stable-diffusion-2-base"
|
50 |
+
elif self.sd_version == '1.5':
|
51 |
+
model_key = "runwayml/stable-diffusion-v1-5"
|
52 |
+
else:
|
53 |
+
raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
|
54 |
+
|
55 |
+
# # Create model
|
56 |
+
# self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", torch_dtype=torch_dtype).to(self.device)
|
57 |
+
# self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
|
58 |
+
# self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device)
|
59 |
+
# self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
|
60 |
+
|
61 |
+
# self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
62 |
+
# # self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
63 |
+
|
64 |
+
# self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
65 |
+
# self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
|
66 |
+
|
67 |
+
print(f'[INFO] loading stable diffusion VSD modules...')
|
68 |
+
|
69 |
+
self.unet_lora = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", torch_dtype=torch_dtype).to(self.device)
|
70 |
+
cleanup()
|
71 |
+
|
72 |
+
for p in self.vae.parameters():
|
73 |
+
p.requires_grad_(False)
|
74 |
+
for p in self.text_encoder.parameters():
|
75 |
+
p.requires_grad_(False)
|
76 |
+
for p in self.unet.parameters():
|
77 |
+
p.requires_grad_(False)
|
78 |
+
for p in self.unet_lora.parameters():
|
79 |
+
p.requires_grad_(False)
|
80 |
+
|
81 |
+
# set up LoRA layers
|
82 |
+
lora_attn_procs = {}
|
83 |
+
for name in self.unet_lora.attn_processors.keys():
|
84 |
+
cross_attention_dim = (
|
85 |
+
None
|
86 |
+
if name.endswith("attn1.processor")
|
87 |
+
else self.unet_lora.config.cross_attention_dim
|
88 |
+
)
|
89 |
+
if name.startswith("mid_block"):
|
90 |
+
hidden_size = self.unet_lora.config.block_out_channels[-1]
|
91 |
+
elif name.startswith("up_blocks"):
|
92 |
+
block_id = int(name[len("up_blocks.")])
|
93 |
+
hidden_size = list(reversed(self.unet_lora.config.block_out_channels))[
|
94 |
+
block_id
|
95 |
+
]
|
96 |
+
elif name.startswith("down_blocks"):
|
97 |
+
block_id = int(name[len("down_blocks.")])
|
98 |
+
hidden_size = self.unet_lora.config.block_out_channels[block_id]
|
99 |
+
|
100 |
+
lora_attn_procs[name] = LoRAAttnProcessor(
|
101 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
102 |
+
)
|
103 |
+
|
104 |
+
self.unet_lora.set_attn_processor(lora_attn_procs)
|
105 |
+
|
106 |
+
self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to(
|
107 |
+
self.device
|
108 |
+
)
|
109 |
+
self.lora_layers._load_state_dict_pre_hooks.clear()
|
110 |
+
self.lora_layers._state_dict_hooks.clear()
|
111 |
+
self.lora_n_timestamp_samples = lora_n_timestamp_samples
|
112 |
+
self.scheduler_lora = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
|
113 |
+
|
114 |
+
print(f'[INFO] loaded stable diffusion VSD modules!')
|
115 |
+
|
116 |
+
def train_lora(
|
117 |
+
self,
|
118 |
+
latents,
|
119 |
+
text_embeddings,
|
120 |
+
camera_condition
|
121 |
+
):
|
122 |
+
B = latents.shape[0]
|
123 |
+
lora_n_timestamp_samples = self.lora_n_timestamp_samples
|
124 |
+
latents = latents.detach().repeat(lora_n_timestamp_samples, 1, 1, 1)
|
125 |
+
|
126 |
+
t = torch.randint(
|
127 |
+
int(self.num_train_timesteps * 0.0),
|
128 |
+
int(self.num_train_timesteps * 1.0),
|
129 |
+
[B * lora_n_timestamp_samples],
|
130 |
+
dtype=torch.long,
|
131 |
+
device=self.device,
|
132 |
+
)
|
133 |
+
|
134 |
+
noise = torch.randn_like(latents)
|
135 |
+
noisy_latents = self.scheduler_lora.add_noise(latents, noise, t)
|
136 |
+
if self.scheduler_lora.config.prediction_type == "epsilon":
|
137 |
+
target = noise
|
138 |
+
elif self.scheduler_lora.config.prediction_type == "v_prediction":
|
139 |
+
target = self.scheduler_lora.get_velocity(latents, noise, t)
|
140 |
+
else:
|
141 |
+
raise ValueError(
|
142 |
+
f"Unknown prediction type {self.scheduler_lora.config.prediction_type}"
|
143 |
+
)
|
144 |
+
|
145 |
+
# use view-independent text embeddings in LoRA
|
146 |
+
_, text_embeddings_cond = text_embeddings.chunk(2)
|
147 |
+
|
148 |
+
if random.random() < 0.1:
|
149 |
+
camera_condition = torch.zeros_like(camera_condition)
|
150 |
+
|
151 |
+
noise_pred = self.unet_lora(
|
152 |
+
noisy_latents,
|
153 |
+
t,
|
154 |
+
encoder_hidden_states=text_embeddings_cond.repeat(
|
155 |
+
lora_n_timestamp_samples, 1, 1
|
156 |
+
),
|
157 |
+
class_labels=camera_condition.reshape(B, -1).repeat(
|
158 |
+
lora_n_timestamp_samples, 1
|
159 |
+
),
|
160 |
+
cross_attention_kwargs={"scale": 1.0}
|
161 |
+
).sample
|
162 |
+
|
163 |
+
loss_lora = 0.5 * F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
164 |
+
return loss_lora
|
165 |
+
|
166 |
+
|
167 |
+
def train_step(
|
168 |
+
self,
|
169 |
+
text_embeddings,
|
170 |
+
text_embeddings_vd,
|
171 |
+
pred_rgb,
|
172 |
+
camera_condition,
|
173 |
+
im_features,
|
174 |
+
guidance_scale=7.5,
|
175 |
+
guidance_scale_lora=7.5,
|
176 |
+
loss_weight=1.0,
|
177 |
+
min_step_pct=0.02,
|
178 |
+
max_step_pct=0.98,
|
179 |
+
return_aux=False
|
180 |
+
):
|
181 |
+
pred_rgb = pred_rgb.to(self.torch_dtype)
|
182 |
+
text_embeddings = text_embeddings.to(self.torch_dtype)
|
183 |
+
text_embeddings_vd = text_embeddings_vd.to(self.torch_dtype)
|
184 |
+
camera_condition = camera_condition.to(self.torch_dtype)
|
185 |
+
im_features = im_features.to(self.torch_dtype)
|
186 |
+
|
187 |
+
# condition_label = camera_condition
|
188 |
+
condition_label = im_features
|
189 |
+
|
190 |
+
b = pred_rgb.shape[0]
|
191 |
+
|
192 |
+
# interp to 512x512 to be fed into vae.
|
193 |
+
# _t = time.time()
|
194 |
+
pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
|
195 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s')
|
196 |
+
|
197 |
+
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
198 |
+
min_step = int(self.num_train_timesteps * min_step_pct)
|
199 |
+
max_step = int(self.num_train_timesteps * max_step_pct)
|
200 |
+
t = torch.randint(min_step, max_step + 1, [b], dtype=torch.long, device=self.device)
|
201 |
+
|
202 |
+
# encode image into latents with vae, requires grad!
|
203 |
+
# _t = time.time()
|
204 |
+
latents = self.encode_imgs(pred_rgb_512)
|
205 |
+
# torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s')
|
206 |
+
|
207 |
+
# predict the noise residual with unet, NO grad!
|
208 |
+
# _t = time.time()
|
209 |
+
with torch.no_grad():
|
210 |
+
# add noise
|
211 |
+
noise = torch.randn_like(latents)
|
212 |
+
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
213 |
+
# pred noise
|
214 |
+
latent_model_input = torch.cat([latents_noisy] * 2)
|
215 |
+
|
216 |
+
# disable unet class embedding here
|
217 |
+
cls_embedding = self.unet.class_embedding
|
218 |
+
self.unet.class_embedding = None
|
219 |
+
|
220 |
+
cross_attention_kwargs = None
|
221 |
+
noise_pred_pretrain = self.unet(
|
222 |
+
latent_model_input,
|
223 |
+
torch.cat([t, t]),
|
224 |
+
encoder_hidden_states=text_embeddings_vd,
|
225 |
+
class_labels=None,
|
226 |
+
cross_attention_kwargs=cross_attention_kwargs
|
227 |
+
).sample
|
228 |
+
|
229 |
+
self.unet.class_embedding = cls_embedding
|
230 |
+
|
231 |
+
# use view-independent text embeddings in LoRA
|
232 |
+
_, text_embeddings_cond = text_embeddings.chunk(2)
|
233 |
+
|
234 |
+
noise_pred_est = self.unet_lora(
|
235 |
+
latent_model_input,
|
236 |
+
torch.cat([t, t]),
|
237 |
+
encoder_hidden_states=torch.cat([text_embeddings_cond] * 2),
|
238 |
+
class_labels=torch.cat(
|
239 |
+
[
|
240 |
+
condition_label.reshape(b, -1),
|
241 |
+
torch.zeros_like(condition_label.reshape(b, -1)),
|
242 |
+
],
|
243 |
+
dim=0,
|
244 |
+
),
|
245 |
+
cross_attention_kwargs={"scale": 1.0},
|
246 |
+
).sample
|
247 |
+
|
248 |
+
noise_pred_pretrain_uncond, noise_pred_pretrain_text = noise_pred_pretrain.chunk(2)
|
249 |
+
|
250 |
+
noise_pred_pretrain = noise_pred_pretrain_uncond + guidance_scale * (
|
251 |
+
noise_pred_pretrain_text - noise_pred_pretrain_uncond
|
252 |
+
)
|
253 |
+
|
254 |
+
assert self.scheduler.config.prediction_type == "epsilon"
|
255 |
+
if self.scheduler_lora.config.prediction_type == "v_prediction":
|
256 |
+
alphas_cumprod = self.scheduler_lora.alphas_cumprod.to(
|
257 |
+
device=latents_noisy.device, dtype=latents_noisy.dtype
|
258 |
+
)
|
259 |
+
alpha_t = alphas_cumprod[t] ** 0.5
|
260 |
+
sigma_t = (1 - alphas_cumprod[t]) ** 0.5
|
261 |
+
|
262 |
+
noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).reshape(
|
263 |
+
-1, 1, 1, 1
|
264 |
+
) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).reshape(-1, 1, 1, 1)
|
265 |
+
|
266 |
+
noise_pred_est_uncond, noise_pred_est_camera = noise_pred_est.chunk(2)
|
267 |
+
|
268 |
+
noise_pred_est = noise_pred_est_uncond + guidance_scale_lora * (
|
269 |
+
noise_pred_est_camera - noise_pred_est_uncond
|
270 |
+
)
|
271 |
+
|
272 |
+
# w(t), sigma_t^2
|
273 |
+
w = (1 - self.alphas[t])
|
274 |
+
# w = self.alphas[t] ** 0.5 * (1 - self.alphas[t])
|
275 |
+
grad = loss_weight * w[:, None, None, None] * (noise_pred_pretrain - noise_pred_est)
|
276 |
+
|
277 |
+
grad = torch.nan_to_num(grad)
|
278 |
+
|
279 |
+
targets = (latents - grad).detach()
|
280 |
+
loss_vsd = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]
|
281 |
+
|
282 |
+
loss_lora = self.train_lora(latents, text_embeddings, condition_label)
|
283 |
+
|
284 |
+
loss = {
|
285 |
+
'loss_vsd': loss_vsd,
|
286 |
+
'loss_lora': loss_lora
|
287 |
+
}
|
288 |
+
|
289 |
+
if return_aux:
|
290 |
+
aux = {'grad': grad, 't': t, 'w': w}
|
291 |
+
return loss, aux
|
292 |
+
else:
|
293 |
+
return loss
|
294 |
+
|
295 |
+
|
296 |
+
|
297 |
+
if __name__ == '__main__':
|
298 |
+
import argparse
|
299 |
+
import matplotlib.pyplot as plt
|
300 |
+
|
301 |
+
parser = argparse.ArgumentParser()
|
302 |
+
parser.add_argument('prompt', type=str)
|
303 |
+
parser.add_argument('--negative', default='', type=str)
|
304 |
+
parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
|
305 |
+
parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
|
306 |
+
parser.add_argument('-H', type=int, default=512)
|
307 |
+
parser.add_argument('-W', type=int, default=512)
|
308 |
+
parser.add_argument('--seed', type=int, default=0)
|
309 |
+
parser.add_argument('--steps', type=int, default=50)
|
310 |
+
opt = parser.parse_args()
|
311 |
+
|
312 |
+
seed_everything(opt.seed)
|
313 |
+
|
314 |
+
device = torch.device('cuda')
|
315 |
+
|
316 |
+
sd = StableDiffusion_VSD(device, opt.sd_version, opt.hf_key)
|
317 |
+
|
318 |
+
imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)
|
319 |
+
|
320 |
+
# visualize image
|
321 |
+
plt.imshow(imgs[0])
|
322 |
+
plt.show()
|
323 |
+
plt.savefig(f'{opt.prompt}.png')
|
video3d/discriminator_architecture.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
from math import log2
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import autograd
|
6 |
+
|
7 |
+
|
8 |
+
class DCDiscriminator(nn.Module):
|
9 |
+
''' DC Discriminator class.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
in_dim (int): input dimension
|
13 |
+
n_feat (int): features of final hidden layer
|
14 |
+
img_size (int): input image size
|
15 |
+
'''
|
16 |
+
def __init__(self, in_dim=1, out_dim=1, n_feat=512, img_size=256, last_bias=False):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.in_dim = in_dim
|
20 |
+
self.out_dim = out_dim
|
21 |
+
n_layers = int(log2(img_size) - 2)
|
22 |
+
self.blocks = nn.ModuleList(
|
23 |
+
[nn.Conv2d(
|
24 |
+
in_dim,
|
25 |
+
int(n_feat / (2 ** (n_layers - 1))),
|
26 |
+
4, 2, 1, bias=False)] + [nn.Conv2d(
|
27 |
+
int(n_feat / (2 ** (n_layers - i))),
|
28 |
+
int(n_feat / (2 ** (n_layers - 1 - i))),
|
29 |
+
4, 2, 1, bias=False) for i in range(1, n_layers)])
|
30 |
+
|
31 |
+
self.conv_out = nn.Conv2d(n_feat, out_dim, 4, 1, 0, bias=last_bias)
|
32 |
+
self.actvn = nn.LeakyReLU(0.2, inplace=True)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
batch_size = x.shape[0]
|
36 |
+
if x.shape[1] != self.in_dim:
|
37 |
+
import ipdb; ipdb.set_trace()
|
38 |
+
x = x[:, :self.in_dim]
|
39 |
+
for layer in self.blocks:
|
40 |
+
x = self.actvn(layer(x))
|
41 |
+
|
42 |
+
out = self.conv_out(x)
|
43 |
+
out = out.reshape(batch_size, self.out_dim)
|
44 |
+
return out
|
45 |
+
|
46 |
+
|
47 |
+
# class ADADiscriminator(DCDiscriminator):
|
48 |
+
# def __init__(self, aug, aug_p, **kwargs):
|
49 |
+
# super().__init__(**kwargs)
|
50 |
+
# self.aug = build_from_config(aug)
|
51 |
+
# self.aug.p.copy_(torch.tensor(aug_p, dtype=torch.float32))
|
52 |
+
# self.resolution = kwargs['img_size']
|
53 |
+
|
54 |
+
# def get_resolution(self):
|
55 |
+
# return self.resolution
|
56 |
+
|
57 |
+
# def forward(self, x, **kwargs):
|
58 |
+
# x = self.aug(x)
|
59 |
+
# return super().forward(x, **kwargs)
|
60 |
+
|
61 |
+
|
62 |
+
# class ADADiscriminatorView(ADADiscriminator):
|
63 |
+
# def __init__(self, out_dim_position, out_dim_latent, **kwargs):
|
64 |
+
# self.out_dim_position = out_dim_position
|
65 |
+
# self.out_dim_latent = out_dim_latent
|
66 |
+
|
67 |
+
# super().__init__(**kwargs)
|
68 |
+
|
69 |
+
def bce_loss_target(d_out, target):
|
70 |
+
targets = d_out.new_full(size=d_out.size(), fill_value=target)
|
71 |
+
loss = F.binary_cross_entropy_with_logits(d_out, targets)
|
72 |
+
return loss.mean()
|
73 |
+
|
74 |
+
def compute_grad2(d_out, x_in):
|
75 |
+
batch_size = x_in.size(0)
|
76 |
+
grad_dout = autograd.grad(
|
77 |
+
outputs=d_out.sum(), inputs=x_in,
|
78 |
+
create_graph=True, retain_graph=True, only_inputs=True
|
79 |
+
)[0]
|
80 |
+
grad_dout2 = grad_dout.pow(2)
|
81 |
+
assert(grad_dout2.size() == x_in.size())
|
82 |
+
reg = grad_dout2.reshape(batch_size, -1).sum(1)
|
83 |
+
return reg.mean()
|
video3d/flow/__init__.py
ADDED
File without changes
|
video3d/flow/flow.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy.lib.npyio import load
|
2 |
+
from torch._C import device
|
3 |
+
import sys
|
4 |
+
sys.path.append('/scratch/shared/beegfs/szwu/projects/video3d/RAFT')
|
5 |
+
from core.raft import RAFT
|
6 |
+
|
7 |
+
from .utils import InputPadder
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class AttrDict(dict):
|
12 |
+
def __init__(self, *args, **kwargs):
|
13 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
14 |
+
self.__dict__ = self
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class FlowModel():
|
19 |
+
def __init__(self, model, device):
|
20 |
+
args = AttrDict({'model': model, 'small': False, 'mixed_precision': False, 'alternate_corr': False})
|
21 |
+
self.model = self.load_model(args, device)
|
22 |
+
self.device = device
|
23 |
+
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def load_model(args, device):
|
27 |
+
model = torch.nn.DataParallel(RAFT(args))
|
28 |
+
model.load_state_dict(torch.load(args.model))
|
29 |
+
|
30 |
+
model = model.module
|
31 |
+
model.to(device)
|
32 |
+
model.eval()
|
33 |
+
return model
|
34 |
+
|
35 |
+
|
36 |
+
def preprocess_image(self, image):
|
37 |
+
# image = image[:, :, ::-1].copy()
|
38 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float()
|
39 |
+
image = image.to(self.device)
|
40 |
+
image = image[None]
|
41 |
+
# size = [540, 960]
|
42 |
+
# image = torch.nn.functional.interpolate(image, size=size, mode='bilinear', align_corners=False)
|
43 |
+
padder = InputPadder(image.shape)
|
44 |
+
return padder.pad(image)[0], padder
|
45 |
+
|
46 |
+
|
47 |
+
def compute_flow(self, frame, next_frame, iters=20):
|
48 |
+
frame, padder = self.preprocess_image(frame)
|
49 |
+
next_frame, padder = self.preprocess_image(next_frame)
|
50 |
+
_, flow = self.model(frame, next_frame, iters=iters, test_mode=True)
|
51 |
+
return padder.unpad(flow)[0].permute(1, 2, 0).cpu().numpy()
|
video3d/flow/utils.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Taken from RAFT
|
2 |
+
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class InputPadder:
|
7 |
+
""" Pads images such that dimensions are divisible by 8 """
|
8 |
+
def __init__(self, dims, mode='sintel'):
|
9 |
+
self.ht, self.wd = dims[-2:]
|
10 |
+
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
11 |
+
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
12 |
+
if mode == 'sintel':
|
13 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
14 |
+
else:
|
15 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
16 |
+
|
17 |
+
def pad(self, *inputs):
|
18 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
19 |
+
|
20 |
+
def unpad(self,x):
|
21 |
+
ht, wd = x.shape[-2:]
|
22 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
23 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
video3d/geometry/dlmesh.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from ..render import mesh
|
13 |
+
from ..render import render
|
14 |
+
from ..render import regularizer
|
15 |
+
|
16 |
+
###############################################################################
|
17 |
+
# Geometry interface
|
18 |
+
###############################################################################
|
19 |
+
|
20 |
+
class DLMesh(torch.nn.Module):
|
21 |
+
def __init__(self, initial_guess, FLAGS):
|
22 |
+
super(DLMesh, self).__init__()
|
23 |
+
|
24 |
+
self.FLAGS = FLAGS
|
25 |
+
|
26 |
+
self.initial_guess = initial_guess
|
27 |
+
self.mesh = initial_guess.clone()
|
28 |
+
print("Base mesh has %d triangles and %d vertices." % (self.mesh.t_pos_idx.shape[0], self.mesh.v_pos.shape[0]))
|
29 |
+
|
30 |
+
self.mesh.v_pos = torch.nn.Parameter(self.mesh.v_pos, requires_grad=True)
|
31 |
+
self.register_parameter('vertex_pos', self.mesh.v_pos)
|
32 |
+
|
33 |
+
@torch.no_grad()
|
34 |
+
def getAABB(self):
|
35 |
+
return mesh.aabb(self.mesh)
|
36 |
+
|
37 |
+
def getMesh(self, material):
|
38 |
+
self.mesh.material = material
|
39 |
+
|
40 |
+
imesh = mesh.Mesh(base=self.mesh)
|
41 |
+
# Compute normals and tangent space
|
42 |
+
imesh = mesh.auto_normals(imesh)
|
43 |
+
imesh = mesh.compute_tangents(imesh)
|
44 |
+
return imesh
|
45 |
+
|
46 |
+
def render(self, glctx, target, lgt, opt_material, bsdf=None):
|
47 |
+
opt_mesh = self.getMesh(opt_material)
|
48 |
+
return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
|
49 |
+
num_layers=self.FLAGS.layers, msaa=True, background=target['background'], bsdf=bsdf)
|
50 |
+
|
51 |
+
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration):
|
52 |
+
|
53 |
+
# ==============================================================================================
|
54 |
+
# Render optimizable object with identical conditions
|
55 |
+
# ==============================================================================================
|
56 |
+
buffers = self.render(glctx, target, lgt, opt_material)
|
57 |
+
|
58 |
+
# ==============================================================================================
|
59 |
+
# Compute loss
|
60 |
+
# ==============================================================================================
|
61 |
+
t_iter = iteration / self.FLAGS.iter
|
62 |
+
|
63 |
+
# Image-space loss, split into a coverage component and a color component
|
64 |
+
color_ref = target['img']
|
65 |
+
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
|
66 |
+
img_loss += loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
|
67 |
+
|
68 |
+
reg_loss = torch.tensor([0], dtype=torch.float32, device="cuda")
|
69 |
+
|
70 |
+
# Compute regularizer.
|
71 |
+
if self.FLAGS.laplace == "absolute":
|
72 |
+
reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter)
|
73 |
+
elif self.FLAGS.laplace == "relative":
|
74 |
+
reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos - self.initial_guess.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter)
|
75 |
+
|
76 |
+
# Albedo (k_d) smoothnesss regularizer
|
77 |
+
reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
|
78 |
+
|
79 |
+
# Visibility regularizer
|
80 |
+
reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500)
|
81 |
+
|
82 |
+
# Light white balance regularizer
|
83 |
+
reg_loss = reg_loss + lgt.regularizer() * 0.005
|
84 |
+
|
85 |
+
return img_loss, reg_loss
|
video3d/geometry/dmtet.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
from multiprocessing.spawn import get_preparation_data
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from ..render import mesh
|
15 |
+
from ..render import render
|
16 |
+
from ..networks import MLPWithPositionalEncoding, MLPWithPositionalEncoding_Style
|
17 |
+
|
18 |
+
###############################################################################
|
19 |
+
# Marching tetrahedrons implementation (differentiable), adapted from
|
20 |
+
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
|
21 |
+
#
|
22 |
+
# Note this only supports batch size = 1.
|
23 |
+
###############################################################################
|
24 |
+
|
25 |
+
class DMTet:
|
26 |
+
def __init__(self):
|
27 |
+
self.triangle_table = torch.tensor([
|
28 |
+
[-1, -1, -1, -1, -1, -1],
|
29 |
+
[ 1, 0, 2, -1, -1, -1],
|
30 |
+
[ 4, 0, 3, -1, -1, -1],
|
31 |
+
[ 1, 4, 2, 1, 3, 4],
|
32 |
+
[ 3, 1, 5, -1, -1, -1],
|
33 |
+
[ 2, 3, 0, 2, 5, 3],
|
34 |
+
[ 1, 4, 0, 1, 5, 4],
|
35 |
+
[ 4, 2, 5, -1, -1, -1],
|
36 |
+
[ 4, 5, 2, -1, -1, -1],
|
37 |
+
[ 4, 1, 0, 4, 5, 1],
|
38 |
+
[ 3, 2, 0, 3, 5, 2],
|
39 |
+
[ 1, 3, 5, -1, -1, -1],
|
40 |
+
[ 4, 1, 2, 4, 3, 1],
|
41 |
+
[ 3, 0, 4, -1, -1, -1],
|
42 |
+
[ 2, 0, 1, -1, -1, -1],
|
43 |
+
[-1, -1, -1, -1, -1, -1]
|
44 |
+
], dtype=torch.long, device='cuda')
|
45 |
+
|
46 |
+
self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')
|
47 |
+
self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')
|
48 |
+
|
49 |
+
###############################################################################
|
50 |
+
# Utility functions
|
51 |
+
###############################################################################
|
52 |
+
|
53 |
+
def sort_edges(self, edges_ex2):
|
54 |
+
with torch.no_grad():
|
55 |
+
order = (edges_ex2[:,0] > edges_ex2[:,1]).long()
|
56 |
+
order = order.unsqueeze(dim=1)
|
57 |
+
|
58 |
+
a = torch.gather(input=edges_ex2, index=order, dim=1)
|
59 |
+
b = torch.gather(input=edges_ex2, index=1-order, dim=1)
|
60 |
+
|
61 |
+
return torch.stack([a, b],-1)
|
62 |
+
|
63 |
+
def map_uv(self, faces, face_gidx, max_idx):
|
64 |
+
N = int(np.ceil(np.sqrt((max_idx+1)//2)))
|
65 |
+
tex_y, tex_x = torch.meshgrid(
|
66 |
+
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
|
67 |
+
torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"),
|
68 |
+
indexing='ij'
|
69 |
+
)
|
70 |
+
|
71 |
+
pad = 0.9 / N
|
72 |
+
|
73 |
+
uvs = torch.stack([
|
74 |
+
tex_x , tex_y,
|
75 |
+
tex_x + pad, tex_y,
|
76 |
+
tex_x + pad, tex_y + pad,
|
77 |
+
tex_x , tex_y + pad
|
78 |
+
], dim=-1).view(-1, 2)
|
79 |
+
|
80 |
+
def _idx(tet_idx, N):
|
81 |
+
x = tet_idx % N
|
82 |
+
y = torch.div(tet_idx, N, rounding_mode='trunc')
|
83 |
+
return y * N + x
|
84 |
+
|
85 |
+
tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)
|
86 |
+
tri_idx = face_gidx % 2
|
87 |
+
|
88 |
+
uv_idx = torch.stack((
|
89 |
+
tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2
|
90 |
+
), dim = -1). view(-1, 3)
|
91 |
+
|
92 |
+
return uvs, uv_idx
|
93 |
+
|
94 |
+
###############################################################################
|
95 |
+
# Marching tets implementation
|
96 |
+
###############################################################################
|
97 |
+
|
98 |
+
def __call__(self, pos_nx3, sdf_n, tet_fx4):
|
99 |
+
with torch.no_grad():
|
100 |
+
occ_n = sdf_n > 0
|
101 |
+
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)
|
102 |
+
occ_sum = torch.sum(occ_fx4, -1)
|
103 |
+
valid_tets = (occ_sum>0) & (occ_sum<4)
|
104 |
+
occ_sum = occ_sum[valid_tets]
|
105 |
+
|
106 |
+
# find all vertices
|
107 |
+
all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)
|
108 |
+
all_edges = self.sort_edges(all_edges)
|
109 |
+
unique_edges, idx_map = torch.unique(all_edges,dim=0, return_inverse=True)
|
110 |
+
|
111 |
+
unique_edges = unique_edges.long()
|
112 |
+
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1
|
113 |
+
mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1
|
114 |
+
mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long,device="cuda")
|
115 |
+
idx_map = mapping[idx_map] # map edges to verts
|
116 |
+
|
117 |
+
interp_v = unique_edges[mask_edges]
|
118 |
+
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)
|
119 |
+
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)
|
120 |
+
edges_to_interp_sdf[:,-1] *= -1
|
121 |
+
|
122 |
+
denominator = edges_to_interp_sdf.sum(1,keepdim = True)
|
123 |
+
|
124 |
+
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1])/denominator
|
125 |
+
verts = (edges_to_interp * edges_to_interp_sdf).sum(1)
|
126 |
+
|
127 |
+
idx_map = idx_map.reshape(-1,6)
|
128 |
+
|
129 |
+
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda"))
|
130 |
+
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)
|
131 |
+
num_triangles = self.num_triangles_table[tetindex]
|
132 |
+
|
133 |
+
# Generate triangle indices
|
134 |
+
faces = torch.cat((
|
135 |
+
torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),
|
136 |
+
torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),
|
137 |
+
), dim=0)
|
138 |
+
|
139 |
+
# Get global face index (static, does not depend on topology)
|
140 |
+
num_tets = tet_fx4.shape[0]
|
141 |
+
tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets]
|
142 |
+
face_gidx = torch.cat((
|
143 |
+
tet_gidx[num_triangles == 1]*2,
|
144 |
+
torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)
|
145 |
+
), dim=0)
|
146 |
+
|
147 |
+
uvs, uv_idx = self.map_uv(faces, face_gidx, num_tets*2)
|
148 |
+
|
149 |
+
return verts, faces, uvs, uv_idx
|
150 |
+
|
151 |
+
###############################################################################
|
152 |
+
# Regularizer
|
153 |
+
###############################################################################
|
154 |
+
|
155 |
+
def sdf_bce_reg_loss(sdf, all_edges):
|
156 |
+
sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)
|
157 |
+
mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])
|
158 |
+
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
159 |
+
sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \
|
160 |
+
torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())
|
161 |
+
if torch.isnan(sdf_diff).any():
|
162 |
+
import ipdb; ipdb.set_trace()
|
163 |
+
return sdf_diff
|
164 |
+
|
165 |
+
###############################################################################
|
166 |
+
# Geometry interface
|
167 |
+
###############################################################################
|
168 |
+
|
169 |
+
class DMTetGeometry(torch.nn.Module):
|
170 |
+
def __init__(self, grid_res, scale, sdf_mode, num_layers=None, hidden_size=None, embedder_freq=None, embed_concat_pts=True, init_sdf=None, jitter_grid=0., perturb_sdf_iter=10000, sym_prior_shape=False, dim_of_classes=0, condition_choice='concat'):
|
171 |
+
super(DMTetGeometry, self).__init__()
|
172 |
+
|
173 |
+
self.sdf_mode = sdf_mode
|
174 |
+
self.grid_res = grid_res
|
175 |
+
self.marching_tets = DMTet()
|
176 |
+
self.grid_scale = scale
|
177 |
+
self.init_sdf = init_sdf
|
178 |
+
self.jitter_grid = jitter_grid
|
179 |
+
self.perturb_sdf_iter = perturb_sdf_iter
|
180 |
+
self.sym_prior_shape = sym_prior_shape
|
181 |
+
self.load_tets(self.grid_res, self.grid_scale)
|
182 |
+
|
183 |
+
if sdf_mode == "param":
|
184 |
+
sdf = torch.rand_like(self.verts[:,0]) - 0.1 # Random init.
|
185 |
+
self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)
|
186 |
+
self.register_parameter('sdf', self.sdf)
|
187 |
+
self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)
|
188 |
+
self.register_parameter('deform', self.deform)
|
189 |
+
else:
|
190 |
+
embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9
|
191 |
+
|
192 |
+
if dim_of_classes == 0 or (dim_of_classes != 0 and condition_choice == 'concat'):
|
193 |
+
self.mlp = MLPWithPositionalEncoding(
|
194 |
+
3,
|
195 |
+
1,
|
196 |
+
num_layers,
|
197 |
+
nf=hidden_size,
|
198 |
+
extra_dim=dim_of_classes,
|
199 |
+
dropout=0,
|
200 |
+
activation=None,
|
201 |
+
n_harmonic_functions=embedder_freq,
|
202 |
+
omega0=embedder_scaler,
|
203 |
+
embed_concat_pts=embed_concat_pts)
|
204 |
+
|
205 |
+
elif condition_choice == 'film' or condition_choice == 'mod':
|
206 |
+
self.mlp = MLPWithPositionalEncoding_Style(
|
207 |
+
3,
|
208 |
+
1,
|
209 |
+
num_layers,
|
210 |
+
nf=hidden_size,
|
211 |
+
extra_dim=dim_of_classes,
|
212 |
+
dropout=0,
|
213 |
+
activation=None,
|
214 |
+
n_harmonic_functions=embedder_freq,
|
215 |
+
omega0=embedder_scaler,
|
216 |
+
embed_concat_pts=embed_concat_pts,
|
217 |
+
style_choice=condition_choice)
|
218 |
+
|
219 |
+
else:
|
220 |
+
raise NotImplementedError
|
221 |
+
|
222 |
+
def load_tets(self, grid_res=None, scale=None):
|
223 |
+
if grid_res is None:
|
224 |
+
grid_res = self.grid_res
|
225 |
+
else:
|
226 |
+
self.grid_res = grid_res
|
227 |
+
if scale is None:
|
228 |
+
scale = self.grid_scale
|
229 |
+
else:
|
230 |
+
self.grid_scale = scale
|
231 |
+
tets = np.load('./data/tets/{}_tets.npz'.format(grid_res))
|
232 |
+
self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * scale # verts original scale (-0.5, 0.5)
|
233 |
+
self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')
|
234 |
+
self.generate_edges()
|
235 |
+
|
236 |
+
def get_sdf(self, pts=None, perturb_sdf=False, total_iter=0, class_vector=None):
|
237 |
+
if self.sdf_mode == 'param':
|
238 |
+
sdf = self.sdf
|
239 |
+
else:
|
240 |
+
if pts is None:
|
241 |
+
pts = self.verts
|
242 |
+
if self.sym_prior_shape:
|
243 |
+
xs, ys, zs = pts.unbind(-1)
|
244 |
+
pts = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
|
245 |
+
feat = None
|
246 |
+
if class_vector is not None:
|
247 |
+
feat = class_vector.unsqueeze(0).repeat(pts.shape[0], 1)
|
248 |
+
sdf = self.mlp(pts, feat=feat)
|
249 |
+
|
250 |
+
if self.init_sdf is None:
|
251 |
+
pass
|
252 |
+
elif type(self.init_sdf) in [float, int]:
|
253 |
+
sdf = sdf + self.init_sdf
|
254 |
+
elif self.init_sdf == 'sphere':
|
255 |
+
init_radius = self.grid_scale * 0.25
|
256 |
+
init_sdf = init_radius - pts.norm(dim=-1, keepdim=True) # init sdf is a sphere centered at origin
|
257 |
+
sdf = sdf + init_sdf
|
258 |
+
elif self.init_sdf == 'ellipsoid':
|
259 |
+
rxy = self.grid_scale * 0.15
|
260 |
+
xs, ys, zs = pts.unbind(-1)[:3]
|
261 |
+
init_sdf = rxy - torch.stack([xs, ys, zs/2], -1).norm(dim=-1, keepdim=True) # init sdf is approximately an ellipsoid centered at origin
|
262 |
+
sdf = sdf + init_sdf
|
263 |
+
else:
|
264 |
+
raise NotImplementedError
|
265 |
+
|
266 |
+
if perturb_sdf:
|
267 |
+
sdf = sdf + torch.randn_like(sdf) * 0.1 * max(0, 1-total_iter/self.perturb_sdf_iter)
|
268 |
+
return sdf
|
269 |
+
|
270 |
+
def get_sdf_gradient(self, class_vector=None):
|
271 |
+
assert self.sdf_mode == 'mlp', "Only MLP supports gradient computation."
|
272 |
+
num_samples = 5000
|
273 |
+
sample_points = (torch.rand(num_samples, 3, device=self.verts.device) - 0.5) * self.grid_scale
|
274 |
+
mesh_verts = self.mesh_verts.detach() + (torch.rand_like(self.mesh_verts) -0.5) * 0.1 * self.grid_scale
|
275 |
+
rand_idx = torch.randperm(len(mesh_verts), device=mesh_verts.device)[:5000]
|
276 |
+
mesh_verts = mesh_verts[rand_idx]
|
277 |
+
sample_points = torch.cat([sample_points, mesh_verts], 0)
|
278 |
+
sample_points.requires_grad = True
|
279 |
+
y = self.get_sdf(pts=sample_points, perturb_sdf=False, class_vector=class_vector)
|
280 |
+
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
281 |
+
try:
|
282 |
+
gradients = torch.autograd.grad(
|
283 |
+
outputs=[y],
|
284 |
+
inputs=sample_points,
|
285 |
+
grad_outputs=d_output,
|
286 |
+
create_graph=True,
|
287 |
+
retain_graph=True,
|
288 |
+
only_inputs=True)[0]
|
289 |
+
except RuntimeError: # For validation, we have disabled gradient calculation.
|
290 |
+
return torch.zeros_like(sample_points)
|
291 |
+
return gradients
|
292 |
+
|
293 |
+
def get_sdf_reg_loss(self, class_vector=None):
|
294 |
+
reg_loss = {"sdf_bce_reg_loss": sdf_bce_reg_loss(self.current_sdf, self.all_edges).mean()}
|
295 |
+
if self.sdf_mode == 'mlp':
|
296 |
+
reg_loss["sdf_gradient_reg_loss"] = ((self.get_sdf_gradient(class_vector=class_vector).norm(dim=-1) - 1) ** 2).mean()
|
297 |
+
reg_loss['sdf_inflate_reg_loss'] = -self.current_sdf.mean()
|
298 |
+
return reg_loss
|
299 |
+
|
300 |
+
def generate_edges(self):
|
301 |
+
with torch.no_grad():
|
302 |
+
edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda")
|
303 |
+
all_edges = self.indices[:,edges].reshape(-1,2)
|
304 |
+
all_edges_sorted = torch.sort(all_edges, dim=1)[0]
|
305 |
+
self.all_edges = torch.unique(all_edges_sorted, dim=0)
|
306 |
+
|
307 |
+
@torch.no_grad()
|
308 |
+
def getAABB(self):
|
309 |
+
return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
|
310 |
+
|
311 |
+
def getMesh(self, material=None, perturb_sdf=False, total_iter=0, jitter_grid=True, class_vector=None):
|
312 |
+
# Run DM tet to get a base mesh
|
313 |
+
v_deformed = self.verts
|
314 |
+
|
315 |
+
# if self.FLAGS.deform_grid:
|
316 |
+
# v_deformed = self.verts + 2 / (self.grid_res * 2) * torch.tanh(self.deform)
|
317 |
+
# else:
|
318 |
+
# v_deformed = self.verts
|
319 |
+
if jitter_grid and self.jitter_grid > 0:
|
320 |
+
jitter = (torch.rand(1, device=v_deformed.device)*2-1) * self.jitter_grid * self.grid_scale
|
321 |
+
v_deformed = v_deformed + jitter
|
322 |
+
|
323 |
+
self.current_sdf = self.get_sdf(v_deformed, perturb_sdf=perturb_sdf, total_iter=total_iter, class_vector=class_vector)
|
324 |
+
verts, faces, uvs, uv_idx = self.marching_tets(v_deformed, self.current_sdf, self.indices)
|
325 |
+
self.mesh_verts = verts
|
326 |
+
return mesh.make_mesh(verts[None], faces[None], uvs[None], uv_idx[None], material)
|
327 |
+
|
328 |
+
def render(self, glctx, target, lgt, opt_material, bsdf=None):
|
329 |
+
opt_mesh = self.getMesh(opt_material)
|
330 |
+
return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf)
|
331 |
+
|
332 |
+
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration):
|
333 |
+
# ==============================================================================================
|
334 |
+
# Render optimizable object with identical conditions
|
335 |
+
# ==============================================================================================
|
336 |
+
buffers = self.render(glctx, target, lgt, opt_material)
|
337 |
+
|
338 |
+
# ==============================================================================================
|
339 |
+
# Compute loss
|
340 |
+
# ==============================================================================================
|
341 |
+
t_iter = iteration / 20000
|
342 |
+
|
343 |
+
# Image-space loss, split into a coverage component and a color component
|
344 |
+
color_ref = target['img']
|
345 |
+
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
|
346 |
+
img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
|
347 |
+
|
348 |
+
# SDF regularizer
|
349 |
+
# sdf_weight = self.sdf_regularizer - (self.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter) # Dropoff to 0.01
|
350 |
+
reg_loss = sum(self.get_sdf_reg_loss().values)
|
351 |
+
|
352 |
+
# Albedo (k_d) smoothnesss regularizer
|
353 |
+
reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
|
354 |
+
|
355 |
+
# Visibility regularizer
|
356 |
+
reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500)
|
357 |
+
|
358 |
+
# Light white balance regularizer
|
359 |
+
reg_loss = reg_loss + lgt.regularizer() * 0.005
|
360 |
+
|
361 |
+
return img_loss, reg_loss
|
video3d/model.py
ADDED
@@ -0,0 +1,1526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing.spawn import prepare
|
2 |
+
from turtle import forward
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision.models as models
|
7 |
+
import nvdiffrast.torch as dr
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import os
|
11 |
+
import os.path as osp
|
12 |
+
|
13 |
+
from video3d.render.regularizer import get_edge_length, normal_consistency
|
14 |
+
from . import networks
|
15 |
+
from .renderer import *
|
16 |
+
from .utils import misc, meters, flow_viz, arap, custom_loss
|
17 |
+
from .dataloaders import get_sequence_loader, get_image_loader
|
18 |
+
from .cub_dataloaders import get_cub_loader
|
19 |
+
from .utils.skinning_v4 import estimate_bones, skinning
|
20 |
+
import lpips
|
21 |
+
from einops import rearrange
|
22 |
+
|
23 |
+
from .geometry.dmtet import DMTetGeometry
|
24 |
+
from .geometry.dlmesh import DLMesh
|
25 |
+
|
26 |
+
from .render import renderutils as ru
|
27 |
+
from .render import material
|
28 |
+
from .render import mlptexture
|
29 |
+
from .render import util
|
30 |
+
from .render import mesh
|
31 |
+
from .render import light
|
32 |
+
from .render import render
|
33 |
+
|
34 |
+
EPS = 1e-7
|
35 |
+
|
36 |
+
|
37 |
+
def get_optimizer(model, lr=0.0001, betas=(0.9, 0.999), weight_decay=0):
|
38 |
+
return torch.optim.Adam(
|
39 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
40 |
+
lr=lr, betas=betas, weight_decay=weight_decay)
|
41 |
+
|
42 |
+
|
43 |
+
def set_requires_grad(model, requires_grad):
|
44 |
+
if model is not None:
|
45 |
+
for param in model.parameters():
|
46 |
+
param.requires_grad = requires_grad
|
47 |
+
|
48 |
+
|
49 |
+
def forward_to_matrix(vec_forward, up=[0,1,0]):
|
50 |
+
up = torch.FloatTensor(up).to(vec_forward.device)
|
51 |
+
# vec_forward = nn.functional.normalize(vec_forward, p=2, dim=-1) # x right, y up, z forward
|
52 |
+
vec_right = up.expand_as(vec_forward).cross(vec_forward, dim=-1)
|
53 |
+
vec_right = nn.functional.normalize(vec_right, p=2, dim=-1)
|
54 |
+
vec_up = vec_forward.cross(vec_right, dim=-1)
|
55 |
+
vec_up = nn.functional.normalize(vec_up, p=2, dim=-1)
|
56 |
+
rot_mat = torch.stack([vec_right, vec_up, vec_forward], -2)
|
57 |
+
return rot_mat
|
58 |
+
|
59 |
+
|
60 |
+
def sample_pose_hypothesis_from_quad_prediction(poses_raw, total_iter, batch_size, num_frames, pose_xflip_recon=False, input_image_xflip_flag=None, rot_temp_scalar=1., num_hypos=4, naive_probs_iter=2000, best_pose_start_iter=6000, random_sample=True):
|
61 |
+
rots_pred = poses_raw[..., :num_hypos*4].view(-1, num_hypos, 4)
|
62 |
+
rots_logits = rots_pred[..., 0] # Nx4
|
63 |
+
temp = 1 / np.clip(total_iter / 1000 / rot_temp_scalar, 1., 100.)
|
64 |
+
|
65 |
+
rots_probs = torch.nn.functional.softmax(-rots_logits / temp, dim=1) # N x K
|
66 |
+
# naive_probs = torch.FloatTensor([10] + [1] * (num_hypos - 1)).to(rots_logits.device)
|
67 |
+
naive_probs = torch.ones(num_hypos).to(rots_logits.device)
|
68 |
+
naive_probs = naive_probs / naive_probs.sum()
|
69 |
+
naive_probs_weight = np.clip(1 - (total_iter - naive_probs_iter) / 2000, 0, 1)
|
70 |
+
rots_probs = naive_probs.view(1, num_hypos) * naive_probs_weight + rots_probs * (1 - naive_probs_weight)
|
71 |
+
|
72 |
+
rots_pred = rots_pred[..., 1:4]
|
73 |
+
trans_pred = poses_raw[..., -3:]
|
74 |
+
best_rot_idx = torch.argmax(rots_probs, dim=1) # N
|
75 |
+
if random_sample:
|
76 |
+
# rand_rot_idx = torch.randint(0, 4, (batch_size * num_frames,), device=poses_raw.device) # N
|
77 |
+
rand_rot_idx = torch.randperm(batch_size * num_frames, device=poses_raw.device) % num_hypos # N
|
78 |
+
# rand_rot_idx = torch.randperm(batch_size, device=poses_raw.device)[:,None].repeat(1, num_frames).view(-1) % 4 # N
|
79 |
+
best_flag = (torch.randperm(batch_size * num_frames, device=poses_raw.device) / (batch_size * num_frames) < np.clip((total_iter - best_pose_start_iter)/2000, 0, 0.8)).long()
|
80 |
+
rand_flag = 1 - best_flag
|
81 |
+
# best_flag = torch.zeros_like(best_rot_idx)
|
82 |
+
rot_idx = best_rot_idx * best_flag + rand_rot_idx * (1 - best_flag)
|
83 |
+
else:
|
84 |
+
rand_flag = torch.zeros_like(best_rot_idx)
|
85 |
+
rot_idx = best_rot_idx
|
86 |
+
rot_pred = torch.gather(rots_pred, 1, rot_idx[:, None, None].expand(-1, 1, 3))[:, 0] # Nx3
|
87 |
+
pose_raw = torch.cat([rot_pred, trans_pred], -1)
|
88 |
+
rot_prob = torch.gather(rots_probs, 1, rot_idx[:, None].expand(-1, 1))[:, 0] # N
|
89 |
+
rot_logit = torch.gather(rots_logits, 1, rot_idx[:, None].expand(-1, 1))[:, 0] # N
|
90 |
+
|
91 |
+
if pose_xflip_recon:
|
92 |
+
raise NotImplementedError
|
93 |
+
rot_mat = forward_to_matrix(pose_raw[:, :3], up=[0, 1, 0])
|
94 |
+
pose = torch.cat([rot_mat.view(batch_size * num_frames, -1), pose_raw[:, 3:]], -1)
|
95 |
+
return pose_raw, pose, rot_idx, rot_prob, rot_logit, rots_probs, rand_flag
|
96 |
+
|
97 |
+
|
98 |
+
class PriorPredictor(nn.Module):
|
99 |
+
def __init__(self, cfgs):
|
100 |
+
super().__init__()
|
101 |
+
dmtet_grid = cfgs.get('dmtet_grid', 64)
|
102 |
+
grid_scale = cfgs.get('grid_scale', 5)
|
103 |
+
prior_sdf_mode = cfgs.get('prior_sdf_mode', 'mlp')
|
104 |
+
num_layers_shape = cfgs.get('num_layers_shape', 5)
|
105 |
+
hidden_size = cfgs.get('hidden_size', 64)
|
106 |
+
embedder_freq_shape = cfgs.get('embedder_freq_shape', 8)
|
107 |
+
embed_concat_pts = cfgs.get('embed_concat_pts', True)
|
108 |
+
init_sdf = cfgs.get('init_sdf', None)
|
109 |
+
jitter_grid = cfgs.get('jitter_grid', 0.)
|
110 |
+
perturb_sdf_iter = cfgs.get('perturb_sdf_iter', 10000)
|
111 |
+
sym_prior_shape = cfgs.get('sym_prior_shape', False)
|
112 |
+
self.netShape = DMTetGeometry(dmtet_grid, grid_scale, prior_sdf_mode, num_layers=num_layers_shape, hidden_size=hidden_size, embedder_freq=embedder_freq_shape, embed_concat_pts=embed_concat_pts, init_sdf=init_sdf, jitter_grid=jitter_grid, perturb_sdf_iter=perturb_sdf_iter, sym_prior_shape=sym_prior_shape)
|
113 |
+
|
114 |
+
mlp_hidden_size = cfgs.get('hidden_size', 64)
|
115 |
+
tet_bbox = self.netShape.getAABB()
|
116 |
+
self.render_dino_mode = cfgs.get('render_dino_mode', None)
|
117 |
+
num_layers_dino = cfgs.get("num_layers_dino", 5)
|
118 |
+
dino_feature_recon_dim = cfgs.get('dino_feature_recon_dim', 64)
|
119 |
+
sym_dino = cfgs.get("sym_dino", False)
|
120 |
+
dino_min = torch.zeros(dino_feature_recon_dim) + cfgs.get('dino_min', 0.)
|
121 |
+
dino_max = torch.zeros(dino_feature_recon_dim) + cfgs.get('dino_max', 1.)
|
122 |
+
min_max = torch.stack((dino_min, dino_max), dim=0)
|
123 |
+
if self.render_dino_mode is None:
|
124 |
+
pass
|
125 |
+
elif self.render_dino_mode == 'feature_mlpnv':
|
126 |
+
self.netDINO = mlptexture.MLPTexture3D(tet_bbox, channels=dino_feature_recon_dim, internal_dims=mlp_hidden_size, hidden=num_layers_dino-1, feat_dim=0, min_max=min_max, bsdf=None, perturb_normal=False, symmetrize=sym_dino)
|
127 |
+
elif self.render_dino_mode == 'feature_mlp':
|
128 |
+
embedder_scaler = 2 * np.pi / grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9
|
129 |
+
embed_concat_pts = cfgs.get('embed_concat_pts', True)
|
130 |
+
self.netDINO = networks.MLPTextureSimple(
|
131 |
+
3, # x, y, z coordinates
|
132 |
+
dino_feature_recon_dim,
|
133 |
+
num_layers_dino,
|
134 |
+
nf=mlp_hidden_size,
|
135 |
+
dropout=0,
|
136 |
+
activation="sigmoid",
|
137 |
+
min_max=min_max,
|
138 |
+
n_harmonic_functions=cfgs.get('embedder_freq_dino', 8),
|
139 |
+
omega0=embedder_scaler,
|
140 |
+
extra_dim=0,
|
141 |
+
embed_concat_pts=embed_concat_pts,
|
142 |
+
perturb_normal=False,
|
143 |
+
symmetrize=sym_dino
|
144 |
+
)
|
145 |
+
elif self.render_dino_mode == 'cluster':
|
146 |
+
num_layers_dino = cfgs.get("num_layers_dino", 5)
|
147 |
+
dino_cluster_dim = cfgs.get('dino_cluster_dim', 64)
|
148 |
+
self.netDINO = mlptexture.MLPTexture3D(tet_bbox, channels=dino_cluster_dim, internal_dims=mlp_hidden_size, hidden=num_layers_dino-1, feat_dim=0, min_max=None, bsdf=None, perturb_normal=False, symmetrize=sym_dino)
|
149 |
+
else:
|
150 |
+
raise NotImplementedError
|
151 |
+
|
152 |
+
def forward(self, perturb_sdf=False, total_iter=None, is_training=True):
|
153 |
+
prior_shape = self.netShape.getMesh(perturb_sdf=perturb_sdf, total_iter=total_iter, jitter_grid=is_training)
|
154 |
+
return prior_shape, self.netDINO
|
155 |
+
|
156 |
+
|
157 |
+
class InstancePredictor(nn.Module):
|
158 |
+
def __init__(self, cfgs, tet_bbox=None):
|
159 |
+
super().__init__()
|
160 |
+
self.cfgs = cfgs
|
161 |
+
self.grid_scale = cfgs.get('grid_scale', 5)
|
162 |
+
|
163 |
+
self.enable_encoder = cfgs.get('enable_encoder', False)
|
164 |
+
if self.enable_encoder:
|
165 |
+
encoder_latent_dim = cfgs.get('latent_dim', 256)
|
166 |
+
encoder_pretrained = cfgs.get('encoder_pretrained', False)
|
167 |
+
encoder_frozen = cfgs.get('encoder_frozen', False)
|
168 |
+
encoder_arch = cfgs.get('encoder_arch', 'simple')
|
169 |
+
in_image_size = cfgs.get('in_image_size', 256)
|
170 |
+
self.dino_feature_input = cfgs.get('dino_feature_input', False)
|
171 |
+
dino_feature_dim = cfgs.get('dino_feature_dim', 64)
|
172 |
+
if encoder_arch == 'simple':
|
173 |
+
if self.dino_feature_input:
|
174 |
+
self.netEncoder = networks.EncoderWithDINO(cin_rgb=3, cin_dino=dino_feature_dim, cout=encoder_latent_dim, in_size=in_image_size, zdim=None, nf=64, activation=None)
|
175 |
+
else:
|
176 |
+
self.netEncoder = networks.Encoder(cin=3, cout=encoder_latent_dim, in_size=in_image_size, zdim=None, nf=64, activation=None)
|
177 |
+
elif encoder_arch == 'vgg':
|
178 |
+
self.netEncoder = networks.VGGEncoder(cout=encoder_latent_dim, pretrained=encoder_pretrained)
|
179 |
+
elif encoder_arch == 'resnet':
|
180 |
+
self.netEncoder = networks.ResnetEncoder(cout=encoder_latent_dim, pretrained=encoder_pretrained)
|
181 |
+
elif encoder_arch == 'vit':
|
182 |
+
which_vit = cfgs.get('which_vit', 'dino_vits8')
|
183 |
+
vit_final_layer_type = cfgs.get('vit_final_layer_type', 'conv')
|
184 |
+
self.netEncoder = networks.ViTEncoder(cout=encoder_latent_dim, which_vit=which_vit, pretrained=encoder_pretrained, frozen=encoder_frozen, in_size=in_image_size, final_layer_type=vit_final_layer_type)
|
185 |
+
else:
|
186 |
+
raise NotImplementedError
|
187 |
+
else:
|
188 |
+
encoder_latent_dim = 0
|
189 |
+
|
190 |
+
mlp_hidden_size = cfgs.get('hidden_size', 64)
|
191 |
+
|
192 |
+
bsdf = cfgs.get("bsdf", 'diffuse')
|
193 |
+
num_layers_tex = cfgs.get("num_layers_tex", 5)
|
194 |
+
feat_dim = cfgs.get("latent_dim", 64) if self.enable_encoder else 0
|
195 |
+
perturb_normal = cfgs.get("perturb_normal", False)
|
196 |
+
sym_texture = cfgs.get("sym_texture", False)
|
197 |
+
kd_min = torch.FloatTensor(cfgs.get('kd_min', [0., 0., 0., 0.]))
|
198 |
+
kd_max = torch.FloatTensor(cfgs.get('kd_max', [1., 1., 1., 1.]))
|
199 |
+
ks_min = torch.FloatTensor(cfgs.get('ks_min', [0., 0., 0.]))
|
200 |
+
ks_max = torch.FloatTensor(cfgs.get('ks_max', [0., 0., 0.]))
|
201 |
+
nrm_min = torch.FloatTensor(cfgs.get('nrm_min', [-1., -1., 0.]))
|
202 |
+
nrm_max = torch.FloatTensor(cfgs.get('nrm_max', [1., 1., 1.]))
|
203 |
+
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
|
204 |
+
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
|
205 |
+
min_max = torch.stack((mlp_min, mlp_max), dim=0)
|
206 |
+
out_chn = 9
|
207 |
+
# TODO: if the tet verts are deforming, we need to recompute tet_bbox
|
208 |
+
texture_mode = cfgs.get("texture_mode", 'mlp')
|
209 |
+
if texture_mode == 'mlpnv':
|
210 |
+
self.netTexture = mlptexture.MLPTexture3D(tet_bbox, channels=out_chn, internal_dims=mlp_hidden_size, hidden=num_layers_tex-1, feat_dim=feat_dim, min_max=min_max, bsdf=bsdf, perturb_normal=perturb_normal, symmetrize=sym_texture)
|
211 |
+
elif texture_mode == 'mlp':
|
212 |
+
embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9
|
213 |
+
embed_concat_pts = cfgs.get('embed_concat_pts', True)
|
214 |
+
self.netTexture = networks.MLPTextureSimple(
|
215 |
+
3, # x, y, z coordinates
|
216 |
+
out_chn,
|
217 |
+
num_layers_tex,
|
218 |
+
nf=mlp_hidden_size,
|
219 |
+
dropout=0,
|
220 |
+
activation="sigmoid",
|
221 |
+
min_max=min_max,
|
222 |
+
n_harmonic_functions=cfgs.get('embedder_freq_tex', 10),
|
223 |
+
omega0=embedder_scaler,
|
224 |
+
extra_dim=feat_dim,
|
225 |
+
embed_concat_pts=embed_concat_pts,
|
226 |
+
perturb_normal=perturb_normal,
|
227 |
+
symmetrize=sym_texture
|
228 |
+
)
|
229 |
+
|
230 |
+
self.rot_rep = cfgs.get('rot_rep', 'euler_angle')
|
231 |
+
self.enable_pose = cfgs.get('enable_pose', False)
|
232 |
+
if self.enable_pose:
|
233 |
+
cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.)
|
234 |
+
fov = cfgs.get('crop_fov_approx', 25)
|
235 |
+
half_range = np.tan(fov /2 /180 * np.pi) * cam_pos_z_offset # 2.22
|
236 |
+
self.max_trans_xy_range = half_range * cfgs.get('max_trans_xy_range_ratio', 1.)
|
237 |
+
self.max_trans_z_range = half_range * cfgs.get('max_trans_z_range_ratio', 1.)
|
238 |
+
self.lookat_init = cfgs.get('lookat_init', None)
|
239 |
+
self.lookat_zeroy = cfgs.get('lookat_zeroy', False)
|
240 |
+
self.rot_temp_scalar = cfgs.get('rot_temp_scalar', 1.)
|
241 |
+
self.naive_probs_iter = cfgs.get('naive_probs_iter', 2000)
|
242 |
+
self.best_pose_start_iter = cfgs.get('best_pose_start_iter', 6000)
|
243 |
+
|
244 |
+
if self.rot_rep == 'euler_angle':
|
245 |
+
pose_cout = 6
|
246 |
+
elif self.rot_rep == 'quaternion':
|
247 |
+
pose_cout = 7
|
248 |
+
elif self.rot_rep == 'lookat':
|
249 |
+
pose_cout = 6
|
250 |
+
elif self.rot_rep == 'quadlookat':
|
251 |
+
self.num_pose_hypos = 4
|
252 |
+
pose_cout = (3 + 1) * self.num_pose_hypos + 3 # 4 forward vectors for 4 quadrants, 4 quadrant classification logits, 3 for translation
|
253 |
+
self.orthant_signs = torch.FloatTensor([[1,1,1], [-1,1,1], [-1,1,-1], [1,1,-1]])
|
254 |
+
elif self.rot_rep == 'octlookat':
|
255 |
+
self.num_pose_hypos = 8
|
256 |
+
pose_cout = (3 + 1) * self.num_pose_hypos + 3 # 4 forward vectors for 8 octants, 8 octant classification logits, 3 for translation
|
257 |
+
self.orthant_signs = torch.stack(torch.meshgrid([torch.arange(1, -2, -2)] *3), -1).view(-1, 3) # 8x3
|
258 |
+
else:
|
259 |
+
raise NotImplementedError
|
260 |
+
|
261 |
+
self.pose_arch = cfgs.get('pose_arch', 'mlp')
|
262 |
+
if self.pose_arch == 'mlp':
|
263 |
+
num_layers_pose = cfgs.get('num_layers_pose', 5)
|
264 |
+
self.netPose = networks.MLP(
|
265 |
+
encoder_latent_dim,
|
266 |
+
pose_cout,
|
267 |
+
num_layers_pose,
|
268 |
+
nf=mlp_hidden_size,
|
269 |
+
dropout=0,
|
270 |
+
activation=None
|
271 |
+
)
|
272 |
+
elif self.pose_arch == 'encoder':
|
273 |
+
if self.dino_feature_input:
|
274 |
+
dino_feature_dim = cfgs.get('dino_feature_dim', 64)
|
275 |
+
self.netPose = networks.EncoderWithDINO(cin_rgb=3, cin_dino=dino_feature_dim, cout=pose_cout, in_size=in_image_size, zdim=None, nf=64, activation=None)
|
276 |
+
else:
|
277 |
+
self.netPose = networks.Encoder(cin=3, cout=pose_cout, in_size=in_image_size, zdim=None, nf=64, activation=None)
|
278 |
+
elif self.pose_arch in ['encoder_dino_patch_out', 'encoder_dino_patch_key']:
|
279 |
+
if which_vit == 'dino_vits8':
|
280 |
+
dino_feat_dim = 384
|
281 |
+
elif which_vit == 'dinov2_vits14':
|
282 |
+
dino_feat_dim = 384
|
283 |
+
elif which_vit == 'dino_vitb8':
|
284 |
+
dino_feat_dim = 768
|
285 |
+
self.netPose = networks.Encoder32(cin=dino_feat_dim, cout=pose_cout, nf=256, activation=None)
|
286 |
+
elif self.pose_arch == 'vit':
|
287 |
+
encoder_pretrained = cfgs.get('encoder_pretrained', False)
|
288 |
+
encoder_frozen = cfgs.get('encoder_frozen', False)
|
289 |
+
which_vit = cfgs.get('which_vit', 'dino_vits8')
|
290 |
+
vit_final_layer_type = cfgs.get('vit_final_layer_type', 'conv')
|
291 |
+
self.netPose = networks.ViTEncoder(cout=encoder_latent_dim, which_vit=which_vit, pretrained=encoder_pretrained, frozen=encoder_frozen, in_size=in_image_size, final_layer_type=vit_final_layer_type)
|
292 |
+
else:
|
293 |
+
raise NotImplementedError
|
294 |
+
|
295 |
+
self.enable_deform = cfgs.get('enable_deform', False)
|
296 |
+
if self.enable_deform:
|
297 |
+
embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9
|
298 |
+
embed_concat_pts = cfgs.get('embed_concat_pts', True)
|
299 |
+
num_layers_deform = cfgs.get('num_layers_deform', 5)
|
300 |
+
self.deform_epochs = np.arange(*cfgs.get('deform_epochs', [0, 0]))
|
301 |
+
sym_deform = cfgs.get("sym_deform", False)
|
302 |
+
self.netDeform = networks.MLPWithPositionalEncoding(
|
303 |
+
3, # x, y, z coordinates
|
304 |
+
3, # dx, dy, dz deformation
|
305 |
+
num_layers_deform,
|
306 |
+
nf=mlp_hidden_size,
|
307 |
+
dropout=0,
|
308 |
+
activation=None,
|
309 |
+
n_harmonic_functions=cfgs.get('embedder_freq_deform', 10),
|
310 |
+
omega0=embedder_scaler,
|
311 |
+
extra_dim=encoder_latent_dim,
|
312 |
+
embed_concat_pts=embed_concat_pts,
|
313 |
+
symmetrize=sym_deform
|
314 |
+
)
|
315 |
+
|
316 |
+
self.enable_articulation = cfgs.get('enable_articulation', False)
|
317 |
+
if self.enable_articulation:
|
318 |
+
self.num_body_bones = cfgs.get('num_body_bones', 4)
|
319 |
+
self.articulation_multiplier = cfgs.get('articulation_multiplier', 1)
|
320 |
+
self.static_root_bones = cfgs.get('static_root_bones', False)
|
321 |
+
self.skinning_temperature = cfgs.get('skinning_temperature', 1)
|
322 |
+
self.articulation_epochs = np.arange(*cfgs.get('articulation_epochs', [0, 0]))
|
323 |
+
self.num_legs = cfgs.get('num_legs', 0)
|
324 |
+
self.num_leg_bones = cfgs.get('num_leg_bones', 0)
|
325 |
+
self.body_bones_type = cfgs.get('body_bones_type', 'z_minmax')
|
326 |
+
self.perturb_articulation_epochs = np.arange(*cfgs.get('perturb_articulation_epochs', [0, 0]))
|
327 |
+
self.num_bones = self.num_body_bones + self.num_legs * self.num_leg_bones
|
328 |
+
self.constrain_legs = cfgs.get('constrain_legs', False)
|
329 |
+
self.attach_legs_to_body_epochs = np.arange(*cfgs.get('attach_legs_to_body_epochs', [0, 0]))
|
330 |
+
self.max_arti_angle = cfgs.get('max_arti_angle', 60)
|
331 |
+
|
332 |
+
num_layers_arti = cfgs.get('num_layers_arti', 5)
|
333 |
+
which_vit = cfgs.get('which_vit', 'dino_vits8')
|
334 |
+
if which_vit == 'dino_vits8':
|
335 |
+
dino_feat_dim = 384
|
336 |
+
elif which_vit == 'dino_vitb8':
|
337 |
+
dino_feat_dim = 768
|
338 |
+
self.articulation_arch = cfgs.get('articulation_arch', 'mlp')
|
339 |
+
self.articulation_feature_mode = cfgs.get('articulation_feature_mode', 'sample')
|
340 |
+
embedder_freq_arti = cfgs.get('embedder_freq_arti', 8)
|
341 |
+
if self.articulation_feature_mode == 'global':
|
342 |
+
feat_dim = encoder_latent_dim
|
343 |
+
elif self.articulation_feature_mode == 'sample':
|
344 |
+
feat_dim = dino_feat_dim
|
345 |
+
elif self.articulation_feature_mode == 'sample+global':
|
346 |
+
feat_dim = encoder_latent_dim + dino_feat_dim
|
347 |
+
if self.articulation_feature_mode == 'attention':
|
348 |
+
arti_feat_attn_zdim = cfgs.get('arti_feat_attn_zdim', 128)
|
349 |
+
pos_dim = 1 + 2 + 3*2
|
350 |
+
self.netFeatureAttn = networks.FeatureAttention(which_vit, pos_dim, embedder_freq_arti, arti_feat_attn_zdim, img_size=in_image_size)
|
351 |
+
embedder_scaler = np.pi * 0.9 # originally (-1, 1) rescale to (-pi, pi) * 0.9
|
352 |
+
self.netArticulation = networks.ArticulationNetwork(self.articulation_arch, feat_dim, 1+2+3*2, num_layers_arti, mlp_hidden_size, n_harmonic_functions=embedder_freq_arti, omega0=embedder_scaler)
|
353 |
+
self.kinematic_tree_epoch = -1
|
354 |
+
|
355 |
+
self.enable_lighting = cfgs.get('enable_lighting', False)
|
356 |
+
if self.enable_lighting:
|
357 |
+
num_layers_light = cfgs.get('num_layers_light', 5)
|
358 |
+
amb_diff_min = torch.FloatTensor(cfgs.get('amb_diff_min', [0., 0.]))
|
359 |
+
amb_diff_max = torch.FloatTensor(cfgs.get('amb_diff_max', [1., 1.]))
|
360 |
+
intensity_min_max = torch.stack((amb_diff_min, amb_diff_max), dim=0)
|
361 |
+
self.netLight = light.DirectionalLight(encoder_latent_dim, num_layers_light, mlp_hidden_size, intensity_min_max=intensity_min_max)
|
362 |
+
|
363 |
+
self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.)
|
364 |
+
self.crop_fov_approx = cfgs.get("crop_fov_approx", 25)
|
365 |
+
|
366 |
+
def forward_encoder(self, images, dino_features=None):
|
367 |
+
images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1)
|
368 |
+
patch_out = patch_key = None
|
369 |
+
if self.dino_feature_input and self.cfgs.get('encoder_arch', 'simple') != 'vit':
|
370 |
+
dino_features_in = dino_features.view(-1, *dino_features.shape[2:]) * 2 - 1 # rescale to (-1, 1)
|
371 |
+
feat_out = self.netEncoder(images_in, dino_features_in) # Shape: (B, latent_dim)
|
372 |
+
elif self.cfgs.get('encoder_arch', 'simple') == 'vit':
|
373 |
+
feat_out, feat_key, patch_out, patch_key = self.netEncoder(images_in, return_patches=True)
|
374 |
+
else:
|
375 |
+
feat_out = self.netEncoder(images_in) # Shape: (B, latent_dim)
|
376 |
+
return feat_out, feat_key, patch_out, patch_key
|
377 |
+
|
378 |
+
def forward_pose(self, images, feat, patch_out, patch_key, dino_features):
|
379 |
+
if self.pose_arch == 'mlp':
|
380 |
+
pose = self.netPose(feat)
|
381 |
+
elif self.pose_arch == 'encoder':
|
382 |
+
images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1)
|
383 |
+
if self.dino_feature_input:
|
384 |
+
dino_features_in = dino_features.view(-1, *dino_features.shape[2:]) * 2 - 1 # rescale to (-1, 1)
|
385 |
+
pose = self.netPose(images_in, dino_features_in) # Shape: (B, latent_dim)
|
386 |
+
else:
|
387 |
+
pose = self.netPose(images_in) # Shape: (B, latent_dim)
|
388 |
+
elif self.pose_arch == 'vit':
|
389 |
+
images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1)
|
390 |
+
pose = self.netPose(images_in)
|
391 |
+
elif self.pose_arch == 'encoder_dino_patch_out':
|
392 |
+
pose = self.netPose(patch_out) # Shape: (B, latent_dim)
|
393 |
+
elif self.pose_arch == 'encoder_dino_patch_key':
|
394 |
+
pose = self.netPose(patch_key) # Shape: (B, latent_dim)
|
395 |
+
else:
|
396 |
+
raise NotImplementedError
|
397 |
+
trans_pred = pose[...,-3:].tanh() * torch.FloatTensor([self.max_trans_xy_range, self.max_trans_xy_range, self.max_trans_z_range]).to(pose.device)
|
398 |
+
if self.rot_rep == 'euler_angle':
|
399 |
+
multiplier = 1.
|
400 |
+
if self.gradually_expand_yaw:
|
401 |
+
# multiplier += (min(iteration, 20000) // 500) * 0.25
|
402 |
+
multiplier *= 1.2 ** (min(iteration, 20000) // 500) # 1.125^40 = 111.200
|
403 |
+
rot_pred = torch.cat([pose[...,:1], pose[...,1:2]*multiplier, pose[...,2:3]], -1).tanh()
|
404 |
+
rot_pred = rot_pred * torch.FloatTensor([self.max_rot_x_range, self.max_rot_y_range, self.max_rot_z_range]).to(pose.device) /180 * np.pi
|
405 |
+
|
406 |
+
elif self.rot_rep == 'quaternion':
|
407 |
+
quat_init = torch.FloatTensor([0.01,0,0,0]).to(pose.device)
|
408 |
+
rot_pred = pose[...,:4] + quat_init
|
409 |
+
rot_pred = nn.functional.normalize(rot_pred, p=2, dim=-1)
|
410 |
+
# rot_pred = torch.cat([rot_pred[...,:1].abs(), rot_pred[...,1:]], -1) # make real part non-negative
|
411 |
+
rot_pred = rot_pred * rot_pred[...,:1].sign() # make real part non-negative
|
412 |
+
|
413 |
+
elif self.rot_rep == 'lookat':
|
414 |
+
vec_forward_raw = pose[...,:3]
|
415 |
+
if self.lookat_init is not None:
|
416 |
+
vec_forward_raw = vec_forward_raw + torch.FloatTensor(self.lookat_init).to(pose.device)
|
417 |
+
if self.lookat_zeroy:
|
418 |
+
vec_forward_raw = vec_forward_raw * torch.FloatTensor([1,0,1]).to(pose.device)
|
419 |
+
vec_forward_raw = nn.functional.normalize(vec_forward_raw, p=2, dim=-1) # x right, y up, z forward
|
420 |
+
rot_pred = vec_forward_raw
|
421 |
+
|
422 |
+
elif self.rot_rep in ['quadlookat', 'octlookat']:
|
423 |
+
rots_pred = pose[..., :self.num_pose_hypos*4].view(-1, self.num_pose_hypos, 4) # (B, T, K, 4)
|
424 |
+
rots_logits = rots_pred[..., :1]
|
425 |
+
vec_forward_raw = rots_pred[..., 1:4]
|
426 |
+
xs, ys, zs = vec_forward_raw.unbind(-1)
|
427 |
+
margin = 0.
|
428 |
+
xs = nn.functional.softplus(xs, beta=np.log(2)/(0.5+margin)) - margin # initialize to 0.5
|
429 |
+
if self.rot_rep == 'octlookat':
|
430 |
+
ys = nn.functional.softplus(ys, beta=np.log(2)/(0.5+margin)) - margin # initialize to 0.5
|
431 |
+
if self.lookat_zeroy:
|
432 |
+
ys = ys * 0
|
433 |
+
zs = nn.functional.softplus(zs, beta=2*np.log(2)) # initialize to 0.5
|
434 |
+
vec_forward_raw = torch.stack([xs, ys, zs], -1)
|
435 |
+
vec_forward_raw = vec_forward_raw * self.orthant_signs.to(pose.device)
|
436 |
+
vec_forward_raw = nn.functional.normalize(vec_forward_raw, p=2, dim=-1) # x right, y up, z forward
|
437 |
+
rot_pred = torch.cat([rots_logits, vec_forward_raw], -1).view(-1, self.num_pose_hypos*4)
|
438 |
+
|
439 |
+
else:
|
440 |
+
raise NotImplementedError
|
441 |
+
|
442 |
+
pose = torch.cat([rot_pred, trans_pred], -1)
|
443 |
+
return pose
|
444 |
+
|
445 |
+
def forward_deformation(self, shape, feat=None):
|
446 |
+
original_verts = shape.v_pos
|
447 |
+
num_verts = original_verts.shape[1]
|
448 |
+
if feat is not None:
|
449 |
+
deform_feat = feat[:, None, :].repeat(1, num_verts, 1) # Shape: (B, num_verts, latent_dim)
|
450 |
+
original_verts = original_verts.repeat(len(feat),1,1)
|
451 |
+
deformation = self.netDeform(original_verts, deform_feat) * 0.1 # Shape: (B, num_verts, 3)
|
452 |
+
shape = shape.deform(deformation)
|
453 |
+
return shape, deformation
|
454 |
+
|
455 |
+
def forward_articulation(self, shape, feat, patch_feat, mvp, w2c, batch_size, num_frames, epoch):
|
456 |
+
"""
|
457 |
+
Forward propagation of articulation. For each bone, the network takes: 1) the 3D location of the bone; 2) the feature of the patch which
|
458 |
+
the bone is projected to; and 3) an encoding of the bone's index to predict the bone's rotation (represented by an Euler angle).
|
459 |
+
|
460 |
+
Args:
|
461 |
+
shape: a Mesh object, whose v_pos has batch size BxF or 1.
|
462 |
+
feat: the feature of the patches. Shape: (BxF, feat_dim, num_patches_per_axis, num_patches_per_axis)
|
463 |
+
mvp: the model-view-projection matrix. Shape: (BxF, 4, 4)
|
464 |
+
|
465 |
+
Returns:
|
466 |
+
shape: a Mesh object, whose v_pos has batch size BxF (collapsed).
|
467 |
+
articulation_angles: the predicted bone rotations. Shape: (B, F, num_bones, 3)
|
468 |
+
aux: a dictionary containing auxiliary information.
|
469 |
+
"""
|
470 |
+
verts = shape.v_pos
|
471 |
+
if len(verts) == 1:
|
472 |
+
verts = verts[None]
|
473 |
+
else:
|
474 |
+
verts = verts.view(batch_size, num_frames, *verts.shape[1:])
|
475 |
+
|
476 |
+
if self.kinematic_tree_epoch != epoch:
|
477 |
+
# if (epoch == self.articulation_epochs[0]) and (self.kinematic_tree_epoch != epoch):
|
478 |
+
# if (epoch in [self.articulation_epochs[0], self.articulation_epochs[0]+2, self.articulation_epochs[0]+4]) and (self.kinematic_tree_epoch != epoch):
|
479 |
+
attach_legs_to_body = epoch in self.attach_legs_to_body_epochs
|
480 |
+
bones, self.kinematic_tree, self.bone_aux = estimate_bones(verts.detach(), self.num_body_bones, n_legs=self.num_legs, n_leg_bones=self.num_leg_bones, body_bones_type=self.body_bones_type, compute_kinematic_chain=True, attach_legs_to_body=attach_legs_to_body)
|
481 |
+
self.kinematic_tree_epoch = epoch
|
482 |
+
else:
|
483 |
+
bones = estimate_bones(verts.detach(), self.num_body_bones, n_legs=self.num_legs, n_leg_bones=self.num_leg_bones, body_bones_type=self.body_bones_type, compute_kinematic_chain=False, aux=self.bone_aux)
|
484 |
+
|
485 |
+
bones_pos = bones # Shape: (B, F, K, 2, 3)
|
486 |
+
if batch_size > bones_pos.shape[0] or num_frames > bones_pos.shape[1]:
|
487 |
+
assert bones_pos.shape[0] == 1 and bones_pos.shape[1] == 1, "If there is a mismatch, then there must be only one canonical mesh."
|
488 |
+
bones_pos = bones_pos.repeat(batch_size, num_frames, 1, 1, 1)
|
489 |
+
num_bones = bones_pos.shape[2]
|
490 |
+
bones_pos = bones_pos.view(batch_size*num_frames, num_bones, 2, 3) # NxKx2x3
|
491 |
+
bones_mid_pos = bones_pos.mean(2) # NxKx3
|
492 |
+
bones_idx = torch.arange(num_bones).to(bones_pos.device)
|
493 |
+
|
494 |
+
bones_mid_pos_world4 = torch.cat([bones_mid_pos, torch.ones_like(bones_mid_pos[..., :1])], -1) # NxKx4
|
495 |
+
bones_mid_pos_clip4 = bones_mid_pos_world4 @ mvp.transpose(-1, -2)
|
496 |
+
bones_mid_pos_uv = bones_mid_pos_clip4[..., :2] / bones_mid_pos_clip4[..., 3:4]
|
497 |
+
bones_mid_pos_uv = bones_mid_pos_uv.detach()
|
498 |
+
|
499 |
+
bones_pos_world4 = torch.cat([bones_pos, torch.ones_like(bones_pos[..., :1])], -1) # NxKx2x4
|
500 |
+
bones_pos_cam4 = bones_pos_world4 @ w2c[:,None].transpose(-1, -2)
|
501 |
+
bones_pos_cam3 = bones_pos_cam4[..., :3] / bones_pos_cam4[..., 3:4]
|
502 |
+
bones_pos_cam3 = bones_pos_cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(bones_pos_cam3.device).view(1, 1, 1, 3)
|
503 |
+
bones_pos_in = bones_pos_cam3.view(batch_size*num_frames, num_bones, 2*3) / self.grid_scale * 2 # (-1, 1), NxKx(2*3)
|
504 |
+
|
505 |
+
bones_idx_in = ((bones_idx[None, :, None] + 0.5) / num_bones * 2 - 1).repeat(batch_size * num_frames, 1, 1) # (-1, 1)
|
506 |
+
bones_pos_in = torch.cat([bones_mid_pos_uv, bones_pos_in, bones_idx_in], -1).detach()
|
507 |
+
|
508 |
+
if self.articulation_feature_mode == 'global':
|
509 |
+
bones_patch_features = feat[:, None].repeat(1, num_bones, 1) # (BxF, K, feat_dim)
|
510 |
+
elif self.articulation_feature_mode == 'sample':
|
511 |
+
bones_patch_features = F.grid_sample(patch_feat, bones_mid_pos_uv.view(batch_size * num_frames, 1, -1, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # (BxF, K, feat_dim)
|
512 |
+
elif self.articulation_feature_mode == 'sample+global':
|
513 |
+
bones_patch_features = F.grid_sample(patch_feat, bones_mid_pos_uv.view(batch_size * num_frames, 1, -1, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # (BxF, K, feat_dim)
|
514 |
+
bones_patch_features = torch.cat([feat[:, None].repeat(1, num_bones, 1), bones_patch_features], -1)
|
515 |
+
elif self.articulation_feature_mode == 'attention':
|
516 |
+
bones_patch_features = self.netFeatureAttn(bones_pos_in, patch_feat)
|
517 |
+
else:
|
518 |
+
raise NotImplementedError
|
519 |
+
|
520 |
+
articulation_angles = self.netArticulation(bones_patch_features, bones_pos_in).view(batch_size, num_frames, num_bones, 3) * self.articulation_multiplier
|
521 |
+
|
522 |
+
if self.static_root_bones:
|
523 |
+
root_bones = [self.num_body_bones // 2 - 1, self.num_body_bones - 1]
|
524 |
+
tmp_mask = torch.ones_like(articulation_angles)
|
525 |
+
tmp_mask[:, :, root_bones] = 0
|
526 |
+
articulation_angles = articulation_angles * tmp_mask
|
527 |
+
|
528 |
+
articulation_angles = articulation_angles.tanh()
|
529 |
+
|
530 |
+
if self.constrain_legs:
|
531 |
+
leg_bones_posx = [self.num_body_bones + i for i in range(self.num_leg_bones * self.num_legs // 2)]
|
532 |
+
leg_bones_negx = [self.num_body_bones + self.num_leg_bones * self.num_legs // 2 + i for i in range(self.num_leg_bones * self.num_legs // 2)]
|
533 |
+
|
534 |
+
tmp_mask = torch.zeros_like(articulation_angles)
|
535 |
+
tmp_mask[:, :, leg_bones_posx + leg_bones_negx, 2] = 1
|
536 |
+
articulation_angles = tmp_mask * (articulation_angles * 0.3) + (1 - tmp_mask) * articulation_angles # no twist
|
537 |
+
|
538 |
+
tmp_mask = torch.zeros_like(articulation_angles)
|
539 |
+
tmp_mask[:, :, leg_bones_posx + leg_bones_negx, 1] = 1
|
540 |
+
articulation_angles = tmp_mask * (articulation_angles * 0.3) + (1 - tmp_mask) * articulation_angles # (-0.4, 0.4), limit side bending
|
541 |
+
|
542 |
+
if epoch in self.perturb_articulation_epochs:
|
543 |
+
articulation_angles = articulation_angles + torch.randn_like(articulation_angles) * 0.1
|
544 |
+
articulation_angles = articulation_angles * self.max_arti_angle / 180 * np.pi
|
545 |
+
|
546 |
+
verts_articulated, aux = skinning(verts, bones, self.kinematic_tree, articulation_angles,
|
547 |
+
output_posed_bones=True, temperature=self.skinning_temperature)
|
548 |
+
verts_articulated = verts_articulated.view(batch_size*num_frames, *verts_articulated.shape[2:])
|
549 |
+
v_tex = shape.v_tex
|
550 |
+
if len(v_tex) != len(verts_articulated):
|
551 |
+
v_tex = v_tex.repeat(len(verts_articulated), 1, 1)
|
552 |
+
shape = mesh.make_mesh(
|
553 |
+
verts_articulated,
|
554 |
+
shape.t_pos_idx,
|
555 |
+
v_tex,
|
556 |
+
shape.t_tex_idx,
|
557 |
+
shape.material)
|
558 |
+
return shape, articulation_angles, aux
|
559 |
+
|
560 |
+
def get_camera_extrinsics_from_pose(self, pose, znear=0.1, zfar=1000.):
|
561 |
+
N = len(pose)
|
562 |
+
cam_pos_offset = torch.FloatTensor([0, 0, -self.cam_pos_z_offset]).to(pose.device)
|
563 |
+
pose_R = pose[:, :9].view(N, 3, 3).transpose(2, 1)
|
564 |
+
pose_T = pose[:, -3:] + cam_pos_offset[None, None, :]
|
565 |
+
pose_T = pose_T.view(N, 3, 1)
|
566 |
+
pose_RT = torch.cat([pose_R, pose_T], axis=2) # Nx3x4
|
567 |
+
w2c = torch.cat([pose_RT, torch.FloatTensor([0, 0, 0, 1]).repeat(N, 1, 1).to(pose.device)], axis=1) # Nx4x4
|
568 |
+
# We assume the images are perfect square.
|
569 |
+
proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, znear, zfar)[None].to(pose.device)
|
570 |
+
mvp = torch.matmul(proj, w2c)
|
571 |
+
campos = -torch.matmul(pose_R.transpose(2, 1), pose_T).view(N, 3)
|
572 |
+
return mvp, w2c, campos
|
573 |
+
|
574 |
+
def forward(self, images=None, prior_shape=None, epoch=None, dino_features=None, dino_clusters=None, total_iter=None, is_training=True):
|
575 |
+
batch_size, num_frames = images.shape[:2]
|
576 |
+
if self.enable_encoder:
|
577 |
+
feat_out, feat_key, patch_out, patch_key = self.forward_encoder(images, dino_features)
|
578 |
+
else:
|
579 |
+
feat_out = feat_key = patch_out = patch_key = None
|
580 |
+
shape = prior_shape
|
581 |
+
texture = self.netTexture
|
582 |
+
|
583 |
+
multi_hypothesis_aux = {}
|
584 |
+
if self.enable_pose:
|
585 |
+
poses_raw = self.forward_pose(images, feat_out, patch_out, patch_key, dino_features)
|
586 |
+
pose_raw, pose, rot_idx, rot_prob, rot_logit, rots_probs, rand_pose_flag = sample_pose_hypothesis_from_quad_prediction(poses_raw, total_iter, batch_size, num_frames, rot_temp_scalar=self.rot_temp_scalar, num_hypos=self.num_pose_hypos, naive_probs_iter=self.naive_probs_iter, best_pose_start_iter=self.best_pose_start_iter, random_sample=is_training)
|
587 |
+
multi_hypothesis_aux['rot_idx'] = rot_idx
|
588 |
+
multi_hypothesis_aux['rot_prob'] = rot_prob
|
589 |
+
multi_hypothesis_aux['rot_logit'] = rot_logit
|
590 |
+
multi_hypothesis_aux['rots_probs'] = rots_probs
|
591 |
+
multi_hypothesis_aux['rand_pose_flag'] = rand_pose_flag
|
592 |
+
else:
|
593 |
+
raise NotImplementedError
|
594 |
+
mvp, w2c, campos = self.get_camera_extrinsics_from_pose(pose)
|
595 |
+
|
596 |
+
deformation = None
|
597 |
+
if self.enable_deform and epoch in self.deform_epochs:
|
598 |
+
shape, deformation = self.forward_deformation(shape, feat_key)
|
599 |
+
|
600 |
+
arti_params, articulation_aux = None, {}
|
601 |
+
if self.enable_articulation and epoch in self.articulation_epochs:
|
602 |
+
shape, arti_params, articulation_aux = self.forward_articulation(shape, feat_key, patch_key, mvp, w2c, batch_size, num_frames, epoch)
|
603 |
+
|
604 |
+
if self.enable_lighting:
|
605 |
+
light = self.netLight
|
606 |
+
else:
|
607 |
+
light = None
|
608 |
+
|
609 |
+
aux = articulation_aux
|
610 |
+
aux.update(multi_hypothesis_aux)
|
611 |
+
|
612 |
+
return shape, pose_raw, pose, mvp, w2c, campos, texture, feat_out, deformation, arti_params, light, aux
|
613 |
+
|
614 |
+
|
615 |
+
class Unsup3D:
|
616 |
+
def __init__(self, cfgs):
|
617 |
+
self.cfgs = cfgs
|
618 |
+
self.device = cfgs.get('device', 'cpu')
|
619 |
+
self.in_image_size = cfgs.get('in_image_size', 128)
|
620 |
+
self.out_image_size = cfgs.get('out_image_size', 128)
|
621 |
+
|
622 |
+
self.num_epochs = cfgs.get('num_epochs', 10)
|
623 |
+
self.lr = cfgs.get('lr', 1e-4)
|
624 |
+
self.use_scheduler = cfgs.get('use_scheduler', False)
|
625 |
+
if self.use_scheduler:
|
626 |
+
scheduler_milestone = cfgs.get('scheduler_milestone', [1,2,3,4,5])
|
627 |
+
scheduler_gamma = cfgs.get('scheduler_gamma', 0.5)
|
628 |
+
self.make_scheduler = lambda optim: torch.optim.lr_scheduler.MultiStepLR(optim, milestones=scheduler_milestone, gamma=scheduler_gamma)
|
629 |
+
|
630 |
+
self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.)
|
631 |
+
self.full_size_h = cfgs.get('full_size_h', 1080)
|
632 |
+
self.full_size_w = cfgs.get('full_size_w', 1920)
|
633 |
+
# self.fov_w = cfgs.get('fov_w', 60)
|
634 |
+
# self.fov_h = np.arctan(np.tan(self.fov_w /2 /180*np.pi) / self.full_size_w * self.full_size_h) *2 /np.pi*180 # 36
|
635 |
+
self.crop_fov_approx = cfgs.get("crop_fov_approx", 25)
|
636 |
+
self.mesh_regularization_mode = cfgs.get('mesh_regularization_mode', 'seq')
|
637 |
+
|
638 |
+
self.enable_prior = cfgs.get('enable_prior', False)
|
639 |
+
if self.enable_prior:
|
640 |
+
self.netPrior = PriorPredictor(self.cfgs)
|
641 |
+
self.prior_lr = cfgs.get('prior_lr', self.lr)
|
642 |
+
self.prior_weight_decay = cfgs.get('prior_weight_decay', 0.)
|
643 |
+
self.prior_only_epochs = cfgs.get('prior_only_epochs', 0)
|
644 |
+
self.netInstance = InstancePredictor(self.cfgs, tet_bbox=self.netPrior.netShape.getAABB())
|
645 |
+
self.perturb_sdf = cfgs.get('perturb_sdf', False)
|
646 |
+
self.blur_mask = cfgs.get('blur_mask', False)
|
647 |
+
self.blur_mask_iter = cfgs.get('blur_mask_iter', 1)
|
648 |
+
|
649 |
+
self.seqshape_epochs = np.arange(*cfgs.get('seqshape_epochs', [0, self.num_epochs]))
|
650 |
+
self.avg_texture_epochs = np.arange(*cfgs.get('avg_texture_epochs', [0, 0]))
|
651 |
+
self.swap_texture_epochs = np.arange(*cfgs.get('swap_texture_epochs', [0, 0]))
|
652 |
+
self.swap_priorshape_epochs = np.arange(*cfgs.get('swap_priorshape_epochs', [0, 0]))
|
653 |
+
self.avg_seqshape_epochs = np.arange(*cfgs.get('avg_seqshape_epochs', [0, 0]))
|
654 |
+
self.swap_seqshape_epochs = np.arange(*cfgs.get('swap_seqshape_epochs', [0, 0]))
|
655 |
+
self.pose_epochs = np.arange(*cfgs.get('pose_epochs', [0, 0]))
|
656 |
+
self.pose_iters = cfgs.get('pose_iters', 0)
|
657 |
+
self.deform_type = cfgs.get('deform_type', None)
|
658 |
+
self.mesh_reg_decay_epoch = cfgs.get('mesh_reg_decay_epoch', 0)
|
659 |
+
self.sdf_reg_decay_start_iter = cfgs.get('sdf_reg_decay_start_iter', 0)
|
660 |
+
self.mesh_reg_decay_rate = cfgs.get('mesh_reg_decay_rate', 1)
|
661 |
+
self.texture_epochs = np.arange(*cfgs.get('texture_epochs', [0, self.num_epochs]))
|
662 |
+
self.zflip_epochs = np.arange(*cfgs.get('zflip_epochs', [0, self.num_epochs]))
|
663 |
+
self.lookat_zflip_loss_epochs = np.arange(*cfgs.get('lookat_zflip_loss_epochs', [0, self.num_epochs]))
|
664 |
+
self.lookat_zflip_no_other_losses = cfgs.get('lookat_zflip_no_other_losses', False)
|
665 |
+
self.flow_loss_epochs = np.arange(*cfgs.get('flow_loss_epochs', [0, self.num_epochs]))
|
666 |
+
self.sdf_inflate_reg_loss_epochs = np.arange(*cfgs.get('sdf_inflate_reg_loss_epochs', [0, self.num_epochs]))
|
667 |
+
self.arti_reg_loss_epochs = np.arange(*cfgs.get('arti_reg_loss_epochs', [0, self.num_epochs]))
|
668 |
+
self.background_mode = cfgs.get('background_mode', 'background')
|
669 |
+
self.shape_prior_type = cfgs.get('shape_prior_type', 'deform')
|
670 |
+
self.backward_prior = cfgs.get('backward_prior', True)
|
671 |
+
self.resume_prior_optim = cfgs.get('resume_prior_optim', True)
|
672 |
+
self.dmtet_grid_smaller_epoch = cfgs.get('dmtet_grid_smaller_epoch', 0)
|
673 |
+
self.dmtet_grid_smaller = cfgs.get('dmtet_grid_smaller', 128)
|
674 |
+
self.dmtet_grid = cfgs.get('dmtet_grid', 256)
|
675 |
+
self.pose_xflip_recon_epochs = np.arange(*cfgs.get('pose_xflip_recon_epochs', [0, 0]))
|
676 |
+
self.rot_rand_quad_epochs = np.arange(*cfgs.get('rot_rand_quad_epochs', [0, 0]))
|
677 |
+
self.rot_all_quad_epochs = np.arange(*cfgs.get('rot_all_quad_epochs', [0, 0]))
|
678 |
+
|
679 |
+
## perceptual loss
|
680 |
+
if cfgs.get('perceptual_loss_weight', 0.) > 0:
|
681 |
+
self.perceptual_loss_use_lin = cfgs.get('perceptual_loss_use_lin', True)
|
682 |
+
self.perceptual_loss = lpips.LPIPS(net='vgg', lpips=self.perceptual_loss_use_lin)
|
683 |
+
|
684 |
+
self.glctx = dr.RasterizeGLContext()
|
685 |
+
self.render_flow = self.cfgs.get('flow_loss_weight', 0.) > 0.
|
686 |
+
self.extra_renders = cfgs.get('extra_renders', [])
|
687 |
+
self.renderer_spp = cfgs.get('renderer_spp', 1)
|
688 |
+
self.dino_feature_recon_dim = cfgs.get('dino_feature_recon_dim', 64)
|
689 |
+
|
690 |
+
self.total_loss = 0.
|
691 |
+
self.all_scores = torch.Tensor()
|
692 |
+
self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results')
|
693 |
+
|
694 |
+
@staticmethod
|
695 |
+
def get_data_loaders(cfgs, dataset, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None):
|
696 |
+
train_loader = val_loader = test_loader = None
|
697 |
+
color_jitter_train = cfgs.get('color_jitter_train', None)
|
698 |
+
color_jitter_val = cfgs.get('color_jitter_val', None)
|
699 |
+
random_flip_train = cfgs.get('random_flip_train', False)
|
700 |
+
|
701 |
+
## video dataset
|
702 |
+
if dataset == 'video':
|
703 |
+
data_loader_mode = cfgs.get('data_loader_mode', 'n_frame')
|
704 |
+
skip_beginning = cfgs.get('skip_beginning', 4)
|
705 |
+
skip_end = cfgs.get('skip_end', 4)
|
706 |
+
num_sample_frames = cfgs.get('num_sample_frames', 2)
|
707 |
+
min_seq_len = cfgs.get('min_seq_len', 10)
|
708 |
+
max_seq_len = cfgs.get('max_seq_len', 10)
|
709 |
+
debug_seq = cfgs.get('debug_seq', False)
|
710 |
+
random_sample_train_frames = cfgs.get('random_sample_train_frames', False)
|
711 |
+
shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False)
|
712 |
+
random_sample_val_frames = cfgs.get('random_sample_val_frames', False)
|
713 |
+
load_background = cfgs.get('background_mode', 'none') == 'background'
|
714 |
+
rgb_suffix = cfgs.get('rgb_suffix', '.png')
|
715 |
+
load_dino_feature = cfgs.get('load_dino_feature', False)
|
716 |
+
load_dino_cluster = cfgs.get('load_dino_cluster', False)
|
717 |
+
dino_feature_dim = cfgs.get('dino_feature_dim', 64)
|
718 |
+
get_loader = lambda **kwargs: get_sequence_loader(
|
719 |
+
mode=data_loader_mode,
|
720 |
+
batch_size=batch_size,
|
721 |
+
num_workers=num_workers,
|
722 |
+
in_image_size=in_image_size,
|
723 |
+
out_image_size=out_image_size,
|
724 |
+
debug_seq=debug_seq,
|
725 |
+
skip_beginning=skip_beginning,
|
726 |
+
skip_end=skip_end,
|
727 |
+
num_sample_frames=num_sample_frames,
|
728 |
+
min_seq_len=min_seq_len,
|
729 |
+
max_seq_len=max_seq_len,
|
730 |
+
load_background=load_background,
|
731 |
+
rgb_suffix=rgb_suffix,
|
732 |
+
load_dino_feature=load_dino_feature,
|
733 |
+
load_dino_cluster=load_dino_cluster,
|
734 |
+
dino_feature_dim=dino_feature_dim,
|
735 |
+
**kwargs)
|
736 |
+
|
737 |
+
if run_train:
|
738 |
+
assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}"
|
739 |
+
print(f"Loading training data from {train_data_dir}")
|
740 |
+
train_loader = get_loader(data_dir=train_data_dir, is_validation=False, random_sample=random_sample_train_frames, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train)
|
741 |
+
|
742 |
+
if val_data_dir is not None:
|
743 |
+
assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}"
|
744 |
+
print(f"Loading validation data from {val_data_dir}")
|
745 |
+
val_loader = get_loader(data_dir=val_data_dir, is_validation=True, random_sample=random_sample_val_frames, shuffle=False, dense_sample=False, color_jitter=color_jitter_val, random_flip=False)
|
746 |
+
if run_test:
|
747 |
+
assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}"
|
748 |
+
print(f"Loading testing data from {test_data_dir}")
|
749 |
+
test_loader = get_loader(data_dir=test_data_dir, is_validation=True, dense_sample=False, color_jitter=None, random_flip=False)
|
750 |
+
|
751 |
+
## CUB dataset
|
752 |
+
elif dataset == 'cub':
|
753 |
+
get_loader = lambda **kwargs: get_cub_loader(
|
754 |
+
batch_size=batch_size,
|
755 |
+
num_workers=num_workers,
|
756 |
+
image_size=in_image_size,
|
757 |
+
**kwargs)
|
758 |
+
|
759 |
+
if run_train:
|
760 |
+
assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}"
|
761 |
+
print(f"Loading training data from {train_data_dir}")
|
762 |
+
train_loader = get_loader(data_dir=train_data_dir, split='train', is_validation=False)
|
763 |
+
val_loader = get_loader(data_dir=val_data_dir, split='val', is_validation=True)
|
764 |
+
|
765 |
+
if run_test:
|
766 |
+
assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}"
|
767 |
+
print(f"Loading testing data from {test_data_dir}")
|
768 |
+
test_loader = get_loader(data_dir=test_data_dir, split='test', is_validation=True)
|
769 |
+
|
770 |
+
## other datasets
|
771 |
+
else:
|
772 |
+
get_loader = lambda **kwargs: get_image_loader(
|
773 |
+
batch_size=batch_size,
|
774 |
+
num_workers=num_workers,
|
775 |
+
image_size=in_image_size,
|
776 |
+
**kwargs)
|
777 |
+
|
778 |
+
if run_train:
|
779 |
+
assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}"
|
780 |
+
print(f"Loading training data from {train_data_dir}")
|
781 |
+
train_loader = get_loader(data_dir=train_data_dir, is_validation=False, color_jitter=color_jitter_train)
|
782 |
+
|
783 |
+
if val_data_dir is not None:
|
784 |
+
assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}"
|
785 |
+
print(f"Loading validation data from {val_data_dir}")
|
786 |
+
val_loader = get_loader(data_dir=val_data_dir, is_validation=True, color_jitter=color_jitter_val)
|
787 |
+
|
788 |
+
if run_test:
|
789 |
+
assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}"
|
790 |
+
print(f"Loading testing data from {test_data_dir}")
|
791 |
+
test_loader = get_loader(data_dir=test_data_dir, is_validation=True, color_jitter=None)
|
792 |
+
|
793 |
+
return train_loader, val_loader, test_loader
|
794 |
+
|
795 |
+
def load_model_state(self, cp):
|
796 |
+
self.netInstance.load_state_dict(cp["netInstance"])
|
797 |
+
if self.enable_prior:
|
798 |
+
self.netPrior.load_state_dict(cp["netPrior"])
|
799 |
+
|
800 |
+
def load_optimizer_state(self, cp):
|
801 |
+
self.optimizerInstance.load_state_dict(cp["optimizerInstance"])
|
802 |
+
if self.use_scheduler:
|
803 |
+
if 'schedulerInstance' in cp:
|
804 |
+
self.schedulerInstance.load_state_dict(cp["schedulerInstance"])
|
805 |
+
if self.enable_prior and self.resume_prior_optim:
|
806 |
+
self.optimizerPrior.load_state_dict(cp["optimizerPrior"])
|
807 |
+
if self.use_scheduler:
|
808 |
+
if 'schedulerPrior' in cp:
|
809 |
+
self.schedulerPrior.load_state_dict(cp["schedulerPrior"])
|
810 |
+
|
811 |
+
def get_model_state(self):
|
812 |
+
state = {"netInstance": self.netInstance.state_dict()}
|
813 |
+
if self.enable_prior:
|
814 |
+
state["netPrior"] = self.netPrior.state_dict()
|
815 |
+
return state
|
816 |
+
|
817 |
+
def get_optimizer_state(self):
|
818 |
+
state = {"optimizerInstance": self.optimizerInstance.state_dict()}
|
819 |
+
if self.use_scheduler:
|
820 |
+
state["schedulerInstance"] = self.schedulerInstance.state_dict()
|
821 |
+
if self.enable_prior:
|
822 |
+
state["optimizerPrior"] = self.optimizerPrior.state_dict()
|
823 |
+
if self.use_scheduler:
|
824 |
+
state["schedulerPrior"] = self.schedulerPrior.state_dict()
|
825 |
+
return state
|
826 |
+
|
827 |
+
def to(self, device):
|
828 |
+
self.device = device
|
829 |
+
self.netInstance.to(device)
|
830 |
+
if self.enable_prior:
|
831 |
+
self.netPrior.to(device)
|
832 |
+
if hasattr(self, 'perceptual_loss'):
|
833 |
+
self.perceptual_loss.to(device)
|
834 |
+
|
835 |
+
def set_train(self):
|
836 |
+
self.netInstance.train()
|
837 |
+
if self.enable_prior:
|
838 |
+
self.netPrior.train()
|
839 |
+
|
840 |
+
def set_eval(self):
|
841 |
+
self.netInstance.eval()
|
842 |
+
if self.enable_prior:
|
843 |
+
self.netPrior.eval()
|
844 |
+
|
845 |
+
def reset_optimizers(self):
|
846 |
+
print("Resetting optimizers...")
|
847 |
+
self.optimizerInstance = get_optimizer(self.netInstance, self.lr)
|
848 |
+
if self.use_scheduler:
|
849 |
+
self.schedulerInstance = self.make_scheduler(self.optimizerInstance)
|
850 |
+
if self.enable_prior:
|
851 |
+
self.optimizerPrior = get_optimizer(self.netPrior, lr=self.prior_lr, weight_decay=self.prior_weight_decay)
|
852 |
+
if self.use_scheduler:
|
853 |
+
self.schedulerPrior = self.make_scheduler(self.optimizerPrior)
|
854 |
+
|
855 |
+
def backward(self):
|
856 |
+
self.optimizerInstance.zero_grad()
|
857 |
+
if self.backward_prior:
|
858 |
+
self.optimizerPrior.zero_grad()
|
859 |
+
self.total_loss.backward()
|
860 |
+
self.optimizerInstance.step()
|
861 |
+
if self.backward_prior:
|
862 |
+
self.optimizerPrior.step()
|
863 |
+
self.total_loss = 0.
|
864 |
+
|
865 |
+
def scheduler_step(self):
|
866 |
+
if self.use_scheduler:
|
867 |
+
self.schedulerInstance.step()
|
868 |
+
if self.enable_prior:
|
869 |
+
self.schedulerPrior.step()
|
870 |
+
|
871 |
+
def zflip_pose(self, pose):
|
872 |
+
if self.rot_rep == 'lookat':
|
873 |
+
vec_forward = pose[:,:,6:9]
|
874 |
+
vec_forward = vec_forward * torch.FloatTensor([1,1,-1]).view(1,1,3).to(vec_forward.device)
|
875 |
+
up = torch.FloatTensor([0,1,0]).to(pose.device).view(1,1,3)
|
876 |
+
vec_right = up.expand_as(vec_forward).cross(vec_forward, dim=-1)
|
877 |
+
vec_right = nn.functional.normalize(vec_right, p=2, dim=-1)
|
878 |
+
vec_up = vec_forward.cross(vec_right, dim=-1)
|
879 |
+
vec_up = nn.functional.normalize(vec_up, p=2, dim=-1)
|
880 |
+
rot_mat = torch.stack([vec_right, vec_up, vec_forward], 2)
|
881 |
+
rot_pred = rot_mat.reshape(*pose.shape[:-1], -1)
|
882 |
+
pose_zflip = torch.cat([rot_pred, pose[:,:,9:]], -1)
|
883 |
+
else:
|
884 |
+
raise NotImplementedError
|
885 |
+
return pose_zflip
|
886 |
+
|
887 |
+
def render(self, shape, texture, mvp, w2c, campos, resolution, background='none', im_features=None, light=None, prior_shape=None, render_flow=True, dino_pred=None, render_mode='diffuse', two_sided_shading=True, num_frames=None, spp=1):
|
888 |
+
h, w = resolution
|
889 |
+
N = len(mvp)
|
890 |
+
if background in ['none', 'black']:
|
891 |
+
bg_image = torch.zeros((N, h, w, 3), device=mvp.device)
|
892 |
+
elif background == 'white':
|
893 |
+
bg_image = torch.ones((N, h, w, 3), device=mvp.device)
|
894 |
+
elif background == 'checkerboard':
|
895 |
+
bg_image = torch.FloatTensor(util.checkerboard((h, w), 8), device=self.device).repeat(N, 1, 1, 1) # NxHxWxC
|
896 |
+
else:
|
897 |
+
raise NotImplementedError
|
898 |
+
|
899 |
+
frame_rendered = render.render_mesh(
|
900 |
+
self.glctx,
|
901 |
+
shape,
|
902 |
+
mtx_in=mvp,
|
903 |
+
w2c=w2c,
|
904 |
+
view_pos=campos,
|
905 |
+
material=texture,
|
906 |
+
lgt=light,
|
907 |
+
resolution=resolution,
|
908 |
+
spp=spp,
|
909 |
+
msaa=True,
|
910 |
+
background=bg_image,
|
911 |
+
bsdf=render_mode,
|
912 |
+
feat=im_features,
|
913 |
+
prior_mesh=prior_shape,
|
914 |
+
two_sided_shading=two_sided_shading,
|
915 |
+
render_flow=render_flow,
|
916 |
+
dino_pred=dino_pred,
|
917 |
+
num_frames=num_frames)
|
918 |
+
shaded = frame_rendered['shaded'].permute(0, 3, 1, 2)
|
919 |
+
image_pred = shaded[:, :3, :, :]
|
920 |
+
mask_pred = shaded[:, 3, :, :]
|
921 |
+
albedo = frame_rendered['kd'].permute(0, 3, 1, 2)[:, :3, :, :]
|
922 |
+
if 'shading' in frame_rendered:
|
923 |
+
shading = frame_rendered['shading'].permute(0, 3, 1, 2)[:, :1, :, :]
|
924 |
+
else:
|
925 |
+
shading = None
|
926 |
+
if render_flow:
|
927 |
+
flow_pred = frame_rendered['flow']
|
928 |
+
flow_pred = flow_pred.permute(0, 3, 1, 2)[:, :2, :, :]
|
929 |
+
else:
|
930 |
+
flow_pred = None
|
931 |
+
if dino_pred is not None:
|
932 |
+
dino_feat_im_pred = frame_rendered['dino_feat_im_pred']
|
933 |
+
dino_feat_im_pred = dino_feat_im_pred.permute(0, 3, 1, 2)[:, :-1]
|
934 |
+
else:
|
935 |
+
dino_feat_im_pred = None
|
936 |
+
|
937 |
+
return image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading
|
938 |
+
|
939 |
+
def compute_reconstruction_losses(self, image_pred, image_gt, mask_pred, mask_gt, mask_dt, mask_valid, flow_pred, flow_gt, dino_feat_im_gt, dino_feat_im_pred, background_mode='none', reduce=False):
|
940 |
+
losses = {}
|
941 |
+
batch_size, num_frames, _, h, w = image_pred.shape # BxFxCxHxW
|
942 |
+
|
943 |
+
# image_loss = (image_pred - image_gt) ** 2
|
944 |
+
image_loss = (image_pred - image_gt).abs()
|
945 |
+
|
946 |
+
## silhouette loss
|
947 |
+
mask_pred_valid = mask_pred * mask_valid
|
948 |
+
# mask_pred_valid = mask_pred
|
949 |
+
# losses["silhouette_loss"] = ((mask_pred - mask_gt) ** 2).mean()
|
950 |
+
# mask_loss_mask = (image_loss.mean(2).detach() > 0.05).float()
|
951 |
+
mask_loss = (mask_pred_valid - mask_gt) ** 2
|
952 |
+
# mask_loss = nn.functional.mse_loss(mask_pred, mask_gt)
|
953 |
+
# num_mask_pixels = mask_loss_mask.reshape(batch_size*num_frames, -1).sum(1).clamp(min=1)
|
954 |
+
# losses["silhouette_loss"] = (mask_loss.reshape(batch_size*num_frames, -1).sum(1) / num_mask_pixels).mean()
|
955 |
+
losses['silhouette_loss'] = mask_loss.view(batch_size, num_frames, -1).mean(2)
|
956 |
+
losses['silhouette_dt_loss'] = (mask_pred * mask_dt[:,:,1]).view(batch_size, num_frames, -1).mean(2)
|
957 |
+
losses['silhouette_inv_dt_loss'] = ((1-mask_pred) * mask_dt[:,:,0]).view(batch_size, num_frames, -1).mean(2)
|
958 |
+
|
959 |
+
mask_pred_binary = (mask_pred_valid > 0.).float().detach()
|
960 |
+
mask_both_binary = (mask_pred_binary * mask_gt).view(batch_size*num_frames, 1, *mask_pred.shape[2:])
|
961 |
+
mask_both_binary = (nn.functional.avg_pool2d(mask_both_binary, 3, stride=1, padding=1).view(batch_size, num_frames, *mask_pred.shape[2:]) > 0.99).float().detach() # erode by 1 pixel
|
962 |
+
|
963 |
+
## reconstruction loss
|
964 |
+
# image_loss_mask = (mask_pred*mask_gt).unsqueeze(2).expand_as(image_gt)
|
965 |
+
# image_loss = image_loss * image_loss_mask
|
966 |
+
# num_mask_pixels = image_loss_mask.reshape(batch_size*num_frames, -1).sum(1).clamp(min=1)
|
967 |
+
# losses["rgb_loss"] = (image_loss.reshape(batch_size*num_frames, -1).sum(1) / num_mask_pixels).mean()
|
968 |
+
if background_mode in ['background', 'input']:
|
969 |
+
pass
|
970 |
+
else:
|
971 |
+
image_loss = image_loss * mask_both_binary.unsqueeze(2)
|
972 |
+
losses['rgb_loss'] = image_loss.reshape(batch_size, num_frames, -1).mean(2)
|
973 |
+
|
974 |
+
if self.cfgs.get('perceptual_loss_weight', 0.) > 0:
|
975 |
+
if background_mode in ['background', 'input']:
|
976 |
+
perc_image_pred = image_pred
|
977 |
+
perc_image_gt = image_gt
|
978 |
+
else:
|
979 |
+
perc_image_pred = image_pred * mask_pred_binary.unsqueeze(2) + 0.5 * (1-mask_pred_binary.unsqueeze(2))
|
980 |
+
perc_image_gt = image_gt * mask_pred_binary.unsqueeze(2) + 0.5 * (1-mask_pred_binary.unsqueeze(2))
|
981 |
+
losses['perceptual_loss'] = self.perceptual_loss(perc_image_pred.view(-1, *image_pred.shape[2:]) *2-1, perc_image_gt.view(-1, *image_gt.shape[2:]) *2-1).view(batch_size, num_frames)
|
982 |
+
|
983 |
+
## flow loss - between first and second frame
|
984 |
+
if flow_pred is not None:
|
985 |
+
flow_loss = (flow_pred - flow_gt).abs()
|
986 |
+
flow_loss_mask = mask_both_binary[:,:-1].unsqueeze(2).expand_as(flow_gt).detach()
|
987 |
+
|
988 |
+
## ignore frames where GT flow is too large (likely inaccurate)
|
989 |
+
large_flow = (flow_gt.abs() > 0.5).float() * flow_loss_mask
|
990 |
+
large_flow = (large_flow.view(batch_size, num_frames-1, -1).sum(2) > 0).float()
|
991 |
+
self.large_flow = large_flow
|
992 |
+
|
993 |
+
flow_loss = flow_loss * flow_loss_mask * (1 - large_flow[:,:,None,None,None])
|
994 |
+
num_mask_pixels = flow_loss_mask.reshape(batch_size, num_frames-1, -1).sum(2).clamp(min=1)
|
995 |
+
losses['flow_loss'] = (flow_loss.reshape(batch_size, num_frames-1, -1).sum(2) / num_mask_pixels)
|
996 |
+
# losses["flow_loss"] = flow_loss.mean()
|
997 |
+
|
998 |
+
if dino_feat_im_pred is not None:
|
999 |
+
dino_feat_loss = (dino_feat_im_pred - dino_feat_im_gt) ** 2
|
1000 |
+
dino_feat_loss = dino_feat_loss * mask_both_binary.unsqueeze(2)
|
1001 |
+
losses['dino_feat_im_loss'] = dino_feat_loss.reshape(batch_size, num_frames, -1).mean(2)
|
1002 |
+
|
1003 |
+
if reduce:
|
1004 |
+
for k, v in losses.item():
|
1005 |
+
losses[k] = v.mean()
|
1006 |
+
return losses
|
1007 |
+
|
1008 |
+
def compute_pose_xflip_reg_loss(self, input_image, dino_feat_im, pose_raw, input_image_xflip_flag=None):
|
1009 |
+
image_xflip = input_image.flip(4)
|
1010 |
+
if dino_feat_im is not None:
|
1011 |
+
dino_feat_im_xflip = dino_feat_im.flip(4)
|
1012 |
+
else:
|
1013 |
+
dino_feat_im_xflip = None
|
1014 |
+
feat_xflip, _ = self.netInstance.forward_encoder(image_xflip, dino_feat_im_xflip)
|
1015 |
+
batch_size, num_frames = input_image.shape[:2]
|
1016 |
+
pose_xflip_raw = self.netInstance.forward_pose(image_xflip, feat_xflip, dino_feat_im_xflip)
|
1017 |
+
|
1018 |
+
if input_image_xflip_flag is not None:
|
1019 |
+
pose_xflip_raw_xflip = pose_xflip_raw * torch.FloatTensor([-1,1,1,-1,1,1]).to(pose_raw.device) # forward x, trans x
|
1020 |
+
pose_xflip_raw = pose_xflip_raw * (1 - input_image_xflip_flag.view(batch_size * num_frames, 1)) + pose_xflip_raw_xflip * input_image_xflip_flag.view(batch_size * num_frames, 1)
|
1021 |
+
|
1022 |
+
rot_rep = self.netInstance.rot_rep
|
1023 |
+
if rot_rep == 'euler_angle' or rot_rep == 'soft_calss':
|
1024 |
+
pose_xflip_xflip = pose_xflip * torch.FloatTensor([1,-1,-1,-1,1,1]).to(pose_xflip.device) # rot y+z, trans x
|
1025 |
+
pose_xflip_reg_loss = ((pose_xflip_xflip - pose) ** 2.).mean()
|
1026 |
+
elif rot_rep == 'quaternion':
|
1027 |
+
rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose[...,:4]), convention='XYZ')
|
1028 |
+
pose_euler = torch.cat([rot_euler, pose[...,4:]], -1)
|
1029 |
+
rot_xflip_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose_xflip[...,:4]), convention='XYZ')
|
1030 |
+
pose_xflip_euler = torch.cat([rot_xflip_euler, pose_xflip[...,4:]], -1)
|
1031 |
+
pose_xflip_euler_xflip = pose_xflip_euler * torch.FloatTensor([1,-1,-1,-1,1,1]).to(pose_xflip.device) # rot y+z, trans x
|
1032 |
+
pose_xflip_reg_loss = ((pose_xflip_euler_xflip - pose_euler) ** 2.).mean()
|
1033 |
+
elif rot_rep == 'lookat':
|
1034 |
+
pose_xflip_raw_xflip = pose_xflip_raw * torch.FloatTensor([-1,1,1,-1,1,1]).to(pose_raw.device) # forward x, trans x
|
1035 |
+
pose_xflip_reg_loss = ((pose_xflip_raw_xflip - pose_raw)[...,0] ** 2.) # compute x only
|
1036 |
+
# if epoch >= self.nolookat_zflip_loss_epochs and self.lookat_zflip_no_other_losses:
|
1037 |
+
# pose_xflip_reg_loss = pose_xflip_reg_loss.mean(1) * is_pose_1_better
|
1038 |
+
pose_xflip_reg_loss = pose_xflip_reg_loss.mean()
|
1039 |
+
return pose_xflip_reg_loss, pose_xflip_raw
|
1040 |
+
|
1041 |
+
def compute_edge_length_reg_loss(self, mesh, prior_mesh):
|
1042 |
+
prior_edge_lengths = get_edge_length(prior_mesh.v_pos, prior_mesh.t_pos_idx)
|
1043 |
+
max_length = prior_edge_lengths.max().detach() *1.1
|
1044 |
+
edge_lengths = get_edge_length(mesh.v_pos, mesh.t_pos_idx)
|
1045 |
+
mesh_edge_length_loss = ((edge_lengths - max_length).clamp(min=0)**2).mean()
|
1046 |
+
return mesh_edge_length_loss, edge_lengths
|
1047 |
+
|
1048 |
+
def compute_regularizers(self, mesh, prior_mesh, input_image, dino_feat_im, pose_raw, input_image_xflip_flag=None, arti_params=None, deformation=None):
|
1049 |
+
losses = {}
|
1050 |
+
aux = {}
|
1051 |
+
|
1052 |
+
if self.enable_prior:
|
1053 |
+
losses.update(self.netPrior.netShape.get_sdf_reg_loss())
|
1054 |
+
|
1055 |
+
if self.cfgs.get('pose_xflip_reg_loss_weight', 0.) > 0:
|
1056 |
+
losses["pose_xflip_reg_loss"], aux['pose_xflip_raw'] = self.compute_pose_xflip_reg_loss(input_image, dino_feat_im, pose_raw, input_image_xflip_flag)
|
1057 |
+
|
1058 |
+
b, f = input_image.shape[:2]
|
1059 |
+
if b >= 2:
|
1060 |
+
vec_forward = pose_raw[..., :3]
|
1061 |
+
losses['pose_entropy_loss'] = (vec_forward[:b//2] * vec_forward[b//2:(b//2)*2]).sum(-1).mean()
|
1062 |
+
else:
|
1063 |
+
losses['pose_entropy_loss'] = 0.
|
1064 |
+
|
1065 |
+
losses['mesh_normal_consistency_loss'] = normal_consistency(mesh.v_pos, mesh.t_pos_idx)
|
1066 |
+
losses['mesh_edge_length_loss'], aux['edge_lengths'] = self.compute_edge_length_reg_loss(mesh, prior_mesh)
|
1067 |
+
if arti_params is not None:
|
1068 |
+
losses['arti_reg_loss'] = (arti_params ** 2).mean()
|
1069 |
+
|
1070 |
+
if deformation is not None:
|
1071 |
+
losses['deformation_reg_loss'] = (deformation ** 2).mean()
|
1072 |
+
# losses['deformation_reg_loss'] = deformation.abs().mean()
|
1073 |
+
|
1074 |
+
return losses, aux
|
1075 |
+
|
1076 |
+
def forward(self, batch, epoch, iter, is_train=True, viz_logger=None, total_iter=None, save_results=False, save_dir=None, which_data='', logger_prefix='', is_training=True):
|
1077 |
+
batch = [x.to(self.device) if x is not None else None for x in batch]
|
1078 |
+
input_image, mask_gt, mask_dt, mask_valid, flow_gt, bbox, bg_image, dino_feat_im, dino_cluster_im, seq_idx, frame_idx = batch
|
1079 |
+
batch_size, num_frames, _, h0, w0 = input_image.shape # BxFxCxHxW
|
1080 |
+
h = w = self.out_image_size
|
1081 |
+
|
1082 |
+
def collapseF(x):
|
1083 |
+
return None if x is None else x.view(batch_size * num_frames, *x.shape[2:])
|
1084 |
+
def expandF(x):
|
1085 |
+
return None if x is None else x.view(batch_size, num_frames, *x.shape[1:])
|
1086 |
+
|
1087 |
+
if flow_gt.dim() == 2: # dummy tensor for not loading flow
|
1088 |
+
flow_gt = None
|
1089 |
+
if dino_feat_im.dim() == 2: # dummy tensor for not loading dino features
|
1090 |
+
dino_feat_im = None
|
1091 |
+
dino_feat_im_gt = None
|
1092 |
+
else:
|
1093 |
+
dino_feat_im_gt = expandF(torch.nn.functional.interpolate(collapseF(dino_feat_im), size=[h, w], mode="bilinear"))[:, :, :self.dino_feature_recon_dim]
|
1094 |
+
if dino_cluster_im.dim() == 2: # dummy tensor for not loading dino clusters
|
1095 |
+
dino_cluster_im = None
|
1096 |
+
dino_cluster_im_gt = None
|
1097 |
+
else:
|
1098 |
+
dino_cluster_im_gt = expandF(torch.nn.functional.interpolate(collapseF(dino_cluster_im), size=[h, w], mode="nearest"))
|
1099 |
+
|
1100 |
+
seq_idx = seq_idx.squeeze(1)
|
1101 |
+
# seq_idx = seq_idx * 0 # single sequnce model
|
1102 |
+
frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness = bbox.unbind(2) # BxFx7
|
1103 |
+
bbox = torch.stack([crop_x0, crop_y0, crop_w, crop_h], 2)
|
1104 |
+
mask_gt = (mask_gt[:, :, 0, :, :] > 0.9).float() # BxFxHxW
|
1105 |
+
mask_dt = mask_dt / self.in_image_size
|
1106 |
+
|
1107 |
+
if which_data != 'video':
|
1108 |
+
flow_gt = None
|
1109 |
+
|
1110 |
+
aux_viz = {}
|
1111 |
+
|
1112 |
+
## GT
|
1113 |
+
image_gt = input_image
|
1114 |
+
if self.out_image_size != self.in_image_size:
|
1115 |
+
image_gt = expandF(torch.nn.functional.interpolate(collapseF(image_gt), size=[h, w], mode='bilinear'))
|
1116 |
+
if flow_gt is not None:
|
1117 |
+
flow_gt = torch.nn.functional.interpolate(flow_gt.view(batch_size*(num_frames-1), 2, h0, w0), size=[h, w], mode="bilinear").view(batch_size, num_frames-1, 2, h, w)
|
1118 |
+
|
1119 |
+
self.train_pose_only = False
|
1120 |
+
if epoch in self.pose_epochs:
|
1121 |
+
if (total_iter // self.pose_iters) % 2 == 0:
|
1122 |
+
self.train_pose_only = True
|
1123 |
+
|
1124 |
+
## flip input and pose
|
1125 |
+
if epoch in self.pose_xflip_recon_epochs:
|
1126 |
+
input_image_xflip = input_image.flip(-1)
|
1127 |
+
input_image_xflip_flag = torch.randint(0, 2, (batch_size, num_frames), device=input_image.device)
|
1128 |
+
input_image = input_image * (1 - input_image_xflip_flag[:,:,None,None,None]) + input_image_xflip * input_image_xflip_flag[:,:,None,None,None]
|
1129 |
+
else:
|
1130 |
+
input_image_xflip_flag = None
|
1131 |
+
|
1132 |
+
## 1st pose hypothesis with original predictions
|
1133 |
+
|
1134 |
+
# ==============================================================================================
|
1135 |
+
# Predict prior mesh.
|
1136 |
+
# ==============================================================================================
|
1137 |
+
if self.enable_prior:
|
1138 |
+
if epoch < self.dmtet_grid_smaller_epoch:
|
1139 |
+
if self.netPrior.netShape.grid_res != self.dmtet_grid_smaller:
|
1140 |
+
self.netPrior.netShape.load_tets(self.dmtet_grid_smaller)
|
1141 |
+
else:
|
1142 |
+
if self.netPrior.netShape.grid_res != self.dmtet_grid:
|
1143 |
+
self.netPrior.netShape.load_tets(self.dmtet_grid)
|
1144 |
+
|
1145 |
+
perturb_sdf = self.perturb_sdf if is_train else False
|
1146 |
+
prior_shape, dino_pred = self.netPrior(perturb_sdf=perturb_sdf, total_iter=total_iter, is_training=is_training)
|
1147 |
+
else:
|
1148 |
+
prior_shape = None
|
1149 |
+
raise NotImplementedError
|
1150 |
+
|
1151 |
+
shape, pose_raw, pose, mvp, w2c, campos, texture, im_features, deformation, arti_params, light, forward_aux = self.netInstance(input_image, prior_shape, epoch, dino_feat_im, dino_cluster_im, total_iter, is_training=is_training) # frame dim collapsed N=(B*F)
|
1152 |
+
rot_logit = forward_aux['rot_logit']
|
1153 |
+
rot_idx = forward_aux['rot_idx']
|
1154 |
+
rot_prob = forward_aux['rot_prob']
|
1155 |
+
aux_viz.update(forward_aux)
|
1156 |
+
|
1157 |
+
if self.train_pose_only:
|
1158 |
+
safe_detach = lambda x: x.detach() if x is not None else None
|
1159 |
+
prior_shape = safe_detach(prior_shape)
|
1160 |
+
shape = safe_detach(shape)
|
1161 |
+
im_features = safe_detach(im_features)
|
1162 |
+
arti_params = safe_detach(arti_params)
|
1163 |
+
deformation = safe_detach(deformation)
|
1164 |
+
set_requires_grad(texture, False)
|
1165 |
+
set_requires_grad(light, False)
|
1166 |
+
set_requires_grad(dino_pred, False)
|
1167 |
+
else:
|
1168 |
+
set_requires_grad(texture, True)
|
1169 |
+
set_requires_grad(light, True)
|
1170 |
+
set_requires_grad(dino_pred, True)
|
1171 |
+
|
1172 |
+
render_flow = self.render_flow and num_frames > 1
|
1173 |
+
image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading = self.render(shape, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=im_features, light=light, prior_shape=prior_shape, render_flow=render_flow, dino_pred=dino_pred, num_frames=num_frames, spp=self.renderer_spp)
|
1174 |
+
image_pred, mask_pred, flow_pred, dino_feat_im_pred = map(expandF, (image_pred, mask_pred, flow_pred, dino_feat_im_pred))
|
1175 |
+
if flow_pred is not None:
|
1176 |
+
flow_pred = flow_pred[:, :-1] # Bx(F-1)x2xHxW
|
1177 |
+
|
1178 |
+
if self.blur_mask:
|
1179 |
+
sigma = max(0.5, 3 * (1 - total_iter / self.blur_mask_iter))
|
1180 |
+
if sigma > 0.5:
|
1181 |
+
mask_gt = util.blur_image(mask_gt, kernel_size=9, sigma=sigma, mode='gaussian')
|
1182 |
+
# mask_pred = util.blur_image(mask_pred, kernel_size=7, mode='average')
|
1183 |
+
|
1184 |
+
losses = self.compute_reconstruction_losses(image_pred, image_gt, mask_pred, mask_gt, mask_dt, mask_valid, flow_pred, flow_gt, dino_feat_im_gt, dino_feat_im_pred, background_mode=self.background_mode, reduce=False)
|
1185 |
+
|
1186 |
+
## TODO: assume flow loss is not used
|
1187 |
+
logit_loss_target = torch.zeros_like(expandF(rot_logit))
|
1188 |
+
final_losses = {}
|
1189 |
+
for name, loss in losses.items():
|
1190 |
+
loss_weight_logit = self.cfgs.get(f"{name}_weight", 0.)
|
1191 |
+
# if (name in ['flow_loss'] and epoch not in self.flow_loss_epochs) or (name in ['rgb_loss', 'perceptual_loss'] and epoch not in self.texture_epochs):
|
1192 |
+
# if name in ['flow_loss', 'rgb_loss', 'perceptual_loss']:
|
1193 |
+
# loss_weight_logit = 0.
|
1194 |
+
if name in ['sdf_bce_reg_loss', 'sdf_gradient_reg_loss', 'sdf_inflate_reg_loss']:
|
1195 |
+
if total_iter >= self.sdf_reg_decay_start_iter:
|
1196 |
+
decay_rate = max(0, 1 - (total_iter-self.sdf_reg_decay_start_iter) / 10000)
|
1197 |
+
loss_weight_logit = max(loss_weight_logit * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.))
|
1198 |
+
if name in ['dino_feat_im_loss']:
|
1199 |
+
loss_weight_logit = loss_weight_logit * self.cfgs.get("logit_loss_dino_feat_im_loss_multiplier", 1.)
|
1200 |
+
if loss_weight_logit > 0:
|
1201 |
+
logit_loss_target += loss * loss_weight_logit
|
1202 |
+
|
1203 |
+
if self.netInstance.rot_rep in ['quadlookat', 'octlookat']:
|
1204 |
+
loss = loss * rot_prob.detach().view(batch_size, num_frames)[:, :loss.shape[1]] *self.netInstance.num_pose_hypos
|
1205 |
+
if name == 'flow_loss' and num_frames > 1:
|
1206 |
+
ri = rot_idx.view(batch_size, num_frames)
|
1207 |
+
same_rot_idx = (ri[:, 1:] == ri[:, :-1]).float()
|
1208 |
+
loss = loss * same_rot_idx
|
1209 |
+
final_losses[name] = loss.mean()
|
1210 |
+
final_losses['logit_loss'] = ((expandF(rot_logit) - logit_loss_target.detach())**2.).mean()
|
1211 |
+
|
1212 |
+
## regularizers
|
1213 |
+
regularizers, aux = self.compute_regularizers(shape, prior_shape, input_image, dino_feat_im, pose_raw, input_image_xflip_flag, arti_params, deformation)
|
1214 |
+
final_losses.update(regularizers)
|
1215 |
+
aux_viz.update(aux)
|
1216 |
+
|
1217 |
+
total_loss = 0
|
1218 |
+
for name, loss in final_losses.items():
|
1219 |
+
loss_weight = self.cfgs.get(f"{name}_weight", 0.)
|
1220 |
+
if loss_weight <= 0:
|
1221 |
+
continue
|
1222 |
+
|
1223 |
+
if self.train_pose_only:
|
1224 |
+
if name not in ['silhouette_loss', 'silhouette_dt_loss', 'silhouette_inv_dt_loss', 'flow_loss', 'pose_xflip_reg_loss', 'lookat_zflip_loss', 'dino_feat_im_loss']:
|
1225 |
+
continue
|
1226 |
+
if epoch not in self.flow_loss_epochs:
|
1227 |
+
if name in ['flow_loss']:
|
1228 |
+
continue
|
1229 |
+
if epoch not in self.texture_epochs:
|
1230 |
+
if name in ['rgb_loss', 'perceptual_loss']:
|
1231 |
+
continue
|
1232 |
+
if epoch not in self.lookat_zflip_loss_epochs:
|
1233 |
+
if name in ['lookat_zflip_loss']:
|
1234 |
+
continue
|
1235 |
+
if name in ['mesh_laplacian_smoothing_loss', 'mesh_normal_consistency_loss']:
|
1236 |
+
if total_iter < self.cfgs.get('mesh_reg_start_iter', 0):
|
1237 |
+
continue
|
1238 |
+
if epoch >= self.mesh_reg_decay_epoch:
|
1239 |
+
decay_rate = self.mesh_reg_decay_rate ** (epoch - self.mesh_reg_decay_epoch)
|
1240 |
+
loss_weight = max(loss_weight * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.))
|
1241 |
+
if epoch not in self.sdf_inflate_reg_loss_epochs:
|
1242 |
+
if name in ['sdf_inflate_reg_loss']:
|
1243 |
+
continue
|
1244 |
+
if epoch not in self.arti_reg_loss_epochs:
|
1245 |
+
if name in ['arti_reg_loss']:
|
1246 |
+
continue
|
1247 |
+
if name in ['sdf_bce_reg_loss', 'sdf_gradient_reg_loss', 'sdf_inflate_reg_loss']:
|
1248 |
+
if total_iter >= self.sdf_reg_decay_start_iter:
|
1249 |
+
decay_rate = max(0, 1 - (total_iter-self.sdf_reg_decay_start_iter) / 10000)
|
1250 |
+
loss_weight = max(loss_weight * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.))
|
1251 |
+
|
1252 |
+
total_loss += loss * loss_weight
|
1253 |
+
|
1254 |
+
self.total_loss += total_loss # reset to 0 in backward step
|
1255 |
+
|
1256 |
+
if torch.isnan(self.total_loss):
|
1257 |
+
print("NaN in loss...")
|
1258 |
+
import ipdb; ipdb.set_trace()
|
1259 |
+
|
1260 |
+
final_losses['logit_loss_target'] = logit_loss_target.mean()
|
1261 |
+
|
1262 |
+
metrics = {'loss': total_loss, **final_losses}
|
1263 |
+
|
1264 |
+
## log visuals
|
1265 |
+
if viz_logger is not None:
|
1266 |
+
b0 = max(min(batch_size, 16//num_frames), 1)
|
1267 |
+
viz_logger.add_image(logger_prefix+'image/image_gt', misc.image_grid(image_gt.detach().cpu()[:b0,:].reshape(-1,*input_image.shape[2:]).clamp(0,1)), total_iter)
|
1268 |
+
viz_logger.add_image(logger_prefix+'image/image_pred', misc.image_grid(image_pred.detach().cpu()[:b0,:].reshape(-1,*image_pred.shape[2:]).clamp(0,1)), total_iter)
|
1269 |
+
# viz_logger.add_image(logger_prefix+'image/flow_loss_mask', misc.image_grid(flow_loss_mask[:b0,:,:1].reshape(-1,1,*flow_loss_mask.shape[3:]).repeat(1,3,1,1).clamp(0,1)), total_iter)
|
1270 |
+
viz_logger.add_image(logger_prefix+'image/mask_gt', misc.image_grid(mask_gt.detach().cpu()[:b0,:].reshape(-1,*mask_gt.shape[2:]).unsqueeze(1).repeat(1,3,1,1).clamp(0,1)), total_iter)
|
1271 |
+
viz_logger.add_image(logger_prefix+'image/mask_pred', misc.image_grid(mask_pred.detach().cpu()[:b0,:].reshape(-1,*mask_pred.shape[2:]).unsqueeze(1).repeat(1,3,1,1).clamp(0,1)), total_iter)
|
1272 |
+
|
1273 |
+
if self.render_flow and flow_gt is not None:
|
1274 |
+
flow_gt = flow_gt.detach().cpu()
|
1275 |
+
flow_gt_viz = torch.cat([flow_gt[:b0], torch.zeros_like(flow_gt[:b0,:,:1])], 2) + 0.5 # -0.5~1.5
|
1276 |
+
flow_gt_viz = torch.nn.functional.pad(flow_gt_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1])
|
1277 |
+
|
1278 |
+
## draw marker on large flow frames
|
1279 |
+
large_flow_marker_mask = torch.zeros_like(flow_gt_viz)
|
1280 |
+
large_flow_marker_mask[:,:,:,:8,:8] = 1.
|
1281 |
+
large_flow = torch.cat([self.large_flow, self.large_flow[:,:1] *0.], 1).detach().cpu()[:b0]
|
1282 |
+
large_flow_marker_mask = large_flow_marker_mask * large_flow[:,:,None,None,None]
|
1283 |
+
red = torch.FloatTensor([1,0,0])[None,None,:,None,None]
|
1284 |
+
flow_gt_viz = large_flow_marker_mask * red + (1-large_flow_marker_mask) * flow_gt_viz
|
1285 |
+
|
1286 |
+
viz_logger.add_image(logger_prefix+'image/flow_gt', misc.image_grid(flow_gt_viz.reshape(-1,*flow_gt_viz.shape[2:])), total_iter)
|
1287 |
+
|
1288 |
+
if self.render_flow and flow_pred is not None:
|
1289 |
+
flow_pred = flow_pred.detach().cpu()
|
1290 |
+
flow_pred_viz = torch.cat([flow_pred[:b0], torch.zeros_like(flow_pred[:b0,:,:1])], 2) + 0.5 # -0.5~1.5
|
1291 |
+
flow_pred_viz = torch.nn.functional.pad(flow_pred_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1])
|
1292 |
+
viz_logger.add_image(logger_prefix+'image/flow_pred', misc.image_grid(flow_pred_viz.reshape(-1,*flow_pred_viz.shape[2:])), total_iter)
|
1293 |
+
|
1294 |
+
if light is not None:
|
1295 |
+
param_names = ['dir_x', 'dir_y', 'dir_z', 'int_ambient', 'int_diffuse']
|
1296 |
+
for name, param in zip(param_names, light.light_params.unbind(-1)):
|
1297 |
+
viz_logger.add_histogram(logger_prefix+'light/'+name, param, total_iter)
|
1298 |
+
viz_logger.add_image(
|
1299 |
+
logger_prefix + f'image/albedo',
|
1300 |
+
misc.image_grid(expandF(albedo)[:b0, ...].view(-1, *albedo.shape[1:])),
|
1301 |
+
total_iter)
|
1302 |
+
viz_logger.add_image(
|
1303 |
+
logger_prefix + f'image/shading',
|
1304 |
+
misc.image_grid(expandF(shading)[:b0, ...].view(-1, *shading.shape[1:]).repeat(1, 3, 1, 1) /2.),
|
1305 |
+
total_iter)
|
1306 |
+
|
1307 |
+
viz_logger.add_histogram(logger_prefix+'sdf', self.netPrior.netShape.get_sdf(perturb_sdf=False), total_iter)
|
1308 |
+
viz_logger.add_histogram(logger_prefix+'coordinates', shape.v_pos, total_iter)
|
1309 |
+
if arti_params is not None:
|
1310 |
+
viz_logger.add_histogram(logger_prefix+'arti_params', arti_params, total_iter)
|
1311 |
+
viz_logger.add_histogram(logger_prefix+'edge_lengths', aux_viz['edge_lengths'], total_iter)
|
1312 |
+
|
1313 |
+
if deformation is not None:
|
1314 |
+
viz_logger.add_histogram(logger_prefix+'deformation', deformation, total_iter)
|
1315 |
+
|
1316 |
+
rot_rep = self.netInstance.rot_rep
|
1317 |
+
if rot_rep == 'euler_angle' or rot_rep == 'soft_calss':
|
1318 |
+
for i, name in enumerate(['rot_x', 'rot_y', 'rot_z', 'trans_x', 'trans_y', 'trans_z']):
|
1319 |
+
viz_logger.add_histogram(logger_prefix+'pose/'+name, pose[...,i], total_iter)
|
1320 |
+
elif rot_rep == 'quaternion':
|
1321 |
+
for i, name in enumerate(['qt_0', 'qt_1', 'qt_2', 'qt_3', 'trans_x', 'trans_y', 'trans_z']):
|
1322 |
+
viz_logger.add_histogram(logger_prefix+'pose/'+name, pose[...,i], total_iter)
|
1323 |
+
rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose.detach().cpu()[...,:4]), convention='XYZ')
|
1324 |
+
for i, name in enumerate(['rot_x', 'rot_y', 'rot_z']):
|
1325 |
+
viz_logger.add_histogram(logger_prefix+'pose/'+name, rot_euler[...,i], total_iter)
|
1326 |
+
elif rot_rep in ['lookat', 'quadlookat', 'octlookat']:
|
1327 |
+
for i, name in enumerate(['fwd_x', 'fwd_y', 'fwd_z']):
|
1328 |
+
viz_logger.add_histogram(logger_prefix+'pose/'+name, pose_raw[...,i], total_iter)
|
1329 |
+
for i, name in enumerate(['trans_x', 'trans_y', 'trans_z']):
|
1330 |
+
viz_logger.add_histogram(logger_prefix+'pose/'+name, pose_raw[...,-3+i], total_iter)
|
1331 |
+
|
1332 |
+
if rot_rep in ['quadlookat', 'octlookat']:
|
1333 |
+
for i, rp in enumerate(forward_aux['rots_probs'].unbind(-1)):
|
1334 |
+
viz_logger.add_histogram(logger_prefix+'pose/rot_prob_%d'%i, rp, total_iter)
|
1335 |
+
|
1336 |
+
if 'pose_xflip_raw' in aux_viz:
|
1337 |
+
pose_xflip_raw = aux_viz['pose_xflip_raw']
|
1338 |
+
if rot_rep == 'euler_angle' or rot_rep == 'soft_calss':
|
1339 |
+
for i, name in enumerate(['rot_x', 'rot_y', 'rot_z', 'trans_x', 'trans_y', 'trans_z']):
|
1340 |
+
viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip[...,i], total_iter)
|
1341 |
+
elif rot_rep == 'quaternion':
|
1342 |
+
for i, name in enumerate(['qt_0', 'qt_1', 'qt_2', 'qt_3', 'trans_x', 'trans_y', 'trans_z']):
|
1343 |
+
viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip[...,i], total_iter)
|
1344 |
+
rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose_xflip.detach().cpu()[...,:4]), convention='XYZ')
|
1345 |
+
for i, name in enumerate(['rot_x', 'rot_y', 'rot_z']):
|
1346 |
+
viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, rot_euler[...,i], total_iter)
|
1347 |
+
elif rot_rep in ['lookat', 'quadlookat', 'octlookat']:
|
1348 |
+
for i, name in enumerate(['fwd_x', 'fwd_y', 'fwd_z']):
|
1349 |
+
viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip_raw[...,i], total_iter)
|
1350 |
+
for i, name in enumerate(['trans_x', 'trans_y', 'trans_z']):
|
1351 |
+
viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip_raw[...,-3+i], total_iter)
|
1352 |
+
|
1353 |
+
if dino_feat_im_gt is not None:
|
1354 |
+
dino_feat_im_gt_first3 = dino_feat_im_gt[:,:,:3]
|
1355 |
+
viz_logger.add_image(logger_prefix+'image/dino_feat_im_gt', misc.image_grid(dino_feat_im_gt_first3.detach().cpu()[:b0,:].reshape(-1,*dino_feat_im_gt_first3.shape[2:]).clamp(0,1)), total_iter)
|
1356 |
+
|
1357 |
+
if dino_cluster_im_gt is not None:
|
1358 |
+
viz_logger.add_image(logger_prefix+'image/dino_cluster_im_gt', misc.image_grid(dino_cluster_im_gt.detach().cpu()[:b0,:].reshape(-1,*dino_cluster_im_gt.shape[2:]).clamp(0,1)), total_iter)
|
1359 |
+
|
1360 |
+
if dino_feat_im_pred is not None:
|
1361 |
+
dino_feat_im_pred_first3 = dino_feat_im_pred[:,:,:3]
|
1362 |
+
viz_logger.add_image(logger_prefix+'image/dino_feat_im_pred', misc.image_grid(dino_feat_im_pred_first3.detach().cpu()[:b0,:].reshape(-1,*dino_feat_im_pred_first3.shape[2:]).clamp(0,1)), total_iter)
|
1363 |
+
|
1364 |
+
for which_shape, modes in self.extra_renders.items():
|
1365 |
+
# This is wrong
|
1366 |
+
# if which_shape == "prior":
|
1367 |
+
# shape_to_render = prior_shape.extend(im_features.shape[0])
|
1368 |
+
# needed_im_features = None
|
1369 |
+
if which_shape == "instance":
|
1370 |
+
shape_to_render = shape
|
1371 |
+
needed_im_features = im_features
|
1372 |
+
else:
|
1373 |
+
raise NotImplementedError
|
1374 |
+
|
1375 |
+
for mode in modes:
|
1376 |
+
rendered, _, _, _, _, _ = self.render(shape_to_render, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, render_mode=mode, render_flow=False, dino_pred=None)
|
1377 |
+
if 'kd' in mode:
|
1378 |
+
rendered = util.rgb_to_srgb(rendered)
|
1379 |
+
rendered = rendered.detach().cpu()
|
1380 |
+
|
1381 |
+
if 'posed_bones' in aux_viz:
|
1382 |
+
rendered_bone_image = self.render_bones(mvp, aux_viz['posed_bones'], (h, w))
|
1383 |
+
rendered_bone_image_mask = (rendered_bone_image < 1).any(1, keepdim=True).float()
|
1384 |
+
# viz_logger.add_image(logger_prefix+'image/articulation_bones', misc.image_grid(self.render_bones(mvp, aux_viz['posed_bones'])), total_iter)
|
1385 |
+
rendered = rendered_bone_image_mask*0.8 * rendered_bone_image + (1-rendered_bone_image_mask*0.8) * rendered
|
1386 |
+
|
1387 |
+
if rot_rep in ['quadlookat', 'octlookat']:
|
1388 |
+
rand_pose_flag = forward_aux['rand_pose_flag'].detach().cpu()
|
1389 |
+
rand_pose_marker_mask = torch.zeros_like(rendered)
|
1390 |
+
rand_pose_marker_mask[:,:,:16,:16] = 1.
|
1391 |
+
rand_pose_marker_mask = rand_pose_marker_mask * rand_pose_flag[:,None,None,None]
|
1392 |
+
red = torch.FloatTensor([1,0,0])[None,:,None,None]
|
1393 |
+
rendered = rand_pose_marker_mask * red + (1-rand_pose_marker_mask) * rendered
|
1394 |
+
|
1395 |
+
viz_logger.add_image(
|
1396 |
+
logger_prefix + f'image/{which_shape}_{mode}',
|
1397 |
+
misc.image_grid(expandF(rendered)[:b0, ...].view(-1, *rendered.shape[1:])),
|
1398 |
+
total_iter)
|
1399 |
+
|
1400 |
+
viz_logger.add_video(
|
1401 |
+
logger_prefix + f'animation/{which_shape}_{mode}',
|
1402 |
+
self.render_rotation_frames(shape_to_render, texture, light, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, num_frames=15, render_mode=mode, b=1).detach().cpu().unsqueeze(0),
|
1403 |
+
total_iter,
|
1404 |
+
fps=2)
|
1405 |
+
|
1406 |
+
viz_logger.add_video(
|
1407 |
+
logger_prefix+'animation/prior_image_rotation',
|
1408 |
+
self.render_rotation_frames(prior_shape, texture, light, (h, w), background=self.background_mode, im_features=im_features, num_frames=15, b=1).detach().cpu().unsqueeze(0).clamp(0,1),
|
1409 |
+
total_iter,
|
1410 |
+
fps=2)
|
1411 |
+
|
1412 |
+
viz_logger.add_video(
|
1413 |
+
logger_prefix+'animation/prior_normal_rotation',
|
1414 |
+
self.render_rotation_frames(prior_shape, texture, light, (h, w), background=self.background_mode, im_features=im_features, num_frames=15, render_mode='geo_normal', b=1).detach().cpu().unsqueeze(0),
|
1415 |
+
total_iter,
|
1416 |
+
fps=2)
|
1417 |
+
|
1418 |
+
if save_results:
|
1419 |
+
b0 = self.cfgs.get('num_saved_from_each_batch', batch_size*num_frames)
|
1420 |
+
fnames = [f'{total_iter:07d}_{fid:10d}' for fid in collapseF(frame_id.int())][:b0]
|
1421 |
+
|
1422 |
+
misc.save_images(save_dir, collapseF(image_gt)[:b0].clamp(0,1).detach().cpu().numpy(), suffix='image_gt', fnames=fnames)
|
1423 |
+
misc.save_images(save_dir, collapseF(image_pred)[:b0].clamp(0,1).detach().cpu().numpy(), suffix='image_pred', fnames=fnames)
|
1424 |
+
misc.save_images(save_dir, collapseF(mask_gt)[:b0].unsqueeze(1).repeat(1,3,1,1).clamp(0,1).detach().cpu().numpy(), suffix='mask_gt', fnames=fnames)
|
1425 |
+
misc.save_images(save_dir, collapseF(mask_pred)[:b0].unsqueeze(1).repeat(1,3,1,1).clamp(0,1).detach().cpu().numpy(), suffix='mask_pred', fnames=fnames)
|
1426 |
+
# tmp_shape = shape.first_n(b0).clone()
|
1427 |
+
# tmp_shape.material = texture
|
1428 |
+
# feat = im_features[:b0] if im_features is not None else None
|
1429 |
+
# misc.save_obj(save_dir, tmp_shape, save_material=False, feat=feat, suffix="mesh", fnames=fnames) # Save the first mesh.
|
1430 |
+
# if self.render_flow and flow_gt is not None:
|
1431 |
+
# flow_gt_viz = torch.cat([flow_gt, torch.zeros_like(flow_gt[:,:,:1])], 2) + 0.5 # -0.5~1.5
|
1432 |
+
# flow_gt_viz = flow_gt_viz.view(-1, *flow_gt_viz.shape[2:])
|
1433 |
+
# misc.save_images(save_dir, flow_gt_viz[:b0].clamp(0,1).detach().cpu().numpy(), suffix='flow_gt', fnames=fnames)
|
1434 |
+
# if flow_pred is not None:
|
1435 |
+
# flow_pred_viz = torch.cat([flow_pred, torch.zeros_like(flow_pred[:,:,:1])], 2) + 0.5 # -0.5~1.5
|
1436 |
+
# flow_pred_viz = flow_pred_viz.view(-1, *flow_pred_viz.shape[2:])
|
1437 |
+
# misc.save_images(save_dir, flow_pred_viz[:b0].clamp(0,1).detach().cpu().numpy(), suffix='flow_pred', fnames=fnames)
|
1438 |
+
|
1439 |
+
misc.save_txt(save_dir, pose[:b0].detach().cpu().numpy(), suffix='pose', fnames=fnames)
|
1440 |
+
|
1441 |
+
return metrics
|
1442 |
+
|
1443 |
+
def save_scores(self, path):
|
1444 |
+
header = 'mask_mse, \
|
1445 |
+
mask_iou, \
|
1446 |
+
image_mse, \
|
1447 |
+
flow_mse'
|
1448 |
+
mean = self.all_scores.mean(0)
|
1449 |
+
std = self.all_scores.std(0)
|
1450 |
+
header = header + '\nMean: ' + ',\t'.join(['%.8f'%x for x in mean])
|
1451 |
+
header = header + '\nStd: ' + ',\t'.join(['%.8f'%x for x in std])
|
1452 |
+
misc.save_scores(path, self.all_scores, header=header)
|
1453 |
+
print(header)
|
1454 |
+
|
1455 |
+
def render_rotation_frames(self, mesh, texture, light, resolution, background='none', im_features=None, prior_shape=None, num_frames=36, render_mode='diffuse', b=None):
|
1456 |
+
frames = []
|
1457 |
+
if b is None:
|
1458 |
+
b = len(mesh)
|
1459 |
+
else:
|
1460 |
+
mesh = mesh.first_n(b)
|
1461 |
+
feat = im_features[:b] if im_features is not None else None
|
1462 |
+
|
1463 |
+
delta_angle = np.pi / num_frames * 2
|
1464 |
+
delta_rot_matrix = torch.FloatTensor([
|
1465 |
+
[np.cos(delta_angle), 0, np.sin(delta_angle), 0],
|
1466 |
+
[0, 1, 0, 0],
|
1467 |
+
[-np.sin(delta_angle), 0, np.cos(delta_angle), 0],
|
1468 |
+
[0, 0, 0, 1],
|
1469 |
+
]).to(self.device).repeat(b, 1, 1)
|
1470 |
+
|
1471 |
+
w2c = torch.FloatTensor(np.diag([1., 1., 1., 1]))
|
1472 |
+
w2c[:3, 3] = torch.FloatTensor([0, 0, -self.cam_pos_z_offset *1.1])
|
1473 |
+
w2c = w2c.repeat(b, 1, 1).to(self.device)
|
1474 |
+
proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device)
|
1475 |
+
mvp = torch.bmm(proj, w2c)
|
1476 |
+
campos = -w2c[:, :3, 3]
|
1477 |
+
|
1478 |
+
def rotate_pose(mvp, campos):
|
1479 |
+
mvp = torch.matmul(mvp, delta_rot_matrix)
|
1480 |
+
campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), campos[:,:,None])[:,:,0]
|
1481 |
+
return mvp, campos
|
1482 |
+
|
1483 |
+
for _ in range(num_frames):
|
1484 |
+
image_pred, _, _, _, _, _ = self.render(mesh, texture, mvp, w2c, campos, resolution, background=background, im_features=feat, light=light, prior_shape=prior_shape, render_flow=False, dino_pred=None, render_mode=render_mode, two_sided_shading=False)
|
1485 |
+
frames += [misc.image_grid(image_pred)]
|
1486 |
+
mvp, campos = rotate_pose(mvp, campos)
|
1487 |
+
return torch.stack(frames, dim=0) # Shape: (T, C, H, W)
|
1488 |
+
|
1489 |
+
def render_bones(self, mvp, bones_pred, size=(256, 256)):
|
1490 |
+
bone_world4 = torch.concat([bones_pred, torch.ones_like(bones_pred[..., :1]).to(bones_pred.device)], dim=-1)
|
1491 |
+
b, f, num_bones = bone_world4.shape[:3]
|
1492 |
+
bones_clip4 = (bone_world4.view(b, f, num_bones*2, 1, 4) @ mvp.transpose(-1, -2).reshape(b, f, 1, 4, 4)).view(b, f, num_bones, 2, 4)
|
1493 |
+
bones_uv = bones_clip4[..., :2] / bones_clip4[..., 3:4] # b, f, num_bones, 2, 2
|
1494 |
+
dpi = 32
|
1495 |
+
fx, fy = size[1] // dpi, size[0] // dpi
|
1496 |
+
|
1497 |
+
rendered = []
|
1498 |
+
for b_idx in range(b):
|
1499 |
+
for f_idx in range(f):
|
1500 |
+
frame_bones_uv = bones_uv[b_idx, f_idx].cpu().numpy()
|
1501 |
+
fig = plt.figure(figsize=(fx, fy), dpi=dpi, frameon=False)
|
1502 |
+
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
1503 |
+
ax.set_axis_off()
|
1504 |
+
for bone in frame_bones_uv:
|
1505 |
+
ax.plot(bone[:, 0], bone[:, 1], marker='o', linewidth=8, markersize=20)
|
1506 |
+
ax.set_xlim(-1, 1)
|
1507 |
+
ax.set_ylim(-1, 1)
|
1508 |
+
ax.invert_yaxis()
|
1509 |
+
# Convert to image
|
1510 |
+
fig.add_axes(ax)
|
1511 |
+
fig.canvas.draw_idle()
|
1512 |
+
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
1513 |
+
w, h = fig.canvas.get_width_height()
|
1514 |
+
image.resize(h, w, 3)
|
1515 |
+
rendered += [image / 255.]
|
1516 |
+
return torch.from_numpy(np.stack(rendered, 0).transpose(0, 3, 1, 2))
|
1517 |
+
|
1518 |
+
def render_deformation_frames(self, mesh, texture, batch_size, num_frames, resolution, background='none', im_features=None, render_mode='diffuse', b=None):
|
1519 |
+
# frames = []
|
1520 |
+
# if b is None:
|
1521 |
+
# b = batch_size
|
1522 |
+
# im_features = im_features[]
|
1523 |
+
# mesh = mesh.first_n(num_frames * b)
|
1524 |
+
# for i in range(b):
|
1525 |
+
# tmp_mesh = mesh.get_m_to_n(i*num_frames:(i+1)*num_frames)
|
1526 |
+
pass
|
video3d/model_ddp.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
video3d/networks.py
ADDED
@@ -0,0 +1,1724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torchvision
|
5 |
+
import torchvision.models as models
|
6 |
+
from typing import Union, List, Tuple
|
7 |
+
import os
|
8 |
+
import video3d.utils.misc as misc
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from siren_pytorch import SirenNet
|
11 |
+
from video3d.triplane_texture.lift_architecture import Lift_Encoder
|
12 |
+
from video3d.triplane_texture.triplane_transformer import Triplane_Transformer
|
13 |
+
|
14 |
+
|
15 |
+
EPS = 1e-7
|
16 |
+
|
17 |
+
|
18 |
+
def get_activation(name, inplace=True, lrelu_param=0.2):
|
19 |
+
if name == 'tanh':
|
20 |
+
return nn.Tanh()
|
21 |
+
elif name == 'sigmoid':
|
22 |
+
return nn.Sigmoid()
|
23 |
+
elif name == 'relu':
|
24 |
+
return nn.ReLU(inplace=inplace)
|
25 |
+
elif name == 'lrelu':
|
26 |
+
return nn.LeakyReLU(lrelu_param, inplace=inplace)
|
27 |
+
else:
|
28 |
+
raise NotImplementedError
|
29 |
+
|
30 |
+
|
31 |
+
class MLPWithPositionalEncoding(nn.Module):
|
32 |
+
def __init__(self,
|
33 |
+
cin,
|
34 |
+
cout,
|
35 |
+
num_layers,
|
36 |
+
nf=256,
|
37 |
+
dropout=0,
|
38 |
+
activation=None,
|
39 |
+
n_harmonic_functions=10,
|
40 |
+
omega0=1,
|
41 |
+
extra_dim=0,
|
42 |
+
embed_concat_pts=True,
|
43 |
+
symmetrize=False):
|
44 |
+
super().__init__()
|
45 |
+
self.extra_dim = extra_dim
|
46 |
+
|
47 |
+
if n_harmonic_functions > 0:
|
48 |
+
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
|
49 |
+
dim_in = cin * 2 * n_harmonic_functions
|
50 |
+
self.embed_concat_pts = embed_concat_pts
|
51 |
+
if embed_concat_pts:
|
52 |
+
dim_in += cin
|
53 |
+
else:
|
54 |
+
self.embedder = None
|
55 |
+
dim_in = cin
|
56 |
+
|
57 |
+
self.in_layer = nn.Linear(dim_in, nf)
|
58 |
+
self.relu = nn.ReLU(inplace=True)
|
59 |
+
self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation)
|
60 |
+
self.symmetrize = symmetrize
|
61 |
+
|
62 |
+
def forward(self, x, feat=None):
|
63 |
+
assert (feat is None and self.extra_dim == 0) or feat.shape[-1] == self.extra_dim
|
64 |
+
if self.symmetrize:
|
65 |
+
xs, ys, zs = x.unbind(-1)
|
66 |
+
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
|
67 |
+
|
68 |
+
if self.embedder is not None:
|
69 |
+
x_in = self.embedder(x)
|
70 |
+
if self.embed_concat_pts:
|
71 |
+
x_in = torch.cat([x, x_in], -1)
|
72 |
+
else:
|
73 |
+
x_in = x
|
74 |
+
|
75 |
+
x_in = self.relu(self.in_layer(x_in))
|
76 |
+
|
77 |
+
if feat is not None:
|
78 |
+
# if len(feat.shape) == 1:
|
79 |
+
# for _ in range(len(x_in.shape) - 1):
|
80 |
+
# feat = feat.unsqueeze(0)
|
81 |
+
# feat = feat.repeat(*x_in.shape[:-1], 1)
|
82 |
+
x_in = torch.concat([x_in, feat], dim=-1)
|
83 |
+
|
84 |
+
return self.mlp(x_in)
|
85 |
+
|
86 |
+
|
87 |
+
class MLPWithPositionalEncoding_Style(nn.Module):
|
88 |
+
def __init__(self,
|
89 |
+
cin,
|
90 |
+
cout,
|
91 |
+
num_layers,
|
92 |
+
nf=256,
|
93 |
+
dropout=0,
|
94 |
+
activation=None,
|
95 |
+
n_harmonic_functions=10,
|
96 |
+
omega0=1,
|
97 |
+
extra_dim=0,
|
98 |
+
embed_concat_pts=True,
|
99 |
+
symmetrize=False,
|
100 |
+
style_choice='film'):
|
101 |
+
super().__init__()
|
102 |
+
self.extra_dim = extra_dim
|
103 |
+
|
104 |
+
if n_harmonic_functions > 0:
|
105 |
+
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
|
106 |
+
dim_in = cin * 2 * n_harmonic_functions
|
107 |
+
self.embed_concat_pts = embed_concat_pts
|
108 |
+
if embed_concat_pts:
|
109 |
+
dim_in += cin
|
110 |
+
else:
|
111 |
+
self.embedder = None
|
112 |
+
dim_in = cin
|
113 |
+
|
114 |
+
self.in_layer = nn.Linear(dim_in, nf)
|
115 |
+
self.relu = nn.ReLU(inplace=True)
|
116 |
+
|
117 |
+
if extra_dim == 0:
|
118 |
+
self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation)
|
119 |
+
|
120 |
+
else:
|
121 |
+
if style_choice == 'film':
|
122 |
+
self.mlp = MLP_FiLM(nf, cout, num_layers, nf, dropout, activation)
|
123 |
+
self.style_mlp = MLP(extra_dim, nf*2, 2, nf, dropout, None)
|
124 |
+
|
125 |
+
elif style_choice == 'mod':
|
126 |
+
self.mlp = MLP_Mod(nf, cout, num_layers, nf, dropout, activation)
|
127 |
+
self.style_mlp = MLP(extra_dim, nf, 2, nf, dropout, None)
|
128 |
+
|
129 |
+
else:
|
130 |
+
raise NotImplementedError
|
131 |
+
|
132 |
+
self.style_choice = style_choice
|
133 |
+
|
134 |
+
self.symmetrize = symmetrize
|
135 |
+
|
136 |
+
def forward(self, x, feat=None):
|
137 |
+
assert (feat is None and self.extra_dim == 0) or feat.shape[-1] == self.extra_dim
|
138 |
+
if self.symmetrize:
|
139 |
+
xs, ys, zs = x.unbind(-1)
|
140 |
+
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
|
141 |
+
|
142 |
+
if self.embedder is not None:
|
143 |
+
x_in = self.embedder(x)
|
144 |
+
if self.embed_concat_pts:
|
145 |
+
x_in = torch.cat([x, x_in], -1)
|
146 |
+
else:
|
147 |
+
x_in = x
|
148 |
+
|
149 |
+
x_in = self.relu(self.in_layer(x_in))
|
150 |
+
|
151 |
+
if feat is not None:
|
152 |
+
style = self.style_mlp(feat)
|
153 |
+
|
154 |
+
if self.style_choice == 'film':
|
155 |
+
style = style.reshape(style.shape[:-1] + (-1, 2))
|
156 |
+
|
157 |
+
out = self.mlp(x_in, style)
|
158 |
+
|
159 |
+
else:
|
160 |
+
out = self.mlp(x_in)
|
161 |
+
|
162 |
+
return out
|
163 |
+
|
164 |
+
|
165 |
+
class MLP_FiLM(nn.Module):
|
166 |
+
def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None):
|
167 |
+
# default no dropout
|
168 |
+
super().__init__()
|
169 |
+
assert num_layers >= 1
|
170 |
+
self.num_layers = num_layers
|
171 |
+
if num_layers == 1:
|
172 |
+
self.network = Linear_FiLM(cin, cout, bias=False)
|
173 |
+
else:
|
174 |
+
self.relu = nn.ReLU(inplace=True)
|
175 |
+
for i in range(num_layers):
|
176 |
+
if i == 0:
|
177 |
+
setattr(self, f'linear_{i}', Linear_FiLM(cin, nf, bias=False))
|
178 |
+
elif i == (num_layers-1):
|
179 |
+
setattr(self, f'linear_{i}', Linear_FiLM(nf, cout, bias=False))
|
180 |
+
else:
|
181 |
+
setattr(self, f'linear_{i}', Linear_FiLM(nf, nf, bias=False))
|
182 |
+
|
183 |
+
def forward(self, input, style):
|
184 |
+
if self.num_layers == 1:
|
185 |
+
out = self.network(input, style)
|
186 |
+
else:
|
187 |
+
x = input
|
188 |
+
for i in range(self.num_layers):
|
189 |
+
linear_layer = getattr(self, f'linear_{i}')
|
190 |
+
if i == (self.num_layers - 1):
|
191 |
+
x = linear_layer(x, style)
|
192 |
+
else:
|
193 |
+
x = linear_layer(x, style)
|
194 |
+
x = self.relu(x)
|
195 |
+
|
196 |
+
out = x
|
197 |
+
return out
|
198 |
+
|
199 |
+
|
200 |
+
class MLP_Mod(nn.Module):
|
201 |
+
def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None):
|
202 |
+
# default no dropout
|
203 |
+
super().__init__()
|
204 |
+
assert num_layers >= 1
|
205 |
+
self.num_layers = num_layers
|
206 |
+
if num_layers == 1:
|
207 |
+
self.network = Linear_Mod(cin, cout, bias=False)
|
208 |
+
else:
|
209 |
+
self.relu = nn.ReLU(inplace=True)
|
210 |
+
for i in range(num_layers):
|
211 |
+
if i == 0:
|
212 |
+
setattr(self, f'linear_{i}', Linear_Mod(cin, nf, bias=False))
|
213 |
+
elif i == (num_layers-1):
|
214 |
+
setattr(self, f'linear_{i}', Linear_Mod(nf, cout, bias=False))
|
215 |
+
else:
|
216 |
+
setattr(self, f'linear_{i}', Linear_Mod(nf, nf, bias=False))
|
217 |
+
|
218 |
+
def forward(self, input, style):
|
219 |
+
if self.num_layers == 1:
|
220 |
+
out = self.network(input, style)
|
221 |
+
else:
|
222 |
+
x = input
|
223 |
+
for i in range(self.num_layers):
|
224 |
+
linear_layer = getattr(self, f'linear_{i}')
|
225 |
+
if i == (self.num_layers - 1):
|
226 |
+
x = linear_layer(x, style)
|
227 |
+
else:
|
228 |
+
x = linear_layer(x, style)
|
229 |
+
x = self.relu(x)
|
230 |
+
|
231 |
+
out = x
|
232 |
+
return out
|
233 |
+
|
234 |
+
|
235 |
+
import math
|
236 |
+
|
237 |
+
class Linear_FiLM(nn.Module):
|
238 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
239 |
+
device=None, dtype=None) -> None:
|
240 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
241 |
+
super().__init__()
|
242 |
+
self.in_features = in_features
|
243 |
+
self.out_features = out_features
|
244 |
+
self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
245 |
+
if bias:
|
246 |
+
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
247 |
+
else:
|
248 |
+
self.register_parameter('bias', None)
|
249 |
+
self.reset_parameters()
|
250 |
+
|
251 |
+
def reset_parameters(self) -> None:
|
252 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
253 |
+
if self.bias is not None:
|
254 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
255 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
256 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
257 |
+
|
258 |
+
def forward(self, input, style):
|
259 |
+
# if input is [..., D], style should be [..., D, 2]
|
260 |
+
x = input * style[..., 0] + style[..., 1]
|
261 |
+
return torch.nn.functional.linear(x, self.weight, self.bias)
|
262 |
+
|
263 |
+
def extra_repr(self) -> str:
|
264 |
+
return 'in_features={}, out_features={}, bias={}'.format(
|
265 |
+
self.in_features, self.out_features, self.bias is not None
|
266 |
+
)
|
267 |
+
|
268 |
+
|
269 |
+
class Linear_Mod(nn.Module):
|
270 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
271 |
+
device=None, dtype=None) -> None:
|
272 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
273 |
+
super().__init__()
|
274 |
+
self.in_features = in_features
|
275 |
+
self.out_features = out_features
|
276 |
+
self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
277 |
+
if bias:
|
278 |
+
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
279 |
+
else:
|
280 |
+
self.register_parameter('bias', None)
|
281 |
+
self.reset_parameters()
|
282 |
+
|
283 |
+
def reset_parameters(self) -> None:
|
284 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
285 |
+
if self.bias is not None:
|
286 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
287 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
288 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
289 |
+
|
290 |
+
def forward(self, input, style):
|
291 |
+
# weight: [out_features, in_features]
|
292 |
+
# style: [..., in_features]
|
293 |
+
if len(style.shape) > 1:
|
294 |
+
style = style.reshape(-1, style.shape[-1])
|
295 |
+
style = style[0]
|
296 |
+
|
297 |
+
weight = self.weight * style.unsqueeze(0)
|
298 |
+
decoefs = ((weight * weight).sum(dim=-1, keepdim=True) + 1e-5).sqrt()
|
299 |
+
weight = weight / decoefs
|
300 |
+
|
301 |
+
return torch.nn.functional.linear(input, weight, self.bias)
|
302 |
+
|
303 |
+
def extra_repr(self) -> str:
|
304 |
+
return 'in_features={}, out_features={}, bias={}'.format(
|
305 |
+
self.in_features, self.out_features, self.bias is not None
|
306 |
+
)
|
307 |
+
|
308 |
+
|
309 |
+
class MLPTextureSimple(nn.Module):
|
310 |
+
def __init__(self,
|
311 |
+
cin,
|
312 |
+
cout,
|
313 |
+
num_layers,
|
314 |
+
nf=256,
|
315 |
+
dropout=0,
|
316 |
+
activation=None,
|
317 |
+
min_max=None,
|
318 |
+
n_harmonic_functions=10,
|
319 |
+
omega0=1,
|
320 |
+
extra_dim=0,
|
321 |
+
embed_concat_pts=True,
|
322 |
+
perturb_normal=False,
|
323 |
+
symmetrize=False,
|
324 |
+
texture_act='relu',
|
325 |
+
linear_bias=False):
|
326 |
+
super().__init__()
|
327 |
+
self.extra_dim = extra_dim
|
328 |
+
|
329 |
+
if n_harmonic_functions > 0:
|
330 |
+
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
|
331 |
+
dim_in = cin * 2 * n_harmonic_functions
|
332 |
+
self.embed_concat_pts = embed_concat_pts
|
333 |
+
if embed_concat_pts:
|
334 |
+
dim_in += cin
|
335 |
+
else:
|
336 |
+
self.embedder = None
|
337 |
+
dim_in = cin
|
338 |
+
|
339 |
+
self.in_layer = nn.Linear(dim_in, nf)
|
340 |
+
self.relu = nn.ReLU(inplace=True)
|
341 |
+
|
342 |
+
if texture_act == 'sin':
|
343 |
+
print('using siren network for texture mlp here')
|
344 |
+
self.mlp = SirenNet(
|
345 |
+
dim_in=(nf + extra_dim),
|
346 |
+
dim_hidden=nf,
|
347 |
+
dim_out=cout,
|
348 |
+
num_layers=num_layers,
|
349 |
+
final_activation=get_activation(activation),
|
350 |
+
w0_initial=30,
|
351 |
+
use_bias=linear_bias,
|
352 |
+
dropout=dropout
|
353 |
+
)
|
354 |
+
else:
|
355 |
+
self.mlp = MLP(nf + extra_dim, cout, num_layers, nf, dropout, activation, inner_act=texture_act, linear_bias=linear_bias)
|
356 |
+
self.perturb_normal = perturb_normal
|
357 |
+
self.symmetrize = symmetrize
|
358 |
+
if min_max is not None:
|
359 |
+
self.register_buffer('min_max', min_max)
|
360 |
+
else:
|
361 |
+
self.min_max = None
|
362 |
+
self.bsdf = None
|
363 |
+
|
364 |
+
def sample(self, x, feat=None):
|
365 |
+
assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] == self.extra_dim)
|
366 |
+
b, h, w, c = x.shape
|
367 |
+
|
368 |
+
if self.symmetrize:
|
369 |
+
xs, ys, zs = x.unbind(-1)
|
370 |
+
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
|
371 |
+
|
372 |
+
x = x.view(-1, c)
|
373 |
+
if self.embedder is not None:
|
374 |
+
x_in = self.embedder(x)
|
375 |
+
if self.embed_concat_pts:
|
376 |
+
x_in = torch.cat([x, x_in], -1)
|
377 |
+
else:
|
378 |
+
x_in = x
|
379 |
+
|
380 |
+
x_in = self.in_layer(x_in)
|
381 |
+
if feat is not None:
|
382 |
+
feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
|
383 |
+
x_in = torch.concat([x_in, feat], dim=-1)
|
384 |
+
out = self.mlp(self.relu(x_in))
|
385 |
+
if self.min_max is not None:
|
386 |
+
out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
|
387 |
+
return out.view(b, h, w, -1)
|
388 |
+
|
389 |
+
|
390 |
+
class MLPTextureTriplane(nn.Module):
|
391 |
+
def __init__(self,
|
392 |
+
cin,
|
393 |
+
cout,
|
394 |
+
num_layers,
|
395 |
+
nf=256,
|
396 |
+
dropout=0,
|
397 |
+
activation=None,
|
398 |
+
min_max=None,
|
399 |
+
n_harmonic_functions=10,
|
400 |
+
omega0=1,
|
401 |
+
extra_dim=0,
|
402 |
+
embed_concat_pts=True,
|
403 |
+
perturb_normal=False,
|
404 |
+
symmetrize=False,
|
405 |
+
texture_act='relu',
|
406 |
+
linear_bias=False,
|
407 |
+
cam_pos_z_offset=10.,
|
408 |
+
grid_scale=7,):
|
409 |
+
super().__init__()
|
410 |
+
self.extra_dim = extra_dim
|
411 |
+
|
412 |
+
if n_harmonic_functions > 0:
|
413 |
+
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
|
414 |
+
dim_in = cin * 2 * n_harmonic_functions
|
415 |
+
self.embed_concat_pts = embed_concat_pts
|
416 |
+
if embed_concat_pts:
|
417 |
+
dim_in += cin
|
418 |
+
else:
|
419 |
+
self.embedder = None
|
420 |
+
dim_in = cin
|
421 |
+
|
422 |
+
self.in_layer = nn.Linear(dim_in, nf)
|
423 |
+
self.relu = nn.ReLU(inplace=True)
|
424 |
+
|
425 |
+
self.feat_net = Triplane_Transformer(
|
426 |
+
emb_dim=256,
|
427 |
+
num_layers=8,
|
428 |
+
triplane_dim=80,
|
429 |
+
triplane_scale=grid_scale
|
430 |
+
)
|
431 |
+
self.extra_dim -= extra_dim
|
432 |
+
self.extra_dim += (self.feat_net.triplane_dim * 3)
|
433 |
+
|
434 |
+
if texture_act == 'sin':
|
435 |
+
print('using siren network for texture mlp here')
|
436 |
+
self.mlp = SirenNet(
|
437 |
+
dim_in=(nf + self.extra_dim),
|
438 |
+
dim_hidden=nf,
|
439 |
+
dim_out=cout,
|
440 |
+
num_layers=num_layers,
|
441 |
+
final_activation=get_activation(activation),
|
442 |
+
w0_initial=30,
|
443 |
+
use_bias=linear_bias,
|
444 |
+
dropout=dropout
|
445 |
+
)
|
446 |
+
else:
|
447 |
+
self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation, inner_act=texture_act, linear_bias=linear_bias)
|
448 |
+
self.perturb_normal = perturb_normal
|
449 |
+
self.symmetrize = symmetrize
|
450 |
+
if min_max is not None:
|
451 |
+
self.register_buffer('min_max', min_max)
|
452 |
+
else:
|
453 |
+
self.min_max = None
|
454 |
+
self.bsdf = None
|
455 |
+
|
456 |
+
def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None):
|
457 |
+
# assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] == self.extra_dim)
|
458 |
+
b, h, w, c = x.shape
|
459 |
+
|
460 |
+
if self.symmetrize:
|
461 |
+
xs, ys, zs = x.unbind(-1)
|
462 |
+
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
|
463 |
+
|
464 |
+
if isinstance(feat_map, dict):
|
465 |
+
feat_map = feat_map["im_features_map"]
|
466 |
+
|
467 |
+
feat_map = feat_map.permute(0, 2, 3, 1)
|
468 |
+
_, ph, pw, _ = feat_map.shape
|
469 |
+
feat_map = feat_map.reshape(feat_map.shape[0], ph*pw, feat_map.shape[-1])
|
470 |
+
pts_feat = self.feat_net(feat_map, x.reshape(b, -1, 3))
|
471 |
+
pts_c = pts_feat.shape[-1]
|
472 |
+
pts_feat = pts_feat.reshape(-1, pts_c)
|
473 |
+
|
474 |
+
x = x.view(-1, c)
|
475 |
+
if self.embedder is not None:
|
476 |
+
x_in = self.embedder(x)
|
477 |
+
if self.embed_concat_pts:
|
478 |
+
x_in = torch.cat([x, x_in], -1)
|
479 |
+
else:
|
480 |
+
x_in = x
|
481 |
+
|
482 |
+
x_in = self.in_layer(x_in)
|
483 |
+
|
484 |
+
x_in = torch.concat([x_in, pts_feat], dim=-1)
|
485 |
+
|
486 |
+
out = self.mlp(self.relu(x_in))
|
487 |
+
if self.min_max is not None:
|
488 |
+
out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
|
489 |
+
return out.view(b, h, w, -1)
|
490 |
+
|
491 |
+
|
492 |
+
class LocalFeatureBlock(nn.Module):
|
493 |
+
def __init__(self, local_feat_dim, input_dim=384, output_dim=384, upscale_num=3):
|
494 |
+
super().__init__()
|
495 |
+
self.local_feat_dim = local_feat_dim
|
496 |
+
self.conv_list = nn.ModuleList([])
|
497 |
+
self.upscale_list = nn.ModuleList([])
|
498 |
+
|
499 |
+
for i in range(upscale_num):
|
500 |
+
if i == 0:
|
501 |
+
self.conv_list.append(nn.Conv2d(input_dim, 4 * local_feat_dim, 3, stride=1, padding=1, dilation=1))
|
502 |
+
else:
|
503 |
+
self.conv_list.append(nn.Conv2d(local_feat_dim, 4 * local_feat_dim, 3, stride=1, padding=1, dilation=1))
|
504 |
+
self.upscale_list.append(nn.PixelShuffle(2))
|
505 |
+
|
506 |
+
self.conv_head = nn.Conv2d(local_feat_dim, output_dim, 3, stride=1, padding=1, dilation=1)
|
507 |
+
|
508 |
+
def forward(self, x):
|
509 |
+
for idx, conv in enumerate(self.conv_list):
|
510 |
+
x = conv(x)
|
511 |
+
x = self.upscale_list[idx](x)
|
512 |
+
|
513 |
+
out = self.conv_head(x)
|
514 |
+
return out
|
515 |
+
|
516 |
+
|
517 |
+
class MLPTextureLocal(nn.Module):
|
518 |
+
def __init__(self,
|
519 |
+
cin,
|
520 |
+
cout,
|
521 |
+
num_layers,
|
522 |
+
nf=256,
|
523 |
+
dropout=0,
|
524 |
+
activation=None,
|
525 |
+
min_max=None,
|
526 |
+
n_harmonic_functions=10,
|
527 |
+
omega0=1,
|
528 |
+
extra_dim=0,
|
529 |
+
embed_concat_pts=True,
|
530 |
+
perturb_normal=False,
|
531 |
+
symmetrize=False,
|
532 |
+
texture_way=None,
|
533 |
+
larger_tex_dim=False,
|
534 |
+
cam_pos_z_offset=10.,
|
535 |
+
grid_scale=7.):
|
536 |
+
super().__init__()
|
537 |
+
self.extra_dim = extra_dim
|
538 |
+
self.cam_pos_z_offset = cam_pos_z_offset
|
539 |
+
self.grid_scale = grid_scale
|
540 |
+
|
541 |
+
local_feat_dim = 64
|
542 |
+
|
543 |
+
assert texture_way is not None
|
544 |
+
self.texture_way = texture_way
|
545 |
+
if 'local' in texture_way and 'global' in texture_way:
|
546 |
+
# self.extra_dim = extra_dim + local_feat_dim
|
547 |
+
self.extra_dim = extra_dim
|
548 |
+
elif 'local' in texture_way and 'global' not in texture_way:
|
549 |
+
# self.extra_dim = local_feat_dim
|
550 |
+
self.extra_dim = extra_dim
|
551 |
+
elif 'local' not in texture_way and 'global' in texture_way:
|
552 |
+
self.extra_dim = extra_dim
|
553 |
+
|
554 |
+
if n_harmonic_functions > 0:
|
555 |
+
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
|
556 |
+
dim_in = cin * 2 * n_harmonic_functions
|
557 |
+
self.embed_concat_pts = embed_concat_pts
|
558 |
+
if embed_concat_pts:
|
559 |
+
dim_in += cin
|
560 |
+
else:
|
561 |
+
self.embedder = None
|
562 |
+
dim_in = cin
|
563 |
+
|
564 |
+
# self.local_feature_block = LocalFeatureBlock(local_feat_dim=local_feat_dim, input_dim=384, output_dim=256)
|
565 |
+
self.local_feature_block = nn.Linear(384, nf, bias=False)
|
566 |
+
|
567 |
+
self.in_layer = nn.Linear(dim_in, nf)
|
568 |
+
self.relu = nn.ReLU(inplace=True)
|
569 |
+
self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation)
|
570 |
+
self.perturb_normal = perturb_normal
|
571 |
+
self.symmetrize = symmetrize
|
572 |
+
if min_max is not None:
|
573 |
+
self.register_buffer('min_max', min_max)
|
574 |
+
else:
|
575 |
+
self.min_max = None
|
576 |
+
self.bsdf = None
|
577 |
+
|
578 |
+
def get_uv_depth(self, xyz, mvp):
|
579 |
+
# xyz: [b, k, 3]
|
580 |
+
# mvp: [b, 4, 4]
|
581 |
+
cam4 = torch.matmul(torch.nn.functional.pad(xyz, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2))
|
582 |
+
cam3 = cam4[..., :3] / cam4[..., 3:4]
|
583 |
+
cam_uv = cam3[..., :2]
|
584 |
+
# cam_uv = cam_uv.detach()
|
585 |
+
cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(xyz.device).view(1, 1, 3)
|
586 |
+
cam_depth = cam_depth / self.grid_scale * 2
|
587 |
+
cam_depth = cam_depth[..., 2:3]
|
588 |
+
# cam_depth = cam_depth.detach()
|
589 |
+
return cam_uv, cam_depth
|
590 |
+
|
591 |
+
def proj_sample_deform(self, xyz, feat_map, mvp, w2c, img_h, img_w):
|
592 |
+
# here the xyz is deformed points
|
593 |
+
# and we don't cast any symmtery here
|
594 |
+
b, k, c = xyz.shape
|
595 |
+
THRESHOLD = 1e-4
|
596 |
+
if isinstance(feat_map, torch.Tensor):
|
597 |
+
coordinates = xyz
|
598 |
+
# use pre-symmetry points to get feature and record depth
|
599 |
+
cam_uv, cam_depth = self.get_uv_depth(coordinates, mvp)
|
600 |
+
cam_uv = cam_uv.detach()
|
601 |
+
cam_depth = cam_depth.detach()
|
602 |
+
|
603 |
+
# get local feature
|
604 |
+
feature = F.grid_sample(feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
|
605 |
+
|
606 |
+
self.input_depth = cam_depth.reshape(b, 256, 256, 1) # [B, 256, 256, 1]
|
607 |
+
self.input_pts = coordinates.detach()
|
608 |
+
|
609 |
+
elif isinstance(feat_map, dict):
|
610 |
+
original_mvp = feat_map['original_mvp']
|
611 |
+
local_feat_map = feat_map['im_features_map']
|
612 |
+
original_depth = self.input_depth[0:b]
|
613 |
+
|
614 |
+
coordinates = xyz
|
615 |
+
cam_uv, cam_depth = self.get_uv_depth(coordinates, original_mvp)
|
616 |
+
cam_uv = cam_uv.detach()
|
617 |
+
cam_depth = cam_depth.detach()
|
618 |
+
|
619 |
+
project_feature = F.grid_sample(local_feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
|
620 |
+
project_depth = F.grid_sample(original_depth.permute(0, 3, 1, 2), cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1]
|
621 |
+
|
622 |
+
use_mask = cam_depth <= project_depth + THRESHOLD
|
623 |
+
feature = project_feature * use_mask.repeat(1, 1, project_feature.shape[-1])
|
624 |
+
|
625 |
+
ret_feature = self.local_feature_block(feature.reshape(b*k, -1)) # the linear is without bias, so 0 value feature will still get 0 value
|
626 |
+
return ret_feature
|
627 |
+
|
628 |
+
def proj_sample(self, xyz, feat_map, mvp, w2c, img_h, img_w, xyz_before_sym=None):
|
629 |
+
# the new one with no input feature map upsampling
|
630 |
+
# feat_map: [B, C, H, W]
|
631 |
+
b, k, c = xyz.shape
|
632 |
+
if isinstance(feat_map, torch.Tensor):
|
633 |
+
if xyz_before_sym is None:
|
634 |
+
coordinates = xyz
|
635 |
+
else:
|
636 |
+
coordinates = xyz_before_sym
|
637 |
+
# use pre-symmetry points to get feature and record depth
|
638 |
+
cam_uv, cam_depth = self.get_uv_depth(coordinates, mvp)
|
639 |
+
cam_uv = cam_uv.detach()
|
640 |
+
cam_depth = cam_depth.detach()
|
641 |
+
|
642 |
+
# get local feature
|
643 |
+
feature = F.grid_sample(feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
|
644 |
+
|
645 |
+
self.input_depth = cam_depth.reshape(b, 256, 256, 1) # [B, 256, 256, 1]
|
646 |
+
self.input_pts = coordinates.detach()
|
647 |
+
|
648 |
+
elif isinstance(feat_map, dict):
|
649 |
+
original_mvp = feat_map['original_mvp']
|
650 |
+
local_feat_map = feat_map['im_features_map']
|
651 |
+
THRESHOLD = 1e-4
|
652 |
+
original_depth = self.input_depth[0:b]
|
653 |
+
# if b == 1:
|
654 |
+
# from pdb import set_trace; set_trace()
|
655 |
+
# tmp_mask = xyz[0].reshape(256, 256, 3).sum(dim=-1) != 0
|
656 |
+
# tmp_mask = tmp_mask.cpu().numpy()
|
657 |
+
# tmp_mask = tmp_mask * 255
|
658 |
+
# src_dp = self.input_depth[0,:,:,0].cpu().numpy()
|
659 |
+
# input_pts = self.input_pts[0].cpu().numpy()
|
660 |
+
# input_mask = self.input_pts[0].reshape(256, 256, 3).sum(dim=-1) != 0
|
661 |
+
# input_mask = input_mask.int().cpu().numpy()
|
662 |
+
# input_mask = input_mask * 255
|
663 |
+
# np.save('./tmp_save/src_dp.npy', src_dp)
|
664 |
+
# np.save('./tmp_save/input_pts.npy', input_pts)
|
665 |
+
# import cv2
|
666 |
+
# cv2.imwrite('./tmp_save/input_mask.png', input_mask)
|
667 |
+
# cv2.imwrite('./tmp_save/mask.png', tmp_mask)
|
668 |
+
# test_pts_pos = xyz[0].cpu().numpy()
|
669 |
+
# np.save('./tmp_save/test_pts_pos.npy', test_pts_pos)
|
670 |
+
# test_pts_raw = xyz_before_sym[0].cpu().numpy()
|
671 |
+
# np.save('./tmp_save/test_pts_raw.npy', test_pts_raw)
|
672 |
+
# mvp_now = mvp[0].detach().cpu().numpy()
|
673 |
+
# mvp_original = original_mvp[0].detach().cpu().numpy()
|
674 |
+
# np.save('./tmp_save/mvp_now.npy', mvp_now)
|
675 |
+
# np.save('./tmp_save/mvp_original.npy', mvp_original)
|
676 |
+
if xyz_before_sym is None:
|
677 |
+
# just check the project depth of xyz
|
678 |
+
coordinates = xyz
|
679 |
+
cam_uv, cam_depth = self.get_uv_depth(coordinates, original_mvp)
|
680 |
+
cam_uv = cam_uv.detach()
|
681 |
+
cam_depth = cam_depth.detach()
|
682 |
+
|
683 |
+
project_feature = F.grid_sample(local_feat_map, cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
|
684 |
+
project_depth = F.grid_sample(original_depth.permute(0, 3, 1, 2), cam_uv.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1]
|
685 |
+
|
686 |
+
use_mask = cam_depth <= project_depth + THRESHOLD
|
687 |
+
feature = project_feature * use_mask.repeat(1, 1, project_feature.shape[-1])
|
688 |
+
else:
|
689 |
+
# need to double check, but now we are still use symmetry! Even if the two points are all visible in input view
|
690 |
+
coords_inp = xyz
|
691 |
+
x_check, y_check, z_check = xyz.unbind(-1)
|
692 |
+
xyz_check = torch.stack([-1 * x_check, y_check, z_check], -1)
|
693 |
+
coords_rev = xyz_check # we directly use neg-x to get the points of another side
|
694 |
+
|
695 |
+
uv_inp, dp_inp = self.get_uv_depth(coords_inp, original_mvp)
|
696 |
+
uv_rev, dp_rev = self.get_uv_depth(coords_rev, original_mvp)
|
697 |
+
uv_inp = uv_inp.detach()
|
698 |
+
uv_rev = uv_rev.detach()
|
699 |
+
dp_inp = dp_inp.detach()
|
700 |
+
dp_rev = dp_rev.detach()
|
701 |
+
|
702 |
+
proj_feat_inp = F.grid_sample(local_feat_map, uv_inp.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
|
703 |
+
proj_feat_rev = F.grid_sample(local_feat_map, uv_rev.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, c]
|
704 |
+
|
705 |
+
proj_dp_inp = F.grid_sample(original_depth.permute(0, 3, 1, 2), uv_inp.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1]
|
706 |
+
proj_dp_rev = F.grid_sample(original_depth.permute(0, 3, 1, 2), uv_rev.view(b, 1, k, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # [b, k, 1]
|
707 |
+
|
708 |
+
use_mask_inp = dp_inp <= proj_dp_inp + THRESHOLD
|
709 |
+
use_mask_rev = dp_rev <= proj_dp_rev + THRESHOLD
|
710 |
+
|
711 |
+
# for those points we can see in two sides, we use average
|
712 |
+
use_mask_inp = use_mask_inp.int()
|
713 |
+
use_mask_rev = use_mask_rev.int()
|
714 |
+
both_vis = (use_mask_inp == 1) & (use_mask_rev == 1)
|
715 |
+
use_mask_inp[both_vis] = 0.5
|
716 |
+
use_mask_rev[both_vis] = 0.5
|
717 |
+
|
718 |
+
feature = proj_feat_inp * use_mask_inp.repeat(1, 1, proj_feat_inp.shape[-1]) + proj_feat_rev * use_mask_rev.repeat(1, 1, proj_feat_rev.shape[-1])
|
719 |
+
else:
|
720 |
+
raise NotImplementedError
|
721 |
+
|
722 |
+
ret_feature = self.local_feature_block(feature.reshape(b*k, -1)) # the linear is without bias, so 0 value feature will still get 0 value
|
723 |
+
return ret_feature
|
724 |
+
|
725 |
+
def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None):
|
726 |
+
# assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] <= self.extra_dim)
|
727 |
+
b, h, w, c = x.shape
|
728 |
+
|
729 |
+
xyz_before_sym = None
|
730 |
+
if self.symmetrize:
|
731 |
+
xyz_before_sym = x.reshape(b, -1, c)
|
732 |
+
xs, ys, zs = x.unbind(-1)
|
733 |
+
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
|
734 |
+
|
735 |
+
mvp = mvp.detach() # [b, 4, 4]
|
736 |
+
w2c = w2c.detach() # [b, 4, 4]
|
737 |
+
|
738 |
+
pts_xyz = x.reshape(b, -1, c)
|
739 |
+
deform_xyz = deform_xyz.reshape(b, -1, c)
|
740 |
+
|
741 |
+
if 'global' in self.texture_way and 'local' in self.texture_way:
|
742 |
+
global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
|
743 |
+
# local_feat = self.proj_sample(pts_xyz, feat_map, mvp, w2c, h, w, xyz_before_sym=xyz_before_sym)
|
744 |
+
local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w)
|
745 |
+
# feature_rep = torch.concat([global_feat, local_feat], dim=-1)
|
746 |
+
feature_rep = global_feat + local_feat
|
747 |
+
elif 'global' not in self.texture_way and 'local' in self.texture_way:
|
748 |
+
# local_feat = self.proj_sample(pts_xyz, feat_map, mvp, w2c, h, w, xyz_before_sym=xyz_before_sym)
|
749 |
+
local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w)
|
750 |
+
feature_rep = local_feat
|
751 |
+
elif 'global' in self.texture_way and 'local' not in self.texture_way:
|
752 |
+
global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
|
753 |
+
feature_rep = global_feat
|
754 |
+
else:
|
755 |
+
global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
|
756 |
+
feature_rep = global_feat
|
757 |
+
|
758 |
+
x = x.view(-1, c)
|
759 |
+
|
760 |
+
if self.embedder is not None:
|
761 |
+
x_in = self.embedder(x)
|
762 |
+
if self.embed_concat_pts:
|
763 |
+
x_in = torch.cat([x, x_in], -1)
|
764 |
+
else:
|
765 |
+
x_in = x
|
766 |
+
|
767 |
+
x_in = self.in_layer(x_in)
|
768 |
+
|
769 |
+
# if feat is not None:
|
770 |
+
# feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
|
771 |
+
# x_in = torch.concat([x_in, feat], dim=-1)
|
772 |
+
|
773 |
+
x_in = torch.concat([x_in, feature_rep], dim=-1)
|
774 |
+
|
775 |
+
out = self.mlp(self.relu(x_in))
|
776 |
+
if self.min_max is not None:
|
777 |
+
out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
|
778 |
+
return out.view(b, h, w, -1)
|
779 |
+
|
780 |
+
|
781 |
+
class LiftTexture(nn.Module):
|
782 |
+
def __init__(self,
|
783 |
+
cin,
|
784 |
+
cout,
|
785 |
+
num_layers,
|
786 |
+
nf=256,
|
787 |
+
dropout=0,
|
788 |
+
activation=None,
|
789 |
+
min_max=None,
|
790 |
+
n_harmonic_functions=10,
|
791 |
+
omega0=1,
|
792 |
+
extra_dim=0,
|
793 |
+
embed_concat_pts=True,
|
794 |
+
perturb_normal=False,
|
795 |
+
symmetrize=False,
|
796 |
+
texture_way=None,
|
797 |
+
cam_pos_z_offset=10.,
|
798 |
+
grid_scale=7.,
|
799 |
+
local_feat_dim=128,
|
800 |
+
grid_size=32,
|
801 |
+
optim_latent=False):
|
802 |
+
super().__init__()
|
803 |
+
self.extra_dim = extra_dim
|
804 |
+
self.cam_pos_z_offset = cam_pos_z_offset
|
805 |
+
self.grid_scale = grid_scale
|
806 |
+
|
807 |
+
assert texture_way is not None
|
808 |
+
self.extra_dim = local_feat_dim + extra_dim
|
809 |
+
|
810 |
+
if n_harmonic_functions > 0:
|
811 |
+
self.embedder = HarmonicEmbedding(n_harmonic_functions, omega0)
|
812 |
+
dim_in = cin * 2 * n_harmonic_functions
|
813 |
+
self.embed_concat_pts = embed_concat_pts
|
814 |
+
if embed_concat_pts:
|
815 |
+
dim_in += cin
|
816 |
+
else:
|
817 |
+
self.embedder = None
|
818 |
+
dim_in = cin
|
819 |
+
|
820 |
+
self.encoder = Lift_Encoder(
|
821 |
+
cin=384,
|
822 |
+
feat_dim=local_feat_dim,
|
823 |
+
grid_scale=grid_scale / 2, # the dmtet is initialized in (-0.5, 0.5)
|
824 |
+
grid_size=grid_size,
|
825 |
+
optim_latent=optim_latent,
|
826 |
+
with_z_feature=True,
|
827 |
+
cam_pos_z_offset=cam_pos_z_offset
|
828 |
+
)
|
829 |
+
|
830 |
+
|
831 |
+
self.in_layer = nn.Linear(dim_in, nf)
|
832 |
+
self.relu = nn.ReLU(inplace=True)
|
833 |
+
self.mlp = MLP(nf + self.extra_dim, cout, num_layers, nf, dropout, activation)
|
834 |
+
self.perturb_normal = perturb_normal
|
835 |
+
self.symmetrize = symmetrize
|
836 |
+
if min_max is not None:
|
837 |
+
self.register_buffer('min_max', min_max)
|
838 |
+
else:
|
839 |
+
self.min_max = None
|
840 |
+
self.bsdf = None
|
841 |
+
|
842 |
+
def get_uv_depth(self, xyz, mvp):
|
843 |
+
# xyz: [b, k, 3]
|
844 |
+
# mvp: [b, 4, 4]
|
845 |
+
cam4 = torch.matmul(torch.nn.functional.pad(xyz, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2))
|
846 |
+
cam3 = cam4[..., :3] / cam4[..., 3:4]
|
847 |
+
cam_uv = cam3[..., :2]
|
848 |
+
# cam_uv = cam_uv.detach()
|
849 |
+
cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(xyz.device).view(1, 1, 3)
|
850 |
+
cam_depth = cam_depth / self.grid_scale * 2
|
851 |
+
cam_depth = cam_depth[..., 2:3]
|
852 |
+
# cam_depth = cam_depth.detach()
|
853 |
+
return cam_uv, cam_depth
|
854 |
+
|
855 |
+
def proj_sample_deform(self, xyz, feat_map, mvp, w2c, img_h, img_w):
|
856 |
+
# here the xyz is deformed points
|
857 |
+
# and we don't cast any symmtery here
|
858 |
+
if isinstance(feat_map, torch.Tensor):
|
859 |
+
feature = self.encoder(feat_map, mvp, xyz, inference="unproject")
|
860 |
+
|
861 |
+
elif isinstance(feat_map, dict):
|
862 |
+
feature = self.encoder(feat_map['im_features_map'], mvp, xyz, inference="sample")
|
863 |
+
C = feature.shape[-1]
|
864 |
+
feature = feature.reshape(-1, C)
|
865 |
+
return feature
|
866 |
+
|
867 |
+
def sample(self, x, feat=None, feat_map=None, mvp=None, w2c=None, deform_xyz=None):
|
868 |
+
# assert (feat is None and self.extra_dim == 0) or (feat.shape[-1] <= self.extra_dim)
|
869 |
+
b, h, w, c = x.shape
|
870 |
+
|
871 |
+
xyz_before_sym = None
|
872 |
+
if self.symmetrize:
|
873 |
+
xyz_before_sym = x.reshape(b, -1, c)
|
874 |
+
xs, ys, zs = x.unbind(-1)
|
875 |
+
x = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
|
876 |
+
|
877 |
+
mvp = mvp.detach() # [b, 4, 4]
|
878 |
+
w2c = w2c.detach() # [b, 4, 4]
|
879 |
+
|
880 |
+
pts_xyz = x.reshape(b, -1, c)
|
881 |
+
deform_xyz = deform_xyz.reshape(b, -1, c)
|
882 |
+
|
883 |
+
global_feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
|
884 |
+
local_feat = self.proj_sample_deform(deform_xyz, feat_map, mvp, w2c, h, w)
|
885 |
+
feature_rep = torch.concat([global_feat, local_feat], dim=-1)
|
886 |
+
x = x.view(-1, c)
|
887 |
+
|
888 |
+
if self.embedder is not None:
|
889 |
+
x_in = self.embedder(x)
|
890 |
+
if self.embed_concat_pts:
|
891 |
+
x_in = torch.cat([x, x_in], -1)
|
892 |
+
else:
|
893 |
+
x_in = x
|
894 |
+
|
895 |
+
x_in = self.in_layer(x_in)
|
896 |
+
|
897 |
+
# if feat is not None:
|
898 |
+
# feat = feat[:,None,None].expand(b, h, w, -1).reshape(b*h*w, -1)
|
899 |
+
# x_in = torch.concat([x_in, feat], dim=-1)
|
900 |
+
|
901 |
+
x_in = torch.concat([x_in, feature_rep], dim=-1)
|
902 |
+
|
903 |
+
out = self.mlp(self.relu(x_in))
|
904 |
+
if self.min_max is not None:
|
905 |
+
out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
|
906 |
+
return out.view(b, h, w, -1)
|
907 |
+
|
908 |
+
|
909 |
+
class HarmonicEmbedding(nn.Module):
|
910 |
+
def __init__(self, n_harmonic_functions=10, omega0=1):
|
911 |
+
"""
|
912 |
+
Positional Embedding implementation (adapted from Pytorch3D).
|
913 |
+
Given an input tensor `x` of shape [minibatch, ... , dim],
|
914 |
+
the harmonic embedding layer converts each feature
|
915 |
+
in `x` into a series of harmonic features `embedding`
|
916 |
+
as follows:
|
917 |
+
embedding[..., i*dim:(i+1)*dim] = [
|
918 |
+
sin(x[..., i]),
|
919 |
+
sin(2*x[..., i]),
|
920 |
+
sin(4*x[..., i]),
|
921 |
+
...
|
922 |
+
sin(2**self.n_harmonic_functions * x[..., i]),
|
923 |
+
cos(x[..., i]),
|
924 |
+
cos(2*x[..., i]),
|
925 |
+
cos(4*x[..., i]),
|
926 |
+
...
|
927 |
+
cos(2**self.n_harmonic_functions * x[..., i])
|
928 |
+
]
|
929 |
+
Note that `x` is also premultiplied by `omega0` before
|
930 |
+
evaluting the harmonic functions.
|
931 |
+
"""
|
932 |
+
super().__init__()
|
933 |
+
self.frequencies = omega0 * (2.0 ** torch.arange(n_harmonic_functions))
|
934 |
+
|
935 |
+
def forward(self, x):
|
936 |
+
"""
|
937 |
+
Args:
|
938 |
+
x: tensor of shape [..., dim]
|
939 |
+
Returns:
|
940 |
+
embedding: a harmonic embedding of `x`
|
941 |
+
of shape [..., n_harmonic_functions * dim * 2]
|
942 |
+
"""
|
943 |
+
embed = (x[..., None] * self.frequencies.to(x.device)).view(*x.shape[:-1], -1)
|
944 |
+
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
945 |
+
|
946 |
+
|
947 |
+
class VGGEncoder(nn.Module):
|
948 |
+
def __init__(self, cout, pretrained=False):
|
949 |
+
super().__init__()
|
950 |
+
if pretrained:
|
951 |
+
raise NotImplementedError
|
952 |
+
vgg = models.vgg16()
|
953 |
+
self.vgg_encoder = nn.Sequential(vgg.features, vgg.avgpool)
|
954 |
+
self.linear1 = nn.Linear(25088, 4096)
|
955 |
+
self.linear2 = nn.Linear(4096, cout)
|
956 |
+
self.relu = nn.ReLU(inplace=True)
|
957 |
+
|
958 |
+
def forward(self, x):
|
959 |
+
batch_size, _, _, _ = x.shape
|
960 |
+
out = self.relu(self.linear1(self.vgg_encoder(x).view(batch_size, -1)))
|
961 |
+
return self.linear2(out)
|
962 |
+
|
963 |
+
|
964 |
+
class ResnetEncoder(nn.Module):
|
965 |
+
def __init__(self, cout, pretrained=False):
|
966 |
+
super().__init__()
|
967 |
+
self.resnet = nn.Sequential(list(models.resnet18(weights="DEFAULT" if pretrained else None).modules())[:-1])
|
968 |
+
self.final_linear = nn.Linear(512, cout)
|
969 |
+
|
970 |
+
def forward(self, x):
|
971 |
+
return self.final_linear(self.resnet(x))
|
972 |
+
|
973 |
+
|
974 |
+
class Encoder(nn.Module):
|
975 |
+
def __init__(self, cin, cout, in_size=128, zdim=None, nf=64, activation=None):
|
976 |
+
super().__init__()
|
977 |
+
network = [
|
978 |
+
nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64
|
979 |
+
nn.GroupNorm(16, nf),
|
980 |
+
# nn.ReLU(inplace=True),
|
981 |
+
nn.LeakyReLU(0.2, inplace=True),
|
982 |
+
nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
|
983 |
+
nn.GroupNorm(16*2, nf*2),
|
984 |
+
# nn.ReLU(inplace=True),
|
985 |
+
nn.LeakyReLU(0.2, inplace=True),
|
986 |
+
nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
|
987 |
+
nn.GroupNorm(16*4, nf*4),
|
988 |
+
# nn.ReLU(inplace=True),
|
989 |
+
nn.LeakyReLU(0.2, inplace=True),
|
990 |
+
nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
|
991 |
+
# nn.GroupNorm(16*8, nf*8),
|
992 |
+
# nn.ReLU(inplace=True),
|
993 |
+
nn.LeakyReLU(0.2, inplace=True),
|
994 |
+
]
|
995 |
+
|
996 |
+
add_downsample = int(np.log2(in_size//128))
|
997 |
+
if add_downsample > 0:
|
998 |
+
for _ in range(add_downsample):
|
999 |
+
network += [
|
1000 |
+
nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
|
1001 |
+
# nn.GroupNorm(16*8, nf*8),
|
1002 |
+
# nn.ReLU(inplace=True),
|
1003 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1004 |
+
]
|
1005 |
+
|
1006 |
+
network += [
|
1007 |
+
nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
|
1008 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1009 |
+
]
|
1010 |
+
|
1011 |
+
if zdim is None:
|
1012 |
+
network += [
|
1013 |
+
nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
|
1014 |
+
]
|
1015 |
+
else:
|
1016 |
+
network += [
|
1017 |
+
nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
|
1018 |
+
# nn.ReLU(inplace=True),
|
1019 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1020 |
+
nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False),
|
1021 |
+
]
|
1022 |
+
|
1023 |
+
if activation is not None:
|
1024 |
+
network += [get_activation(activation)]
|
1025 |
+
self.network = nn.Sequential(*network)
|
1026 |
+
|
1027 |
+
def forward(self, input):
|
1028 |
+
return self.network(input).reshape(input.size(0), -1)
|
1029 |
+
|
1030 |
+
|
1031 |
+
class EncoderWithDINO(nn.Module):
|
1032 |
+
def __init__(self, cin_rgb, cin_dino, cout, in_size=128, zdim=None, nf=64, activation=None):
|
1033 |
+
super().__init__()
|
1034 |
+
network_rgb_in = [
|
1035 |
+
nn.Conv2d(cin_rgb, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64
|
1036 |
+
nn.GroupNorm(16, nf),
|
1037 |
+
# nn.ReLU(inplace=True),
|
1038 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1039 |
+
nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
|
1040 |
+
nn.GroupNorm(16*2, nf*2),
|
1041 |
+
# nn.ReLU(inplace=True),
|
1042 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1043 |
+
nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
|
1044 |
+
nn.GroupNorm(16*4, nf*4),
|
1045 |
+
# nn.ReLU(inplace=True),
|
1046 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1047 |
+
]
|
1048 |
+
self.network_rgb_in = nn.Sequential(*network_rgb_in)
|
1049 |
+
network_dino_in = [
|
1050 |
+
nn.Conv2d(cin_dino, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64
|
1051 |
+
nn.GroupNorm(16, nf),
|
1052 |
+
# nn.ReLU(inplace=True),
|
1053 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1054 |
+
nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
|
1055 |
+
nn.GroupNorm(16*2, nf*2),
|
1056 |
+
# nn.ReLU(inplace=True),
|
1057 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1058 |
+
nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
|
1059 |
+
nn.GroupNorm(16*4, nf*4),
|
1060 |
+
# nn.ReLU(inplace=True),
|
1061 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1062 |
+
]
|
1063 |
+
self.network_dino_in = nn.Sequential(*network_dino_in)
|
1064 |
+
|
1065 |
+
network_fusion = [
|
1066 |
+
nn.Conv2d(nf*4*2, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
|
1067 |
+
# nn.GroupNorm(16*8, nf*8),
|
1068 |
+
# nn.ReLU(inplace=True),
|
1069 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1070 |
+
]
|
1071 |
+
|
1072 |
+
add_downsample = int(np.log2(in_size//128))
|
1073 |
+
if add_downsample > 0:
|
1074 |
+
for _ in range(add_downsample):
|
1075 |
+
network_fusion += [
|
1076 |
+
nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
|
1077 |
+
# nn.GroupNorm(16*8, nf*8),
|
1078 |
+
# nn.ReLU(inplace=True),
|
1079 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1080 |
+
]
|
1081 |
+
|
1082 |
+
network_fusion += [
|
1083 |
+
nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
|
1084 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1085 |
+
]
|
1086 |
+
|
1087 |
+
if zdim is None:
|
1088 |
+
network_fusion += [
|
1089 |
+
nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
|
1090 |
+
]
|
1091 |
+
else:
|
1092 |
+
network_fusion += [
|
1093 |
+
nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
|
1094 |
+
# nn.ReLU(inplace=True),
|
1095 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1096 |
+
nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False),
|
1097 |
+
]
|
1098 |
+
|
1099 |
+
if activation is not None:
|
1100 |
+
network_fusion += [get_activation(activation)]
|
1101 |
+
self.network_fusion = nn.Sequential(*network_fusion)
|
1102 |
+
|
1103 |
+
def forward(self, rgb_image, dino_image):
|
1104 |
+
rgb_feat = self.network_rgb_in(rgb_image)
|
1105 |
+
dino_feat = self.network_dino_in(dino_image)
|
1106 |
+
out = self.network_fusion(torch.cat([rgb_feat, dino_feat], dim=1))
|
1107 |
+
return out.reshape(rgb_image.size(0), -1)
|
1108 |
+
|
1109 |
+
|
1110 |
+
class Encoder32(nn.Module):
|
1111 |
+
def __init__(self, cin, cout, nf=256, activation=None):
|
1112 |
+
super().__init__()
|
1113 |
+
network = [
|
1114 |
+
nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
|
1115 |
+
nn.GroupNorm(nf//4, nf),
|
1116 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1117 |
+
nn.Conv2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
|
1118 |
+
nn.GroupNorm(nf//4, nf),
|
1119 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1120 |
+
nn.Conv2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
|
1121 |
+
nn.GroupNorm(nf//4, nf),
|
1122 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1123 |
+
nn.Conv2d(nf, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
|
1124 |
+
]
|
1125 |
+
if activation is not None:
|
1126 |
+
network += [get_activation(activation)]
|
1127 |
+
self.network = nn.Sequential(*network)
|
1128 |
+
|
1129 |
+
def forward(self, input):
|
1130 |
+
return self.network(input).reshape(input.size(0), -1)
|
1131 |
+
|
1132 |
+
|
1133 |
+
class MLP(nn.Module):
|
1134 |
+
def __init__(self, cin, cout, num_layers, nf=256, dropout=0, activation=None, inner_act='relu', linear_bias=False):
|
1135 |
+
super().__init__()
|
1136 |
+
assert num_layers >= 1
|
1137 |
+
layer_act = get_activation(inner_act)
|
1138 |
+
if num_layers == 1:
|
1139 |
+
network = [nn.Linear(cin, cout, bias=linear_bias)]
|
1140 |
+
else:
|
1141 |
+
# network = [nn.Linear(cin, nf, bias=False)]
|
1142 |
+
# for _ in range(num_layers-2):
|
1143 |
+
# network += [
|
1144 |
+
# nn.ReLU(inplace=True),
|
1145 |
+
# nn.Linear(nf, nf, bias=False)]
|
1146 |
+
# if dropout:
|
1147 |
+
# network += [nn.Dropout(dropout)]
|
1148 |
+
# network += [
|
1149 |
+
# nn.ReLU(inplace=True),
|
1150 |
+
# nn.Linear(nf, cout, bias=False)]
|
1151 |
+
network = [nn.Linear(cin, nf, bias=linear_bias)]
|
1152 |
+
for _ in range(num_layers-2):
|
1153 |
+
network += [
|
1154 |
+
layer_act,
|
1155 |
+
nn.Linear(nf, nf, bias=linear_bias)]
|
1156 |
+
if dropout:
|
1157 |
+
network += [nn.Dropout(dropout)]
|
1158 |
+
network += [
|
1159 |
+
layer_act,
|
1160 |
+
nn.Linear(nf, cout, bias=linear_bias)]
|
1161 |
+
if activation is not None:
|
1162 |
+
network += [get_activation(activation)]
|
1163 |
+
self.network = nn.Sequential(*network)
|
1164 |
+
|
1165 |
+
def forward(self, input):
|
1166 |
+
return self.network(input)
|
1167 |
+
|
1168 |
+
|
1169 |
+
class Embedding(nn.Module):
|
1170 |
+
def __init__(self, cin, cout, zdim=128, nf=64, activation=None):
|
1171 |
+
super().__init__()
|
1172 |
+
network = [
|
1173 |
+
nn.Linear(cin, nf, bias=False),
|
1174 |
+
nn.ReLU(inplace=True),
|
1175 |
+
nn.Linear(nf, zdim, bias=False),
|
1176 |
+
nn.ReLU(inplace=True),
|
1177 |
+
nn.Linear(zdim, cout, bias=False)]
|
1178 |
+
if activation is not None:
|
1179 |
+
network += [get_activation(activation)]
|
1180 |
+
self.network = nn.Sequential(*network)
|
1181 |
+
|
1182 |
+
def forward(self, input):
|
1183 |
+
return self.network(input.reshape(input.size(0), -1)).reshape(input.size(0), -1)
|
1184 |
+
|
1185 |
+
|
1186 |
+
class PerceptualLoss(nn.Module):
|
1187 |
+
def __init__(self, requires_grad=False):
|
1188 |
+
super(PerceptualLoss, self).__init__()
|
1189 |
+
mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406])
|
1190 |
+
std_rgb = torch.FloatTensor([0.229, 0.224, 0.225])
|
1191 |
+
self.register_buffer('mean_rgb', mean_rgb)
|
1192 |
+
self.register_buffer('std_rgb', std_rgb)
|
1193 |
+
|
1194 |
+
vgg_pretrained_features = torchvision.models.vgg16(pretrained=True).features
|
1195 |
+
self.slice1 = nn.Sequential()
|
1196 |
+
self.slice2 = nn.Sequential()
|
1197 |
+
self.slice3 = nn.Sequential()
|
1198 |
+
self.slice4 = nn.Sequential()
|
1199 |
+
for x in range(4):
|
1200 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
1201 |
+
for x in range(4, 9):
|
1202 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
1203 |
+
for x in range(9, 16):
|
1204 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
1205 |
+
for x in range(16, 23):
|
1206 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
1207 |
+
if not requires_grad:
|
1208 |
+
for param in self.parameters():
|
1209 |
+
param.requires_grad = False
|
1210 |
+
|
1211 |
+
def normalize(self, x):
|
1212 |
+
out = x/2 + 0.5
|
1213 |
+
out = (out - self.mean_rgb.view(1,3,1,1)) / self.std_rgb.view(1,3,1,1)
|
1214 |
+
return out
|
1215 |
+
|
1216 |
+
def __call__(self, im1, im2, mask=None, conf_sigma=None):
|
1217 |
+
im = torch.cat([im1,im2], 0)
|
1218 |
+
im = self.normalize(im) # normalize input
|
1219 |
+
|
1220 |
+
## compute features
|
1221 |
+
feats = []
|
1222 |
+
f = self.slice1(im)
|
1223 |
+
feats += [torch.chunk(f, 2, dim=0)]
|
1224 |
+
f = self.slice2(f)
|
1225 |
+
feats += [torch.chunk(f, 2, dim=0)]
|
1226 |
+
f = self.slice3(f)
|
1227 |
+
feats += [torch.chunk(f, 2, dim=0)]
|
1228 |
+
f = self.slice4(f)
|
1229 |
+
feats += [torch.chunk(f, 2, dim=0)]
|
1230 |
+
|
1231 |
+
losses = []
|
1232 |
+
for f1, f2 in feats[2:3]: # use relu3_3 features only
|
1233 |
+
loss = (f1-f2)**2
|
1234 |
+
if conf_sigma is not None:
|
1235 |
+
loss = loss / (2*conf_sigma**2 +EPS) + (conf_sigma +EPS).log()
|
1236 |
+
if mask is not None:
|
1237 |
+
b, c, h, w = loss.shape
|
1238 |
+
_, _, hm, wm = mask.shape
|
1239 |
+
sh, sw = hm//h, wm//w
|
1240 |
+
mask0 = nn.functional.avg_pool2d(mask, kernel_size=(sh,sw), stride=(sh,sw)).expand_as(loss)
|
1241 |
+
loss = (loss * mask0).sum() / mask0.sum()
|
1242 |
+
else:
|
1243 |
+
loss = loss.mean()
|
1244 |
+
losses += [loss]
|
1245 |
+
return sum(losses)
|
1246 |
+
|
1247 |
+
|
1248 |
+
## from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
1249 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
1250 |
+
"""3x3 convolution with padding"""
|
1251 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
1252 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
1253 |
+
|
1254 |
+
|
1255 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
1256 |
+
"""1x1 convolution"""
|
1257 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
1258 |
+
|
1259 |
+
|
1260 |
+
class BasicBlock(nn.Module):
|
1261 |
+
expansion = 1
|
1262 |
+
|
1263 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
1264 |
+
base_width=64, dilation=1, norm_layer=None):
|
1265 |
+
super(BasicBlock, self).__init__()
|
1266 |
+
if groups != 1 or base_width != 64:
|
1267 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
1268 |
+
if dilation > 1:
|
1269 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
1270 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
1271 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
1272 |
+
self.relu = nn.ReLU(inplace=True)
|
1273 |
+
self.conv2 = conv3x3(planes, planes)
|
1274 |
+
|
1275 |
+
self.norm_layer = norm_layer
|
1276 |
+
if norm_layer is not None:
|
1277 |
+
self.bn1 = norm_layer(planes)
|
1278 |
+
self.bn2 = norm_layer(planes)
|
1279 |
+
|
1280 |
+
if inplanes != planes:
|
1281 |
+
self.downsample = nn.Sequential(
|
1282 |
+
conv1x1(inplanes, planes, stride),
|
1283 |
+
norm_layer(planes),
|
1284 |
+
)
|
1285 |
+
else:
|
1286 |
+
self.downsample = None
|
1287 |
+
self.stride = stride
|
1288 |
+
|
1289 |
+
def forward(self, x):
|
1290 |
+
identity = x
|
1291 |
+
|
1292 |
+
out = self.conv1(x)
|
1293 |
+
if self.norm_layer is not None:
|
1294 |
+
out = self.bn1(out)
|
1295 |
+
out = self.relu(out)
|
1296 |
+
|
1297 |
+
out = self.conv2(out)
|
1298 |
+
if self.norm_layer is not None:
|
1299 |
+
out = self.bn2(out)
|
1300 |
+
|
1301 |
+
if self.downsample is not None:
|
1302 |
+
identity = self.downsample(x)
|
1303 |
+
|
1304 |
+
out += identity
|
1305 |
+
out = self.relu(out)
|
1306 |
+
|
1307 |
+
return out
|
1308 |
+
|
1309 |
+
|
1310 |
+
class ResEncoder(nn.Module):
|
1311 |
+
def __init__(self, cin, cout, in_size=128, zdim=None, nf=64, activation=None):
|
1312 |
+
super().__init__()
|
1313 |
+
network = [
|
1314 |
+
nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 128x128 -> 64x64
|
1315 |
+
# nn.GroupNorm(16, nf),
|
1316 |
+
# nn.ReLU(inplace=True),
|
1317 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1318 |
+
nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32
|
1319 |
+
# nn.GroupNorm(16*2, nf*2),
|
1320 |
+
# nn.ReLU(inplace=True),
|
1321 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1322 |
+
BasicBlock(nf*2, nf*2, norm_layer=None),
|
1323 |
+
BasicBlock(nf*2, nf*2, norm_layer=None),
|
1324 |
+
nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16
|
1325 |
+
# nn.GroupNorm(16*4, nf*4),
|
1326 |
+
# nn.ReLU(inplace=True),
|
1327 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1328 |
+
BasicBlock(nf*4, nf*4, norm_layer=None),
|
1329 |
+
BasicBlock(nf*4, nf*4, norm_layer=None),
|
1330 |
+
nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8
|
1331 |
+
# nn.ReLU(inplace=True),
|
1332 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1333 |
+
BasicBlock(nf*8, nf*8, norm_layer=None),
|
1334 |
+
BasicBlock(nf*8, nf*8, norm_layer=None),
|
1335 |
+
]
|
1336 |
+
|
1337 |
+
add_downsample = int(np.log2(in_size//64))
|
1338 |
+
if add_downsample > 0:
|
1339 |
+
for _ in range(add_downsample):
|
1340 |
+
network += [
|
1341 |
+
nn.Conv2d(nf*8, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4
|
1342 |
+
# nn.ReLU(inplace=True),
|
1343 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1344 |
+
BasicBlock(nf*8, nf*8, norm_layer=None),
|
1345 |
+
BasicBlock(nf*8, nf*8, norm_layer=None),
|
1346 |
+
]
|
1347 |
+
|
1348 |
+
if zdim is None:
|
1349 |
+
network += [
|
1350 |
+
nn.Conv2d(nf*8, cout, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
|
1351 |
+
]
|
1352 |
+
else:
|
1353 |
+
network += [
|
1354 |
+
nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1
|
1355 |
+
# nn.ReLU(inplace=True),
|
1356 |
+
nn.LeakyReLU(0.2, inplace=True),
|
1357 |
+
nn.Conv2d(zdim, cout, kernel_size=1, stride=1, padding=0, bias=False),
|
1358 |
+
]
|
1359 |
+
|
1360 |
+
if activation is not None:
|
1361 |
+
network += [get_activation(activation)]
|
1362 |
+
self.network = nn.Sequential(*network)
|
1363 |
+
|
1364 |
+
def forward(self, input):
|
1365 |
+
return self.network(input).reshape(input.size(0), -1)
|
1366 |
+
|
1367 |
+
|
1368 |
+
class Attention(nn.Module):
|
1369 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
1370 |
+
super().__init__()
|
1371 |
+
self.num_heads = num_heads
|
1372 |
+
head_dim = dim // num_heads
|
1373 |
+
self.scale = qk_scale or head_dim ** -0.5
|
1374 |
+
|
1375 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
1376 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
1377 |
+
self.proj = nn.Linear(dim, dim)
|
1378 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
1379 |
+
|
1380 |
+
def forward(self, x):
|
1381 |
+
B, N, C = x.shape
|
1382 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
1383 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
1384 |
+
|
1385 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
1386 |
+
attn = attn.softmax(dim=-1)
|
1387 |
+
attn = self.attn_drop(attn)
|
1388 |
+
|
1389 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
1390 |
+
x = self.proj(x)
|
1391 |
+
x = self.proj_drop(x)
|
1392 |
+
return x, attn
|
1393 |
+
|
1394 |
+
|
1395 |
+
class ViTEncoder(nn.Module):
|
1396 |
+
def __init__(self, cout, which_vit='dino_vits8', pretrained=False, frozen=False, in_size=256, final_layer_type='none', root='/root'):
|
1397 |
+
super().__init__()
|
1398 |
+
if misc.is_main_process():
|
1399 |
+
force_reload = not os.path.exists(os.path.join(root, ".cache/torch/hub/checkpoints/"))
|
1400 |
+
else:
|
1401 |
+
force_reload = False
|
1402 |
+
if "dinov2" in which_vit:
|
1403 |
+
self.ViT = torch.hub.load('facebookresearch/dinov2:main', which_vit, pretrained=pretrained, force_reload=force_reload)
|
1404 |
+
else:
|
1405 |
+
self.ViT = torch.hub.load('facebookresearch/dino:main', which_vit, pretrained=pretrained, force_reload=force_reload)
|
1406 |
+
|
1407 |
+
if frozen:
|
1408 |
+
for p in self.ViT.parameters():
|
1409 |
+
p.requires_grad = False
|
1410 |
+
if which_vit == 'dino_vits8':
|
1411 |
+
self.vit_feat_dim = 384
|
1412 |
+
self.patch_size = 8
|
1413 |
+
elif which_vit == 'dinov2_vits14':
|
1414 |
+
self.vit_feat_dim = 384
|
1415 |
+
self.patch_size = 14
|
1416 |
+
elif which_vit == 'dino_vitb8':
|
1417 |
+
self.vit_feat_dim = 768
|
1418 |
+
self.patch_size = 8
|
1419 |
+
|
1420 |
+
self._feats = []
|
1421 |
+
self.hook_handlers = []
|
1422 |
+
|
1423 |
+
if final_layer_type == 'none':
|
1424 |
+
pass
|
1425 |
+
elif final_layer_type == 'conv':
|
1426 |
+
self.final_layer_patch_out = Encoder32(self.vit_feat_dim, cout, nf=256, activation=None)
|
1427 |
+
self.final_layer_patch_key = Encoder32(self.vit_feat_dim, cout, nf=256, activation=None)
|
1428 |
+
elif final_layer_type == 'attention':
|
1429 |
+
raise NotImplementedError
|
1430 |
+
self.final_layer = Attention(
|
1431 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
1432 |
+
self.fc = nn.Linear(self.vit_feat_dim, cout)
|
1433 |
+
else:
|
1434 |
+
raise NotImplementedError
|
1435 |
+
self.final_layer_type = final_layer_type
|
1436 |
+
|
1437 |
+
def _get_hook(self, facet: str):
|
1438 |
+
"""
|
1439 |
+
generate a hook method for a specific block and facet.
|
1440 |
+
"""
|
1441 |
+
if facet in ['attn', 'token']:
|
1442 |
+
def _hook(model, input, output):
|
1443 |
+
self._feats.append(output)
|
1444 |
+
return _hook
|
1445 |
+
|
1446 |
+
if facet == 'query':
|
1447 |
+
facet_idx = 0
|
1448 |
+
elif facet == 'key':
|
1449 |
+
facet_idx = 1
|
1450 |
+
elif facet == 'value':
|
1451 |
+
facet_idx = 2
|
1452 |
+
else:
|
1453 |
+
raise TypeError(f"{facet} is not a supported facet.")
|
1454 |
+
|
1455 |
+
def _inner_hook(module, input, output):
|
1456 |
+
input = input[0]
|
1457 |
+
B, N, C = input.shape
|
1458 |
+
qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
|
1459 |
+
self._feats.append(qkv[facet_idx]) #Bxhxtxd
|
1460 |
+
return _inner_hook
|
1461 |
+
|
1462 |
+
def _register_hooks(self, layers: List[int], facet: str) -> None:
|
1463 |
+
"""
|
1464 |
+
register hook to extract features.
|
1465 |
+
:param layers: layers from which to extract features.
|
1466 |
+
:param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn']
|
1467 |
+
"""
|
1468 |
+
for block_idx, block in enumerate(self.ViT.blocks):
|
1469 |
+
if block_idx in layers:
|
1470 |
+
if facet == 'token':
|
1471 |
+
self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet)))
|
1472 |
+
elif facet == 'attn':
|
1473 |
+
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet)))
|
1474 |
+
elif facet in ['key', 'query', 'value']:
|
1475 |
+
self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet)))
|
1476 |
+
else:
|
1477 |
+
raise TypeError(f"{facet} is not a supported facet.")
|
1478 |
+
|
1479 |
+
def _unregister_hooks(self) -> None:
|
1480 |
+
"""
|
1481 |
+
unregisters the hooks. should be called after feature extraction.
|
1482 |
+
"""
|
1483 |
+
for handle in self.hook_handlers:
|
1484 |
+
handle.remove()
|
1485 |
+
self.hook_handlers = []
|
1486 |
+
|
1487 |
+
def forward(self, x, return_patches=False):
|
1488 |
+
b, c, h, w = x.shape
|
1489 |
+
self._feats = []
|
1490 |
+
self._register_hooks([11], 'key')
|
1491 |
+
#self._register_hooks([11], 'token')
|
1492 |
+
x = self.ViT.prepare_tokens(x)
|
1493 |
+
#x = self.ViT.prepare_tokens_with_masks(x)
|
1494 |
+
|
1495 |
+
for blk in self.ViT.blocks:
|
1496 |
+
x = blk(x)
|
1497 |
+
out = self.ViT.norm(x)
|
1498 |
+
self._unregister_hooks()
|
1499 |
+
|
1500 |
+
ph, pw = h // self.patch_size, w // self.patch_size
|
1501 |
+
patch_out = out[:, 1:] # first is class token
|
1502 |
+
patch_out = patch_out.reshape(b, ph, pw, self.vit_feat_dim).permute(0, 3, 1, 2)
|
1503 |
+
|
1504 |
+
patch_key = self._feats[0][:,:,1:] # B, num_heads, num_patches, dim
|
1505 |
+
patch_key = patch_key.permute(0, 1, 3, 2).reshape(b, self.vit_feat_dim, ph, pw)
|
1506 |
+
|
1507 |
+
if self.final_layer_type == 'none':
|
1508 |
+
global_feat_out = out[:, 0].reshape(b, -1) # first is class token
|
1509 |
+
global_feat_key = self._feats[0][:, :, 0].reshape(b, -1) # first is class token
|
1510 |
+
elif self.final_layer_type == 'conv':
|
1511 |
+
global_feat_out = self.final_layer_patch_out(patch_out).view(b, -1)
|
1512 |
+
global_feat_key = self.final_layer_patch_key(patch_key).view(b, -1)
|
1513 |
+
elif self.final_layer_type == 'attention':
|
1514 |
+
raise NotImplementedError
|
1515 |
+
else:
|
1516 |
+
raise NotImplementedError
|
1517 |
+
if not return_patches:
|
1518 |
+
patch_out = patch_key = None
|
1519 |
+
return global_feat_out, global_feat_key, patch_out, patch_key
|
1520 |
+
|
1521 |
+
|
1522 |
+
class ArticulationNetwork(nn.Module):
|
1523 |
+
def __init__(self, net_type, feat_dim, pos_dim, num_layers, nf, n_harmonic_functions=0, omega0=1, activation=None, enable_articulation_idadd=False):
|
1524 |
+
super().__init__()
|
1525 |
+
if n_harmonic_functions > 0:
|
1526 |
+
self.posenc = HarmonicEmbedding(n_harmonic_functions=n_harmonic_functions, omega0=omega0)
|
1527 |
+
pos_dim = pos_dim * (n_harmonic_functions * 2 + 1)
|
1528 |
+
else:
|
1529 |
+
self.posenc = None
|
1530 |
+
pos_dim = 4
|
1531 |
+
cout = 3
|
1532 |
+
|
1533 |
+
if net_type == 'mlp':
|
1534 |
+
self.network = MLP(
|
1535 |
+
feat_dim + pos_dim, # + bone xyz pos and index
|
1536 |
+
cout, # We represent the rotation of each bone by its Euler angles ψ, θ, and φ
|
1537 |
+
num_layers,
|
1538 |
+
nf=nf,
|
1539 |
+
dropout=0,
|
1540 |
+
activation=activation
|
1541 |
+
)
|
1542 |
+
elif net_type == 'attention':
|
1543 |
+
self.in_layer = nn.Sequential(
|
1544 |
+
nn.Linear(feat_dim + pos_dim, nf),
|
1545 |
+
nn.GELU(),
|
1546 |
+
nn.LayerNorm(nf),
|
1547 |
+
)
|
1548 |
+
self.blocks = nn.ModuleList([
|
1549 |
+
Block(
|
1550 |
+
dim=nf, num_heads=8, mlp_ratio=2., qkv_bias=False, qk_scale=None,
|
1551 |
+
drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm)
|
1552 |
+
for i in range(num_layers)])
|
1553 |
+
out_layer = [nn.Linear(nf, cout)]
|
1554 |
+
if activation:
|
1555 |
+
out_layer += [get_activation(activation)]
|
1556 |
+
self.out_layer = nn.Sequential(*out_layer)
|
1557 |
+
else:
|
1558 |
+
raise NotImplementedError
|
1559 |
+
self.net_type = net_type
|
1560 |
+
self.enable_articulation_idadd = enable_articulation_idadd
|
1561 |
+
|
1562 |
+
def forward(self, x, pos):
|
1563 |
+
pos_inp = pos
|
1564 |
+
if self.posenc is not None:
|
1565 |
+
pos = torch.cat([pos, self.posenc(pos)], dim=-1)
|
1566 |
+
x = torch.cat([x, pos], dim=-1)
|
1567 |
+
if self.enable_articulation_idadd:
|
1568 |
+
articulation_id = pos_inp[..., -1:]
|
1569 |
+
x = x + articulation_id
|
1570 |
+
if self.net_type == 'mlp':
|
1571 |
+
out = self.network(x)
|
1572 |
+
elif self.net_type == 'attention':
|
1573 |
+
x = self.in_layer(x)
|
1574 |
+
for blk in self.blocks:
|
1575 |
+
x = blk(x)
|
1576 |
+
out = self.out_layer(x)
|
1577 |
+
else:
|
1578 |
+
raise NotImplementedError
|
1579 |
+
return out
|
1580 |
+
|
1581 |
+
|
1582 |
+
## Attention block from ViT (https://github.com/facebookresearch/dino/blob/main/vision_transformer.py)
|
1583 |
+
class Attention(nn.Module):
|
1584 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
1585 |
+
super().__init__()
|
1586 |
+
self.num_heads = num_heads
|
1587 |
+
head_dim = dim // num_heads
|
1588 |
+
self.scale = qk_scale or head_dim ** -0.5
|
1589 |
+
|
1590 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
1591 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
1592 |
+
self.proj = nn.Linear(dim, dim)
|
1593 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
1594 |
+
|
1595 |
+
def forward(self, x):
|
1596 |
+
B, N, C = x.shape
|
1597 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
1598 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
1599 |
+
|
1600 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
1601 |
+
attn = attn.softmax(dim=-1)
|
1602 |
+
attn = self.attn_drop(attn)
|
1603 |
+
|
1604 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
1605 |
+
x = self.proj(x)
|
1606 |
+
x = self.proj_drop(x)
|
1607 |
+
return x, attn
|
1608 |
+
|
1609 |
+
|
1610 |
+
class Mlp(nn.Module):
|
1611 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
1612 |
+
super().__init__()
|
1613 |
+
out_features = out_features or in_features
|
1614 |
+
hidden_features = hidden_features or in_features
|
1615 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
1616 |
+
self.act = act_layer()
|
1617 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
1618 |
+
self.drop = nn.Dropout(drop)
|
1619 |
+
|
1620 |
+
def forward(self, x):
|
1621 |
+
x = self.fc1(x)
|
1622 |
+
x = self.act(x)
|
1623 |
+
x = self.drop(x)
|
1624 |
+
x = self.fc2(x)
|
1625 |
+
x = self.drop(x)
|
1626 |
+
return x
|
1627 |
+
|
1628 |
+
|
1629 |
+
class DropPath(nn.Module):
|
1630 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
1631 |
+
"""
|
1632 |
+
def __init__(self, drop_prob=None):
|
1633 |
+
super(DropPath, self).__init__()
|
1634 |
+
self.drop_prob = drop_prob
|
1635 |
+
|
1636 |
+
def forward(self, x):
|
1637 |
+
return drop_path(x, self.drop_prob, self.training)
|
1638 |
+
|
1639 |
+
|
1640 |
+
class Block(nn.Module):
|
1641 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
1642 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
1643 |
+
super().__init__()
|
1644 |
+
self.norm1 = norm_layer(dim)
|
1645 |
+
self.attn = Attention(
|
1646 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
1647 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
1648 |
+
self.norm2 = norm_layer(dim)
|
1649 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
1650 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
1651 |
+
|
1652 |
+
def forward(self, x, return_attention=False):
|
1653 |
+
y, attn = self.attn(self.norm1(x))
|
1654 |
+
if return_attention:
|
1655 |
+
return attn
|
1656 |
+
x = x + self.drop_path(y)
|
1657 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
1658 |
+
return x
|
1659 |
+
|
1660 |
+
|
1661 |
+
class FeatureAttention(nn.Module):
|
1662 |
+
def __init__(self, vit_type, pos_dim, embedder_freq=0, zdim=128, img_size=256, activation=None):
|
1663 |
+
super().__init__()
|
1664 |
+
self.zdim = zdim
|
1665 |
+
if embedder_freq > 0:
|
1666 |
+
self.posenc = HarmonicEmbedding(n_harmonic_functions=embedder_freq, omega0=1)
|
1667 |
+
pos_dim = pos_dim * (embedder_freq * 2 + 1)
|
1668 |
+
else:
|
1669 |
+
self.posenc = None
|
1670 |
+
self.pos_dim = pos_dim
|
1671 |
+
|
1672 |
+
if vit_type == 'dino_vits8':
|
1673 |
+
self.vit_feat_dim = 384
|
1674 |
+
patch_size = 8
|
1675 |
+
elif which_vit == 'dinov2_vits14':
|
1676 |
+
self.vit_feat_dim = 384
|
1677 |
+
self.patch_size = 14
|
1678 |
+
elif vit_type == 'dino_vitb8':
|
1679 |
+
self.vit_feat_dim = 768
|
1680 |
+
patch_size = 8
|
1681 |
+
else:
|
1682 |
+
raise NotImplementedError
|
1683 |
+
self.num_patches_per_dim = img_size // patch_size
|
1684 |
+
|
1685 |
+
self.kv = nn.Sequential(
|
1686 |
+
nn.Linear(self.vit_feat_dim, zdim),
|
1687 |
+
nn.ReLU(inplace=True),
|
1688 |
+
nn.LayerNorm(zdim),
|
1689 |
+
nn.Linear(zdim, zdim*2),
|
1690 |
+
)
|
1691 |
+
|
1692 |
+
self.q = nn.Sequential(
|
1693 |
+
nn.Linear(pos_dim, zdim),
|
1694 |
+
nn.ReLU(inplace=True),
|
1695 |
+
nn.LayerNorm(zdim),
|
1696 |
+
nn.Linear(zdim, zdim),
|
1697 |
+
)
|
1698 |
+
|
1699 |
+
final_mlp = [
|
1700 |
+
nn.Linear(zdim, zdim),
|
1701 |
+
nn.ReLU(inplace=True),
|
1702 |
+
nn.LayerNorm(zdim),
|
1703 |
+
nn.Linear(zdim, self.vit_feat_dim)
|
1704 |
+
]
|
1705 |
+
if activation is not None:
|
1706 |
+
final_mlp += [get_activation(activation)]
|
1707 |
+
self.final_ln = nn.Sequential(*final_mlp)
|
1708 |
+
|
1709 |
+
def forward(self, x, feat):
|
1710 |
+
_, vit_feat_dim, ph, pw = feat.shape
|
1711 |
+
assert ph == pw and ph == self.num_patches_per_dim and vit_feat_dim == self.vit_feat_dim
|
1712 |
+
|
1713 |
+
if self.posenc is not None:
|
1714 |
+
x = torch.cat([x, self.posenc(x)], dim=-1)
|
1715 |
+
bxf, k, c = x.shape
|
1716 |
+
assert c == self.pos_dim
|
1717 |
+
|
1718 |
+
query = self.q(x)
|
1719 |
+
feat_in = feat.view(bxf, vit_feat_dim, ph*pw).permute(0, 2, 1) # N, K, C
|
1720 |
+
k, v = self.kv(feat_in).chunk(2, dim=-1)
|
1721 |
+
attn = torch.einsum('bnd,bpd->bnp', query, k).softmax(dim=-1)
|
1722 |
+
out = torch.einsum('bnp,bpd->bnd', attn, v)
|
1723 |
+
out = self.final_ln(out)
|
1724 |
+
return out
|
video3d/render/light.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
import nvdiffrast.torch as dr
|
15 |
+
|
16 |
+
from . import util
|
17 |
+
from . import renderutils as ru
|
18 |
+
from ..networks import MLP
|
19 |
+
|
20 |
+
######################################################################################
|
21 |
+
# Utility functions
|
22 |
+
######################################################################################
|
23 |
+
|
24 |
+
class cubemap_mip(torch.autograd.Function):
|
25 |
+
@staticmethod
|
26 |
+
def forward(ctx, cubemap):
|
27 |
+
return util.avg_pool_nhwc(cubemap, (2,2))
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def backward(ctx, dout):
|
31 |
+
res = dout.shape[1] * 2
|
32 |
+
out = torch.zeros(6, res, res, dout.shape[-1], dtype=torch.float32, device="cuda")
|
33 |
+
for s in range(6):
|
34 |
+
gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
|
35 |
+
torch.linspace(-1.0 + 1.0 / res, 1.0 - 1.0 / res, res, device="cuda"),
|
36 |
+
indexing='ij')
|
37 |
+
v = util.safe_normalize(util.cube_to_dir(s, gx, gy))
|
38 |
+
out[s, ...] = dr.texture(dout[None, ...] * 0.25, v[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')
|
39 |
+
return out
|
40 |
+
|
41 |
+
######################################################################################
|
42 |
+
# Split-sum environment map light source with automatic mipmap generation
|
43 |
+
######################################################################################
|
44 |
+
|
45 |
+
class EnvironmentLight(torch.nn.Module):
|
46 |
+
LIGHT_MIN_RES = 16
|
47 |
+
|
48 |
+
MIN_ROUGHNESS = 0.08
|
49 |
+
MAX_ROUGHNESS = 0.5
|
50 |
+
|
51 |
+
def __init__(self, base):
|
52 |
+
super(EnvironmentLight, self).__init__()
|
53 |
+
self.mtx = None
|
54 |
+
self.base = torch.nn.Parameter(base.clone().detach(), requires_grad=True)
|
55 |
+
self.register_parameter('env_base', self.base)
|
56 |
+
|
57 |
+
def xfm(self, mtx):
|
58 |
+
self.mtx = mtx
|
59 |
+
|
60 |
+
def clone(self):
|
61 |
+
return EnvironmentLight(self.base.clone().detach())
|
62 |
+
|
63 |
+
def clamp_(self, min=None, max=None):
|
64 |
+
self.base.clamp_(min, max)
|
65 |
+
|
66 |
+
def get_mip(self, roughness):
|
67 |
+
return torch.where(roughness < self.MAX_ROUGHNESS
|
68 |
+
, (torch.clamp(roughness, self.MIN_ROUGHNESS, self.MAX_ROUGHNESS) - self.MIN_ROUGHNESS) / (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) * (len(self.specular) - 2)
|
69 |
+
, (torch.clamp(roughness, self.MAX_ROUGHNESS, 1.0) - self.MAX_ROUGHNESS) / (1.0 - self.MAX_ROUGHNESS) + len(self.specular) - 2)
|
70 |
+
|
71 |
+
def build_mips(self, cutoff=0.99):
|
72 |
+
self.specular = [self.base]
|
73 |
+
while self.specular[-1].shape[1] > self.LIGHT_MIN_RES:
|
74 |
+
self.specular += [cubemap_mip.apply(self.specular[-1])]
|
75 |
+
|
76 |
+
self.diffuse = ru.diffuse_cubemap(self.specular[-1])
|
77 |
+
|
78 |
+
for idx in range(len(self.specular) - 1):
|
79 |
+
roughness = (idx / (len(self.specular) - 2)) * (self.MAX_ROUGHNESS - self.MIN_ROUGHNESS) + self.MIN_ROUGHNESS
|
80 |
+
self.specular[idx] = ru.specular_cubemap(self.specular[idx], roughness, cutoff)
|
81 |
+
self.specular[-1] = ru.specular_cubemap(self.specular[-1], 1.0, cutoff)
|
82 |
+
|
83 |
+
def regularizer(self):
|
84 |
+
white = (self.base[..., 0:1] + self.base[..., 1:2] + self.base[..., 2:3]) / 3.0
|
85 |
+
return torch.mean(torch.abs(self.base - white))
|
86 |
+
|
87 |
+
def shade(self, gb_pos, gb_normal, kd, ks, view_pos, specular=True):
|
88 |
+
wo = util.safe_normalize(view_pos - gb_pos)
|
89 |
+
|
90 |
+
if specular:
|
91 |
+
roughness = ks[..., 1:2] # y component
|
92 |
+
metallic = ks[..., 2:3] # z component
|
93 |
+
spec_col = (1.0 - metallic)*0.04 + kd * metallic
|
94 |
+
diff_col = kd * (1.0 - metallic)
|
95 |
+
else:
|
96 |
+
diff_col = kd
|
97 |
+
|
98 |
+
reflvec = util.safe_normalize(util.reflect(wo, gb_normal))
|
99 |
+
nrmvec = gb_normal
|
100 |
+
if self.mtx is not None: # Rotate lookup
|
101 |
+
mtx = torch.as_tensor(self.mtx, dtype=torch.float32, device='cuda')
|
102 |
+
reflvec = ru.xfm_vectors(reflvec.view(reflvec.shape[0], reflvec.shape[1] * reflvec.shape[2], reflvec.shape[3]), mtx).view(*reflvec.shape)
|
103 |
+
nrmvec = ru.xfm_vectors(nrmvec.view(nrmvec.shape[0], nrmvec.shape[1] * nrmvec.shape[2], nrmvec.shape[3]), mtx).view(*nrmvec.shape)
|
104 |
+
|
105 |
+
# Diffuse lookup
|
106 |
+
diffuse = dr.texture(self.diffuse[None, ...], nrmvec.contiguous(), filter_mode='linear', boundary_mode='cube')
|
107 |
+
shaded_col = diffuse * diff_col
|
108 |
+
|
109 |
+
if specular:
|
110 |
+
# Lookup FG term from lookup texture
|
111 |
+
NdotV = torch.clamp(util.dot(wo, gb_normal), min=1e-4)
|
112 |
+
fg_uv = torch.cat((NdotV, roughness), dim=-1)
|
113 |
+
if not hasattr(self, '_FG_LUT'):
|
114 |
+
self._FG_LUT = torch.as_tensor(np.fromfile('data/irrmaps/bsdf_256_256.bin', dtype=np.float32).reshape(1, 256, 256, 2), dtype=torch.float32, device='cuda')
|
115 |
+
fg_lookup = dr.texture(self._FG_LUT, fg_uv, filter_mode='linear', boundary_mode='clamp')
|
116 |
+
|
117 |
+
# Roughness adjusted specular env lookup
|
118 |
+
miplevel = self.get_mip(roughness)
|
119 |
+
spec = dr.texture(self.specular[0][None, ...], reflvec.contiguous(), mip=list(m[None, ...] for m in self.specular[1:]), mip_level_bias=miplevel[..., 0], filter_mode='linear-mipmap-linear', boundary_mode='cube')
|
120 |
+
|
121 |
+
# Compute aggregate lighting
|
122 |
+
reflectance = spec_col * fg_lookup[...,0:1] + fg_lookup[...,1:2]
|
123 |
+
shaded_col += spec * reflectance
|
124 |
+
|
125 |
+
return shaded_col * (1.0 - ks[..., 0:1]) # Modulate by hemisphere visibility
|
126 |
+
|
127 |
+
######################################################################################
|
128 |
+
# Load and store
|
129 |
+
######################################################################################
|
130 |
+
|
131 |
+
# Load from latlong .HDR file
|
132 |
+
def _load_env_hdr(fn, scale=1.0):
|
133 |
+
latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale
|
134 |
+
cubemap = util.latlong_to_cubemap(latlong_img, [512, 512])
|
135 |
+
|
136 |
+
l = EnvironmentLight(cubemap)
|
137 |
+
l.build_mips()
|
138 |
+
|
139 |
+
return l
|
140 |
+
|
141 |
+
def load_env(fn, scale=1.0):
|
142 |
+
if os.path.splitext(fn)[1].lower() == ".hdr":
|
143 |
+
return _load_env_hdr(fn, scale)
|
144 |
+
else:
|
145 |
+
assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1]
|
146 |
+
|
147 |
+
def save_env_map(fn, light):
|
148 |
+
assert isinstance(light, EnvironmentLight), "Can only save EnvironmentLight currently"
|
149 |
+
if isinstance(light, EnvironmentLight):
|
150 |
+
color = util.cubemap_to_latlong(light.base, [512, 1024])
|
151 |
+
util.save_image_raw(fn, color.detach().cpu().numpy())
|
152 |
+
|
153 |
+
######################################################################################
|
154 |
+
# Create trainable env map with random initialization
|
155 |
+
######################################################################################
|
156 |
+
|
157 |
+
def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25):
|
158 |
+
base = torch.rand(6, base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias
|
159 |
+
return EnvironmentLight(base)
|
160 |
+
|
161 |
+
|
162 |
+
######################################################################################
|
163 |
+
# Directional light source
|
164 |
+
######################################################################################
|
165 |
+
|
166 |
+
class DirectionalLight(torch.nn.Module):
|
167 |
+
def __init__(self, mlp_in, mlp_layers, mlp_hidden_size, intensity_min_max=None):
|
168 |
+
super(DirectionalLight, self).__init__()
|
169 |
+
self.mlp = MLP(mlp_in, 4, mlp_layers, nf=mlp_hidden_size, activation='sigmoid')
|
170 |
+
if intensity_min_max is not None:
|
171 |
+
self.register_buffer('intensity_min_max', intensity_min_max)
|
172 |
+
else:
|
173 |
+
self.intensity_min_max = None
|
174 |
+
|
175 |
+
def forward(self, feat):
|
176 |
+
# print('----------------- forward light !!! -----------------')
|
177 |
+
out = self.mlp(feat)
|
178 |
+
light_dir = F.normalize(torch.cat([out[..., 0:1] *2-1, torch.ones_like(out[..., :1]) * 0.5, out[..., 1:2] *2-1], dim=-1), dim=-1) # upper hemisphere
|
179 |
+
if self.intensity_min_max is not None:
|
180 |
+
int = out[..., 2:] * (self.intensity_min_max[1][None, :] - self.intensity_min_max[0][None, :]) + self.intensity_min_max[0][None, :]
|
181 |
+
self.light_params = torch.cat([light_dir, int], -1)
|
182 |
+
return self.light_params
|
183 |
+
|
184 |
+
def shade(self, feat, kd, normal):
|
185 |
+
light_params = self.forward(feat)
|
186 |
+
light_dir = light_params[..., :3][:, None, None, :]
|
187 |
+
int_amb = light_params[..., 3:4][:, None, None, :]
|
188 |
+
int_diff = light_params[..., 4:5][:, None, None, :]
|
189 |
+
shading = (int_amb + int_diff * torch.clamp(util.dot(light_dir, normal), min=0.0))
|
190 |
+
shaded = shading * kd
|
191 |
+
return shaded, shading
|
video3d/render/material.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import nvdiffrast.torch as dr
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
from video3d.render.render import render_uv
|
17 |
+
|
18 |
+
from . import util
|
19 |
+
from . import texture
|
20 |
+
from . import mlptexture
|
21 |
+
from ..utils import misc
|
22 |
+
|
23 |
+
######################################################################################
|
24 |
+
# Wrapper to make materials behave like a python dict, but register textures as
|
25 |
+
# torch.nn.Module parameters.
|
26 |
+
######################################################################################
|
27 |
+
class Material(torch.nn.Module):
|
28 |
+
def __init__(self, mat_dict):
|
29 |
+
super(Material, self).__init__()
|
30 |
+
self.mat_keys = set()
|
31 |
+
for key in mat_dict.keys():
|
32 |
+
self.mat_keys.add(key)
|
33 |
+
self[key] = mat_dict[key]
|
34 |
+
|
35 |
+
def __contains__(self, key):
|
36 |
+
return hasattr(self, key)
|
37 |
+
|
38 |
+
def __getitem__(self, key):
|
39 |
+
return getattr(self, key)
|
40 |
+
|
41 |
+
def __setitem__(self, key, val):
|
42 |
+
self.mat_keys.add(key)
|
43 |
+
setattr(self, key, val)
|
44 |
+
|
45 |
+
def __delitem__(self, key):
|
46 |
+
self.mat_keys.remove(key)
|
47 |
+
delattr(self, key)
|
48 |
+
|
49 |
+
def keys(self):
|
50 |
+
return self.mat_keys
|
51 |
+
|
52 |
+
######################################################################################
|
53 |
+
# .mtl material format loading / storing
|
54 |
+
######################################################################################
|
55 |
+
@torch.no_grad()
|
56 |
+
def load_mtl(fn, clear_ks=True):
|
57 |
+
import re
|
58 |
+
mtl_path = os.path.dirname(fn)
|
59 |
+
|
60 |
+
# Read file
|
61 |
+
with open(fn, 'r') as f:
|
62 |
+
lines = f.readlines()
|
63 |
+
|
64 |
+
# Parse materials
|
65 |
+
materials = []
|
66 |
+
for line in lines:
|
67 |
+
split_line = re.split(' +|\t+|\n+', line.strip())
|
68 |
+
prefix = split_line[0].lower()
|
69 |
+
data = split_line[1:]
|
70 |
+
if 'newmtl' in prefix:
|
71 |
+
material = Material({'name' : data[0]})
|
72 |
+
materials += [material]
|
73 |
+
elif materials:
|
74 |
+
if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix:
|
75 |
+
material[prefix] = data[0]
|
76 |
+
else:
|
77 |
+
material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')
|
78 |
+
|
79 |
+
# Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps
|
80 |
+
for mat in materials:
|
81 |
+
if not 'bsdf' in mat:
|
82 |
+
mat['bsdf'] = 'pbr'
|
83 |
+
|
84 |
+
if 'map_kd' in mat:
|
85 |
+
mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']))
|
86 |
+
else:
|
87 |
+
mat['kd'] = texture.Texture2D(mat['kd'])
|
88 |
+
|
89 |
+
if 'map_ks' in mat:
|
90 |
+
mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3)
|
91 |
+
else:
|
92 |
+
mat['ks'] = texture.Texture2D(mat['ks'])
|
93 |
+
|
94 |
+
if 'bump' in mat:
|
95 |
+
mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3)
|
96 |
+
|
97 |
+
# Convert Kd from sRGB to linear RGB
|
98 |
+
mat['kd'] = texture.srgb_to_rgb(mat['kd'])
|
99 |
+
|
100 |
+
if clear_ks:
|
101 |
+
# Override ORM occlusion (red) channel by zeros. We hijack this channel
|
102 |
+
for mip in mat['ks'].getMips():
|
103 |
+
mip[..., 0] = 0.0
|
104 |
+
|
105 |
+
return materials
|
106 |
+
|
107 |
+
@torch.no_grad()
|
108 |
+
def save_mtl(fn, material, mesh=None, feat=None, resolution=[256, 256], prior_shape=None):
|
109 |
+
folder = os.path.dirname(fn)
|
110 |
+
file = os.path.basename(fn)
|
111 |
+
prefix = '_'.join(file.split('_')[:-1]) + '_'
|
112 |
+
with open(fn, "w") as f:
|
113 |
+
f.write('newmtl defaultMat\n')
|
114 |
+
if material is not None:
|
115 |
+
f.write('bsdf %s\n' % material['bsdf'])
|
116 |
+
if 'kd_ks_normal' in material.keys():
|
117 |
+
assert mesh is not None
|
118 |
+
glctx = dr.RasterizeGLContext()
|
119 |
+
mask, kd, ks, normal = render_uv(glctx, mesh, resolution, material['kd_ks_normal'], feat=feat, prior_shape=prior_shape)
|
120 |
+
|
121 |
+
hole_mask = 1. - mask
|
122 |
+
hole_mask = hole_mask.int()[0]
|
123 |
+
def uv_padding(image):
|
124 |
+
uv_padding_size = 4
|
125 |
+
inpaint_image = (
|
126 |
+
cv2.inpaint(
|
127 |
+
(image.detach().cpu().numpy() * 255).astype(np.uint8),
|
128 |
+
(hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
|
129 |
+
uv_padding_size,
|
130 |
+
cv2.INPAINT_TELEA,
|
131 |
+
)
|
132 |
+
/ 255.0
|
133 |
+
)
|
134 |
+
return torch.from_numpy(inpaint_image).to(image)
|
135 |
+
|
136 |
+
kd = uv_padding(kd[0])[None]
|
137 |
+
|
138 |
+
batch_size = kd.shape[0]
|
139 |
+
f.write(f'map_Kd {prefix}texture_kd.png\n')
|
140 |
+
misc.save_images(folder, kd.permute(0,3,1,2).detach().cpu().numpy(), fnames=[prefix + "texture_kd"] * batch_size)
|
141 |
+
f.write(f'map_Ks {prefix}texture_ks.png\n')
|
142 |
+
misc.save_images(folder, ks.permute(0,3,1,2).detach().cpu().numpy(), fnames=[prefix + "texture_ks"] * batch_size)
|
143 |
+
# disable normal
|
144 |
+
# f.write(f'bump {prefix}texture_n.png\n')
|
145 |
+
# misc.save_images(folder, normal.permute(0,3,1,2).detach().cpu().numpy(), fnames=[prefix + "texture_n"] * batch_size)
|
146 |
+
if 'kd' in material.keys():
|
147 |
+
f.write('map_Kd texture_kd.png\n')
|
148 |
+
texture.save_texture2D(os.path.join(folder, 'texture_Kd.png'), texture.rgb_to_srgb(material['kd']))
|
149 |
+
if 'ks' in material.keys():
|
150 |
+
f.write('map_Ks texture_ks.png\n')
|
151 |
+
texture.save_texture2D(os.path.join(folder, 'texture_Ks.png'), material['ks'])
|
152 |
+
if 'normal' in material.keys():
|
153 |
+
f.write('bump texture_n.png\n')
|
154 |
+
texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5)
|
155 |
+
else:
|
156 |
+
f.write('Kd 1 1 1\n')
|
157 |
+
f.write('Ks 0 0 0\n')
|
158 |
+
f.write('Ka 0 0 0\n')
|
159 |
+
f.write('Tf 1 1 1\n')
|
160 |
+
f.write('Ni 1\n')
|
161 |
+
f.write('Ns 0\n')
|
162 |
+
|
163 |
+
######################################################################################
|
164 |
+
# Merge multiple materials into a single uber-material
|
165 |
+
######################################################################################
|
166 |
+
|
167 |
+
def _upscale_replicate(x, full_res):
|
168 |
+
x = x.permute(0, 3, 1, 2)
|
169 |
+
x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate')
|
170 |
+
return x.permute(0, 2, 3, 1).contiguous()
|
171 |
+
|
172 |
+
def merge_materials(materials, texcoords, tfaces, mfaces):
|
173 |
+
assert len(materials) > 0
|
174 |
+
for mat in materials:
|
175 |
+
assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)"
|
176 |
+
assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled"
|
177 |
+
|
178 |
+
uber_material = Material({
|
179 |
+
'name' : 'uber_material',
|
180 |
+
'bsdf' : materials[0]['bsdf'],
|
181 |
+
})
|
182 |
+
|
183 |
+
textures = ['kd', 'ks', 'normal']
|
184 |
+
|
185 |
+
# Find maximum texture resolution across all materials and textures
|
186 |
+
max_res = None
|
187 |
+
for mat in materials:
|
188 |
+
for tex in textures:
|
189 |
+
tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1])
|
190 |
+
max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res
|
191 |
+
|
192 |
+
# Compute size of compund texture and round up to nearest PoT
|
193 |
+
full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int)
|
194 |
+
|
195 |
+
# Normalize texture resolution across all materials & combine into a single large texture
|
196 |
+
for tex in textures:
|
197 |
+
if tex in materials[0]:
|
198 |
+
tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x
|
199 |
+
tex_data = _upscale_replicate(tex_data, full_res)
|
200 |
+
uber_material[tex] = texture.Texture2D(tex_data)
|
201 |
+
|
202 |
+
# Compute scaling values for used / unused texture area
|
203 |
+
s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]]
|
204 |
+
|
205 |
+
# Recompute texture coordinates to cooincide with new composite texture
|
206 |
+
new_tverts = {}
|
207 |
+
new_tverts_data = []
|
208 |
+
for fi in range(len(tfaces)):
|
209 |
+
matIdx = mfaces[fi]
|
210 |
+
for vi in range(3):
|
211 |
+
ti = tfaces[fi][vi]
|
212 |
+
if not (ti in new_tverts):
|
213 |
+
new_tverts[ti] = {}
|
214 |
+
if not (matIdx in new_tverts[ti]): # create new vertex
|
215 |
+
new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here
|
216 |
+
new_tverts[ti][matIdx] = len(new_tverts_data) - 1
|
217 |
+
tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex
|
218 |
+
|
219 |
+
return uber_material, new_tverts_data, tfaces
|
220 |
+
|
221 |
+
######################################################################################
|
222 |
+
# Utility functions for material
|
223 |
+
######################################################################################
|
224 |
+
|
225 |
+
def initial_guess_material(cfgs, mlp=False, init_mat=None, tet_bbox=None):
|
226 |
+
kd_min = torch.tensor(cfgs.get('kd_min', [0., 0., 0., 0.]), dtype=torch.float32)
|
227 |
+
kd_max = torch.tensor(cfgs.get('kd_max', [1., 1., 1., 1.]), dtype=torch.float32)
|
228 |
+
ks_min = torch.tensor(cfgs.get('ks_min', [0., 0., 0.]), dtype=torch.float32)
|
229 |
+
ks_max = torch.tensor(cfgs.get('ks_max', [0., 0., 0.]), dtype=torch.float32)
|
230 |
+
nrm_min = torch.tensor(cfgs.get('nrm_min', [-1., -1., 0.]), dtype=torch.float32)
|
231 |
+
nrm_max = torch.tensor(cfgs.get('nrm_max', [1., 1., 1.]), dtype=torch.float32)
|
232 |
+
if mlp:
|
233 |
+
num_layers = cfgs.get("num_layers_tex", 5)
|
234 |
+
nf = cfgs.get("hidden_size", 128)
|
235 |
+
enable_encoder = cfgs.get("enable_encoder", False)
|
236 |
+
feat_dim = cfgs.get("latent_dim", 64) if enable_encoder else 0
|
237 |
+
|
238 |
+
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0)
|
239 |
+
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0)
|
240 |
+
min_max = torch.stack((mlp_min, mlp_max), dim=0)
|
241 |
+
out_chn = 9
|
242 |
+
mlp_map_opt = mlptexture.MLPTexture3D(tet_bbox, channels=out_chn, internal_dims=nf, hidden=num_layers-1, feat_dim=feat_dim, min_max=min_max)
|
243 |
+
mat = Material({'kd_ks_normal' : mlp_map_opt})
|
244 |
+
else:
|
245 |
+
# Setup Kd (albedo) and Ks (x, roughness, metalness) textures
|
246 |
+
if cfgs.random_textures or init_mat is None:
|
247 |
+
num_channels = 4 if cfgs.layers > 1 else 3
|
248 |
+
kd_init = torch.rand(size=cfgs.texture_res + [num_channels]) * (kd_max - kd_min)[None, None, 0:num_channels] + kd_min[None, None, 0:num_channels]
|
249 |
+
kd_map_opt = texture.create_trainable(kd_init , cfgs.texture_res, not cfgs.custom_mip, [kd_min, kd_max])
|
250 |
+
|
251 |
+
ksR = np.random.uniform(size=cfgs.texture_res + [1], low=0.0, high=0.01)
|
252 |
+
ksG = np.random.uniform(size=cfgs.texture_res + [1], low=ks_min[1].cpu(), high=ks_max[1].cpu())
|
253 |
+
ksB = np.random.uniform(size=cfgs.texture_res + [1], low=ks_min[2].cpu(), high=ks_max[2].cpu())
|
254 |
+
|
255 |
+
ks_map_opt = texture.create_trainable(np.concatenate((ksR, ksG, ksB), axis=2), cfgs.texture_res, not cfgs.custom_mip, [ks_min, ks_max])
|
256 |
+
else:
|
257 |
+
kd_map_opt = texture.create_trainable(init_mat['kd'], cfgs.texture_res, not cfgs.custom_mip, [kd_min, kd_max])
|
258 |
+
ks_map_opt = texture.create_trainable(init_mat['ks'], cfgs.texture_res, not cfgs.custom_mip, [ks_min, ks_max])
|
259 |
+
|
260 |
+
# Setup normal map
|
261 |
+
if cfgs.random_textures or init_mat is None or 'normal' not in init_mat:
|
262 |
+
normal_map_opt = texture.create_trainable(np.array([0, 0, 1]), cfgs.texture_res, not cfgs.custom_mip, [nrm_min, nrm_max])
|
263 |
+
else:
|
264 |
+
normal_map_opt = texture.create_trainable(init_mat['normal'], cfgs.texture_res, not cfgs.custom_mip, [nrm_min, nrm_max])
|
265 |
+
|
266 |
+
mat = Material({
|
267 |
+
'kd' : kd_map_opt,
|
268 |
+
'ks' : ks_map_opt,
|
269 |
+
'normal' : normal_map_opt
|
270 |
+
})
|
271 |
+
|
272 |
+
if init_mat is not None:
|
273 |
+
mat['bsdf'] = init_mat['bsdf']
|
274 |
+
elif "bsdf" in cfgs:
|
275 |
+
mat['bsdf'] = cfgs["bsdf"]
|
276 |
+
else:
|
277 |
+
mat['bsdf'] = 'pbr'
|
278 |
+
|
279 |
+
if not cfgs.get("perturb_normal", False):
|
280 |
+
mat['no_perturbed_nrm'] = True
|
281 |
+
|
282 |
+
return mat
|
video3d/render/mesh.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
from difflib import unified_diff
|
11 |
+
import os
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from . import obj
|
16 |
+
from . import util
|
17 |
+
|
18 |
+
#########################################################################################
|
19 |
+
# Base mesh class
|
20 |
+
#
|
21 |
+
# Minibatch in mesh is supported, as long as each mesh shares the same edge connectivity.
|
22 |
+
#########################################################################################
|
23 |
+
class Mesh:
|
24 |
+
def __init__(self,
|
25 |
+
v_pos=None,
|
26 |
+
t_pos_idx=None,
|
27 |
+
v_nrm=None,
|
28 |
+
t_nrm_idx=None,
|
29 |
+
v_tex=None,
|
30 |
+
t_tex_idx=None,
|
31 |
+
v_tng=None,
|
32 |
+
t_tng_idx=None,
|
33 |
+
material=None,
|
34 |
+
base=None):
|
35 |
+
self.v_pos = v_pos
|
36 |
+
self.v_nrm = v_nrm
|
37 |
+
self.v_tex = v_tex
|
38 |
+
self.v_tng = v_tng
|
39 |
+
self.t_pos_idx = t_pos_idx
|
40 |
+
self.t_nrm_idx = t_nrm_idx
|
41 |
+
self.t_tex_idx = t_tex_idx
|
42 |
+
self.t_tng_idx = t_tng_idx
|
43 |
+
self.material = material
|
44 |
+
|
45 |
+
if base is not None:
|
46 |
+
self.copy_none(base)
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.v_pos)
|
50 |
+
|
51 |
+
def copy_none(self, other):
|
52 |
+
if self.v_pos is None:
|
53 |
+
self.v_pos = other.v_pos
|
54 |
+
if self.t_pos_idx is None:
|
55 |
+
self.t_pos_idx = other.t_pos_idx
|
56 |
+
if self.v_nrm is None:
|
57 |
+
self.v_nrm = other.v_nrm
|
58 |
+
if self.t_nrm_idx is None:
|
59 |
+
self.t_nrm_idx = other.t_nrm_idx
|
60 |
+
if self.v_tex is None:
|
61 |
+
self.v_tex = other.v_tex
|
62 |
+
if self.t_tex_idx is None:
|
63 |
+
self.t_tex_idx = other.t_tex_idx
|
64 |
+
if self.v_tng is None:
|
65 |
+
self.v_tng = other.v_tng
|
66 |
+
if self.t_tng_idx is None:
|
67 |
+
self.t_tng_idx = other.t_tng_idx
|
68 |
+
if self.material is None:
|
69 |
+
self.material = other.material
|
70 |
+
|
71 |
+
def clone(self):
|
72 |
+
out = Mesh(base=self)
|
73 |
+
if out.v_pos is not None:
|
74 |
+
out.v_pos = out.v_pos.clone().detach()
|
75 |
+
if out.t_pos_idx is not None:
|
76 |
+
out.t_pos_idx = out.t_pos_idx.clone().detach()
|
77 |
+
if out.v_nrm is not None:
|
78 |
+
out.v_nrm = out.v_nrm.clone().detach()
|
79 |
+
if out.t_nrm_idx is not None:
|
80 |
+
out.t_nrm_idx = out.t_nrm_idx.clone().detach()
|
81 |
+
if out.v_tex is not None:
|
82 |
+
out.v_tex = out.v_tex.clone().detach()
|
83 |
+
if out.t_tex_idx is not None:
|
84 |
+
out.t_tex_idx = out.t_tex_idx.clone().detach()
|
85 |
+
if out.v_tng is not None:
|
86 |
+
out.v_tng = out.v_tng.clone().detach()
|
87 |
+
if out.t_tng_idx is not None:
|
88 |
+
out.t_tng_idx = out.t_tng_idx.clone().detach()
|
89 |
+
return out
|
90 |
+
|
91 |
+
def detach(self):
|
92 |
+
return self.clone()
|
93 |
+
|
94 |
+
def extend(self, N: int):
|
95 |
+
"""
|
96 |
+
Create new Mesh class which contains each input mesh N times.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
N: number of new copies of each mesh.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
new Mesh object.
|
103 |
+
"""
|
104 |
+
verts = self.v_pos.repeat(N, 1, 1)
|
105 |
+
faces = self.t_pos_idx
|
106 |
+
uvs = self.v_tex.repeat(N, 1, 1)
|
107 |
+
uv_idx = self.t_tex_idx
|
108 |
+
mat = self.material
|
109 |
+
|
110 |
+
return make_mesh(verts, faces, uvs, uv_idx, self.material)
|
111 |
+
|
112 |
+
def deform(self, deformation):
|
113 |
+
"""
|
114 |
+
Create new Mesh class which is obtained by performing the deformation to the self.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
deformation: tensor with shape (B, V, 3)
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
new Mesh object after the deformation.
|
121 |
+
"""
|
122 |
+
assert deformation.shape[1] == self.v_pos.shape[1] and deformation.shape[2] == 3
|
123 |
+
verts = self.v_pos + deformation
|
124 |
+
return make_mesh(verts, self.t_pos_idx, self.v_tex.repeat(len(verts), 1, 1), self.t_tex_idx, self.material)
|
125 |
+
|
126 |
+
def get_m_to_n(self, m: int, n: int):
|
127 |
+
"""
|
128 |
+
Create new Mesh class with the n-th (included) mesh to the m-th (not included) mesh in the batch.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
m: the index of the starting mesh to be contained.
|
132 |
+
n: the index of the first mesh not to be contained.
|
133 |
+
"""
|
134 |
+
verts = self.v_pos[m:n, ...]
|
135 |
+
faces = self.t_pos_idx
|
136 |
+
uvs = self.v_tex[m:n, ...]
|
137 |
+
uv_idx = self.t_tex_idx
|
138 |
+
mat = self.material
|
139 |
+
|
140 |
+
return make_mesh(verts, faces, uvs, uv_idx, mat)
|
141 |
+
|
142 |
+
def first_n(self, n: int):
|
143 |
+
"""
|
144 |
+
Create new Mesh class with only the first n meshes in the batch.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
n: number of meshes to be contained.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
new Mesh object with the first n meshes.
|
151 |
+
"""
|
152 |
+
return self.get_m_to_n(0, n)
|
153 |
+
verts = self.v_pos[:n, ...]
|
154 |
+
faces = self.t_pos_idx
|
155 |
+
uvs = self.v_tex[:n, ...]
|
156 |
+
uv_idx = self.t_tex_idx
|
157 |
+
mat = self.material
|
158 |
+
|
159 |
+
return make_mesh(verts, faces, uvs, uv_idx, mat)
|
160 |
+
|
161 |
+
def get_n(self, n: int):
|
162 |
+
"""
|
163 |
+
Create new Mesh class with only the n-th meshes in the batch.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
n: the index of the mesh to be contained.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
new Mesh object with the n-th mesh.
|
170 |
+
"""
|
171 |
+
verts = self.v_pos[n:n+1, ...]
|
172 |
+
faces = self.t_pos_idx
|
173 |
+
uvs = self.v_tex[n:n+1, ...]
|
174 |
+
uv_idx = self.t_tex_idx
|
175 |
+
mat = self.material
|
176 |
+
|
177 |
+
return make_mesh(verts, faces, uvs, uv_idx, mat)
|
178 |
+
|
179 |
+
|
180 |
+
######################################################################################
|
181 |
+
# Mesh loading helper
|
182 |
+
######################################################################################
|
183 |
+
def load_mesh(filename, mtl_override=None):
|
184 |
+
name, ext = os.path.splitext(filename)
|
185 |
+
if ext == ".obj":
|
186 |
+
return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override)
|
187 |
+
assert False, "Invalid mesh file extension"
|
188 |
+
|
189 |
+
######################################################################################
|
190 |
+
# Compute AABB
|
191 |
+
######################################################################################
|
192 |
+
def aabb(mesh):
|
193 |
+
return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values
|
194 |
+
|
195 |
+
######################################################################################
|
196 |
+
# Compute unique edge list from attribute/vertex index list
|
197 |
+
######################################################################################
|
198 |
+
def compute_edges(attr_idx, return_inverse=False):
|
199 |
+
with torch.no_grad():
|
200 |
+
# Create all edges, packed by triangle
|
201 |
+
idx = attr_idx[0]
|
202 |
+
all_edges = torch.cat((
|
203 |
+
torch.stack((idx[:, 0], idx[:, 1]), dim=-1),
|
204 |
+
torch.stack((idx[:, 1], idx[:, 2]), dim=-1),
|
205 |
+
torch.stack((idx[:, 2], idx[:, 0]), dim=-1),
|
206 |
+
), dim=-1).view(-1, 2)
|
207 |
+
|
208 |
+
# Swap edge order so min index is always first
|
209 |
+
order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
|
210 |
+
sorted_edges = torch.cat((
|
211 |
+
torch.gather(all_edges, 1, order),
|
212 |
+
torch.gather(all_edges, 1, 1 - order)
|
213 |
+
), dim=-1)
|
214 |
+
|
215 |
+
# Eliminate duplicates and return inverse mapping
|
216 |
+
return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse)
|
217 |
+
|
218 |
+
######################################################################################
|
219 |
+
# Compute unique edge to face mapping from attribute/vertex index list
|
220 |
+
######################################################################################
|
221 |
+
def compute_edge_to_face_mapping(attr_idx, return_inverse=False):
|
222 |
+
with torch.no_grad():
|
223 |
+
# Get unique edges
|
224 |
+
# Create all edges, packed by triangle
|
225 |
+
idx = attr_idx[0]
|
226 |
+
all_edges = torch.cat((
|
227 |
+
torch.stack((idx[:, 0], idx[:, 1]), dim=-1),
|
228 |
+
torch.stack((idx[:, 1], idx[:, 2]), dim=-1),
|
229 |
+
torch.stack((idx[:, 2], idx[:, 0]), dim=-1),
|
230 |
+
), dim=-1).view(-1, 2)
|
231 |
+
|
232 |
+
# Swap edge order so min index is always first
|
233 |
+
order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)
|
234 |
+
sorted_edges = torch.cat((
|
235 |
+
torch.gather(all_edges, 1, order),
|
236 |
+
torch.gather(all_edges, 1, 1 - order)
|
237 |
+
), dim=-1)
|
238 |
+
|
239 |
+
# Elliminate duplicates and return inverse mapping
|
240 |
+
unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)
|
241 |
+
|
242 |
+
tris = torch.arange(idx.shape[0]).repeat_interleave(3).cuda()
|
243 |
+
|
244 |
+
tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()
|
245 |
+
|
246 |
+
# Compute edge to face table
|
247 |
+
mask0 = order[:,0] == 0
|
248 |
+
mask1 = order[:,0] == 1
|
249 |
+
tris_per_edge[idx_map[mask0], 0] = tris[mask0]
|
250 |
+
tris_per_edge[idx_map[mask1], 1] = tris[mask1]
|
251 |
+
|
252 |
+
return tris_per_edge
|
253 |
+
|
254 |
+
######################################################################################
|
255 |
+
# Align base mesh to reference mesh:move & rescale to match bounding boxes.
|
256 |
+
######################################################################################
|
257 |
+
def unit_size(mesh):
|
258 |
+
with torch.no_grad():
|
259 |
+
vmin, vmax = aabb(mesh)
|
260 |
+
scale = 2 / torch.max(vmax - vmin).item()
|
261 |
+
v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin
|
262 |
+
v_pos = v_pos * scale # Rescale to unit size
|
263 |
+
|
264 |
+
return Mesh(v_pos, base=mesh)
|
265 |
+
|
266 |
+
######################################################################################
|
267 |
+
# Center & scale mesh for rendering
|
268 |
+
######################################################################################
|
269 |
+
def center_by_reference(base_mesh, ref_aabb, scale):
|
270 |
+
center = (ref_aabb[0] + ref_aabb[1]) * 0.5
|
271 |
+
scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item()
|
272 |
+
v_pos = (base_mesh.v_pos - center[None, ...]) * scale
|
273 |
+
return Mesh(v_pos, base=base_mesh)
|
274 |
+
|
275 |
+
######################################################################################
|
276 |
+
# Simple smooth vertex normal computation
|
277 |
+
######################################################################################
|
278 |
+
def auto_normals(imesh):
|
279 |
+
batch_size = imesh.v_pos.shape[0]
|
280 |
+
|
281 |
+
i0 = imesh.t_pos_idx[0, :, 0] # Shape: (F)
|
282 |
+
i1 = imesh.t_pos_idx[0, :, 1] # Shape: (F)
|
283 |
+
i2 = imesh.t_pos_idx[0, :, 2] # Shape: (F)
|
284 |
+
|
285 |
+
v0 = imesh.v_pos[:, i0, :] # Shape: (B, F, 3)
|
286 |
+
v1 = imesh.v_pos[:, i1, :] # Shape: (B, F, 3)
|
287 |
+
v2 = imesh.v_pos[:, i2, :] # Shape: (B, F, 3)
|
288 |
+
|
289 |
+
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) # Shape: (B, F, 3)
|
290 |
+
|
291 |
+
# Splat face normals to vertices
|
292 |
+
v_nrm = torch.zeros_like(imesh.v_pos) # Shape: (B, V, 3)
|
293 |
+
v_nrm.scatter_add_(1, i0[None, :, None].repeat(batch_size, 1, 3), face_normals)
|
294 |
+
v_nrm.scatter_add_(1, i1[None, :, None].repeat(batch_size, 1, 3), face_normals)
|
295 |
+
v_nrm.scatter_add_(1, i2[None, :, None].repeat(batch_size, 1, 3), face_normals)
|
296 |
+
|
297 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
298 |
+
v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20,
|
299 |
+
v_nrm, torch.tensor([0.0, 0.0, 1.0],
|
300 |
+
dtype=torch.float32, device='cuda'))
|
301 |
+
v_nrm = util.safe_normalize(v_nrm)
|
302 |
+
|
303 |
+
if torch.is_anomaly_enabled():
|
304 |
+
assert torch.all(torch.isfinite(v_nrm))
|
305 |
+
|
306 |
+
return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh)
|
307 |
+
|
308 |
+
######################################################################################
|
309 |
+
# Compute tangent space from texture map coordinates
|
310 |
+
# Follows http://www.mikktspace.com/ conventions
|
311 |
+
######################################################################################
|
312 |
+
def compute_tangents(imesh):
|
313 |
+
batch_size = imesh.v_pos.shape[0]
|
314 |
+
|
315 |
+
vn_idx = [None] * 3
|
316 |
+
pos = [None] * 3
|
317 |
+
tex = [None] * 3
|
318 |
+
for i in range(0,3):
|
319 |
+
pos[i] = imesh.v_pos[:, imesh.t_pos_idx[0, :, i]]
|
320 |
+
tex[i] = imesh.v_tex[:, imesh.t_tex_idx[0, :, i]]
|
321 |
+
vn_idx[i] = imesh.t_nrm_idx[..., i:i+1]
|
322 |
+
|
323 |
+
tangents = torch.zeros_like(imesh.v_nrm)
|
324 |
+
tansum = torch.zeros_like(imesh.v_nrm)
|
325 |
+
|
326 |
+
# Compute tangent space for each triangle
|
327 |
+
uve1 = tex[1] - tex[0] # Shape: (B, F, 2)
|
328 |
+
uve2 = tex[2] - tex[0] # Shape: (B, F, 2)
|
329 |
+
pe1 = pos[1] - pos[0] # Shape: (B, F, 3)
|
330 |
+
pe2 = pos[2] - pos[0] # Shape: (B, F, 3)
|
331 |
+
|
332 |
+
nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2] # Shape: (B, F, 3)
|
333 |
+
denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1] # Shape: (B, F, 1)
|
334 |
+
|
335 |
+
# Avoid division by zero for degenerated texture coordinates
|
336 |
+
tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) # Shape: (B, F, 3)
|
337 |
+
|
338 |
+
# Update all 3 vertices
|
339 |
+
for i in range(0,3):
|
340 |
+
idx = vn_idx[i].repeat(batch_size, 1, 3) # Shape: (B, F, 3)
|
341 |
+
tangents.scatter_add_(1, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
342 |
+
tansum.scatter_add_(1, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1
|
343 |
+
tangents = tangents / tansum
|
344 |
+
|
345 |
+
# Normalize and make sure tangent is perpendicular to normal
|
346 |
+
tangents = util.safe_normalize(tangents)
|
347 |
+
tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm)
|
348 |
+
|
349 |
+
if torch.is_anomaly_enabled():
|
350 |
+
assert torch.all(torch.isfinite(tangents))
|
351 |
+
|
352 |
+
return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh)
|
353 |
+
|
354 |
+
######################################################################################
|
355 |
+
# Create new Mesh from verts, faces, uvs, and uv_idx. The rest is auto computed.
|
356 |
+
######################################################################################
|
357 |
+
def make_mesh(verts, faces, uvs, uv_idx, material):
|
358 |
+
"""
|
359 |
+
Create new Mesh class with given verts, faces, uvs, and uv_idx.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
verts: tensor of shape (B, V, 3)
|
363 |
+
faces: tensor of shape (1, F, 3)
|
364 |
+
uvs: tensor of shape (B, V, 2)
|
365 |
+
uv_idx: tensor of shape (1, F, 3)
|
366 |
+
material: an Material instance, specifying the material of the mesh.
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
new Mesh object.
|
370 |
+
"""
|
371 |
+
assert len(verts.shape) == 3 and len(faces.shape) == 3 and len(uvs.shape) == 3 and len(uv_idx.shape) == 3, "All components must be batched."
|
372 |
+
assert faces.shape[0] == 1 and uv_idx.shape[0] == 1, "Every mesh must share the same edge connectivity."
|
373 |
+
assert verts.shape[0] == uvs.shape[0], "Batch size must be consistent."
|
374 |
+
ret = Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)
|
375 |
+
ret = auto_normals(ret)
|
376 |
+
ret = compute_tangents(ret)
|
377 |
+
return ret
|
video3d/render/mlptexture.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import tinycudann as tcnn
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
#######################################################################################################################################################
|
15 |
+
# Small MLP using PyTorch primitives, internal helper class
|
16 |
+
#######################################################################################################################################################
|
17 |
+
|
18 |
+
class _MLP(torch.nn.Module):
|
19 |
+
def __init__(self, cfg, loss_scale=1.0):
|
20 |
+
super(_MLP, self).__init__()
|
21 |
+
self.loss_scale = loss_scale
|
22 |
+
net = (torch.nn.Linear(cfg['n_input_dims'], cfg['n_neurons'], bias=False), torch.nn.ReLU())
|
23 |
+
for i in range(cfg['n_hidden_layers']-1):
|
24 |
+
net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_neurons'], bias=False), torch.nn.ReLU())
|
25 |
+
net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_output_dims'], bias=False),)
|
26 |
+
self.net = torch.nn.Sequential(*net).cuda()
|
27 |
+
|
28 |
+
self.net.apply(self._init_weights)
|
29 |
+
|
30 |
+
if self.loss_scale != 1.0:
|
31 |
+
self.net.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] * self.loss_scale, ))
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return self.net(x.to(torch.float32))
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def _init_weights(m):
|
38 |
+
if type(m) == torch.nn.Linear:
|
39 |
+
torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
|
40 |
+
if hasattr(m.bias, 'data'):
|
41 |
+
m.bias.data.fill_(0.0)
|
42 |
+
|
43 |
+
#######################################################################################################################################################
|
44 |
+
# Outward visible MLP class
|
45 |
+
#######################################################################################################################################################
|
46 |
+
|
47 |
+
class MLPTexture3D(torch.nn.Module):
|
48 |
+
def __init__(self, AABB, channels=3, internal_dims=32, hidden=2, feat_dim=0, min_max=None, bsdf='diffuse', perturb_normal=False, symmetrize=False):
|
49 |
+
super(MLPTexture3D, self).__init__()
|
50 |
+
|
51 |
+
self.channels = channels
|
52 |
+
self.feat_dim = feat_dim
|
53 |
+
self.internal_dims = internal_dims
|
54 |
+
self.AABB = AABB
|
55 |
+
self.bsdf = bsdf
|
56 |
+
self.perturb_normal = perturb_normal
|
57 |
+
self.symmetrize = symmetrize
|
58 |
+
if min_max is not None:
|
59 |
+
self.register_buffer('min_max', min_max)
|
60 |
+
else:
|
61 |
+
self.min_max = None
|
62 |
+
|
63 |
+
# Setup positional encoding, see https://github.com/NVlabs/tiny-cuda-nn for details.
|
64 |
+
desired_resolution = 4096
|
65 |
+
base_grid_resolution = 16
|
66 |
+
num_levels = 16
|
67 |
+
per_level_scale = np.exp(np.log(desired_resolution / base_grid_resolution) / (num_levels-1))
|
68 |
+
|
69 |
+
enc_cfg = {
|
70 |
+
"otype": "HashGrid",
|
71 |
+
"n_levels": num_levels,
|
72 |
+
"n_features_per_level": 2,
|
73 |
+
"log2_hashmap_size": 19,
|
74 |
+
"base_resolution": base_grid_resolution,
|
75 |
+
"per_level_scale" : per_level_scale
|
76 |
+
}
|
77 |
+
|
78 |
+
# gradient_scaling = 128.0
|
79 |
+
gradient_scaling = 1.0
|
80 |
+
self.encoder = tcnn.Encoding(3, enc_cfg)
|
81 |
+
self.encoder.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] / gradient_scaling, ))
|
82 |
+
|
83 |
+
# Setup MLP
|
84 |
+
mlp_cfg = {
|
85 |
+
"n_input_dims" : internal_dims + feat_dim,
|
86 |
+
"n_output_dims" : self.channels,
|
87 |
+
"n_hidden_layers" : hidden,
|
88 |
+
"n_neurons" : self.internal_dims
|
89 |
+
}
|
90 |
+
self.linear = torch.nn.Linear(self.encoder.n_output_dims, internal_dims)
|
91 |
+
self.net = _MLP(mlp_cfg, gradient_scaling)
|
92 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
93 |
+
print("Encoder output: %d dims" % (self.encoder.n_output_dims))
|
94 |
+
|
95 |
+
# Sample texture at a given location
|
96 |
+
def sample(self, texc, feat=None):
|
97 |
+
assert (feat is None and self.feat_dim == 0) or feat.shape[-1] == self.feat_dim
|
98 |
+
|
99 |
+
if self.symmetrize:
|
100 |
+
xs, ys, zs = texc.unbind(-1)
|
101 |
+
texc = torch.stack([xs.abs(), ys, zs], -1) # mirror -x to +x
|
102 |
+
|
103 |
+
_texc = (texc.view(-1, 3) - self.AABB[0][None, ...]) / (self.AABB[1][None, ...] - self.AABB[0][None, ...])
|
104 |
+
_texc = torch.clamp(_texc, min=0, max=1)
|
105 |
+
|
106 |
+
_, image_h, image_w, _ = texc.shape
|
107 |
+
p_enc = self.encoder(_texc.contiguous())
|
108 |
+
x_in = self.linear(p_enc.type(texc.dtype))
|
109 |
+
if feat is not None:
|
110 |
+
feat_in = feat[:, None, None, :].repeat(1, image_h, image_w, 1).view(-1, self.feat_dim)
|
111 |
+
x_in = torch.concat([x_in, feat_in], dim=-1)
|
112 |
+
out = self.net(self.relu(x_in))
|
113 |
+
|
114 |
+
# Sigmoid limit and scale to the allowed range
|
115 |
+
out = torch.sigmoid(out)
|
116 |
+
if self.min_max is not None:
|
117 |
+
out = out * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]
|
118 |
+
|
119 |
+
return out.view(*texc.shape[:-1], self.channels) # Remap to [n, h, w, c]
|
120 |
+
|
121 |
+
def cleanup(self):
|
122 |
+
tcnn.free_temporary_memory()
|
video3d/render/obj.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import os
|
11 |
+
import torch
|
12 |
+
import xatlas
|
13 |
+
import trimesh
|
14 |
+
import numpy as np
|
15 |
+
import cv2
|
16 |
+
import nvdiffrast.torch as dr
|
17 |
+
from video3d.render.render import render_uv
|
18 |
+
from video3d.render.mesh import Mesh
|
19 |
+
from . import texture
|
20 |
+
from . import mesh
|
21 |
+
from . import material
|
22 |
+
|
23 |
+
######################################################################################
|
24 |
+
# Utility functions
|
25 |
+
######################################################################################
|
26 |
+
|
27 |
+
def _find_mat(materials, name):
|
28 |
+
for mat in materials:
|
29 |
+
if mat['name'] == name:
|
30 |
+
return mat
|
31 |
+
return materials[0] # Materials 0 is the default
|
32 |
+
|
33 |
+
######################################################################################
|
34 |
+
# Create mesh object from objfile
|
35 |
+
######################################################################################
|
36 |
+
|
37 |
+
def load_obj(filename, clear_ks=True, mtl_override=None):
|
38 |
+
obj_path = os.path.dirname(filename)
|
39 |
+
|
40 |
+
# Read entire file
|
41 |
+
with open(filename, 'r') as f:
|
42 |
+
lines = f.readlines()
|
43 |
+
|
44 |
+
# Load materials
|
45 |
+
all_materials = [
|
46 |
+
{
|
47 |
+
'name' : '_default_mat',
|
48 |
+
'bsdf' : 'pbr',
|
49 |
+
'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')),
|
50 |
+
'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))
|
51 |
+
}
|
52 |
+
]
|
53 |
+
if mtl_override is None:
|
54 |
+
for line in lines:
|
55 |
+
if len(line.split()) == 0:
|
56 |
+
continue
|
57 |
+
if line.split()[0] == 'mtllib':
|
58 |
+
all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library
|
59 |
+
else:
|
60 |
+
all_materials += material.load_mtl(mtl_override)
|
61 |
+
|
62 |
+
# load vertices
|
63 |
+
vertices, texcoords, normals = [], [], []
|
64 |
+
for line in lines:
|
65 |
+
if len(line.split()) == 0:
|
66 |
+
continue
|
67 |
+
|
68 |
+
prefix = line.split()[0].lower()
|
69 |
+
if prefix == 'v':
|
70 |
+
vertices.append([float(v) for v in line.split()[1:]])
|
71 |
+
elif prefix == 'vt':
|
72 |
+
val = [float(v) for v in line.split()[1:]]
|
73 |
+
texcoords.append([val[0], 1.0 - val[1]])
|
74 |
+
elif prefix == 'vn':
|
75 |
+
normals.append([float(v) for v in line.split()[1:]])
|
76 |
+
|
77 |
+
# load faces
|
78 |
+
activeMatIdx = None
|
79 |
+
used_materials = []
|
80 |
+
faces, tfaces, nfaces, mfaces = [], [], [], []
|
81 |
+
for line in lines:
|
82 |
+
if len(line.split()) == 0:
|
83 |
+
continue
|
84 |
+
|
85 |
+
prefix = line.split()[0].lower()
|
86 |
+
if prefix == 'usemtl': # Track used materials
|
87 |
+
mat = _find_mat(all_materials, line.split()[1])
|
88 |
+
if not mat in used_materials:
|
89 |
+
used_materials.append(mat)
|
90 |
+
activeMatIdx = used_materials.index(mat)
|
91 |
+
elif prefix == 'f': # Parse face
|
92 |
+
vs = line.split()[1:]
|
93 |
+
nv = len(vs)
|
94 |
+
vv = vs[0].split('/')
|
95 |
+
v0 = int(vv[0]) - 1
|
96 |
+
t0 = int(vv[1]) - 1 if vv[1] != "" else -1
|
97 |
+
n0 = int(vv[2]) - 1 if vv[2] != "" else -1
|
98 |
+
for i in range(nv - 2): # Triangulate polygons
|
99 |
+
vv = vs[i + 1].split('/')
|
100 |
+
v1 = int(vv[0]) - 1
|
101 |
+
t1 = int(vv[1]) - 1 if vv[1] != "" else -1
|
102 |
+
n1 = int(vv[2]) - 1 if vv[2] != "" else -1
|
103 |
+
vv = vs[i + 2].split('/')
|
104 |
+
v2 = int(vv[0]) - 1
|
105 |
+
t2 = int(vv[1]) - 1 if vv[1] != "" else -1
|
106 |
+
n2 = int(vv[2]) - 1 if vv[2] != "" else -1
|
107 |
+
mfaces.append(activeMatIdx)
|
108 |
+
faces.append([v0, v1, v2])
|
109 |
+
tfaces.append([t0, t1, t2])
|
110 |
+
nfaces.append([n0, n1, n2])
|
111 |
+
assert len(tfaces) == len(faces) and len(nfaces) == len (faces)
|
112 |
+
|
113 |
+
# Create an "uber" material by combining all textures into a larger texture
|
114 |
+
if len(used_materials) > 1:
|
115 |
+
uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces)
|
116 |
+
else:
|
117 |
+
uber_material = used_materials[0]
|
118 |
+
|
119 |
+
vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda')
|
120 |
+
texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None
|
121 |
+
normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None
|
122 |
+
|
123 |
+
faces = torch.tensor(faces, dtype=torch.int64, device='cuda')
|
124 |
+
tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None
|
125 |
+
nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None
|
126 |
+
|
127 |
+
return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material)
|
128 |
+
|
129 |
+
######################################################################################
|
130 |
+
# Save mesh object to objfile
|
131 |
+
######################################################################################
|
132 |
+
|
133 |
+
def write_obj(folder, fname, mesh, idx, save_material=True, feat=None, resolution=[256, 256]):
|
134 |
+
obj_file = os.path.join(folder, fname + '.obj')
|
135 |
+
print("Writing mesh: ", obj_file)
|
136 |
+
with open(obj_file, "w") as f:
|
137 |
+
f.write(f"mtllib {fname}.mtl\n")
|
138 |
+
f.write("g default\n")
|
139 |
+
|
140 |
+
v_pos = mesh.v_pos[idx].detach().cpu().numpy() if mesh.v_pos is not None else None
|
141 |
+
v_nrm = mesh.v_nrm[idx].detach().cpu().numpy() if mesh.v_nrm is not None else None
|
142 |
+
v_tex = mesh.v_tex[idx].detach().cpu().numpy() if mesh.v_tex is not None else None
|
143 |
+
|
144 |
+
t_pos_idx = mesh.t_pos_idx[0].detach().cpu().numpy() if mesh.t_pos_idx is not None else None
|
145 |
+
t_nrm_idx = mesh.t_nrm_idx[0].detach().cpu().numpy() if mesh.t_nrm_idx is not None else None
|
146 |
+
t_tex_idx = mesh.t_tex_idx[0].detach().cpu().numpy() if mesh.t_tex_idx is not None else None
|
147 |
+
|
148 |
+
print(" writing %d vertices" % len(v_pos))
|
149 |
+
for v in v_pos:
|
150 |
+
f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
|
151 |
+
|
152 |
+
if v_tex is not None and save_material:
|
153 |
+
print(" writing %d texcoords" % len(v_tex))
|
154 |
+
assert(len(t_pos_idx) == len(t_tex_idx))
|
155 |
+
for v in v_tex:
|
156 |
+
f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
|
157 |
+
|
158 |
+
if v_nrm is not None:
|
159 |
+
print(" writing %d normals" % len(v_nrm))
|
160 |
+
assert(len(t_pos_idx) == len(t_nrm_idx))
|
161 |
+
for v in v_nrm:
|
162 |
+
f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
|
163 |
+
|
164 |
+
# faces
|
165 |
+
f.write("s 1 \n")
|
166 |
+
f.write("g pMesh1\n")
|
167 |
+
f.write("usemtl defaultMat\n")
|
168 |
+
|
169 |
+
# Write faces
|
170 |
+
print(" writing %d faces" % len(t_pos_idx))
|
171 |
+
for i in range(len(t_pos_idx)):
|
172 |
+
f.write("f ")
|
173 |
+
for j in range(3):
|
174 |
+
f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
|
175 |
+
f.write("\n")
|
176 |
+
|
177 |
+
if save_material and mesh.material is not None:
|
178 |
+
mtl_file = os.path.join(folder, fname + '.mtl')
|
179 |
+
print("Writing material: ", mtl_file)
|
180 |
+
material.save_mtl(mtl_file, mesh.material, mesh=mesh.get_n(idx), feat=feat, resolution=resolution)
|
181 |
+
|
182 |
+
print("Done exporting mesh")
|
183 |
+
|
184 |
+
|
185 |
+
def write_textured_obj(folder, fname, mesh, idx, save_material=True, feat=None, resolution=[256, 256], prior_shape=None):
|
186 |
+
mesh = mesh.get_n(idx)
|
187 |
+
obj_file = os.path.join(folder, fname + '.obj')
|
188 |
+
print("Writing mesh: ", obj_file)
|
189 |
+
|
190 |
+
# Create uvs with xatlas
|
191 |
+
v_pos = mesh.v_pos.detach().cpu().numpy()
|
192 |
+
t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy()
|
193 |
+
|
194 |
+
# v_color = torch.Tensor(v_pos)[None].to("cuda")
|
195 |
+
# v_color = mesh.material.sample(v_color, feat)
|
196 |
+
# v_color = v_color[0,0,:,:3].detach().cpu()
|
197 |
+
# v_color = torch.concat([v_color, torch.ones((v_color.shape[0], 1))], dim=-1)
|
198 |
+
# v_color = v_color.numpy() * 255
|
199 |
+
# v_color = v_color.astype(np.int32)
|
200 |
+
# tmp = trimesh.Trimesh(vertices=v_pos[0], faces=t_pos_idx[0], vertex_colors=v_color)
|
201 |
+
# _ = tmp.export("tmp.obj")
|
202 |
+
# from pdb import set_trace; set_trace()
|
203 |
+
|
204 |
+
atlas = xatlas.Atlas()
|
205 |
+
atlas.add_mesh(
|
206 |
+
v_pos[0],
|
207 |
+
t_pos_idx[0],
|
208 |
+
)
|
209 |
+
co = xatlas.ChartOptions()
|
210 |
+
po = xatlas.PackOptions()
|
211 |
+
# for k, v in xatlas_chart_options.items():
|
212 |
+
# setattr(co, k, v)
|
213 |
+
# for k, v in xatlas_pack_options.items():
|
214 |
+
# setattr(po, k, v)
|
215 |
+
atlas.generate(co, po)
|
216 |
+
vmapping, indices, uvs = atlas.get_mesh(0)
|
217 |
+
# vmapping, indices, uvs = xatlas.parametrize(v_pos[0], t_pos_idx[0])
|
218 |
+
|
219 |
+
# Convert to tensors
|
220 |
+
indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
|
221 |
+
|
222 |
+
uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')
|
223 |
+
faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')
|
224 |
+
|
225 |
+
# new_mesh = Mesh(v_tex=uvs, t_tex_idx=faces, base=mesh)
|
226 |
+
new_mesh = Mesh(v_tex=uvs[None], t_tex_idx=faces[None], base=mesh)
|
227 |
+
|
228 |
+
# glctx = dr.RasterizeGLContext()
|
229 |
+
# mask, kd, ks, normal = render_uv(glctx, new_mesh, resolution, mesh.material, feat=feat)
|
230 |
+
|
231 |
+
# kd_min, kd_max = torch.tensor([ 0.0, 0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'), torch.tensor([ 1.0, 1.0, 1.0, 1.0], dtype=torch.float32, device='cuda')
|
232 |
+
# ks_min, ks_max = torch.tensor([ 0.0, 0.0, 0.0] , dtype=torch.float32, device='cuda'), torch.tensor([ 0.0, 0.0, 0.0] , dtype=torch.float32, device='cuda')
|
233 |
+
# nrm_min, nrm_max = torch.tensor([-1.0, -1.0, 0.0], dtype=torch.float32, device='cuda'), torch.tensor([ 1.0, 1.0, 1.0], dtype=torch.float32, device='cuda')
|
234 |
+
|
235 |
+
new_mesh.material = material.Material({
|
236 |
+
'bsdf' : 'diffuse',
|
237 |
+
# 'kd' : texture.Texture2D(kd, min_max=[kd_min, kd_max]),
|
238 |
+
# 'ks' : texture.Texture2D(ks, min_max=[ks_min, ks_max]),
|
239 |
+
# 'normal' : texture.Texture2D(normal, min_max=[nrm_min, nrm_max]),
|
240 |
+
'kd_ks_normal': mesh.material
|
241 |
+
})
|
242 |
+
|
243 |
+
with open(obj_file, "w") as f:
|
244 |
+
f.write(f"mtllib {fname}.mtl\n")
|
245 |
+
f.write("g default\n")
|
246 |
+
|
247 |
+
v_pos = new_mesh.v_pos[idx].detach().cpu().numpy() if new_mesh.v_pos is not None else None
|
248 |
+
v_nrm = new_mesh.v_nrm[idx].detach().cpu().numpy() if new_mesh.v_nrm is not None else None
|
249 |
+
v_tex = new_mesh.v_tex[idx].detach().cpu().numpy() if new_mesh.v_tex is not None else None
|
250 |
+
|
251 |
+
t_pos_idx = new_mesh.t_pos_idx[0].detach().cpu().numpy() if new_mesh.t_pos_idx is not None else None
|
252 |
+
t_nrm_idx = new_mesh.t_nrm_idx[0].detach().cpu().numpy() if new_mesh.t_nrm_idx is not None else None
|
253 |
+
t_tex_idx = new_mesh.t_tex_idx[0].detach().cpu().numpy() if new_mesh.t_tex_idx is not None else None
|
254 |
+
|
255 |
+
print(" writing %d vertices" % len(v_pos))
|
256 |
+
for v in v_pos:
|
257 |
+
f.write('v {} {} {} \n'.format(v[0], v[1], v[2]))
|
258 |
+
|
259 |
+
if v_tex is not None and save_material:
|
260 |
+
print(" writing %d texcoords" % len(v_tex))
|
261 |
+
assert(len(t_pos_idx) == len(t_tex_idx))
|
262 |
+
for v in v_tex:
|
263 |
+
f.write('vt {} {} \n'.format(v[0], 1.0 - v[1]))
|
264 |
+
|
265 |
+
if v_nrm is not None:
|
266 |
+
print(" writing %d normals" % len(v_nrm))
|
267 |
+
assert(len(t_pos_idx) == len(t_nrm_idx))
|
268 |
+
for v in v_nrm:
|
269 |
+
f.write('vn {} {} {}\n'.format(v[0], v[1], v[2]))
|
270 |
+
|
271 |
+
# faces
|
272 |
+
f.write("s 1 \n")
|
273 |
+
f.write("g pMesh1\n")
|
274 |
+
f.write("usemtl defaultMat\n")
|
275 |
+
|
276 |
+
# Write faces
|
277 |
+
print(" writing %d faces" % len(t_pos_idx))
|
278 |
+
for i in range(len(t_pos_idx)):
|
279 |
+
f.write("f ")
|
280 |
+
for j in range(3):
|
281 |
+
f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))
|
282 |
+
f.write("\n")
|
283 |
+
|
284 |
+
mtl_file = os.path.join(folder, fname + '.mtl')
|
285 |
+
print("Writing material: ", mtl_file)
|
286 |
+
material.save_mtl(mtl_file, new_mesh.material, mesh=new_mesh, feat=feat, resolution=resolution, prior_shape=prior_shape)
|
287 |
+
|
288 |
+
print("Done exporting mesh")
|
video3d/render/regularizer.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import nvdiffrast.torch as dr
|
12 |
+
|
13 |
+
from . import util
|
14 |
+
from . import mesh
|
15 |
+
|
16 |
+
######################################################################################
|
17 |
+
# Computes the image gradient, useful for kd/ks smoothness losses
|
18 |
+
######################################################################################
|
19 |
+
def image_grad(buf, std=0.01):
|
20 |
+
t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device="cuda"),
|
21 |
+
torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device="cuda"),
|
22 |
+
indexing='ij')
|
23 |
+
tc = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device="cuda") + torch.stack((s, t), dim=-1)[None, ...]
|
24 |
+
tap = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp')
|
25 |
+
return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:]
|
26 |
+
|
27 |
+
######################################################################################
|
28 |
+
# Computes the avergage edge length of a mesh.
|
29 |
+
# Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients
|
30 |
+
######################################################################################
|
31 |
+
def avg_edge_length(v_pos, t_pos_idx):
|
32 |
+
e_pos_idx = mesh.compute_edges(t_pos_idx)
|
33 |
+
edge_len = util.length(v_pos[:, e_pos_idx[:, 0]] - v_pos[:, e_pos_idx[:, 1]])
|
34 |
+
return torch.mean(edge_len)
|
35 |
+
|
36 |
+
######################################################################################
|
37 |
+
# Laplacian regularization using umbrella operator (Fujiwara / Desbrun).
|
38 |
+
# https://mgarland.org/class/geom04/material/smoothing.pdf
|
39 |
+
######################################################################################
|
40 |
+
def laplace_regularizer_const(v_pos, t_pos_idx):
|
41 |
+
batch_size = v_pos.shape[0]
|
42 |
+
|
43 |
+
term = torch.zeros_like(v_pos)
|
44 |
+
norm = torch.zeros_like(v_pos[..., 0:1])
|
45 |
+
|
46 |
+
v0 = v_pos[:, t_pos_idx[0, :, 0], :]
|
47 |
+
v1 = v_pos[:, t_pos_idx[0, :, 1], :]
|
48 |
+
v2 = v_pos[:, t_pos_idx[0, :, 2], :]
|
49 |
+
|
50 |
+
term.scatter_add_(1, t_pos_idx[..., 0:1].repeat(batch_size, 1, 3), (v1 - v0) + (v2 - v0))
|
51 |
+
term.scatter_add_(1, t_pos_idx[..., 1:2].repeat(batch_size, 1, 3), (v0 - v1) + (v2 - v1))
|
52 |
+
term.scatter_add_(1, t_pos_idx[..., 2:3].repeat(batch_size, 1, 3), (v0 - v2) + (v1 - v2))
|
53 |
+
|
54 |
+
two = torch.ones_like(v0) * 2.0
|
55 |
+
# norm.scatter_add_(1, t_pos_idx[..., 0:1].repeat(batch_size, 1, 3), two)
|
56 |
+
# norm.scatter_add_(1, t_pos_idx[..., 1:2].repeat(batch_size, 1, 3), two)
|
57 |
+
# norm.scatter_add_(1, t_pos_idx[..., 2:3].repeat(batch_size, 1, 3), two)
|
58 |
+
norm.scatter_add_(1, t_pos_idx[..., 0:1].repeat(batch_size, 1, 1), two)
|
59 |
+
norm.scatter_add_(1, t_pos_idx[..., 1:2].repeat(batch_size, 1, 1), two)
|
60 |
+
norm.scatter_add_(1, t_pos_idx[..., 2:3].repeat(batch_size, 1, 1), two)
|
61 |
+
|
62 |
+
term = term / torch.clamp(norm, min=1.0)
|
63 |
+
|
64 |
+
return torch.mean(term ** 2)
|
65 |
+
|
66 |
+
######################################################################################
|
67 |
+
# Smooth vertex normals
|
68 |
+
######################################################################################
|
69 |
+
def normal_consistency(v_pos, t_pos_idx):
|
70 |
+
# Compute face normals
|
71 |
+
v0 = v_pos[:, t_pos_idx[0, :, 0]]
|
72 |
+
v1 = v_pos[:, t_pos_idx[0, :, 1]]
|
73 |
+
v2 = v_pos[:, t_pos_idx[0, :, 2]]
|
74 |
+
|
75 |
+
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0, dim=-1))
|
76 |
+
|
77 |
+
tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx)
|
78 |
+
|
79 |
+
# Fetch normals for both faces sharing an edge
|
80 |
+
n0 = face_normals[:, tris_per_edge[:, 0], :]
|
81 |
+
n1 = face_normals[:, tris_per_edge[:, 1], :]
|
82 |
+
|
83 |
+
# Compute error metric based on normal difference
|
84 |
+
term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0)
|
85 |
+
term = (1.0 - term) * 0.5
|
86 |
+
|
87 |
+
return torch.mean(torch.abs(term))
|
88 |
+
|
89 |
+
|
90 |
+
def get_edge_length(v_pos, t_pos_idx):
|
91 |
+
e_pos_idx = mesh.compute_edges(t_pos_idx)
|
92 |
+
edge_len = util.length(v_pos[:, e_pos_idx[:, 0]] - v_pos[:, e_pos_idx[:, 1]])
|
93 |
+
return edge_len
|
video3d/render/render.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import nvdiffrast.torch as dr
|
12 |
+
|
13 |
+
from . import util
|
14 |
+
from . import renderutils as ru
|
15 |
+
from . import light
|
16 |
+
|
17 |
+
# ==============================================================================================
|
18 |
+
# Helper functions
|
19 |
+
# ==============================================================================================
|
20 |
+
def interpolate(attr, rast, attr_idx, rast_db=None):
|
21 |
+
return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')
|
22 |
+
|
23 |
+
# ==============================================================================================
|
24 |
+
# pixel shader
|
25 |
+
# ==============================================================================================
|
26 |
+
def shade(
|
27 |
+
gb_pos,
|
28 |
+
gb_geometric_normal,
|
29 |
+
gb_normal,
|
30 |
+
gb_tangent,
|
31 |
+
gb_tex_pos,
|
32 |
+
gb_texc,
|
33 |
+
gb_texc_deriv,
|
34 |
+
w2c,
|
35 |
+
view_pos,
|
36 |
+
lgt,
|
37 |
+
material,
|
38 |
+
bsdf,
|
39 |
+
feat,
|
40 |
+
two_sided_shading,
|
41 |
+
delta_xy_interp=None,
|
42 |
+
dino_pred=None,
|
43 |
+
class_vector=None,
|
44 |
+
im_features_map=None,
|
45 |
+
mvp=None
|
46 |
+
):
|
47 |
+
|
48 |
+
################################################################################
|
49 |
+
# Texture lookups
|
50 |
+
################################################################################
|
51 |
+
perturbed_nrm = None
|
52 |
+
# Combined texture, used for MLPs because lookups are expensive
|
53 |
+
# all_tex_jitter = material.sample(gb_tex_pos + torch.normal(mean=0, std=0.01, size=gb_tex_pos.shape, device="cuda"), feat=feat)
|
54 |
+
if material is not None:
|
55 |
+
if im_features_map is None:
|
56 |
+
all_tex = material.sample(gb_tex_pos, feat=feat)
|
57 |
+
else:
|
58 |
+
all_tex = material.sample(gb_tex_pos, feat=feat, feat_map=im_features_map, mvp=mvp, w2c=w2c, deform_xyz=gb_pos)
|
59 |
+
else:
|
60 |
+
all_tex = torch.ones(*gb_pos.shape[:-1], 9, device=gb_pos.device)
|
61 |
+
kd, ks, perturbed_nrm = all_tex[..., :3], all_tex[..., 3:6], all_tex[..., 6:9]
|
62 |
+
|
63 |
+
# Compute albedo (kd) gradient, used for material regularizer
|
64 |
+
# kd_grad = torch.sum(torch.abs(all_tex_jitter[..., :-6] - all_tex[..., :-6]), dim=-1, keepdim=True) /
|
65 |
+
|
66 |
+
if dino_pred is not None and class_vector is None:
|
67 |
+
# DOR: predive the dino value using x,y,z, we would concatenate the label vector.
|
68 |
+
# trained together, generated image as the supervision for the one-hot-vector.
|
69 |
+
dino_feat_im_pred = dino_pred.sample(gb_tex_pos)
|
70 |
+
# dino_feat_im_pred = dino_pred.sample(gb_tex_pos.detach())
|
71 |
+
if dino_pred is not None and class_vector is not None:
|
72 |
+
dino_feat_im_pred = dino_pred.sample(gb_tex_pos, feat=class_vector)
|
73 |
+
|
74 |
+
# else:
|
75 |
+
# kd_jitter = material['kd'].sample(gb_texc + torch.normal(mean=0, std=0.005, size=gb_texc.shape, device="cuda"), gb_texc_deriv)
|
76 |
+
# kd = material['kd'].sample(gb_texc, gb_texc_deriv)
|
77 |
+
# ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha
|
78 |
+
# if 'normal' in material:
|
79 |
+
# perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv)
|
80 |
+
# kd_grad = torch.sum(torch.abs(kd_jitter[..., 0:3] - kd[..., 0:3]), dim=-1, keepdim=True) / 3
|
81 |
+
|
82 |
+
# Separate kd into alpha and color, default alpha = 1
|
83 |
+
# alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])
|
84 |
+
# kd = kd[..., 0:3]
|
85 |
+
alpha = torch.ones_like(kd[..., 0:1])
|
86 |
+
|
87 |
+
################################################################################
|
88 |
+
# Normal perturbation & normal bend
|
89 |
+
################################################################################
|
90 |
+
if material is None or not material.perturb_normal:
|
91 |
+
perturbed_nrm = None
|
92 |
+
|
93 |
+
gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=two_sided_shading, opengl=True, use_python=True)
|
94 |
+
|
95 |
+
# if two_sided_shading:
|
96 |
+
# view_vec = util.safe_normalize(view_pos - gb_pos, -1)
|
97 |
+
# gb_normal = torch.where(torch.sum(gb_geometric_normal * view_vec, -1, keepdim=True) > 0, gb_geometric_normal, -gb_geometric_normal)
|
98 |
+
# else:
|
99 |
+
# gb_normal = gb_geometric_normal
|
100 |
+
|
101 |
+
b, h, w, _ = gb_normal.shape
|
102 |
+
cam_normal = util.safe_normalize(torch.matmul(gb_normal.view(b, -1, 3), w2c[:,:3,:3].transpose(2,1))).view(b, h, w, 3)
|
103 |
+
|
104 |
+
################################################################################
|
105 |
+
# Evaluate BSDF
|
106 |
+
################################################################################
|
107 |
+
|
108 |
+
assert bsdf is not None or material.bsdf is not None, "Material must specify a BSDF type"
|
109 |
+
bsdf = bsdf if bsdf is not None else material.bsdf
|
110 |
+
shading = None
|
111 |
+
if bsdf == 'pbr':
|
112 |
+
if isinstance(lgt, light.EnvironmentLight):
|
113 |
+
shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True)
|
114 |
+
else:
|
115 |
+
assert False, "Invalid light type"
|
116 |
+
elif bsdf == 'diffuse':
|
117 |
+
if lgt is None:
|
118 |
+
shaded_col = kd
|
119 |
+
elif isinstance(lgt, light.EnvironmentLight):
|
120 |
+
shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False)
|
121 |
+
# elif isinstance(lgt, light.DirectionalLight):
|
122 |
+
# shaded_col, shading = lgt.shade(feat, kd, cam_normal)
|
123 |
+
# else:
|
124 |
+
# assert False, "Invalid light type"
|
125 |
+
else:
|
126 |
+
shaded_col, shading = lgt.shade(feat, kd, cam_normal)
|
127 |
+
elif bsdf == 'normal':
|
128 |
+
shaded_col = (gb_normal + 1.0) * 0.5
|
129 |
+
elif bsdf == 'geo_normal':
|
130 |
+
shaded_col = (gb_geometric_normal + 1.0) * 0.5
|
131 |
+
elif bsdf == 'tangent':
|
132 |
+
shaded_col = (gb_tangent + 1.0) * 0.5
|
133 |
+
elif bsdf == 'kd':
|
134 |
+
shaded_col = kd
|
135 |
+
elif bsdf == 'ks':
|
136 |
+
shaded_col = ks
|
137 |
+
else:
|
138 |
+
assert False, "Invalid BSDF '%s'" % bsdf
|
139 |
+
|
140 |
+
# Return multiple buffers
|
141 |
+
buffers = {
|
142 |
+
'kd' : torch.cat((kd, alpha), dim=-1),
|
143 |
+
'shaded' : torch.cat((shaded_col, alpha), dim=-1),
|
144 |
+
# 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1),
|
145 |
+
# 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1),
|
146 |
+
}
|
147 |
+
|
148 |
+
if dino_pred is not None:
|
149 |
+
buffers['dino_feat_im_pred'] = torch.cat((dino_feat_im_pred, alpha), dim=-1)
|
150 |
+
|
151 |
+
if delta_xy_interp is not None:
|
152 |
+
buffers['flow'] = torch.cat((delta_xy_interp, alpha), dim=-1)
|
153 |
+
|
154 |
+
if shading is not None:
|
155 |
+
buffers['shading'] = torch.cat((shading, alpha), dim=-1)
|
156 |
+
|
157 |
+
return buffers
|
158 |
+
|
159 |
+
# ==============================================================================================
|
160 |
+
# Render a depth slice of the mesh (scene), some limitations:
|
161 |
+
# - Single light
|
162 |
+
# - Single material
|
163 |
+
# ==============================================================================================
|
164 |
+
def render_layer(
|
165 |
+
rast,
|
166 |
+
rast_deriv,
|
167 |
+
mesh,
|
168 |
+
w2c,
|
169 |
+
view_pos,
|
170 |
+
material,
|
171 |
+
lgt,
|
172 |
+
resolution,
|
173 |
+
spp,
|
174 |
+
msaa,
|
175 |
+
bsdf,
|
176 |
+
feat,
|
177 |
+
prior_mesh,
|
178 |
+
two_sided_shading,
|
179 |
+
render_flow,
|
180 |
+
delta_xy=None,
|
181 |
+
dino_pred=None,
|
182 |
+
class_vector=None,
|
183 |
+
im_features_map=None,
|
184 |
+
mvp=None
|
185 |
+
):
|
186 |
+
|
187 |
+
full_res = [resolution[0]*spp, resolution[1]*spp]
|
188 |
+
|
189 |
+
if prior_mesh is None:
|
190 |
+
prior_mesh = mesh
|
191 |
+
|
192 |
+
################################################################################
|
193 |
+
# Rasterize
|
194 |
+
################################################################################
|
195 |
+
|
196 |
+
# Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution
|
197 |
+
if spp > 1 and msaa:
|
198 |
+
rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest')
|
199 |
+
rast_out_deriv_s = util.scale_img_nhwc(rast_deriv, resolution, mag='nearest', min='nearest') * spp
|
200 |
+
else:
|
201 |
+
rast_out_s = rast
|
202 |
+
rast_out_deriv_s = rast_deriv
|
203 |
+
|
204 |
+
if render_flow:
|
205 |
+
delta_xy_interp, _ = interpolate(delta_xy, rast_out_s, mesh.t_pos_idx[0].int())
|
206 |
+
else:
|
207 |
+
delta_xy_interp = None
|
208 |
+
|
209 |
+
################################################################################
|
210 |
+
# Interpolate attributes
|
211 |
+
################################################################################
|
212 |
+
|
213 |
+
# Interpolate world space position
|
214 |
+
gb_pos, _ = interpolate(mesh.v_pos, rast_out_s, mesh.t_pos_idx[0].int())
|
215 |
+
|
216 |
+
# Compute geometric normals. We need those because of bent normals trick (for bump mapping)
|
217 |
+
v0 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 0], :]
|
218 |
+
v1 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 1], :]
|
219 |
+
v2 = mesh.v_pos[:, mesh.t_pos_idx[0, :, 2], :]
|
220 |
+
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0, dim=-1))
|
221 |
+
num_faces = face_normals.shape[1]
|
222 |
+
face_normal_indices = (torch.arange(0, num_faces, dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
|
223 |
+
gb_geometric_normal, _ = interpolate(face_normals, rast_out_s, face_normal_indices.int())
|
224 |
+
|
225 |
+
# Compute tangent space
|
226 |
+
assert mesh.v_nrm is not None and mesh.v_tng is not None
|
227 |
+
gb_normal, _ = interpolate(mesh.v_nrm, rast_out_s, mesh.t_nrm_idx[0].int())
|
228 |
+
gb_tangent, _ = interpolate(mesh.v_tng, rast_out_s, mesh.t_tng_idx[0].int()) # Interpolate tangents
|
229 |
+
|
230 |
+
# Texture coordinate
|
231 |
+
assert mesh.v_tex is not None
|
232 |
+
gb_texc, gb_texc_deriv = interpolate(mesh.v_tex, rast_out_s, mesh.t_tex_idx[0].int(), rast_db=rast_out_deriv_s)
|
233 |
+
|
234 |
+
################################################################################
|
235 |
+
# Shade
|
236 |
+
################################################################################
|
237 |
+
|
238 |
+
gb_tex_pos, _ = interpolate(prior_mesh.v_pos, rast_out_s, mesh.t_pos_idx[0].int())
|
239 |
+
buffers = shade(gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_tex_pos, gb_texc, gb_texc_deriv, w2c, view_pos, lgt, material, bsdf, feat=feat, two_sided_shading=two_sided_shading, delta_xy_interp=delta_xy_interp, dino_pred=dino_pred, class_vector=class_vector, im_features_map=im_features_map, mvp=mvp)
|
240 |
+
|
241 |
+
################################################################################
|
242 |
+
# Prepare output
|
243 |
+
################################################################################
|
244 |
+
|
245 |
+
# Scale back up to visibility resolution if using MSAA
|
246 |
+
if spp > 1 and msaa:
|
247 |
+
for key in buffers.keys():
|
248 |
+
buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest')
|
249 |
+
|
250 |
+
# Return buffers
|
251 |
+
return buffers
|
252 |
+
|
253 |
+
# ==============================================================================================
|
254 |
+
# Render a depth peeled mesh (scene), some limitations:
|
255 |
+
# - Single light
|
256 |
+
# - Single material
|
257 |
+
# ==============================================================================================
|
258 |
+
def render_mesh(
|
259 |
+
ctx,
|
260 |
+
mesh,
|
261 |
+
mtx_in,
|
262 |
+
w2c,
|
263 |
+
view_pos,
|
264 |
+
material,
|
265 |
+
lgt,
|
266 |
+
resolution,
|
267 |
+
spp = 1,
|
268 |
+
num_layers = 1,
|
269 |
+
msaa = False,
|
270 |
+
background = None,
|
271 |
+
bsdf = None,
|
272 |
+
feat = None,
|
273 |
+
prior_mesh = None,
|
274 |
+
two_sided_shading = True,
|
275 |
+
render_flow = False,
|
276 |
+
dino_pred = None,
|
277 |
+
class_vector = None,
|
278 |
+
num_frames = None,
|
279 |
+
im_features_map = None
|
280 |
+
):
|
281 |
+
|
282 |
+
def prepare_input_vector(x):
|
283 |
+
x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x
|
284 |
+
return x[:, None, None, :] if len(x.shape) == 2 else x
|
285 |
+
|
286 |
+
def composite_buffer(key, layers, background, antialias):
|
287 |
+
accum = background
|
288 |
+
for buffers, rast in reversed(layers):
|
289 |
+
alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:]
|
290 |
+
accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha)
|
291 |
+
if antialias:
|
292 |
+
accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx[0].int())
|
293 |
+
return accum
|
294 |
+
|
295 |
+
assert mesh.t_pos_idx.shape[1] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)"
|
296 |
+
assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1])
|
297 |
+
|
298 |
+
full_res = [resolution[0] * spp, resolution[1] * spp]
|
299 |
+
|
300 |
+
# Convert numpy arrays to torch tensors
|
301 |
+
mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in
|
302 |
+
view_pos = prepare_input_vector(view_pos) # Shape: (B, 1, 1, 3)
|
303 |
+
|
304 |
+
# clip space transform
|
305 |
+
v_pos_clip = ru.xfm_points(mesh.v_pos, mtx_in, use_python=True)
|
306 |
+
|
307 |
+
# render flow
|
308 |
+
if render_flow:
|
309 |
+
v_pos_clip2 = v_pos_clip[..., :2] / v_pos_clip[..., -1:]
|
310 |
+
v_pos_clip2 = v_pos_clip2.view(-1, num_frames, *v_pos_clip2.shape[1:])
|
311 |
+
delta_xy = v_pos_clip2[:, 1:] - v_pos_clip2[:, :-1]
|
312 |
+
delta_xy = torch.cat([delta_xy, torch.zeros_like(delta_xy[:, :1])], dim=1)
|
313 |
+
delta_xy = delta_xy.view(-1, *delta_xy.shape[2:])
|
314 |
+
else:
|
315 |
+
delta_xy = None
|
316 |
+
|
317 |
+
# Render all layers front-to-back
|
318 |
+
layers = []
|
319 |
+
with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx[0].int(), full_res) as peeler:
|
320 |
+
for _ in range(num_layers):
|
321 |
+
rast, db = peeler.rasterize_next_layer()
|
322 |
+
rendered = render_layer(rast, db, mesh, w2c, view_pos, material, lgt, resolution, spp, msaa, bsdf, feat=feat, prior_mesh=prior_mesh, two_sided_shading=two_sided_shading, render_flow=render_flow, delta_xy=delta_xy, dino_pred=dino_pred, class_vector=class_vector, im_features_map=im_features_map, mvp=mtx_in)
|
323 |
+
layers += [(rendered, rast)]
|
324 |
+
|
325 |
+
# Setup background
|
326 |
+
if background is not None:
|
327 |
+
if spp > 1:
|
328 |
+
background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest')
|
329 |
+
background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1)
|
330 |
+
else:
|
331 |
+
background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda')
|
332 |
+
|
333 |
+
# Composite layers front-to-back
|
334 |
+
out_buffers = {}
|
335 |
+
for key in layers[0][0].keys():
|
336 |
+
antialias = key in ['shaded', 'dino_feat_im_pred', 'flow']
|
337 |
+
bg = background if key in ['shaded'] else torch.zeros_like(layers[0][0][key])
|
338 |
+
accum = composite_buffer(key, layers, bg, antialias)
|
339 |
+
|
340 |
+
# Downscale to framebuffer resolution. Use avg pooling
|
341 |
+
out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum
|
342 |
+
|
343 |
+
return out_buffers
|
344 |
+
|
345 |
+
# ==============================================================================================
|
346 |
+
# Render UVs
|
347 |
+
# ==============================================================================================
|
348 |
+
def render_uv(ctx, mesh, resolution, mlp_texture, feat=None, prior_shape=None):
|
349 |
+
|
350 |
+
# clip space transform
|
351 |
+
uv_clip = mesh.v_tex * 2.0 - 1.0
|
352 |
+
|
353 |
+
# pad to four component coordinate
|
354 |
+
uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1)
|
355 |
+
|
356 |
+
# rasterize
|
357 |
+
rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx[0].int(), resolution)
|
358 |
+
|
359 |
+
# Interpolate world space position
|
360 |
+
if prior_shape is not None:
|
361 |
+
gb_pos, _ = interpolate(prior_shape.v_pos, rast, mesh.t_pos_idx[0].int())
|
362 |
+
else:
|
363 |
+
gb_pos, _ = interpolate(mesh.v_pos, rast, mesh.t_pos_idx[0].int())
|
364 |
+
|
365 |
+
# Sample out textures from MLP
|
366 |
+
all_tex = mlp_texture.sample(gb_pos, feat=feat)
|
367 |
+
assert all_tex.shape[-1] == 9 or all_tex.shape[-1] == 10, "Combined kd_ks_normal must be 9 or 10 channels"
|
368 |
+
perturbed_nrm = all_tex[..., -3:]
|
369 |
+
return (rast[..., -1:] > 0).float(), all_tex[..., :-6], all_tex[..., -6:-3], util.safe_normalize(perturbed_nrm)
|
video3d/render/renderutils/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith
|
11 |
+
__all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ]
|
video3d/render/renderutils/bsdf.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
|
13 |
+
NORMAL_THRESHOLD = 0.1
|
14 |
+
|
15 |
+
################################################################################
|
16 |
+
# Vector utility functions
|
17 |
+
################################################################################
|
18 |
+
|
19 |
+
def _dot(x, y):
|
20 |
+
return torch.sum(x*y, -1, keepdim=True)
|
21 |
+
|
22 |
+
def _reflect(x, n):
|
23 |
+
return 2*_dot(x, n)*n - x
|
24 |
+
|
25 |
+
def _safe_normalize(x):
|
26 |
+
return torch.nn.functional.normalize(x, dim = -1)
|
27 |
+
|
28 |
+
def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
|
29 |
+
# Swap normal direction for backfacing surfaces
|
30 |
+
if two_sided_shading:
|
31 |
+
smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)
|
32 |
+
geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)
|
33 |
+
|
34 |
+
t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)
|
35 |
+
return torch.lerp(geom_nrm, smooth_nrm, t)
|
36 |
+
|
37 |
+
|
38 |
+
def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
|
39 |
+
smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm, dim=-1))
|
40 |
+
if opengl:
|
41 |
+
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
|
42 |
+
else:
|
43 |
+
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
|
44 |
+
return _safe_normalize(shading_nrm)
|
45 |
+
|
46 |
+
def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
|
47 |
+
smooth_nrm = _safe_normalize(smooth_nrm)
|
48 |
+
smooth_tng = _safe_normalize(smooth_tng)
|
49 |
+
view_vec = _safe_normalize(view_pos - pos)
|
50 |
+
shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)
|
51 |
+
return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)
|
52 |
+
|
53 |
+
################################################################################
|
54 |
+
# Simple lambertian diffuse BSDF
|
55 |
+
################################################################################
|
56 |
+
|
57 |
+
def bsdf_lambert(nrm, wi):
|
58 |
+
return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi
|
59 |
+
|
60 |
+
################################################################################
|
61 |
+
# Frostbite diffuse
|
62 |
+
################################################################################
|
63 |
+
|
64 |
+
def bsdf_frostbite(nrm, wi, wo, linearRoughness):
|
65 |
+
wiDotN = _dot(wi, nrm)
|
66 |
+
woDotN = _dot(wo, nrm)
|
67 |
+
|
68 |
+
h = _safe_normalize(wo + wi)
|
69 |
+
wiDotH = _dot(wi, h)
|
70 |
+
|
71 |
+
energyBias = 0.5 * linearRoughness
|
72 |
+
energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness
|
73 |
+
f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness
|
74 |
+
f0 = 1.0
|
75 |
+
|
76 |
+
wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN)
|
77 |
+
woScatter = bsdf_fresnel_shlick(f0, f90, woDotN)
|
78 |
+
res = wiScatter * woScatter * energyFactor
|
79 |
+
return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res))
|
80 |
+
|
81 |
+
################################################################################
|
82 |
+
# Phong specular, loosely based on mitsuba implementation
|
83 |
+
################################################################################
|
84 |
+
|
85 |
+
def bsdf_phong(nrm, wo, wi, N):
|
86 |
+
dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)
|
87 |
+
dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)
|
88 |
+
return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)
|
89 |
+
|
90 |
+
################################################################################
|
91 |
+
# PBR's implementation of GGX specular
|
92 |
+
################################################################################
|
93 |
+
|
94 |
+
specular_epsilon = 1e-4
|
95 |
+
|
96 |
+
def bsdf_fresnel_shlick(f0, f90, cosTheta):
|
97 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
98 |
+
return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0
|
99 |
+
|
100 |
+
def bsdf_ndf_ggx(alphaSqr, cosTheta):
|
101 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
102 |
+
d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1
|
103 |
+
return alphaSqr / (d * d * math.pi)
|
104 |
+
|
105 |
+
def bsdf_lambda_ggx(alphaSqr, cosTheta):
|
106 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
107 |
+
cosThetaSqr = _cosTheta * _cosTheta
|
108 |
+
tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr
|
109 |
+
res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)
|
110 |
+
return res
|
111 |
+
|
112 |
+
def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
|
113 |
+
lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)
|
114 |
+
lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)
|
115 |
+
return 1 / (1 + lambdaI + lambdaO)
|
116 |
+
|
117 |
+
def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
|
118 |
+
_alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)
|
119 |
+
alphaSqr = _alpha * _alpha
|
120 |
+
|
121 |
+
h = _safe_normalize(wo + wi)
|
122 |
+
woDotN = _dot(wo, nrm)
|
123 |
+
wiDotN = _dot(wi, nrm)
|
124 |
+
woDotH = _dot(wo, h)
|
125 |
+
nDotH = _dot(nrm, h)
|
126 |
+
|
127 |
+
D = bsdf_ndf_ggx(alphaSqr, nDotH)
|
128 |
+
G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)
|
129 |
+
F = bsdf_fresnel_shlick(col, 1, woDotH)
|
130 |
+
|
131 |
+
w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)
|
132 |
+
|
133 |
+
frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)
|
134 |
+
return torch.where(frontfacing, w, torch.zeros_like(w))
|
135 |
+
|
136 |
+
def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
|
137 |
+
wo = _safe_normalize(view_pos - pos)
|
138 |
+
wi = _safe_normalize(light_pos - pos)
|
139 |
+
|
140 |
+
spec_str = arm[..., 0:1] # x component
|
141 |
+
roughness = arm[..., 1:2] # y component
|
142 |
+
metallic = arm[..., 2:3] # z component
|
143 |
+
ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)
|
144 |
+
kd = kd * (1.0 - metallic)
|
145 |
+
|
146 |
+
if BSDF == 0:
|
147 |
+
diffuse = kd * bsdf_lambert(nrm, wi)
|
148 |
+
else:
|
149 |
+
diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness)
|
150 |
+
specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)
|
151 |
+
return diffuse + specular
|
video3d/render/renderutils/c_src/bsdf.cu
ADDED
@@ -0,0 +1,710 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "common.h"
|
13 |
+
#include "bsdf.h"
|
14 |
+
|
15 |
+
#define SPECULAR_EPSILON 1e-4f
|
16 |
+
|
17 |
+
//------------------------------------------------------------------------
|
18 |
+
// Lambert functions
|
19 |
+
|
20 |
+
__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)
|
21 |
+
{
|
22 |
+
return max(dot(nrm, wi) / M_PI, 0.0f);
|
23 |
+
}
|
24 |
+
|
25 |
+
__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
|
26 |
+
{
|
27 |
+
if (dot(nrm, wi) > 0.0f)
|
28 |
+
bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
|
29 |
+
}
|
30 |
+
|
31 |
+
//------------------------------------------------------------------------
|
32 |
+
// Fresnel Schlick
|
33 |
+
|
34 |
+
__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)
|
35 |
+
{
|
36 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
37 |
+
float scale = powf(1.0f - _cosTheta, 5.0f);
|
38 |
+
return f0 * (1.0f - scale) + f90 * scale;
|
39 |
+
}
|
40 |
+
|
41 |
+
__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)
|
42 |
+
{
|
43 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
44 |
+
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
45 |
+
d_f0 += d_out * (1.0 - scale);
|
46 |
+
d_f90 += d_out * scale;
|
47 |
+
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
48 |
+
{
|
49 |
+
d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
|
54 |
+
{
|
55 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
56 |
+
float scale = powf(1.0f - _cosTheta, 5.0f);
|
57 |
+
return f0 * (1.0f - scale) + f90 * scale;
|
58 |
+
}
|
59 |
+
|
60 |
+
__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
|
61 |
+
{
|
62 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
63 |
+
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
64 |
+
d_f0 += d_out * (1.0 - scale);
|
65 |
+
d_f90 += d_out * scale;
|
66 |
+
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
67 |
+
{
|
68 |
+
d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
//------------------------------------------------------------------------
|
73 |
+
// Frostbite diffuse
|
74 |
+
|
75 |
+
__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)
|
76 |
+
{
|
77 |
+
float wiDotN = dot(wi, nrm);
|
78 |
+
float woDotN = dot(wo, nrm);
|
79 |
+
if (wiDotN > 0.0f && woDotN > 0.0f)
|
80 |
+
{
|
81 |
+
vec3f h = safeNormalize(wo + wi);
|
82 |
+
float wiDotH = dot(wi, h);
|
83 |
+
|
84 |
+
float energyBias = 0.5f * linearRoughness;
|
85 |
+
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
86 |
+
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
87 |
+
float f0 = 1.f;
|
88 |
+
|
89 |
+
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
|
90 |
+
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
91 |
+
|
92 |
+
return wiScatter * woScatter * energyFactor;
|
93 |
+
}
|
94 |
+
else return 0.0f;
|
95 |
+
}
|
96 |
+
|
97 |
+
__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)
|
98 |
+
{
|
99 |
+
float wiDotN = dot(wi, nrm);
|
100 |
+
float woDotN = dot(wo, nrm);
|
101 |
+
|
102 |
+
if (wiDotN > 0.0f && woDotN > 0.0f)
|
103 |
+
{
|
104 |
+
vec3f h = safeNormalize(wo + wi);
|
105 |
+
float wiDotH = dot(wi, h);
|
106 |
+
|
107 |
+
float energyBias = 0.5f * linearRoughness;
|
108 |
+
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
109 |
+
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
110 |
+
float f0 = 1.f;
|
111 |
+
|
112 |
+
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
|
113 |
+
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
114 |
+
|
115 |
+
// -------------- BWD --------------
|
116 |
+
// Backprop: return wiScatter * woScatter * energyFactor;
|
117 |
+
float d_wiScatter = d_out * woScatter * energyFactor;
|
118 |
+
float d_woScatter = d_out * wiScatter * energyFactor;
|
119 |
+
float d_energyFactor = d_out * wiScatter * woScatter;
|
120 |
+
|
121 |
+
// Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
122 |
+
float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;
|
123 |
+
bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);
|
124 |
+
|
125 |
+
// Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);
|
126 |
+
float d_wiDotN = 0.0f;
|
127 |
+
bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);
|
128 |
+
|
129 |
+
// Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
130 |
+
float d_energyBias = d_f90;
|
131 |
+
float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;
|
132 |
+
d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;
|
133 |
+
|
134 |
+
// Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
135 |
+
d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;
|
136 |
+
|
137 |
+
// Backprop: float energyBias = 0.5f * linearRoughness;
|
138 |
+
d_linearRoughness += 0.5 * d_energyBias;
|
139 |
+
|
140 |
+
// Backprop: float wiDotH = dot(wi, h);
|
141 |
+
vec3f d_h(0);
|
142 |
+
bwdDot(wi, h, d_wi, d_h, d_wiDotH);
|
143 |
+
|
144 |
+
// Backprop: vec3f h = safeNormalize(wo + wi);
|
145 |
+
vec3f d_wo_wi(0);
|
146 |
+
bwdSafeNormalize(wo + wi, d_wo_wi, d_h);
|
147 |
+
d_wi += d_wo_wi; d_wo += d_wo_wi;
|
148 |
+
|
149 |
+
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
150 |
+
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
151 |
+
}
|
152 |
+
}
|
153 |
+
|
154 |
+
//------------------------------------------------------------------------
|
155 |
+
// Ndf GGX
|
156 |
+
|
157 |
+
__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)
|
158 |
+
{
|
159 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
160 |
+
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
|
161 |
+
return alphaSqr / (d * d * M_PI);
|
162 |
+
}
|
163 |
+
|
164 |
+
__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
165 |
+
{
|
166 |
+
// Torch only back propagates if clamp doesn't trigger
|
167 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
168 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
169 |
+
d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
170 |
+
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
171 |
+
{
|
172 |
+
d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
173 |
+
}
|
174 |
+
}
|
175 |
+
|
176 |
+
//------------------------------------------------------------------------
|
177 |
+
// Lambda GGX
|
178 |
+
|
179 |
+
__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
|
180 |
+
{
|
181 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
182 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
183 |
+
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
184 |
+
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
185 |
+
return res;
|
186 |
+
}
|
187 |
+
|
188 |
+
__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
189 |
+
{
|
190 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
191 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
192 |
+
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
193 |
+
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
194 |
+
|
195 |
+
d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
|
196 |
+
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
197 |
+
d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
|
198 |
+
}
|
199 |
+
|
200 |
+
//------------------------------------------------------------------------
|
201 |
+
// Masking GGX
|
202 |
+
|
203 |
+
__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
|
204 |
+
{
|
205 |
+
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
206 |
+
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
207 |
+
return 1.0f / (1.0f + lambdaI + lambdaO);
|
208 |
+
}
|
209 |
+
|
210 |
+
__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
|
211 |
+
{
|
212 |
+
// FWD eval
|
213 |
+
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
214 |
+
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
215 |
+
|
216 |
+
// BWD eval
|
217 |
+
float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
|
218 |
+
bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
|
219 |
+
bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
|
220 |
+
}
|
221 |
+
|
222 |
+
//------------------------------------------------------------------------
|
223 |
+
// GGX specular
|
224 |
+
|
225 |
+
__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
|
226 |
+
{
|
227 |
+
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
228 |
+
float alphaSqr = _alpha * _alpha;
|
229 |
+
|
230 |
+
vec3f h = safeNormalize(wo + wi);
|
231 |
+
float woDotN = dot(wo, nrm);
|
232 |
+
float wiDotN = dot(wi, nrm);
|
233 |
+
float woDotH = dot(wo, h);
|
234 |
+
float nDotH = dot(nrm, h);
|
235 |
+
|
236 |
+
float D = fwdNdfGGX(alphaSqr, nDotH);
|
237 |
+
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
238 |
+
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
239 |
+
vec3f w = F * D * G * 0.25 / woDotN;
|
240 |
+
|
241 |
+
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
242 |
+
return frontfacing ? w : 0.0f;
|
243 |
+
}
|
244 |
+
|
245 |
+
__device__ void bwdPbrSpecular(
|
246 |
+
const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
|
247 |
+
vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
|
248 |
+
{
|
249 |
+
///////////////////////////////////////////////////////////////////////
|
250 |
+
// FWD eval
|
251 |
+
|
252 |
+
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
253 |
+
float alphaSqr = _alpha * _alpha;
|
254 |
+
|
255 |
+
vec3f h = safeNormalize(wo + wi);
|
256 |
+
float woDotN = dot(wo, nrm);
|
257 |
+
float wiDotN = dot(wi, nrm);
|
258 |
+
float woDotH = dot(wo, h);
|
259 |
+
float nDotH = dot(nrm, h);
|
260 |
+
|
261 |
+
float D = fwdNdfGGX(alphaSqr, nDotH);
|
262 |
+
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
263 |
+
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
264 |
+
vec3f w = F * D * G * 0.25 / woDotN;
|
265 |
+
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
266 |
+
|
267 |
+
if (frontfacing)
|
268 |
+
{
|
269 |
+
///////////////////////////////////////////////////////////////////////
|
270 |
+
// BWD eval
|
271 |
+
|
272 |
+
vec3f d_F = d_out * D * G * 0.25f / woDotN;
|
273 |
+
float d_D = sum(d_out * F * G * 0.25f / woDotN);
|
274 |
+
float d_G = sum(d_out * F * D * 0.25f / woDotN);
|
275 |
+
|
276 |
+
float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
|
277 |
+
|
278 |
+
vec3f d_f90(0);
|
279 |
+
float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
|
280 |
+
bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
|
281 |
+
bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
|
282 |
+
bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
|
283 |
+
|
284 |
+
vec3f d_h(0);
|
285 |
+
bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
|
286 |
+
bwdDot(wo, h, d_wo, d_h, d_woDotH);
|
287 |
+
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
288 |
+
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
289 |
+
|
290 |
+
vec3f d_h_unnorm(0);
|
291 |
+
bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
|
292 |
+
d_wo += d_h_unnorm;
|
293 |
+
d_wi += d_h_unnorm;
|
294 |
+
|
295 |
+
if (alpha > min_roughness * min_roughness)
|
296 |
+
d_alpha += d_alphaSqr * 2 * alpha;
|
297 |
+
}
|
298 |
+
}
|
299 |
+
|
300 |
+
//------------------------------------------------------------------------
|
301 |
+
// Full PBR BSDF
|
302 |
+
|
303 |
+
__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)
|
304 |
+
{
|
305 |
+
vec3f wo = safeNormalize(view_pos - pos);
|
306 |
+
vec3f wi = safeNormalize(light_pos - pos);
|
307 |
+
|
308 |
+
float alpha = arm.y * arm.y;
|
309 |
+
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
310 |
+
vec3f diff_col = kd * (1.0f - arm.z);
|
311 |
+
|
312 |
+
float diff = 0.0f;
|
313 |
+
if (BSDF == 0)
|
314 |
+
diff = fwdLambert(nrm, wi);
|
315 |
+
else
|
316 |
+
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
|
317 |
+
vec3f diffuse = diff_col * diff;
|
318 |
+
vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
|
319 |
+
|
320 |
+
return diffuse + specular;
|
321 |
+
}
|
322 |
+
|
323 |
+
__device__ void bwdPbrBSDF(
|
324 |
+
const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,
|
325 |
+
vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
|
326 |
+
{
|
327 |
+
////////////////////////////////////////////////////////////////////////
|
328 |
+
// FWD
|
329 |
+
vec3f _wi = light_pos - pos;
|
330 |
+
vec3f _wo = view_pos - pos;
|
331 |
+
vec3f wi = safeNormalize(_wi);
|
332 |
+
vec3f wo = safeNormalize(_wo);
|
333 |
+
|
334 |
+
float alpha = arm.y * arm.y;
|
335 |
+
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
336 |
+
vec3f diff_col = kd * (1.0f - arm.z);
|
337 |
+
float diff = 0.0f;
|
338 |
+
if (BSDF == 0)
|
339 |
+
diff = fwdLambert(nrm, wi);
|
340 |
+
else
|
341 |
+
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
|
342 |
+
|
343 |
+
////////////////////////////////////////////////////////////////////////
|
344 |
+
// BWD
|
345 |
+
|
346 |
+
float d_alpha(0);
|
347 |
+
vec3f d_spec_col(0), d_wi(0), d_wo(0);
|
348 |
+
bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
349 |
+
|
350 |
+
float d_diff = sum(diff_col * d_out);
|
351 |
+
if (BSDF == 0)
|
352 |
+
bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);
|
353 |
+
else
|
354 |
+
bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);
|
355 |
+
|
356 |
+
// Backprop: diff_col = kd * (1.0f - arm.z)
|
357 |
+
vec3f d_diff_col = d_out * diff;
|
358 |
+
d_kd += d_diff_col * (1.0f - arm.z);
|
359 |
+
d_arm.z -= sum(d_diff_col * kd);
|
360 |
+
|
361 |
+
// Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
|
362 |
+
d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
|
363 |
+
d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
|
364 |
+
d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
|
365 |
+
|
366 |
+
// Backprop: alpha = arm.y * arm.y
|
367 |
+
d_arm.y += d_alpha * 2 * arm.y;
|
368 |
+
|
369 |
+
// Backprop: vec3f wi = safeNormalize(light_pos - pos);
|
370 |
+
vec3f d__wi(0);
|
371 |
+
bwdSafeNormalize(_wi, d__wi, d_wi);
|
372 |
+
d_light_pos += d__wi;
|
373 |
+
d_pos -= d__wi;
|
374 |
+
|
375 |
+
// Backprop: vec3f wo = safeNormalize(view_pos - pos);
|
376 |
+
vec3f d__wo(0);
|
377 |
+
bwdSafeNormalize(_wo, d__wo, d_wo);
|
378 |
+
d_view_pos += d__wo;
|
379 |
+
d_pos -= d__wo;
|
380 |
+
}
|
381 |
+
|
382 |
+
//------------------------------------------------------------------------
|
383 |
+
// Kernels
|
384 |
+
|
385 |
+
__global__ void LambertFwdKernel(LambertKernelParams p)
|
386 |
+
{
|
387 |
+
// Calculate pixel position.
|
388 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
389 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
390 |
+
unsigned int pz = blockIdx.z;
|
391 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
392 |
+
return;
|
393 |
+
|
394 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
395 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
396 |
+
|
397 |
+
float res = fwdLambert(nrm, wi);
|
398 |
+
|
399 |
+
p.out.store(px, py, pz, res);
|
400 |
+
}
|
401 |
+
|
402 |
+
__global__ void LambertBwdKernel(LambertKernelParams p)
|
403 |
+
{
|
404 |
+
// Calculate pixel position.
|
405 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
406 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
407 |
+
unsigned int pz = blockIdx.z;
|
408 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
409 |
+
return;
|
410 |
+
|
411 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
412 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
413 |
+
float d_out = p.out.fetch1(px, py, pz);
|
414 |
+
|
415 |
+
vec3f d_nrm(0), d_wi(0);
|
416 |
+
bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
|
417 |
+
|
418 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
419 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
420 |
+
}
|
421 |
+
|
422 |
+
__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)
|
423 |
+
{
|
424 |
+
// Calculate pixel position.
|
425 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
426 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
427 |
+
unsigned int pz = blockIdx.z;
|
428 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
429 |
+
return;
|
430 |
+
|
431 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
432 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
433 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
434 |
+
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
|
435 |
+
|
436 |
+
float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);
|
437 |
+
|
438 |
+
p.out.store(px, py, pz, res);
|
439 |
+
}
|
440 |
+
|
441 |
+
__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)
|
442 |
+
{
|
443 |
+
// Calculate pixel position.
|
444 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
445 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
446 |
+
unsigned int pz = blockIdx.z;
|
447 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
448 |
+
return;
|
449 |
+
|
450 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
451 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
452 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
453 |
+
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
|
454 |
+
float d_out = p.out.fetch1(px, py, pz);
|
455 |
+
|
456 |
+
float d_linearRoughness = 0.0f;
|
457 |
+
vec3f d_nrm(0), d_wi(0), d_wo(0);
|
458 |
+
bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);
|
459 |
+
|
460 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
461 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
462 |
+
p.wo.store_grad(px, py, pz, d_wo);
|
463 |
+
p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);
|
464 |
+
}
|
465 |
+
|
466 |
+
__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
|
467 |
+
{
|
468 |
+
// Calculate pixel position.
|
469 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
470 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
471 |
+
unsigned int pz = blockIdx.z;
|
472 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
473 |
+
return;
|
474 |
+
|
475 |
+
vec3f f0 = p.f0.fetch3(px, py, pz);
|
476 |
+
vec3f f90 = p.f90.fetch3(px, py, pz);
|
477 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
478 |
+
|
479 |
+
vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
|
480 |
+
p.out.store(px, py, pz, res);
|
481 |
+
}
|
482 |
+
|
483 |
+
__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
|
484 |
+
{
|
485 |
+
// Calculate pixel position.
|
486 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
487 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
488 |
+
unsigned int pz = blockIdx.z;
|
489 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
490 |
+
return;
|
491 |
+
|
492 |
+
vec3f f0 = p.f0.fetch3(px, py, pz);
|
493 |
+
vec3f f90 = p.f90.fetch3(px, py, pz);
|
494 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
495 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
496 |
+
|
497 |
+
vec3f d_f0(0), d_f90(0);
|
498 |
+
float d_cosTheta(0);
|
499 |
+
bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
|
500 |
+
|
501 |
+
p.f0.store_grad(px, py, pz, d_f0);
|
502 |
+
p.f90.store_grad(px, py, pz, d_f90);
|
503 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
504 |
+
}
|
505 |
+
|
506 |
+
__global__ void ndfGGXFwdKernel(NdfGGXParams p)
|
507 |
+
{
|
508 |
+
// Calculate pixel position.
|
509 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
510 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
511 |
+
unsigned int pz = blockIdx.z;
|
512 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
513 |
+
return;
|
514 |
+
|
515 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
516 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
517 |
+
float res = fwdNdfGGX(alphaSqr, cosTheta);
|
518 |
+
|
519 |
+
p.out.store(px, py, pz, res);
|
520 |
+
}
|
521 |
+
|
522 |
+
__global__ void ndfGGXBwdKernel(NdfGGXParams p)
|
523 |
+
{
|
524 |
+
// Calculate pixel position.
|
525 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
526 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
527 |
+
unsigned int pz = blockIdx.z;
|
528 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
529 |
+
return;
|
530 |
+
|
531 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
532 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
533 |
+
float d_out = p.out.fetch1(px, py, pz);
|
534 |
+
|
535 |
+
float d_alphaSqr(0), d_cosTheta(0);
|
536 |
+
bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
537 |
+
|
538 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
539 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
540 |
+
}
|
541 |
+
|
542 |
+
__global__ void lambdaGGXFwdKernel(NdfGGXParams p)
|
543 |
+
{
|
544 |
+
// Calculate pixel position.
|
545 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
546 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
547 |
+
unsigned int pz = blockIdx.z;
|
548 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
549 |
+
return;
|
550 |
+
|
551 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
552 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
553 |
+
float res = fwdLambdaGGX(alphaSqr, cosTheta);
|
554 |
+
|
555 |
+
p.out.store(px, py, pz, res);
|
556 |
+
}
|
557 |
+
|
558 |
+
__global__ void lambdaGGXBwdKernel(NdfGGXParams p)
|
559 |
+
{
|
560 |
+
// Calculate pixel position.
|
561 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
562 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
563 |
+
unsigned int pz = blockIdx.z;
|
564 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
565 |
+
return;
|
566 |
+
|
567 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
568 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
569 |
+
float d_out = p.out.fetch1(px, py, pz);
|
570 |
+
|
571 |
+
float d_alphaSqr(0), d_cosTheta(0);
|
572 |
+
bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
573 |
+
|
574 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
575 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
576 |
+
}
|
577 |
+
|
578 |
+
__global__ void maskingSmithFwdKernel(MaskingSmithParams p)
|
579 |
+
{
|
580 |
+
// Calculate pixel position.
|
581 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
582 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
583 |
+
unsigned int pz = blockIdx.z;
|
584 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
585 |
+
return;
|
586 |
+
|
587 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
588 |
+
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
589 |
+
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
590 |
+
float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
|
591 |
+
|
592 |
+
p.out.store(px, py, pz, res);
|
593 |
+
}
|
594 |
+
|
595 |
+
__global__ void maskingSmithBwdKernel(MaskingSmithParams p)
|
596 |
+
{
|
597 |
+
// Calculate pixel position.
|
598 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
599 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
600 |
+
unsigned int pz = blockIdx.z;
|
601 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
602 |
+
return;
|
603 |
+
|
604 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
605 |
+
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
606 |
+
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
607 |
+
float d_out = p.out.fetch1(px, py, pz);
|
608 |
+
|
609 |
+
float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
|
610 |
+
bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
|
611 |
+
|
612 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
613 |
+
p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
|
614 |
+
p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
|
615 |
+
}
|
616 |
+
|
617 |
+
__global__ void pbrSpecularFwdKernel(PbrSpecular p)
|
618 |
+
{
|
619 |
+
// Calculate pixel position.
|
620 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
621 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
622 |
+
unsigned int pz = blockIdx.z;
|
623 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
624 |
+
return;
|
625 |
+
|
626 |
+
vec3f col = p.col.fetch3(px, py, pz);
|
627 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
628 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
629 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
630 |
+
float alpha = p.alpha.fetch1(px, py, pz);
|
631 |
+
|
632 |
+
vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
|
633 |
+
|
634 |
+
p.out.store(px, py, pz, res);
|
635 |
+
}
|
636 |
+
|
637 |
+
__global__ void pbrSpecularBwdKernel(PbrSpecular p)
|
638 |
+
{
|
639 |
+
// Calculate pixel position.
|
640 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
641 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
642 |
+
unsigned int pz = blockIdx.z;
|
643 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
644 |
+
return;
|
645 |
+
|
646 |
+
vec3f col = p.col.fetch3(px, py, pz);
|
647 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
648 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
649 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
650 |
+
float alpha = p.alpha.fetch1(px, py, pz);
|
651 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
652 |
+
|
653 |
+
float d_alpha(0);
|
654 |
+
vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
|
655 |
+
bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
656 |
+
|
657 |
+
p.col.store_grad(px, py, pz, d_col);
|
658 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
659 |
+
p.wo.store_grad(px, py, pz, d_wo);
|
660 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
661 |
+
p.alpha.store_grad(px, py, pz, d_alpha);
|
662 |
+
}
|
663 |
+
|
664 |
+
__global__ void pbrBSDFFwdKernel(PbrBSDF p)
|
665 |
+
{
|
666 |
+
// Calculate pixel position.
|
667 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
668 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
669 |
+
unsigned int pz = blockIdx.z;
|
670 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
671 |
+
return;
|
672 |
+
|
673 |
+
vec3f kd = p.kd.fetch3(px, py, pz);
|
674 |
+
vec3f arm = p.arm.fetch3(px, py, pz);
|
675 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
676 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
677 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
678 |
+
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
679 |
+
|
680 |
+
vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);
|
681 |
+
|
682 |
+
p.out.store(px, py, pz, res);
|
683 |
+
}
|
684 |
+
__global__ void pbrBSDFBwdKernel(PbrBSDF p)
|
685 |
+
{
|
686 |
+
// Calculate pixel position.
|
687 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
688 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
689 |
+
unsigned int pz = blockIdx.z;
|
690 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
691 |
+
return;
|
692 |
+
|
693 |
+
vec3f kd = p.kd.fetch3(px, py, pz);
|
694 |
+
vec3f arm = p.arm.fetch3(px, py, pz);
|
695 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
696 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
697 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
698 |
+
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
699 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
700 |
+
|
701 |
+
vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
|
702 |
+
bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
|
703 |
+
|
704 |
+
p.kd.store_grad(px, py, pz, d_kd);
|
705 |
+
p.arm.store_grad(px, py, pz, d_arm);
|
706 |
+
p.pos.store_grad(px, py, pz, d_pos);
|
707 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
708 |
+
p.view_pos.store_grad(px, py, pz, d_view_pos);
|
709 |
+
p.light_pos.store_grad(px, py, pz, d_light_pos);
|
710 |
+
}
|
video3d/render/renderutils/c_src/bsdf.h
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
#include "common.h"
|
15 |
+
|
16 |
+
struct LambertKernelParams
|
17 |
+
{
|
18 |
+
Tensor nrm;
|
19 |
+
Tensor wi;
|
20 |
+
Tensor out;
|
21 |
+
dim3 gridSize;
|
22 |
+
};
|
23 |
+
|
24 |
+
struct FrostbiteDiffuseKernelParams
|
25 |
+
{
|
26 |
+
Tensor nrm;
|
27 |
+
Tensor wi;
|
28 |
+
Tensor wo;
|
29 |
+
Tensor linearRoughness;
|
30 |
+
Tensor out;
|
31 |
+
dim3 gridSize;
|
32 |
+
};
|
33 |
+
|
34 |
+
struct FresnelShlickKernelParams
|
35 |
+
{
|
36 |
+
Tensor f0;
|
37 |
+
Tensor f90;
|
38 |
+
Tensor cosTheta;
|
39 |
+
Tensor out;
|
40 |
+
dim3 gridSize;
|
41 |
+
};
|
42 |
+
|
43 |
+
struct NdfGGXParams
|
44 |
+
{
|
45 |
+
Tensor alphaSqr;
|
46 |
+
Tensor cosTheta;
|
47 |
+
Tensor out;
|
48 |
+
dim3 gridSize;
|
49 |
+
};
|
50 |
+
|
51 |
+
struct MaskingSmithParams
|
52 |
+
{
|
53 |
+
Tensor alphaSqr;
|
54 |
+
Tensor cosThetaI;
|
55 |
+
Tensor cosThetaO;
|
56 |
+
Tensor out;
|
57 |
+
dim3 gridSize;
|
58 |
+
};
|
59 |
+
|
60 |
+
struct PbrSpecular
|
61 |
+
{
|
62 |
+
Tensor col;
|
63 |
+
Tensor nrm;
|
64 |
+
Tensor wo;
|
65 |
+
Tensor wi;
|
66 |
+
Tensor alpha;
|
67 |
+
Tensor out;
|
68 |
+
dim3 gridSize;
|
69 |
+
float min_roughness;
|
70 |
+
};
|
71 |
+
|
72 |
+
struct PbrBSDF
|
73 |
+
{
|
74 |
+
Tensor kd;
|
75 |
+
Tensor arm;
|
76 |
+
Tensor pos;
|
77 |
+
Tensor nrm;
|
78 |
+
Tensor view_pos;
|
79 |
+
Tensor light_pos;
|
80 |
+
Tensor out;
|
81 |
+
dim3 gridSize;
|
82 |
+
float min_roughness;
|
83 |
+
int BSDF;
|
84 |
+
};
|
video3d/render/renderutils/c_src/common.cpp
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include <cuda_runtime.h>
|
13 |
+
#include <algorithm>
|
14 |
+
|
15 |
+
//------------------------------------------------------------------------
|
16 |
+
// Block and grid size calculators for kernel launches.
|
17 |
+
|
18 |
+
dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)
|
19 |
+
{
|
20 |
+
int maxThreads = maxWidth * maxHeight;
|
21 |
+
if (maxThreads <= 1 || (dims.x * dims.y) <= 1)
|
22 |
+
return dim3(1, 1, 1); // Degenerate.
|
23 |
+
|
24 |
+
// Start from max size.
|
25 |
+
int bw = maxWidth;
|
26 |
+
int bh = maxHeight;
|
27 |
+
|
28 |
+
// Optimizations for weirdly sized buffers.
|
29 |
+
if (dims.x < bw)
|
30 |
+
{
|
31 |
+
// Decrease block width to smallest power of two that covers the buffer width.
|
32 |
+
while ((bw >> 1) >= dims.x)
|
33 |
+
bw >>= 1;
|
34 |
+
|
35 |
+
// Maximize height.
|
36 |
+
bh = maxThreads / bw;
|
37 |
+
if (bh > dims.y)
|
38 |
+
bh = dims.y;
|
39 |
+
}
|
40 |
+
else if (dims.y < bh)
|
41 |
+
{
|
42 |
+
// Halve height and double width until fits completely inside buffer vertically.
|
43 |
+
while (bh > dims.y)
|
44 |
+
{
|
45 |
+
bh >>= 1;
|
46 |
+
if (bw < dims.x)
|
47 |
+
bw <<= 1;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
// Done.
|
52 |
+
return dim3(bw, bh, 1);
|
53 |
+
}
|
54 |
+
|
55 |
+
// returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync)
|
56 |
+
dim3 getWarpSize(dim3 blockSize)
|
57 |
+
{
|
58 |
+
return dim3(
|
59 |
+
std::min(blockSize.x, 32u),
|
60 |
+
std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)),
|
61 |
+
std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z))
|
62 |
+
);
|
63 |
+
}
|
64 |
+
|
65 |
+
dim3 getLaunchGridSize(dim3 blockSize, dim3 dims)
|
66 |
+
{
|
67 |
+
dim3 gridSize;
|
68 |
+
gridSize.x = (dims.x - 1) / blockSize.x + 1;
|
69 |
+
gridSize.y = (dims.y - 1) / blockSize.y + 1;
|
70 |
+
gridSize.z = (dims.z - 1) / blockSize.z + 1;
|
71 |
+
return gridSize;
|
72 |
+
}
|
73 |
+
|
74 |
+
//------------------------------------------------------------------------
|
video3d/render/renderutils/c_src/common.h
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
#include <cuda.h>
|
14 |
+
#include <stdint.h>
|
15 |
+
|
16 |
+
#include "vec3f.h"
|
17 |
+
#include "vec4f.h"
|
18 |
+
#include "tensor.h"
|
19 |
+
|
20 |
+
dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims);
|
21 |
+
dim3 getLaunchGridSize(dim3 blockSize, dim3 dims);
|
22 |
+
|
23 |
+
#ifdef __CUDACC__
|
24 |
+
|
25 |
+
#ifdef _MSC_VER
|
26 |
+
#define M_PI 3.14159265358979323846f
|
27 |
+
#endif
|
28 |
+
|
29 |
+
__host__ __device__ static inline dim3 getWarpSize(dim3 blockSize)
|
30 |
+
{
|
31 |
+
return dim3(
|
32 |
+
min(blockSize.x, 32u),
|
33 |
+
min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)),
|
34 |
+
min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z))
|
35 |
+
);
|
36 |
+
}
|
37 |
+
|
38 |
+
__device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); }
|
39 |
+
#else
|
40 |
+
dim3 getWarpSize(dim3 blockSize);
|
41 |
+
#endif
|
video3d/render/renderutils/c_src/cubemap.cu
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "common.h"
|
13 |
+
#include "cubemap.h"
|
14 |
+
#include <float.h>
|
15 |
+
|
16 |
+
// https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf
|
17 |
+
__device__ float pixel_area(int x, int y, int N)
|
18 |
+
{
|
19 |
+
if (N > 1)
|
20 |
+
{
|
21 |
+
int H = N / 2;
|
22 |
+
x = abs(x - H);
|
23 |
+
y = abs(y - H);
|
24 |
+
float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H);
|
25 |
+
float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H);
|
26 |
+
return dx * dy;
|
27 |
+
}
|
28 |
+
else
|
29 |
+
return 1;
|
30 |
+
}
|
31 |
+
|
32 |
+
__device__ vec3f cube_to_dir(int x, int y, int side, int N)
|
33 |
+
{
|
34 |
+
float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f;
|
35 |
+
float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f;
|
36 |
+
switch (side)
|
37 |
+
{
|
38 |
+
case 0: return safeNormalize(vec3f(1, -fy, -fx));
|
39 |
+
case 1: return safeNormalize(vec3f(-1, -fy, fx));
|
40 |
+
case 2: return safeNormalize(vec3f(fx, 1, fy));
|
41 |
+
case 3: return safeNormalize(vec3f(fx, -1, -fy));
|
42 |
+
case 4: return safeNormalize(vec3f(fx, -fy, 1));
|
43 |
+
case 5: return safeNormalize(vec3f(-fx, -fy, -1));
|
44 |
+
}
|
45 |
+
return vec3f(0,0,0); // Unreachable
|
46 |
+
}
|
47 |
+
|
48 |
+
__device__ vec3f dir_to_side(int side, vec3f v)
|
49 |
+
{
|
50 |
+
switch (side)
|
51 |
+
{
|
52 |
+
case 0: return vec3f(-v.z, -v.y, v.x);
|
53 |
+
case 1: return vec3f( v.z, -v.y, -v.x);
|
54 |
+
case 2: return vec3f( v.x, v.z, v.y);
|
55 |
+
case 3: return vec3f( v.x, -v.z, -v.y);
|
56 |
+
case 4: return vec3f( v.x, -v.y, v.z);
|
57 |
+
case 5: return vec3f(-v.x, -v.y, -v.z);
|
58 |
+
}
|
59 |
+
return vec3f(0,0,0); // Unreachable
|
60 |
+
}
|
61 |
+
|
62 |
+
__device__ void extents_1d(float x, float z, float theta, float& _min, float& _max)
|
63 |
+
{
|
64 |
+
float l = sqrtf(x * x + z * z);
|
65 |
+
float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l;
|
66 |
+
float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l;
|
67 |
+
if (pzl <= 0.00001f)
|
68 |
+
_min = pxl > 0.0f ? FLT_MAX : -FLT_MAX;
|
69 |
+
else
|
70 |
+
_min = pxl / pzl;
|
71 |
+
if (pzr <= 0.00001f)
|
72 |
+
_max = pxr > 0.0f ? FLT_MAX : -FLT_MAX;
|
73 |
+
else
|
74 |
+
_max = pxr / pzr;
|
75 |
+
}
|
76 |
+
|
77 |
+
__device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax)
|
78 |
+
{
|
79 |
+
vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1
|
80 |
+
|
81 |
+
if (theta < 0.785398f) // PI/4
|
82 |
+
{
|
83 |
+
float xmin, xmax, ymin, ymax;
|
84 |
+
extents_1d(c.x, c.z, theta, xmin, xmax);
|
85 |
+
extents_1d(c.y, c.z, theta, ymin, ymax);
|
86 |
+
|
87 |
+
if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f)
|
88 |
+
{
|
89 |
+
_xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb
|
90 |
+
}
|
91 |
+
else
|
92 |
+
{
|
93 |
+
_xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
94 |
+
_xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
95 |
+
_ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
96 |
+
_ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
97 |
+
}
|
98 |
+
}
|
99 |
+
else
|
100 |
+
{
|
101 |
+
_xmin = 0.0f;
|
102 |
+
_xmax = (float)(N-1);
|
103 |
+
_ymin = 0.0f;
|
104 |
+
_ymax = (float)(N-1);
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////////////
|
109 |
+
// Diffuse kernel
|
110 |
+
__global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p)
|
111 |
+
{
|
112 |
+
// Calculate pixel position.
|
113 |
+
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
114 |
+
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
115 |
+
int pz = blockIdx.z;
|
116 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
117 |
+
return;
|
118 |
+
|
119 |
+
int Npx = p.cubemap.dims[1];
|
120 |
+
vec3f N = cube_to_dir(px, py, pz, Npx);
|
121 |
+
|
122 |
+
vec3f col(0);
|
123 |
+
|
124 |
+
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
125 |
+
{
|
126 |
+
for (int y = 0; y < Npx; ++y)
|
127 |
+
{
|
128 |
+
for (int x = 0; x < Npx; ++x)
|
129 |
+
{
|
130 |
+
vec3f L = cube_to_dir(x, y, s, Npx);
|
131 |
+
float costheta = min(max(dot(N, L), 0.0f), 0.999f);
|
132 |
+
float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
|
133 |
+
col += p.cubemap.fetch3(x, y, s) * w;
|
134 |
+
}
|
135 |
+
}
|
136 |
+
}
|
137 |
+
|
138 |
+
p.out.store(px, py, pz, col);
|
139 |
+
}
|
140 |
+
|
141 |
+
__global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p)
|
142 |
+
{
|
143 |
+
// Calculate pixel position.
|
144 |
+
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
145 |
+
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
146 |
+
int pz = blockIdx.z;
|
147 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
148 |
+
return;
|
149 |
+
|
150 |
+
int Npx = p.cubemap.dims[1];
|
151 |
+
vec3f N = cube_to_dir(px, py, pz, Npx);
|
152 |
+
vec3f grad = p.out.fetch3(px, py, pz);
|
153 |
+
|
154 |
+
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
155 |
+
{
|
156 |
+
for (int y = 0; y < Npx; ++y)
|
157 |
+
{
|
158 |
+
for (int x = 0; x < Npx; ++x)
|
159 |
+
{
|
160 |
+
vec3f L = cube_to_dir(x, y, s, Npx);
|
161 |
+
float costheta = min(max(dot(N, L), 0.0f), 0.999f);
|
162 |
+
float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
|
163 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
|
164 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
|
165 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
|
166 |
+
}
|
167 |
+
}
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////////////
|
172 |
+
// GGX splitsum kernel
|
173 |
+
|
174 |
+
__device__ inline float ndfGGX(const float alphaSqr, const float cosTheta)
|
175 |
+
{
|
176 |
+
float _cosTheta = clamp(cosTheta, 0.0, 1.0f);
|
177 |
+
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
|
178 |
+
return alphaSqr / (d * d * M_PI);
|
179 |
+
}
|
180 |
+
|
181 |
+
__global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p)
|
182 |
+
{
|
183 |
+
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
184 |
+
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
185 |
+
int pz = blockIdx.z;
|
186 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
187 |
+
return;
|
188 |
+
|
189 |
+
int Npx = p.gridSize.x;
|
190 |
+
vec3f VNR = cube_to_dir(px, py, pz, Npx);
|
191 |
+
|
192 |
+
const int TILE_SIZE = 16;
|
193 |
+
|
194 |
+
// Brute force entire cubemap and compute bounds for the cone
|
195 |
+
for (int s = 0; s < p.gridSize.z; ++s)
|
196 |
+
{
|
197 |
+
// Assume empty BBox
|
198 |
+
int _min_x = p.gridSize.x - 1, _max_x = 0;
|
199 |
+
int _min_y = p.gridSize.y - 1, _max_y = 0;
|
200 |
+
|
201 |
+
// For each (8x8) tile
|
202 |
+
for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++)
|
203 |
+
{
|
204 |
+
for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++)
|
205 |
+
{
|
206 |
+
// Compute tile extents
|
207 |
+
int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE;
|
208 |
+
int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y);
|
209 |
+
|
210 |
+
// Use some blunt interval arithmetics to cull tiles
|
211 |
+
vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx);
|
212 |
+
vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx);
|
213 |
+
|
214 |
+
float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x));
|
215 |
+
float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y));
|
216 |
+
float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z));
|
217 |
+
|
218 |
+
float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z);
|
219 |
+
if (maxdp >= p.costheta_cutoff)
|
220 |
+
{
|
221 |
+
// Test all pixels in tile.
|
222 |
+
for (int y = tsy; y < tey; ++y)
|
223 |
+
{
|
224 |
+
for (int x = tsx; x < tex; ++x)
|
225 |
+
{
|
226 |
+
vec3f L = cube_to_dir(x, y, s, Npx);
|
227 |
+
if (dot(L, VNR) >= p.costheta_cutoff)
|
228 |
+
{
|
229 |
+
_min_x = min(_min_x, x);
|
230 |
+
_max_x = max(_max_x, x);
|
231 |
+
_min_y = min(_min_y, y);
|
232 |
+
_max_y = max(_max_y, y);
|
233 |
+
}
|
234 |
+
}
|
235 |
+
}
|
236 |
+
}
|
237 |
+
}
|
238 |
+
}
|
239 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x);
|
240 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x);
|
241 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y);
|
242 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y);
|
243 |
+
}
|
244 |
+
}
|
245 |
+
|
246 |
+
__global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p)
|
247 |
+
{
|
248 |
+
// Calculate pixel position.
|
249 |
+
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
250 |
+
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
251 |
+
int pz = blockIdx.z;
|
252 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
253 |
+
return;
|
254 |
+
|
255 |
+
int Npx = p.cubemap.dims[1];
|
256 |
+
vec3f VNR = cube_to_dir(px, py, pz, Npx);
|
257 |
+
|
258 |
+
float alpha = p.roughness * p.roughness;
|
259 |
+
float alphaSqr = alpha * alpha;
|
260 |
+
|
261 |
+
float wsum = 0.0f;
|
262 |
+
vec3f col(0);
|
263 |
+
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
264 |
+
{
|
265 |
+
int xmin, xmax, ymin, ymax;
|
266 |
+
xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
|
267 |
+
xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
|
268 |
+
ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
|
269 |
+
ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
|
270 |
+
|
271 |
+
if (xmin <= xmax)
|
272 |
+
{
|
273 |
+
for (int y = ymin; y <= ymax; ++y)
|
274 |
+
{
|
275 |
+
for (int x = xmin; x <= xmax; ++x)
|
276 |
+
{
|
277 |
+
vec3f L = cube_to_dir(x, y, s, Npx);
|
278 |
+
if (dot(L, VNR) >= p.costheta_cutoff)
|
279 |
+
{
|
280 |
+
vec3f H = safeNormalize(L + VNR);
|
281 |
+
|
282 |
+
float wiDotN = max(dot(L, VNR), 0.0f);
|
283 |
+
float VNRDotH = max(dot(VNR, H), 0.0f);
|
284 |
+
|
285 |
+
float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
|
286 |
+
col += p.cubemap.fetch3(x, y, s) * w;
|
287 |
+
wsum += w;
|
288 |
+
}
|
289 |
+
}
|
290 |
+
}
|
291 |
+
}
|
292 |
+
}
|
293 |
+
|
294 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x);
|
295 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y);
|
296 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z);
|
297 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum);
|
298 |
+
}
|
299 |
+
|
300 |
+
__global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p)
|
301 |
+
{
|
302 |
+
// Calculate pixel position.
|
303 |
+
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
304 |
+
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
305 |
+
int pz = blockIdx.z;
|
306 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
307 |
+
return;
|
308 |
+
|
309 |
+
int Npx = p.cubemap.dims[1];
|
310 |
+
vec3f VNR = cube_to_dir(px, py, pz, Npx);
|
311 |
+
|
312 |
+
vec3f grad = p.out.fetch3(px, py, pz);
|
313 |
+
|
314 |
+
float alpha = p.roughness * p.roughness;
|
315 |
+
float alphaSqr = alpha * alpha;
|
316 |
+
|
317 |
+
vec3f col(0);
|
318 |
+
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
319 |
+
{
|
320 |
+
int xmin, xmax, ymin, ymax;
|
321 |
+
xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
|
322 |
+
xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
|
323 |
+
ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
|
324 |
+
ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
|
325 |
+
|
326 |
+
if (xmin <= xmax)
|
327 |
+
{
|
328 |
+
for (int y = ymin; y <= ymax; ++y)
|
329 |
+
{
|
330 |
+
for (int x = xmin; x <= xmax; ++x)
|
331 |
+
{
|
332 |
+
vec3f L = cube_to_dir(x, y, s, Npx);
|
333 |
+
if (dot(L, VNR) >= p.costheta_cutoff)
|
334 |
+
{
|
335 |
+
vec3f H = safeNormalize(L + VNR);
|
336 |
+
|
337 |
+
float wiDotN = max(dot(L, VNR), 0.0f);
|
338 |
+
float VNRDotH = max(dot(VNR, H), 0.0f);
|
339 |
+
|
340 |
+
float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
|
341 |
+
|
342 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
|
343 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
|
344 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
|
345 |
+
}
|
346 |
+
}
|
347 |
+
}
|
348 |
+
}
|
349 |
+
}
|
350 |
+
}
|
video3d/render/renderutils/c_src/cubemap.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
#include "common.h"
|
15 |
+
|
16 |
+
struct DiffuseCubemapKernelParams
|
17 |
+
{
|
18 |
+
Tensor cubemap;
|
19 |
+
Tensor out;
|
20 |
+
dim3 gridSize;
|
21 |
+
};
|
22 |
+
|
23 |
+
struct SpecularCubemapKernelParams
|
24 |
+
{
|
25 |
+
Tensor cubemap;
|
26 |
+
Tensor bounds;
|
27 |
+
Tensor out;
|
28 |
+
dim3 gridSize;
|
29 |
+
float costheta_cutoff;
|
30 |
+
float roughness;
|
31 |
+
};
|
32 |
+
|
33 |
+
struct SpecularBoundsKernelParams
|
34 |
+
{
|
35 |
+
float costheta_cutoff;
|
36 |
+
Tensor out;
|
37 |
+
dim3 gridSize;
|
38 |
+
};
|
video3d/render/renderutils/c_src/loss.cu
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include <cuda.h>
|
13 |
+
|
14 |
+
#include "common.h"
|
15 |
+
#include "loss.h"
|
16 |
+
|
17 |
+
//------------------------------------------------------------------------
|
18 |
+
// Utils
|
19 |
+
|
20 |
+
__device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; }
|
21 |
+
|
22 |
+
__device__ float warpSum(float val) {
|
23 |
+
for (int i = 1; i < 32; i *= 2)
|
24 |
+
val += __shfl_xor_sync(0xFFFFFFFF, val, i);
|
25 |
+
return val;
|
26 |
+
}
|
27 |
+
|
28 |
+
//------------------------------------------------------------------------
|
29 |
+
// Tonemapping
|
30 |
+
|
31 |
+
__device__ inline float fwdSRGB(float x)
|
32 |
+
{
|
33 |
+
return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f);
|
34 |
+
}
|
35 |
+
|
36 |
+
__device__ inline void bwdSRGB(float x, float &d_x, float d_out)
|
37 |
+
{
|
38 |
+
if (x > 0.0031308f)
|
39 |
+
d_x += d_out * 0.439583f / powf(x, 0.583333f);
|
40 |
+
else if (x > 0.0f)
|
41 |
+
d_x += d_out * 12.92f;
|
42 |
+
}
|
43 |
+
|
44 |
+
__device__ inline vec3f fwdTonemapLogSRGB(vec3f x)
|
45 |
+
{
|
46 |
+
return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f)));
|
47 |
+
}
|
48 |
+
|
49 |
+
__device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out)
|
50 |
+
{
|
51 |
+
if (x.x > 0.0f && x.x < 65535.0f)
|
52 |
+
{
|
53 |
+
bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x);
|
54 |
+
d_x.x *= 1 / (x.x + 1.0f);
|
55 |
+
}
|
56 |
+
if (x.y > 0.0f && x.y < 65535.0f)
|
57 |
+
{
|
58 |
+
bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y);
|
59 |
+
d_x.y *= 1 / (x.y + 1.0f);
|
60 |
+
}
|
61 |
+
if (x.z > 0.0f && x.z < 65535.0f)
|
62 |
+
{
|
63 |
+
bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z);
|
64 |
+
d_x.z *= 1 / (x.z + 1.0f);
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
68 |
+
__device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f)
|
69 |
+
{
|
70 |
+
return (img - target) * (img - target) / (img * img + target * target + eps);
|
71 |
+
}
|
72 |
+
|
73 |
+
__device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f)
|
74 |
+
{
|
75 |
+
float denom = (target * target + img * img + eps);
|
76 |
+
d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom);
|
77 |
+
d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom);
|
78 |
+
}
|
79 |
+
|
80 |
+
__device__ inline float fwdSMAPE(float img, float target, float eps=0.01f)
|
81 |
+
{
|
82 |
+
return abs(img - target) / (img + target + eps);
|
83 |
+
}
|
84 |
+
|
85 |
+
__device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f)
|
86 |
+
{
|
87 |
+
float denom = (target + img + eps);
|
88 |
+
d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom);
|
89 |
+
d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom);
|
90 |
+
}
|
91 |
+
|
92 |
+
//------------------------------------------------------------------------
|
93 |
+
// Kernels
|
94 |
+
|
95 |
+
__global__ void imgLossFwdKernel(LossKernelParams p)
|
96 |
+
{
|
97 |
+
// Calculate pixel position.
|
98 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
99 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
100 |
+
unsigned int pz = blockIdx.z;
|
101 |
+
|
102 |
+
float floss = 0.0f;
|
103 |
+
if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z)
|
104 |
+
{
|
105 |
+
vec3f img = p.img.fetch3(px, py, pz);
|
106 |
+
vec3f target = p.target.fetch3(px, py, pz);
|
107 |
+
|
108 |
+
img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f));
|
109 |
+
target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f));
|
110 |
+
|
111 |
+
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
112 |
+
{
|
113 |
+
img = fwdTonemapLogSRGB(img);
|
114 |
+
target = fwdTonemapLogSRGB(target);
|
115 |
+
}
|
116 |
+
|
117 |
+
vec3f vloss(0);
|
118 |
+
if (p.loss == LOSS_MSE)
|
119 |
+
vloss = (img - target) * (img - target);
|
120 |
+
else if (p.loss == LOSS_RELMSE)
|
121 |
+
vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z));
|
122 |
+
else if (p.loss == LOSS_SMAPE)
|
123 |
+
vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z));
|
124 |
+
else
|
125 |
+
vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z));
|
126 |
+
|
127 |
+
floss = sum(vloss) / 3.0f;
|
128 |
+
}
|
129 |
+
|
130 |
+
floss = warpSum(floss);
|
131 |
+
|
132 |
+
dim3 warpSize = getWarpSize(blockDim);
|
133 |
+
if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0)
|
134 |
+
p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss);
|
135 |
+
}
|
136 |
+
|
137 |
+
__global__ void imgLossBwdKernel(LossKernelParams p)
|
138 |
+
{
|
139 |
+
// Calculate pixel position.
|
140 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
141 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
142 |
+
unsigned int pz = blockIdx.z;
|
143 |
+
|
144 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
145 |
+
return;
|
146 |
+
|
147 |
+
dim3 warpSize = getWarpSize(blockDim);
|
148 |
+
|
149 |
+
vec3f _img = p.img.fetch3(px, py, pz);
|
150 |
+
vec3f _target = p.target.fetch3(px, py, pz);
|
151 |
+
float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z);
|
152 |
+
|
153 |
+
/////////////////////////////////////////////////////////////////////
|
154 |
+
// FWD
|
155 |
+
|
156 |
+
vec3f img = _img, target = _target;
|
157 |
+
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
158 |
+
{
|
159 |
+
img = fwdTonemapLogSRGB(img);
|
160 |
+
target = fwdTonemapLogSRGB(target);
|
161 |
+
}
|
162 |
+
|
163 |
+
/////////////////////////////////////////////////////////////////////
|
164 |
+
// BWD
|
165 |
+
|
166 |
+
vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f;
|
167 |
+
|
168 |
+
vec3f d_img(0), d_target(0);
|
169 |
+
if (p.loss == LOSS_MSE)
|
170 |
+
{
|
171 |
+
d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z));
|
172 |
+
d_target = -d_img;
|
173 |
+
}
|
174 |
+
else if (p.loss == LOSS_RELMSE)
|
175 |
+
{
|
176 |
+
bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
|
177 |
+
bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
|
178 |
+
bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
|
179 |
+
}
|
180 |
+
else if (p.loss == LOSS_SMAPE)
|
181 |
+
{
|
182 |
+
bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
|
183 |
+
bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
|
184 |
+
bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
|
185 |
+
}
|
186 |
+
else
|
187 |
+
{
|
188 |
+
d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z));
|
189 |
+
d_target = -d_img;
|
190 |
+
}
|
191 |
+
|
192 |
+
|
193 |
+
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
194 |
+
{
|
195 |
+
vec3f d__img(0), d__target(0);
|
196 |
+
bwdTonemapLogSRGB(_img, d__img, d_img);
|
197 |
+
bwdTonemapLogSRGB(_target, d__target, d_target);
|
198 |
+
d_img = d__img; d_target = d__target;
|
199 |
+
}
|
200 |
+
|
201 |
+
if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0;
|
202 |
+
if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0;
|
203 |
+
if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0;
|
204 |
+
if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0;
|
205 |
+
if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0;
|
206 |
+
if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0;
|
207 |
+
|
208 |
+
p.img.store_grad(px, py, pz, d_img);
|
209 |
+
p.target.store_grad(px, py, pz, d_target);
|
210 |
+
}
|
video3d/render/renderutils/c_src/loss.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
#include "common.h"
|
15 |
+
|
16 |
+
enum TonemapperType
|
17 |
+
{
|
18 |
+
TONEMAPPER_NONE = 0,
|
19 |
+
TONEMAPPER_LOG_SRGB = 1
|
20 |
+
};
|
21 |
+
|
22 |
+
enum LossType
|
23 |
+
{
|
24 |
+
LOSS_L1 = 0,
|
25 |
+
LOSS_MSE = 1,
|
26 |
+
LOSS_RELMSE = 2,
|
27 |
+
LOSS_SMAPE = 3
|
28 |
+
};
|
29 |
+
|
30 |
+
struct LossKernelParams
|
31 |
+
{
|
32 |
+
Tensor img;
|
33 |
+
Tensor target;
|
34 |
+
Tensor out;
|
35 |
+
dim3 gridSize;
|
36 |
+
TonemapperType tonemapper;
|
37 |
+
LossType loss;
|
38 |
+
};
|
video3d/render/renderutils/c_src/mesh.cu
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include <cuda.h>
|
13 |
+
#include <stdio.h>
|
14 |
+
|
15 |
+
#include "common.h"
|
16 |
+
#include "mesh.h"
|
17 |
+
|
18 |
+
|
19 |
+
//------------------------------------------------------------------------
|
20 |
+
// Kernels
|
21 |
+
|
22 |
+
__global__ void xfmPointsFwdKernel(XfmKernelParams p)
|
23 |
+
{
|
24 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
25 |
+
unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
|
26 |
+
|
27 |
+
__shared__ float mtx[4][4];
|
28 |
+
if (threadIdx.x < 16)
|
29 |
+
mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
|
30 |
+
__syncthreads();
|
31 |
+
|
32 |
+
if (px >= p.gridSize.x)
|
33 |
+
return;
|
34 |
+
|
35 |
+
vec3f pos(
|
36 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
|
37 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
|
38 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
|
39 |
+
);
|
40 |
+
|
41 |
+
if (p.isPoints)
|
42 |
+
{
|
43 |
+
p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]);
|
44 |
+
p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]);
|
45 |
+
p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]);
|
46 |
+
p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]);
|
47 |
+
}
|
48 |
+
else
|
49 |
+
{
|
50 |
+
p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]);
|
51 |
+
p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]);
|
52 |
+
p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]);
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
56 |
+
__global__ void xfmPointsBwdKernel(XfmKernelParams p)
|
57 |
+
{
|
58 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
59 |
+
unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;
|
60 |
+
|
61 |
+
__shared__ float mtx[4][4];
|
62 |
+
if (threadIdx.x < 16)
|
63 |
+
mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));
|
64 |
+
__syncthreads();
|
65 |
+
|
66 |
+
if (px >= p.gridSize.x)
|
67 |
+
return;
|
68 |
+
|
69 |
+
vec3f pos(
|
70 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),
|
71 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),
|
72 |
+
p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))
|
73 |
+
);
|
74 |
+
|
75 |
+
vec4f d_out(
|
76 |
+
p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)),
|
77 |
+
p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)),
|
78 |
+
p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)),
|
79 |
+
p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0))
|
80 |
+
);
|
81 |
+
|
82 |
+
if (p.isPoints)
|
83 |
+
{
|
84 |
+
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]);
|
85 |
+
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]);
|
86 |
+
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]);
|
87 |
+
}
|
88 |
+
else
|
89 |
+
{
|
90 |
+
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]);
|
91 |
+
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]);
|
92 |
+
p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]);
|
93 |
+
}
|
94 |
+
}
|
video3d/render/renderutils/c_src/mesh.h
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
#include "common.h"
|
15 |
+
|
16 |
+
struct XfmKernelParams
|
17 |
+
{
|
18 |
+
bool isPoints;
|
19 |
+
Tensor points;
|
20 |
+
Tensor matrix;
|
21 |
+
Tensor out;
|
22 |
+
dim3 gridSize;
|
23 |
+
};
|
video3d/render/renderutils/c_src/normal.cu
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "common.h"
|
13 |
+
#include "normal.h"
|
14 |
+
|
15 |
+
#define NORMAL_THRESHOLD 0.1f
|
16 |
+
|
17 |
+
//------------------------------------------------------------------------
|
18 |
+
// Perturb shading normal by tangent frame
|
19 |
+
|
20 |
+
__device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl)
|
21 |
+
{
|
22 |
+
vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
|
23 |
+
vec3f smooth_bitng = safeNormalize(_smooth_bitng);
|
24 |
+
vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
|
25 |
+
return safeNormalize(_shading_nrm);
|
26 |
+
}
|
27 |
+
|
28 |
+
__device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl)
|
29 |
+
{
|
30 |
+
////////////////////////////////////////////////////////////////////////
|
31 |
+
// FWD
|
32 |
+
vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);
|
33 |
+
vec3f smooth_bitng = safeNormalize(_smooth_bitng);
|
34 |
+
vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);
|
35 |
+
|
36 |
+
////////////////////////////////////////////////////////////////////////
|
37 |
+
// BWD
|
38 |
+
vec3f d_shading_nrm(0);
|
39 |
+
bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out);
|
40 |
+
|
41 |
+
vec3f d_smooth_bitng(0);
|
42 |
+
|
43 |
+
if (perturbed_nrm.z > 0.0f)
|
44 |
+
{
|
45 |
+
d_smooth_nrm += d_shading_nrm * perturbed_nrm.z;
|
46 |
+
d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm);
|
47 |
+
}
|
48 |
+
|
49 |
+
d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y;
|
50 |
+
d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng);
|
51 |
+
|
52 |
+
d_smooth_tng += d_shading_nrm * perturbed_nrm.x;
|
53 |
+
d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng);
|
54 |
+
|
55 |
+
vec3f d__smooth_bitng(0);
|
56 |
+
bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng);
|
57 |
+
|
58 |
+
bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng);
|
59 |
+
}
|
60 |
+
|
61 |
+
//------------------------------------------------------------------------
|
62 |
+
#define bent_nrm_eps 0.001f
|
63 |
+
|
64 |
+
__device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm)
|
65 |
+
{
|
66 |
+
float dp = dot(view_vec, smooth_nrm);
|
67 |
+
float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
|
68 |
+
return geom_nrm * (1.0f - t) + smooth_nrm * t;
|
69 |
+
}
|
70 |
+
|
71 |
+
__device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out)
|
72 |
+
{
|
73 |
+
////////////////////////////////////////////////////////////////////////
|
74 |
+
// FWD
|
75 |
+
float dp = dot(view_vec, smooth_nrm);
|
76 |
+
float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);
|
77 |
+
|
78 |
+
////////////////////////////////////////////////////////////////////////
|
79 |
+
// BWD
|
80 |
+
if (dp > NORMAL_THRESHOLD)
|
81 |
+
d_smooth_nrm += d_out;
|
82 |
+
else
|
83 |
+
{
|
84 |
+
// geom_nrm * (1.0f - t) + smooth_nrm * t;
|
85 |
+
d_geom_nrm += d_out * (1.0f - t);
|
86 |
+
d_smooth_nrm += d_out * t;
|
87 |
+
float d_t = sum(d_out * (smooth_nrm - geom_nrm));
|
88 |
+
|
89 |
+
float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD;
|
90 |
+
|
91 |
+
bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp);
|
92 |
+
}
|
93 |
+
}
|
94 |
+
|
95 |
+
//------------------------------------------------------------------------
|
96 |
+
// Kernels
|
97 |
+
|
98 |
+
__global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p)
|
99 |
+
{
|
100 |
+
// Calculate pixel position.
|
101 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
102 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
103 |
+
unsigned int pz = blockIdx.z;
|
104 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
105 |
+
return;
|
106 |
+
|
107 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
108 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
109 |
+
vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
|
110 |
+
vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
|
111 |
+
vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
|
112 |
+
vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
|
113 |
+
|
114 |
+
vec3f smooth_nrm = safeNormalize(_smooth_nrm);
|
115 |
+
vec3f smooth_tng = safeNormalize(_smooth_tng);
|
116 |
+
vec3f view_vec = safeNormalize(view_pos - pos);
|
117 |
+
vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
|
118 |
+
|
119 |
+
vec3f res;
|
120 |
+
if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
|
121 |
+
res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm);
|
122 |
+
else
|
123 |
+
res = fwdBendNormal(view_vec, shading_nrm, geom_nrm);
|
124 |
+
|
125 |
+
p.out.store(px, py, pz, res);
|
126 |
+
}
|
127 |
+
|
128 |
+
__global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p)
|
129 |
+
{
|
130 |
+
// Calculate pixel position.
|
131 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
132 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
133 |
+
unsigned int pz = blockIdx.z;
|
134 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
135 |
+
return;
|
136 |
+
|
137 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
138 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
139 |
+
vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);
|
140 |
+
vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);
|
141 |
+
vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);
|
142 |
+
vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);
|
143 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
144 |
+
|
145 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
146 |
+
// FWD
|
147 |
+
|
148 |
+
vec3f smooth_nrm = safeNormalize(_smooth_nrm);
|
149 |
+
vec3f smooth_tng = safeNormalize(_smooth_tng);
|
150 |
+
vec3f _view_vec = view_pos - pos;
|
151 |
+
vec3f view_vec = safeNormalize(view_pos - pos);
|
152 |
+
|
153 |
+
vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);
|
154 |
+
|
155 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
156 |
+
// BWD
|
157 |
+
|
158 |
+
vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0);
|
159 |
+
if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)
|
160 |
+
{
|
161 |
+
bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
|
162 |
+
d_shading_nrm = -d_shading_nrm;
|
163 |
+
d_geom_nrm = -d_geom_nrm;
|
164 |
+
}
|
165 |
+
else
|
166 |
+
bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);
|
167 |
+
|
168 |
+
vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0);
|
169 |
+
bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl);
|
170 |
+
|
171 |
+
vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0);
|
172 |
+
bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec);
|
173 |
+
bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm);
|
174 |
+
bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng);
|
175 |
+
|
176 |
+
p.pos.store_grad(px, py, pz, -d__view_vec);
|
177 |
+
p.view_pos.store_grad(px, py, pz, d__view_vec);
|
178 |
+
p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm);
|
179 |
+
p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm);
|
180 |
+
p.smooth_tng.store_grad(px, py, pz, d__smooth_tng);
|
181 |
+
p.geom_nrm.store_grad(px, py, pz, d_geom_nrm);
|
182 |
+
}
|
video3d/render/renderutils/c_src/normal.h
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
#include "common.h"
|
15 |
+
|
16 |
+
struct PrepareShadingNormalKernelParams
|
17 |
+
{
|
18 |
+
Tensor pos;
|
19 |
+
Tensor view_pos;
|
20 |
+
Tensor perturbed_nrm;
|
21 |
+
Tensor smooth_nrm;
|
22 |
+
Tensor smooth_tng;
|
23 |
+
Tensor geom_nrm;
|
24 |
+
Tensor out;
|
25 |
+
dim3 gridSize;
|
26 |
+
bool two_sided_shading, opengl;
|
27 |
+
};
|
video3d/render/renderutils/c_src/tensor.h
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
#if defined(__CUDACC__) && defined(BFLOAT16)
|
14 |
+
#include <cuda_bf16.h> // bfloat16 is float32 compatible with less mantissa bits
|
15 |
+
#endif
|
16 |
+
|
17 |
+
//---------------------------------------------------------------------------------
|
18 |
+
// CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16
|
19 |
+
|
20 |
+
struct Tensor
|
21 |
+
{
|
22 |
+
void* val;
|
23 |
+
void* d_val;
|
24 |
+
int dims[4], _dims[4];
|
25 |
+
int strides[4];
|
26 |
+
bool fp16;
|
27 |
+
|
28 |
+
#if defined(__CUDA__) && !defined(__CUDA_ARCH__)
|
29 |
+
Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {}
|
30 |
+
#endif
|
31 |
+
|
32 |
+
#ifdef __CUDACC__
|
33 |
+
// Helpers to index and read/write a single element
|
34 |
+
__device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; }
|
35 |
+
__device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); }
|
36 |
+
__device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; }
|
37 |
+
#ifdef BFLOAT16
|
38 |
+
__device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; }
|
39 |
+
__device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; }
|
40 |
+
__device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; }
|
41 |
+
#else
|
42 |
+
__device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; }
|
43 |
+
__device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; }
|
44 |
+
__device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; }
|
45 |
+
#endif
|
46 |
+
|
47 |
+
//////////////////////////////////////////////////////////////////////////////////////////
|
48 |
+
// Fetch, use broadcasting for tensor dimensions of size 1
|
49 |
+
__device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const
|
50 |
+
{
|
51 |
+
return fetch(nhwcIndex(z, y, x, 0));
|
52 |
+
}
|
53 |
+
|
54 |
+
__device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const
|
55 |
+
{
|
56 |
+
return vec3f(
|
57 |
+
fetch(nhwcIndex(z, y, x, 0)),
|
58 |
+
fetch(nhwcIndex(z, y, x, 1)),
|
59 |
+
fetch(nhwcIndex(z, y, x, 2))
|
60 |
+
);
|
61 |
+
}
|
62 |
+
|
63 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
64 |
+
// Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
|
65 |
+
__device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val)
|
66 |
+
{
|
67 |
+
store(_nhwcIndex(z, y, x, 0), _val);
|
68 |
+
}
|
69 |
+
|
70 |
+
__device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
|
71 |
+
{
|
72 |
+
store(_nhwcIndex(z, y, x, 0), _val.x);
|
73 |
+
store(_nhwcIndex(z, y, x, 1), _val.y);
|
74 |
+
store(_nhwcIndex(z, y, x, 2), _val.z);
|
75 |
+
}
|
76 |
+
|
77 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
78 |
+
// Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside
|
79 |
+
__device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val)
|
80 |
+
{
|
81 |
+
store_grad(nhwcIndexContinuous(z, y, x, 0), _val);
|
82 |
+
}
|
83 |
+
|
84 |
+
__device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val)
|
85 |
+
{
|
86 |
+
store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x);
|
87 |
+
store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y);
|
88 |
+
store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z);
|
89 |
+
}
|
90 |
+
#endif
|
91 |
+
|
92 |
+
};
|
video3d/render/renderutils/c_src/torch_bindings.cpp
ADDED
@@ -0,0 +1,1062 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#ifdef _MSC_VER
|
13 |
+
#pragma warning(push, 0)
|
14 |
+
#include <torch/extension.h>
|
15 |
+
#pragma warning(pop)
|
16 |
+
#else
|
17 |
+
#include <torch/extension.h>
|
18 |
+
#endif
|
19 |
+
|
20 |
+
#include <ATen/cuda/CUDAContext.h>
|
21 |
+
#include <ATen/cuda/CUDAUtils.h>
|
22 |
+
#include <algorithm>
|
23 |
+
#include <string>
|
24 |
+
|
25 |
+
#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) { cudaError_t err = CUDA_CALL; AT_CUDA_CHECK(cudaGetLastError()); }
|
26 |
+
#define NVDR_CHECK_GL_ERROR(GL_CALL) { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, "OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]"); }
|
27 |
+
#define CHECK_TENSOR(X, DIMS, CHANNELS) \
|
28 |
+
TORCH_CHECK(X.is_cuda(), #X " must be a cuda tensor") \
|
29 |
+
TORCH_CHECK(X.scalar_type() == torch::kFloat || X.scalar_type() == torch::kBFloat16, #X " must be fp32 or bf16") \
|
30 |
+
TORCH_CHECK(X.dim() == DIMS, #X " must have " #DIMS " dimensions") \
|
31 |
+
TORCH_CHECK(X.size(DIMS - 1) == CHANNELS, #X " must have " #CHANNELS " channels")
|
32 |
+
|
33 |
+
#include "common.h"
|
34 |
+
#include "loss.h"
|
35 |
+
#include "normal.h"
|
36 |
+
#include "cubemap.h"
|
37 |
+
#include "bsdf.h"
|
38 |
+
#include "mesh.h"
|
39 |
+
|
40 |
+
#define BLOCK_X 8
|
41 |
+
#define BLOCK_Y 8
|
42 |
+
|
43 |
+
//------------------------------------------------------------------------
|
44 |
+
// mesh.cu
|
45 |
+
|
46 |
+
void xfmPointsFwdKernel(XfmKernelParams p);
|
47 |
+
void xfmPointsBwdKernel(XfmKernelParams p);
|
48 |
+
|
49 |
+
//------------------------------------------------------------------------
|
50 |
+
// loss.cu
|
51 |
+
|
52 |
+
void imgLossFwdKernel(LossKernelParams p);
|
53 |
+
void imgLossBwdKernel(LossKernelParams p);
|
54 |
+
|
55 |
+
//------------------------------------------------------------------------
|
56 |
+
// normal.cu
|
57 |
+
|
58 |
+
void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p);
|
59 |
+
void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p);
|
60 |
+
|
61 |
+
//------------------------------------------------------------------------
|
62 |
+
// cubemap.cu
|
63 |
+
|
64 |
+
void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p);
|
65 |
+
void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p);
|
66 |
+
void SpecularBoundsKernel(SpecularBoundsKernelParams p);
|
67 |
+
void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p);
|
68 |
+
void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p);
|
69 |
+
|
70 |
+
//------------------------------------------------------------------------
|
71 |
+
// bsdf.cu
|
72 |
+
|
73 |
+
void LambertFwdKernel(LambertKernelParams p);
|
74 |
+
void LambertBwdKernel(LambertKernelParams p);
|
75 |
+
|
76 |
+
void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p);
|
77 |
+
void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p);
|
78 |
+
|
79 |
+
void FresnelShlickFwdKernel(FresnelShlickKernelParams p);
|
80 |
+
void FresnelShlickBwdKernel(FresnelShlickKernelParams p);
|
81 |
+
|
82 |
+
void ndfGGXFwdKernel(NdfGGXParams p);
|
83 |
+
void ndfGGXBwdKernel(NdfGGXParams p);
|
84 |
+
|
85 |
+
void lambdaGGXFwdKernel(NdfGGXParams p);
|
86 |
+
void lambdaGGXBwdKernel(NdfGGXParams p);
|
87 |
+
|
88 |
+
void maskingSmithFwdKernel(MaskingSmithParams p);
|
89 |
+
void maskingSmithBwdKernel(MaskingSmithParams p);
|
90 |
+
|
91 |
+
void pbrSpecularFwdKernel(PbrSpecular p);
|
92 |
+
void pbrSpecularBwdKernel(PbrSpecular p);
|
93 |
+
|
94 |
+
void pbrBSDFFwdKernel(PbrBSDF p);
|
95 |
+
void pbrBSDFBwdKernel(PbrBSDF p);
|
96 |
+
|
97 |
+
//------------------------------------------------------------------------
|
98 |
+
// Tensor helpers
|
99 |
+
|
100 |
+
void update_grid(dim3 &gridSize, torch::Tensor x)
|
101 |
+
{
|
102 |
+
gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2));
|
103 |
+
gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1));
|
104 |
+
gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0));
|
105 |
+
}
|
106 |
+
|
107 |
+
template<typename... Ts>
|
108 |
+
void update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs)
|
109 |
+
{
|
110 |
+
gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2));
|
111 |
+
gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1));
|
112 |
+
gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0));
|
113 |
+
update_grid(gridSize, std::forward<Ts>(vs)...);
|
114 |
+
}
|
115 |
+
|
116 |
+
Tensor make_cuda_tensor(torch::Tensor val)
|
117 |
+
{
|
118 |
+
Tensor res;
|
119 |
+
for (int i = 0; i < val.dim(); ++i)
|
120 |
+
{
|
121 |
+
res.dims[i] = val.size(i);
|
122 |
+
res.strides[i] = val.stride(i);
|
123 |
+
}
|
124 |
+
res.fp16 = val.scalar_type() == torch::kBFloat16;
|
125 |
+
res.val = res.fp16 ? (void*)val.data_ptr<torch::BFloat16>() : (void*)val.data_ptr<float>();
|
126 |
+
res.d_val = nullptr;
|
127 |
+
return res;
|
128 |
+
}
|
129 |
+
|
130 |
+
Tensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* grad = nullptr)
|
131 |
+
{
|
132 |
+
Tensor res;
|
133 |
+
for (int i = 0; i < val.dim(); ++i)
|
134 |
+
{
|
135 |
+
res.dims[i] = val.size(i);
|
136 |
+
res.strides[i] = val.stride(i);
|
137 |
+
}
|
138 |
+
if (val.dim() == 4)
|
139 |
+
res._dims[0] = outDims.z, res._dims[1] = outDims.y, res._dims[2] = outDims.x, res._dims[3] = val.size(3);
|
140 |
+
else
|
141 |
+
res._dims[0] = outDims.z, res._dims[1] = outDims.x, res._dims[2] = val.size(2), res._dims[3] = 1; // Add a trailing one for indexing math to work out
|
142 |
+
|
143 |
+
res.fp16 = val.scalar_type() == torch::kBFloat16;
|
144 |
+
res.val = res.fp16 ? (void*)val.data_ptr<torch::BFloat16>() : (void*)val.data_ptr<float>();
|
145 |
+
res.d_val = nullptr;
|
146 |
+
if (grad != nullptr)
|
147 |
+
{
|
148 |
+
if (val.dim() == 4)
|
149 |
+
*grad = torch::empty({ outDims.z, outDims.y, outDims.x, val.size(3) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA));
|
150 |
+
else // 3
|
151 |
+
*grad = torch::empty({ outDims.z, outDims.x, val.size(2) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA));
|
152 |
+
|
153 |
+
res.d_val = res.fp16 ? (void*)grad->data_ptr<torch::BFloat16>() : (void*)grad->data_ptr<float>();
|
154 |
+
}
|
155 |
+
return res;
|
156 |
+
}
|
157 |
+
|
158 |
+
//------------------------------------------------------------------------
|
159 |
+
// prepare_shading_normal
|
160 |
+
|
161 |
+
torch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, bool two_sided_shading, bool opengl, bool fp16)
|
162 |
+
{
|
163 |
+
CHECK_TENSOR(pos, 4, 3);
|
164 |
+
CHECK_TENSOR(view_pos, 4, 3);
|
165 |
+
CHECK_TENSOR(perturbed_nrm, 4, 3);
|
166 |
+
CHECK_TENSOR(smooth_nrm, 4, 3);
|
167 |
+
CHECK_TENSOR(smooth_tng, 4, 3);
|
168 |
+
CHECK_TENSOR(geom_nrm, 4, 3);
|
169 |
+
|
170 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
171 |
+
|
172 |
+
// Extract input parameters.
|
173 |
+
PrepareShadingNormalKernelParams p;
|
174 |
+
p.two_sided_shading = two_sided_shading;
|
175 |
+
p.opengl = opengl;
|
176 |
+
p.out.fp16 = fp16;
|
177 |
+
update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm);
|
178 |
+
|
179 |
+
// Allocate output tensors.
|
180 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
181 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
|
182 |
+
|
183 |
+
// Choose launch parameters.
|
184 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
185 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
186 |
+
|
187 |
+
// Setup tensors
|
188 |
+
p.pos = make_cuda_tensor(pos, p.gridSize);
|
189 |
+
p.view_pos = make_cuda_tensor(view_pos, p.gridSize);
|
190 |
+
p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize);
|
191 |
+
p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize);
|
192 |
+
p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize);
|
193 |
+
p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize);
|
194 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
195 |
+
|
196 |
+
// Launch CUDA kernel.
|
197 |
+
void* args[] = { &p };
|
198 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalFwdKernel, gridSize, blockSize, args, 0, stream));
|
199 |
+
|
200 |
+
return out;
|
201 |
+
}
|
202 |
+
|
203 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> prepare_shading_normal_bwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, torch::Tensor grad, bool two_sided_shading, bool opengl)
|
204 |
+
{
|
205 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
206 |
+
|
207 |
+
// Extract input parameters.
|
208 |
+
PrepareShadingNormalKernelParams p;
|
209 |
+
p.two_sided_shading = two_sided_shading;
|
210 |
+
p.opengl = opengl;
|
211 |
+
update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm);
|
212 |
+
|
213 |
+
// Choose launch parameters.
|
214 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
215 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
216 |
+
|
217 |
+
// Setup tensors
|
218 |
+
torch::Tensor pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad;
|
219 |
+
p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad);
|
220 |
+
p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad);
|
221 |
+
p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize, &perturbed_nrm_grad);
|
222 |
+
p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize, &smooth_nrm_grad);
|
223 |
+
p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize, &smooth_tng_grad);
|
224 |
+
p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize, &geom_nrm_grad);
|
225 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
226 |
+
|
227 |
+
// Launch CUDA kernel.
|
228 |
+
void* args[] = { &p };
|
229 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalBwdKernel, gridSize, blockSize, args, 0, stream));
|
230 |
+
|
231 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad);
|
232 |
+
}
|
233 |
+
|
234 |
+
//------------------------------------------------------------------------
|
235 |
+
// lambert
|
236 |
+
|
237 |
+
torch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16)
|
238 |
+
{
|
239 |
+
CHECK_TENSOR(nrm, 4, 3);
|
240 |
+
CHECK_TENSOR(wi, 4, 3);
|
241 |
+
|
242 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
243 |
+
|
244 |
+
// Extract input parameters.
|
245 |
+
LambertKernelParams p;
|
246 |
+
p.out.fp16 = fp16;
|
247 |
+
update_grid(p.gridSize, nrm, wi);
|
248 |
+
|
249 |
+
// Allocate output tensors.
|
250 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
251 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
|
252 |
+
|
253 |
+
// Choose launch parameters.
|
254 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
255 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
256 |
+
|
257 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize);
|
258 |
+
p.wi = make_cuda_tensor(wi, p.gridSize);
|
259 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
260 |
+
|
261 |
+
// Launch CUDA kernel.
|
262 |
+
void* args[] = { &p };
|
263 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertFwdKernel, gridSize, blockSize, args, 0, stream));
|
264 |
+
|
265 |
+
return out;
|
266 |
+
}
|
267 |
+
|
268 |
+
std::tuple<torch::Tensor, torch::Tensor> lambert_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor grad)
|
269 |
+
{
|
270 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
271 |
+
|
272 |
+
// Extract input parameters.
|
273 |
+
LambertKernelParams p;
|
274 |
+
update_grid(p.gridSize, nrm, wi);
|
275 |
+
|
276 |
+
// Choose launch parameters.
|
277 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
278 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
279 |
+
|
280 |
+
torch::Tensor nrm_grad, wi_grad;
|
281 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
|
282 |
+
p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);
|
283 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
284 |
+
|
285 |
+
// Launch CUDA kernel.
|
286 |
+
void* args[] = { &p };
|
287 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertBwdKernel, gridSize, blockSize, args, 0, stream));
|
288 |
+
|
289 |
+
return std::tuple<torch::Tensor, torch::Tensor>(nrm_grad, wi_grad);
|
290 |
+
}
|
291 |
+
|
292 |
+
//------------------------------------------------------------------------
|
293 |
+
// frostbite diffuse
|
294 |
+
|
295 |
+
torch::Tensor frostbite_fwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, bool fp16)
|
296 |
+
{
|
297 |
+
CHECK_TENSOR(nrm, 4, 3);
|
298 |
+
CHECK_TENSOR(wi, 4, 3);
|
299 |
+
CHECK_TENSOR(wo, 4, 3);
|
300 |
+
CHECK_TENSOR(linearRoughness, 4, 1);
|
301 |
+
|
302 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
303 |
+
|
304 |
+
// Extract input parameters.
|
305 |
+
FrostbiteDiffuseKernelParams p;
|
306 |
+
p.out.fp16 = fp16;
|
307 |
+
update_grid(p.gridSize, nrm, wi, wo, linearRoughness);
|
308 |
+
|
309 |
+
// Allocate output tensors.
|
310 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
311 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
|
312 |
+
|
313 |
+
// Choose launch parameters.
|
314 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
315 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
316 |
+
|
317 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize);
|
318 |
+
p.wi = make_cuda_tensor(wi, p.gridSize);
|
319 |
+
p.wo = make_cuda_tensor(wo, p.gridSize);
|
320 |
+
p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize);
|
321 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
322 |
+
|
323 |
+
// Launch CUDA kernel.
|
324 |
+
void* args[] = { &p };
|
325 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseFwdKernel, gridSize, blockSize, args, 0, stream));
|
326 |
+
|
327 |
+
return out;
|
328 |
+
}
|
329 |
+
|
330 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> frostbite_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, torch::Tensor grad)
|
331 |
+
{
|
332 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
333 |
+
|
334 |
+
// Extract input parameters.
|
335 |
+
FrostbiteDiffuseKernelParams p;
|
336 |
+
update_grid(p.gridSize, nrm, wi, wo, linearRoughness);
|
337 |
+
|
338 |
+
// Choose launch parameters.
|
339 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
340 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
341 |
+
|
342 |
+
torch::Tensor nrm_grad, wi_grad, wo_grad, linearRoughness_grad;
|
343 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
|
344 |
+
p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);
|
345 |
+
p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad);
|
346 |
+
p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize, &linearRoughness_grad);
|
347 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
348 |
+
|
349 |
+
// Launch CUDA kernel.
|
350 |
+
void* args[] = { &p };
|
351 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseBwdKernel, gridSize, blockSize, args, 0, stream));
|
352 |
+
|
353 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(nrm_grad, wi_grad, wo_grad, linearRoughness_grad);
|
354 |
+
}
|
355 |
+
|
356 |
+
//------------------------------------------------------------------------
|
357 |
+
// fresnel_shlick
|
358 |
+
|
359 |
+
torch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, bool fp16)
|
360 |
+
{
|
361 |
+
CHECK_TENSOR(f0, 4, 3);
|
362 |
+
CHECK_TENSOR(f90, 4, 3);
|
363 |
+
CHECK_TENSOR(cosTheta, 4, 1);
|
364 |
+
|
365 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
366 |
+
|
367 |
+
// Extract input parameters.
|
368 |
+
FresnelShlickKernelParams p;
|
369 |
+
p.out.fp16 = fp16;
|
370 |
+
update_grid(p.gridSize, f0, f90, cosTheta);
|
371 |
+
|
372 |
+
// Allocate output tensors.
|
373 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
374 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
|
375 |
+
|
376 |
+
// Choose launch parameters.
|
377 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
378 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
379 |
+
|
380 |
+
p.f0 = make_cuda_tensor(f0, p.gridSize);
|
381 |
+
p.f90 = make_cuda_tensor(f90, p.gridSize);
|
382 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);
|
383 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
384 |
+
|
385 |
+
// Launch CUDA kernel.
|
386 |
+
void* args[] = { &p };
|
387 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickFwdKernel, gridSize, blockSize, args, 0, stream));
|
388 |
+
|
389 |
+
return out;
|
390 |
+
}
|
391 |
+
|
392 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> fresnel_shlick_bwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, torch::Tensor grad)
|
393 |
+
{
|
394 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
395 |
+
|
396 |
+
// Extract input parameters.
|
397 |
+
FresnelShlickKernelParams p;
|
398 |
+
update_grid(p.gridSize, f0, f90, cosTheta);
|
399 |
+
|
400 |
+
// Choose launch parameters.
|
401 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
402 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
403 |
+
|
404 |
+
torch::Tensor f0_grad, f90_grad, cosT_grad;
|
405 |
+
p.f0 = make_cuda_tensor(f0, p.gridSize, &f0_grad);
|
406 |
+
p.f90 = make_cuda_tensor(f90, p.gridSize, &f90_grad);
|
407 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosT_grad);
|
408 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
409 |
+
|
410 |
+
// Launch CUDA kernel.
|
411 |
+
void* args[] = { &p };
|
412 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickBwdKernel, gridSize, blockSize, args, 0, stream));
|
413 |
+
|
414 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>(f0_grad, f90_grad, cosT_grad);
|
415 |
+
}
|
416 |
+
|
417 |
+
//------------------------------------------------------------------------
|
418 |
+
// ndf_ggd
|
419 |
+
|
420 |
+
torch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16)
|
421 |
+
{
|
422 |
+
CHECK_TENSOR(alphaSqr, 4, 1);
|
423 |
+
CHECK_TENSOR(cosTheta, 4, 1);
|
424 |
+
|
425 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
426 |
+
|
427 |
+
// Extract input parameters.
|
428 |
+
NdfGGXParams p;
|
429 |
+
p.out.fp16 = fp16;
|
430 |
+
update_grid(p.gridSize, alphaSqr, cosTheta);
|
431 |
+
|
432 |
+
// Allocate output tensors.
|
433 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
434 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
|
435 |
+
|
436 |
+
// Choose launch parameters.
|
437 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
438 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
439 |
+
|
440 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);
|
441 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);
|
442 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
443 |
+
|
444 |
+
// Launch CUDA kernel.
|
445 |
+
void* args[] = { &p };
|
446 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXFwdKernel, gridSize, blockSize, args, 0, stream));
|
447 |
+
|
448 |
+
return out;
|
449 |
+
}
|
450 |
+
|
451 |
+
std::tuple<torch::Tensor, torch::Tensor> ndf_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad)
|
452 |
+
{
|
453 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
454 |
+
|
455 |
+
// Extract input parameters.
|
456 |
+
NdfGGXParams p;
|
457 |
+
update_grid(p.gridSize, alphaSqr, cosTheta);
|
458 |
+
|
459 |
+
// Choose launch parameters.
|
460 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
461 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
462 |
+
|
463 |
+
torch::Tensor alphaSqr_grad, cosTheta_grad;
|
464 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);
|
465 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad);
|
466 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
467 |
+
|
468 |
+
// Launch CUDA kernel.
|
469 |
+
void* args[] = { &p };
|
470 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXBwdKernel, gridSize, blockSize, args, 0, stream));
|
471 |
+
|
472 |
+
return std::tuple<torch::Tensor, torch::Tensor>(alphaSqr_grad, cosTheta_grad);
|
473 |
+
}
|
474 |
+
|
475 |
+
//------------------------------------------------------------------------
|
476 |
+
// lambda_ggx
|
477 |
+
|
478 |
+
torch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16)
|
479 |
+
{
|
480 |
+
CHECK_TENSOR(alphaSqr, 4, 1);
|
481 |
+
CHECK_TENSOR(cosTheta, 4, 1);
|
482 |
+
|
483 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
484 |
+
|
485 |
+
// Extract input parameters.
|
486 |
+
NdfGGXParams p;
|
487 |
+
p.out.fp16 = fp16;
|
488 |
+
update_grid(p.gridSize, alphaSqr, cosTheta);
|
489 |
+
|
490 |
+
// Allocate output tensors.
|
491 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
492 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
|
493 |
+
|
494 |
+
// Choose launch parameters.
|
495 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
496 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
497 |
+
|
498 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);
|
499 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);
|
500 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
501 |
+
|
502 |
+
// Launch CUDA kernel.
|
503 |
+
void* args[] = { &p };
|
504 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXFwdKernel, gridSize, blockSize, args, 0, stream));
|
505 |
+
|
506 |
+
return out;
|
507 |
+
}
|
508 |
+
|
509 |
+
std::tuple<torch::Tensor, torch::Tensor> lambda_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad)
|
510 |
+
{
|
511 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
512 |
+
|
513 |
+
// Extract input parameters.
|
514 |
+
NdfGGXParams p;
|
515 |
+
update_grid(p.gridSize, alphaSqr, cosTheta);
|
516 |
+
|
517 |
+
// Choose launch parameters.
|
518 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
519 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
520 |
+
|
521 |
+
torch::Tensor alphaSqr_grad, cosTheta_grad;
|
522 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);
|
523 |
+
p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad);
|
524 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
525 |
+
|
526 |
+
// Launch CUDA kernel.
|
527 |
+
void* args[] = { &p };
|
528 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXBwdKernel, gridSize, blockSize, args, 0, stream));
|
529 |
+
|
530 |
+
return std::tuple<torch::Tensor, torch::Tensor>(alphaSqr_grad, cosTheta_grad);
|
531 |
+
}
|
532 |
+
|
533 |
+
//------------------------------------------------------------------------
|
534 |
+
// masking_smith
|
535 |
+
|
536 |
+
torch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, bool fp16)
|
537 |
+
{
|
538 |
+
CHECK_TENSOR(alphaSqr, 4, 1);
|
539 |
+
CHECK_TENSOR(cosThetaI, 4, 1);
|
540 |
+
CHECK_TENSOR(cosThetaO, 4, 1);
|
541 |
+
|
542 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
543 |
+
|
544 |
+
// Extract input parameters.
|
545 |
+
MaskingSmithParams p;
|
546 |
+
p.out.fp16 = fp16;
|
547 |
+
update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO);
|
548 |
+
|
549 |
+
// Allocate output tensors.
|
550 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
551 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);
|
552 |
+
|
553 |
+
// Choose launch parameters.
|
554 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
555 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
556 |
+
|
557 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);
|
558 |
+
p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize);
|
559 |
+
p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize);
|
560 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
561 |
+
|
562 |
+
// Launch CUDA kernel.
|
563 |
+
void* args[] = { &p };
|
564 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithFwdKernel, gridSize, blockSize, args, 0, stream));
|
565 |
+
|
566 |
+
return out;
|
567 |
+
}
|
568 |
+
|
569 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> masking_smith_bwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, torch::Tensor grad)
|
570 |
+
{
|
571 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
572 |
+
|
573 |
+
// Extract input parameters.
|
574 |
+
MaskingSmithParams p;
|
575 |
+
update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO);
|
576 |
+
|
577 |
+
// Choose launch parameters.
|
578 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
579 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
580 |
+
|
581 |
+
torch::Tensor alphaSqr_grad, cosThetaI_grad, cosThetaO_grad;
|
582 |
+
p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);
|
583 |
+
p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize, &cosThetaI_grad);
|
584 |
+
p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize, &cosThetaO_grad);
|
585 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
586 |
+
|
587 |
+
// Launch CUDA kernel.
|
588 |
+
void* args[] = { &p };
|
589 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithBwdKernel, gridSize, blockSize, args, 0, stream));
|
590 |
+
|
591 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>(alphaSqr_grad, cosThetaI_grad, cosThetaO_grad);
|
592 |
+
}
|
593 |
+
|
594 |
+
//------------------------------------------------------------------------
|
595 |
+
// pbr_specular
|
596 |
+
|
597 |
+
torch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, bool fp16)
|
598 |
+
{
|
599 |
+
CHECK_TENSOR(col, 4, 3);
|
600 |
+
CHECK_TENSOR(nrm, 4, 3);
|
601 |
+
CHECK_TENSOR(wo, 4, 3);
|
602 |
+
CHECK_TENSOR(wi, 4, 3);
|
603 |
+
CHECK_TENSOR(alpha, 4, 1);
|
604 |
+
|
605 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
606 |
+
|
607 |
+
// Extract input parameters.
|
608 |
+
PbrSpecular p;
|
609 |
+
p.out.fp16 = fp16;
|
610 |
+
p.min_roughness = min_roughness;
|
611 |
+
update_grid(p.gridSize, col, nrm, wo, wi, alpha);
|
612 |
+
|
613 |
+
// Allocate output tensors.
|
614 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
615 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
|
616 |
+
|
617 |
+
// Choose launch parameters.
|
618 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
619 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
620 |
+
|
621 |
+
p.col = make_cuda_tensor(col, p.gridSize);
|
622 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize);
|
623 |
+
p.wo = make_cuda_tensor(wo, p.gridSize);
|
624 |
+
p.wi = make_cuda_tensor(wi, p.gridSize);
|
625 |
+
p.alpha = make_cuda_tensor(alpha, p.gridSize);
|
626 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
627 |
+
|
628 |
+
// Launch CUDA kernel.
|
629 |
+
void* args[] = { &p };
|
630 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularFwdKernel, gridSize, blockSize, args, 0, stream));
|
631 |
+
|
632 |
+
return out;
|
633 |
+
}
|
634 |
+
|
635 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> pbr_specular_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, torch::Tensor grad)
|
636 |
+
{
|
637 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
638 |
+
|
639 |
+
// Extract input parameters.
|
640 |
+
PbrSpecular p;
|
641 |
+
update_grid(p.gridSize, col, nrm, wo, wi, alpha);
|
642 |
+
p.min_roughness = min_roughness;
|
643 |
+
|
644 |
+
// Choose launch parameters.
|
645 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
646 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
647 |
+
|
648 |
+
torch::Tensor col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad;
|
649 |
+
p.col = make_cuda_tensor(col, p.gridSize, &col_grad);
|
650 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
|
651 |
+
p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad);
|
652 |
+
p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);
|
653 |
+
p.alpha = make_cuda_tensor(alpha, p.gridSize, &alpha_grad);
|
654 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
655 |
+
|
656 |
+
// Launch CUDA kernel.
|
657 |
+
void* args[] = { &p };
|
658 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularBwdKernel, gridSize, blockSize, args, 0, stream));
|
659 |
+
|
660 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad);
|
661 |
+
}
|
662 |
+
|
663 |
+
//------------------------------------------------------------------------
|
664 |
+
// pbr_bsdf
|
665 |
+
|
666 |
+
torch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, bool fp16)
|
667 |
+
{
|
668 |
+
CHECK_TENSOR(kd, 4, 3);
|
669 |
+
CHECK_TENSOR(arm, 4, 3);
|
670 |
+
CHECK_TENSOR(pos, 4, 3);
|
671 |
+
CHECK_TENSOR(nrm, 4, 3);
|
672 |
+
CHECK_TENSOR(view_pos, 4, 3);
|
673 |
+
CHECK_TENSOR(light_pos, 4, 3);
|
674 |
+
|
675 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
676 |
+
|
677 |
+
// Extract input parameters.
|
678 |
+
PbrBSDF p;
|
679 |
+
p.out.fp16 = fp16;
|
680 |
+
p.min_roughness = min_roughness;
|
681 |
+
p.BSDF = BSDF;
|
682 |
+
update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos);
|
683 |
+
|
684 |
+
// Allocate output tensors.
|
685 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
686 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
|
687 |
+
|
688 |
+
// Choose launch parameters.
|
689 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
690 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
691 |
+
|
692 |
+
p.kd = make_cuda_tensor(kd, p.gridSize);
|
693 |
+
p.arm = make_cuda_tensor(arm, p.gridSize);
|
694 |
+
p.pos = make_cuda_tensor(pos, p.gridSize);
|
695 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize);
|
696 |
+
p.view_pos = make_cuda_tensor(view_pos, p.gridSize);
|
697 |
+
p.light_pos = make_cuda_tensor(light_pos, p.gridSize);
|
698 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
699 |
+
|
700 |
+
// Launch CUDA kernel.
|
701 |
+
void* args[] = { &p };
|
702 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFFwdKernel, gridSize, blockSize, args, 0, stream));
|
703 |
+
|
704 |
+
return out;
|
705 |
+
}
|
706 |
+
|
707 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> pbr_bsdf_bwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, torch::Tensor grad)
|
708 |
+
{
|
709 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
710 |
+
|
711 |
+
// Extract input parameters.
|
712 |
+
PbrBSDF p;
|
713 |
+
update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos);
|
714 |
+
p.min_roughness = min_roughness;
|
715 |
+
p.BSDF = BSDF;
|
716 |
+
|
717 |
+
// Choose launch parameters.
|
718 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
719 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
720 |
+
|
721 |
+
torch::Tensor kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad;
|
722 |
+
p.kd = make_cuda_tensor(kd, p.gridSize, &kd_grad);
|
723 |
+
p.arm = make_cuda_tensor(arm, p.gridSize, &arm_grad);
|
724 |
+
p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad);
|
725 |
+
p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);
|
726 |
+
p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad);
|
727 |
+
p.light_pos = make_cuda_tensor(light_pos, p.gridSize, &light_pos_grad);
|
728 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
729 |
+
|
730 |
+
// Launch CUDA kernel.
|
731 |
+
void* args[] = { &p };
|
732 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFBwdKernel, gridSize, blockSize, args, 0, stream));
|
733 |
+
|
734 |
+
return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad);
|
735 |
+
}
|
736 |
+
|
737 |
+
//------------------------------------------------------------------------
|
738 |
+
// filter_cubemap
|
739 |
+
|
740 |
+
torch::Tensor diffuse_cubemap_fwd(torch::Tensor cubemap)
|
741 |
+
{
|
742 |
+
CHECK_TENSOR(cubemap, 4, 3);
|
743 |
+
|
744 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
745 |
+
|
746 |
+
// Extract input parameters.
|
747 |
+
DiffuseCubemapKernelParams p;
|
748 |
+
update_grid(p.gridSize, cubemap);
|
749 |
+
|
750 |
+
// Allocate output tensors.
|
751 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
752 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);
|
753 |
+
|
754 |
+
// Choose launch parameters.
|
755 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
756 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
757 |
+
|
758 |
+
// Setup tensors
|
759 |
+
p.cubemap = make_cuda_tensor(cubemap, p.gridSize);
|
760 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
761 |
+
|
762 |
+
// Launch CUDA kernel.
|
763 |
+
void* args[] = { &p };
|
764 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapFwdKernel, gridSize, blockSize, args, 0, stream));
|
765 |
+
|
766 |
+
return out;
|
767 |
+
}
|
768 |
+
|
769 |
+
torch::Tensor diffuse_cubemap_bwd(torch::Tensor cubemap, torch::Tensor grad)
|
770 |
+
{
|
771 |
+
CHECK_TENSOR(cubemap, 4, 3);
|
772 |
+
CHECK_TENSOR(grad, 4, 3);
|
773 |
+
|
774 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
775 |
+
|
776 |
+
// Extract input parameters.
|
777 |
+
DiffuseCubemapKernelParams p;
|
778 |
+
update_grid(p.gridSize, cubemap);
|
779 |
+
|
780 |
+
// Choose launch parameters.
|
781 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
782 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
783 |
+
|
784 |
+
// Setup tensors
|
785 |
+
torch::Tensor cubemap_grad;
|
786 |
+
p.cubemap = make_cuda_tensor(cubemap, p.gridSize);
|
787 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
788 |
+
|
789 |
+
cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
|
790 |
+
p.cubemap.d_val = (void*)cubemap_grad.data_ptr<float>();
|
791 |
+
|
792 |
+
// Launch CUDA kernel.
|
793 |
+
void* args[] = { &p };
|
794 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapBwdKernel, gridSize, blockSize, args, 0, stream));
|
795 |
+
|
796 |
+
return cubemap_grad;
|
797 |
+
}
|
798 |
+
|
799 |
+
torch::Tensor specular_bounds(int resolution, float costheta_cutoff)
|
800 |
+
{
|
801 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
802 |
+
|
803 |
+
// Extract input parameters.
|
804 |
+
SpecularBoundsKernelParams p;
|
805 |
+
p.costheta_cutoff = costheta_cutoff;
|
806 |
+
p.gridSize = dim3(resolution, resolution, 6);
|
807 |
+
|
808 |
+
// Allocate output tensors.
|
809 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
810 |
+
torch::Tensor out = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 6*4 }, opts);
|
811 |
+
|
812 |
+
// Choose launch parameters.
|
813 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
814 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
815 |
+
|
816 |
+
// Setup tensors
|
817 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
818 |
+
|
819 |
+
// Launch CUDA kernel.
|
820 |
+
void* args[] = { &p };
|
821 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularBoundsKernel, gridSize, blockSize, args, 0, stream));
|
822 |
+
|
823 |
+
return out;
|
824 |
+
}
|
825 |
+
|
826 |
+
torch::Tensor specular_cubemap_fwd(torch::Tensor cubemap, torch::Tensor bounds, float roughness, float costheta_cutoff)
|
827 |
+
{
|
828 |
+
CHECK_TENSOR(cubemap, 4, 3);
|
829 |
+
CHECK_TENSOR(bounds, 4, 6*4);
|
830 |
+
|
831 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
832 |
+
|
833 |
+
// Extract input parameters.
|
834 |
+
SpecularCubemapKernelParams p;
|
835 |
+
p.roughness = roughness;
|
836 |
+
p.costheta_cutoff = costheta_cutoff;
|
837 |
+
update_grid(p.gridSize, cubemap);
|
838 |
+
|
839 |
+
// Allocate output tensors.
|
840 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
841 |
+
torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 4 }, opts);
|
842 |
+
|
843 |
+
// Choose launch parameters.
|
844 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
845 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
846 |
+
|
847 |
+
// Setup tensors
|
848 |
+
p.cubemap = make_cuda_tensor(cubemap, p.gridSize);
|
849 |
+
p.bounds = make_cuda_tensor(bounds, p.gridSize);
|
850 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
851 |
+
|
852 |
+
// Launch CUDA kernel.
|
853 |
+
void* args[] = { &p };
|
854 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapFwdKernel, gridSize, blockSize, args, 0, stream));
|
855 |
+
|
856 |
+
return out;
|
857 |
+
}
|
858 |
+
|
859 |
+
torch::Tensor specular_cubemap_bwd(torch::Tensor cubemap, torch::Tensor bounds, torch::Tensor grad, float roughness, float costheta_cutoff)
|
860 |
+
{
|
861 |
+
CHECK_TENSOR(cubemap, 4, 3);
|
862 |
+
CHECK_TENSOR(bounds, 4, 6*4);
|
863 |
+
|
864 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
865 |
+
|
866 |
+
// Extract input parameters.
|
867 |
+
SpecularCubemapKernelParams p;
|
868 |
+
p.roughness = roughness;
|
869 |
+
p.costheta_cutoff = costheta_cutoff;
|
870 |
+
update_grid(p.gridSize, cubemap);
|
871 |
+
|
872 |
+
// Choose launch parameters.
|
873 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
874 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
875 |
+
|
876 |
+
// Setup tensors
|
877 |
+
torch::Tensor cubemap_grad;
|
878 |
+
p.cubemap = make_cuda_tensor(cubemap, p.gridSize);
|
879 |
+
p.bounds = make_cuda_tensor(bounds, p.gridSize);
|
880 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
881 |
+
|
882 |
+
cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
|
883 |
+
p.cubemap.d_val = (void*)cubemap_grad.data_ptr<float>();
|
884 |
+
|
885 |
+
// Launch CUDA kernel.
|
886 |
+
void* args[] = { &p };
|
887 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapBwdKernel, gridSize, blockSize, args, 0, stream));
|
888 |
+
|
889 |
+
return cubemap_grad;
|
890 |
+
}
|
891 |
+
|
892 |
+
//------------------------------------------------------------------------
|
893 |
+
// loss function
|
894 |
+
|
895 |
+
LossType strToLoss(std::string str)
|
896 |
+
{
|
897 |
+
if (str == "mse")
|
898 |
+
return LOSS_MSE;
|
899 |
+
else if (str == "relmse")
|
900 |
+
return LOSS_RELMSE;
|
901 |
+
else if (str == "smape")
|
902 |
+
return LOSS_SMAPE;
|
903 |
+
else
|
904 |
+
return LOSS_L1;
|
905 |
+
}
|
906 |
+
|
907 |
+
torch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, std::string loss, std::string tonemapper, bool fp16)
|
908 |
+
{
|
909 |
+
CHECK_TENSOR(img, 4, 3);
|
910 |
+
CHECK_TENSOR(target, 4, 3);
|
911 |
+
|
912 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
913 |
+
|
914 |
+
// Extract input parameters.
|
915 |
+
LossKernelParams p;
|
916 |
+
p.out.fp16 = fp16;
|
917 |
+
p.loss = strToLoss(loss);
|
918 |
+
p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE;
|
919 |
+
update_grid(p.gridSize, img, target);
|
920 |
+
|
921 |
+
// Choose launch parameters.
|
922 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
923 |
+
dim3 warpSize = getWarpSize(blockSize);
|
924 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
925 |
+
|
926 |
+
// Allocate output tensors.
|
927 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
928 |
+
torch::Tensor out = torch::empty({ (p.gridSize.z - 1)/ warpSize.z + 1, (p.gridSize.y - 1) / warpSize.y + 1, (p.gridSize.x - 1) / warpSize.x + 1, 1 }, opts);
|
929 |
+
|
930 |
+
p.img = make_cuda_tensor(img, p.gridSize);
|
931 |
+
p.target = make_cuda_tensor(target, p.gridSize);
|
932 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
933 |
+
|
934 |
+
// Launch CUDA kernel.
|
935 |
+
void* args[] = { &p };
|
936 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossFwdKernel, gridSize, blockSize, args, 0, stream));
|
937 |
+
|
938 |
+
return out;
|
939 |
+
}
|
940 |
+
|
941 |
+
std::tuple<torch::Tensor, torch::Tensor> image_loss_bwd(torch::Tensor img, torch::Tensor target, torch::Tensor grad, std::string loss, std::string tonemapper)
|
942 |
+
{
|
943 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
944 |
+
|
945 |
+
// Extract input parameters.
|
946 |
+
LossKernelParams p;
|
947 |
+
p.loss = strToLoss(loss);
|
948 |
+
p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE;
|
949 |
+
update_grid(p.gridSize, img, target);
|
950 |
+
|
951 |
+
// Choose launch parameters.
|
952 |
+
dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);
|
953 |
+
dim3 warpSize = getWarpSize(blockSize);
|
954 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
955 |
+
|
956 |
+
torch::Tensor img_grad, target_grad;
|
957 |
+
p.img = make_cuda_tensor(img, p.gridSize, &img_grad);
|
958 |
+
p.target = make_cuda_tensor(target, p.gridSize, &target_grad);
|
959 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
960 |
+
|
961 |
+
// Launch CUDA kernel.
|
962 |
+
void* args[] = { &p };
|
963 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossBwdKernel, gridSize, blockSize, args, 0, stream));
|
964 |
+
|
965 |
+
return std::tuple<torch::Tensor, torch::Tensor>(img_grad, target_grad);
|
966 |
+
}
|
967 |
+
|
968 |
+
//------------------------------------------------------------------------
|
969 |
+
// transform function
|
970 |
+
|
971 |
+
torch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool isPoints, bool fp16)
|
972 |
+
{
|
973 |
+
CHECK_TENSOR(points, 3, 3);
|
974 |
+
CHECK_TENSOR(matrix, 3, 4);
|
975 |
+
|
976 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
977 |
+
|
978 |
+
// Extract input parameters.
|
979 |
+
XfmKernelParams p;
|
980 |
+
p.out.fp16 = fp16;
|
981 |
+
p.isPoints = isPoints;
|
982 |
+
p.gridSize.x = points.size(1);
|
983 |
+
p.gridSize.y = 1;
|
984 |
+
p.gridSize.z = std::max(matrix.size(0), points.size(0));
|
985 |
+
|
986 |
+
// Choose launch parameters.
|
987 |
+
dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1);
|
988 |
+
dim3 warpSize = getWarpSize(blockSize);
|
989 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
990 |
+
|
991 |
+
// Allocate output tensors.
|
992 |
+
torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);
|
993 |
+
torch::Tensor out = isPoints ? torch::empty({ matrix.size(0), points.size(1), 4 }, opts) : torch::empty({ matrix.size(0), points.size(1), 3 }, opts);
|
994 |
+
|
995 |
+
p.points = make_cuda_tensor(points, p.gridSize);
|
996 |
+
p.matrix = make_cuda_tensor(matrix, p.gridSize);
|
997 |
+
p.out = make_cuda_tensor(out, p.gridSize);
|
998 |
+
|
999 |
+
// Launch CUDA kernel.
|
1000 |
+
void* args[] = { &p };
|
1001 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsFwdKernel, gridSize, blockSize, args, 0, stream));
|
1002 |
+
|
1003 |
+
return out;
|
1004 |
+
}
|
1005 |
+
|
1006 |
+
torch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch::Tensor grad, bool isPoints)
|
1007 |
+
{
|
1008 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
1009 |
+
|
1010 |
+
// Extract input parameters.
|
1011 |
+
XfmKernelParams p;
|
1012 |
+
p.isPoints = isPoints;
|
1013 |
+
p.gridSize.x = points.size(1);
|
1014 |
+
p.gridSize.y = 1;
|
1015 |
+
p.gridSize.z = std::max(matrix.size(0), points.size(0));
|
1016 |
+
|
1017 |
+
// Choose launch parameters.
|
1018 |
+
dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1);
|
1019 |
+
dim3 warpSize = getWarpSize(blockSize);
|
1020 |
+
dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);
|
1021 |
+
|
1022 |
+
torch::Tensor points_grad;
|
1023 |
+
p.points = make_cuda_tensor(points, p.gridSize, &points_grad);
|
1024 |
+
p.matrix = make_cuda_tensor(matrix, p.gridSize);
|
1025 |
+
p.out = make_cuda_tensor(grad, p.gridSize);
|
1026 |
+
|
1027 |
+
// Launch CUDA kernel.
|
1028 |
+
void* args[] = { &p };
|
1029 |
+
NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsBwdKernel, gridSize, blockSize, args, 0, stream));
|
1030 |
+
|
1031 |
+
return points_grad;
|
1032 |
+
}
|
1033 |
+
|
1034 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
1035 |
+
m.def("prepare_shading_normal_fwd", &prepare_shading_normal_fwd, "prepare_shading_normal_fwd");
|
1036 |
+
m.def("prepare_shading_normal_bwd", &prepare_shading_normal_bwd, "prepare_shading_normal_bwd");
|
1037 |
+
m.def("lambert_fwd", &lambert_fwd, "lambert_fwd");
|
1038 |
+
m.def("lambert_bwd", &lambert_bwd, "lambert_bwd");
|
1039 |
+
m.def("frostbite_fwd", &frostbite_fwd, "frostbite_fwd");
|
1040 |
+
m.def("frostbite_bwd", &frostbite_bwd, "frostbite_bwd");
|
1041 |
+
m.def("fresnel_shlick_fwd", &fresnel_shlick_fwd, "fresnel_shlick_fwd");
|
1042 |
+
m.def("fresnel_shlick_bwd", &fresnel_shlick_bwd, "fresnel_shlick_bwd");
|
1043 |
+
m.def("ndf_ggx_fwd", &ndf_ggx_fwd, "ndf_ggx_fwd");
|
1044 |
+
m.def("ndf_ggx_bwd", &ndf_ggx_bwd, "ndf_ggx_bwd");
|
1045 |
+
m.def("lambda_ggx_fwd", &lambda_ggx_fwd, "lambda_ggx_fwd");
|
1046 |
+
m.def("lambda_ggx_bwd", &lambda_ggx_bwd, "lambda_ggx_bwd");
|
1047 |
+
m.def("masking_smith_fwd", &masking_smith_fwd, "masking_smith_fwd");
|
1048 |
+
m.def("masking_smith_bwd", &masking_smith_bwd, "masking_smith_bwd");
|
1049 |
+
m.def("pbr_specular_fwd", &pbr_specular_fwd, "pbr_specular_fwd");
|
1050 |
+
m.def("pbr_specular_bwd", &pbr_specular_bwd, "pbr_specular_bwd");
|
1051 |
+
m.def("pbr_bsdf_fwd", &pbr_bsdf_fwd, "pbr_bsdf_fwd");
|
1052 |
+
m.def("pbr_bsdf_bwd", &pbr_bsdf_bwd, "pbr_bsdf_bwd");
|
1053 |
+
m.def("diffuse_cubemap_fwd", &diffuse_cubemap_fwd, "diffuse_cubemap_fwd");
|
1054 |
+
m.def("diffuse_cubemap_bwd", &diffuse_cubemap_bwd, "diffuse_cubemap_bwd");
|
1055 |
+
m.def("specular_bounds", &specular_bounds, "specular_bounds");
|
1056 |
+
m.def("specular_cubemap_fwd", &specular_cubemap_fwd, "specular_cubemap_fwd");
|
1057 |
+
m.def("specular_cubemap_bwd", &specular_cubemap_bwd, "specular_cubemap_bwd");
|
1058 |
+
m.def("image_loss_fwd", &image_loss_fwd, "image_loss_fwd");
|
1059 |
+
m.def("image_loss_bwd", &image_loss_bwd, "image_loss_bwd");
|
1060 |
+
m.def("xfm_fwd", &xfm_fwd, "xfm_fwd");
|
1061 |
+
m.def("xfm_bwd", &xfm_bwd, "xfm_bwd");
|
1062 |
+
}
|
video3d/render/renderutils/c_src/vec3f.h
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
struct vec3f
|
15 |
+
{
|
16 |
+
float x, y, z;
|
17 |
+
|
18 |
+
#ifdef __CUDACC__
|
19 |
+
__device__ vec3f() { }
|
20 |
+
__device__ vec3f(float v) { x = v; y = v; z = v; }
|
21 |
+
__device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; }
|
22 |
+
__device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; }
|
23 |
+
|
24 |
+
__device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; }
|
25 |
+
__device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; }
|
26 |
+
__device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; }
|
27 |
+
__device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; }
|
28 |
+
#endif
|
29 |
+
};
|
30 |
+
|
31 |
+
#ifdef __CUDACC__
|
32 |
+
__device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); }
|
33 |
+
__device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); }
|
34 |
+
__device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); }
|
35 |
+
__device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); }
|
36 |
+
__device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); }
|
37 |
+
|
38 |
+
__device__ static inline float sum(vec3f a)
|
39 |
+
{
|
40 |
+
return a.x + a.y + a.z;
|
41 |
+
}
|
42 |
+
|
43 |
+
__device__ static inline vec3f cross(vec3f a, vec3f b)
|
44 |
+
{
|
45 |
+
vec3f out;
|
46 |
+
out.x = a.y * b.z - a.z * b.y;
|
47 |
+
out.y = a.z * b.x - a.x * b.z;
|
48 |
+
out.z = a.x * b.y - a.y * b.x;
|
49 |
+
return out;
|
50 |
+
}
|
51 |
+
|
52 |
+
__device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out)
|
53 |
+
{
|
54 |
+
d_a.x += d_out.z * b.y - d_out.y * b.z;
|
55 |
+
d_a.y += d_out.x * b.z - d_out.z * b.x;
|
56 |
+
d_a.z += d_out.y * b.x - d_out.x * b.y;
|
57 |
+
|
58 |
+
d_b.x += d_out.y * a.z - d_out.z * a.y;
|
59 |
+
d_b.y += d_out.z * a.x - d_out.x * a.z;
|
60 |
+
d_b.z += d_out.x * a.y - d_out.y * a.x;
|
61 |
+
}
|
62 |
+
|
63 |
+
__device__ static inline float dot(vec3f a, vec3f b)
|
64 |
+
{
|
65 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
66 |
+
}
|
67 |
+
|
68 |
+
__device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out)
|
69 |
+
{
|
70 |
+
d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z;
|
71 |
+
d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z;
|
72 |
+
}
|
73 |
+
|
74 |
+
__device__ static inline vec3f reflect(vec3f x, vec3f n)
|
75 |
+
{
|
76 |
+
return n * 2.0f * dot(n, x) - x;
|
77 |
+
}
|
78 |
+
|
79 |
+
__device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out)
|
80 |
+
{
|
81 |
+
d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z);
|
82 |
+
d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z);
|
83 |
+
d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1);
|
84 |
+
|
85 |
+
d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x);
|
86 |
+
d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y);
|
87 |
+
d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z));
|
88 |
+
}
|
89 |
+
|
90 |
+
__device__ static inline vec3f safeNormalize(vec3f v)
|
91 |
+
{
|
92 |
+
float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
|
93 |
+
return l > 0.0f ? (v / l) : vec3f(0.0f);
|
94 |
+
}
|
95 |
+
|
96 |
+
__device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out)
|
97 |
+
{
|
98 |
+
|
99 |
+
float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);
|
100 |
+
if (l > 0.0f)
|
101 |
+
{
|
102 |
+
float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f);
|
103 |
+
d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac;
|
104 |
+
d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac;
|
105 |
+
d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac;
|
106 |
+
}
|
107 |
+
}
|
108 |
+
|
109 |
+
#endif
|
video3d/render/renderutils/c_src/vec4f.h
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
struct vec4f
|
15 |
+
{
|
16 |
+
float x, y, z, w;
|
17 |
+
|
18 |
+
#ifdef __CUDACC__
|
19 |
+
__device__ vec4f() { }
|
20 |
+
__device__ vec4f(float v) { x = v; y = v; z = v; w = v; }
|
21 |
+
__device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; }
|
22 |
+
__device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; }
|
23 |
+
#endif
|
24 |
+
};
|
25 |
+
|
video3d/render/renderutils/loss.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
#----------------------------------------------------------------------------
|
13 |
+
# HDR image losses
|
14 |
+
#----------------------------------------------------------------------------
|
15 |
+
|
16 |
+
def _tonemap_srgb(f):
|
17 |
+
return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)
|
18 |
+
|
19 |
+
def _SMAPE(img, target, eps=0.01):
|
20 |
+
nom = torch.abs(img - target)
|
21 |
+
denom = torch.abs(img) + torch.abs(target) + 0.01
|
22 |
+
return torch.mean(nom / denom)
|
23 |
+
|
24 |
+
def _RELMSE(img, target, eps=0.1):
|
25 |
+
nom = (img - target) * (img - target)
|
26 |
+
denom = img * img + target * target + 0.1
|
27 |
+
return torch.mean(nom / denom)
|
28 |
+
|
29 |
+
def image_loss_fn(img, target, loss, tonemapper):
|
30 |
+
if tonemapper == 'log_srgb':
|
31 |
+
img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1))
|
32 |
+
target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1))
|
33 |
+
|
34 |
+
if loss == 'mse':
|
35 |
+
return torch.nn.functional.mse_loss(img, target)
|
36 |
+
elif loss == 'smape':
|
37 |
+
return _SMAPE(img, target)
|
38 |
+
elif loss == 'relmse':
|
39 |
+
return _RELMSE(img, target)
|
40 |
+
else:
|
41 |
+
return torch.nn.functional.l1_loss(img, target)
|
video3d/render/renderutils/ops.py
ADDED
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import torch
|
14 |
+
import torch.utils.cpp_extension
|
15 |
+
|
16 |
+
from .bsdf import *
|
17 |
+
from .loss import *
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
# C++/Cuda plugin compiler/loader.
|
21 |
+
|
22 |
+
_cached_plugin = None
|
23 |
+
def _get_plugin():
|
24 |
+
# Return cached plugin if already loaded.
|
25 |
+
global _cached_plugin
|
26 |
+
if _cached_plugin is not None:
|
27 |
+
return _cached_plugin
|
28 |
+
|
29 |
+
# Make sure we can find the necessary compiler and libary binaries.
|
30 |
+
if os.name == 'nt':
|
31 |
+
def find_cl_path():
|
32 |
+
import glob
|
33 |
+
for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']:
|
34 |
+
paths = sorted(glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition), reverse=True)
|
35 |
+
if paths:
|
36 |
+
return paths[0]
|
37 |
+
|
38 |
+
# If cl.exe is not on path, try to find it.
|
39 |
+
if os.system("where cl.exe >nul 2>nul") != 0:
|
40 |
+
cl_path = find_cl_path()
|
41 |
+
if cl_path is None:
|
42 |
+
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
43 |
+
os.environ['PATH'] += ';' + cl_path
|
44 |
+
|
45 |
+
# Compiler options.
|
46 |
+
opts = ['-DNVDR_TORCH']
|
47 |
+
|
48 |
+
# Linker options.
|
49 |
+
if os.name == 'posix':
|
50 |
+
ldflags = ['-lcuda', '-lnvrtc']
|
51 |
+
elif os.name == 'nt':
|
52 |
+
ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib']
|
53 |
+
|
54 |
+
# List of sources.
|
55 |
+
source_files = [
|
56 |
+
'c_src/mesh.cu',
|
57 |
+
'c_src/loss.cu',
|
58 |
+
'c_src/bsdf.cu',
|
59 |
+
'c_src/normal.cu',
|
60 |
+
'c_src/cubemap.cu',
|
61 |
+
'c_src/common.cpp',
|
62 |
+
'c_src/torch_bindings.cpp'
|
63 |
+
]
|
64 |
+
|
65 |
+
# Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine.
|
66 |
+
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
67 |
+
|
68 |
+
# Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment.
|
69 |
+
try:
|
70 |
+
lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory('renderutils_plugin', False), 'lock')
|
71 |
+
if os.path.exists(lock_fn):
|
72 |
+
print("Warning: Lock file exists in build directory: '%s'" % lock_fn)
|
73 |
+
except:
|
74 |
+
pass
|
75 |
+
|
76 |
+
# Compile and load.
|
77 |
+
source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files]
|
78 |
+
torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_cflags=opts,
|
79 |
+
extra_cuda_cflags=opts, extra_ldflags=ldflags, with_cuda=True, verbose=True)
|
80 |
+
|
81 |
+
# Import, cache, and return the compiled module.
|
82 |
+
import renderutils_plugin
|
83 |
+
_cached_plugin = renderutils_plugin
|
84 |
+
return _cached_plugin
|
85 |
+
|
86 |
+
#----------------------------------------------------------------------------
|
87 |
+
# Internal kernels, just used for testing functionality
|
88 |
+
|
89 |
+
class _fresnel_shlick_func(torch.autograd.Function):
|
90 |
+
@staticmethod
|
91 |
+
def forward(ctx, f0, f90, cosTheta):
|
92 |
+
out = _get_plugin().fresnel_shlick_fwd(f0, f90, cosTheta, False)
|
93 |
+
ctx.save_for_backward(f0, f90, cosTheta)
|
94 |
+
return out
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def backward(ctx, dout):
|
98 |
+
f0, f90, cosTheta = ctx.saved_variables
|
99 |
+
return _get_plugin().fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,)
|
100 |
+
|
101 |
+
def _fresnel_shlick(f0, f90, cosTheta, use_python=False):
|
102 |
+
if use_python:
|
103 |
+
out = bsdf_fresnel_shlick(f0, f90, cosTheta)
|
104 |
+
else:
|
105 |
+
out = _fresnel_shlick_func.apply(f0, f90, cosTheta)
|
106 |
+
|
107 |
+
if torch.is_anomaly_enabled():
|
108 |
+
assert torch.all(torch.isfinite(out)), "Output of _fresnel_shlick contains inf or NaN"
|
109 |
+
return out
|
110 |
+
|
111 |
+
|
112 |
+
class _ndf_ggx_func(torch.autograd.Function):
|
113 |
+
@staticmethod
|
114 |
+
def forward(ctx, alphaSqr, cosTheta):
|
115 |
+
out = _get_plugin().ndf_ggx_fwd(alphaSqr, cosTheta, False)
|
116 |
+
ctx.save_for_backward(alphaSqr, cosTheta)
|
117 |
+
return out
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def backward(ctx, dout):
|
121 |
+
alphaSqr, cosTheta = ctx.saved_variables
|
122 |
+
return _get_plugin().ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)
|
123 |
+
|
124 |
+
def _ndf_ggx(alphaSqr, cosTheta, use_python=False):
|
125 |
+
if use_python:
|
126 |
+
out = bsdf_ndf_ggx(alphaSqr, cosTheta)
|
127 |
+
else:
|
128 |
+
out = _ndf_ggx_func.apply(alphaSqr, cosTheta)
|
129 |
+
|
130 |
+
if torch.is_anomaly_enabled():
|
131 |
+
assert torch.all(torch.isfinite(out)), "Output of _ndf_ggx contains inf or NaN"
|
132 |
+
return out
|
133 |
+
|
134 |
+
class _lambda_ggx_func(torch.autograd.Function):
|
135 |
+
@staticmethod
|
136 |
+
def forward(ctx, alphaSqr, cosTheta):
|
137 |
+
out = _get_plugin().lambda_ggx_fwd(alphaSqr, cosTheta, False)
|
138 |
+
ctx.save_for_backward(alphaSqr, cosTheta)
|
139 |
+
return out
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def backward(ctx, dout):
|
143 |
+
alphaSqr, cosTheta = ctx.saved_variables
|
144 |
+
return _get_plugin().lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)
|
145 |
+
|
146 |
+
def _lambda_ggx(alphaSqr, cosTheta, use_python=False):
|
147 |
+
if use_python:
|
148 |
+
out = bsdf_lambda_ggx(alphaSqr, cosTheta)
|
149 |
+
else:
|
150 |
+
out = _lambda_ggx_func.apply(alphaSqr, cosTheta)
|
151 |
+
|
152 |
+
if torch.is_anomaly_enabled():
|
153 |
+
assert torch.all(torch.isfinite(out)), "Output of _lambda_ggx contains inf or NaN"
|
154 |
+
return out
|
155 |
+
|
156 |
+
class _masking_smith_func(torch.autograd.Function):
|
157 |
+
@staticmethod
|
158 |
+
def forward(ctx, alphaSqr, cosThetaI, cosThetaO):
|
159 |
+
ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO)
|
160 |
+
out = _get_plugin().masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False)
|
161 |
+
return out
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def backward(ctx, dout):
|
165 |
+
alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables
|
166 |
+
return _get_plugin().masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,)
|
167 |
+
|
168 |
+
def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False):
|
169 |
+
if use_python:
|
170 |
+
out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO)
|
171 |
+
else:
|
172 |
+
out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO)
|
173 |
+
|
174 |
+
if torch.is_anomaly_enabled():
|
175 |
+
assert torch.all(torch.isfinite(out)), "Output of _masking_smith contains inf or NaN"
|
176 |
+
return out
|
177 |
+
|
178 |
+
#----------------------------------------------------------------------------
|
179 |
+
# Shading normal setup (bump mapping + bent normals)
|
180 |
+
|
181 |
+
class _prepare_shading_normal_func(torch.autograd.Function):
|
182 |
+
@staticmethod
|
183 |
+
def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
|
184 |
+
ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl
|
185 |
+
out = _get_plugin().prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False)
|
186 |
+
ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm)
|
187 |
+
return out
|
188 |
+
|
189 |
+
@staticmethod
|
190 |
+
def backward(ctx, dout):
|
191 |
+
pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables
|
192 |
+
return _get_plugin().prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None)
|
193 |
+
|
194 |
+
def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False):
|
195 |
+
'''Takes care of all corner cases and produces a final normal used for shading:
|
196 |
+
- Constructs tangent space
|
197 |
+
- Flips normal direction based on geometric normal for two sided Shading
|
198 |
+
- Perturbs shading normal by normal map
|
199 |
+
- Bends backfacing normals towards the camera to avoid shading artifacts
|
200 |
+
|
201 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
pos: World space g-buffer position.
|
205 |
+
view_pos: Camera position in world space (typically using broadcasting).
|
206 |
+
perturbed_nrm: Trangent-space normal perturbation from normal map lookup.
|
207 |
+
smooth_nrm: Interpolated vertex normals.
|
208 |
+
smooth_tng: Interpolated vertex tangents.
|
209 |
+
geom_nrm: Geometric (face) normals.
|
210 |
+
two_sided_shading: Use one/two sided shading
|
211 |
+
opengl: Use OpenGL/DirectX normal map conventions
|
212 |
+
use_python: Use PyTorch implementation (for validation)
|
213 |
+
Returns:
|
214 |
+
Final shading normal
|
215 |
+
'''
|
216 |
+
|
217 |
+
if perturbed_nrm is None:
|
218 |
+
perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...]
|
219 |
+
|
220 |
+
if use_python:
|
221 |
+
out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)
|
222 |
+
else:
|
223 |
+
out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)
|
224 |
+
|
225 |
+
if torch.is_anomaly_enabled():
|
226 |
+
assert torch.all(torch.isfinite(out)), "Output of prepare_shading_normal contains inf or NaN"
|
227 |
+
return out
|
228 |
+
|
229 |
+
#----------------------------------------------------------------------------
|
230 |
+
# BSDF functions
|
231 |
+
|
232 |
+
class _lambert_func(torch.autograd.Function):
|
233 |
+
@staticmethod
|
234 |
+
def forward(ctx, nrm, wi):
|
235 |
+
out = _get_plugin().lambert_fwd(nrm, wi, False)
|
236 |
+
ctx.save_for_backward(nrm, wi)
|
237 |
+
return out
|
238 |
+
|
239 |
+
@staticmethod
|
240 |
+
def backward(ctx, dout):
|
241 |
+
nrm, wi = ctx.saved_variables
|
242 |
+
return _get_plugin().lambert_bwd(nrm, wi, dout) + (None,)
|
243 |
+
|
244 |
+
def lambert(nrm, wi, use_python=False):
|
245 |
+
'''Lambertian bsdf.
|
246 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
nrm: World space shading normal.
|
250 |
+
wi: World space light vector.
|
251 |
+
use_python: Use PyTorch implementation (for validation)
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
Shaded diffuse value with shape [minibatch_size, height, width, 1]
|
255 |
+
'''
|
256 |
+
|
257 |
+
if use_python:
|
258 |
+
out = bsdf_lambert(nrm, wi)
|
259 |
+
else:
|
260 |
+
out = _lambert_func.apply(nrm, wi)
|
261 |
+
|
262 |
+
if torch.is_anomaly_enabled():
|
263 |
+
assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN"
|
264 |
+
return out
|
265 |
+
|
266 |
+
class _frostbite_diffuse_func(torch.autograd.Function):
|
267 |
+
@staticmethod
|
268 |
+
def forward(ctx, nrm, wi, wo, linearRoughness):
|
269 |
+
out = _get_plugin().frostbite_fwd(nrm, wi, wo, linearRoughness, False)
|
270 |
+
ctx.save_for_backward(nrm, wi, wo, linearRoughness)
|
271 |
+
return out
|
272 |
+
|
273 |
+
@staticmethod
|
274 |
+
def backward(ctx, dout):
|
275 |
+
nrm, wi, wo, linearRoughness = ctx.saved_variables
|
276 |
+
return _get_plugin().frostbite_bwd(nrm, wi, wo, linearRoughness, dout) + (None,)
|
277 |
+
|
278 |
+
def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False):
|
279 |
+
'''Frostbite, normalized Disney Diffuse bsdf.
|
280 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
nrm: World space shading normal.
|
284 |
+
wi: World space light vector.
|
285 |
+
wo: World space camera vector.
|
286 |
+
linearRoughness: Material roughness
|
287 |
+
use_python: Use PyTorch implementation (for validation)
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
Shaded diffuse value with shape [minibatch_size, height, width, 1]
|
291 |
+
'''
|
292 |
+
|
293 |
+
if use_python:
|
294 |
+
out = bsdf_frostbite(nrm, wi, wo, linearRoughness)
|
295 |
+
else:
|
296 |
+
out = _frostbite_diffuse_func.apply(nrm, wi, wo, linearRoughness)
|
297 |
+
|
298 |
+
if torch.is_anomaly_enabled():
|
299 |
+
assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN"
|
300 |
+
return out
|
301 |
+
|
302 |
+
class _pbr_specular_func(torch.autograd.Function):
|
303 |
+
@staticmethod
|
304 |
+
def forward(ctx, col, nrm, wo, wi, alpha, min_roughness):
|
305 |
+
ctx.save_for_backward(col, nrm, wo, wi, alpha)
|
306 |
+
ctx.min_roughness = min_roughness
|
307 |
+
out = _get_plugin().pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False)
|
308 |
+
return out
|
309 |
+
|
310 |
+
@staticmethod
|
311 |
+
def backward(ctx, dout):
|
312 |
+
col, nrm, wo, wi, alpha = ctx.saved_variables
|
313 |
+
return _get_plugin().pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None)
|
314 |
+
|
315 |
+
def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False):
|
316 |
+
'''Physically-based specular bsdf.
|
317 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
col: Specular lobe color
|
321 |
+
nrm: World space shading normal.
|
322 |
+
wo: World space camera vector.
|
323 |
+
wi: World space light vector
|
324 |
+
alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1]
|
325 |
+
min_roughness: Scalar roughness clamping threshold
|
326 |
+
|
327 |
+
use_python: Use PyTorch implementation (for validation)
|
328 |
+
Returns:
|
329 |
+
Shaded specular color
|
330 |
+
'''
|
331 |
+
|
332 |
+
if use_python:
|
333 |
+
out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness)
|
334 |
+
else:
|
335 |
+
out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness)
|
336 |
+
|
337 |
+
if torch.is_anomaly_enabled():
|
338 |
+
assert torch.all(torch.isfinite(out)), "Output of pbr_specular contains inf or NaN"
|
339 |
+
return out
|
340 |
+
|
341 |
+
class _pbr_bsdf_func(torch.autograd.Function):
|
342 |
+
@staticmethod
|
343 |
+
def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
|
344 |
+
ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos)
|
345 |
+
ctx.min_roughness = min_roughness
|
346 |
+
ctx.BSDF = BSDF
|
347 |
+
out = _get_plugin().pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF, False)
|
348 |
+
return out
|
349 |
+
|
350 |
+
@staticmethod
|
351 |
+
def backward(ctx, dout):
|
352 |
+
kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables
|
353 |
+
return _get_plugin().pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, ctx.BSDF, dout) + (None, None, None)
|
354 |
+
|
355 |
+
def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, bsdf="lambert", use_python=False):
|
356 |
+
'''Physically-based bsdf, both diffuse & specular lobes
|
357 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
|
358 |
+
|
359 |
+
Args:
|
360 |
+
kd: Diffuse albedo.
|
361 |
+
arm: Specular parameters (attenuation, linear roughness, metalness).
|
362 |
+
pos: World space position.
|
363 |
+
nrm: World space shading normal.
|
364 |
+
view_pos: Camera position in world space, typically using broadcasting.
|
365 |
+
light_pos: Light position in world space, typically using broadcasting.
|
366 |
+
min_roughness: Scalar roughness clamping threshold
|
367 |
+
bsdf: Controls diffuse BSDF, can be either 'lambert' or 'frostbite'
|
368 |
+
|
369 |
+
use_python: Use PyTorch implementation (for validation)
|
370 |
+
|
371 |
+
Returns:
|
372 |
+
Shaded color.
|
373 |
+
'''
|
374 |
+
|
375 |
+
BSDF = 0
|
376 |
+
if bsdf == 'frostbite':
|
377 |
+
BSDF = 1
|
378 |
+
|
379 |
+
if use_python:
|
380 |
+
out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF)
|
381 |
+
else:
|
382 |
+
out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF)
|
383 |
+
|
384 |
+
if torch.is_anomaly_enabled():
|
385 |
+
assert torch.all(torch.isfinite(out)), "Output of pbr_bsdf contains inf or NaN"
|
386 |
+
return out
|
387 |
+
|
388 |
+
#----------------------------------------------------------------------------
|
389 |
+
# cubemap filter with filtering across edges
|
390 |
+
|
391 |
+
class _diffuse_cubemap_func(torch.autograd.Function):
|
392 |
+
@staticmethod
|
393 |
+
def forward(ctx, cubemap):
|
394 |
+
out = _get_plugin().diffuse_cubemap_fwd(cubemap)
|
395 |
+
ctx.save_for_backward(cubemap)
|
396 |
+
return out
|
397 |
+
|
398 |
+
@staticmethod
|
399 |
+
def backward(ctx, dout):
|
400 |
+
cubemap, = ctx.saved_variables
|
401 |
+
cubemap_grad = _get_plugin().diffuse_cubemap_bwd(cubemap, dout)
|
402 |
+
return cubemap_grad, None
|
403 |
+
|
404 |
+
def diffuse_cubemap(cubemap, use_python=False):
|
405 |
+
if use_python:
|
406 |
+
assert False
|
407 |
+
else:
|
408 |
+
out = _diffuse_cubemap_func.apply(cubemap)
|
409 |
+
if torch.is_anomaly_enabled():
|
410 |
+
assert torch.all(torch.isfinite(out)), "Output of diffuse_cubemap contains inf or NaN"
|
411 |
+
return out
|
412 |
+
|
413 |
+
class _specular_cubemap(torch.autograd.Function):
|
414 |
+
@staticmethod
|
415 |
+
def forward(ctx, cubemap, roughness, costheta_cutoff, bounds):
|
416 |
+
out = _get_plugin().specular_cubemap_fwd(cubemap, bounds, roughness, costheta_cutoff)
|
417 |
+
ctx.save_for_backward(cubemap, bounds)
|
418 |
+
ctx.roughness, ctx.theta_cutoff = roughness, costheta_cutoff
|
419 |
+
return out
|
420 |
+
|
421 |
+
@staticmethod
|
422 |
+
def backward(ctx, dout):
|
423 |
+
cubemap, bounds = ctx.saved_variables
|
424 |
+
cubemap_grad = _get_plugin().specular_cubemap_bwd(cubemap, bounds, dout, ctx.roughness, ctx.theta_cutoff)
|
425 |
+
return cubemap_grad, None, None, None
|
426 |
+
|
427 |
+
# Compute the bounds of the GGX NDF lobe to retain "cutoff" percent of the energy
|
428 |
+
def __ndfBounds(res, roughness, cutoff):
|
429 |
+
def ndfGGX(alphaSqr, costheta):
|
430 |
+
costheta = np.clip(costheta, 0.0, 1.0)
|
431 |
+
d = (costheta * alphaSqr - costheta) * costheta + 1.0
|
432 |
+
return alphaSqr / (d * d * np.pi)
|
433 |
+
|
434 |
+
# Sample out cutoff angle
|
435 |
+
nSamples = 1000000
|
436 |
+
costheta = np.cos(np.linspace(0, np.pi/2.0, nSamples))
|
437 |
+
D = np.cumsum(ndfGGX(roughness**4, costheta))
|
438 |
+
idx = np.argmax(D >= D[..., -1] * cutoff)
|
439 |
+
|
440 |
+
# Brute force compute lookup table with bounds
|
441 |
+
bounds = _get_plugin().specular_bounds(res, costheta[idx])
|
442 |
+
|
443 |
+
return costheta[idx], bounds
|
444 |
+
__ndfBoundsDict = {}
|
445 |
+
|
446 |
+
def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False):
|
447 |
+
assert cubemap.shape[0] == 6 and cubemap.shape[1] == cubemap.shape[2], "Bad shape for cubemap tensor: %s" % str(cubemap.shape)
|
448 |
+
|
449 |
+
if use_python:
|
450 |
+
assert False
|
451 |
+
else:
|
452 |
+
key = (cubemap.shape[1], roughness, cutoff)
|
453 |
+
if key not in __ndfBoundsDict:
|
454 |
+
__ndfBoundsDict[key] = __ndfBounds(*key)
|
455 |
+
out = _specular_cubemap.apply(cubemap, roughness, *__ndfBoundsDict[key])
|
456 |
+
if torch.is_anomaly_enabled():
|
457 |
+
assert torch.all(torch.isfinite(out)), "Output of specular_cubemap contains inf or NaN"
|
458 |
+
return out[..., 0:3] / out[..., 3:]
|
459 |
+
|
460 |
+
#----------------------------------------------------------------------------
|
461 |
+
# Fast image loss function
|
462 |
+
|
463 |
+
class _image_loss_func(torch.autograd.Function):
|
464 |
+
@staticmethod
|
465 |
+
def forward(ctx, img, target, loss, tonemapper):
|
466 |
+
ctx.loss, ctx.tonemapper = loss, tonemapper
|
467 |
+
ctx.save_for_backward(img, target)
|
468 |
+
out = _get_plugin().image_loss_fwd(img, target, loss, tonemapper, False)
|
469 |
+
return out
|
470 |
+
|
471 |
+
@staticmethod
|
472 |
+
def backward(ctx, dout):
|
473 |
+
img, target = ctx.saved_variables
|
474 |
+
return _get_plugin().image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None)
|
475 |
+
|
476 |
+
def image_loss(img, target, loss='l1', tonemapper='none', use_python=False):
|
477 |
+
'''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf.
|
478 |
+
All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.
|
479 |
+
|
480 |
+
Args:
|
481 |
+
img: Input image.
|
482 |
+
target: Target (reference) image.
|
483 |
+
loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse']
|
484 |
+
tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb']
|
485 |
+
use_python: Use PyTorch implementation (for validation)
|
486 |
+
|
487 |
+
Returns:
|
488 |
+
Image space loss (scalar value).
|
489 |
+
'''
|
490 |
+
if use_python:
|
491 |
+
out = image_loss_fn(img, target, loss, tonemapper)
|
492 |
+
else:
|
493 |
+
out = _image_loss_func.apply(img, target, loss, tonemapper)
|
494 |
+
out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2])
|
495 |
+
|
496 |
+
if torch.is_anomaly_enabled():
|
497 |
+
assert torch.all(torch.isfinite(out)), "Output of image_loss contains inf or NaN"
|
498 |
+
return out
|
499 |
+
|
500 |
+
#----------------------------------------------------------------------------
|
501 |
+
# Transform points function
|
502 |
+
|
503 |
+
class _xfm_func(torch.autograd.Function):
|
504 |
+
@staticmethod
|
505 |
+
def forward(ctx, points, matrix, isPoints):
|
506 |
+
ctx.save_for_backward(points, matrix)
|
507 |
+
ctx.isPoints = isPoints
|
508 |
+
return _get_plugin().xfm_fwd(points, matrix, isPoints, False)
|
509 |
+
|
510 |
+
@staticmethod
|
511 |
+
def backward(ctx, dout):
|
512 |
+
points, matrix = ctx.saved_variables
|
513 |
+
return (_get_plugin().xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None)
|
514 |
+
|
515 |
+
def xfm_points(points, matrix, use_python=False):
|
516 |
+
'''Transform points.
|
517 |
+
Args:
|
518 |
+
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
519 |
+
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
520 |
+
use_python: Use PyTorch's torch.matmul (for validation)
|
521 |
+
Returns:
|
522 |
+
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
523 |
+
'''
|
524 |
+
if use_python:
|
525 |
+
out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
|
526 |
+
else:
|
527 |
+
out = _xfm_func.apply(points, matrix, True)
|
528 |
+
|
529 |
+
if torch.is_anomaly_enabled():
|
530 |
+
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
|
531 |
+
return out
|
532 |
+
|
533 |
+
def xfm_vectors(vectors, matrix, use_python=False):
|
534 |
+
'''Transform vectors.
|
535 |
+
Args:
|
536 |
+
vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
537 |
+
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
538 |
+
use_python: Use PyTorch's torch.matmul (for validation)
|
539 |
+
|
540 |
+
Returns:
|
541 |
+
Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
542 |
+
'''
|
543 |
+
|
544 |
+
if use_python:
|
545 |
+
out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous()
|
546 |
+
else:
|
547 |
+
out = _xfm_func.apply(vectors, matrix, False)
|
548 |
+
|
549 |
+
if torch.is_anomaly_enabled():
|
550 |
+
assert torch.all(torch.isfinite(out)), "Output of xfm_vectors contains inf or NaN"
|
551 |
+
return out
|
552 |
+
|
553 |
+
|
554 |
+
|
video3d/render/renderutils/tests/test_bsdf.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
15 |
+
import renderutils as ru
|
16 |
+
|
17 |
+
RES = 4
|
18 |
+
DTYPE = torch.float32
|
19 |
+
|
20 |
+
def relative_loss(name, ref, cuda):
|
21 |
+
ref = ref.float()
|
22 |
+
cuda = cuda.float()
|
23 |
+
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
|
24 |
+
|
25 |
+
def test_normal():
|
26 |
+
pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
27 |
+
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
|
28 |
+
view_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
29 |
+
view_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True)
|
30 |
+
perturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
31 |
+
perturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True)
|
32 |
+
smooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
33 |
+
smooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True)
|
34 |
+
smooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
35 |
+
smooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True)
|
36 |
+
geom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
37 |
+
geom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True)
|
38 |
+
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
39 |
+
|
40 |
+
ref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True)
|
41 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
42 |
+
ref_loss.backward()
|
43 |
+
|
44 |
+
cuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True)
|
45 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
46 |
+
cuda_loss.backward()
|
47 |
+
|
48 |
+
print("-------------------------------------------------------------")
|
49 |
+
print(" bent normal")
|
50 |
+
print("-------------------------------------------------------------")
|
51 |
+
relative_loss("res:", ref, cuda)
|
52 |
+
relative_loss("pos:", pos_ref.grad, pos_cuda.grad)
|
53 |
+
relative_loss("view_pos:", view_pos_ref.grad, view_pos_cuda.grad)
|
54 |
+
relative_loss("perturbed_nrm:", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad)
|
55 |
+
relative_loss("smooth_nrm:", smooth_nrm_ref.grad, smooth_nrm_cuda.grad)
|
56 |
+
relative_loss("smooth_tng:", smooth_tng_ref.grad, smooth_tng_cuda.grad)
|
57 |
+
relative_loss("geom_nrm:", geom_nrm_ref.grad, geom_nrm_cuda.grad)
|
58 |
+
|
59 |
+
def test_schlick():
|
60 |
+
f0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
61 |
+
f0_ref = f0_cuda.clone().detach().requires_grad_(True)
|
62 |
+
f90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
63 |
+
f90_ref = f90_cuda.clone().detach().requires_grad_(True)
|
64 |
+
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0
|
65 |
+
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
|
66 |
+
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
|
67 |
+
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
68 |
+
|
69 |
+
ref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True)
|
70 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
71 |
+
ref_loss.backward()
|
72 |
+
|
73 |
+
cuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda)
|
74 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
75 |
+
cuda_loss.backward()
|
76 |
+
|
77 |
+
print("-------------------------------------------------------------")
|
78 |
+
print(" Fresnel shlick")
|
79 |
+
print("-------------------------------------------------------------")
|
80 |
+
relative_loss("res:", ref, cuda)
|
81 |
+
relative_loss("f0:", f0_ref.grad, f0_cuda.grad)
|
82 |
+
relative_loss("f90:", f90_ref.grad, f90_cuda.grad)
|
83 |
+
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
|
84 |
+
|
85 |
+
def test_ndf_ggx():
|
86 |
+
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
87 |
+
alphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
88 |
+
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
89 |
+
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1
|
90 |
+
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
|
91 |
+
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
|
92 |
+
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
93 |
+
|
94 |
+
ref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True)
|
95 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
96 |
+
ref_loss.backward()
|
97 |
+
|
98 |
+
cuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda)
|
99 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
100 |
+
cuda_loss.backward()
|
101 |
+
|
102 |
+
print("-------------------------------------------------------------")
|
103 |
+
print(" Ndf GGX")
|
104 |
+
print("-------------------------------------------------------------")
|
105 |
+
relative_loss("res:", ref, cuda)
|
106 |
+
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
|
107 |
+
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
|
108 |
+
|
109 |
+
def test_lambda_ggx():
|
110 |
+
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
111 |
+
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
112 |
+
cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1
|
113 |
+
cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)
|
114 |
+
cosT_ref = cosT_cuda.clone().detach().requires_grad_(True)
|
115 |
+
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
116 |
+
|
117 |
+
ref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True)
|
118 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
119 |
+
ref_loss.backward()
|
120 |
+
|
121 |
+
cuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda)
|
122 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
123 |
+
cuda_loss.backward()
|
124 |
+
|
125 |
+
print("-------------------------------------------------------------")
|
126 |
+
print(" Lambda GGX")
|
127 |
+
print("-------------------------------------------------------------")
|
128 |
+
relative_loss("res:", ref, cuda)
|
129 |
+
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
|
130 |
+
relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad)
|
131 |
+
|
132 |
+
def test_masking_smith():
|
133 |
+
alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
134 |
+
alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)
|
135 |
+
cosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
136 |
+
cosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True)
|
137 |
+
cosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
138 |
+
cosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True)
|
139 |
+
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
140 |
+
|
141 |
+
ref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True)
|
142 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
143 |
+
ref_loss.backward()
|
144 |
+
|
145 |
+
cuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda)
|
146 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
147 |
+
cuda_loss.backward()
|
148 |
+
|
149 |
+
print("-------------------------------------------------------------")
|
150 |
+
print(" Smith masking term")
|
151 |
+
print("-------------------------------------------------------------")
|
152 |
+
relative_loss("res:", ref, cuda)
|
153 |
+
relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad)
|
154 |
+
relative_loss("cosThetaI:", cosThetaI_ref.grad, cosThetaI_cuda.grad)
|
155 |
+
relative_loss("cosThetaO:", cosThetaO_ref.grad, cosThetaO_cuda.grad)
|
156 |
+
|
157 |
+
def test_lambert():
|
158 |
+
normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
159 |
+
normals_ref = normals_cuda.clone().detach().requires_grad_(True)
|
160 |
+
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
161 |
+
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
|
162 |
+
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
163 |
+
|
164 |
+
ref = ru.lambert(normals_ref, wi_ref, use_python=True)
|
165 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
166 |
+
ref_loss.backward()
|
167 |
+
|
168 |
+
cuda = ru.lambert(normals_cuda, wi_cuda)
|
169 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
170 |
+
cuda_loss.backward()
|
171 |
+
|
172 |
+
print("-------------------------------------------------------------")
|
173 |
+
print(" Lambert")
|
174 |
+
print("-------------------------------------------------------------")
|
175 |
+
relative_loss("res:", ref, cuda)
|
176 |
+
relative_loss("nrm:", normals_ref.grad, normals_cuda.grad)
|
177 |
+
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
|
178 |
+
|
179 |
+
def test_frostbite():
|
180 |
+
normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
181 |
+
normals_ref = normals_cuda.clone().detach().requires_grad_(True)
|
182 |
+
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
183 |
+
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
|
184 |
+
wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
185 |
+
wo_ref = wo_cuda.clone().detach().requires_grad_(True)
|
186 |
+
rough_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
187 |
+
rough_ref = rough_cuda.clone().detach().requires_grad_(True)
|
188 |
+
target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')
|
189 |
+
|
190 |
+
ref = ru.frostbite_diffuse(normals_ref, wi_ref, wo_ref, rough_ref, use_python=True)
|
191 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
192 |
+
ref_loss.backward()
|
193 |
+
|
194 |
+
cuda = ru.frostbite_diffuse(normals_cuda, wi_cuda, wo_cuda, rough_cuda)
|
195 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
196 |
+
cuda_loss.backward()
|
197 |
+
|
198 |
+
print("-------------------------------------------------------------")
|
199 |
+
print(" Frostbite")
|
200 |
+
print("-------------------------------------------------------------")
|
201 |
+
relative_loss("res:", ref, cuda)
|
202 |
+
relative_loss("nrm:", normals_ref.grad, normals_cuda.grad)
|
203 |
+
relative_loss("wo:", wo_ref.grad, wo_cuda.grad)
|
204 |
+
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
|
205 |
+
relative_loss("rough:", rough_ref.grad, rough_cuda.grad)
|
206 |
+
|
207 |
+
def test_pbr_specular():
|
208 |
+
col_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
209 |
+
col_ref = col_cuda.clone().detach().requires_grad_(True)
|
210 |
+
nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
211 |
+
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
|
212 |
+
wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
213 |
+
wi_ref = wi_cuda.clone().detach().requires_grad_(True)
|
214 |
+
wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
215 |
+
wo_ref = wo_cuda.clone().detach().requires_grad_(True)
|
216 |
+
alpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)
|
217 |
+
alpha_ref = alpha_cuda.clone().detach().requires_grad_(True)
|
218 |
+
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
219 |
+
|
220 |
+
ref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True)
|
221 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
222 |
+
ref_loss.backward()
|
223 |
+
|
224 |
+
cuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda)
|
225 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
226 |
+
cuda_loss.backward()
|
227 |
+
|
228 |
+
print("-------------------------------------------------------------")
|
229 |
+
print(" Pbr specular")
|
230 |
+
print("-------------------------------------------------------------")
|
231 |
+
|
232 |
+
relative_loss("res:", ref, cuda)
|
233 |
+
if col_ref.grad is not None:
|
234 |
+
relative_loss("col:", col_ref.grad, col_cuda.grad)
|
235 |
+
if nrm_ref.grad is not None:
|
236 |
+
relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad)
|
237 |
+
if wi_ref.grad is not None:
|
238 |
+
relative_loss("wi:", wi_ref.grad, wi_cuda.grad)
|
239 |
+
if wo_ref.grad is not None:
|
240 |
+
relative_loss("wo:", wo_ref.grad, wo_cuda.grad)
|
241 |
+
if alpha_ref.grad is not None:
|
242 |
+
relative_loss("alpha:", alpha_ref.grad, alpha_cuda.grad)
|
243 |
+
|
244 |
+
def test_pbr_bsdf(bsdf):
|
245 |
+
kd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
246 |
+
kd_ref = kd_cuda.clone().detach().requires_grad_(True)
|
247 |
+
arm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
248 |
+
arm_ref = arm_cuda.clone().detach().requires_grad_(True)
|
249 |
+
pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
250 |
+
pos_ref = pos_cuda.clone().detach().requires_grad_(True)
|
251 |
+
nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
252 |
+
nrm_ref = nrm_cuda.clone().detach().requires_grad_(True)
|
253 |
+
view_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
254 |
+
view_ref = view_cuda.clone().detach().requires_grad_(True)
|
255 |
+
light_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
256 |
+
light_ref = light_cuda.clone().detach().requires_grad_(True)
|
257 |
+
target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')
|
258 |
+
|
259 |
+
ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True, bsdf=bsdf)
|
260 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
261 |
+
ref_loss.backward()
|
262 |
+
|
263 |
+
cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda, bsdf=bsdf)
|
264 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
265 |
+
cuda_loss.backward()
|
266 |
+
|
267 |
+
print("-------------------------------------------------------------")
|
268 |
+
print(" Pbr BSDF")
|
269 |
+
print("-------------------------------------------------------------")
|
270 |
+
|
271 |
+
relative_loss("res:", ref, cuda)
|
272 |
+
if kd_ref.grad is not None:
|
273 |
+
relative_loss("kd:", kd_ref.grad, kd_cuda.grad)
|
274 |
+
if arm_ref.grad is not None:
|
275 |
+
relative_loss("arm:", arm_ref.grad, arm_cuda.grad)
|
276 |
+
if pos_ref.grad is not None:
|
277 |
+
relative_loss("pos:", pos_ref.grad, pos_cuda.grad)
|
278 |
+
if nrm_ref.grad is not None:
|
279 |
+
relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad)
|
280 |
+
if view_ref.grad is not None:
|
281 |
+
relative_loss("view:", view_ref.grad, view_cuda.grad)
|
282 |
+
if light_ref.grad is not None:
|
283 |
+
relative_loss("light:", light_ref.grad, light_cuda.grad)
|
284 |
+
|
285 |
+
test_normal()
|
286 |
+
|
287 |
+
test_schlick()
|
288 |
+
test_ndf_ggx()
|
289 |
+
test_lambda_ggx()
|
290 |
+
test_masking_smith()
|
291 |
+
|
292 |
+
test_lambert()
|
293 |
+
test_frostbite()
|
294 |
+
test_pbr_specular()
|
295 |
+
test_pbr_bsdf('lambert')
|
296 |
+
test_pbr_bsdf('frostbite')
|
video3d/render/renderutils/tests/test_cubemap.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
sys.path.insert(0, os.path.join(sys.path[0], '../..'))
|
15 |
+
import renderutils as ru
|
16 |
+
|
17 |
+
RES = 4
|
18 |
+
DTYPE = torch.float32
|
19 |
+
|
20 |
+
def relative_loss(name, ref, cuda):
|
21 |
+
ref = ref.float()
|
22 |
+
cuda = cuda.float()
|
23 |
+
print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())
|
24 |
+
|
25 |
+
def test_cubemap():
|
26 |
+
cubemap_cuda = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)
|
27 |
+
cubemap_ref = cubemap_cuda.clone().detach().requires_grad_(True)
|
28 |
+
weights = torch.rand(3, 3, 1, dtype=DTYPE, device='cuda')
|
29 |
+
target = torch.rand(6, RES, RES, 3, dtype=DTYPE, device='cuda')
|
30 |
+
|
31 |
+
ref = ru.filter_cubemap(cubemap_ref, weights, use_python=True)
|
32 |
+
ref_loss = torch.nn.MSELoss()(ref, target)
|
33 |
+
ref_loss.backward()
|
34 |
+
|
35 |
+
cuda = ru.filter_cubemap(cubemap_cuda, weights, use_python=False)
|
36 |
+
cuda_loss = torch.nn.MSELoss()(cuda, target)
|
37 |
+
cuda_loss.backward()
|
38 |
+
|
39 |
+
print("-------------------------------------------------------------")
|
40 |
+
print(" Cubemap:")
|
41 |
+
print("-------------------------------------------------------------")
|
42 |
+
|
43 |
+
relative_loss("flt:", ref, cuda)
|
44 |
+
relative_loss("cubemap:", cubemap_ref.grad, cubemap_cuda.grad)
|
45 |
+
|
46 |
+
|
47 |
+
test_cubemap()
|