codestella commited on
Commit
97ec1af
1 Parent(s): 3c9f729

code change

Browse files
.gitattributes DELETED
@@ -1,17 +0,0 @@
1
- *.bin.* filter=lfs diff=lfs merge=lfs -text
2
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.h5 filter=lfs diff=lfs merge=lfs -text
5
- *.tflite filter=lfs diff=lfs merge=lfs -text
6
- *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
- *.ot filter=lfs diff=lfs merge=lfs -text
8
- *.onnx filter=lfs diff=lfs merge=lfs -text
9
- *.arrow filter=lfs diff=lfs merge=lfs -text
10
- *.ftz filter=lfs diff=lfs merge=lfs -text
11
- *.joblib filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.pb filter=lfs diff=lfs merge=lfs -text
15
- *.pt filter=lfs diff=lfs merge=lfs -text
16
- *.pth filter=lfs diff=lfs merge=lfs -text
17
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE CHANGED
File without changes
__init__.py CHANGED
File without changes
assets/lego-nerf.gif DELETED
Binary file (519 kB)
 
configs/blender.yaml CHANGED
File without changes
configs/demo.yaml CHANGED
File without changes
configs/diet_nerf_tpu_vm_4shot.yaml CHANGED
@@ -8,8 +8,9 @@ white_bkgd: true
8
  batch_size: 1024
9
  randomized: true
10
  max_steps: 200000
 
11
  print_every: 100
12
- render_every: 500
13
  save_every: 5000
14
  use_semantic_loss: true
15
  clip_model_name: openai/clip-vit-base-patch32
 
8
  batch_size: 1024
9
  randomized: true
10
  max_steps: 200000
11
+ stop_sc_loss: 160000
12
  print_every: 100
13
+ render_every: 1000
14
  save_every: 5000
15
  use_semantic_loss: true
16
  clip_model_name: openai/clip-vit-base-patch32
configs/diet_nerf_tpu_vm_few_shot.yaml CHANGED
@@ -8,8 +8,9 @@ white_bkgd: true
8
  batch_size: 1024
9
  randomized: true
10
  max_steps: 200000
 
11
  print_every: 100
12
- render_every: 500
13
  save_every: 5000
14
  use_semantic_loss: true
15
  clip_model_name: openai/clip-vit-base-patch32
 
8
  batch_size: 1024
9
  randomized: true
10
  max_steps: 200000
11
+ stop_sc_loss: 160000
12
  print_every: 100
13
+ render_every: 1000
14
  save_every: 5000
15
  use_semantic_loss: true
16
  clip_model_name: openai/clip-vit-base-patch32
configs/diet_nerf_tpu_vm_test.yaml CHANGED
@@ -2,12 +2,13 @@ dataset: blender
2
  batching: single_image
3
  factor: 0
4
  num_coarse_samples: 64
5
- num_fine_samples: 64
6
  use_viewdirs: true
7
  white_bkgd: true
8
- batch_size: 1026
9
  randomized: true
10
  max_steps: 200000
 
11
  print_every: 100
12
  render_every: 1000
13
  save_every: 5000
 
2
  batching: single_image
3
  factor: 0
4
  num_coarse_samples: 64
5
+ num_fine_samples: 128
6
  use_viewdirs: true
7
  white_bkgd: true
8
+ batch_size: 1024
9
  randomized: true
10
  max_steps: 200000
11
+ stop_sc_loss: 160000
12
  print_every: 100
13
  render_every: 1000
14
  save_every: 5000
configs/eval_diet_nerf_tpu_vm_few_shot.yaml CHANGED
File without changes
configs/nerf_tpu_vm_4shot.yaml CHANGED
File without changes
configs/nerf_tpu_vm_few_shot.yaml CHANGED
File without changes
configs/orig_nerf_tpu_vm_full.yaml CHANGED
File without changes
configs/orig_nerf_tpu_vm_test.yaml CHANGED
File without changes
eval.py CHANGED
@@ -112,30 +112,39 @@ def main(unused_argv):
112
  summary_writer = tensorboard.SummaryWriter(
113
  path.join(FLAGS.train_dir, "eval"))
