File size: 13,574 Bytes
19677a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# coding=utf-8
# Copyright 2021 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training script for Nerf."""
import functools
import gc
import time
from absl import app
from absl import flags
import flax
from flax.metrics import tensorboard
from flax.training import checkpoints
import jax
from jax import config
from jax import random
import jax.numpy as jnp
import numpy as np
# import wandb
from tqdm import tqdm

from jaxnerf.nerf import datasets
from jaxnerf.nerf import models
from jaxnerf.nerf import utils
from jaxnerf.nerf import clip_utils

FLAGS = flags.FLAGS

utils.define_flags()
config.parse_flags_with_absl()

# set up TPU for colab
import os
if "COLAB_TPU_ADDR" in os.environ:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
print(f"detected device: {jax.local_devices()}")


def train_step(model, clip_model, rng, state, batch, lr, step, K):#, clip_grad):
    # TODO make clip_grad input enable
    """One optimization step.

    Args:
        model: The linen model.
        rng: jnp.ndarray, random number generator.
        state: utils.TrainState, state of the model/optimizer.
        batch: dict, a mini-batch of data for training.
        lr: float, real-time learning rate.

    Returns:
        new_state: utils.TrainState, new training state.
        stats: list. [(loss, psnr), (loss_coarse, psnr_coarse)].
        rng: jnp.ndarray, updated random number generator.
    """
    rng, key_0, key_1 = random.split(rng, 3)

    def loss_fn(variables):
        rays = batch["rays"]
        ret = model.apply(variables, key_0, key_1, rays, FLAGS.randomized)
        if len(ret) not in (1, 2):
            raise ValueError(
                "ret should contain either 1 set of output (coarse only), or 2 sets"
                "of output (coarse as ret[0] and fine as ret[1]).")
        # The main prediction is always at the end of the ret list.
        rgb, unused_disp, unused_acc = ret[-1]
        loss = ((rgb - batch["pixels"][Ellipsis, :3]) ** 2).mean()
        psnr = utils.compute_psnr(loss)
        if len(ret) > 1:
            # If there are both coarse and fine predictions, we compute the loss for
            # the coarse prediction (ret[0]) as well.
            rgb_c, unused_disp_c, unused_acc_c = ret[0]
            loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3]) ** 2).mean()
            psnr_c = utils.compute_psnr(loss_c)
        else:
            loss_c = 0.
            psnr_c = 0.

        def tree_sum_fn(fn):
            return jax.tree_util.tree_reduce(lambda x, y: x + fn(y),
                                             variables, initializer=0)

        weight_l2 = (tree_sum_fn(lambda z: jnp.sum(z ** 2)) /
                     tree_sum_fn(lambda z: jnp.prod(jnp.array(z.shape))))

        total_loss = loss + loss_c + FLAGS.weight_decay_mult * weight_l2
        stats = utils.Stats(loss=loss, psnr=psnr, loss_c=loss_c,
                            psnr_c=psnr_c, weight_l2=weight_l2)
        return total_loss, stats

    (_, stats), grad = (
        jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
    grad = jax.lax.pmean(grad, axis_name="batch")
    stats = jax.lax.pmean(stats, axis_name="batch")
    
    # Clip the gradient by value.
    if FLAGS.grad_max_val > 0:
        clip_fn = lambda z: jnp.clip(z, -FLAGS.grad_max_val, FLAGS.grad_max_val)
        grad = jax.tree_util.tree_map(clip_fn, grad)

    # Clip the (possibly value-clipped) gradient by norm.
    if FLAGS.grad_max_norm > 0:
        grad_norm = jnp.sqrt(
            jax.tree_util.tree_reduce(
                lambda x, y: x + jnp.sum(y ** 2), grad, initializer=0))
        mult = jnp.minimum(1, FLAGS.grad_max_norm / (1e-7 + grad_norm))
        grad = jax.tree_util.tree_map(lambda z: mult * z, grad)

    #return grad, state, rng
    new_optimizer = state.optimizer.apply_gradient(grad, learning_rate =lr)
    new_state = state.replace(optimizer=new_optimizer)
    return new_state, stats, rng

def update_step(state, grad, lr):
    new_optimizer = state.optimizer.apply_gradient(grad, learning_rate=lr)
    new_state = state.replace(optimizer=new_optimizer)
    return new_state


def main(unused_argv):
    #wandb.init(project="hf-flax-clip-nerf", entity="wandb", sync_tensorboard=True)
    rng = random.PRNGKey(20200823)
    # Shift the numpy random seed by host_id() to shuffle data loaded by different
    # hosts.
    np.random.seed(20201473 + jax.host_id())

    if FLAGS.config is not None:
        utils.update_flags(FLAGS)
    if FLAGS.batch_size % jax.device_count() != 0:
        raise ValueError("Batch size must be divisible by the number of devices.")
    if FLAGS.train_dir is None:
        raise ValueError("train_dir must be set. None set now.")
    if FLAGS.data_dir is None:
        raise ValueError("data_dir must be set. None set now.")

    # setup CLIP model
    if FLAGS.use_semantic_loss:
        clip_model = clip_utils.init_CLIP(FLAGS.clip_output_dtype,
                                          FLAGS.clip_model_name)
        print('semantic loss ACTIVATED, CLIP is set up')
    else:
        clip_model = None
        print('semantic loss DEACTIVATED, CLIP is set to None')
    
    dataset = datasets.get_dataset("train", FLAGS, clip_model)
    test_dataset = datasets.get_dataset("test", FLAGS, clip_model)

    # setup NeRF model
    rng, key = random.split(rng)
    model, variables = models.get_model(key, dataset.peek(), FLAGS)
    optimizer = flax.optim.Adam(FLAGS.lr_init).create(variables)
    state = utils.TrainState(optimizer=optimizer)
    del optimizer, variables
    learning_rate_fn = functools.partial(
        utils.learning_rate_decay,
        lr_init=FLAGS.lr_init,
        lr_final=FLAGS.lr_final,
        max_steps=FLAGS.max_steps,
        lr_delay_steps=FLAGS.lr_delay_steps,
        lr_delay_mult=FLAGS.lr_delay_mult)

    train_pstep = jax.pmap(
        functools.partial(train_step, model, clip_model),
        axis_name="batch",
        in_axes=(0, 0, 0, None, None, None),
        donate_argnums=(2,))

    update_pstep = jax.pmap(
        functools.partial(update_step,),
        axis_name="batch",
        in_axes=(0, None, None),
        donate_argnums=(0,))


    def render_fn(variables, key_0, key_1, rays):
        return jax.lax.all_gather(
            model.apply(variables, key_0, key_1, rays, FLAGS.randomized),
            axis_name="batch")

    render_pfn = jax.pmap(
        render_fn,
        in_axes=(None, None, None, 0),  # Only distribute the data input.
        donate_argnums=(3,),
        axis_name="batch")

    # Compiling to the CPU because it's faster and more accurate.
    ssim_fn = jax.jit(
        functools.partial(utils.compute_ssim, max_val=1.), backend="cpu")

    if not utils.isdir(FLAGS.train_dir):
        utils.makedirs(FLAGS.train_dir)
    state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
    # Resume training a the step of the last checkpoint.
    init_step = state.optimizer.state.step + 1

    # for distributive training
    state = flax.jax_utils.replicate(state)
    if jax.host_id() == 0:
        summary_writer = tensorboard.SummaryWriter(FLAGS.train_dir)

    # Prefetch_buffer_size = 3 x batch_size
    pdataset = flax.jax_utils.prefetch_to_device(dataset, 3)
    n_local_devices = jax.local_device_count()
    rng = rng + jax.host_id()  # Make random seed separate across hosts.
    keys = random.split(rng, n_local_devices)  # For pmapping RNG keys.
    gc.disable()  # Disable automatic garbage collection for efficiency.
    stats_trace = []
    reset_timer = True

    # for semantic loss update
    cnter = 1
    trigger = int(FLAGS.sc_loss_every / n_local_devices)

    for step, batch in tqdm(zip(range(init_step, FLAGS.max_steps + 1), pdataset)):
        if reset_timer:
            t_loop_start = time.time()
            reset_timer = False
        lr = learning_rate_fn(step)

        if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
            # remove dimension for device coz its only run in host core
            sc_batch = dataset.get_clip_data()
            sc_loss, sc_grad = clip_utils.update_semantic_loss(model, clip_model,
                                                               keys[0], state, sc_batch, lr)
            sc_grad = flax.jax_utils.replicate(sc_grad)
            sc_grad = jax.tree_map( lambda x: x[0], sc_grad)

        else:
            sc_loss = 0.
            
        state, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)#, grad)
        
        if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
            state = update_pstep(state, sc_grad, lr)
       
        if jax.host_id() == 0:
            stats_trace.append(stats)
        if step % FLAGS.gc_every == 0:
            gc.collect()

        # Log training summaries. This is put behind a host_id check because in
        # multi-host evaluation, all hosts need to run inference even though we
        # only use host 0 to record results.
        if jax.host_id() == 0:
            if step % FLAGS.print_every == 0:
                summary_writer.scalar("train_loss", stats.loss[0], step)
                summary_writer.scalar("train_psnr", stats.psnr[0], step)
                summary_writer.scalar("train_loss_coarse", stats.loss_c[0], step)
                summary_writer.scalar("train_psnr_coarse", stats.psnr_c[0], step)
                summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
                avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
                avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
                stats_trace = []
                summary_writer.scalar("train_avg_loss", avg_loss, step)
                summary_writer.scalar("train_avg_psnr", avg_psnr, step)
                summary_writer.scalar("learning_rate", lr, step)
                steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
                reset_timer = True
                rays_per_sec = FLAGS.batch_size * steps_per_sec
                summary_writer.scalar("train_steps_per_sec", steps_per_sec, step)
                summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
                precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
                print(("{:" + "{:d}".format(precision) + "d}").format(step) +
                      f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
                      f"avg_loss={avg_loss:0.4f}, " +
                      f"weight_l2={stats.weight_l2[0]:0.2e}, " +
                    #   f"sc_loss={sc_loss:0.4f}, " +
                      f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
            if step % FLAGS.save_every == 0:
                state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
                checkpoints.save_checkpoint(
                    FLAGS.train_dir, state_to_save, int(step), keep=100)

        # Test-set evaluation.
        if FLAGS.render_every > 0 and step % FLAGS.render_every == 0:
            # We reuse the same random number generator from the optimization step
            # here on purpose so that the visualization matches what happened in
            # training.
            t_eval_start = time.time()
            eval_variables = jax.device_get(jax.tree_map(lambda x: x[0],
                                                         state)).optimizer.target
            test_case = next(test_dataset)
            pred_color, pred_disp, pred_acc = utils.render_image(
                functools.partial(render_pfn, eval_variables),
                test_case["rays"],
                keys[0],
                FLAGS.dataset == "llff",
                chunk=FLAGS.chunk)

            # Log eval summaries on host 0.
            if jax.host_id() == 0:
                psnr = utils.compute_psnr(
                    ((pred_color - test_case["pixels"]) ** 2).mean())
                ssim = ssim_fn(pred_color, test_case["pixels"])
                eval_time = time.time() - t_eval_start
                num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
                rays_per_sec = num_rays / eval_time
                summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
                print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
                summary_writer.scalar("test_psnr", psnr, step)
                summary_writer.scalar("test_ssim", ssim, step)
                summary_writer.image("test_pred_color", pred_color, step)
                summary_writer.image("test_pred_disp", pred_disp, step)
                summary_writer.image("test_pred_acc", pred_acc, step)
                summary_writer.image("test_target", test_case["pixels"], step)

    if FLAGS.max_steps % FLAGS.save_every != 0:
        state = jax.device_get(jax.tree_map(lambda x: x[0], state))
        checkpoints.save_checkpoint(
            FLAGS.train_dir, state, int(FLAGS.max_steps), keep=100)


if __name__ == "__main__":
    app.run(main)