xieyizheng commited on
Commit
71fa249
1 Parent(s): 549caa6

Upload HyperNeRF_Training_clean.ipynb

Browse files
Files changed (1) hide show
  1. HyperNeRF_Training_clean.ipynb +693 -0
HyperNeRF_Training_clean.ipynb ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "EZ_wkNVdTz-C"
7
+ },
8
+ "source": [
9
+ "# Let's train HyperNeRF!\n",
10
+ "\n",
11
+ "**Author**: [Keunhong Park](https://keunhong.com)\n",
12
+ "\n",
13
+ "[[Project Page](https://hypernerf.github.io)]\n",
14
+ "[[Paper](https://arxiv.org/abs/2106.13228)]\n",
15
+ "[[GitHub](https://github.com/google/hypernerf)]\n",
16
+ "\n",
17
+ "This notebook provides an demo for training HyperNeRF.\n",
18
+ "\n",
19
+ "### Instructions\n",
20
+ "\n",
21
+ "1. Convert a video into our dataset format using the Nerfies [dataset processing notebook](https://colab.sandbox.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb).\n",
22
+ "2. Set the `data_dir` below to where you saved the dataset.\n",
23
+ "3. Come back to this notebook to train HyperNeRF.\n",
24
+ "\n",
25
+ "\n",
26
+ "### Notes\n",
27
+ " * To accomodate the limited compute power of Colab runtimes, this notebook defaults to a \"toy\" version of our method. The number of samples have been reduced and the elastic regularization turned off.\n",
28
+ "\n",
29
+ " * To train a high-quality model, please look at the CLI options we provide in the [Github repository](https://github.com/google/hypernerf).\n",
30
+ "\n",
31
+ "\n",
32
+ "\n",
33
+ " * Please report issues on the [GitHub issue tracker](https://github.com/google/hypernerf/issues).\n",
34
+ "\n",
35
+ "\n",
36
+ "If you find this work useful, please consider citing:\n",
37
+ "```bibtex\n",
38
+ "@article{park2021hypernerf\n",
39
+ " author = {Park, Keunhong and Sinha, Utkarsh and Hedman, Peter and Barron, Jonathan T. and Bouaziz, Sofien and Goldman, Dan B and Martin-Brualla, Ricardo and Seitz, Steven M.},\n",
40
+ " title = {HyperNeRF: A Higher-Dimensional Representation for Topologically Varying Neural Radiance Fields},\n",
41
+ " journal = {arXiv preprint arXiv:2106.13228},\n",
42
+ " year = {2021},\n",
43
+ "}\n",
44
+ "```\n"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "metadata": {
50
+ "id": "OlW1gF_djH6H"
51
+ },
52
+ "source": [
53
+ "## Environment Setup"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {
60
+ "colab": {
61
+ "base_uri": "https://localhost:8080/"
62
+ },
63
+ "id": "lMGu9ctBT-MD",
64
+ "outputId": "41a8dd06-943a-4820-c2cf-e98a25a167e7"
65
+ },
66
+ "outputs": [],
67
+ "source": [
68
+ "#!wget https://raw.githubusercontent.com/google/hypernerf/main/requirements.txt\n",
69
+ "!wget https://raw.githubusercontent.com/xieyizheng/hypernerf/main/requirements.txt\n",
70
+ "!python --version\n",
71
+ "!pip install -r requirements.txt"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {
78
+ "colab": {
79
+ "base_uri": "https://localhost:8080/"
80
+ },
81
+ "id": "ns2J1yBAsYgt",
82
+ "outputId": "6c73222d-8643-4fe7-8f90-1b2ab79465df"
83
+ },
84
+ "outputs": [],
85
+ "source": [
86
+ "!nvidia-smi"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "\n",
96
+ "#if only freshly installed the requirements, recommend to restart the runtime!\n"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {
103
+ "colab": {
104
+ "base_uri": "https://localhost:8080/"
105
+ },
106
+ "id": "zGJux-m5Xp3Z",
107
+ "outputId": "58e386b0-44be-4741-8dbf-43cb46dade40"
108
+ },
109
+ "outputs": [],
110
+ "source": [
111
+ "# @title Configure notebook runtime\n",
112
+ "#***only gpu works\n",
113
+ "# @markdown If you would like to use a GPU runtime instead, change the runtime type by going to `Runtime > Change runtime type`. \n",
114
+ "# @markdown You will have to use a smaller batch size on GPU.\n",
115
+ "import jax\n",
116
+ "#jax.config.update('jax_platform_name', 'gpu')\n",
117
+ "runtime_type = 'gpu' # @param ['gpu', 'tpu']\n",
118
+ "if runtime_type == 'tpu':\n",
119
+ " import jax.tools.colab_tpu\n",
120
+ " jax.tools.colab_tpu.setup_tpu()\n",
121
+ "\n",
122
+ "print('Detected Devices:', jax.devices())"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "metadata": {
129
+ "cellView": "form",
130
+ "colab": {
131
+ "base_uri": "https://localhost:8080/"
132
+ },
133
+ "id": "afUtLfRWULEi",
134
+ "outputId": "2919d242-fa49-447d-934e-877fbb42e5de"
135
+ },
136
+ "outputs": [],
137
+ "source": [
138
+ "# @title Mount Google Drive\n",
139
+ "# @markdown Mount Google Drive onto `/content/gdrive`. You can skip this if running locally.\n",
140
+ "\n",
141
+ "#use accordingly, if local, comment this out\n",
142
+ "from google.colab import drive\n",
143
+ "drive.mount('/content/gdrive')"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {
150
+ "id": "ENOfbG3AkcVN"
151
+ },
152
+ "outputs": [],
153
+ "source": [
154
+ "# @title Define imports and utility functions.\n",
155
+ "\n",
156
+ "import jax\n",
157
+ "from jax.config import config as jax_config\n",
158
+ "import jax.numpy as jnp\n",
159
+ "from jax import grad, jit, vmap\n",
160
+ "from jax import random\n",
161
+ "\n",
162
+ "import flax\n",
163
+ "import flax.linen as nn\n",
164
+ "from flax import jax_utils\n",
165
+ "from flax import optim\n",
166
+ "from flax.metrics import tensorboard\n",
167
+ "from flax.training import checkpoints\n",
168
+ "#jax_config.enable_omnistaging() # Linen requires enabling omnistaging\n",
169
+ "\n",
170
+ "from absl import logging\n",
171
+ "from io import BytesIO\n",
172
+ "import random as pyrandom\n",
173
+ "import numpy as np\n",
174
+ "import PIL\n",
175
+ "import IPython\n",
176
+ "\n",
177
+ "\n",
178
+ "# Monkey patch logging.\n",
179
+ "def myprint(msg, *args, **kwargs):\n",
180
+ " print(msg % args)\n",
181
+ "\n",
182
+ "logging.info = myprint \n",
183
+ "logging.warn = myprint\n",
184
+ "logging.error = myprint\n",
185
+ "\n",
186
+ "\n",
187
+ "def show_image(image, fmt='png'):\n",
188
+ " image = image_utils.image_to_uint8(image)\n",
189
+ " f = BytesIO()\n",
190
+ " PIL.Image.fromarray(image).save(f, fmt)\n",
191
+ " IPython.display.display(IPython.display.Image(data=f.getvalue()))\n",
192
+ "\n"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {
198
+ "id": "wW7FsSB-jORB"
199
+ },
200
+ "source": [
201
+ "## Configuration"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {
208
+ "cellView": "form",
209
+ "colab": {
210
+ "base_uri": "https://localhost:8080/",
211
+ "height": 1000
212
+ },
213
+ "id": "rz7wRm7YT9Ka",
214
+ "outputId": "8d185175-603a-4903-ad18-7626eb8d1d91"
215
+ },
216
+ "outputs": [],
217
+ "source": [
218
+ "# @title Model and dataset configuration\n",
219
+ "\n",
220
+ "from pathlib import Path\n",
221
+ "from pprint import pprint\n",
222
+ "import gin\n",
223
+ "from IPython.display import display, Markdown\n",
224
+ "\n",
225
+ "from hypernerf import models\n",
226
+ "from hypernerf import modules\n",
227
+ "from hypernerf import warping\n",
228
+ "from hypernerf import datasets\n",
229
+ "from hypernerf import configs\n",
230
+ "\n",
231
+ "\n",
232
+ "# @markdown The working directory.\n",
233
+ "train_dir = '/content/gdrive/My Drive/nerfies/hypernerf_experiments/hand/exp1' # @param {type: \"string\"}\n",
234
+ "# @markdown The directory to the dataset capture.\n",
235
+ "data_dir = '/content/gdrive/My Drive/nerfies/captures/hand' # @param {type: \"string\"}\n",
236
+ "\n",
237
+ "# @markdown Training configuration.\n",
238
+ "max_steps = 200000 # @param {type: 'number'}\n",
239
+ "batch_size = 2048 # @param {type: 'number'}\n",
240
+ "image_scale = 8 # @param {type: 'number'}\n",
241
+ "\n",
242
+ "# @markdown Model configuration.\n",
243
+ "use_viewdirs = True #@param {type: 'boolean'}\n",
244
+ "use_appearance_metadata = True #@param {type: 'boolean'}\n",
245
+ "num_coarse_samples = 64 # @param {type: 'number'}\n",
246
+ "num_fine_samples = 64 # @param {type: 'number'}\n",
247
+ "\n",
248
+ "# @markdown Deformation configuration.\n",
249
+ "use_warp = True #@param {type: 'boolean'}\n",
250
+ "warp_field_type = '@SE3Field' #@param['@SE3Field', '@TranslationField']\n",
251
+ "warp_min_deg = 0 #@param{type:'number'}\n",
252
+ "warp_max_deg = 6 #@param{type:'number'}\n",
253
+ "\n",
254
+ "# @markdown Hyper-space configuration.\n",
255
+ "hyper_num_dims = 8 #@param{type:'number'}\n",
256
+ "hyper_point_min_deg = 0 #@param{type:'number'}\n",
257
+ "hyper_point_max_deg = 1 #@param{type:'number'}\n",
258
+ "hyper_slice_method = 'bendy_sheet' #@param['none', 'axis_aligned_plane', 'bendy_sheet']\n",
259
+ "\n",
260
+ "\n",
261
+ "checkpoint_dir = Path(train_dir, 'checkpoints')\n",
262
+ "checkpoint_dir.mkdir(exist_ok=True, parents=True)\n",
263
+ "\n",
264
+ "config_str = f\"\"\"\n",
265
+ "DELAYED_HYPER_ALPHA_SCHED = {{\n",
266
+ " 'type': 'piecewise',\n",
267
+ " 'schedules': [\n",
268
+ " (1000, ('constant', 0.0)),\n",
269
+ " (0, ('linear', 0.0, %hyper_point_max_deg, 10000))\n",
270
+ " ],\n",
271
+ "}}\n",
272
+ "\n",
273
+ "ExperimentConfig.image_scale = {image_scale}\n",
274
+ "ExperimentConfig.datasource_cls = @NerfiesDataSource\n",
275
+ "NerfiesDataSource.data_dir = '{data_dir}'\n",
276
+ "NerfiesDataSource.image_scale = {image_scale}\n",
277
+ "\n",
278
+ "NerfModel.use_viewdirs = {int(use_viewdirs)}\n",
279
+ "NerfModel.use_rgb_condition = {int(use_appearance_metadata)}\n",
280
+ "NerfModel.num_coarse_samples = {num_coarse_samples}\n",
281
+ "NerfModel.num_fine_samples = {num_fine_samples}\n",
282
+ "\n",
283
+ "NerfModel.use_viewdirs = True\n",
284
+ "NerfModel.use_stratified_sampling = True\n",
285
+ "NerfModel.use_posenc_identity = False\n",
286
+ "NerfModel.nerf_trunk_width = 128\n",
287
+ "NerfModel.nerf_trunk_depth = 8\n",
288
+ "\n",
289
+ "TrainConfig.max_steps = {max_steps}\n",
290
+ "TrainConfig.batch_size = {batch_size}\n",
291
+ "TrainConfig.print_every = 100\n",
292
+ "TrainConfig.use_elastic_loss = False\n",
293
+ "TrainConfig.use_background_loss = False\n",
294
+ "\n",
295
+ "# Warp configs.\n",
296
+ "warp_min_deg = {warp_min_deg}\n",
297
+ "warp_max_deg = {warp_max_deg}\n",
298
+ "NerfModel.use_warp = {use_warp}\n",
299
+ "SE3Field.min_deg = %warp_min_deg\n",
300
+ "SE3Field.max_deg = %warp_max_deg\n",
301
+ "SE3Field.use_posenc_identity = False\n",
302
+ "NerfModel.warp_field_cls = @SE3Field\n",
303
+ "\n",
304
+ "TrainConfig.warp_alpha_schedule = {{\n",
305
+ " 'type': 'linear',\n",
306
+ " 'initial_value': {warp_min_deg},\n",
307
+ " 'final_value': {warp_max_deg},\n",
308
+ " 'num_steps': {int(max_steps*0.8)},\n",
309
+ "}}\n",
310
+ "\n",
311
+ "# Hyper configs.\n",
312
+ "hyper_num_dims = {hyper_num_dims}\n",
313
+ "hyper_point_min_deg = {hyper_point_min_deg}\n",
314
+ "hyper_point_max_deg = {hyper_point_max_deg}\n",
315
+ "\n",
316
+ "NerfModel.hyper_embed_cls = @hyper/GLOEmbed\n",
317
+ "hyper/GLOEmbed.num_dims = %hyper_num_dims\n",
318
+ "NerfModel.hyper_point_min_deg = %hyper_point_min_deg\n",
319
+ "NerfModel.hyper_point_max_deg = %hyper_point_max_deg\n",
320
+ "\n",
321
+ "TrainConfig.hyper_alpha_schedule = %DELAYED_HYPER_ALPHA_SCHED\n",
322
+ "\n",
323
+ "hyper_sheet_min_deg = 0\n",
324
+ "hyper_sheet_max_deg = 6\n",
325
+ "HyperSheetMLP.min_deg = %hyper_sheet_min_deg\n",
326
+ "HyperSheetMLP.max_deg = %hyper_sheet_max_deg\n",
327
+ "HyperSheetMLP.output_channels = %hyper_num_dims\n",
328
+ "\n",
329
+ "NerfModel.hyper_slice_method = '{hyper_slice_method}'\n",
330
+ "NerfModel.hyper_sheet_mlp_cls = @HyperSheetMLP\n",
331
+ "NerfModel.hyper_use_warp_embed = True\n",
332
+ "\n",
333
+ "TrainConfig.hyper_sheet_alpha_schedule = ('constant', %hyper_sheet_max_deg)\n",
334
+ "\"\"\"\n",
335
+ "\n",
336
+ "gin.parse_config(config_str)\n",
337
+ "\n",
338
+ "config_path = Path(train_dir, 'config.gin')\n",
339
+ "with open(config_path, 'w') as f:\n",
340
+ " logging.info('Saving config to %s', config_path)\n",
341
+ " f.write(config_str)\n",
342
+ "\n",
343
+ "exp_config = configs.ExperimentConfig()\n",
344
+ "train_config = configs.TrainConfig()\n",
345
+ "eval_config = configs.EvalConfig()\n",
346
+ "\n",
347
+ "display(Markdown(\n",
348
+ " gin.config.markdown(gin.config_str())))"
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "execution_count": null,
354
+ "metadata": {
355
+ "cellView": "form",
356
+ "colab": {
357
+ "base_uri": "https://localhost:8080/",
358
+ "height": 533
359
+ },
360
+ "id": "r872r6hiVUVS",
361
+ "outputId": "f8794983-1165-4e93-8236-6cac48bbd552"
362
+ },
363
+ "outputs": [],
364
+ "source": [
365
+ "# @title Create datasource and show an example.\n",
366
+ "\n",
367
+ "from hypernerf import datasets\n",
368
+ "from hypernerf import image_utils\n",
369
+ "\n",
370
+ "dummy_model = models.NerfModel({}, 0, 0)\n",
371
+ "datasource = exp_config.datasource_cls(\n",
372
+ " image_scale=exp_config.image_scale,\n",
373
+ " random_seed=exp_config.random_seed,\n",
374
+ " # Enable metadata based on model needs.\n",
375
+ " use_warp_id=dummy_model.use_warp,\n",
376
+ " use_appearance_id=(\n",
377
+ " dummy_model.nerf_embed_key == 'appearance'\n",
378
+ " or dummy_model.hyper_embed_key == 'appearance'),\n",
379
+ " use_camera_id=dummy_model.nerf_embed_key == 'camera',\n",
380
+ " use_time=dummy_model.warp_embed_key == 'time')\n",
381
+ "\n",
382
+ "show_image(datasource.load_rgb(datasource.train_ids[0]))"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "metadata": {
389
+ "colab": {
390
+ "base_uri": "https://localhost:8080/"
391
+ },
392
+ "id": "XC3PIY74XB05",
393
+ "outputId": "b2f57210-07ff-49c5-b51b-87864d4a1f17"
394
+ },
395
+ "outputs": [],
396
+ "source": [
397
+ "# @title Create training iterators\n",
398
+ "\n",
399
+ "devices = jax.local_devices()\n",
400
+ "\n",
401
+ "train_iter = datasource.create_iterator(\n",
402
+ " datasource.train_ids,\n",
403
+ " flatten=True,\n",
404
+ " shuffle=True,\n",
405
+ " batch_size=train_config.batch_size,\n",
406
+ " prefetch_size=3,\n",
407
+ " shuffle_buffer_size=train_config.shuffle_buffer_size,\n",
408
+ " devices=devices,\n",
409
+ ")\n",
410
+ "\n",
411
+ "def shuffled(l):\n",
412
+ " import random as r\n",
413
+ " import copy\n",
414
+ " l = copy.copy(l)\n",
415
+ " r.shuffle(l)\n",
416
+ " return l\n",
417
+ "\n",
418
+ "train_eval_iter = datasource.create_iterator(\n",
419
+ " shuffled(datasource.train_ids), batch_size=0, devices=devices)\n",
420
+ "val_eval_iter = datasource.create_iterator(\n",
421
+ " shuffled(datasource.val_ids), batch_size=0, devices=devices)"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "markdown",
426
+ "metadata": {
427
+ "id": "erY9l66KjYYW"
428
+ },
429
+ "source": [
430
+ "## Training"
431
+ ]
432
+ },
433
+ {
434
+ "cell_type": "code",
435
+ "execution_count": null,
436
+ "metadata": {
437
+ "colab": {
438
+ "base_uri": "https://localhost:8080/"
439
+ },
440
+ "id": "nZnS8BhcXe5E",
441
+ "outputId": "980fda5d-863e-4b7e-d3dd-aff48dd4950a"
442
+ },
443
+ "outputs": [],
444
+ "source": [
445
+ "# @title Initialize model\n",
446
+ "# @markdown Defines the model and initializes its parameters.\n",
447
+ "\n",
448
+ "from flax.training import checkpoints\n",
449
+ "from hypernerf import models\n",
450
+ "from hypernerf import model_utils\n",
451
+ "from hypernerf import schedules\n",
452
+ "from hypernerf import training\n",
453
+ "\n",
454
+ "# @markdown Restore a checkpoint if one exists.\n",
455
+ "restore_checkpoint = True # @param{type:'boolean'}\n",
456
+ "\n",
457
+ "\n",
458
+ "rng = random.PRNGKey(exp_config.random_seed)\n",
459
+ "np.random.seed(exp_config.random_seed + jax.process_index())\n",
460
+ "devices_to_use = jax.devices()\n",
461
+ "\n",
462
+ "learning_rate_sched = schedules.from_config(train_config.lr_schedule)\n",
463
+ "nerf_alpha_sched = schedules.from_config(train_config.nerf_alpha_schedule)\n",
464
+ "warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)\n",
465
+ "elastic_loss_weight_sched = schedules.from_config(\n",
466
+ "train_config.elastic_loss_weight_schedule)\n",
467
+ "hyper_alpha_sched = schedules.from_config(train_config.hyper_alpha_schedule)\n",
468
+ "hyper_sheet_alpha_sched = schedules.from_config(\n",
469
+ " train_config.hyper_sheet_alpha_schedule)\n",
470
+ "\n",
471
+ "rng, key = random.split(rng)\n",
472
+ "params = {}\n",
473
+ "model, params['model'] = models.construct_nerf(\n",
474
+ " key,\n",
475
+ " batch_size=train_config.batch_size,\n",
476
+ " embeddings_dict=datasource.embeddings_dict,\n",
477
+ " near=datasource.near,\n",
478
+ " far=datasource.far)\n",
479
+ "\n",
480
+ "optimizer_def = optim.Adam(learning_rate_sched(0))\n",
481
+ "optimizer = optimizer_def.create(params)\n",
482
+ "\n",
483
+ "state = model_utils.TrainState(\n",
484
+ " optimizer=optimizer,\n",
485
+ " nerf_alpha=nerf_alpha_sched(0),\n",
486
+ " warp_alpha=warp_alpha_sched(0),\n",
487
+ " hyper_alpha=hyper_alpha_sched(0),\n",
488
+ " hyper_sheet_alpha=hyper_sheet_alpha_sched(0))\n",
489
+ "scalar_params = training.ScalarParams(\n",
490
+ " learning_rate=learning_rate_sched(0),\n",
491
+ " elastic_loss_weight=elastic_loss_weight_sched(0),\n",
492
+ " warp_reg_loss_weight=train_config.warp_reg_loss_weight,\n",
493
+ " warp_reg_loss_alpha=train_config.warp_reg_loss_alpha,\n",
494
+ " warp_reg_loss_scale=train_config.warp_reg_loss_scale,\n",
495
+ " background_loss_weight=train_config.background_loss_weight,\n",
496
+ " hyper_reg_loss_weight=train_config.hyper_reg_loss_weight)\n",
497
+ "\n",
498
+ "if restore_checkpoint:\n",
499
+ " logging.info('Restoring checkpoint from %s', checkpoint_dir)\n",
500
+ " state = checkpoints.restore_checkpoint(checkpoint_dir, state)\n",
501
+ "step = state.optimizer.state.step + 1\n",
502
+ "state = jax_utils.replicate(state, devices=devices)\n",
503
+ "del params"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "code",
508
+ "execution_count": null,
509
+ "metadata": {
510
+ "id": "at2CL5DRZ7By"
511
+ },
512
+ "outputs": [],
513
+ "source": [
514
+ "# @title Define pmapped functions\n",
515
+ "# @markdown This parallelizes the training and evaluation step functions using `jax.pmap`.\n",
516
+ "\n",
517
+ "import functools\n",
518
+ "from hypernerf import evaluation\n",
519
+ "\n",
520
+ "\n",
521
+ "def _model_fn(key_0, key_1, params, rays_dict, extra_params):\n",
522
+ " out = model.apply({'params': params},\n",
523
+ " rays_dict,\n",
524
+ " extra_params=extra_params,\n",
525
+ " rngs={\n",
526
+ " 'coarse': key_0,\n",
527
+ " 'fine': key_1\n",
528
+ " },\n",
529
+ " mutable=False)\n",
530
+ " return jax.lax.all_gather(out, axis_name='batch')\n",
531
+ "\n",
532
+ "pmodel_fn = jax.pmap(\n",
533
+ " # Note rng_keys are useless in eval mode since there's no randomness.\n",
534
+ " _model_fn,\n",
535
+ " in_axes=(0, 0, 0, 0, 0), # Only distribute the data input.\n",
536
+ " devices=devices_to_use,\n",
537
+ " axis_name='batch',\n",
538
+ ")\n",
539
+ "\n",
540
+ "render_fn = functools.partial(evaluation.render_image,\n",
541
+ " model_fn=pmodel_fn,\n",
542
+ " device_count=len(devices),\n",
543
+ " chunk=eval_config.chunk)\n",
544
+ "train_step = functools.partial(\n",
545
+ " training.train_step,\n",
546
+ " model,\n",
547
+ " elastic_reduce_method=train_config.elastic_reduce_method,\n",
548
+ " elastic_loss_type=train_config.elastic_loss_type,\n",
549
+ " use_elastic_loss=train_config.use_elastic_loss,\n",
550
+ " use_background_loss=train_config.use_background_loss,\n",
551
+ " use_warp_reg_loss=train_config.use_warp_reg_loss,\n",
552
+ " use_hyper_reg_loss=train_config.use_hyper_reg_loss,\n",
553
+ ")\n",
554
+ "ptrain_step = jax.pmap(\n",
555
+ " train_step,\n",
556
+ " axis_name='batch',\n",
557
+ " devices=devices,\n",
558
+ " # rng_key, state, batch, scalar_params.\n",
559
+ " in_axes=(0, 0, 0, None),\n",
560
+ " # Treat use_elastic_loss as compile-time static.\n",
561
+ " donate_argnums=(2,), # Donate the 'batch' argument.\n",
562
+ ")"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "execution_count": null,
568
+ "metadata": {
569
+ "colab": {
570
+ "base_uri": "https://localhost:8080/",
571
+ "height": 1000
572
+ },
573
+ "id": "vbc7cMr5aR_1",
574
+ "outputId": "d35e110d-7dbc-41ca-acae-4f81d0a5af22"
575
+ },
576
+ "outputs": [],
577
+ "source": [
578
+ "# @title Train!\n",
579
+ "# @markdown This runs the training loop!\n",
580
+ "\n",
581
+ "import mediapy\n",
582
+ "from hypernerf import utils\n",
583
+ "from hypernerf import visualization as viz\n",
584
+ "\n",
585
+ "\n",
586
+ "print_every_n_iterations = 100 # @param{type:'number'}\n",
587
+ "visualize_results_every_n_iterations = 500 # @param{type:'number'}\n",
588
+ "save_checkpoint_every_n_iterations = 1000 # @param{type:'number'}\n",
589
+ "\n",
590
+ "\n",
591
+ "logging.info('Starting training')\n",
592
+ "rng = rng + jax.process_index() # Make random seed separate across hosts.\n",
593
+ "keys = random.split(rng, len(devices))\n",
594
+ "time_tracker = utils.TimeTracker()\n",
595
+ "time_tracker.tic('data', 'total')\n",
596
+ "\n",
597
+ "for step, batch in zip(range(step, train_config.max_steps + 1), train_iter):\n",
598
+ " time_tracker.toc('data')\n",
599
+ " scalar_params = scalar_params.replace(\n",
600
+ " learning_rate=learning_rate_sched(step),\n",
601
+ " elastic_loss_weight=elastic_loss_weight_sched(step))\n",
602
+ " # pytype: enable=attribute-error\n",
603
+ " nerf_alpha = jax_utils.replicate(nerf_alpha_sched(step), devices)\n",
604
+ " warp_alpha = jax_utils.replicate(warp_alpha_sched(step), devices)\n",
605
+ " hyper_alpha = jax_utils.replicate(hyper_alpha_sched(step), devices)\n",
606
+ " hyper_sheet_alpha = jax_utils.replicate(\n",
607
+ " hyper_sheet_alpha_sched(step), devices)\n",
608
+ " state = state.replace(nerf_alpha=nerf_alpha,\n",
609
+ " warp_alpha=warp_alpha,\n",
610
+ " hyper_alpha=hyper_alpha,\n",
611
+ " hyper_sheet_alpha=hyper_sheet_alpha)\n",
612
+ "\n",
613
+ " with time_tracker.record_time('train_step'):\n",
614
+ " state, stats, keys, _ = ptrain_step(keys, state, batch, scalar_params)\n",
615
+ " time_tracker.toc('total')\n",
616
+ "\n",
617
+ " if step % print_every_n_iterations == 0:\n",
618
+ " logging.info(\n",
619
+ " 'step=%d, warp_alpha=%.04f, hyper_alpha=%.04f, hyper_sheet_alpha=%.04f, %s',\n",
620
+ " step, \n",
621
+ " warp_alpha_sched(step), \n",
622
+ " hyper_alpha_sched(step), \n",
623
+ " hyper_sheet_alpha_sched(step), \n",
624
+ " time_tracker.summary_str('last'))\n",
625
+ " coarse_metrics_str = ', '.join(\n",
626
+ " [f'{k}={v.mean():.04f}' for k, v in stats['coarse'].items()])\n",
627
+ " fine_metrics_str = ', '.join(\n",
628
+ " [f'{k}={v.mean():.04f}' for k, v in stats['fine'].items()])\n",
629
+ " logging.info('\\tcoarse metrics: %s', coarse_metrics_str)\n",
630
+ " if 'fine' in stats:\n",
631
+ " logging.info('\\tfine metrics: %s', fine_metrics_str)\n",
632
+ " \n",
633
+ " if step % visualize_results_every_n_iterations == 0:\n",
634
+ " print(f'[step={step}] Training set visualization')\n",
635
+ " eval_batch = next(train_eval_iter)\n",
636
+ " render = render_fn(state, eval_batch, rng=rng)\n",
637
+ " rgb = render['rgb']\n",
638
+ " acc = render['acc']\n",
639
+ " depth_exp = render['depth']\n",
640
+ " depth_med = render['med_depth']\n",
641
+ " rgb_target = eval_batch['rgb']\n",
642
+ " depth_med_viz = viz.colorize(depth_med, cmin=datasource.near, cmax=datasource.far)\n",
643
+ " mediapy.show_images([rgb_target, rgb, depth_med_viz],\n",
644
+ " titles=['GT RGB', 'Pred RGB', 'Pred Depth'])\n",
645
+ "\n",
646
+ " print(f'[step={step}] Validation set visualization')\n",
647
+ " eval_batch = next(val_eval_iter)\n",
648
+ " render = render_fn(state, eval_batch, rng=rng)\n",
649
+ " rgb = render['rgb']\n",
650
+ " acc = render['acc']\n",
651
+ " depth_exp = render['depth']\n",
652
+ " depth_med = render['med_depth']\n",
653
+ " rgb_target = eval_batch['rgb']\n",
654
+ " depth_med_viz = viz.colorize(depth_med, cmin=datasource.near, cmax=datasource.far)\n",
655
+ " mediapy.show_images([rgb_target, rgb, depth_med_viz],\n",
656
+ " titles=['GT RGB', 'Pred RGB', 'Pred Depth'])\n",
657
+ "\n",
658
+ " if step % save_checkpoint_every_n_iterations == 0:\n",
659
+ " training.save_checkpoint(checkpoint_dir, state)\n",
660
+ "\n",
661
+ " time_tracker.tic('data', 'total')\n"
662
+ ]
663
+ }
664
+ ],
665
+ "metadata": {
666
+ "accelerator": "GPU",
667
+ "colab": {
668
+ "gpuType": "V100",
669
+ "machine_shape": "hm",
670
+ "provenance": []
671
+ },
672
+ "gpuClass": "standard",
673
+ "kernelspec": {
674
+ "display_name": "Python 3 (ipykernel)",
675
+ "language": "python",
676
+ "name": "python3"
677
+ },
678
+ "language_info": {
679
+ "codemirror_mode": {
680
+ "name": "ipython",
681
+ "version": 3
682
+ },
683
+ "file_extension": ".py",
684
+ "mimetype": "text/x-python",
685
+ "name": "python",
686
+ "nbconvert_exporter": "python",
687
+ "pygments_lexer": "ipython3",
688
+ "version": "3.10.10"
689
+ }
690
+ },
691
+ "nbformat": 4,
692
+ "nbformat_minor": 1
693
+ }