Stanislaw Szymanowicz commited on
Commit
bfc135c
1 Parent(s): 5b91d79

cpu compatibility

Browse files
Files changed (1) hide show
  1. scene/gaussian_predictor.py +4 -2
scene/gaussian_predictor.py CHANGED
@@ -746,7 +746,8 @@ class GaussianSplatPredictor(nn.Module):
746
  # Pos prediction is in camera space - compute the positions in the world space
747
  pos = self.flatten_vector(pos)
748
  pos = torch.cat([pos,
749
- torch.ones((pos.shape[0], pos.shape[1], 1), device="cuda", dtype=torch.float32)
 
750
  ], dim=2)
751
  pos = torch.bmm(pos, source_cameras_view_to_world)
752
  pos = pos[:, :, :3] / (pos[:, :, 3:] + 1e-10)
@@ -781,7 +782,8 @@ class GaussianSplatPredictor(nn.Module):
781
  out_dict["features_rest"] = torch.zeros((out_dict["features_dc"].shape[0],
782
  out_dict["features_dc"].shape[1],
783
  (self.cfg.model.max_sh_degree + 1) ** 2 - 1,
784
- 3), dtype=out_dict["features_dc"].dtype, device="cuda")
 
785
 
786
  out_dict = self.multi_view_union(out_dict, B, N_views)
787
  out_dict = self.make_contiguous(out_dict)
 
746
  # Pos prediction is in camera space - compute the positions in the world space
747
  pos = self.flatten_vector(pos)
748
  pos = torch.cat([pos,
749
+ torch.ones((pos.shape[0], pos.shape[1], 1),
750
+ device=pos.device, dtype=torch.float32)
751
  ], dim=2)
752
  pos = torch.bmm(pos, source_cameras_view_to_world)
753
  pos = pos[:, :, :3] / (pos[:, :, 3:] + 1e-10)
 
782
  out_dict["features_rest"] = torch.zeros((out_dict["features_dc"].shape[0],
783
  out_dict["features_dc"].shape[1],
784
  (self.cfg.model.max_sh_degree + 1) ** 2 - 1,
785
+ 3), dtype=out_dict["features_dc"].dtype,
786
+ device=out_dict["xyz"].device)
787
 
788
  out_dict = self.multi_view_union(out_dict, B, N_views)
789
  out_dict = self.make_contiguous(out_dict)