114
 
115
- def generate_spinning_gif(radius, phi, gif_fn, frame_n):
116
  _rng = random.PRNGKey(0)
117
  partial_render_fn = functools.partial(render_pfn, state.optimizer.target)
118
  gif_images = []
 
119
  for theta in tqdm(np.linspace(-math.pi, math.pi, frame_n)):
120
  camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi))
121
  rays = dataset.camtoworld_matrix_to_rays(camtoworld, downsample=4)
122
  _rng, key0, key1 = random.split(_rng, 3)
123
- color, _, _ = utils.render_image(partial_render_fn, rays,
124
  _rng, False, chunk=4096)
125
  image = predict_to_image(color)
 
126
  gif_images.append(image)
 
 
 
 
127
  gif_images[0].save(gif_fn, save_all=True,
128
  append_images=gif_images,
129
  duration=100, loop=0)
130
- return gif_images
 
 
 
 
131
 
132
  if FLAGS.generate_gif_only:
133
  print('generate GIF file only')
134
  _radius = 4.
135
  _phi = (30 * math.pi) / 180
136
- _gif_fn = os.path.join(out_dir, 'spinning.gif')
137
- generate_spinning_gif(_radius, _phi, _gif_fn, frame_n=30)
138
- print(f'GIF file for spinning views written: {_gif_fn}')
139
  return
140
  else:
141
  print('generate GIF file AND evaluate model performance')
@@ -149,6 +158,7 @@ def main(unused_argv):
149
  utils.makedirs(out_dir)
150
  psnr_values = []
151
  ssim_values = []
 
152
  #lpips_values = []
153
  if not FLAGS.eval_once:
154
  showcase_index = np.random.randint(0, dataset.size)
@@ -225,9 +235,8 @@ def main(unused_argv):
225
  if not is_gif_written:
226
  _radius = 4.
227
  _phi = (30 * math.pi) / 180
228
- _gif_fn = os.path.join(out_dir, 'spinning.gif')
229
- generate_spinning_gif(_radius, _phi, _gif_fn, frame_n=30)
230
- print(f'GIF file for spinning views written: {_gif_fn}')
231
  is_gif_written = True
232
 
233
  if FLAGS.eval_once:
 
112
  summary_writer = tensorboard.SummaryWriter(
113
  path.join(FLAGS.train_dir, "eval"))
114
 
115
+ def generate_spinning_gif(radius, phi, output_dir, frame_n):
116
  _rng = random.PRNGKey(0)
117
  partial_render_fn = functools.partial(render_pfn, state.optimizer.target)
118
  gif_images = []
119
+ gif_images2 = []
120
  for theta in tqdm(np.linspace(-math.pi, math.pi, frame_n)):
121
  camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi))
122
  rays = dataset.camtoworld_matrix_to_rays(camtoworld, downsample=4)
123
  _rng, key0, key1 = random.split(_rng, 3)
124
+ color, disp, _ = utils.render_image(partial_render_fn, rays,
125
  _rng, False, chunk=4096)
126
  image = predict_to_image(color)
127
+ image2 = predict_to_image(disp[Ellipsis, 0])
128
  gif_images.append(image)
129
+ gif_images2.append(image2)
130
+
131
+ gif_fn = os.path.join(output_dir, 'rgb_spinning.gif')
132
+ gif_fn2 = os.path.join(output_dir, 'disp_spinning.gif')
133
  gif_images[0].save(gif_fn, save_all=True,
134
  append_images=gif_images,
135
  duration=100, loop=0)
136
+ gif_images2[0].save(gif_fn2, save_all=True,
137
+ append_images=gif_images2,
138
+ duration=100, loop=0)
139
+
140
+ #return gif_images, gif_images2
141
 
142
  if FLAGS.generate_gif_only:
143
  print('generate GIF file only')
144
  _radius = 4.
145
  _phi = (30 * math.pi) / 180
146
+ generate_spinning_gif(_radius, _phi, out_dir, frame_n=30)
147
+ print('GIF file for spinning views written)')
 
