ashawkey commited on
Commit
3de5f93
1 Parent(s): fa5ca19

improve non-cuda-ray mode

Browse files
Files changed (3) hide show
  1. main.py +11 -4
  2. nerf/sd.py +1 -1
  3. nerf/utils.py +15 -13
main.py CHANGED
@@ -28,8 +28,8 @@ if __name__ == '__main__':
28
  parser.add_argument('--ckpt', type=str, default='latest')
29
  parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
30
  parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
31
- parser.add_argument('--num_steps', type=int, default=128, help="num steps sampled per ray (only valid when not using --cuda_ray)")
32
- parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
33
  parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
34
  parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
35
  parser.add_argument('--albedo_iters', type=int, default=15000, help="training iters that only use albedo shading")
@@ -40,8 +40,8 @@ if __name__ == '__main__':
40
  parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
41
  parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
42
  # rendering resolution in training, decrease this if CUDA OOM.
43
- parser.add_argument('--w', type=int, default=128, help="render width for NeRF in training")
44
- parser.add_argument('--h', type=int, default=128, help="render height for NeRF in training")
45
  parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
46
 
47
  ### dataset options
@@ -55,6 +55,7 @@ if __name__ == '__main__':
55
  parser.add_argument('--angle_front', type=float, default=30, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
56
 
57
  parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
 
58
  parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
59
 
60
  ### GUI options
@@ -72,10 +73,16 @@ if __name__ == '__main__':
72
  if opt.O:
73
  opt.fp16 = True
74
  opt.dir_text = True
 
75
  opt.cuda_ray = True
 
 
 
76
  elif opt.O2:
77
  opt.fp16 = True
78
  opt.dir_text = True
 
 
79
 
80
  if opt.backbone == 'vanilla':
81
  from nerf.network import NeRFNetwork
 
28
  parser.add_argument('--ckpt', type=str, default='latest')
29
  parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
30
  parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
31
+ parser.add_argument('--num_steps', type=int, default=64, help="num steps sampled per ray (only valid when not using --cuda_ray)")
32
+ parser.add_argument('--upsample_steps', type=int, default=64, help="num steps up-sampled per ray (only valid when not using --cuda_ray)")
33
  parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
34
  parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)")
35
  parser.add_argument('--albedo_iters', type=int, default=15000, help="training iters that only use albedo shading")
 
40
  parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
41
  parser.add_argument('--backbone', type=str, default='grid', help="nerf backbone, choose from [grid, tcnn, vanilla]")
42
  # rendering resolution in training, decrease this if CUDA OOM.
43
+ parser.add_argument('--w', type=int, default=64, help="render width for NeRF in training")
44
+ parser.add_argument('--h', type=int, default=64, help="render height for NeRF in training")
45
  parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
46
 
47
  ### dataset options
 
