ashawkey commited on
Commit
d718cf7
1 Parent(s): 30e1aa8

add option to enable camera pose jitterring

Browse files
Files changed (2) hide show
  1. main.py +1 -0
  2. nerf/provider.py +12 -5
main.py CHANGED
@@ -42,6 +42,7 @@ if __name__ == '__main__':
42
  # rendering resolution in training, decrease this if CUDA OOM.
43
  parser.add_argument('--w', type=int, default=128, help="render width for NeRF in training")
44
  parser.add_argument('--h', type=int, default=128, help="render height for NeRF in training")
 
45
 
46
  ### dataset options
47
  parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
 
42
  # rendering resolution in training, decrease this if CUDA OOM.
43
  parser.add_argument('--w', type=int, default=128, help="render width for NeRF in training")
44
  parser.add_argument('--h', type=int, default=128, help="render height for NeRF in training")
45
+ parser.add_argument('--jitter_pose', action='store_true', help="add jitters to the randomly sampled camera poses")
46
 
47
  ### dataset options
48
  parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box(-bound, bound)")
nerf/provider.py CHANGED
@@ -55,7 +55,7 @@ def get_view_direction(thetas, phis, overhead, front):
55
  return res
56
 
57
 
58
- def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 150], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60):
59
  ''' generate random poses from an orbit camera
60
  Args:
61
  size: batch size of generated poses.
@@ -82,16 +82,23 @@ def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 150], phi_ra
82
  radius * torch.sin(thetas) * torch.cos(phis),
83
  ], dim=-1) # [B, 3]
84
 
 
 
85
  # jitters
86
- centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
87
- targets = torch.randn_like(centers) * 0.2
 
88
 
89
  # lookat
90
  forward_vector = safe_normalize(targets - centers)
91
  up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
92
  right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
 
 
 
 
 
93
 
94
- up_noise = torch.randn_like(up_vector) * 0.02
95
  up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
96
 
97
  poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
@@ -170,7 +177,7 @@ class NeRFDataset:
170
 
171
  if self.training:
172
  # random pose on the fly
173
- poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front)
174
 
175
  # random focal
176
  fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0]
 
55
  return res
56
 
57
 
58
+ def rand_poses(size, device, radius_range=[1, 1.5], theta_range=[0, 150], phi_range=[0, 360], return_dirs=False, angle_overhead=30, angle_front=60, jitter=False):
59
  ''' generate random poses from an orbit camera
60
  Args:
61
  size: batch size of generated poses.
 
82
  radius * torch.sin(thetas) * torch.cos(phis),
83
  ], dim=-1) # [B, 3]
84
 
85
+ targets = 0
86
+
87
  # jitters
88
+ if jitter:
89
+ centers = centers + (torch.rand_like(centers) * 0.2 - 0.1)
90
+ targets = targets + torch.randn_like(centers) * 0.2
91
 
92
  # lookat
93
  forward_vector = safe_normalize(targets - centers)
94
  up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
95
  right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
96
+
97
+ if jitter:
98
+ up_noise = torch.randn_like(up_vector) * 0.02
99
+ else:
100
+ up_noise = 0
101
 
 
102
  up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise)
103
 
104
  poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
 
177
 
178
  if self.training:
179
  # random pose on the fly
180
+ poses, dirs = rand_poses(B, self.device, radius_range=self.radius_range, return_dirs=self.opt.dir_text, angle_overhead=self.opt.angle_overhead, angle_front=self.opt.angle_front, jitter=self.opt.jitter_pose)
181
 
182
  # random focal
183
  fov = random.random() * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0]