148
  return
149
  else:
150
  print('generate GIF file AND evaluate model performance')
 
158
  utils.makedirs(out_dir)
159
  psnr_values = []
160
  ssim_values = []
161
+
162
  #lpips_values = []
163
  if not FLAGS.eval_once:
164
  showcase_index = np.random.randint(0, dataset.size)
 
235
  if not is_gif_written:
236
  _radius = 4.
237
  _phi = (30 * math.pi) / 180
238
+ generate_spinning_gif(_radius, _phi, out_dir, frame_n=30)
239
+ print(f'GIF file for spinning views written')
 
240
  is_gif_written = True
241
 
242
  if FLAGS.eval_once:
eval.sh CHANGED
File without changes
example_data/imgs/r_0.png CHANGED
example_data/transforms_test.json CHANGED
File without changes
example_data/transforms_train.json CHANGED
File without changes
fork-of-first-touch-of-nerf-in-jax.ipynb CHANGED
File without changes
nerf/__init__.py CHANGED
File without changes
nerf/__pycache__/__init__.cpython-37.pyc DELETED
Binary file (137 Bytes)
 
nerf/__pycache__/clip_utils.cpython-37.pyc DELETED
Binary file (5.16 kB)
 
nerf/__pycache__/datasets.cpython-37.pyc DELETED
Binary file (18.3 kB)
 
nerf/__pycache__/model_utils.cpython-37.pyc DELETED
Binary file (10 kB)
 
nerf/__pycache__/models.cpython-37.pyc DELETED
Binary file (5.08 kB)
 
nerf/__pycache__/utils.cpython-37.pyc DELETED
Binary file (15.8 kB)
 
nerf/clip_utils.py CHANGED
@@ -15,50 +15,44 @@ FLAGS = flags.FLAGS
15
 
16
  @partial(jax.jit, static_argnums=[0])
17
  def semantic_loss(clip_model, src_image, target_embedding):
18
- #c_image = utils.unshard(src_image[0])
19
- f_image = utils.unshard(src_image[-1])
20
-
21
- w = int(math.sqrt(src_image[-1].size//3))
22
- #c_image = c_image.reshape([w, w, 3])
23
  f_image = f_image.reshape([w, w, 3])
24
-
25
- src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(f_image,0).transpose(0, 3, 1, 2)))
26
- #src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
27
  src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
28
- sc_loss = 1 - jnp.sum(src_embedding * target_embedding)
29
  return sc_loss, f_image
30
 
31
  def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
32
- random_rays = jax.tree_map(lambda x: utils.shard(x).astype(jnp.float16), batch["random_rays"])
33
- target_embedding = batch["embedding"].astype(jnp.float16)
34
  rng, key_0, key_1 = random.split(rng,3)
35
-
36
  def loss_fn(variables):
37
- src_image = render_pfn(variables, key_0, key_1, random_rays)
38
- sc_loss, f_image = semantic_loss(clip_model, src_image, target_embedding)
39
  return sc_loss * FLAGS.sc_loss_mult, f_image
40
  (sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
41
  return sc_loss, grad, src_image
42
 
43
  @partial(jax.jit, static_argnums=[0, 1])
44
  def semantic_step_single(model, clip_model, rng, state, batch, lr):
45
- batch = jax.tree_map(lambda x: x.astype(jnp.float16), batch)
46
- # the batch is without shard
47
- random_rays = batch["random_rays"]
48
  rng, key_0, key_1 = random.split(rng,3)
49
 
50
  def semantic_loss(variables):
51
  c_image, f_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
52
- # reshape flat pixel to an image (assume 3 channels & square shape)
53
  w = int(math.sqrt(f_image.shape[0]))
54
- # c_image = c_image.reshape([w, w, 3])
55
  f_image = f_image.reshape([w, w, 3])
56
 
57
- src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(f_image,0).transpose(0, 3, 1, 2)))
58
- # src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
59
  src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
60
- target_embedding = batch["embedding"]
61
- sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding)**2)
62
  return sc_loss * FLAGS.sc_loss_mult, f_image
