Spaces:
Build error
Build error
first deploy demo
Browse files- .gitignore +3 -0
- LICENSE +201 -0
- app.py +40 -0
- demo/__init__.py +0 -0
- demo/src/__init__.py +0 -0
- demo/src/config.py +44 -0
- demo/src/models.py +38 -0
- demo/src/utils.py +62 -0
- jaxnerf/README.md +205 -0
- jaxnerf/__init__.py +15 -0
- jaxnerf/configs/blender.yaml +9 -0
- jaxnerf/configs/demo.yaml +10 -0
- jaxnerf/configs/diet_nerf_tpu_vm_few_shot.yaml +20 -0
- jaxnerf/configs/diet_nerf_tpu_vm_test.yaml +20 -0
- jaxnerf/configs/eval_diet_nerf_tpu_vm_few_shot.yaml +22 -0
- jaxnerf/configs/llff.yaml +13 -0
- jaxnerf/configs/llff_360.yaml +15 -0
- jaxnerf/configs/nerf_tpu_vm_few_shot.yaml +20 -0
- jaxnerf/configs/orig_nerf_tpu_vm_full.yaml +13 -0
- jaxnerf/configs/orig_nerf_tpu_vm_test.yaml +13 -0
- jaxnerf/eval.py +192 -0
- jaxnerf/eval.sh +44 -0
- jaxnerf/example_data/imgs/r_0.png +0 -0
- jaxnerf/example_data/transforms_test.json +1 -0
- jaxnerf/example_data/transforms_train.json +1 -0
- jaxnerf/nerf/__init__.py +15 -0
- jaxnerf/nerf/clip_utils.py +134 -0
- jaxnerf/nerf/datasets.py +565 -0
- jaxnerf/nerf/model_utils.py +321 -0
- jaxnerf/nerf/models.py +256 -0
- jaxnerf/nerf/precompute.py +59 -0
- jaxnerf/nerf/utils.py +457 -0
- jaxnerf/requirements.txt +14 -0
- jaxnerf/run.sh +33 -0
- jaxnerf/train.py +326 -0
- jaxnerf/train.sh +34 -0
- requirements.txt +8 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
__pycache__
|
3 |
+
models
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import streamlit as st
|
4 |
+
from google_drive_downloader import GoogleDriveDownloader as gdd
|
5 |
+
|
6 |
+
from demo.src.models import load_trained_model
|
7 |
+
from demo.src.utils import render_predict_from_pose, predict_to_image
|
8 |
+
from demo.src.config import MODEL_DIR, MODEL_NAME, FILE_ID
|
9 |
+
|
10 |
+
|
11 |
+
if not os.path.isfile('models'):
|
12 |
+
model_path = os.path.join(MODEL_DIR, MODEL_NAME)
|
13 |
+
gdd.download_file_from_google_drive(file_id=FILE_ID,
|
14 |
+
dest_path=model_path,
|
15 |
+
unzip=False)
|
16 |
+
print(f'model downloaded from google drive: {model_path}')
|
17 |
+
|
18 |
+
|
19 |
+
@st.cache(show_spinner=False, allow_output_mutation=True)
|
20 |
+
def fetch_model():
|
21 |
+
model, state = load_trained_model(MODEL_DIR, MODEL_NAME)
|
22 |
+
return model, state
|
23 |
+
|
24 |
+
|
25 |
+
model, state = fetch_model()
|
26 |
+
pi = math.pi
|
27 |
+
st.set_page_config(page_title="DietNeRF Demo")
|
28 |
+
st.sidebar.header('SELECT YOUR VIEW DIRECTION')
|
29 |
+
theta = st.sidebar.slider("Theta", min_value=0., max_value=2.*pi,
|
30 |
+
step=0.5, value=0.)
|
31 |
+
phi = st.sidebar.slider("Phi", min_value=0., max_value=0.5*pi,
|
32 |
+
step=0.1, value=1.)
|
33 |
+
radius = st.sidebar.slider("Radius", min_value=2., max_value=6.,
|
34 |
+
step=1., value=3.)
|
35 |
+
|
36 |
+
|
37 |
+
pred_color, _ = render_predict_from_pose(state, theta, phi, radius)
|
38 |
+
im = predict_to_image(pred_color)
|
39 |
+
|
40 |
+
st.image(im, use_column_width=False)
|
demo/__init__.py
ADDED
File without changes
|
demo/src/__init__.py
ADDED
File without changes
|
demo/src/config.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# for downloading model from google drive
|
2 |
+
FILE_ID = "1iytA1n2z4go3uVCwE__vIKouTKyIDjEq"
|
3 |
+
MODEL_DIR = './models'
|
4 |
+
MODEL_NAME = 'trained_model'
|
5 |
+
|
6 |
+
|
7 |
+
class NerfConfig:
|
8 |
+
# MODEL CONFIG
|
9 |
+
model = "nerf"
|
10 |
+
net_activation = "relu"
|
11 |
+
rgb_activation = "sigmoid"
|
12 |
+
sigma_activation = "relu"
|
13 |
+
min_deg_point = 0
|
14 |
+
max_deg_point = 10
|
15 |
+
deg_view = 4
|
16 |
+
# reduce num_coarse_samples, num_fine_samples for speedup
|
17 |
+
num_coarse_samples = 32
|
18 |
+
num_fine_samples = 64
|
19 |
+
use_viewdirs = True
|
20 |
+
near = 2
|
21 |
+
far = 6
|
22 |
+
noise_std = None
|
23 |
+
# TODO @Alex: set white_bkgd as flag if we add LLFF dataset
|
24 |
+
white_bkgd = True
|
25 |
+
net_depth = 8
|
26 |
+
net_width = 256
|
27 |
+
net_depth_condition = 1
|
28 |
+
net_width_condition = 128
|
29 |
+
skip_layer = 4
|
30 |
+
num_rgb_channels = 3
|
31 |
+
num_sigma_channels = 1
|
32 |
+
lindisp = True
|
33 |
+
legacy_posenc_order = False
|
34 |
+
randomized = True
|
35 |
+
|
36 |
+
# DATA CONFIG
|
37 |
+
W = 800
|
38 |
+
H = 800
|
39 |
+
IMAGE_SHAPE = (W, H, 3)
|
40 |
+
# TODO @Alex: flexible focal if we add LLFF dataset
|
41 |
+
FOCAL = 555.5555155968841
|
42 |
+
# reduce CHUNK if OOM
|
43 |
+
CHUNK = 4096
|
44 |
+
DOWNSAMPLE = 2
|
demo/src/models.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import flax
|
3 |
+
from jax import random
|
4 |
+
from flax.training import checkpoints
|
5 |
+
|
6 |
+
from jaxnerf.nerf import models
|
7 |
+
from jaxnerf.nerf import utils
|
8 |
+
from demo.src.config import NerfConfig
|
9 |
+
|
10 |
+
rng = random.PRNGKey(0)
|
11 |
+
# TODO @Alex: make image size flexible if needed
|
12 |
+
dummy_rays = random.uniform(rng, shape=NerfConfig.IMAGE_SHAPE)
|
13 |
+
dummy_batch = {"rays": utils.Rays(dummy_rays, dummy_rays, dummy_rays)}
|
14 |
+
dummy_lr = 1e-2
|
15 |
+
|
16 |
+
|
17 |
+
def load_trained_model(model_dir, model_fn):
|
18 |
+
model, init_variables = init_model()
|
19 |
+
optimizer = flax.optim.Adam(dummy_lr).create(init_variables)
|
20 |
+
state = utils.TrainState(optimizer=optimizer)
|
21 |
+
del optimizer, init_variables
|
22 |
+
assert os.path.isfile(os.path.join(model_dir, model_fn))
|
23 |
+
state = checkpoints.restore_checkpoint(model_dir, state,
|
24 |
+
prefix=model_fn)
|
25 |
+
return model, state
|
26 |
+
|
27 |
+
|
28 |
+
def init_model():
|
29 |
+
_, key = random.split(rng)
|
30 |
+
model, init_variables = models.get_model(key, dummy_batch,
|
31 |
+
NerfConfig)
|
32 |
+
return model, init_variables
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == '__main__':
|
36 |
+
_model_dir = '../ship_fewshot_wsc'
|
37 |
+
_model_fn = 'checkpoint_345000'
|
38 |
+
_model, _state = load_trained_model(_model_dir, _model_fn)
|
demo/src/utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import jax
|
3 |
+
from jax import random
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from jaxnerf.nerf import clip_utils
|
8 |
+
from jaxnerf.nerf import utils
|
9 |
+
from demo.src.config import NerfConfig
|
10 |
+
from demo.src.models import init_model
|
11 |
+
|
12 |
+
model, _ = init_model()
|
13 |
+
|
14 |
+
|
15 |
+
def render_predict_from_pose(state, theta, phi, radius):
|
16 |
+
rng = random.PRNGKey(0)
|
17 |
+
partial_render_fn = partial(render_pfn, state.optimizer.target)
|
18 |
+
rays = _render_rays_from_pose(theta, phi, radius)
|
19 |
+
pred_color, pred_disp, _ = utils.render_image(
|
20 |
+
partial_render_fn, rays,
|
21 |
+
rng, False, chunk=NerfConfig.CHUNK)
|
22 |
+
return pred_color, pred_disp
|
23 |
+
|
24 |
+
|
25 |
+
def predict_to_image(pred_out):
|
26 |
+
image_arr = np.array(np.clip(pred_out, 0., 1.) * 255.).astype(np.uint8)
|
27 |
+
return Image.fromarray(image_arr)
|
28 |
+
|
29 |
+
|
30 |
+
def _render_rays_from_pose(theta, phi, radius):
|
31 |
+
camtoworld = np.array(clip_utils.pose_spherical(theta, phi, radius))
|
32 |
+
rays = _camtoworld_matrix_to_rays(camtoworld)
|
33 |
+
return rays
|
34 |
+
|
35 |
+
|
36 |
+
def _camtoworld_matrix_to_rays(camtoworld):
|
37 |
+
""" render one instance of rays given a camera to world matrix (4, 4) """
|
38 |
+
pixel_center = 0.
|
39 |
+
w, h = NerfConfig.W, NerfConfig.H
|
40 |
+
focal, downsample = NerfConfig.FOCAL, NerfConfig.DOWNSAMPLE
|
41 |
+
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
|
42 |
+
np.arange(0, w, downsample, dtype=np.float32) + pixel_center, # X-Axis (columns)
|
43 |
+
np.arange(0, h, downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows)
|
44 |
+
indexing="xy")
|
45 |
+
camera_dirs = np.stack([(x - w * 0.5) / focal,
|
46 |
+
-(y - h * 0.5) / focal,
|
47 |
+
-np.ones_like(x)],
|
48 |
+
axis=-1)
|
49 |
+
directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1)
|
50 |
+
origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape)
|
51 |
+
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
|
52 |
+
return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs)
|
53 |
+
|
54 |
+
|
55 |
+
def _render_fn(variables, key_0, key_1, rays):
|
56 |
+
return jax.lax.all_gather(model.apply(
|
57 |
+
variables, key_0, key_1, rays, False),
|
58 |
+
axis_name="batch")
|
59 |
+
|
60 |
+
|
61 |
+
render_pfn = jax.pmap(_render_fn, in_axes=(None, None, None, 0),
|
62 |
+
donate_argnums=3, axis_name="batch")
|
jaxnerf/README.md
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# JaxNeRF
|
2 |
+
|
3 |
+
This is a [JAX](https://github.com/google/jax) implementation of
|
4 |
+
[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](http://www.matthewtancik.com/nerf).
|
5 |
+
This code is created and maintained by
|
6 |
+
[Boyang Deng](https://boyangdeng.com/),
|
7 |
+
[Jon Barron](https://jonbarron.info/),
|
8 |
+
and [Pratul Srinivasan](https://people.eecs.berkeley.edu/~pratul/).
|
9 |
+
|
10 |
+
<div align="center">
|
11 |
+
<img width="95%" alt="NeRF Teaser" src="https://raw.githubusercontent.com/bmild/nerf/master/imgs/pipeline.jpg">
|
12 |
+
</div>
|
13 |
+
|
14 |
+
Our JAX implementation currently supports:
|
15 |
+
|
16 |
+
<table class="tg">
|
17 |
+
<thead>
|
18 |
+
<tr>
|
19 |
+
<th class="tg-0lax"><span style="font-weight:bold">Platform</span></th>
|
20 |
+
<th class="tg-0lax" colspan="2"><span style="font-weight:bold">Single-Host GPU</span></th>
|
21 |
+
<th class="tg-0lax" colspan="2"><span style="font-weight:bold">Multi-Device TPU</span></th>
|
22 |
+
</tr>
|
23 |
+
</thead>
|
24 |
+
<tbody>
|
25 |
+
<tr>
|
26 |
+
<td class="tg-0lax"><span style="font-weight:bold">Type</span></td>
|
27 |
+
<td class="tg-0lax">Single-Device</td>
|
28 |
+
<td class="tg-0lax">Multi-Device</td>
|
29 |
+
<td class="tg-0lax">Single-Host</td>
|
30 |
+
<td class="tg-0lax">Multi-Host</td>
|
31 |
+
</tr>
|
32 |
+
<tr>
|
33 |
+
<td class="tg-0lax"><span style="font-weight:bold">Training</span></td>
|
34 |
+
<td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
|
35 |
+
<td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
|
36 |
+
<td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
|
37 |
+
<td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
|
38 |
+
</tr>
|
39 |
+
<tr>
|
40 |
+
<td class="tg-0lax"><span style="font-weight:bold">Evaluation</span></td>
|
41 |
+
<td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
|
42 |
+
<td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
|
43 |
+
<td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
|
44 |
+
<td class="tg-0lax"><img src="http://storage.googleapis.com/gresearch/jaxnerf/check.png" alt="Supported" width=18px height=18px></td>
|
45 |
+
</tr>
|
46 |
+
</tbody>
|
47 |
+
</table>
|
48 |
+
|
49 |
+
The training job on 128 TPUv2 cores can be done in **2.5 hours (v.s 3 days for TF
|
50 |
+
NeRF)** for 1 million optimization steps. In other words, JaxNeRF trains to the best while trains very fast.
|
51 |
+
|
52 |
+
As for inference speed, here are the statistics of rendering an image with
|
53 |
+
800x800 resolution (numbers are averaged over 50 rendering passes):
|
54 |
+
|
55 |
+
| Platform | 1 x NVIDIA V100 | 8 x NVIDIA V100 | 128 x TPUv2 |
|
56 |
+
|----------|:---------------:|:-----------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------:|
|
57 |
+
| TF NeRF | 27.74 secs | <img src="http://storage.googleapis.com/gresearch/jaxnerf/cross.png" alt="Not Supported" width=18px height=18px> | <img src="http://storage.googleapis.com/gresearch/jaxnerf/cross.png" alt="Not Supported" width=18px height=18px> |
|
58 |
+
| JaxNeRF | 20.77 secs | 2.65 secs | 0.35 secs |
|
59 |
+
|
60 |
+
|
61 |
+
The code is tested and reviewed carefully to match the
|
62 |
+
[original TF NeRF implementation](https://github.com/bmild/nerf).
|
63 |
+
If you have any issues using this code, please do not open an issue as the repo
|
64 |
+
is shared by all projects under Google Research. Instead, just email
|
65 |
+
jaxnerf@google.com.
|
66 |
+
|
67 |
+
## Installation
|
68 |
+
We recommend using [Anaconda](https://www.anaconda.com/products/individual) to set
|
69 |
+
up the environment. Run the following commands:
|
70 |
+
|
71 |
+
```
|
72 |
+
# Clone the repo
|
73 |
+
svn export https://github.com/google-research/google-research/trunk/jaxnerf
|
74 |
+
# Create a conda environment, note you can use python 3.6-3.8 as
|
75 |
+
# one of the dependencies (TensorFlow) hasn't supported python 3.9 yet.
|
76 |
+
conda create --name jaxnerf python=3.6.12; conda activate jaxnerf
|
77 |
+
# Prepare pip
|
78 |
+
conda install pip; pip install --upgrade pip
|
79 |
+
# Install requirements
|
80 |
+
pip install -r jaxnerf/requirements.txt
|
81 |
+
# [Optional] Install GPU and TPU support for Jax
|
82 |
+
# Remember to change cuda101 to your CUDA version, e.g. cuda110 for CUDA 11.0.
|
83 |
+
pip install --upgrade jax jaxlib==0.1.57+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html
|
84 |
+
```
|
85 |
+
|
86 |
+
Then, you'll need to download the datasets
|
87 |
+
from the [NeRF official Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1).
|
88 |
+
Please download the `nerf_synthetic.zip` and `nerf_llff_data.zip` and unzip them
|
89 |
+
in the place you like. Let's assume they are placed under `/tmp/jaxnerf/data/`.
|
90 |
+
|
91 |
+
That's it for installation. You're good to go. **Notice:** For the following instructions, you don't need to enter the jaxnerf folder. Just stay in the parent folder.
|
92 |
+
|
93 |
+
## Two Commands for Everything
|
94 |
+
|
95 |
+
```
|
96 |
+
bash jaxnerf/train.sh demo /tmp/jaxnerf/data
|
97 |
+
bash jaxnerf/eval.sh demo /tmp/jaxnerf/data
|
98 |
+
```
|
99 |
+
|
100 |
+
Once both jobs are done running (which may take a while if you only have 1 GPU
|
101 |
+
or CPU), you'll have a folder, `/tmp/jaxnerf/data/demo`, with:
|
102 |
+
|
103 |
+
* Trained NeRF models for all scenes in the blender dataset.
|
104 |
+
* Rendered images and depth maps for all test views.
|
105 |
+
* The collected PSNRs of all scenes in a TXT file.
|
106 |
+
|
107 |
+
Note that we used the `demo` config here which is basically the `blender` config
|
108 |
+
in the paper except smaller batch size and much less train steps. Of course, you
|
109 |
+
can use other configs to replace `demo` and other data locations to replace
|
110 |
+
`/tmp/jaxnerf/data`.
|
111 |
+
|
112 |
+
We provide 2 configurations in the folder `configs` which match the original
|
113 |
+
configurations used in the paper for the blender dataset and the LLFF dataset.
|
114 |
+
Be careful when you use them. Their batch sizes are large so you may get OOM error if you have limited resources, for example, 1 GPU with small memory. Also, they have many many train steps so you may need days to finish training all scenes.
|
115 |
+
|
116 |
+
## Play with One Scene
|
117 |
+
|
118 |
+
You can also train NeRF on only one scene. The easiest way is to use given configs:
|
119 |
+
|
120 |
+
```
|
121 |
+
python -m jaxnerf.train \
|
122 |
+
--data_dir=/PATH/TO/YOUR/SCENE/DATA \
|
123 |
+
--train_dir=/PATH/TO/THE/PLACE/YOU/WANT/TO/SAVE/CHECKPOINTS \
|
124 |
+
--config=configs/CONFIG_YOU_LIKE
|
125 |
+
```
|
126 |
+
|
127 |
+
Evaluating NeRF on one scene is similar:
|
128 |
+
|
129 |
+
```
|
130 |
+
python -m jaxnerf.eval \
|
131 |
+
--data_dir=/PATH/TO/YOUR/SCENE/DATA \
|
132 |
+
--train_dir=/PATH/TO/THE/PLACE/YOU/SAVED/CHECKPOINTS \
|
133 |
+
--config=configs/CONFIG_YOU_LIKE \
|
134 |
+
--chunk=4096
|
135 |
+
```
|
136 |
+
|
137 |
+
The `chunk` parameter defines how many rays are feed to the model in one go.
|
138 |
+
We recommend you to use the largest value that fits to your device's memory but
|
139 |
+
small values are fine, only a bit slow.
|
140 |
+
|
141 |
+
You can also define your own configurations by passing command line flags. Please refer to the `define_flags` function in `nerf/utils.py` for all the flags and their meanings.
|
142 |
+
|
143 |
+
**Note**: For the ficus scene in the blender dataset, we noticed that it's sensible to different initializations,
|
144 |
+
e.g. using different random seeds, if using the original learning rate schedule in the paper.
|
145 |
+
Therefore, we provide a simple tweak (turned off by default) for more stable trainings: using `lr_delay_steps` and `lr_delay_mult`.
|
146 |
+
This allows the training to start from a smaller learning rate (`lr_init` * `lr_delay_mult`) in the first `lr_delay_steps`.
|
147 |
+
We didn't use them for our pretrained models
|
148 |
+
but we tested `lr_delay_steps=5000` with `lr_delay_mult=0.2` and it works quite smoothly.
|
149 |
+
|
150 |
+
## Pretrained Models
|
151 |
+
|
152 |
+
We provide a collection of pretrained NeRF models that match the numbers
|
153 |
+
reported in the [paper](https://arxiv.org/abs/2003.08934). Actually, ours are
|
154 |
+
slightly better overall because we trained for more iterations (while still
|
155 |
+
being much faster!). You can find our pretrained models
|
156 |
+
[here](http://storage.googleapis.com/gresearch/jaxnerf/jaxnerf_pretrained_models.zip).
|
157 |
+
The performances (in PSNR) of our pretrained NeRF models are listed below:
|
158 |
+
|
159 |
+
### Blender
|
160 |
+
|
161 |
+
|
162 |
+
| Scene | Chair | Drums | Ficus | Hotdog | Lego | Materials | Mic | Ship | Mean |
|
163 |
+
|---------|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|
|
164 |
+
| TF NeRF | 33.00 | 25.01 | 30.13 | 36.18 | 32.54 | 29.62 | 32.91 | 28.65 | 31.01 |
|
165 |
+
| JaxNeRF | **34.08** | **25.03** | **30.43** | **36.92** | **33.28** | **29.91** | **34.53** | **29.36** | **31.69** |
|
166 |
+
|
167 |
+
### LLFF
|
168 |
+
|
169 |
+
| Scene | Room | Fern | Leaves | Fortress | Orchids | Flower | T-Rex | Horns | Mean |
|
170 |
+
|---------|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|:---------:|
|
171 |
+
| TF NeRF | 32.70 | **25.17** | 20.92 | 31.16 | **20.36** | 27.40 | 26.80 | 27.45 | 26.50 |
|
172 |
+
| JaxNeRF | **33.04** | 24.83 | **21.23** | **31.76** | 20.27 | **28.07** | **27.42** | **28.10** | **26.84** |
|
173 |
+
|
174 |
+
## Citation
|
175 |
+
If you use this software package, please cite it as:
|
176 |
+
|
177 |
+
```
|
178 |
+
@software{jaxnerf2020github,
|
179 |
+
author = {Boyang Deng and Jonathan T. Barron and Pratul P. Srinivasan},
|
180 |
+
title = {{JaxNeRF}: an efficient {JAX} implementation of {NeRF}},
|
181 |
+
url = {https://github.com/google-research/google-research/tree/master/jaxnerf},
|
182 |
+
version = {0.0},
|
183 |
+
year = {2020},
|
184 |
+
}
|
185 |
+
```
|
186 |
+
|
187 |
+
and also cite the original NeRF paper:
|
188 |
+
|
189 |
+
```
|
190 |
+
@inproceedings{mildenhall2020nerf,
|
191 |
+
title={NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
|
192 |
+
author={Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng},
|
193 |
+
year={2020},
|
194 |
+
booktitle={ECCV},
|
195 |
+
}
|
196 |
+
```
|
197 |
+
|
198 |
+
## Acknowledgement
|
199 |
+
We'd like to thank
|
200 |
+
[Daniel Duckworth](http://www.stronglyconvex.com/),
|
201 |
+
[Dan Gnanapragasam](https://research.google/people/DanGnanapragasam/),
|
202 |
+
and [James Bradbury](https://twitter.com/jekbradbury)
|
203 |
+
for their help on reviewing and optimizing this code.
|
204 |
+
We'd like to also thank the amazing [JAX](https://github.com/google/jax) team for
|
205 |
+
very insightful and helpful discussions on how to use JAX for NeRF.
|
jaxnerf/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
jaxnerf/configs/blender.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 4096
|
9 |
+
randomized: true
|
jaxnerf/configs/demo.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 50000
|
jaxnerf/configs/diet_nerf_tpu_vm_few_shot.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 500000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 5000
|
13 |
+
save_every: 5000
|
14 |
+
use_semantic_loss: true
|
15 |
+
clip_model_name: openai/clip-vit-base-patch32
|
16 |
+
clip_output_dtype: float32
|
17 |
+
sc_loss_factor: 4
|
18 |
+
sc_loss_every: 16
|
19 |
+
sc_loss_mult: 10
|
20 |
+
few_shot: 8
|
jaxnerf/configs/diet_nerf_tpu_vm_test.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 500000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 5000
|
13 |
+
save_every: 5000
|
14 |
+
use_semantic_loss: true
|
15 |
+
clip_model_name: openai/clip-vit-base-patch32
|
16 |
+
clip_output_dtype: float32
|
17 |
+
sc_loss_factor: 4
|
18 |
+
sc_loss_every: 16
|
19 |
+
sc_loss_mult: 10
|
20 |
+
few_shot: -1
|
jaxnerf/configs/eval_diet_nerf_tpu_vm_few_shot.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 500000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 5000
|
13 |
+
save_every: 5000
|
14 |
+
use_semantic_loss: true
|
15 |
+
clip_model_name: openai/clip-vit-base-patch32
|
16 |
+
clip_output_dtype: float32
|
17 |
+
sc_loss_factor: 4
|
18 |
+
sc_loss_every: 16
|
19 |
+
sc_loss_mult: 10
|
20 |
+
few_shot: 8
|
21 |
+
spherify: True
|
22 |
+
lindisp: True
|
jaxnerf/configs/llff.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: llff
|
2 |
+
batching: all_images
|
3 |
+
num_coarse_samples: 64
|
4 |
+
num_fine_samples: 128
|
5 |
+
use_viewdirs: true
|
6 |
+
white_bkgd: false
|
7 |
+
batch_size: 4096
|
8 |
+
randomized: true
|
9 |
+
near: 0.
|
10 |
+
far: 1.
|
11 |
+
factor: 4
|
12 |
+
llffhold: 8
|
13 |
+
noise_std: 1.
|
jaxnerf/configs/llff_360.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: llff
|
2 |
+
batching: all_images
|
3 |
+
num_coarse_samples: 64
|
4 |
+
num_fine_samples: 128
|
5 |
+
use_viewdirs: true
|
6 |
+
white_bkgd: false
|
7 |
+
batch_size: 4096
|
8 |
+
randomized: true
|
9 |
+
near: 0.2
|
10 |
+
far: 100.
|
11 |
+
factor: 8
|
12 |
+
llffhold: 8
|
13 |
+
noise_std: 1.
|
14 |
+
spherify: True
|
15 |
+
lindisp: True
|
jaxnerf/configs/nerf_tpu_vm_few_shot.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 500000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 5000
|
13 |
+
save_every: 5000
|
14 |
+
use_semantic_loss: false
|
15 |
+
clip_model_name: openai/clip-vit-base-patch32
|
16 |
+
clip_output_dtype: float32
|
17 |
+
sc_loss_factor: 4
|
18 |
+
sc_loss_every: 16
|
19 |
+
sc_loss_mult: 10
|
20 |
+
few_shot: 8
|
jaxnerf/configs/orig_nerf_tpu_vm_full.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 100000
|
11 |
+
print_every: 1000
|
12 |
+
render_every: 5000
|
13 |
+
save_every: 5000
|
jaxnerf/configs/orig_nerf_tpu_vm_test.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset: blender
|
2 |
+
batching: single_image
|
3 |
+
factor: 0
|
4 |
+
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
+
use_viewdirs: true
|
7 |
+
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
+
randomized: true
|
10 |
+
max_steps: 5000
|
11 |
+
print_every: 100
|
12 |
+
render_every: 500
|
13 |
+
save_every: 500
|
jaxnerf/eval.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Evaluation script for Nerf."""
|
18 |
+
import functools
|
19 |
+
from os import path
|
20 |
+
|
21 |
+
from absl import app
|
22 |
+
from absl import flags
|
23 |
+
import flax
|
24 |
+
from flax.metrics import tensorboard
|
25 |
+
from flax.training import checkpoints
|
26 |
+
import jax
|
27 |
+
from jax import random
|
28 |
+
import numpy as np
|
29 |
+
import tensorflow as tf
|
30 |
+
import tensorflow_hub as tf_hub
|
31 |
+
#import wandb
|
32 |
+
import glob
|
33 |
+
import cv2
|
34 |
+
import os
|
35 |
+
|
36 |
+
from jaxnerf.nerf import datasets
|
37 |
+
from jaxnerf.nerf import models
|
38 |
+
from jaxnerf.nerf import utils
|
39 |
+
|
40 |
+
FLAGS = flags.FLAGS
|
41 |
+
|
42 |
+
utils.define_flags()
|
43 |
+
|
44 |
+
#LPIPS_TFHUB_PATH = "@neural-rendering/lpips/distance/1"
|
45 |
+
|
46 |
+
|
47 |
+
def compute_lpips(image1, image2, model):
|
48 |
+
"""Compute the LPIPS metric."""
|
49 |
+
# The LPIPS model expects a batch dimension.
|
50 |
+
return model(
|
51 |
+
tf.convert_to_tensor(image1[None, Ellipsis]),
|
52 |
+
tf.convert_to_tensor(image2[None, Ellipsis]))[0]
|
53 |
+
|
54 |
+
|
55 |
+
def main(unused_argv):
|
56 |
+
# Hide the GPUs and TPUs from TF so it does not reserve memory on them for
|
57 |
+
# LPIPS computation or dataset loading.
|
58 |
+
tf.config.experimental.set_visible_devices([], "GPU")
|
59 |
+
tf.config.experimental.set_visible_devices([], "TPU")
|
60 |
+
|
61 |
+
#wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True)
|
62 |
+
|
63 |
+
rng = random.PRNGKey(20200823)
|
64 |
+
|
65 |
+
if FLAGS.config is not None:
|
66 |
+
utils.update_flags(FLAGS)
|
67 |
+
if FLAGS.train_dir is None:
|
68 |
+
raise ValueError("train_dir must be set. None set now.")
|
69 |
+
if FLAGS.data_dir is None:
|
70 |
+
raise ValueError("data_dir must be set. None set now.")
|
71 |
+
|
72 |
+
dataset = datasets.get_dataset("test", FLAGS)
|
73 |
+
rng, key = random.split(rng)
|
74 |
+
model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
|
75 |
+
optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
|
76 |
+
state = utils.TrainState(optimizer=optimizer)
|
77 |
+
del optimizer, init_variables
|
78 |
+
|
79 |
+
#lpips_model = tf_hub.load(LPIPS_TFHUB_PATH)
|
80 |
+
|
81 |
+
# Rendering is forced to be deterministic even if training was randomized, as
|
82 |
+
# this eliminates "speckle" artifacts.
|
83 |
+
def render_fn(variables, key_0, key_1, rays):
|
84 |
+
return jax.lax.all_gather(
|
85 |
+
model.apply(variables, key_0, key_1, rays, False), axis_name="batch")
|
86 |
+
|
87 |
+
# pmap over only the data input.
|
88 |
+
render_pfn = jax.pmap(
|
89 |
+
render_fn,
|
90 |
+
in_axes=(None, None, None, 0),
|
91 |
+
donate_argnums=3,
|
92 |
+
axis_name="batch",
|
93 |
+
)
|
94 |
+
|
95 |
+
# Compiling to the CPU because it's faster and more accurate.
|
96 |
+
ssim_fn = jax.jit(
|
97 |
+
functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")
|
98 |
+
|
99 |
+
last_step = 0
|
100 |
+
out_dir = path.join(FLAGS.train_dir,
|
101 |
+
"path_renders" if FLAGS.render_path else "test_preds")
|
102 |
+
if not FLAGS.eval_once:
|
103 |
+
summary_writer = tensorboard.SummaryWriter(
|
104 |
+
path.join(FLAGS.train_dir, "eval"))
|
105 |
+
while True:
|
106 |
+
state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
|
107 |
+
step = int(state.optimizer.state.step)
|
108 |
+
if step <= last_step:
|
109 |
+
continue
|
110 |
+
if FLAGS.save_output and (not utils.isdir(out_dir)):
|
111 |
+
utils.makedirs(out_dir)
|
112 |
+
psnr_values = []
|
113 |
+
ssim_values = []
|
114 |
+
#lpips_values = []
|
115 |
+
if not FLAGS.eval_once:
|
116 |
+
showcase_index = np.random.randint(0, dataset.size)
|
117 |
+
for idx in range(dataset.sizerender_image):
|
118 |
+
print(f"Evaluating {idx + 1}/{dataset.size}")
|
119 |
+
batch = next(dataset)
|
120 |
+
pred_color, pred_disp, pred_acc = utils.render_image(
|
121 |
+
functools.partial(render_pfn, state.optimizer.target),
|
122 |
+
batch["rays"],
|
123 |
+
rng,
|
124 |
+
FLAGS.dataset == "llff",
|
125 |
+
chunk=FLAGS.chunk)
|
126 |
+
if jax.host_id() != 0: # Only record via host 0.
|
127 |
+
continue
|
128 |
+
if not FLAGS.eval_once and idx == showcase_index:
|
129 |
+
showcase_color = pred_color
|
130 |
+
showcase_disp = pred_disp
|
131 |
+
showcase_acc = pred_acc
|
132 |
+
if not FLAGS.render_path:
|
133 |
+
showcase_gt = batch["pixels"]
|
134 |
+
if not FLAGS.render_path:
|
135 |
+
psnr = utils.compute_psnr(((pred_color - batch["pixels"]) ** 2).mean())
|
136 |
+
ssim = ssim_fn(pred_color, batch["pixels"])
|
137 |
+
#lpips = compute_lpips(pred_color, batch["pixels"], lpips_model)
|
138 |
+
print(f"PSNR = {psnr:.4f}, SSIM = {ssim:.4f}")
|
139 |
+
psnr_values.append(float(psnr))
|
140 |
+
ssim_values.append(float(ssim))
|
141 |
+
#lpips_values.append(float(lpips))
|
142 |
+
if FLAGS.save_output:
|
143 |
+
utils.save_img(pred_color, path.join(out_dir, "{:03d}.png".format(idx)))
|
144 |
+
utils.save_img(pred_disp[Ellipsis, 0],
|
145 |
+
path.join(out_dir, "disp_{:03d}.png".format(idx)))
|
146 |
+
if (not FLAGS.eval_once) and (jax.host_id() == 0):
|
147 |
+
summary_writer.image("pred_color", showcase_color, step)
|
148 |
+
summary_writer.image("pred_disp", showcase_disp, step)
|
149 |
+
summary_writer.image("pred_acc", showcase_acc, step)
|
150 |
+
if not FLAGS.render_path:
|
151 |
+
summary_writer.scalar("psnr", np.mean(np.array(psnr_values)), step)
|
152 |
+
summary_writer.scalar("ssim", np.mean(np.array(ssim_values)), step)
|
153 |
+
#summary_writer.scalar("lpips", np.mean(np.array(lpips_values)), step)
|
154 |
+
summary_writer.image("target", showcase_gt, step)
|
155 |
+
if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
|
156 |
+
with utils.open_file(path.join(out_dir, f"psnrs_{step}.txt"), "w") as f:
|
157 |
+
f.write(" ".join([str(v) for v in psnr_values]))
|
158 |
+
with utils.open_file(path.join(out_dir, f"ssims_{step}.txt"), "w") as f:
|
159 |
+
f.write(" ".join([str(v) for v in ssim_values]))
|
160 |
+
#with utils.open_file(path.join(out_dir, f"lpips_{step}.txt"), "w") as f:
|
161 |
+
#f.write(" ".join([str(v) for v in lpips_values]))
|
162 |
+
with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as f:
|
163 |
+
f.write("{}".format(np.mean(np.array(psnr_values))))
|
164 |
+
with utils.open_file(path.join(out_dir, "ssim.txt"), "w") as f:
|
165 |
+
f.write("{}".format(np.mean(np.array(ssim_values))))
|
166 |
+
#with utils.open_file(path.join(out_dir, "lpips.txt"), "w") as f:
|
167 |
+
#f.write("{}".format(np.mean(np.array(lpips_values))))
|
168 |
+
imglist = glob.glob(os.path.join(out_dir, "[0-9][0-9][0-9].png"))
|
169 |
+
sorted_files = sorted(imglist, key=lambda x: int(x.split('/')[-1].split('.')[0]))
|
170 |
+
imglist2 = glob.glob(os.path.join(out_dir, "disp_[0-9][0-9][0-9].png"))
|
171 |
+
sorted_files2 = sorted(imglist2, key=lambda x: int(x.split('/')[-1].split('.')[0].split('_')[-1]))
|
172 |
+
fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
173 |
+
fps = 10.0
|
174 |
+
out = cv2.VideoWriter(os.path.join(out_dir, "rendering_video.mp4"), fourcc, fps,
|
175 |
+
(2 * img.shape[1], img.shape[0]))
|
176 |
+
|
177 |
+
for i in range(len(imglist)):
|
178 |
+
img = cv2.imread(imglist[i], cv2.IMREAD_COLOR)
|
179 |
+
img2 = cv2.imread(imglist2[i], cv2.IMREAD_COLOR)
|
180 |
+
catimg = np.concatenate((img, img2), axis=1)
|
181 |
+
out.write(catimg)
|
182 |
+
|
183 |
+
out.release()
|
184 |
+
if FLAGS.eval_once:
|
185 |
+
break
|
186 |
+
if int(step) >= FLAGS.max_steps:
|
187 |
+
break
|
188 |
+
last_step = step
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
app.run(main)
|
jaxnerf/eval.sh
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The Google Research Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
#!/bin/bash
|
16 |
+
CONFIG=$1
|
17 |
+
DATA_ROOT=$2
|
18 |
+
ROOT_DIR=/tmp/jaxnerf/"$CONFIG"
|
19 |
+
if [ $CONFIG == "llff" ]
|
20 |
+
then
|
21 |
+
SCENES="room fern leaves fortress orchids flower trex horns"
|
22 |
+
DATA_FOLDER="nerf_llff_data"
|
23 |
+
else
|
24 |
+
SCENES="lego chair drums ficus hotdog materials mic ship"
|
25 |
+
DATA_FOLDER="nerf_synthetic"
|
26 |
+
fi
|
27 |
+
|
28 |
+
# launch evaluation jobs for all scenes.
|
29 |
+
for scene in $SCENES; do
|
30 |
+
python -m jaxnerf.eval \
|
31 |
+
--data_dir="$DATA_ROOT"/"$DATA_FOLDER"/"$scene" \
|
32 |
+
--train_dir="$ROOT_DIR"/"$scene" \
|
33 |
+
--chunk=4096 \
|
34 |
+
--config=configs/"$CONFIG"
|
35 |
+
done
|
36 |
+
|
37 |
+
# collect PSNR of all scenes.
|
38 |
+
touch "$ROOT_DIR"/psnr.txt
|
39 |
+
for scene in $SCENES; do
|
40 |
+
printf "${scene}: " >> "$ROOT_DIR"/psnr.txt
|
41 |
+
cat "$ROOT_DIR"/"$scene"/test_preds/psnr.txt >> \
|
42 |
+
"$ROOT_DIR"/psnr.txt
|
43 |
+
printf $'\n' >> "$ROOT_DIR"/psnr.txt
|
44 |
+
done
|
jaxnerf/example_data/imgs/r_0.png
ADDED
jaxnerf/example_data/transforms_test.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]}
|
jaxnerf/example_data/transforms_train.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"camera_angle_x": 0.6911112070083618, "frames": [{"file_path": "./imgs/r_0", "rotation": 0.012566370614359171, "transform_matrix": [[-0.9999021887779236, 0.004192245192825794, -0.013345719315111637, -0.05379832163453102], [-0.013988681137561798, -0.2996590733528137, 0.95394366979599, 3.845470428466797], [-4.656612873077393e-10, 0.9540371894836426, 0.29968830943107605, 1.2080823183059692], [0.0, 0.0, 0.0, 1.0]]}]}
|
jaxnerf/nerf/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
jaxnerf/nerf/clip_utils.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
from absl import flags
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import jax
|
7 |
+
from jax import random
|
8 |
+
import jax.numpy as jnp
|
9 |
+
import numpy as np
|
10 |
+
from transformers import FlaxCLIPModel
|
11 |
+
|
12 |
+
FLAGS = flags.FLAGS
|
13 |
+
# import jmp
|
14 |
+
# my_policy = jmp.Policy(compute_dtype=np.float16,
|
15 |
+
# param_dtype=np.float16,
|
16 |
+
# output_dtype=np.float16)
|
17 |
+
|
18 |
+
|
19 |
+
@partial(jax.jit, static_argnums=[0, 1])
|
20 |
+
def update_semantic_loss(model, clip_model, rng, state, batch, lr):
|
21 |
+
# the batch is without shard
|
22 |
+
random_rays = batch["random_rays"]
|
23 |
+
#rng, key_0, key_1 = rng
|
24 |
+
rng, key_0, key_1 = random.split(rng,3)
|
25 |
+
|
26 |
+
def semantic_loss(variables):
|
27 |
+
# TODO @Alex: (alt) sample less along a ray/ sample on a strided grid (make change on model call)
|
28 |
+
# TODO @Alex: (alt) apply mixed precision
|
29 |
+
src_ret = model.apply(variables, key_0, key_1, random_rays, False)
|
30 |
+
src_image, _, _ = src_ret[-1]
|
31 |
+
# reshape flat pixel to an image (assume 3 channels & square shape)
|
32 |
+
w = int(math.sqrt(src_image.shape[0]))
|
33 |
+
src_image = src_image.reshape([-1, w, w, 3]).transpose(0, 3, 1, 2)
|
34 |
+
src_image = preprocess_for_CLIP(src_image)
|
35 |
+
src_embedding = clip_model.get_image_features(pixel_values=src_image)
|
36 |
+
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
37 |
+
src_embedding = jnp.array(src_embedding)
|
38 |
+
target_embedding = batch["embedding"]
|
39 |
+
sc_loss = 0.5 * FLAGS.sc_loss_mult * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0]
|
40 |
+
return sc_loss * 1e-2
|
41 |
+
|
42 |
+
sc_loss, grad = jax.value_and_grad(semantic_loss)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
43 |
+
return sc_loss, grad
|
44 |
+
|
45 |
+
def trans_t(t):
|
46 |
+
return jnp.array([
|
47 |
+
[1, 0, 0, 0],
|
48 |
+
[0, 1, 0, 0],
|
49 |
+
[0, 0, 1, t],
|
50 |
+
[0, 0, 0, 1]], dtype=jnp.float32)
|
51 |
+
|
52 |
+
|
53 |
+
def rot_phi(phi):
|
54 |
+
return jnp.array([
|
55 |
+
[1, 0, 0, 0],
|
56 |
+
[0, jnp.cos(phi), -np.sin(phi), 0],
|
57 |
+
[0, jnp.sin(phi), jnp.cos(phi), 0],
|
58 |
+
[0, 0, 0, 1]], dtype=jnp.float32)
|
59 |
+
|
60 |
+
|
61 |
+
def rot_theta(th):
|
62 |
+
return jnp.array([
|
63 |
+
[np.cos(th), 0, -np.sin(th), 0],
|
64 |
+
[0, 1, 0, 0],
|
65 |
+
[np.sin(th), 0, jnp.cos(th), 0],
|
66 |
+
[0, 0, 0, 1]], dtype=jnp.float32)
|
67 |
+
|
68 |
+
|
69 |
+
def pose_spherical(theta, phi, radius):
|
70 |
+
c2w = trans_t(radius)
|
71 |
+
c2w = rot_phi(phi / 180. * jnp.pi) @ c2w
|
72 |
+
c2w = rot_theta(theta / 180. * jnp.pi) @ c2w
|
73 |
+
c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
|
74 |
+
return c2w
|
75 |
+
|
76 |
+
|
77 |
+
def random_pose(rng, bds):
|
78 |
+
rng, *rng_inputs = jax.random.split(rng, 3)
|
79 |
+
radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
|
80 |
+
theta = random.uniform(rng_inputs[1], minval=0, maxval=2 * jnp.pi)
|
81 |
+
phi = random.uniform(rng_inputs[1], minval=0, maxval=np.pi / 2)
|
82 |
+
return pose_spherical(radius, theta, phi)
|
83 |
+
|
84 |
+
|
85 |
+
def preprocess_for_CLIP(image):
|
86 |
+
"""
|
87 |
+
jax-based preprocessing for CLIP
|
88 |
+
image [B, 3, H, W]: batch image
|
89 |
+
return [B, 3, 224, 224]: pre-processed image for CLIP
|
90 |
+
"""
|
91 |
+
B, D, H, W = image.shape
|
92 |
+
image = jax.image.resize(image, (B, D, 224, 224), 'bicubic') # assume that images have rectangle shape.
|
93 |
+
mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
|
94 |
+
std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
|
95 |
+
image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
|
96 |
+
return image
|
97 |
+
|
98 |
+
|
99 |
+
# TODO @Alex: VisionModel v.s. original CLIP? (differ by a projection matrix)
|
100 |
+
def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
|
101 |
+
if dtype == 'float16':
|
102 |
+
dtype = jnp.float16
|
103 |
+
elif dtype == 'float32':
|
104 |
+
dtype = jnp.float32
|
105 |
+
else:
|
106 |
+
raise ValueError
|
107 |
+
|
108 |
+
if model_name is None:
|
109 |
+
model_name = 'openai/clip-vit-base-patch32'
|
110 |
+
return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)
|
111 |
+
|
112 |
+
|
113 |
+
# def SC_loss(rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l):
|
114 |
+
# """
|
115 |
+
# target_emb [1, D]: pre-computed target embedding vector \phi(I)
|
116 |
+
# source_img [1, 3, H, W]: source image \hat{I}
|
117 |
+
# l: loss weight lambda
|
118 |
+
# return: SC_loss
|
119 |
+
# """
|
120 |
+
# # _,H,W,D = rays.shape
|
121 |
+
# rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l = my_policy.cast_to_compute(
|
122 |
+
# (rng_inputs, model, params, bds, rays, N_samples, target_emb, CLIP_model, l))
|
123 |
+
# _, H, W, _ = rays.shape
|
124 |
+
# source_img = jnp.clip(render_fn(rng_inputs, model, params, None,
|
125 |
+
# np.reshape(rays, (2, -1, 3)),
|
126 |
+
# bds[0], bds[1], 1, rand=False),
|
127 |
+
# 0, 1)
|
128 |
+
# # source_img = np.clip(render_rays(rng_inputs, model, params, None, np.reshape(rays, (2, -1, 3)), bds[0], bds[1], 1, rand=False), 0, 1)
|
129 |
+
# source_img = np.reshape(source_img, [1, H, W, 3]).transpose(0, 3, 1, 2)
|
130 |
+
# source_img = preprocess_for_CLIP(source_img)
|
131 |
+
# source_emb = CLIP_model.get_image_features(pixel_values=source_img)
|
132 |
+
# source_emb /= np.linalg.norm(source_emb, axis=-1, keepdims=True)
|
133 |
+
# return l/2 * (np.sum((source_emb - target_emb) ** 2) / source_emb.shape[0])
|
134 |
+
|
jaxnerf/nerf/datasets.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Different datasets implementation plus a general port for all the datasets."""
|
18 |
+
INTERNAL = False # pylint: disable=g-statement-before-imports
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
from os import path
|
22 |
+
import queue
|
23 |
+
import threading
|
24 |
+
|
25 |
+
if not INTERNAL:
|
26 |
+
import cv2 # pylint: disable=g-import-not-at-top
|
27 |
+
import jax
|
28 |
+
import numpy as np
|
29 |
+
from PIL import Image
|
30 |
+
|
31 |
+
from jaxnerf.nerf import utils
|
32 |
+
from jaxnerf.nerf import clip_utils
|
33 |
+
|
34 |
+
def get_dataset(split, args, clip_model = None):
|
35 |
+
return dataset_dict[args.dataset](split, args, clip_model)
|
36 |
+
|
37 |
+
|
38 |
+
def convert_to_ndc(origins, directions, focal, w, h, near=1.):
|
39 |
+
"""Convert a set of rays to NDC coordinates."""
|
40 |
+
# Shift ray origins to near plane
|
41 |
+
t = -(near + origins[..., 2]) / directions[..., 2]
|
42 |
+
origins = origins + t[..., None] * directions
|
43 |
+
|
44 |
+
dx, dy, dz = tuple(np.moveaxis(directions, -1, 0))
|
45 |
+
ox, oy, oz = tuple(np.moveaxis(origins, -1, 0))
|
46 |
+
|
47 |
+
# Projection
|
48 |
+
o0 = -((2 * focal) / w) * (ox / oz)
|
49 |
+
o1 = -((2 * focal) / h) * (oy / oz)
|
50 |
+
o2 = 1 + 2 * near / oz
|
51 |
+
|
52 |
+
d0 = -((2 * focal) / w) * (dx / dz - ox / oz)
|
53 |
+
d1 = -((2 * focal) / h) * (dy / dz - oy / oz)
|
54 |
+
d2 = -2 * near / oz
|
55 |
+
|
56 |
+
origins = np.stack([o0, o1, o2], -1)
|
57 |
+
directions = np.stack([d0, d1, d2], -1)
|
58 |
+
return origins, directions
|
59 |
+
|
60 |
+
|
61 |
+
class Dataset(threading.Thread):
|
62 |
+
"""Dataset Base Class."""
|
63 |
+
|
64 |
+
def __init__(self, split, flags, clip_model):
|
65 |
+
super(Dataset, self).__init__()
|
66 |
+
self.queue = queue.Queue(3) # Set prefetch buffer to 3 batches.
|
67 |
+
self.daemon = True
|
68 |
+
self.use_pixel_centers = flags.use_pixel_centers
|
69 |
+
self.split = split
|
70 |
+
|
71 |
+
if split == "train":
|
72 |
+
self._train_init(flags, clip_model)
|
73 |
+
elif split == "test":
|
74 |
+
self._test_init(flags)
|
75 |
+
else:
|
76 |
+
raise ValueError(
|
77 |
+
"the split argument should be either \"train\" or \"test\", set"
|
78 |
+
"to {} here.".format(split))
|
79 |
+
self.batch_size = flags.batch_size // jax.process_count()
|
80 |
+
self.batching = flags.batching
|
81 |
+
self.render_path = flags.render_path
|
82 |
+
self.far = flags.far
|
83 |
+
self.near = flags.near
|
84 |
+
self.max_steps = flags.max_steps
|
85 |
+
self.sc_loss_factor = flags.sc_loss_factor
|
86 |
+
self.start()
|
87 |
+
|
88 |
+
def __iter__(self):
|
89 |
+
return self
|
90 |
+
|
91 |
+
def __next__(self):
|
92 |
+
"""Get the next training batch or test example.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
batch: dict, has "pixels" and "rays".
|
96 |
+
"""
|
97 |
+
x = self.queue.get()
|
98 |
+
if self.split == "train":
|
99 |
+
return utils.shard(x)
|
100 |
+
else:
|
101 |
+
return utils.to_device(x)
|
102 |
+
|
103 |
+
def peek(self):
|
104 |
+
"""Peek at the next training batch or test example without dequeuing it.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
batch: dict, has "pixels" and "rays".
|
108 |
+
"""
|
109 |
+
x = self.queue.queue[0].copy() # Make a copy of the front of the queue.
|
110 |
+
if self.split == "train":
|
111 |
+
return utils.shard(x)
|
112 |
+
else:
|
113 |
+
return utils.to_device(x)
|
114 |
+
|
115 |
+
def run(self):
|
116 |
+
if self.split == "train":
|
117 |
+
next_func = self._next_train
|
118 |
+
else:
|
119 |
+
next_func = self._next_test
|
120 |
+
while True:
|
121 |
+
self.queue.put(next_func())
|
122 |
+
|
123 |
+
@property
|
124 |
+
def size(self):
|
125 |
+
return self.n_examples
|
126 |
+
|
127 |
+
def _train_init(self, flags, clip_model):
|
128 |
+
"""Initialize training."""
|
129 |
+
self._load_renderings(flags, clip_model)
|
130 |
+
self._generate_rays()
|
131 |
+
|
132 |
+
if flags.batching == "all_images":
|
133 |
+
# flatten the ray and image dimension together.
|
134 |
+
self.images = self.images.reshape([-1, 3])
|
135 |
+
self.rays = utils.namedtuple_map(lambda r: r.reshape([-1, r.shape[-1]]),
|
136 |
+
self.rays)
|
137 |
+
elif flags.batching == "single_image":
|
138 |
+
self.images = self.images.reshape([-1, self.resolution, 3])
|
139 |
+
self.rays = utils.namedtuple_map(
|
140 |
+
lambda r: r.reshape([-1, self.resolution, r.shape[-1]]), self.rays)
|
141 |
+
else:
|
142 |
+
raise NotImplementedError(
|
143 |
+
f"{flags.batching} batching strategy is not implemented.")
|
144 |
+
|
145 |
+
def _test_init(self, flags):
|
146 |
+
self._load_renderings(flags, clip_model = None)
|
147 |
+
self._generate_rays()
|
148 |
+
self.it = 0
|
149 |
+
|
150 |
+
def _next_train(self):
|
151 |
+
"""Sample next training batch."""
|
152 |
+
|
153 |
+
if self.batching == "all_images":
|
154 |
+
ray_indices = np.random.randint(0, self.rays[0].shape[0],
|
155 |
+
(self.batch_size,))
|
156 |
+
batch_pixels = self.images[ray_indices]
|
157 |
+
batch_rays = utils.namedtuple_map(lambda r: r[ray_indices], self.rays)
|
158 |
+
raise NotImplementedError("image_index not implemented for batching=all_images")
|
159 |
+
|
160 |
+
elif self.batching == "single_image":
|
161 |
+
image_index = np.random.randint(0, self.n_examples, ())
|
162 |
+
ray_indices = np.random.randint(0, self.rays[0][0].shape[0],
|
163 |
+
(self.batch_size,))
|
164 |
+
batch_pixels = self.images[image_index][ray_indices]
|
165 |
+
batch_rays = utils.namedtuple_map(lambda r: r[image_index][ray_indices],
|
166 |
+
self.rays)
|
167 |
+
else:
|
168 |
+
raise NotImplementedError(
|
169 |
+
f"{self.batching} batching strategy is not implemented.")
|
170 |
+
return {"pixels": batch_pixels, "rays": batch_rays, "image_index": image_index}
|
171 |
+
|
172 |
+
def _next_test(self):
|
173 |
+
"""Sample next test example."""
|
174 |
+
idx = self.it
|
175 |
+
self.it = (self.it + 1) % self.n_examples
|
176 |
+
|
177 |
+
if self.render_path:
|
178 |
+
return {"rays": utils.namedtuple_map(lambda r: r[idx], self.render_rays)}
|
179 |
+
else:
|
180 |
+
return {"pixels": self.images[idx],
|
181 |
+
"rays": utils.namedtuple_map(lambda r: r[idx], self.rays),
|
182 |
+
"image_index": idx}
|
183 |
+
|
184 |
+
# TODO(bydeng): Swap this function with a more flexible camera model.
|
185 |
+
def _generate_rays(self):
|
186 |
+
"""Generating rays for all images."""
|
187 |
+
pixel_center = 0.5 if self.use_pixel_centers else 0.0
|
188 |
+
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
|
189 |
+
np.arange(self.w, dtype=np.float32) + pixel_center, # X-Axis (columns)
|
190 |
+
np.arange(self.h, dtype=np.float32) + pixel_center, # Y-Axis (rows)
|
191 |
+
indexing="xy")
|
192 |
+
camera_dirs = np.stack([(x - self.w * 0.5) / self.focal,
|
193 |
+
-(y - self.h * 0.5) / self.focal, -np.ones_like(x)],
|
194 |
+
axis=-1)
|
195 |
+
directions = ((camera_dirs[None, ..., None, :] *
|
196 |
+
self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1))
|
197 |
+
origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1],
|
198 |
+
directions.shape)
|
199 |
+
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
|
200 |
+
self.rays = utils.Rays(
|
201 |
+
origins=origins, directions=directions, viewdirs=viewdirs)
|
202 |
+
|
203 |
+
def camtoworld_matrix_to_rays(self, camtoworld, downsample = 1):
|
204 |
+
""" render one instance of rays given a camera to world matrix (4, 4) """
|
205 |
+
pixel_center = 0.5 if self.use_pixel_centers else 0.0
|
206 |
+
# TODO @Alex: apply mesh downsampling here
|
207 |
+
x, y = np.meshgrid( # pylint: disable=unbalanced-tuple-unpacking
|
208 |
+
np.arange(self.w, step = downsample, dtype=np.float32) + pixel_center, # X-Axis (columns)
|
209 |
+
np.arange(self.h, step = downsample, dtype=np.float32) + pixel_center, # Y-Axis (rows)
|
210 |
+
indexing="xy")
|
211 |
+
camera_dirs = np.stack([(x - self.w * 0.5) / self.focal,
|
212 |
+
-(y - self.h * 0.5) / self.focal, -np.ones_like(x)],
|
213 |
+
axis=-1)
|
214 |
+
directions = (camera_dirs[..., None, :] * camtoworld[None, None, :3, :3]).sum(axis=-1)
|
215 |
+
origins = np.broadcast_to(camtoworld[None, None, :3, -1], directions.shape)
|
216 |
+
viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
|
217 |
+
return utils.Rays(origins=origins, directions=directions, viewdirs=viewdirs)
|
218 |
+
|
219 |
+
class Blender(Dataset):
|
220 |
+
"""Blender Dataset."""
|
221 |
+
|
222 |
+
def _load_renderings(self, flags, clip_model = None):
|
223 |
+
"""Load images from disk."""
|
224 |
+
if flags.render_path:
|
225 |
+
raise ValueError("render_path cannot be used for the blender dataset.")
|
226 |
+
cams, images, meta = self.load_files(flags.data_dir, self.split, flags.factor, flags.few_shot)
|
227 |
+
|
228 |
+
# load in CLIP precomputed image features
|
229 |
+
self.images = np.stack(images, axis=0)
|
230 |
+
if flags.white_bkgd:
|
231 |
+
self.images = (self.images[..., :3] * self.images[..., -1:] +
|
232 |
+
(1. - self.images[..., -1:]))
|
233 |
+
else:
|
234 |
+
self.images = self.images[..., :3]
|
235 |
+
self.h, self.w = self.images.shape[1:3]
|
236 |
+
self.resolution = self.h * self.w
|
237 |
+
self.camtoworlds = np.stack(cams, axis=0)
|
238 |
+
camera_angle_x = float(meta["camera_angle_x"])
|
239 |
+
self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
|
240 |
+
self.n_examples = self.images.shape[0]
|
241 |
+
|
242 |
+
if flags.use_semantic_loss and clip_model is not None:
|
243 |
+
embs = []
|
244 |
+
for img in self.images:
|
245 |
+
img = np.expand_dims(np.transpose(img,[2,0,1]), 0)
|
246 |
+
embs.append(clip_model.get_image_features(pixel_values = clip_utils.preprocess_for_CLIP(img)))
|
247 |
+
self.embeddings = np.concatenate(embs, 0)
|
248 |
+
|
249 |
+
self.image_idx = np.arange(self.images.shape[0])
|
250 |
+
np.random.shuffle(self.image_idx)
|
251 |
+
self.image_idx = self.image_idx.tolist()
|
252 |
+
|
253 |
+
# self.embeddings = utils.read_pickle(flags.precompute_pkl_path)
|
254 |
+
# self.precompute_pkl_path = flags.precompute_pkl_path
|
255 |
+
|
256 |
+
|
257 |
+
@staticmethod
|
258 |
+
def load_files(data_dir, split, factor, few_shot):
|
259 |
+
with utils.open_file(path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp:
|
260 |
+
meta = json.load(fp)
|
261 |
+
images = []
|
262 |
+
cams = []
|
263 |
+
|
264 |
+
frames = np.arange(len(meta["frames"]))
|
265 |
+
if few_shot > 0 and split == 'train':
|
266 |
+
np.random.shuffle(frames)
|
267 |
+
frames = frames[:few_shot]
|
268 |
+
|
269 |
+
for i in frames:
|
270 |
+
frame = meta["frames"][i]
|
271 |
+
fname = os.path.join(data_dir, frame["file_path"] + ".png")
|
272 |
+
with utils.open_file(fname, "rb") as imgin:
|
273 |
+
image = np.array(Image.open(imgin)).astype(np.float32) / 255.
|
274 |
+
if factor == 2:
|
275 |
+
[halfres_h, halfres_w] = [hw // 2 for hw in image.shape[:2]]
|
276 |
+
image = cv2.resize(image, (halfres_w, halfres_h),
|
277 |
+
interpolation=cv2.INTER_AREA)
|
278 |
+
elif factor == 4:
|
279 |
+
[halfres_h, halfres_w] = [hw // 4 for hw in image.shape[:2]]
|
280 |
+
image = cv2.resize(image, (halfres_w, halfres_h),
|
281 |
+
interpolation=cv2.INTER_AREA)
|
282 |
+
elif factor > 0:
|
283 |
+
raise ValueError("Blender dataset only supports factor=0 or 2 or 4, {} "
|
284 |
+
"set.".format(factor))
|
285 |
+
cams.append(np.array(frame["transform_matrix"], dtype=np.float32))
|
286 |
+
images.append(image)
|
287 |
+
return cams, images, meta
|
288 |
+
|
289 |
+
def _next_train(self):
|
290 |
+
batch_dict = super(Blender, self)._next_train()
|
291 |
+
if self.batching == "single_image":
|
292 |
+
image_index = batch_dict.pop("image_index")
|
293 |
+
# target image for CLIP
|
294 |
+
'''
|
295 |
+
batch_dict["embedding"] = self.embeddings[image_index]
|
296 |
+
|
297 |
+
# source rays for CLIP (for constructing source image later)
|
298 |
+
src_seed = int(np.random.randint(0, self.max_steps, ()))
|
299 |
+
src_rng = jax.random.PRNGKey(src_seed)
|
300 |
+
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
|
301 |
+
random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 16)
|
302 |
+
random_rays = utils.Rays(origins=np.reshape(random_rays[0], [-1,3]), directions=np.reshape(random_rays[1], [-1,3]), viewdirs=np.reshape(random_rays[2], [-1,3]))
|
303 |
+
batch_dict["random_rays"] = random_rays
|
304 |
+
'''
|
305 |
+
else:
|
306 |
+
raise NotImplementedError
|
307 |
+
return batch_dict
|
308 |
+
|
309 |
+
def get_clip_data(self):
|
310 |
+
if len(self.image_idx) == 0:
|
311 |
+
self.image_idx = np.arange(self.images.shape[0])
|
312 |
+
np.random.shuffle(self.image_idx)
|
313 |
+
self.image_idx = self.image_idx.tolist()
|
314 |
+
image_index = self.image_idx.pop()
|
315 |
+
|
316 |
+
batch_dict = {}
|
317 |
+
batch_dict["embedding"] = self.embeddings[image_index]
|
318 |
+
|
319 |
+
# source rays for CLIP (for constructing source image later)
|
320 |
+
src_seed = int(np.random.randint(0, self.max_steps, ()))
|
321 |
+
src_rng = jax.random.PRNGKey(src_seed)
|
322 |
+
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
|
323 |
+
random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 16)
|
324 |
+
random_rays = utils.Rays(origins=np.reshape(random_rays[0], [-1,3]), directions=np.reshape(random_rays[1], [-1,3]), viewdirs=np.reshape(random_rays[2], [-1,3]))
|
325 |
+
batch_dict["random_rays"] = random_rays
|
326 |
+
return batch_dict
|
327 |
+
|
328 |
+
class LLFF(Dataset):
|
329 |
+
"""LLFF Dataset."""
|
330 |
+
|
331 |
+
def _load_renderings(self, flags):
|
332 |
+
"""Load images from disk."""
|
333 |
+
# Load images.
|
334 |
+
imgdir_suffix = ""
|
335 |
+
if flags.factor > 0:
|
336 |
+
imgdir_suffix = "_{}".format(flags.factor)
|
337 |
+
factor = flags.factor
|
338 |
+
else:
|
339 |
+
factor = 1
|
340 |
+
imgdir = path.join(flags.data_dir, "images" + imgdir_suffix)
|
341 |
+
if not utils.file_exists(imgdir):
|
342 |
+
raise ValueError("Image folder {} doesn't exist.".format(imgdir))
|
343 |
+
imgfiles = [
|
344 |
+
path.join(imgdir, f)
|
345 |
+
for f in sorted(utils.listdir(imgdir))
|
346 |
+
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
|
347 |
+
]
|
348 |
+
images = []
|
349 |
+
for imgfile in imgfiles:
|
350 |
+
with utils.open_file(imgfile, "rb") as imgin:
|
351 |
+
image = np.array(Image.open(imgin), dtype=np.float32) / 255.
|
352 |
+
images.append(image)
|
353 |
+
images = np.stack(images, axis=-1)
|
354 |
+
|
355 |
+
# Load poses and bds.
|
356 |
+
with utils.open_file(path.join(flags.data_dir, "poses_bounds.npy"),
|
357 |
+
"rb") as fp:
|
358 |
+
poses_arr = np.load(fp)
|
359 |
+
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
|
360 |
+
bds = poses_arr[:, -2:].transpose([1, 0])
|
361 |
+
if poses.shape[-1] != images.shape[-1]:
|
362 |
+
raise RuntimeError("Mismatch between imgs {} and poses {}".format(
|
363 |
+
images.shape[-1], poses.shape[-1]))
|
364 |
+
|
365 |
+
# Update poses according to downsampling.
|
366 |
+
poses[:2, 4, :] = np.array(images.shape[:2]).reshape([2, 1])
|
367 |
+
poses[2, 4, :] = poses[2, 4, :] * 1. / factor
|
368 |
+
|
369 |
+
# Correct rotation matrix ordering and move variable dim to axis 0.
|
370 |
+
poses = np.concatenate(
|
371 |
+
[poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
|
372 |
+
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
|
373 |
+
images = np.moveaxis(images, -1, 0)
|
374 |
+
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
|
375 |
+
|
376 |
+
# Rescale according to a default bd factor.
|
377 |
+
scale = 1. / (bds.min() * .75)
|
378 |
+
poses[:, :3, 3] *= scale
|
379 |
+
bds *= scale
|
380 |
+
|
381 |
+
# Recenter poses.
|
382 |
+
poses = self._recenter_poses(poses)
|
383 |
+
|
384 |
+
# Generate a spiral/spherical ray path for rendering videos.
|
385 |
+
if flags.spherify:
|
386 |
+
poses = self._generate_spherical_poses(poses, bds)
|
387 |
+
self.spherify = True
|
388 |
+
else:
|
389 |
+
self.spherify = False
|
390 |
+
if not flags.spherify and self.split == "test":
|
391 |
+
self._generate_spiral_poses(poses, bds)
|
392 |
+
|
393 |
+
# Select the split.
|
394 |
+
i_test = np.arange(images.shape[0])[::flags.llffhold]
|
395 |
+
i_train = np.array(
|
396 |
+
[i for i in np.arange(int(images.shape[0])) if i not in i_test])
|
397 |
+
if self.split == "train":
|
398 |
+
indices = i_train
|
399 |
+
else:
|
400 |
+
indices = i_test
|
401 |
+
images = images[indices]
|
402 |
+
poses = poses[indices]
|
403 |
+
|
404 |
+
self.images = images
|
405 |
+
self.camtoworlds = poses[:, :3, :4]
|
406 |
+
self.focal = poses[0, -1, -1]
|
407 |
+
self.h, self.w = images.shape[1:3]
|
408 |
+
self.resolution = self.h * self.w
|
409 |
+
if flags.render_path:
|
410 |
+
self.n_examples = self.render_poses.shape[0]
|
411 |
+
else:
|
412 |
+
self.n_examples = images.shape[0]
|
413 |
+
|
414 |
+
def _generate_rays(self):
|
415 |
+
"""Generate normalized device coordinate rays for llff."""
|
416 |
+
if self.split == "test":
|
417 |
+
n_render_poses = self.render_poses.shape[0]
|
418 |
+
self.camtoworlds = np.concatenate([self.render_poses, self.camtoworlds],
|
419 |
+
axis=0)
|
420 |
+
|
421 |
+
super()._generate_rays()
|
422 |
+
|
423 |
+
if not self.spherify:
|
424 |
+
ndc_origins, ndc_directions = convert_to_ndc(self.rays.origins,
|
425 |
+
self.rays.directions,
|
426 |
+
self.focal, self.w, self.h)
|
427 |
+
self.rays = utils.Rays(
|
428 |
+
origins=ndc_origins,
|
429 |
+
directions=ndc_directions,
|
430 |
+
viewdirs=self.rays.viewdirs)
|
431 |
+
|
432 |
+
# Split poses from the dataset and generated poses
|
433 |
+
if self.split == "test":
|
434 |
+
self.camtoworlds = self.camtoworlds[n_render_poses:]
|
435 |
+
split = [np.split(r, [n_render_poses], 0) for r in self.rays]
|
436 |
+
split0, split1 = zip(*split)
|
437 |
+
self.render_rays = utils.Rays(*split0)
|
438 |
+
self.rays = utils.Rays(*split1)
|
439 |
+
|
440 |
+
def _recenter_poses(self, poses):
|
441 |
+
"""Recenter poses according to the original NeRF code."""
|
442 |
+
poses_ = poses.copy()
|
443 |
+
bottom = np.reshape([0, 0, 0, 1.], [1, 4])
|
444 |
+
c2w = self._poses_avg(poses)
|
445 |
+
c2w = np.concatenate([c2w[:3, :4], bottom], -2)
|
446 |
+
bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
|
447 |
+
poses = np.concatenate([poses[:, :3, :4], bottom], -2)
|
448 |
+
poses = np.linalg.inv(c2w) @ poses
|
449 |
+
poses_[:, :3, :4] = poses[:, :3, :4]
|
450 |
+
poses = poses_
|
451 |
+
return poses
|
452 |
+
|
453 |
+
def _poses_avg(self, poses):
|
454 |
+
"""Average poses according to the original NeRF code."""
|
455 |
+
hwf = poses[0, :3, -1:]
|
456 |
+
center = poses[:, :3, 3].mean(0)
|
457 |
+
vec2 = self._normalize(poses[:, :3, 2].sum(0))
|
458 |
+
up = poses[:, :3, 1].sum(0)
|
459 |
+
c2w = np.concatenate([self._viewmatrix(vec2, up, center), hwf], 1)
|
460 |
+
return c2w
|
461 |
+
|
462 |
+
def _viewmatrix(self, z, up, pos):
|
463 |
+
"""Construct lookat view matrix."""
|
464 |
+
vec2 = self._normalize(z)
|
465 |
+
vec1_avg = up
|
466 |
+
vec0 = self._normalize(np.cross(vec1_avg, vec2))
|
467 |
+
vec1 = self._normalize(np.cross(vec2, vec0))
|
468 |
+
m = np.stack([vec0, vec1, vec2, pos], 1)
|
469 |
+
return m
|
470 |
+
|
471 |
+
def _normalize(self, x):
|
472 |
+
"""Normalization helper function."""
|
473 |
+
return x / np.linalg.norm(x)
|
474 |
+
|
475 |
+
def _generate_spiral_poses(self, poses, bds):
|
476 |
+
"""Generate a spiral path for rendering."""
|
477 |
+
c2w = self._poses_avg(poses)
|
478 |
+
# Get average pose.
|
479 |
+
up = self._normalize(poses[:, :3, 1].sum(0))
|
480 |
+
# Find a reasonable "focus depth" for this dataset.
|
481 |
+
close_depth, inf_depth = bds.min() * .9, bds.max() * 5.
|
482 |
+
dt = .75
|
483 |
+
mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth))
|
484 |
+
focal = mean_dz
|
485 |
+
# Get radii for spiral path.
|
486 |
+
tt = poses[:, :3, 3]
|
487 |
+
rads = np.percentile(np.abs(tt), 90, 0)
|
488 |
+
c2w_path = c2w
|
489 |
+
n_views = 120
|
490 |
+
n_rots = 2
|
491 |
+
# Generate poses for spiral path.
|
492 |
+
render_poses = []
|
493 |
+
rads = np.array(list(rads) + [1.])
|
494 |
+
hwf = c2w_path[:, 4:5]
|
495 |
+
zrate = .5
|
496 |
+
for theta in np.linspace(0., 2. * np.pi * n_rots, n_views + 1)[:-1]:
|
497 |
+
c = np.dot(c2w[:3, :4], (np.array(
|
498 |
+
[np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads))
|
499 |
+
z = self._normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
|
500 |
+
render_poses.append(np.concatenate([self._viewmatrix(z, up, c), hwf], 1))
|
501 |
+
self.render_poses = np.array(render_poses).astype(np.float32)[:, :3, :4]
|
502 |
+
|
503 |
+
def _generate_spherical_poses(self, poses, bds):
|
504 |
+
"""Generate a 360 degree spherical path for rendering."""
|
505 |
+
# pylint: disable=g-long-lambda
|
506 |
+
p34_to_44 = lambda p: np.concatenate([
|
507 |
+
p,
|
508 |
+
np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])
|
509 |
+
], 1)
|
510 |
+
rays_d = poses[:, :3, 2:3]
|
511 |
+
rays_o = poses[:, :3, 3:4]
|
512 |
+
|
513 |
+
def min_line_dist(rays_o, rays_d):
|
514 |
+
a_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
|
515 |
+
b_i = -a_i @ rays_o
|
516 |
+
pt_mindist = np.squeeze(-np.linalg.inv(
|
517 |
+
(np.transpose(a_i, [0, 2, 1]) @ a_i).mean(0)) @ (b_i).mean(0))
|
518 |
+
return pt_mindist
|
519 |
+
|
520 |
+
pt_mindist = min_line_dist(rays_o, rays_d)
|
521 |
+
center = pt_mindist
|
522 |
+
up = (poses[:, :3, 3] - center).mean(0)
|
523 |
+
vec0 = self._normalize(up)
|
524 |
+
vec1 = self._normalize(np.cross([.1, .2, .3], vec0))
|
525 |
+
vec2 = self._normalize(np.cross(vec0, vec1))
|
526 |
+
pos = center
|
527 |
+
c2w = np.stack([vec1, vec2, vec0, pos], 1)
|
528 |
+
poses_reset = (
|
529 |
+
np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]))
|
530 |
+
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
|
531 |
+
sc = 1. / rad
|
532 |
+
poses_reset[:, :3, 3] *= sc
|
533 |
+
bds *= sc
|
534 |
+
rad *= sc
|
535 |
+
centroid = np.mean(poses_reset[:, :3, 3], 0)
|
536 |
+
zh = centroid[2]
|
537 |
+
radcircle = np.sqrt(rad ** 2 - zh ** 2)
|
538 |
+
new_poses = []
|
539 |
+
|
540 |
+
for th in np.linspace(0., 2. * np.pi, 120):
|
541 |
+
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
542 |
+
up = np.array([0, 0, -1.])
|
543 |
+
vec2 = self._normalize(camorigin)
|
544 |
+
vec0 = self._normalize(np.cross(vec2, up))
|
545 |
+
vec1 = self._normalize(np.cross(vec2, vec0))
|
546 |
+
pos = camorigin
|
547 |
+
p = np.stack([vec0, vec1, vec2, pos], 1)
|
548 |
+
new_poses.append(p)
|
549 |
+
|
550 |
+
new_poses = np.stack(new_poses, 0)
|
551 |
+
new_poses = np.concatenate([
|
552 |
+
new_poses,
|
553 |
+
np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)
|
554 |
+
], -1)
|
555 |
+
poses_reset = np.concatenate([
|
556 |
+
poses_reset[:, :3, :4],
|
557 |
+
np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)
|
558 |
+
], -1)
|
559 |
+
if self.split == "test":
|
560 |
+
self.render_poses = new_poses[:, :3, :4]
|
561 |
+
return poses_reset
|
562 |
+
|
563 |
+
|
564 |
+
dataset_dict = {"blender": Blender,
|
565 |
+
"llff": LLFF}
|
jaxnerf/nerf/model_utils.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Helper functions/classes for model definition."""
|
18 |
+
|
19 |
+
import functools
|
20 |
+
from typing import Any, Callable
|
21 |
+
|
22 |
+
from flax import linen as nn
|
23 |
+
import jax
|
24 |
+
from jax import lax
|
25 |
+
from jax import random
|
26 |
+
import jax.numpy as jnp
|
27 |
+
|
28 |
+
|
29 |
+
class MLP(nn.Module):
|
30 |
+
"""A simple MLP."""
|
31 |
+
net_depth: int = 8 # The depth of the first part of MLP.
|
32 |
+
net_width: int = 256 # The width of the first part of MLP.
|
33 |
+
net_depth_condition: int = 1 # The depth of the second part of MLP.
|
34 |
+
net_width_condition: int = 128 # The width of the second part of MLP.
|
35 |
+
net_activation: Callable[..., Any] = nn.relu # The activation function.
|
36 |
+
skip_layer: int = 4 # The layer to add skip layers to.
|
37 |
+
num_rgb_channels: int = 3 # The number of RGB channels.
|
38 |
+
num_sigma_channels: int = 1 # The number of sigma channels.
|
39 |
+
|
40 |
+
@nn.compact
|
41 |
+
def __call__(self, x, condition=None):
|
42 |
+
"""
|
43 |
+
Evaluate the MLP.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
x: jnp.ndarray(float32), [batch, num_samples, feature], points.
|
47 |
+
condition: jnp.ndarray(float32), [batch, feature], if not None, this
|
48 |
+
variable will be part of the input to the second part of the MLP
|
49 |
+
concatenated with the output vector of the first part of the MLP. If
|
50 |
+
None, only the first part of the MLP will be used with input x. In the
|
51 |
+
original paper, this variable is the view direction.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
raw_rgb: jnp.ndarray(float32), with a shape of
|
55 |
+
[batch, num_samples, num_rgb_channels].
|
56 |
+
raw_sigma: jnp.ndarray(float32), with a shape of
|
57 |
+
[batch, num_samples, num_sigma_channels].
|
58 |
+
"""
|
59 |
+
feature_dim = x.shape[-1]
|
60 |
+
num_samples = x.shape[1]
|
61 |
+
x = x.reshape([-1, feature_dim])
|
62 |
+
dense_layer = functools.partial(
|
63 |
+
nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform())
|
64 |
+
inputs = x
|
65 |
+
for i in range(self.net_depth):
|
66 |
+
x = dense_layer(self.net_width)(x)
|
67 |
+
x = self.net_activation(x)
|
68 |
+
if i % self.skip_layer == 0 and i > 0:
|
69 |
+
x = jnp.concatenate([x, inputs], axis=-1)
|
70 |
+
raw_sigma = dense_layer(self.num_sigma_channels)(x).reshape(
|
71 |
+
[-1, num_samples, self.num_sigma_channels])
|
72 |
+
|
73 |
+
if condition is not None:
|
74 |
+
# Output of the first part of MLP.
|
75 |
+
bottleneck = dense_layer(self.net_width)(x)
|
76 |
+
# Broadcast condition from [batch, feature] to
|
77 |
+
# [batch, num_samples, feature] since all the samples along the same ray
|
78 |
+
# have the same viewdir.
|
79 |
+
condition = jnp.tile(condition[:, None, :], (1, num_samples, 1))
|
80 |
+
# Collapse the [batch, num_samples, feature] tensor to
|
81 |
+
# [batch * num_samples, feature] so that it can be fed into nn.Dense.
|
82 |
+
condition = condition.reshape([-1, condition.shape[-1]])
|
83 |
+
x = jnp.concatenate([bottleneck, condition], axis=-1)
|
84 |
+
# Here use 1 extra layer to align with the original nerf model.
|
85 |
+
for i in range(self.net_depth_condition):
|
86 |
+
x = dense_layer(self.net_width_condition)(x)
|
87 |
+
x = self.net_activation(x)
|
88 |
+
raw_rgb = dense_layer(self.num_rgb_channels)(x).reshape(
|
89 |
+
[-1, num_samples, self.num_rgb_channels])
|
90 |
+
return raw_rgb, raw_sigma
|
91 |
+
|
92 |
+
|
93 |
+
def cast_rays(z_vals, origins, directions):
|
94 |
+
return origins[..., None, :] + z_vals[..., None] * directions[..., None, :]
|
95 |
+
|
96 |
+
|
97 |
+
def sample_along_rays(key, origins, directions, num_samples, near, far,
|
98 |
+
randomized, lindisp):
|
99 |
+
"""
|
100 |
+
Stratified sampling along the rays.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
key: jnp.ndarray, random generator key.
|
104 |
+
origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
|
105 |
+
directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
|
106 |
+
num_samples: int.
|
107 |
+
near: float, near clip.
|
108 |
+
far: float, far clip.
|
109 |
+
randomized: bool, use randomized stratified sampling.
|
110 |
+
lindisp: bool, sampling linearly in disparity rather than depth.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
z_vals: jnp.ndarray, [batch_size, num_samples], sampled z values.
|
114 |
+
points: jnp.ndarray, [batch_size, num_samples, 3], sampled points.
|
115 |
+
"""
|
116 |
+
batch_size = origins.shape[0]
|
117 |
+
|
118 |
+
t_vals = jnp.linspace(0., 1., num_samples)
|
119 |
+
if lindisp:
|
120 |
+
z_vals = 1. / (1. / near * (1. - t_vals) + 1. / far * t_vals)
|
121 |
+
else:
|
122 |
+
z_vals = near * (1. - t_vals) + far * t_vals
|
123 |
+
|
124 |
+
if randomized:
|
125 |
+
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
126 |
+
upper = jnp.concatenate([mids, z_vals[..., -1:]], -1)
|
127 |
+
lower = jnp.concatenate([z_vals[..., :1], mids], -1)
|
128 |
+
t_rand = random.uniform(key, [batch_size, num_samples])
|
129 |
+
z_vals = lower + (upper - lower) * t_rand
|
130 |
+
else:
|
131 |
+
# Broadcast z_vals to make the returned shape consistent.
|
132 |
+
z_vals = jnp.broadcast_to(z_vals[None, ...], [batch_size, num_samples])
|
133 |
+
|
134 |
+
coords = cast_rays(z_vals, origins, directions)
|
135 |
+
return z_vals, coords
|
136 |
+
|
137 |
+
|
138 |
+
def posenc(x, min_deg, max_deg, legacy_posenc_order=False):
|
139 |
+
"""
|
140 |
+
Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
|
141 |
+
|
142 |
+
Instead of computing [sin(x), cos(x)], we use the trig identity
|
143 |
+
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
|
144 |
+
|
145 |
+
Args:
|
146 |
+
x: jnp.ndarray, variables to be encoded. Note that x should be in [-pi, pi].
|
147 |
+
min_deg: int, the minimum (inclusive) degree of the encoding.
|
148 |
+
max_deg: int, the maximum (exclusive) degree of the encoding.
|
149 |
+
legacy_posenc_order: bool, keep the same ordering as the original tf code.
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
encoded: jnp.ndarray, encoded variables.
|
153 |
+
"""
|
154 |
+
if min_deg == max_deg:
|
155 |
+
return x
|
156 |
+
scales = jnp.array([2 ** i for i in range(min_deg, max_deg)])
|
157 |
+
if legacy_posenc_order:
|
158 |
+
xb = x[..., None, :] * scales[:, None]
|
159 |
+
four_feat = jnp.reshape(
|
160 |
+
jnp.sin(jnp.stack([xb, xb + 0.5 * jnp.pi], -2)),
|
161 |
+
list(x.shape[:-1]) + [-1])
|
162 |
+
else:
|
163 |
+
xb = jnp.reshape((x[..., None, :] * scales[:, None]),
|
164 |
+
list(x.shape[:-1]) + [-1])
|
165 |
+
four_feat = jnp.sin(jnp.concatenate([xb, xb + 0.5 * jnp.pi], axis=-1))
|
166 |
+
return jnp.concatenate([x] + [four_feat], axis=-1)
|
167 |
+
|
168 |
+
|
169 |
+
def volumetric_rendering(rgb, sigma, z_vals, dirs, white_bkgd):
|
170 |
+
"""
|
171 |
+
Volumetric Rendering Function.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
rgb: jnp.ndarray(float32), color, [batch_size, num_samples, 3]
|
175 |
+
sigma: jnp.ndarray(float32), density, [batch_size, num_samples, 1].
|
176 |
+
z_vals: jnp.ndarray(float32), [batch_size, num_samples].
|
177 |
+
dirs: jnp.ndarray(float32), [batch_size, 3].
|
178 |
+
white_bkgd: bool.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
comp_rgb: jnp.ndarray(float32), [batch_size, 3].
|
182 |
+
disp: jnp.ndarray(float32), [batch_size].
|
183 |
+
acc: jnp.ndarray(float32), [batch_size].
|
184 |
+
weights: jnp.ndarray(float32), [batch_size, num_samples]
|
185 |
+
"""
|
186 |
+
eps = 1e-10
|
187 |
+
dists = jnp.concatenate([
|
188 |
+
z_vals[..., 1:] - z_vals[..., :-1],
|
189 |
+
jnp.broadcast_to([1e10], z_vals[..., :1].shape)
|
190 |
+
], -1)
|
191 |
+
dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)
|
192 |
+
# Note that we're quietly turning sigma from [..., 0] to [...].
|
193 |
+
alpha = 1.0 - jnp.exp(-sigma[..., 0] * dists)
|
194 |
+
accum_prod = jnp.concatenate([
|
195 |
+
jnp.ones_like(alpha[..., :1], alpha.dtype),
|
196 |
+
jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1)
|
197 |
+
],
|
198 |
+
axis=-1)
|
199 |
+
weights = alpha * accum_prod
|
200 |
+
|
201 |
+
comp_rgb = (weights[..., None] * rgb).sum(axis=-2)
|
202 |
+
depth = (weights * z_vals).sum(axis=-1)
|
203 |
+
acc = weights.sum(axis=-1)
|
204 |
+
# Equivalent to (but slightly more efficient and stable than):
|
205 |
+
# disp = 1 / max(eps, where(acc > eps, depth / acc, 0))
|
206 |
+
inv_eps = 1 / eps
|
207 |
+
disp = acc / depth
|
208 |
+
disp = jnp.where((disp > 0) & (disp < inv_eps) & (acc > eps), disp, inv_eps)
|
209 |
+
if white_bkgd:
|
210 |
+
comp_rgb = comp_rgb + (1. - acc[..., None])
|
211 |
+
return comp_rgb, disp, acc, weights
|
212 |
+
|
213 |
+
|
214 |
+
def piecewise_constant_pdf(key, bins, weights, num_samples, randomized):
|
215 |
+
"""
|
216 |
+
Piecewise-Constant PDF sampling.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
key: jnp.ndarray(float32), [2,], random number generator.
|
220 |
+
bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
|
221 |
+
weights: jnp.ndarray(float32), [batch_size, num_bins].
|
222 |
+
num_samples: int, the number of samples.
|
223 |
+
randomized: bool, use randomized samples.
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
z_samples: jnp.ndarray(float32), [batch_size, num_samples].
|
227 |
+
"""
|
228 |
+
# Pad each weight vector (only if necessary) to bring its sum to `eps`. This
|
229 |
+
# avoids NaNs when the input is zeros or small, but has no effect otherwise.
|
230 |
+
eps = 1e-5
|
231 |
+
weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
|
232 |
+
padding = jnp.maximum(0, eps - weight_sum)
|
233 |
+
weights += padding / weights.shape[-1]
|
234 |
+
weight_sum += padding
|
235 |
+
|
236 |
+
# Compute the PDF and CDF for each weight vector, while ensuring that the CDF
|
237 |
+
# starts with exactly 0 and ends with exactly 1.
|
238 |
+
pdf = weights / weight_sum
|
239 |
+
cdf = jnp.minimum(1, jnp.cumsum(pdf[..., :-1], axis=-1))
|
240 |
+
cdf = jnp.concatenate([
|
241 |
+
jnp.zeros(list(cdf.shape[:-1]) + [1]), cdf,
|
242 |
+
jnp.ones(list(cdf.shape[:-1]) + [1])
|
243 |
+
],
|
244 |
+
axis=-1)
|
245 |
+
|
246 |
+
# Draw uniform samples.
|
247 |
+
if randomized:
|
248 |
+
# Note that `u` is in [0, 1) --- it can be zero, but it can never be 1.
|
249 |
+
u = random.uniform(key, list(cdf.shape[:-1]) + [num_samples])
|
250 |
+
else:
|
251 |
+
# Match the behavior of random.uniform() by spanning [0, 1-eps].
|
252 |
+
u = jnp.linspace(0., 1. - jnp.finfo('float32').eps, num_samples)
|
253 |
+
u = jnp.broadcast_to(u, list(cdf.shape[:-1]) + [num_samples])
|
254 |
+
|
255 |
+
# Identify the location in `cdf` that corresponds to a random sample.
|
256 |
+
# The final `True` index in `mask` will be the start of the sampled interval.
|
257 |
+
mask = u[..., None, :] >= cdf[..., :, None]
|
258 |
+
|
259 |
+
def find_interval(x):
|
260 |
+
# Grab the value where `mask` switches from True to False, and vice versa.
|
261 |
+
# This approach takes advantage of the fact that `x` is sorted.
|
262 |
+
x0 = jnp.max(jnp.where(mask, x[..., None], x[..., :1, None]), -2)
|
263 |
+
x1 = jnp.min(jnp.where(~mask, x[..., None], x[..., -1:, None]), -2)
|
264 |
+
return x0, x1
|
265 |
+
|
266 |
+
bins_g0, bins_g1 = find_interval(bins)
|
267 |
+
cdf_g0, cdf_g1 = find_interval(cdf)
|
268 |
+
|
269 |
+
t = jnp.clip(jnp.nan_to_num((u - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1)
|
270 |
+
samples = bins_g0 + t * (bins_g1 - bins_g0)
|
271 |
+
|
272 |
+
# Prevent gradient from backprop-ing through `samples`.
|
273 |
+
return lax.stop_gradient(samples)
|
274 |
+
|
275 |
+
|
276 |
+
def sample_pdf(key, bins, weights, origins, directions, z_vals, num_samples,
|
277 |
+
randomized):
|
278 |
+
"""
|
279 |
+
Hierarchical sampling.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
key: jnp.ndarray(float32), [2,], random number generator.
|
283 |
+
bins: jnp.ndarray(float32), [batch_size, num_bins + 1].
|
284 |
+
weights: jnp.ndarray(float32), [batch_size, num_bins].
|
285 |
+
origins: jnp.ndarray(float32), [batch_size, 3], ray origins.
|
286 |
+
directions: jnp.ndarray(float32), [batch_size, 3], ray directions.
|
287 |
+
z_vals: jnp.ndarray(float32), [batch_size, num_coarse_samples].
|
288 |
+
num_samples: int, the number of samples.
|
289 |
+
randomized: bool, use randomized samples.
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
z_vals: jnp.ndarray(float32),
|
293 |
+
[batch_size, num_coarse_samples + num_fine_samples].
|
294 |
+
points: jnp.ndarray(float32),
|
295 |
+
[batch_size, num_coarse_samples + num_fine_samples, 3].
|
296 |
+
"""
|
297 |
+
z_samples = piecewise_constant_pdf(key, bins, weights, num_samples,
|
298 |
+
randomized)
|
299 |
+
# Compute united z_vals and sample points
|
300 |
+
z_vals = jnp.sort(jnp.concatenate([z_vals, z_samples], axis=-1), axis=-1)
|
301 |
+
coords = cast_rays(z_vals, origins, directions)
|
302 |
+
return z_vals, coords
|
303 |
+
|
304 |
+
|
305 |
+
def add_gaussian_noise(key, raw, noise_std, randomized):
|
306 |
+
"""
|
307 |
+
Adds gaussian noise to `raw`, which can used to regularize it.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
key: jnp.ndarray(float32), [2,], random number generator.
|
311 |
+
raw: jnp.ndarray(float32), arbitrary shape.
|
312 |
+
noise_std: float, The standard deviation of the noise to be added.
|
313 |
+
randomized: bool, add noise if randomized is True.
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
raw + noise: jnp.ndarray(float32), with the same shape as `raw`.
|
317 |
+
"""
|
318 |
+
if (noise_std is not None) and randomized:
|
319 |
+
return raw + random.normal(key, raw.shape, dtype=raw.dtype) * noise_std
|
320 |
+
else:
|
321 |
+
return raw
|
jaxnerf/nerf/models.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Different model implementation plus a general port for all the models."""
|
18 |
+
from typing import Any, Callable
|
19 |
+
from flax import linen as nn
|
20 |
+
from jax import random
|
21 |
+
import jax.numpy as jnp
|
22 |
+
|
23 |
+
from jaxnerf.nerf import model_utils
|
24 |
+
from jaxnerf.nerf import utils
|
25 |
+
|
26 |
+
|
27 |
+
def get_model(key, example_batch, args):
|
28 |
+
"""A helper function that wraps around a 'model zoo'."""
|
29 |
+
model_dict = {"nerf": construct_nerf}
|
30 |
+
return model_dict[args.model](key, example_batch, args)
|
31 |
+
|
32 |
+
|
33 |
+
class NerfModel(nn.Module):
|
34 |
+
"""Nerf NN Model with both coarse and fine MLPs."""
|
35 |
+
num_coarse_samples: int # The number of samples for the coarse nerf.
|
36 |
+
num_fine_samples: int # The number of samples for the fine nerf.
|
37 |
+
use_viewdirs: bool # If True, use viewdirs as an input.
|
38 |
+
near: float # The distance to the near plane
|
39 |
+
far: float # The distance to the far plane
|
40 |
+
noise_std: float # The std dev of noise added to raw sigma.
|
41 |
+
net_depth: int # The depth of the first part of MLP.
|
42 |
+
net_width: int # The width of the first part of MLP.
|
43 |
+
net_depth_condition: int # The depth of the second part of MLP.
|
44 |
+
net_width_condition: int # The width of the second part of MLP.
|
45 |
+
net_activation: Callable[..., Any] # MLP activation
|
46 |
+
skip_layer: int # How often to add skip connections.
|
47 |
+
num_rgb_channels: int # The number of RGB channels.
|
48 |
+
num_sigma_channels: int # The number of density channels.
|
49 |
+
white_bkgd: bool # If True, use a white background.
|
50 |
+
min_deg_point: int # The minimum degree of positional encoding for positions.
|
51 |
+
max_deg_point: int # The maximum degree of positional encoding for positions.
|
52 |
+
deg_view: int # The degree of positional encoding for viewdirs.
|
53 |
+
lindisp: bool # If True, sample linearly in disparity rather than in depth.
|
54 |
+
rgb_activation: Callable[..., Any] # Output RGB activation.
|
55 |
+
sigma_activation: Callable[..., Any] # Output sigma activation.
|
56 |
+
legacy_posenc_order: bool # Keep the same ordering as the original tf code.
|
57 |
+
|
58 |
+
@nn.compact
|
59 |
+
def __call__(self, rng_0, rng_1, rays, randomized):
|
60 |
+
"""Nerf Model.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
rng_0: jnp.ndarray, random number generator for coarse model sampling.
|
64 |
+
rng_1: jnp.ndarray, random number generator for fine model sampling.
|
65 |
+
rays: util.Rays, a namedtuple of ray origins, directions, and viewdirs.
|
66 |
+
randomized: bool, use randomized stratified sampling.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
ret: list, [(rgb_coarse, disp_coarse, acc_coarse), (rgb, disp, acc)]
|
70 |
+
"""
|
71 |
+
# Stratified sampling along rays
|
72 |
+
key, rng_0 = random.split(rng_0)
|
73 |
+
z_vals, samples = model_utils.sample_along_rays(
|
74 |
+
key,
|
75 |
+
rays.origins,
|
76 |
+
rays.directions,
|
77 |
+
self.num_coarse_samples,
|
78 |
+
self.near,
|
79 |
+
self.far,
|
80 |
+
randomized,
|
81 |
+
self.lindisp,
|
82 |
+
)
|
83 |
+
samples_enc = model_utils.posenc(
|
84 |
+
samples,
|
85 |
+
self.min_deg_point,
|
86 |
+
self.max_deg_point,
|
87 |
+
self.legacy_posenc_order,
|
88 |
+
)
|
89 |
+
|
90 |
+
# Construct the "coarse" MLP.
|
91 |
+
coarse_mlp = model_utils.MLP(
|
92 |
+
net_depth=self.net_depth,
|
93 |
+
net_width=self.net_width,
|
94 |
+
net_depth_condition=self.net_depth_condition,
|
95 |
+
net_width_condition=self.net_width_condition,
|
96 |
+
net_activation=self.net_activation,
|
97 |
+
skip_layer=self.skip_layer,
|
98 |
+
num_rgb_channels=self.num_rgb_channels,
|
99 |
+
num_sigma_channels=self.num_sigma_channels)
|
100 |
+
|
101 |
+
# Point attribute predictions
|
102 |
+
if self.use_viewdirs:
|
103 |
+
viewdirs_enc = model_utils.posenc(
|
104 |
+
rays.viewdirs,
|
105 |
+
0,
|
106 |
+
self.deg_view,
|
107 |
+
self.legacy_posenc_order,
|
108 |
+
)
|
109 |
+
raw_rgb, raw_sigma = coarse_mlp(samples_enc, viewdirs_enc)
|
110 |
+
else:
|
111 |
+
viewdirs_enc = None
|
112 |
+
raw_rgb, raw_sigma = coarse_mlp(samples_enc)
|
113 |
+
# Add noises to regularize the density predictions if needed
|
114 |
+
key, rng_0 = random.split(rng_0)
|
115 |
+
raw_sigma = model_utils.add_gaussian_noise(
|
116 |
+
key,
|
117 |
+
raw_sigma,
|
118 |
+
self.noise_std,
|
119 |
+
randomized,
|
120 |
+
)
|
121 |
+
rgb = self.rgb_activation(raw_rgb)
|
122 |
+
sigma = self.sigma_activation(raw_sigma)
|
123 |
+
# Volumetric rendering.
|
124 |
+
comp_rgb, disp, acc, weights = model_utils.volumetric_rendering(
|
125 |
+
rgb,
|
126 |
+
sigma,
|
127 |
+
z_vals,
|
128 |
+
rays.directions,
|
129 |
+
white_bkgd=self.white_bkgd,
|
130 |
+
)
|
131 |
+
ret = [
|
132 |
+
(comp_rgb, disp, acc),
|
133 |
+
]
|
134 |
+
# Hierarchical sampling based on coarse predictions
|
135 |
+
if self.num_fine_samples > 0:
|
136 |
+
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
137 |
+
key, rng_1 = random.split(rng_1)
|
138 |
+
z_vals, samples = model_utils.sample_pdf(
|
139 |
+
key,
|
140 |
+
z_vals_mid,
|
141 |
+
weights[..., 1:-1],
|
142 |
+
rays.origins,
|
143 |
+
rays.directions,
|
144 |
+
z_vals,
|
145 |
+
self.num_fine_samples,
|
146 |
+
randomized,
|
147 |
+
)
|
148 |
+
samples_enc = model_utils.posenc(
|
149 |
+
samples,
|
150 |
+
self.min_deg_point,
|
151 |
+
self.max_deg_point,
|
152 |
+
self.legacy_posenc_order,
|
153 |
+
)
|
154 |
+
|
155 |
+
# Construct the "fine" MLP.
|
156 |
+
fine_mlp = model_utils.MLP(
|
157 |
+
net_depth=self.net_depth,
|
158 |
+
net_width=self.net_width,
|
159 |
+
net_depth_condition=self.net_depth_condition,
|
160 |
+
net_width_condition=self.net_width_condition,
|
161 |
+
net_activation=self.net_activation,
|
162 |
+
skip_layer=self.skip_layer,
|
163 |
+
num_rgb_channels=self.num_rgb_channels,
|
164 |
+
num_sigma_channels=self.num_sigma_channels)
|
165 |
+
|
166 |
+
if self.use_viewdirs:
|
167 |
+
raw_rgb, raw_sigma = fine_mlp(samples_enc, viewdirs_enc)
|
168 |
+
else:
|
169 |
+
raw_rgb, raw_sigma = fine_mlp(samples_enc)
|
170 |
+
key, rng_1 = random.split(rng_1)
|
171 |
+
raw_sigma = model_utils.add_gaussian_noise(
|
172 |
+
key,
|
173 |
+
raw_sigma,
|
174 |
+
self.noise_std,
|
175 |
+
randomized,
|
176 |
+
)
|
177 |
+
rgb = self.rgb_activation(raw_rgb)
|
178 |
+
sigma = self.sigma_activation(raw_sigma)
|
179 |
+
comp_rgb, disp, acc, unused_weights = model_utils.volumetric_rendering(
|
180 |
+
rgb,
|
181 |
+
sigma,
|
182 |
+
z_vals,
|
183 |
+
rays.directions,
|
184 |
+
white_bkgd=self.white_bkgd,
|
185 |
+
)
|
186 |
+
ret.append((comp_rgb, disp, acc))
|
187 |
+
return ret
|
188 |
+
|
189 |
+
|
190 |
+
def construct_nerf(key, example_batch, args):
|
191 |
+
"""Construct a Neural Radiance Field.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
key: jnp.ndarray. Random number generator.
|
195 |
+
example_batch: dict, an example of a batch of data.
|
196 |
+
args: FLAGS class. Hyperparameters of nerf.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
model: nn.Model. Nerf model with parameters.
|
200 |
+
state: flax.Module.state. Nerf model state for stateful parameters.
|
201 |
+
"""
|
202 |
+
net_activation = getattr(nn, str(args.net_activation))
|
203 |
+
rgb_activation = getattr(nn, str(args.rgb_activation))
|
204 |
+
sigma_activation = getattr(nn, str(args.sigma_activation))
|
205 |
+
|
206 |
+
# Assert that rgb_activation always produces outputs in [0, 1], and
|
207 |
+
# sigma_activation always produce non-negative outputs.
|
208 |
+
x = jnp.exp(jnp.linspace(-90, 90, 1024))
|
209 |
+
x = jnp.concatenate([-x[::-1], x], 0)
|
210 |
+
|
211 |
+
rgb = rgb_activation(x)
|
212 |
+
if jnp.any(rgb < 0) or jnp.any(rgb > 1):
|
213 |
+
raise NotImplementedError(
|
214 |
+
"Choice of rgb_activation `{}` produces colors outside of [0, 1]"
|
215 |
+
.format(args.rgb_activation))
|
216 |
+
|
217 |
+
sigma = sigma_activation(x)
|
218 |
+
if jnp.any(sigma < 0):
|
219 |
+
raise NotImplementedError(
|
220 |
+
"Choice of sigma_activation `{}` produces negative densities".format(
|
221 |
+
args.sigma_activation))
|
222 |
+
|
223 |
+
model = NerfModel(
|
224 |
+
min_deg_point=args.min_deg_point,
|
225 |
+
max_deg_point=args.max_deg_point,
|
226 |
+
deg_view=args.deg_view,
|
227 |
+
num_coarse_samples=args.num_coarse_samples,
|
228 |
+
num_fine_samples=args.num_fine_samples,
|
229 |
+
use_viewdirs=args.use_viewdirs,
|
230 |
+
near=args.near,
|
231 |
+
far=args.far,
|
232 |
+
noise_std=args.noise_std,
|
233 |
+
white_bkgd=args.white_bkgd,
|
234 |
+
net_depth=args.net_depth,
|
235 |
+
net_width=args.net_width,
|
236 |
+
net_depth_condition=args.net_depth_condition,
|
237 |
+
net_width_condition=args.net_width_condition,
|
238 |
+
skip_layer=args.skip_layer,
|
239 |
+
num_rgb_channels=args.num_rgb_channels,
|
240 |
+
num_sigma_channels=args.num_sigma_channels,
|
241 |
+
lindisp=args.lindisp,
|
242 |
+
net_activation=net_activation,
|
243 |
+
rgb_activation=rgb_activation,
|
244 |
+
sigma_activation=sigma_activation,
|
245 |
+
legacy_posenc_order=args.legacy_posenc_order)
|
246 |
+
rays = example_batch["rays"]
|
247 |
+
key1, key2, key3 = random.split(key, num=3)
|
248 |
+
|
249 |
+
init_variables = model.init(
|
250 |
+
key1,
|
251 |
+
rng_0=key2,
|
252 |
+
rng_1=key3,
|
253 |
+
rays=utils.namedtuple_map(lambda x: x[0], rays),
|
254 |
+
randomized=args.randomized)
|
255 |
+
|
256 |
+
return model, init_variables
|
jaxnerf/nerf/precompute.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
command line example:
|
3 |
+
$ python -i -m jaxnerf.nerf.precompute --data_dir {path-to-data-dir} --split train \
|
4 |
+
--dataset blender --factor 4 --dtype float16
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
import argparse
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import jax.numpy as np
|
11 |
+
|
12 |
+
from jaxnerf.nerf import utils
|
13 |
+
from jaxnerf.nerf import clip_utils
|
14 |
+
from jaxnerf.nerf import datasets
|
15 |
+
|
16 |
+
|
17 |
+
def precompute_image_features(data_dir: str, split: str, dataset: str, factor: int, dtype: str,
|
18 |
+
model_name: Optional[str], render_path: Optional[str]):
|
19 |
+
if dataset == "blender":
|
20 |
+
if render_path:
|
21 |
+
raise ValueError("render_path cannot be used for the blender dataset.")
|
22 |
+
|
23 |
+
# image in numpy.ndarray
|
24 |
+
_, images, _ = datasets.Blender.load_files(data_dir, split, factor)
|
25 |
+
clip_model = clip_utils.init_CLIP(dtype, model_name)
|
26 |
+
|
27 |
+
# CLIP output in jax.numpy.ndarray
|
28 |
+
images = np.stack(images).transpose(0, 3, 1, 2)
|
29 |
+
images = images[:, :3, :, :]
|
30 |
+
images = clip_utils.preprocess_for_CLIP(images)
|
31 |
+
embeddings = clip_model.get_image_features(pixel_values=images)
|
32 |
+
embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True)
|
33 |
+
print(f'completed precomputing CLIP embeddings: ({embeddings.shape[0]} images)')
|
34 |
+
|
35 |
+
# write as pickle
|
36 |
+
write_path = os.path.join(data_dir, f'clip_cache_{split}_factor{factor}_{dtype}.pkl')
|
37 |
+
utils.write_pickle(embeddings, write_path)
|
38 |
+
print(f'precompute written as pickle: {write_path}')
|
39 |
+
|
40 |
+
elif dataset == "llff":
|
41 |
+
raise NotImplementedError
|
42 |
+
else:
|
43 |
+
raise ValueError(f"invalid dataset: {dataset}")
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
parser = argparse.ArgumentParser()
|
48 |
+
parser.add_argument("--data_dir", type=str, required=True)
|
49 |
+
parser.add_argument("--split", type=str, required=True, help="train/val/test")
|
50 |
+
parser.add_argument("--dataset", type=str, required=True)
|
51 |
+
parser.add_argument("--factor", type=int, required=True,
|
52 |
+
help="downsampling factor: 0/2/4")
|
53 |
+
parser.add_argument("--dtype", type=str, required=True,
|
54 |
+
help="float32/float16 (float16 is used to save memory)")
|
55 |
+
parser.add_argument("--model_name", type=str, required=False, default=None)
|
56 |
+
parser.add_argument("--render_path", type=str, required=False, default=None)
|
57 |
+
args = parser.parse_args()
|
58 |
+
precompute_image_features(args.data_dir, args.split, args.dataset, args.factor,
|
59 |
+
args.dtype, args.model_name, args.render_path)
|
jaxnerf/nerf/utils.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Utility functions."""
|
18 |
+
import collections
|
19 |
+
import os
|
20 |
+
from os import path
|
21 |
+
import pickle
|
22 |
+
from absl import flags
|
23 |
+
import flax
|
24 |
+
import jax
|
25 |
+
import jax.numpy as jnp
|
26 |
+
import jax.scipy as jsp
|
27 |
+
import numpy as np
|
28 |
+
from PIL import Image
|
29 |
+
import yaml
|
30 |
+
from jaxnerf.nerf import datasets
|
31 |
+
|
32 |
+
BASE_DIR = "jaxnerf"
|
33 |
+
INTERNAL = False
|
34 |
+
|
35 |
+
|
36 |
+
@flax.struct.dataclass
|
37 |
+
class TrainState:
|
38 |
+
optimizer: flax.optim.Optimizer
|
39 |
+
|
40 |
+
|
41 |
+
@flax.struct.dataclass
|
42 |
+
class Stats:
|
43 |
+
loss: float
|
44 |
+
psnr: float
|
45 |
+
loss_c: float
|
46 |
+
psnr_c: float
|
47 |
+
weight_l2: float
|
48 |
+
|
49 |
+
|
50 |
+
Rays = collections.namedtuple("Rays", ("origins", "directions", "viewdirs"))
|
51 |
+
|
52 |
+
|
53 |
+
def namedtuple_map(fn, tup):
|
54 |
+
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
|
55 |
+
return type(tup)(*map(fn, tup))
|
56 |
+
|
57 |
+
|
58 |
+
def define_flags():
|
59 |
+
"""Define flags for both training and evaluation modes."""
|
60 |
+
flags.DEFINE_string("train_dir", None, "where to store ckpts and logs")
|
61 |
+
flags.DEFINE_string("data_dir", None, "input data directory.")
|
62 |
+
flags.DEFINE_string("config", None,
|
63 |
+
"using config files to set hyperparameters.")
|
64 |
+
|
65 |
+
# CLIP part Flags
|
66 |
+
flags.DEFINE_bool("use_semantic_loss", True,
|
67 |
+
"whether use semantic loss or not")
|
68 |
+
flags.DEFINE_string("precompute_pkl_path", None,
|
69 |
+
"where to load the pickle file that precompute image features")
|
70 |
+
flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP")
|
71 |
+
flags.DEFINE_string("clip_output_dtype", "float32",
|
72 |
+
"float32/ float16 (float16 for memory saving)")
|
73 |
+
flags.DEFINE_integer("sc_loss_factor", 4,
|
74 |
+
"factor for downsampling image (0/2/4). "
|
75 |
+
"its compounded on top of another flag: factor")
|
76 |
+
flags.DEFINE_integer("sc_loss_every", 16,
|
77 |
+
"no. of steps to take before performing semantic loss evaluation")
|
78 |
+
flags.DEFINE_float("sc_loss_mult", 10.,
|
79 |
+
"weighting for semantic loss from CLIP")
|
80 |
+
|
81 |
+
# Dataset Flags
|
82 |
+
# TODO(pratuls): rename to dataset_loader and consider cleaning up
|
83 |
+
flags.DEFINE_enum("dataset", "blender",
|
84 |
+
list(k for k in datasets.dataset_dict.keys()),
|
85 |
+
"The type of dataset feed to nerf.")
|
86 |
+
flags.DEFINE_enum(
|
87 |
+
"batching", "single_image", ["single_image", "all_images"],
|
88 |
+
"source of ray sampling when collecting training batch,"
|
89 |
+
"single_image for sampling from only one image in a batch,"
|
90 |
+
"all_images for sampling from all the training images.")
|
91 |
+
flags.DEFINE_bool(
|
92 |
+
"white_bkgd", True, "using white color as default background."
|
93 |
+
"(used in the blender dataset only)")
|
94 |
+
flags.DEFINE_integer("batch_size", 1024,
|
95 |
+
"the number of rays in a mini-batch (for training).")
|
96 |
+
flags.DEFINE_integer("factor", 4,
|
97 |
+
"the downsample factor of images, 0 for no downsample.")
|
98 |
+
flags.DEFINE_bool("spherify", False, "set for spherical 360 scenes.")
|
99 |
+
flags.DEFINE_bool(
|
100 |
+
"render_path", False, "render generated path if set true."
|
101 |
+
"(used in the llff dataset only)")
|
102 |
+
flags.DEFINE_integer(
|
103 |
+
"llffhold", 8, "will take every 1/N images as LLFF test set."
|
104 |
+
"(used in the llff dataset only)")
|
105 |
+
flags.DEFINE_bool(
|
106 |
+
"use_pixel_centers", False,
|
107 |
+
"If True, generate rays through the center of each pixel. Note: While "
|
108 |
+
"this is the correct way to handle rays, it is not the way rays are "
|
109 |
+
"handled in the original NeRF paper. Setting this TRUE yields ~ +1 PSNR "
|
110 |
+
"compared to Vanilla NeRF.")
|
111 |
+
|
112 |
+
# Model Flags
|
113 |
+
flags.DEFINE_string("model", "nerf", "name of model to use.")
|
114 |
+
flags.DEFINE_float("near", 2., "near clip of volumetric rendering.")
|
115 |
+
flags.DEFINE_float("far", 6., "far clip of volumentric rendering.")
|
116 |
+
flags.DEFINE_integer("net_depth", 8, "depth of the first part of MLP.")
|
117 |
+
flags.DEFINE_integer("net_width", 256, "width of the first part of MLP.")
|
118 |
+
flags.DEFINE_integer("net_depth_condition", 1,
|
119 |
+
"depth of the second part of MLP.")
|
120 |
+
flags.DEFINE_integer("net_width_condition", 128,
|
121 |
+
"width of the second part of MLP.")
|
122 |
+
flags.DEFINE_float("weight_decay_mult", 0, "The multiplier on weight decay")
|
123 |
+
flags.DEFINE_integer(
|
124 |
+
"skip_layer", 4, "add a skip connection to the output vector of every"
|
125 |
+
"skip_layer layers.")
|
126 |
+
flags.DEFINE_integer("num_rgb_channels", 3, "the number of RGB channels.")
|
127 |
+
flags.DEFINE_integer("num_sigma_channels", 1,
|
128 |
+
"the number of density channels.")
|
129 |
+
flags.DEFINE_bool("randomized", True, "use randomized stratified sampling.")
|
130 |
+
flags.DEFINE_integer("min_deg_point", 0,
|
131 |
+
"Minimum degree of positional encoding for points.")
|
132 |
+
flags.DEFINE_integer("max_deg_point", 10,
|
133 |
+
"Maximum degree of positional encoding for points.")
|
134 |
+
flags.DEFINE_integer("deg_view", 4,
|
135 |
+
"Degree of positional encoding for viewdirs.")
|
136 |
+
flags.DEFINE_integer(
|
137 |
+
"num_coarse_samples", 64,
|
138 |
+
"the number of samples on each ray for the coarse model.")
|
139 |
+
flags.DEFINE_integer("num_fine_samples", 128,
|
140 |
+
"the number of samples on each ray for the fine model.")
|
141 |
+
flags.DEFINE_bool("use_viewdirs", True, "use view directions as a condition.")
|
142 |
+
flags.DEFINE_float(
|
143 |
+
"noise_std", None, "std dev of noise added to regularize sigma output."
|
144 |
+
"(used in the llff dataset only)")
|
145 |
+
flags.DEFINE_bool("lindisp", False,
|
146 |
+
"sampling linearly in disparity rather than depth.")
|
147 |
+
flags.DEFINE_string("net_activation", "relu",
|
148 |
+
"activation function used within the MLP.")
|
149 |
+
flags.DEFINE_string("rgb_activation", "sigmoid",
|
150 |
+
"activation function used to produce RGB.")
|
151 |
+
flags.DEFINE_string("sigma_activation", "relu",
|
152 |
+
"activation function used to produce density.")
|
153 |
+
flags.DEFINE_bool(
|
154 |
+
"legacy_posenc_order", False,
|
155 |
+
"If True, revert the positional encoding feature order to an older version of this codebase."
|
156 |
+
)
|
157 |
+
|
158 |
+
# Train Flags
|
159 |
+
flags.DEFINE_float("lr_init", 5e-4, "The initial learning rate.")
|
160 |
+
flags.DEFINE_float("lr_final", 5e-6, "The final learning rate.")
|
161 |
+
flags.DEFINE_integer(
|
162 |
+
"lr_delay_steps", 0, "The number of steps at the beginning of "
|
163 |
+
"training to reduce the learning rate by lr_delay_mult")
|
164 |
+
flags.DEFINE_float(
|
165 |
+
"lr_delay_mult", 1., "A multiplier on the learning rate when the step "
|
166 |
+
"is < lr_delay_steps")
|
167 |
+
flags.DEFINE_float("grad_max_norm", 0.,
|
168 |
+
"The gradient clipping magnitude (disabled if == 0).")
|
169 |
+
flags.DEFINE_float("grad_max_val", 0.,
|
170 |
+
"The gradient clipping value (disabled if == 0).")
|
171 |
+
|
172 |
+
flags.DEFINE_integer("max_steps", 1000000,
|
173 |
+
"the number of optimization steps.")
|
174 |
+
flags.DEFINE_integer("save_every", 10000,
|
175 |
+
"the number of steps to save a checkpoint.")
|
176 |
+
flags.DEFINE_integer("print_every", 100,
|
177 |
+
"the number of steps between reports to tensorboard.")
|
178 |
+
flags.DEFINE_integer(
|
179 |
+
"render_every", 5000, "the number of steps to render a test image,"
|
180 |
+
"better to be x00 for accurate step time record.")
|
181 |
+
flags.DEFINE_integer("gc_every", 10000,
|
182 |
+
"the number of steps to run python garbage collection.")
|
183 |
+
flags.DEFINE_integer("few_shot", -1,
|
184 |
+
"the number of images.")
|
185 |
+
|
186 |
+
# Eval Flags
|
187 |
+
flags.DEFINE_bool(
|
188 |
+
"eval_once", True,
|
189 |
+
"evaluate the model only once if true, otherwise keeping evaluating new"
|
190 |
+
"checkpoints if there's any.")
|
191 |
+
flags.DEFINE_bool("save_output", True,
|
192 |
+
"save predicted images to disk if True.")
|
193 |
+
flags.DEFINE_integer(
|
194 |
+
"chunk", 8192,
|
195 |
+
"the size of chunks for evaluation inferences, set to the value that"
|
196 |
+
"fits your GPU/TPU memory.")
|
197 |
+
|
198 |
+
def update_flags(args):
|
199 |
+
"""Update the flags in `args` with the contents of the config YAML file."""
|
200 |
+
pth = path.join(BASE_DIR, args.config + ".yaml")
|
201 |
+
with open_file(pth, "r") as fin:
|
202 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
203 |
+
# Only allow args to be updated if they already exist.
|
204 |
+
invalid_args = list(set(configs.keys()) - set(dir(args)))
|
205 |
+
if invalid_args:
|
206 |
+
raise ValueError(f"Invalid args {invalid_args} in {pth}.")
|
207 |
+
args.__dict__.update(configs)
|
208 |
+
|
209 |
+
def open_file(pth, mode="r"):
|
210 |
+
if not INTERNAL:
|
211 |
+
return open(pth, mode=mode)
|
212 |
+
|
213 |
+
|
214 |
+
def file_exists(pth):
|
215 |
+
if not INTERNAL:
|
216 |
+
return path.exists(pth)
|
217 |
+
|
218 |
+
|
219 |
+
def listdir(pth):
|
220 |
+
if not INTERNAL:
|
221 |
+
return os.listdir(pth)
|
222 |
+
|
223 |
+
|
224 |
+
def isdir(pth):
|
225 |
+
if not INTERNAL:
|
226 |
+
return path.isdir(pth)
|
227 |
+
|
228 |
+
|
229 |
+
def makedirs(pth):
|
230 |
+
if not INTERNAL:
|
231 |
+
os.makedirs(pth)
|
232 |
+
|
233 |
+
|
234 |
+
def render_image(render_fn, rays, rng, normalize_disp, chunk=8192):
|
235 |
+
"""Render all the pixels of an image (in test mode).
|
236 |
+
|
237 |
+
Args:
|
238 |
+
render_fn: function, jit-ed render function.
|
239 |
+
rays: a `Rays` namedtuple, the rays to be rendered.
|
240 |
+
rng: jnp.ndarray, random number generator (used in training mode only).
|
241 |
+
normalize_disp: bool, if true then normalize `disp` to [0, 1].
|
242 |
+
chunk: int, the size of chunks to render sequentially.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
rgb: jnp.ndarray, rendered color image.
|
246 |
+
disp: jnp.ndarray, rendered disparity image.
|
247 |
+
acc: jnp.ndarray, rendered accumulated weights per pixel.
|
248 |
+
"""
|
249 |
+
height, width = rays[0].shape[:2]
|
250 |
+
num_rays = height * width
|
251 |
+
rays = namedtuple_map(lambda r: r.reshape((num_rays, -1)), rays)
|
252 |
+
|
253 |
+
unused_rng, key_0, key_1 = jax.random.split(rng, 3)
|
254 |
+
host_id = jax.host_id()
|
255 |
+
results = []
|
256 |
+
for i in range(0, num_rays, chunk):
|
257 |
+
# pylint: disable=cell-var-from-loop
|
258 |
+
chunk_rays = namedtuple_map(lambda r: r[i:i + chunk], rays)
|
259 |
+
chunk_size = chunk_rays[0].shape[0]
|
260 |
+
rays_remaining = chunk_size % jax.device_count()
|
261 |
+
if rays_remaining != 0:
|
262 |
+
padding = jax.device_count() - rays_remaining
|
263 |
+
chunk_rays = namedtuple_map(
|
264 |
+
lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays)
|
265 |
+
else:
|
266 |
+
padding = 0
|
267 |
+
# After padding the number of chunk_rays is always divisible by
|
268 |
+
# host_count.
|
269 |
+
rays_per_host = chunk_rays[0].shape[0] // jax.process_count()
|
270 |
+
start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host
|
271 |
+
chunk_rays = namedtuple_map(lambda r: shard(r[start:stop]), chunk_rays)
|
272 |
+
chunk_results = render_fn(key_0, key_1, chunk_rays)[-1]
|
273 |
+
results.append([unshard(x[0], padding) for x in chunk_results])
|
274 |
+
# pylint: enable=cell-var-from-loop
|
275 |
+
rgb, disp, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)]
|
276 |
+
# Normalize disp for visualization for ndc_rays in llff front-facing scenes.
|
277 |
+
if normalize_disp:
|
278 |
+
disp = (disp - disp.min()) / (disp.max() - disp.min())
|
279 |
+
return (rgb.reshape((height, width, -1)), disp.reshape(
|
280 |
+
(height, width, -1)), acc.reshape((height, width, -1)))
|
281 |
+
|
282 |
+
|
283 |
+
def compute_psnr(mse):
|
284 |
+
"""Compute psnr value given mse (we assume the maximum pixel value is 1).
|
285 |
+
|
286 |
+
Args:
|
287 |
+
mse: float, mean square error of pixels.
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
psnr: float, the psnr value.
|
291 |
+
"""
|
292 |
+
return -10. * jnp.log(mse) / jnp.log(10.)
|
293 |
+
|
294 |
+
|
295 |
+
def compute_ssim(img0,
|
296 |
+
img1,
|
297 |
+
max_val,
|
298 |
+
filter_size=11,
|
299 |
+
filter_sigma=1.5,
|
300 |
+
k1=0.01,
|
301 |
+
k2=0.03,
|
302 |
+
return_map=False):
|
303 |
+
"""Computes SSIM from two images.
|
304 |
+
|
305 |
+
This function was modeled after tf.image.ssim, and should produce comparable
|
306 |
+
output.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
img0: array. An image of size [..., width, height, num_channels].
|
310 |
+
img1: array. An image of size [..., width, height, num_channels].
|
311 |
+
max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
|
312 |
+
filter_size: int >= 1. Window size.
|
313 |
+
filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
|
314 |
+
k1: float > 0. One of the SSIM dampening parameters.
|
315 |
+
k2: float > 0. One of the SSIM dampening parameters.
|
316 |
+
return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
Each image's mean SSIM, or a tensor of individual values if `return_map`.
|
320 |
+
"""
|
321 |
+
# Construct a 1D Gaussian blur filter.
|
322 |
+
hw = filter_size // 2
|
323 |
+
shift = (2 * hw - filter_size + 1) / 2
|
324 |
+
f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma) ** 2
|
325 |
+
filt = jnp.exp(-0.5 * f_i)
|
326 |
+
filt /= jnp.sum(filt)
|
327 |
+
|
328 |
+
# Blur in x and y (faster than the 2D convolution).
|
329 |
+
filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid")
|
330 |
+
filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid")
|
331 |
+
|
332 |
+
# Vmap the blurs to the tensor size, and then compose them.
|
333 |
+
num_dims = len(img0.shape)
|
334 |
+
map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1])
|
335 |
+
for d in map_axes:
|
336 |
+
filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d)
|
337 |
+
filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d)
|
338 |
+
filt_fn = lambda z: filt_fn1(filt_fn2(z))
|
339 |
+
|
340 |
+
mu0 = filt_fn(img0)
|
341 |
+
mu1 = filt_fn(img1)
|
342 |
+
mu00 = mu0 * mu0
|
343 |
+
mu11 = mu1 * mu1
|
344 |
+
mu01 = mu0 * mu1
|
345 |
+
sigma00 = filt_fn(img0 ** 2) - mu00
|
346 |
+
sigma11 = filt_fn(img1 ** 2) - mu11
|
347 |
+
sigma01 = filt_fn(img0 * img1) - mu01
|
348 |
+
|
349 |
+
# Clip the variances and covariances to valid values.
|
350 |
+
# Variance must be non-negative:
|
351 |
+
sigma00 = jnp.maximum(0., sigma00)
|
352 |
+
sigma11 = jnp.maximum(0., sigma11)
|
353 |
+
sigma01 = jnp.sign(sigma01) * jnp.minimum(
|
354 |
+
jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01))
|
355 |
+
|
356 |
+
c1 = (k1 * max_val) ** 2
|
357 |
+
c2 = (k2 * max_val) ** 2
|
358 |
+
numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
|
359 |
+
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
|
360 |
+
ssim_map = numer / denom
|
361 |
+
ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims)))
|
362 |
+
return ssim_map if return_map else ssim
|
363 |
+
|
364 |
+
|
365 |
+
def save_img(img, pth):
|
366 |
+
"""Save an image to disk.
|
367 |
+
|
368 |
+
Args:
|
369 |
+
img: jnp.ndarry, [height, width, channels], img will be clipped to [0, 1]
|
370 |
+
before saved to pth.
|
371 |
+
pth: string, path to save the image to.
|
372 |
+
"""
|
373 |
+
with open_file(pth, "wb") as imgout:
|
374 |
+
Image.fromarray(np.array(
|
375 |
+
(np.clip(img, 0., 1.) * 255.).astype(jnp.uint8))).save(imgout, "PNG")
|
376 |
+
|
377 |
+
|
378 |
+
def learning_rate_decay(step,
|
379 |
+
lr_init,
|
380 |
+
lr_final,
|
381 |
+
max_steps,
|
382 |
+
lr_delay_steps=0,
|
383 |
+
lr_delay_mult=1):
|
384 |
+
"""Continuous learning rate decay function.
|
385 |
+
|
386 |
+
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
|
387 |
+
is log-linearly interpolated elsewhere (equivalent to exponential decay).
|
388 |
+
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
|
389 |
+
function of lr_delay_mult, such that the initial learning rate is
|
390 |
+
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
|
391 |
+
to the normal learning rate when steps>lr_delay_steps.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
step: int, the current optimization step.
|
395 |
+
lr_init: float, the initial learning rate.
|
396 |
+
lr_final: float, the final learning rate.
|
397 |
+
max_steps: int, the number of steps during optimization.
|
398 |
+
lr_delay_steps: int, the number of steps to delay the full learning rate.
|
399 |
+
lr_delay_mult: float, the multiplier on the rate when delaying it.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
lr: the learning for current step 'step'.
|
403 |
+
"""
|
404 |
+
if lr_delay_steps > 0:
|
405 |
+
# A kind of reverse cosine decay.
|
406 |
+
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
|
407 |
+
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
|
408 |
+
else:
|
409 |
+
delay_rate = 1.
|
410 |
+
t = np.clip(step / max_steps, 0, 1)
|
411 |
+
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
|
412 |
+
return delay_rate * log_lerp
|
413 |
+
|
414 |
+
|
415 |
+
def shard(xs):
|
416 |
+
"""Split data into shards for multiple devices along the first dimension."""
|
417 |
+
'''
|
418 |
+
if 'embedding' in xs:
|
419 |
+
xs['pixels'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['pixels'])
|
420 |
+
xs['rays'] = jax.tree_map(lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs['rays'])
|
421 |
+
xs['embedding'] = np.stack([xs['embedding']]*jax.local_device_count(),0)
|
422 |
+
xs['random_rays'] = jax.tree_map(lambda x: np.stack([x]*jax.local_device_count(),0), xs['random_rays'])
|
423 |
+
else:
|
424 |
+
xs = jax.tree_map(
|
425 |
+
lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x
|
426 |
+
, xs)
|
427 |
+
|
428 |
+
return xs
|
429 |
+
'''
|
430 |
+
return jax.tree_map(
|
431 |
+
lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]) if len(x.shape) != 0 else x
|
432 |
+
, xs)
|
433 |
+
|
434 |
+
|
435 |
+
def to_device(xs):
|
436 |
+
"""Transfer data to devices (GPU/TPU)."""
|
437 |
+
return jax.tree_map(jnp.array, xs)
|
438 |
+
|
439 |
+
|
440 |
+
def unshard(x, padding=0):
|
441 |
+
"""Collect the sharded tensor to the shape before sharding."""
|
442 |
+
y = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[2:]))
|
443 |
+
if padding > 0:
|
444 |
+
y = y[:-padding]
|
445 |
+
return y
|
446 |
+
|
447 |
+
|
448 |
+
def write_pickle(data, fn):
|
449 |
+
with open(fn, 'wb') as f:
|
450 |
+
pickle.dump(data, f)
|
451 |
+
return None
|
452 |
+
|
453 |
+
|
454 |
+
def read_pickle(fn):
|
455 |
+
with open(fn, 'rb') as f:
|
456 |
+
data = pickle.load(f)
|
457 |
+
return data
|
jaxnerf/requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.16.4
|
2 |
+
jax>=0.2.6
|
3 |
+
jaxlib>=0.1.57
|
4 |
+
flax>=0.2.2
|
5 |
+
opencv-python>=4.4.0
|
6 |
+
Pillow>=7.2.0
|
7 |
+
pyyaml>=5.3.1
|
8 |
+
tensorboard>=2.4.0
|
9 |
+
tensorflow>=2.3.1
|
10 |
+
tensorflow-hub>=0.11.0
|
11 |
+
transformers==4.8.2
|
12 |
+
wandb==0.10.33
|
13 |
+
tqdm==4.61.2
|
14 |
+
# pip install git+https://github.com/deepmind/jmp # mixed precision for JAX
|
jaxnerf/run.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The Google Research Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
#!/bin/bash
|
16 |
+
set -e
|
17 |
+
set -x
|
18 |
+
|
19 |
+
virtualenv -p python3 .
|
20 |
+
source ./bin/activate
|
21 |
+
|
22 |
+
pip install -r jaxnerf/requirements.txt
|
23 |
+
pip uninstall jax
|
24 |
+
pip install --upgrade pip
|
25 |
+
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
26 |
+
python -m jaxnerf.train \
|
27 |
+
--data_dir=/mnt/data/NeRF_Data/nerf_synthetic/lego \
|
28 |
+
--train_dir=test_output \
|
29 |
+
--max_steps=5 \
|
30 |
+
--factor=2 \
|
31 |
+
--batch_size=512 \
|
32 |
+
--config=configs/orig_nerf_tpu_vm_test \
|
33 |
+
--precompute_pkl_path /mnt/data/NeRF_Data/nerf_synthetic/lego/clip_cache_train_factor4_float32.pkl
|
jaxnerf/train.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Training script for Nerf."""
|
18 |
+
import functools
|
19 |
+
import gc
|
20 |
+
import time
|
21 |
+
from absl import app
|
22 |
+
from absl import flags
|
23 |
+
import flax
|
24 |
+
from flax.metrics import tensorboard
|
25 |
+
from flax.training import checkpoints
|
26 |
+
import jax
|
27 |
+
from jax import config
|
28 |
+
from jax import random
|
29 |
+
import jax.numpy as jnp
|
30 |
+
import numpy as np
|
31 |
+
# import wandb
|
32 |
+
from tqdm import tqdm
|
33 |
+
|
34 |
+
from jaxnerf.nerf import datasets
|
35 |
+
from jaxnerf.nerf import models
|
36 |
+
from jaxnerf.nerf import utils
|
37 |
+
from jaxnerf.nerf import clip_utils
|
38 |
+
|
39 |
+
FLAGS = flags.FLAGS
|
40 |
+
|
41 |
+
utils.define_flags()
|
42 |
+
config.parse_flags_with_absl()
|
43 |
+
|
44 |
+
# set up TPU for colab
|
45 |
+
import os
|
46 |
+
if "COLAB_TPU_ADDR" in os.environ:
|
47 |
+
import jax.tools.colab_tpu
|
48 |
+
jax.tools.colab_tpu.setup_tpu()
|
49 |
+
print(f"detected device: {jax.local_devices()}")
|
50 |
+
|
51 |
+
|
52 |
+
def train_step(model, clip_model, rng, state, batch, lr, step, K):#, clip_grad):
|
53 |
+
# TODO make clip_grad input enable
|
54 |
+
"""One optimization step.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
model: The linen model.
|
58 |
+
rng: jnp.ndarray, random number generator.
|
59 |
+
state: utils.TrainState, state of the model/optimizer.
|
60 |
+
batch: dict, a mini-batch of data for training.
|
61 |
+
lr: float, real-time learning rate.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
new_state: utils.TrainState, new training state.
|
65 |
+
stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
|
66 |
+
rng: jnp.ndarray, updated random number generator.
|
67 |
+
"""
|
68 |
+
rng, key_0, key_1 = random.split(rng, 3)
|
69 |
+
|
70 |
+
def loss_fn(variables):
|
71 |
+
rays = batch["rays"]
|
72 |
+
ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
|
73 |
+
if len(ret) not in (1, 2):
|
74 |
+
raise ValueError(
|
75 |
+
"ret should contain either 1 set of output (coarse only), or 2 sets"
|
76 |
+
"of output (coarse as ret[0] and fine as ret[1]).")
|
77 |
+
# The main prediction is always at the end of the ret list.
|
78 |
+
rgb, unused_disp, unused_acc = ret[-1]
|
79 |
+
loss = ((rgb - batch["pixels"][Ellipsis, :3]) ** 2).mean()
|
80 |
+
psnr = utils.compute_psnr(loss)
|
81 |
+
if len(ret) > 1:
|
82 |
+
# If there are both coarse and fine predictions, we compute the loss for
|
83 |
+
# the coarse prediction (ret[0]) as well.
|
84 |
+
rgb_c, unused_disp_c, unused_acc_c = ret[0]
|
85 |
+
loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3]) ** 2).mean()
|
86 |
+
psnr_c = utils.compute_psnr(loss_c)
|
87 |
+
else:
|
88 |
+
loss_c = 0.
|
89 |
+
psnr_c = 0.
|
90 |
+
|
91 |
+
def tree_sum_fn(fn):
|
92 |
+
return jax.tree_util.tree_reduce(lambda x, y: x + fn(y),
|
93 |
+
variables, initializer=0)
|
94 |
+
|
95 |
+
weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z ** 2)) /
|
96 |
+
tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))
|
97 |
+
|
98 |
+
total_loss = loss + loss_c + FLAGS.weight_decay_mult * weight_l2
|
99 |
+
stats = utils.Stats(loss=loss, psnr=psnr, loss_c=loss_c,
|
100 |
+
psnr_c=psnr_c, weight_l2=weight_l2)
|
101 |
+
return total_loss, stats
|
102 |
+
|
103 |
+
(_, stats), grad = (
|
104 |
+
jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
|
105 |
+
grad = jax.lax.pmean(grad, axis_name="batch")
|
106 |
+
stats = jax.lax.pmean(stats, axis_name="batch")
|
107 |
+
|
108 |
+
# Clip the gradient by value.
|
109 |
+
if FLAGS.grad_max_val > 0:
|
110 |
+
clip_fn = lambda z: jnp.clip(z, -FLAGS.grad_max_val, FLAGS.grad_max_val)
|
111 |
+
grad = jax.tree_util.tree_map(clip_fn, grad)
|
112 |
+
|
113 |
+
# Clip the (possibly value-clipped) gradient by norm.
|
114 |
+
if FLAGS.grad_max_norm > 0:
|
115 |
+
grad_norm = jnp.sqrt(
|
116 |
+
jax.tree_util.tree_reduce(
|
117 |
+
lambda x, y: x + jnp.sum(y ** 2), grad, initializer=0))
|
118 |
+
mult = jnp.minimum(1, FLAGS.grad_max_norm / (1e-7 + grad_norm))
|
119 |
+
grad = jax.tree_util.tree_map(lambda z: mult * z, grad)
|
120 |
+
|
121 |
+
#return grad, state, rng
|
122 |
+
new_optimizer = state.optimizer.apply_gradient(grad, learning_rate =lr)
|
123 |
+
new_state = state.replace(optimizer=new_optimizer)
|
124 |
+
return new_state, stats, rng
|
125 |
+
|
126 |
+
def update_step(state, grad, lr):
|
127 |
+
new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
|
128 |
+
new_state = state.replace(optimizer=new_optimizer)
|
129 |
+
return new_state
|
130 |
+
|
131 |
+
|
132 |
+
def main(unused_argv):
|
133 |
+
#wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True)
|
134 |
+
rng = random.PRNGKey(20200823)
|
135 |
+
# Shift the numpy random seed by host_id() to shuffle data loaded by different
|
136 |
+
# hosts.
|
137 |
+
np.random.seed(20201473 + jax.host_id())
|
138 |
+
|
139 |
+
if FLAGS.config is not None:
|
140 |
+
utils.update_flags(FLAGS)
|
141 |
+
if FLAGS.batch_size % jax.device_count() != 0:
|
142 |
+
raise ValueError("Batch size must be divisible by the number of devices.")
|
143 |
+
if FLAGS.train_dir is None:
|
144 |
+
raise ValueError("train_dir must be set. None set now.")
|
145 |
+
if FLAGS.data_dir is None:
|
146 |
+
raise ValueError("data_dir must be set. None set now.")
|
147 |
+
|
148 |
+
# setup CLIP model
|
149 |
+
if FLAGS.use_semantic_loss:
|
150 |
+
clip_model = clip_utils.init_CLIP(FLAGS.clip_output_dtype,
|
151 |
+
FLAGS.clip_model_name)
|
152 |
+
print('semantic loss ACTIVATED, CLIP is set up')
|
153 |
+
else:
|
154 |
+
clip_model = None
|
155 |
+
print('semantic loss DEACTIVATED, CLIP is set to None')
|
156 |
+
|
157 |
+
dataset = datasets.get_dataset("train", FLAGS, clip_model)
|
158 |
+
test_dataset = datasets.get_dataset("test", FLAGS, clip_model)
|
159 |
+
|
160 |
+
# setup NeRF model
|
161 |
+
rng, key = random.split(rng)
|
162 |
+
model, variables = models.get_model(key, dataset.peek(), FLAGS)
|
163 |
+
optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
|
164 |
+
state = utils.TrainState(optimizer=optimizer)
|
165 |
+
del optimizer, variables
|
166 |
+
learning_rate_fn = functools.partial(
|
167 |
+
utils.learning_rate_decay,
|
168 |
+
lr_init=FLAGS.lr_init,
|
169 |
+
lr_final=FLAGS.lr_final,
|
170 |
+
max_steps=FLAGS.max_steps,
|
171 |
+
lr_delay_steps=FLAGS.lr_delay_steps,
|
172 |
+
lr_delay_mult=FLAGS.lr_delay_mult)
|
173 |
+
|
174 |
+
train_pstep = jax.pmap(
|
175 |
+
functools.partial(train_step, model, clip_model),
|
176 |
+
axis_name="batch",
|
177 |
+
in_axes=(0, 0, 0, None, None, None),
|
178 |
+
donate_argnums=(2,))
|
179 |
+
|
180 |
+
update_pstep = jax.pmap(
|
181 |
+
functools.partial(update_step,),
|
182 |
+
axis_name="batch",
|
183 |
+
in_axes=(0, None, None),
|
184 |
+
donate_argnums=(0,))
|
185 |
+
|
186 |
+
|
187 |
+
def render_fn(variables, key_0, key_1, rays):
|
188 |
+
return jax.lax.all_gather(
|
189 |
+
model.apply(variables, key_0, key_1, rays, FLAGS.randomized),
|
190 |
+
axis_name="batch")
|
191 |
+
|
192 |
+
render_pfn = jax.pmap(
|
193 |
+
render_fn,
|
194 |
+
in_axes=(None, None, None, 0), # Only distribute the data input.
|
195 |
+
donate_argnums=(3,),
|
196 |
+
axis_name="batch")
|
197 |
+
|
198 |
+
# Compiling to the CPU because it's faster and more accurate.
|
199 |
+
ssim_fn = jax.jit(
|
200 |
+
functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")
|
201 |
+
|
202 |
+
if not utils.isdir(FLAGS.train_dir):
|
203 |
+
utils.makedirs(FLAGS.train_dir)
|
204 |
+
state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
|
205 |
+
# Resume training a the step of the last checkpoint.
|
206 |
+
init_step = state.optimizer.state.step + 1
|
207 |
+
|
208 |
+
# for distributive training
|
209 |
+
state = flax.jax_utils.replicate(state)
|
210 |
+
if jax.host_id() == 0:
|
211 |
+
summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)
|
212 |
+
|
213 |
+
# Prefetch_buffer_size = 3 x batch_size
|
214 |
+
pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
|
215 |
+
n_local_devices = jax.local_device_count()
|
216 |
+
rng = rng + jax.host_id() # Make random seed separate across hosts.
|
217 |
+
keys = random.split(rng, n_local_devices) # For pmapping RNG keys.
|
218 |
+
gc.disable() # Disable automatic garbage collection for efficiency.
|
219 |
+
stats_trace = []
|
220 |
+
reset_timer = True
|
221 |
+
|
222 |
+
# for semantic loss update
|
223 |
+
cnter = 1
|
224 |
+
trigger = int(FLAGS.sc_loss_every / n_local_devices)
|
225 |
+
|
226 |
+
for step, batch in tqdm(zip(range(init_step, FLAGS.max_steps + 1), pdataset)):
|
227 |
+
if reset_timer:
|
228 |
+
t_loop_start = time.time()
|
229 |
+
reset_timer = False
|
230 |
+
lr = learning_rate_fn(step)
|
231 |
+
|
232 |
+
if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
|
233 |
+
# remove dimension for device coz its only run in host core
|
234 |
+
sc_batch = dataset.get_clip_data()
|
235 |
+
sc_loss, sc_grad = clip_utils.update_semantic_loss(model, clip_model,
|
236 |
+
keys[0], state, sc_batch, lr)
|
237 |
+
sc_grad = flax.jax_utils.replicate(sc_grad)
|
238 |
+
sc_grad = jax.tree_map( lambda x: x[0], sc_grad)
|
239 |
+
|
240 |
+
else:
|
241 |
+
sc_loss = 0.
|
242 |
+
|
243 |
+
state, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)#, grad)
|
244 |
+
|
245 |
+
if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
|
246 |
+
state = update_pstep(state, sc_grad, lr)
|
247 |
+
|
248 |
+
if jax.host_id() == 0:
|
249 |
+
stats_trace.append(stats)
|
250 |
+
if step % FLAGS.gc_every == 0:
|
251 |
+
gc.collect()
|
252 |
+
|
253 |
+
# Log training summaries. This is put behind a host_id check because in
|
254 |
+
# multi-host evaluation, all hosts need to run inference even though we
|
255 |
+
# only use host 0 to record results.
|
256 |
+
if jax.host_id() == 0:
|
257 |
+
if step % FLAGS.print_every == 0:
|
258 |
+
summary_writer.scalar("train_loss", stats.loss[0], step)
|
259 |
+
summary_writer.scalar("train_psnr", stats.psnr[0], step)
|
260 |
+
summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step)
|
261 |
+
summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step)
|
262 |
+
summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
|
263 |
+
avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
|
264 |
+
avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
|
265 |
+
stats_trace = []
|
266 |
+
summary_writer.scalar("train_avg_loss", avg_loss, step)
|
267 |
+
summary_writer.scalar("train_avg_psnr", avg_psnr, step)
|
268 |
+
summary_writer.scalar("learning_rate", lr, step)
|
269 |
+
steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
|
270 |
+
reset_timer = True
|
271 |
+
rays_per_sec = FLAGS.batch_size * steps_per_sec
|
272 |
+
summary_writer.scalar("train_steps_per_sec", steps_per_sec, step)
|
273 |
+
summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
|
274 |
+
precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
|
275 |
+
print(("{:" + "{:d}".format(precision) + "d}").format(step) +
|
276 |
+
f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
|
277 |
+
f"avg_loss={avg_loss:0.4f}, " +
|
278 |
+
f"weight_l2={stats.weight_l2[0]:0.2e}, " +
|
279 |
+
# f"sc_loss={sc_loss:0.4f}, " +
|
280 |
+
f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
|
281 |
+
if step % FLAGS.save_every == 0:
|
282 |
+
state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
|
283 |
+
checkpoints.save_checkpoint(
|
284 |
+
FLAGS.train_dir, state_to_save, int(step), keep=100)
|
285 |
+
|
286 |
+
# Test-set evaluation.
|
287 |
+
if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
|
288 |
+
# We reuse the same random number generator from the optimization step
|
289 |
+
# here on purpose so that the visualization matches what happened in
|
290 |
+
# training.
|
291 |
+
t_eval_start = time.time()
|
292 |
+
eval_variables = jax.device_get(jax.tree_map(lambda x: x[0],
|
293 |
+
state)).optimizer.target
|
294 |
+
test_case = next(test_dataset)
|
295 |
+
pred_color, pred_disp, pred_acc = utils.render_image(
|
296 |
+
functools.partial(render_pfn, eval_variables),
|
297 |
+
test_case["rays"],
|
298 |
+
keys[0],
|
299 |
+
FLAGS.dataset == "llff",
|
300 |
+
chunk=FLAGS.chunk)
|
301 |
+
|
302 |
+
# Log eval summaries on host 0.
|
303 |
+
if jax.host_id() == 0:
|
304 |
+
psnr = utils.compute_psnr(
|
305 |
+
((pred_color - test_case["pixels"]) ** 2).mean())
|
306 |
+
ssim = ssim_fn(pred_color, test_case["pixels"])
|
307 |
+
eval_time = time.time() - t_eval_start
|
308 |
+
num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
|
309 |
+
rays_per_sec = num_rays / eval_time
|
310 |
+
summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
|
311 |
+
print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
|
312 |
+
summary_writer.scalar("test_psnr", psnr, step)
|
313 |
+
summary_writer.scalar("test_ssim", ssim, step)
|
314 |
+
summary_writer.image("test_pred_color", pred_color, step)
|
315 |
+
summary_writer.image("test_pred_disp", pred_disp, step)
|
316 |
+
summary_writer.image("test_pred_acc", pred_acc, step)
|
317 |
+
summary_writer.image("test_target", test_case["pixels"], step)
|
318 |
+
|
319 |
+
if FLAGS.max_steps % FLAGS.save_every != 0:
|
320 |
+
state = jax.device_get(jax.tree_map(lambda x: x[0], state))
|
321 |
+
checkpoints.save_checkpoint(
|
322 |
+
FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)
|
323 |
+
|
324 |
+
|
325 |
+
if __name__ == "__main__":
|
326 |
+
app.run(main)
|
jaxnerf/train.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The Google Research Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
#!/bin/bash
|
16 |
+
CONFIG=$1
|
17 |
+
DATA_ROOT=$2
|
18 |
+
ROOT_DIR=/tmp/jaxnerf/"$CONFIG"
|
19 |
+
if [ $CONFIG == "llff" ]
|
20 |
+
then
|
21 |
+
SCENES="room fern leaves fortress orchids flower trex horns"
|
22 |
+
DATA_FOLDER="nerf_llff_data"
|
23 |
+
else
|
24 |
+
SCENES="lego chair drums ficus hotdog materials mic ship"
|
25 |
+
DATA_FOLDER="nerf_synthetic"
|
26 |
+
fi
|
27 |
+
|
28 |
+
# launch training jobs for all scenes.
|
29 |
+
for scene in $SCENES; do
|
30 |
+
python -m jaxnerf.train \
|
31 |
+
--data_dir="$DATA_ROOT"/"$DATA_FOLDER"/"$scene" \
|
32 |
+
--train_dir="$ROOT_DIR"/"$scene" \
|
33 |
+
--config=configs/"$CONFIG"
|
34 |
+
done
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.16.4
|
2 |
+
jax>=0.2.6
|
3 |
+
jaxlib>=0.1.57
|
4 |
+
flax>=0.2.2
|
5 |
+
opencv-python>=4.4.0
|
6 |
+
Pillow>=7.2.0
|
7 |
+
streamlit==0.84.1
|
8 |
+
googledrivedownloader==0.4
|