ZhenweiWang commited on
Commit
a795c20
·
verified ·
1 Parent(s): 9c4386d

Update src/models/models/visual_transformer.py

Browse files
src/models/models/visual_transformer.py CHANGED
@@ -347,15 +347,9 @@ class VisualGeometryTransformer(nn.Module):
347
  def _process_conditioning(self, depth_maps, ray_dirs, poses, b, seq_len, patch_count, embed_dim, images, cond_flags):
348
  """Process conditioning inputs."""
349
  h, w = images.shape[-2:]
350
- if self.training:
351
- assert self.sampling_strategy is not None
352
- if self.sampling_strategy == "uniform":
353
- pose_prob = depth_prob = rays_prob = 0.5
354
- else:
355
- raise ValueError(f"Unknown sampling strategy: {self.sampling_strategy}")
356
-
357
  # Process camera pose embedding
358
- use_poses = (self.training and random.random() < pose_prob) or (not self.training and cond_flags[0] == 1 and poses is not None)
359
  if use_poses:
360
  poses = poses.view(b*seq_len, -1)
361
  pose_tokens = self.pose_embed(poses).unsqueeze(1)
@@ -363,7 +357,7 @@ class VisualGeometryTransformer(nn.Module):
363
  pose_tokens = torch.zeros((b*seq_len, 1, embed_dim), device=images.device, dtype=images.dtype)
364
 
365
  # Process depth map embedding
366
- use_depth = (self.training and random.random() < depth_prob) or (not self.training and cond_flags[1] == 1 and depth_maps is not None)
367
  if use_depth:
368
  depth_maps = depth_maps.view(b*seq_len, 1, h, w)
369
  depth_tokens = self.depth_embed(depth_maps).reshape(b * seq_len, patch_count, embed_dim)
@@ -371,7 +365,7 @@ class VisualGeometryTransformer(nn.Module):
371
  depth_tokens = torch.zeros((b*seq_len, patch_count, embed_dim), device=images.device, dtype=images.dtype)
372
 
373
  # Process ray direction embedding
374
- use_rays = (self.training and random.random() < rays_prob) or (not self.training and cond_flags[2] == 1 and ray_dirs is not None)
375
  if use_rays:
376
  ray_dirs = ray_dirs.view(b*seq_len, -1)
377
  ray_tokens = self.ray_embed(ray_dirs).unsqueeze(1)
@@ -396,15 +390,7 @@ class VisualGeometryTransformer(nn.Module):
396
  if pos is not None and pos.shape != pos_target_shape:
397
  pos = pos.view(*pos_target_shape)
398
 
399
- if self.training:
400
- tokens = checkpoint(
401
- blocks[block_idx],
402
- tokens,
403
- pos=pos,
404
- use_reentrant=self.use_reentrant_checkpointing,
405
- )
406
- else:
407
- tokens = blocks[block_idx](tokens, pos=pos)
408
 
409
  return tokens.view(*token_shape)
410
 
 
347
  def _process_conditioning(self, depth_maps, ray_dirs, poses, b, seq_len, patch_count, embed_dim, images, cond_flags):
348
  """Process conditioning inputs."""
349
  h, w = images.shape[-2:]
350
+
 
 
 
 
 
 
351
  # Process camera pose embedding
352
+ use_poses = (cond_flags[0] == 1 and poses is not None)
353
  if use_poses:
354
  poses = poses.view(b*seq_len, -1)
355
  pose_tokens = self.pose_embed(poses).unsqueeze(1)
 
357
  pose_tokens = torch.zeros((b*seq_len, 1, embed_dim), device=images.device, dtype=images.dtype)
358
 
359
  # Process depth map embedding
360
+ use_depth = cond_flags[1] == 1 and depth_maps is not None
361
  if use_depth:
362
  depth_maps = depth_maps.view(b*seq_len, 1, h, w)
363
  depth_tokens = self.depth_embed(depth_maps).reshape(b * seq_len, patch_count, embed_dim)
 
365
  depth_tokens = torch.zeros((b*seq_len, patch_count, embed_dim), device=images.device, dtype=images.dtype)
366
 
367
  # Process ray direction embedding
368
+ use_rays = cond_flags[2] == 1 and ray_dirs is not None
369
  if use_rays:
370
  ray_dirs = ray_dirs.view(b*seq_len, -1)
371
  ray_tokens = self.ray_embed(ray_dirs).unsqueeze(1)
 
390
  if pos is not None and pos.shape != pos_target_shape:
391
  pos = pos.view(*pos_target_shape)
392
 
393
+ tokens = blocks[block_idx](tokens, pos=pos)
 
 
 
 
 
 
 
 
394
 
395
  return tokens.view(*token_shape)
396