Jackoabaad commited on
Commit
9b94fcc
·
1 Parent(s): 0142759

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -77
app.py CHANGED
@@ -1,77 +1,20 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.optim as optim
4
- import torch.utils.data as data
5
-
6
- from nerf import NeRF
7
- from dataset import Dataset
8
-
9
-
10
- def train(nerf, dataloader, optimizer, device):
11
- nerf.train()
12
- for i, data in enumerate(dataloader):
13
- # Get the input and target images.
14
- viewdirs, radiances = data
15
- viewdirs = viewdirs.to(device)
16
- radiances = radiances.to(device)
17
-
18
- # Forward pass.
19
- outputs = nerf(viewdirs)
20
-
21
- # Compute the loss.
22
- loss = nn.functional.mse_loss(outputs, radiances)
23
-
24
- # Backpropagate the loss.
25
- optimizer.zero_grad()
26
- loss.backward()
27
- optimizer.step()
28
-
29
-
30
- def test(nerf, dataloader, device):
31
- nerf.eval()
32
- psnrs = []
33
- for i, data in enumerate(dataloader):
34
- # Get the input and target images.
35
- viewdirs, radiances = data
36
- viewdirs = viewdirs.to(device)
37
- radiances = radiances.to(device)
38
-
39
- # Forward pass.
40
- outputs = nerf(viewdirs)
41
-
42
- # Compute the PSNR.
43
- psnrs.append(
44
- torch.mean(
45
- torch.nn.functional.psnr(outputs, radiances, data["intrinsics"])
46
- )
47
- )
48
-
49
- return np.mean(psnrs)
50
-
51
-
52
- def main():
53
- # Create the dataset.
54
- dataset = Dataset.from_json("data/nerf_synthetic_data.json")
55
- dataloader = data.DataLoader(dataset, batch_size=1, shuffle=True)
56
-
57
- # Create the NeRF model.
58
- nerf = NeRF(32, 64, 8).to(device)
59
-
60
- # Create the optimizer.
61
- optimizer = optim.Adam(nerf.parameters(), lr=1e-3)
62
-
63
- # Train the NeRF model.
64
- for i in range(1000):
65
- train(nerf, dataloader, optimizer, device)
66
-
67
- # Print the loss and PSNR every 100 iterations.
68
- if i % 100 == 0:
69
- loss = test(nerf, dataloader, device)
70
- print(f"Loss: {loss:.4f}")
71
-
72
- # Save the NeRF model.
73
- nerf.save("nerf.pth")
74
-
75
-
76
- if __name__ == "__main__":
77
- main()
 
1
+ import os, sys
2
+ # os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
3
+ os.environ['CUDA_VISIBLE_DEVICES'] = '1'
4
+ import tensorflow as tf
5
+ tf.compat.v1.enable_eager_execution()
6
+
7
+ import numpy as np
8
+ import imageio
9
+ import json
10
+ import random
11
+ import time
12
+ import pprint
13
+
14
+ import matplotlib.pyplot as plt
15
+
16
+ import run_nerf
17
+
18
+ from load_llff import load_llff_data
19
+ from load_deepvoxels import load_dv_data
20
+ from load_blender import load_blender_data