lyndonzheng commited on
Commit
13398c7
1 Parent(s): 3afeb70

add photos

Browse files
app.py CHANGED
@@ -93,7 +93,7 @@ def main():
93
  gr.Markdown(
94
  """
95
  # Flash3D
96
- **Flash3D** [[project page](https://www.robots.ox.ac.uk/~vgg/research/flash3d/)] is a fast, super efficient, trinable on a single GPU in a day for dense 3D reconstruction from a single image.
97
  The model used in the demo was trained on only **RealEstate10k dataset on a single A6000 GPU within 1 day**.
98
  Upload an image of a scene or click on one of the provided examples to see how the Flash3D does.
99
  The 3D viewer will render a .ply scene exported from the 3D Gaussians, which is only an approximation.
 
93
  gr.Markdown(
94
  """
95
  # Flash3D
96
+ **Flash3D** [[project page](https://www.robots.ox.ac.uk/~vgg/research/flash3d/)] is a fast, super efficient, trinable on a single GPU in a day for scene 3D reconstruction from a single image.
97
  The model used in the demo was trained on only **RealEstate10k dataset on a single A6000 GPU within 1 day**.
98
  Upload an image of a scene or click on one of the provided examples to see how the Flash3D does.
99
  The 3D viewer will render a .ply scene exported from the 3D Gaussians, which is only an approximation.
demo_examples/blenheim_palace.JPG ADDED
demo_examples/blenheim_palace_bedroom.png ADDED
demo_examples/blenheim_palace_living.png ADDED
demo_examples/christ_church_cathedral.png ADDED
demo_examples/radcliffe.png ADDED
flash3d/networks/gaussian_predictor.py CHANGED
@@ -73,23 +73,25 @@ class GaussianPredictor(nn.Module):
73
  self.parameters_to_train += models["unidepth_extended"].get_parameter_groups()
74
 
75
  self.models = nn.ModuleDict(models)
 
76
 
 
77
  backproject_depth = {}
78
- H = cfg.dataset.height
79
- W = cfg.dataset.width
80
- for scale in cfg.model.scales:
81
  h = H // (2 ** scale)
82
  w = W // (2 ** scale)
83
- if cfg.model.shift_rays_half_pixel == "zero":
84
  shift_rays_half_pixel = 0
85
- elif cfg.model.shift_rays_half_pixel == "forward":
86
  shift_rays_half_pixel = 0.5
87
- elif cfg.model.shift_rays_half_pixel == "backward":
88
  shift_rays_half_pixel = -0.5
89
  else:
90
  raise NotImplementedError
91
  backproject_depth[str(scale)] = BackprojectDepth(
92
- cfg.optimiser.batch_size * cfg.model.gaussians_per_pixel,
93
  # backprojection can be different if padding was used
94
  h + 2 * self.cfg.dataset.pad_border_aug,
95
  w + 2 * self.cfg.dataset.pad_border_aug,
 
73
  self.parameters_to_train += models["unidepth_extended"].get_parameter_groups()
74
 
75
  self.models = nn.ModuleDict(models)
76
+ self.set_backproject()
77
 
78
+ def set_backproject(self):
79
  backproject_depth = {}
80
+ H = self.cfg.dataset.height
81
+ W = self.cfg.dataset.width
82
+ for scale in self.cfg.model.scales:
83
  h = H // (2 ** scale)
84
  w = W // (2 ** scale)
85
+ if self.cfg.model.shift_rays_half_pixel == "zero":
86
  shift_rays_half_pixel = 0
87
+ elif self.cfg.model.shift_rays_half_pixel == "forward":
88
  shift_rays_half_pixel = 0.5
89
+ elif self.cfg.model.shift_rays_half_pixel == "backward":
90
  shift_rays_half_pixel = -0.5
91
  else:
92
  raise NotImplementedError
93
  backproject_depth[str(scale)] = BackprojectDepth(
94
+ self.cfg.optimiser.batch_size * self.cfg.model.gaussians_per_pixel,
95
  # backprojection can be different if padding was used
96
  h + 2 * self.cfg.dataset.pad_border_aug,
97
  w + 2 * self.cfg.dataset.pad_border_aug,
flash3d/util/vis3d.py CHANGED
@@ -107,11 +107,9 @@ def export_ply(
107
  PlyData([PlyElement.describe(elements, "vertex")]).write(path)
108
 
109
 
110
- def save_ply(outputs, path, num_gauss=3):
111
- pad = 32
112
 
113
  def crop_r(t):
114
- h, w = 256, 384
115
  H = h + pad * 2
116
  W = w + pad * 2
117
  t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W)
@@ -120,14 +118,11 @@ def save_ply(outputs, path, num_gauss=3):
120
  return t
121
 
122
  def crop(t):
123
- h, w = 256, 384
124
  H = h + pad * 2
125
  W = w + pad * 2
126
  t = t[..., pad:H-pad, pad:W-pad]
127
  return t
128
 
129
- # import pdb
130
- # pdb.set_trace()
131
  means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3]
132
  scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
133
  rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
 
107
  PlyData([PlyElement.describe(elements, "vertex")]).write(path)
108
 
109
 
110
+ def save_ply(outputs, path, num_gauss=3, h=256, w=384, pad=32):
 
111
 
112
  def crop_r(t):
 
113
  H = h + pad * 2
114
  W = w + pad * 2
115
  t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W)
 
118
  return t
119
 
120
  def crop(t):
 
121
  H = h + pad * 2
122
  W = w + pad * 2
123
  t = t[..., pad:H-pad, pad:W-pad]
124
  return t
125
 
 
 
126
  means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3]
127
  scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
128
  rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]