55
  parser.add_argument('--angle_front', type=float, default=30, help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.")
56
 
57
  parser.add_argument('--lambda_entropy', type=float, default=1e-4, help="loss scale for alpha entropy")
58
+ parser.add_argument('--lambda_opacity', type=float, default=0, help="loss scale for alpha value")
59
  parser.add_argument('--lambda_orient', type=float, default=1e-2, help="loss scale for orientation")
60
 
61
  ### GUI options
 
73
  if opt.O:
74
  opt.fp16 = True
75
  opt.dir_text = True
76
+ # use occupancy grid to prune ray sampling, faster rendering.
77
  opt.cuda_ray = True
78
+ opt.lambda_entropy = 1e-4
79
+ opt.lambda_opacity = 0
80
+
81
  elif opt.O2:
82
  opt.fp16 = True
83
  opt.dir_text = True
84
+ opt.lambda_entropy = 1e-3
85
+ opt.lambda_opacity = 1e-3 # no occupancy grid, so use a stronger opacity loss.
86
 
87
  if opt.backbone == 'vanilla':
88
  from nerf.network import NeRFNetwork
nerf/sd.py CHANGED
@@ -20,7 +20,7 @@ class StableDiffusion(nn.Module):
20
  print(f'[INFO] loaded hugging face access token from ./TOKEN!')
21
  except FileNotFoundError as e:
22
  self.token = True
23
- print(f'[INFO] try to load hugging face access token from the default plase, make sure you have run `huggingface-cli login`.')
24
 
25
  self.device = device
26
  self.num_train_timesteps = 1000
 
20
  print(f'[INFO] loaded hugging face access token from ./TOKEN!')
21
  except FileNotFoundError as e:
22
  self.token = True
23
+ print(f'[INFO] try to load hugging face access token from the default place, make sure you have run `huggingface-cli login`.')
24
 
25
  self.device = device
26
  self.num_train_timesteps = 1000
nerf/utils.py CHANGED
@@ -330,11 +330,11 @@ class Trainer(object):
330
  if rand > 0.8:
331
  shading = 'albedo'
332
  ambient_ratio = 1.0
333
- elif rand > 0.4:
334
- shading = 'lambertian'
335
- ambient_ratio = 0.1
336
  else:
337
- shading = 'textureless'
338
  ambient_ratio = 0.1
339
 
340
  # _t = time.time()
@@ -355,22 +355,24 @@ class Trainer(object):
355
 
356
  # encode pred_rgb to latents
357
  # _t = time.time()
358
- loss_guidance = self.guidance.train_step(text_z, pred_rgb)
359
  # torch.cuda.synchronize(); print(f'[TIME] total guiding {time.time() - _t:.4f}s')
360
 
361
  # occupancy loss
362
  pred_ws = outputs['weights_sum'].reshape(B, 1, H, W)
363
- # mask_ws = outputs['mask'].reshape(B, 1, H, W) # near < far
364
 
365
- # loss_ws = (pred_ws ** 2 + 0.01).sqrt().mean()
 
 
366
 
367
- alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
368
- # alphas = alphas ** 2 # skewed entropy, favors 0 over 1
369
- loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
370
-
371
- loss = loss_guidance + self.opt.lambda_entropy * loss_entropy
 
372
 
373
- if 'loss_orient' in outputs:
374
  loss_orient = outputs['loss_orient']
375
  loss = loss + self.opt.lambda_orient * loss_orient
376
 
 
330
  if rand > 0.8:
331
  shading = 'albedo'
332
  ambient_ratio = 1.0
333
+ # elif rand > 0.4:
334
+ # shading = 'textureless'
335
+ # ambient_ratio = 0.1
336
  else:
337
+ shading = 'lambertian'
338
  ambient_ratio = 0.1
339
 
340
  # _t = time.time()
 
355
 
356
  # encode pred_rgb to latents
357
  # _t = time.time()
358
+ loss = self.guidance.train_step(text_z, pred_rgb)
359
  # torch.cuda.synchronize(); print(f'[TIME] total guiding {time.time() - _t:.4f}s')
360
 
361
  # occupancy loss
362
  pred_ws = outputs['weights_sum'].reshape(B, 1, H, W)
 
363
 
364
+ if self.opt.lambda_opacity > 0:
365
+ loss_opacity = (pred_ws ** 2).mean()
366
+ loss = loss + self.opt.lambda_opacity * loss_opacity
367
 
368
+ if self.opt.lambda_entropy > 0:
369
+ alphas = (pred_ws).clamp(1e-5, 1 - 1e-5)
370
+ # alphas = alphas ** 2 # skewed entropy, favors 0 over 1
371
+ loss_entropy = (- alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean()
372
+
373
+ loss = loss + self.opt.lambda_entropy * loss_entropy
374
 
375
+ if self.opt.lambda_orient > 0 and 'loss_orient' in outputs:
376
  loss_orient = outputs['loss_orient']
377
  loss = loss + self.opt.lambda_orient * loss_orient
378