63
  (sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
64
  return sc_loss, grad, src_image
 
15
 
16
  @partial(jax.jit, static_argnums=[0])
17
  def semantic_loss(clip_model, src_image, target_embedding):
18
+ c_image = utils.unshard(src_image[0])
19
+ f_image = utils.unshard(src_image[1])
20
+ w = int(math.sqrt(f_image.shape[0]))
21
+ c_image = c_image.reshape([w, w, 3])
 
22
  f_image = f_image.reshape([w, w, 3])
23
+
24
+ src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image,f_image],0).transpose(0, 3, 1, 2)))
 
25
  src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
26
+ sc_loss = 2 - jnp.sum(src_embedding * target_embedding)
27
  return sc_loss, f_image
28
 
29
  def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
30
+ random_rays = batch["random_rays"]
31
+ target_embedding = batch["embedding"]
32
  rng, key_0, key_1 = random.split(rng,3)
33
+
34
  def loss_fn(variables):
35
+ images = render_pfn(variables, key_0, key_1, random_rays)
36
+ sc_loss, f_image = semantic_loss(clip_model, images, target_embedding)
37
  return sc_loss * FLAGS.sc_loss_mult, f_image
38
  (sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
39
  return sc_loss, grad, src_image
40
 
41
  @partial(jax.jit, static_argnums=[0, 1])
42
  def semantic_step_single(model, clip_model, rng, state, batch, lr):
43
+ random_rays = jax.tree_map(lambda x: x.reshape(-1,3), batch["random_rays"])
44
+ target_embedding = batch["embedding"]
 
45
  rng, key_0, key_1 = random.split(rng,3)
46
 
47
  def semantic_loss(variables):
48
  c_image, f_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
 
49
  w = int(math.sqrt(f_image.shape[0]))
50
+ c_image = c_image.reshape([w, w, 3])
51
  f_image = f_image.reshape([w, w, 3])
52
 
53
+ src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image,f_image],0).transpose(0, 3, 1, 2)))
 
54
  src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
55
+ sc_loss = 2 - jnp.sum(src_embedding * target_embedding)
 
56
  return sc_loss * FLAGS.sc_loss_mult, f_image
57
  (sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
58
  return sc_loss, grad, src_image
nerf/datasets.py CHANGED
@@ -236,6 +236,7 @@ class Blender(Dataset):
236
  camera_angle_x = float(meta["camera_angle_x"])
237
  self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
238
  self.n_examples = self.images.shape[0]
 
239
 
240
  if flags.use_semantic_loss and clip_model is not None:
241
  embs = []
@@ -258,8 +259,8 @@ class Blender(Dataset):
258
 
259
  frames = np.arange(len(meta["frames"]))
260
  if few_shot > 0 and split == 'train':
261
- np.random.seed(0)
262
- np.random.shuffle(frames)
263
  frames = frames[:few_shot]
264
 
265
  # if split == 'train':
@@ -308,16 +309,21 @@ class Blender(Dataset):
308
  src_seed = int(time.time())
309
  src_rng = jax.random.PRNGKey(src_seed)
310
  src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
311
- random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 4)
312
- cx = np.random.randint(80, 120)
313
- cy = np.random.randint(80, 120)
314
- d = 70
315
- random_rays = jax.tree_map(lambda x: x[cy-d:cy+d,cx-d:cx+d], random_rays)
 
 
 
316
  w = random_rays[0].shape[0] - random_rays[0].shape[0]%jax.local_device_count()
317
  random_rays = jax.tree_map(lambda x: x[:w,:w].reshape(-1,3), random_rays)
318
- batch_dict["random_rays"] = random_rays
 
 
319
  return batch_dict
320
-
321
  class LLFF(Dataset):
322
  """LLFF Dataset."""
323
 
 
236
  camera_angle_x = float(meta["camera_angle_x"])
237
  self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
238
  self.n_examples = self.images.shape[0]
239
+ self.dtype = flags.clip_output_dtype
240
 
241
  if flags.use_semantic_loss and clip_model is not None:
