ritwikraha commited on
Commit
49e0d56
1 Parent(s): 0bfcfed

add: main script added

Browse files
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ # Setting random seed to obtain reproducible results.
3
+ import tensorflow as tf
4
+
5
+ tf.random.set_seed(42)
6
+
7
+ import os
8
+ import glob
9
+ import imageio
10
+ from PIL import Image
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from tensorflow import keras
14
+ from tensorflow.keras import layers
15
+ import matplotlib.pyplot as plt
16
+
17
+ # Initialize global variables.
18
+ AUTO = tf.data.AUTOTUNE
19
+ BATCH_SIZE = 1
20
+ NUM_SAMPLES = 32
21
+ POS_ENCODE_DIMS = 16
22
+ EPOCHS = 20
23
+ H = 100
24
+ W = 100
25
+ focal = 138.88
26
+
27
+ def encode_position(x):
28
+ """Encodes the position into its corresponding Fourier feature.
29
+
30
+ Args:
31
+ x: The input coordinate.
32
+
33
+ Returns:
34
+ Fourier features tensors of the position.
35
+ """
36
+ positions = [x]
37
+ for i in range(POS_ENCODE_DIMS):
38
+ for fn in [tf.sin, tf.cos]:
39
+ positions.append(fn(2.0 ** i * x))
40
+ return tf.concat(positions, axis=-1)
41
+
42
+
43
+ def get_rays(height, width, focal, pose):
44
+ """Computes origin point and direction vector of rays.
45
+
46
+ Args:
47
+ height: Height of the image.
48
+ width: Width of the image.
49
+ focal: The focal length between the images and the camera.
50
+ pose: The pose matrix of the camera.
51
+
52
+ Returns:
53
+ Tuple of origin point and direction vector for rays.
54
+ """
55
+ # Build a meshgrid for the rays.
56
+ i, j = tf.meshgrid(
57
+ tf.range(width, dtype=tf.float32),
58
+ tf.range(height, dtype=tf.float32),
59
+ indexing="xy",
60
+ )
61
+
62
+ # Normalize the x axis coordinates.
63
+ transformed_i = (i - width * 0.5) / focal
64
+
65
+ # Normalize the y axis coordinates.
66
+ transformed_j = (j - height * 0.5) / focal
67
+
68
+ # Create the direction unit vectors.
69
+ directions = tf.stack([transformed_i, -transformed_j, -tf.ones_like(i)], axis=-1)
70
+
71
+ # Get the camera matrix.
72
+ camera_matrix = pose[:3, :3]
73
+ height_width_focal = pose[:3, -1]
74
+
75
+ # Get origins and directions for the rays.
76
+ transformed_dirs = directions[..., None, :]
77
+ camera_dirs = transformed_dirs * camera_matrix
78
+ ray_directions = tf.reduce_sum(camera_dirs, axis=-1)
79
+ ray_origins = tf.broadcast_to(height_width_focal, tf.shape(ray_directions))
80
+
81
+ # Return the origins and directions.
82
+ return (ray_origins, ray_directions)
83
+
84
+
85
+ def render_flat_rays(ray_origins, ray_directions, near, far, num_samples, rand=False):
86
+ """Renders the rays and flattens it.
87
+
88
+ Args:
89
+ ray_origins: The origin points for rays.
90
+ ray_directions: The direction unit vectors for the rays.
91
+ near: The near bound of the volumetric scene.
92
+ far: The far bound of the volumetric scene.
93
+ num_samples: Number of sample points in a ray.
94
+ rand: Choice for randomising the sampling strategy.
95
+
96
+ Returns:
97
+ Tuple of flattened rays and sample points on each rays.
98
+ """
99
+ # Compute 3D query points.
100
+ # Equation: r(t) = o+td -> Building the "t" here.
101
+ t_vals = tf.linspace(near, far, num_samples)
102
+ if rand:
103
+ # Inject uniform noise into sample space to make the sampling
104
+ # continuous.
105
+ shape = list(ray_origins.shape[:-1]) + [num_samples]
106
+ noise = tf.random.uniform(shape=shape) * (far - near) / num_samples
107
+ t_vals = t_vals + noise
108
+
109
+ # Equation: r(t) = o + td -> Building the "r" here.
110
+ rays = ray_origins[..., None, :] + (
111
+ ray_directions[..., None, :] * t_vals[..., None]
112
+ )
113
+ rays_flat = tf.reshape(rays, [-1, 3])
114
+ rays_flat = encode_position(rays_flat)
115
+ return (rays_flat, t_vals)
116
+
117
+
118
+ def map_fn(pose):
119
+ """Maps individual pose to flattened rays and sample points.
120
+
121
+ Args:
122
+ pose: The pose matrix of the camera.
123
+
124
+ Returns:
125
+ Tuple of flattened rays and sample points corresponding to the
126
+ camera pose.
127
+ """
128
+ (ray_origins, ray_directions) = get_rays(height=H, width=W, focal=focal, pose=pose)
129
+ (rays_flat, t_vals) = render_flat_rays(
130
+ ray_origins=ray_origins,
131
+ ray_directions=ray_directions,
132
+ near=2.0,
133
+ far=6.0,
134
+ num_samples=NUM_SAMPLES,
135
+ rand=True,
136
+ )
137
+ return (rays_flat, t_vals)
138
+
139
+ def render_rgb_depth(model, rays_flat, t_vals, rand=True, train=True):
140
+ """Generates the RGB image and depth map from model prediction.
141
+
142
+ Args:
143
+ model: The MLP model that is trained to predict the rgb and
144
+ volume density of the volumetric scene.
145
+ rays_flat: The flattened rays that serve as the input to
146
+ the NeRF model.
147
+ t_vals: The sample points for the rays.
148
+ rand: Choice to randomise the sampling strategy.
149
+ train: Whether the model is in the training or testing phase.
150
+
151
+ Returns:
152
+ Tuple of rgb image and depth map.
153
+ """
154
+ # Get the predictions from the nerf model and reshape it.
155
+ if train:
156
+ predictions = model(rays_flat)
157
+ else:
158
+ predictions = model.predict(rays_flat)
159
+ predictions = tf.reshape(predictions, shape=(BATCH_SIZE, H, W, NUM_SAMPLES, 4))
160
+
161
+ # Slice the predictions into rgb and sigma.
162
+ rgb = tf.sigmoid(predictions[..., :-1])
163
+ sigma_a = tf.nn.relu(predictions[..., -1])
164
+
165
+ # Get the distance of adjacent intervals.
166
+ delta = t_vals[..., 1:] - t_vals[..., :-1]
167
+ # delta shape = (num_samples)
168
+ if rand:
169
+ delta = tf.concat(
170
+ [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, H, W, 1))], axis=-1
171
+ )
172
+ alpha = 1.0 - tf.exp(-sigma_a * delta)
173
+ else:
174
+ delta = tf.concat(
175
+ [delta, tf.broadcast_to([1e10], shape=(BATCH_SIZE, 1))], axis=-1
176
+ )
177
+ alpha = 1.0 - tf.exp(-sigma_a * delta[:, None, None, :])
178
+
179
+ # Get transmittance.
180
+ exp_term = 1.0 - alpha
181
+ epsilon = 1e-10
182
+ transmittance = tf.math.cumprod(exp_term + epsilon, axis=-1, exclusive=True)
183
+ weights = alpha * transmittance
184
+ rgb = tf.reduce_sum(weights[..., None] * rgb, axis=-2)
185
+
186
+ if rand:
187
+ depth_map = tf.reduce_sum(weights * t_vals, axis=-1)
188
+ else:
189
+ depth_map = tf.reduce_sum(weights * t_vals[:, None, None], axis=-1)
190
+ return (rgb, depth_map)
191
+
192
+ nerf_loaded = tf.keras.models.load_model("nerf", compile=False)
193
+
194
+ def get_translation_t(t):
195
+ """Get the translation matrix for movement in t."""
196
+ matrix = [
197
+ [1, 0, 0, 0],
198
+ [0, 1, 0, 0],
199
+ [0, 0, 1, t],
200
+ [0, 0, 0, 1],
201
+ ]
202
+ return tf.convert_to_tensor(matrix, dtype=tf.float32)
203
+
204
+
205
+ def get_rotation_phi(phi):
206
+ """Get the rotation matrix for movement in phi."""
207
+ matrix = [
208
+ [1, 0, 0, 0],
209
+ [0, tf.cos(phi), -tf.sin(phi), 0],
210
+ [0, tf.sin(phi), tf.cos(phi), 0],
211
+ [0, 0, 0, 1],
212
+ ]
213
+ return tf.convert_to_tensor(matrix, dtype=tf.float32)
214
+
215
+
216
+ def get_rotation_theta(theta):
217
+ """Get the rotation matrix for movement in theta."""
218
+ matrix = [
219
+ [tf.cos(theta), 0, -tf.sin(theta), 0],
220
+ [0, 1, 0, 0],
221
+ [tf.sin(theta), 0, tf.cos(theta), 0],
222
+ [0, 0, 0, 1],
223
+ ]
224
+ return tf.convert_to_tensor(matrix, dtype=tf.float32)
225
+
226
+
227
+ def pose_spherical(theta, phi, t):
228
+ """
229
+ Get the camera to world matrix for the corresponding theta, phi
230
+ and t.
231
+ """
232
+ c2w = get_translation_t(t)
233
+ c2w = get_rotation_phi(phi / 180.0 * np.pi) @ c2w
234
+ c2w = get_rotation_theta(theta / 180.0 * np.pi) @ c2w
235
+ c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
236
+ return c2w
237
+
238
+
239
+ def show_rendered_image(r,theta,phi):
240
+ # Get the camera to world matrix.
241
+ c2w = pose_spherical(theta, phi, r)
242
+
243
+ ray_oris, ray_dirs = get_rays(H, W, focal, c2w)
244
+ rays_flat, t_vals = render_flat_rays(
245
+ ray_oris, ray_dirs, near=2.0, far=6.0, num_samples=NUM_SAMPLES, rand=False
246
+ )
247
+
248
+ rgb, depth = render_rgb_depth(
249
+ nerf_loaded, rays_flat[None, ...], t_vals[None, ...], rand=False, train=False
250
+ )
251
+ return(rgb[0], depth[0])
252
+
253
+ # app.py text matter starts here
254
+ st.title('NeRF:Neural Radiance Fields')
255
+ st.subfield('')
256
+ # set the values of r theta phi
257
+ r = -30.0
258
+ theta = st.slider('Enter a value for theta', 0.0, 360.0, 1)
259
+ phi = st.slider('Enter a value for phi', 0.0, 360.0, 1)
260
+
261
+ color,depth = show_rendered_image(r,theta,phi)
262
+
263
+ st.image(color, caption = "Color")
264
+ st.image(depth, caption = "Depth")
265
+
266
+
267
+
nerf/keras_metadata.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8da49eeec070f24b87d869c2005bdec6fdbd1a1bc1fb6d44c73eb8f89321c6c
3
+ size 21754
nerf/saved_model.pb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fe43d8f799d56fc7ecdc964172f56541f7cbdaed8644559ed9c7bac553e826e
3
+ size 272106
nerf/variables/variables.data-00000-of-00001 ADDED
Binary file (174 kB). View file
 
nerf/variables/variables.index ADDED
Binary file (1.24 kB). View file