Plachta commited on
Commit
76ce233
1 Parent(s): b298540

Update models/grasp_mods.py

Browse files
Files changed (1) hide show
  1. models/grasp_mods.py +1 -0
models/grasp_mods.py CHANGED
@@ -336,6 +336,7 @@ def add_inference_method(self):
336
  # repeat to batchsize
337
  grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
338
  pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0)
 
339
  downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64),
340
  mode='nearest').squeeze(1).bool()
341
  downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64 * 64).contiguous()
 
336
  # repeat to batchsize
337
  grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
338
  pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0)
339
+ pixel_masks = pixel_masks.repeat(n_queries, 1, 1)
340
  downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64),
341
  mode='nearest').squeeze(1).bool()
342
  downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64 * 64).contiguous()