242
  embs = []
 
259
 
260
  frames = np.arange(len(meta["frames"]))
261
  if few_shot > 0 and split == 'train':
262
+ # np.random.seed(0)
263
+ # np.random.shuffle(frames)
264
  frames = frames[:few_shot]
265
 
266
  # if split == 'train':
 
309
  src_seed = int(time.time())
310
  src_rng = jax.random.PRNGKey(src_seed)
311
  src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
312
+
313
+ cx = np.random.randint(320, 480)
314
+ cy = np.random.randint(320, 480)
315
+ d = 140
316
+
317
+ random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 1)
318
+ random_rays = jax.tree_map(lambda x: x[cy-d:cy+d:4,cx-d:cx+d:4], random_rays)
319
+
320
  w = random_rays[0].shape[0] - random_rays[0].shape[0]%jax.local_device_count()
321
  random_rays = jax.tree_map(lambda x: x[:w,:w].reshape(-1,3), random_rays)
322
+ batch_dict["random_rays"] = utils.shard(random_rays)
323
+ if self.dtype == 'float16':
324
+ batch_dict = jax.tree_map(lambda x: x.astype(np.float16), batch_dict)
325
  return batch_dict
326
+
327
  class LLFF(Dataset):
328
  """LLFF Dataset."""
329
 
nerf/model_utils.py CHANGED
File without changes
nerf/models.py CHANGED
@@ -136,7 +136,7 @@ class NerfModel(nn.Module):
136
  (comp_rgb, disp, acc),
137
  ]
138
 
139
- if self.num_fine_samples > 0 and not(rgb_only):
140
  z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
141
  key, rng_1 = random.split(rng_1)
142
 
@@ -191,8 +191,7 @@ class NerfModel(nn.Module):
191
  )
192
  ret.append((comp_rgb, disp, acc))
193
  if rgb_only:
194
- #return [ret[0][0], ret[1][0]]
195
- return [None, ret[0][0]]
196
  return ret
197
 
198
  def construct_nerf(key, example_batch, args):
 
136
  (comp_rgb, disp, acc),
137
  ]
138
 
139
+ if self.num_fine_samples > 0:
140
  z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
141
  key, rng_1 = random.split(rng_1)
142
 
 
191
  )
192
  ret.append((comp_rgb, disp, acc))
193
  if rgb_only:
194
+ return [ret[0][0], ret[1][0]]
 
195
  return ret
196
 
197
  def construct_nerf(key, example_batch, args):
nerf/utils.py CHANGED
@@ -66,11 +66,11 @@ def define_flags():
66
  flags.DEFINE_bool("use_semantic_loss", True,
67
  "whether use semantic loss or not")
68
  flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP")
69
- flags.DEFINE_string("clip_output_dtype", "float32",
70
  "float32/ float16 (float16 for memory saving)")
71
  flags.DEFINE_integer("sc_loss_every", 16,
72
  "no. of steps to take before performing semantic loss evaluation")
73
- flags.DEFINE_float("sc_loss_mult", 1e-3,
74
  "weighting for semantic loss from CLIP")
75
 
76
  # Dataset Flags
@@ -166,6 +166,8 @@ def define_flags():
166
 
167
  flags.DEFINE_integer("max_steps", 1000000,
168
  "the number of optimization steps.")
 
 
169
  flags.DEFINE_integer("save_every", 10000,
170
  "the number of steps to save a checkpoint.")
