jiaweir commited on
Commit
7c6a4e6
1 Parent(s): b73b3dd
Files changed (2) hide show
  1. lgm/core/models.py +6 -5
  2. main_4d_demo.py +1 -1
lgm/core/models.py CHANGED
@@ -128,14 +128,16 @@ class LGM(nn.Module):
128
 
129
  x_orig_res = x.clone()
130
 
131
- x = F.interpolate(x, (self.opt.splat_size // 4, self.opt.splat_size//4), mode='nearest')
132
- x = x.reshape(B, 4, 14, self.opt.splat_size//4, self.opt.splat_size//4)
 
 
133
 
134
  x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
135
 
136
  pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
137
  opacity = self.opacity_act(x[..., 3:4])
138
- scale = self.scale_act(x[..., 4:7]) * 4
139
  rotation = self.rot_act(x[..., 7:11])
140
  rgbs = self.rgb_act(x[..., 11:])
141
 
@@ -155,8 +157,7 @@ class LGM(nn.Module):
155
  gaussians_orig_res = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
156
 
157
 
158
- # return gaussians, gaussians_orig_res
159
- return gaussians_orig_res, gaussians_orig_res
160
 
161
 
162
  def forward(self, data, step_ratio=1):
 
128
 
129
  x_orig_res = x.clone()
130
 
131
+ dowsample_rate = 2
132
+
133
+ x = F.interpolate(x, (self.opt.splat_size // dowsample_rate, self.opt.splat_size//dowsample_rate), mode='nearest')
134
+ x = x.reshape(B, 4, 14, self.opt.splat_size//dowsample_rate, self.opt.splat_size//dowsample_rate)
135
 
136
  x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
137
 
138
  pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
139
  opacity = self.opacity_act(x[..., 3:4])
140
+ scale = self.scale_act(x[..., 4:7]) * dowsample_rate
141
  rotation = self.rot_act(x[..., 7:11])
142
  rgbs = self.rgb_act(x[..., 11:])
143
 
 
157
  gaussians_orig_res = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
158
 
159
 
160
+ return gaussians, gaussians_orig_res
 
161
 
162
 
163
  def forward(self, data, step_ratio=1):
main_4d_demo.py CHANGED
@@ -540,7 +540,7 @@ class GUI:
540
 
541
  # render eval
542
  image_list =[]
543
- fps = 14
544
  delta_time = 1 / 30
545
  self.renderer.prepare_render_4x()
546
  time = 0
 
540
 
541
  # render eval
542
  image_list =[]
543
+ fps = 28
544
  delta_time = 1 / 30
545
  self.renderer.prepare_render_4x()
546
  time = 0