xieyizheng commited on
Commit
b27a485
1 Parent(s): 87286e6

Upload HyperNeRF_Render_Video_clean.ipynb

Browse files
Files changed (1) hide show
  1. HyperNeRF_Render_Video_clean.ipynb +473 -0
HyperNeRF_Render_Video_clean.ipynb ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "QMMWf9AQcdlp"
7
+ },
8
+ "source": [
9
+ "# Render a HyperNeRF video!\n",
10
+ "\n",
11
+ "**Author**: [Keunhong Park](https://keunhong.com)\n",
12
+ "\n",
13
+ "[[Project Page](https://hypernerf.github.io)]\n",
14
+ "[[Paper](https://arxiv.org/abs/2106.13228)]\n",
15
+ "[[GitHub](https://github.com/google/hypernerf)]\n",
16
+ "\n",
17
+ "This notebook renders a video using the test cameras generated in the capture processing notebook.\n",
18
+ "\n",
19
+ "You can also load your own custom cameras by modifying the code slightly.\n",
20
+ "\n",
21
+ "### Instructions\n",
22
+ "\n",
23
+ "1. Convert a video into our dataset format using the [capture processing notebook](https://colab.sandbox.google.com/github/google/nerfies/blob/main/notebooks/Nerfies_Capture_Processing.ipynb).\n",
24
+ "2. Train a HyperNeRF model using the [training notebook](https://colab.sandbox.google.com/github/google/hypernerf/blob/main/notebooks/HyperNeRF_Training.ipynb)\n",
25
+ "3. Run this notebook!\n",
26
+ "\n",
27
+ "\n",
28
+ "### Notes\n",
29
+ " * Please report issues on the [GitHub issue tracker](https://github.com/google/hypernerf/issues)."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "metadata": {
35
+ "id": "gHqkIo4hcGou"
36
+ },
37
+ "source": [
38
+ "## Environment Setup"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {
45
+ "colab": {
46
+ "base_uri": "https://localhost:8080/"
47
+ },
48
+ "id": "ws81Eje47SuV",
49
+ "outputId": "2fa89ef5-4030-46d4-e2d9-2eebffd1b0f9"
50
+ },
51
+ "outputs": [],
52
+ "source": [
53
+ "#!wget https://raw.githubusercontent.com/google/hypernerf/main/requirements.txt\n",
54
+ "!wget https://raw.githubusercontent.com/xieyizheng/hypernerf/main/requirements.txt\n",
55
+ "!python --version\n",
56
+ "!pip install -r requirements.txt\n",
57
+ "\n",
58
+ "#if freshly installed, recommend to restart the runtime!"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {
65
+ "colab": {
66
+ "base_uri": "https://localhost:8080/"
67
+ },
68
+ "id": "-3T2lBKBcIGP",
69
+ "outputId": "6bcc5d9c-108a-4c2b-bef5-fe140c87b3fb"
70
+ },
71
+ "outputs": [],
72
+ "source": [
73
+ "# @title Configure notebook runtime\n",
74
+ "# @markdown If you would like to use a GPU runtime instead, change the runtime type by going to `Runtime > Change runtime type`. \n",
75
+ "# @markdown You will have to use a smaller batch size on GPU.\n",
76
+ "import jax\n",
77
+ "runtime_type = 'gpu' # @param ['gpu', 'tpu']\n",
78
+ "if runtime_type == 'tpu':\n",
79
+ " import jax.tools.colab_tpu\n",
80
+ " jax.tools.colab_tpu.setup_tpu()\n",
81
+ "\n",
82
+ "print('Detected Devices:', jax.devices())"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {
89
+ "colab": {
90
+ "base_uri": "https://localhost:8080/"
91
+ },
92
+ "id": "82kU-W1NcNTW",
93
+ "outputId": "08a21bab-c3cc-43a0-9f1f-fb7e843a8aaa"
94
+ },
95
+ "outputs": [],
96
+ "source": [
97
+ "# @title Mount Google Drive\n",
98
+ "# @markdown Mount Google Drive onto `/content/gdrive`. You can skip this if running locally.\n",
99
+ "\n",
100
+ "from google.colab import drive\n",
101
+ "drive.mount('/content/gdrive')"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "metadata": {
108
+ "id": "YIDbV769cPn1"
109
+ },
110
+ "outputs": [],
111
+ "source": [
112
+ "# @title Define imports and utility functions.\n",
113
+ "\n",
114
+ "import jax\n",
115
+ "from jax.config import config as jax_config\n",
116
+ "import jax.numpy as jnp\n",
117
+ "from jax import grad, jit, vmap\n",
118
+ "from jax import random\n",
119
+ "\n",
120
+ "import flax\n",
121
+ "import flax.linen as nn\n",
122
+ "from flax import jax_utils\n",
123
+ "from flax import optim\n",
124
+ "from flax.metrics import tensorboard\n",
125
+ "from flax.training import checkpoints\n",
126
+ "\n",
127
+ "from absl import logging\n",
128
+ "from io import BytesIO\n",
129
+ "import random as pyrandom\n",
130
+ "import numpy as np\n",
131
+ "import PIL\n",
132
+ "import IPython\n",
133
+ "import tempfile\n",
134
+ "import imageio\n",
135
+ "import mediapy\n",
136
+ "from IPython.display import display, HTML\n",
137
+ "from base64 import b64encode\n",
138
+ "\n",
139
+ "\n",
140
+ "# Monkey patch logging.\n",
141
+ "def myprint(msg, *args, **kwargs):\n",
142
+ " print(msg % args)\n",
143
+ "\n",
144
+ "logging.info = myprint \n",
145
+ "logging.warn = myprint\n",
146
+ "logging.error = myprint"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "metadata": {
153
+ "colab": {
154
+ "base_uri": "https://localhost:8080/",
155
+ "height": 1000
156
+ },
157
+ "id": "2QYJ7dyMcw2f",
158
+ "outputId": "73b49855-05f6-4377-a57a-3a5a061c980a"
159
+ },
160
+ "outputs": [],
161
+ "source": [
162
+ "# @title Model and dataset configuration\n",
163
+ "# @markdown Change the directories to where you saved your capture and experiment.\n",
164
+ "\n",
165
+ "\n",
166
+ "from pathlib import Path\n",
167
+ "from pprint import pprint\n",
168
+ "import gin\n",
169
+ "from IPython.display import display, Markdown\n",
170
+ "\n",
171
+ "from hypernerf import models\n",
172
+ "from hypernerf import modules\n",
173
+ "from hypernerf import warping\n",
174
+ "from hypernerf import datasets\n",
175
+ "from hypernerf import configs\n",
176
+ "\n",
177
+ "\n",
178
+ "# @markdown The working directory where the trained model is.\n",
179
+ "train_dir = '/content/gdrive/My Drive/nerfies/hypernerf_experiments/dvd/exp2' # @param {type: \"string\"}\n",
180
+ "# @markdown The directory to the dataset capture.\n",
181
+ "data_dir = '/content/gdrive/My Drive/nerfies/captures/dvd' # @param {type: \"string\"}\n",
182
+ "\n",
183
+ "checkpoint_dir = Path(train_dir, 'checkpoints')\n",
184
+ "checkpoint_dir.mkdir(exist_ok=True, parents=True)\n",
185
+ "\n",
186
+ "config_path = Path(train_dir, 'config.gin')\n",
187
+ "with open(config_path, 'r') as f:\n",
188
+ " logging.info('Loading config from %s', config_path)\n",
189
+ " config_str = f.read()\n",
190
+ "gin.parse_config(config_str)\n",
191
+ "\n",
192
+ "config_path = Path(train_dir, 'config.gin')\n",
193
+ "with open(config_path, 'w') as f:\n",
194
+ " logging.info('Saving config to %s', config_path)\n",
195
+ " f.write(config_str)\n",
196
+ "\n",
197
+ "exp_config = configs.ExperimentConfig()\n",
198
+ "train_config = configs.TrainConfig()\n",
199
+ "eval_config = configs.EvalConfig()\n",
200
+ "\n",
201
+ "display(Markdown(\n",
202
+ " gin.config.markdown(gin.config_str())))"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "metadata": {
209
+ "cellView": "form",
210
+ "colab": {
211
+ "base_uri": "https://localhost:8080/",
212
+ "height": 439
213
+ },
214
+ "id": "6T7LQ5QSmu4o",
215
+ "outputId": "399c441e-b125-4a99-b36e-7b58e0256858"
216
+ },
217
+ "outputs": [],
218
+ "source": [
219
+ "# @title Create datasource and show an example.\n",
220
+ "\n",
221
+ "from hypernerf import datasets\n",
222
+ "from hypernerf import image_utils\n",
223
+ "\n",
224
+ "dummy_model = models.NerfModel({}, 0, 0)\n",
225
+ "datasource = exp_config.datasource_cls(\n",
226
+ " image_scale=exp_config.image_scale,\n",
227
+ " random_seed=exp_config.random_seed,\n",
228
+ " # Enable metadata based on model needs.\n",
229
+ " use_warp_id=dummy_model.use_warp,\n",
230
+ " use_appearance_id=(\n",
231
+ " dummy_model.nerf_embed_key == 'appearance'\n",
232
+ " or dummy_model.hyper_embed_key == 'appearance'),\n",
233
+ " use_camera_id=dummy_model.nerf_embed_key == 'camera',\n",
234
+ " use_time=dummy_model.warp_embed_key == 'time')\n",
235
+ "\n",
236
+ "mediapy.show_image(datasource.load_rgb(datasource.train_ids[0]))"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "metadata": {
243
+ "colab": {
244
+ "base_uri": "https://localhost:8080/"
245
+ },
246
+ "id": "jEO3xcxpnCqx",
247
+ "outputId": "15e2646e-cf00-4c86-f110-86e21b686813"
248
+ },
249
+ "outputs": [],
250
+ "source": [
251
+ "# @title Load model\n",
252
+ "# @markdown Defines the model and initializes its parameters.\n",
253
+ "\n",
254
+ "from flax.training import checkpoints\n",
255
+ "from hypernerf import models\n",
256
+ "from hypernerf import model_utils\n",
257
+ "from hypernerf import schedules\n",
258
+ "from hypernerf import training\n",
259
+ "\n",
260
+ "rng = random.PRNGKey(exp_config.random_seed)\n",
261
+ "np.random.seed(exp_config.random_seed + jax.process_index())\n",
262
+ "devices_to_use = jax.devices()\n",
263
+ "\n",
264
+ "learning_rate_sched = schedules.from_config(train_config.lr_schedule)\n",
265
+ "nerf_alpha_sched = schedules.from_config(train_config.nerf_alpha_schedule)\n",
266
+ "warp_alpha_sched = schedules.from_config(train_config.warp_alpha_schedule)\n",
267
+ "elastic_loss_weight_sched = schedules.from_config(\n",
268
+ "train_config.elastic_loss_weight_schedule)\n",
269
+ "hyper_alpha_sched = schedules.from_config(train_config.hyper_alpha_schedule)\n",
270
+ "hyper_sheet_alpha_sched = schedules.from_config(\n",
271
+ " train_config.hyper_sheet_alpha_schedule)\n",
272
+ "\n",
273
+ "rng, key = random.split(rng)\n",
274
+ "params = {}\n",
275
+ "model, params['model'] = models.construct_nerf(\n",
276
+ " key,\n",
277
+ " batch_size=train_config.batch_size,\n",
278
+ " embeddings_dict=datasource.embeddings_dict,\n",
279
+ " near=datasource.near,\n",
280
+ " far=datasource.far)\n",
281
+ "\n",
282
+ "optimizer_def = optim.Adam(learning_rate_sched(0))\n",
283
+ "optimizer = optimizer_def.create(params)\n",
284
+ "\n",
285
+ "state = model_utils.TrainState(\n",
286
+ " optimizer=optimizer,\n",
287
+ " nerf_alpha=nerf_alpha_sched(0),\n",
288
+ " warp_alpha=warp_alpha_sched(0),\n",
289
+ " hyper_alpha=hyper_alpha_sched(0),\n",
290
+ " hyper_sheet_alpha=hyper_sheet_alpha_sched(0))\n",
291
+ "scalar_params = training.ScalarParams(\n",
292
+ " learning_rate=learning_rate_sched(0),\n",
293
+ " elastic_loss_weight=elastic_loss_weight_sched(0),\n",
294
+ " warp_reg_loss_weight=train_config.warp_reg_loss_weight,\n",
295
+ " warp_reg_loss_alpha=train_config.warp_reg_loss_alpha,\n",
296
+ " warp_reg_loss_scale=train_config.warp_reg_loss_scale,\n",
297
+ " background_loss_weight=train_config.background_loss_weight,\n",
298
+ " hyper_reg_loss_weight=train_config.hyper_reg_loss_weight)\n",
299
+ "\n",
300
+ "logging.info('Restoring checkpoint from %s', checkpoint_dir)\n",
301
+ "state = checkpoints.restore_checkpoint(checkpoint_dir, state)\n",
302
+ "step = state.optimizer.state.step + 1\n",
303
+ "state = jax_utils.replicate(state, devices=devices_to_use)\n",
304
+ "del params"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": null,
310
+ "metadata": {
311
+ "cellView": "form",
312
+ "id": "2KYhbpsklwAy"
313
+ },
314
+ "outputs": [],
315
+ "source": [
316
+ "# @title Define pmapped render function.\n",
317
+ "\n",
318
+ "import functools\n",
319
+ "from hypernerf import evaluation\n",
320
+ "\n",
321
+ "devices = jax.devices()\n",
322
+ "\n",
323
+ "\n",
324
+ "def _model_fn(key_0, key_1, params, rays_dict, extra_params):\n",
325
+ " out = model.apply({'params': params},\n",
326
+ " rays_dict,\n",
327
+ " extra_params=extra_params,\n",
328
+ " rngs={\n",
329
+ " 'coarse': key_0,\n",
330
+ " 'fine': key_1\n",
331
+ " },\n",
332
+ " mutable=False)\n",
333
+ " return jax.lax.all_gather(out, axis_name='batch')\n",
334
+ "\n",
335
+ "pmodel_fn = jax.pmap(\n",
336
+ " # Note rng_keys are useless in eval mode since there's no randomness.\n",
337
+ " _model_fn,\n",
338
+ " in_axes=(0, 0, 0, 0, 0), # Only distribute the data input.\n",
339
+ " devices=devices_to_use,\n",
340
+ " axis_name='batch',\n",
341
+ ")\n",
342
+ "\n",
343
+ "render_fn = functools.partial(evaluation.render_image,\n",
344
+ " model_fn=pmodel_fn,\n",
345
+ " device_count=len(devices),\n",
346
+ " chunk=eval_config.chunk)"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": null,
352
+ "metadata": {
353
+ "colab": {
354
+ "base_uri": "https://localhost:8080/"
355
+ },
356
+ "id": "73Fq0kNcmAra",
357
+ "outputId": "01f7bcee-833f-47fb-d2ab-0a9a2c15837f"
358
+ },
359
+ "outputs": [],
360
+ "source": [
361
+ "# @title Load cameras.\n",
362
+ "\n",
363
+ "from hypernerf import utils\n",
364
+ "\n",
365
+ "camera_path = 'camera-paths/orbit-mild' # @param {type: 'string'}\n",
366
+ "\n",
367
+ "camera_dir = Path(data_dir, camera_path)\n",
368
+ "print(f'Loading cameras from {camera_dir}')\n",
369
+ "test_camera_paths = datasource.glob_cameras(camera_dir)\n",
370
+ "test_cameras = utils.parallel_map(datasource.load_camera, test_camera_paths, show_pbar=True)"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": null,
376
+ "metadata": {
377
+ "colab": {
378
+ "base_uri": "https://localhost:8080/",
379
+ "height": 1000
380
+ },
381
+ "id": "aP9LjiAZmoRc",
382
+ "outputId": "811dfbc3-ccbc-4748-dee8-92281ea01b2c"
383
+ },
384
+ "outputs": [],
385
+ "source": [
386
+ "# @title Render video frames.\n",
387
+ "from hypernerf import visualization as viz\n",
388
+ "\n",
389
+ "\n",
390
+ "rng = rng + jax.process_index() # Make random seed separate across hosts.\n",
391
+ "keys = random.split(rng, len(devices))\n",
392
+ "\n",
393
+ "results = []\n",
394
+ "for i in range(len(test_cameras)):\n",
395
+ " print(f'Rendering frame {i+1}/{len(test_cameras)}')\n",
396
+ " camera = test_cameras[i]\n",
397
+ " batch = datasets.camera_to_rays(camera)\n",
398
+ " batch['metadata'] = {\n",
399
+ " 'appearance': jnp.zeros_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32),\n",
400
+ " 'warp': jnp.zeros_like(batch['origins'][..., 0, jnp.newaxis], jnp.uint32),\n",
401
+ " }\n",
402
+ " #these two are the \"ambient dimensions\" or \"time axis\" for rendering\n",
403
+ " batch['metadata']['appearance'] += i\n",
404
+ " batch['metadata']['warp'] += i\n",
405
+ "\n",
406
+ " render = render_fn(state, batch, rng=rng)\n",
407
+ " rgb = np.array(render['rgb'])\n",
408
+ " depth_med = np.array(render['med_depth'])\n",
409
+ " results.append((rgb, depth_med))\n",
410
+ " depth_viz = viz.colorize(depth_med.squeeze(), cmin=datasource.near, cmax=datasource.far, invert=True)\n",
411
+ " mediapy.show_images([rgb, depth_viz])"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": null,
417
+ "metadata": {
418
+ "id": "_5hHR9XVm8Ix"
419
+ },
420
+ "outputs": [],
421
+ "source": [
422
+ "# @title Show rendered video.\n",
423
+ "\n",
424
+ "fps = 30 # @param {type:'number'}\n",
425
+ "\n",
426
+ "frames = []\n",
427
+ "for rgb, depth in results:\n",
428
+ " depth_viz = viz.colorize(depth.squeeze(), cmin=datasource.near, cmax=datasource.far, invert=True)\n",
429
+ " frame = np.concatenate([rgb, depth_viz], axis=1)\n",
430
+ " frames.append(image_utils.image_to_uint8(frame))\n",
431
+ "\n",
432
+ "mediapy.show_video(frames, fps=fps)"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": null,
438
+ "metadata": {
439
+ "id": "WW32AVGR0Vwh"
440
+ },
441
+ "outputs": [],
442
+ "source": []
443
+ }
444
+ ],
445
+ "metadata": {
446
+ "accelerator": "GPU",
447
+ "colab": {
448
+ "gpuType": "T4",
449
+ "machine_shape": "hm",
450
+ "provenance": []
451
+ },
452
+ "gpuClass": "standard",
453
+ "kernelspec": {
454
+ "display_name": "Python 3 (ipykernel)",
455
+ "language": "python",
456
+ "name": "python3"
457
+ },
458
+ "language_info": {
459
+ "codemirror_mode": {
460
+ "name": "ipython",
461
+ "version": 3
462
+ },
463
+ "file_extension": ".py",
464
+ "mimetype": "text/x-python",
465
+ "name": "python",
466
+ "nbconvert_exporter": "python",
467
+ "pygments_lexer": "ipython3",
468
+ "version": "3.10.10"
469
+ }
470
+ },
471
+ "nbformat": 4,
472
+ "nbformat_minor": 1
473
+ }