171
  flags.DEFINE_integer("print_every", 100,
 
66
  flags.DEFINE_bool("use_semantic_loss", True,
67
  "whether use semantic loss or not")
68
  flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP")
69
+ flags.DEFINE_string("clip_output_dtype", "float16",
70
  "float32/ float16 (float16 for memory saving)")
71
  flags.DEFINE_integer("sc_loss_every", 16,
72
  "no. of steps to take before performing semantic loss evaluation")
73
+ flags.DEFINE_float("sc_loss_mult", 1e-2,
74
  "weighting for semantic loss from CLIP")
75
 
76
  # Dataset Flags
 
166
 
167
  flags.DEFINE_integer("max_steps", 1000000,
168
  "the number of optimization steps.")
169
+ flags.DEFINE_integer("stop_sc_loss", 1000000,
170
+ "the number of sc_loss optimization steps")
171
  flags.DEFINE_integer("save_every", 10000,
172
  "the number of steps to save a checkpoint.")
173
  flags.DEFINE_integer("print_every", 100,
requirements.txt CHANGED
File without changes
run.sh CHANGED
File without changes
train.py CHANGED
@@ -50,7 +50,6 @@ print(f"detected device: {jax.local_devices()}")
50
 
51
 
52
  def train_step(model, clip_model, rng, state, batch, lr, step, K,):
53
- # TODO make clip_grad input enable
54
  """One optimization step.
55
 
56
  Args:
@@ -102,7 +101,6 @@ def train_step(model, clip_model, rng, state, batch, lr, step, K,):
102
 
103
  (_, stats), grad = (
104
  jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
105
- #grad = jax.lax.pmean(grad, axis_name="batch")
106
  stats = jax.lax.pmean(stats, axis_name="batch")
107
 
108
  # Clip the gradient by value.
@@ -238,26 +236,16 @@ def main(unused_argv):
238
 
239
  grad, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)
240
 
241
- if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
242
  sc_batch = dataset.get_clip_data()
243
  if jax.local_device_count() > 1:
244
  sc_loss, sc_grad, sc_image = clip_utils.semantic_step_multi(render_pfn_, clip_model, keys[0], state, sc_batch, lr)
245
  else:
246
  sc_loss, sc_grad, sc_image = clip_utils.semantic_step_single(model, clip_model, keys[0], state, sc_batch, lr)
247
 
248
- if jax.host_id() == 0 and step%FLAGS.print_every:
249
- for mlp_k, mlp in grad['params'].items():
250
- for layer_k, layer_g in mlp.items():
251
- summary_writer.scalar("%s/%s/kernel_grad"%(mlp_k, layer_k), jnp.linalg.norm(jnp.mean(layer_g['kernel'],0)), step)
252
- for mlp_k, mlp in sc_grad['params'].items():
253
- for layer_k, layer_g in mlp.items():
254
- summary_writer.scalar("%s/%s/kernel_sc_grad"%(mlp_k, layer_k), jnp.linalg.norm(layer_g['kernel']), step)
255
-
256
  leaves, treedef = jax.tree_flatten(grad)
257
  sc_leaves, _ = jax.tree_flatten(sc_grad)
258
  grad = treedef.unflatten(g+jnp.expand_dims(sc_g,0) for g, sc_g in zip(leaves, sc_leaves))
259
-
260
-
261
 
262
  state = update_pstep(state, grad, lr)
263
 
@@ -276,24 +264,26 @@ def main(unused_argv):
276
  summary_writer.scalar("psnr/train", stats.psnr[0], step)
277
  summary_writer.scalar("train_coarse/loss", stats.loss_c[0], step)
278
  summary_writer.scalar("train_coarse/psnr", stats.psnr_c[0], step)
279
- summary_writer.scalar("weight_l2", stats.weight_l2[0], step)
280
  avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
281
  avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
282
  stats_trace = []
283
  summary_writer.scalar("train_avg/loss", avg_loss, step)
284
  summary_writer.scalar("train_avg/psnr", avg_psnr, step)
285
- summary_writer.scalar("learning_rate", lr, step)
286
  steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
287
  reset_timer = True
288
  rays_per_sec = FLAGS.batch_size * steps_per_sec
289
- summary_writer.scalar("train_steps_per_sec", steps_per_sec, step)
290
- summary_writer.scalar("train_rays_per_sec", rays_per_sec, step)
 
 
291
  precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
292
  print(("{:" + "{:d}".format(precision) + "d}").format(step) +
293
  f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
294
  f"avg_loss={avg_loss:0.4f}, " +
295
  f"weight_l2={stats.weight_l2[0]:0.2e}, " +
296
- # f"sc_loss={sc_loss:0.4f}, " +
297
  f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
298
  if step % FLAGS.save_every == 0:
299
  state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
@@ -324,12 +314,10 @@ def main(unused_argv):
324
  eval_time = time.time() - t_eval_start
325
  num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
326
  rays_per_sec = num_rays / eval_time
327
- summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
328
  print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
329
  summary_writer.scalar("psnr/test", psnr, step)
330
- summary_writer.scalar("test_psnr", psnr, step)
331
  summary_writer.scalar("ssim/ssim", ssim, step)
332
- summary_writer.scalar("test_ssim", ssim, step)
333
  if sc_image is not None:
334
  summary_writer .image("random_ray_image", sc_image, step)
335
  summary_writer.image("test_pred_color", pred_color, step)
 
50
 
51
 
52
  def train_step(model, clip_model, rng, state, batch, lr, step, K,):
 
53
  """One optimization step.
54
 
55
  Args:
 
101
 
102
  (_, stats), grad = (
103
  jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
 
104
  stats = jax.lax.pmean(stats, axis_name="batch")
105
 
106
  # Clip the gradient by value.
 
236
 
237
  grad, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)
238
 
239
+ if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss and step < FLAGS.stop_sc_loss:
240
  sc_batch = dataset.get_clip_data()
241
  if jax.local_device_count() > 1:
242
  sc_loss, sc_grad, sc_image = clip_utils.semantic_step_multi(render_pfn_, clip_model, keys[0], state, sc_batch, lr)
243
  else:
244
  sc_loss, sc_grad, sc_image = clip_utils.semantic_step_single(model, clip_model, keys[0], state, sc_batch, lr)
245
 
 
 
 
 
 
 
 
 
246
  leaves, treedef = jax.tree_flatten(grad)
247
  sc_leaves, _ = jax.tree_flatten(sc_grad)
248
  grad = treedef.unflatten(g+jnp.expand_dims(sc_g,0) for g, sc_g in zip(leaves, sc_leaves))
 
 
249
 
250
  state = update_pstep(state, grad, lr)
251
 
 
264
  summary_writer.scalar("psnr/train", stats.psnr[0], step)
265
  summary_writer.scalar("train_coarse/loss", stats.loss_c[0], step)
266
  summary_writer.scalar("train_coarse/psnr", stats.psnr_c[0], step)
267
+
268
  avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
269
  avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
270
  stats_trace = []
271
  summary_writer.scalar("train_avg/loss", avg_loss, step)
272
  summary_writer.scalar("train_avg/psnr", avg_psnr, step)
273
+
274
  steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
275
  reset_timer = True
276
  rays_per_sec = FLAGS.batch_size * steps_per_sec
277
+ summary_writer.scalar("stats/weight_l2", stats.weight_l2[0], step)
278
+ summary_writer.scalar("stats/learning_rate", lr, step)
279
+ summary_writer.scalar("iter_speed/train_steps_per_sec", steps_per_sec, step)
280
+ summary_writer.scalar("iter_speed/train_rays_per_sec", rays_per_sec, step)
281
  precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
282
  print(("{:" + "{:d}".format(precision) + "d}").format(step) +
283
  f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
284
  f"avg_loss={avg_loss:0.4f}, " +
285
  f"weight_l2={stats.weight_l2[0]:0.2e}, " +
286
+ f"sc_loss={sc_loss:0.4f}, " +
287
  f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
288
  if step % FLAGS.save_every == 0:
289
  state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
 
314
  eval_time = time.time() - t_eval_start
315
  num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
316
  rays_per_sec = num_rays / eval_time
317
+ summary_writer.scalar("iter_speed/test_rays_per_sec", rays_per_sec, step)
318
  print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
319
  summary_writer.scalar("psnr/test", psnr, step)
 
320
  summary_writer.scalar("ssim/ssim", ssim, step)
 
321
  if sc_image is not None:
322
  summary_writer .image("random_ray_image", sc_image, step)
323
  summary_writer.image("test_pred_color", pred_color, step)
train.sh CHANGED
File without changes