HarborYuan commited on
Commit
a78077d
1 Parent(s): c7fd587
app/models/heads/mask2former_vid.py CHANGED
@@ -190,7 +190,7 @@ class Mask2FormerVideoHead(AnchorFreeHead):
190
  # back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
191
  # if world_size > 1:
192
  # dist.broadcast(back_token, src=0)
193
- back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
194
  # back_token = back_token.to(device='cpu')
195
  cls_embed = torch.cat([
196
  cls_embed, back_token.repeat(_prototypes, 1)[None]
@@ -597,7 +597,7 @@ class Mask2FormerVideoHead(AnchorFreeHead):
597
  input_query_bbox = input_bbox_embed
598
 
599
  tgt_size = pad_size
600
- attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
601
  # match query cannot see the reconstruct
602
  attn_mask[pad_size:, :pad_size] = True
603
  # reconstruct cannot see each other
 
190
  # back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
191
  # if world_size > 1:
192
  # dist.broadcast(back_token, src=0)
193
+ back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cpu')
194
  # back_token = back_token.to(device='cpu')
195
  cls_embed = torch.cat([
196
  cls_embed, back_token.repeat(_prototypes, 1)[None]
 
597
  input_query_bbox = input_bbox_embed
598
 
599
  tgt_size = pad_size
600
+ attn_mask = torch.ones(tgt_size, tgt_size).to(input_bbox_embed.device) < 0
601
  # match query cannot see the reconstruct
602
  attn_mask[pad_size:, :pad_size] = True
603
  # reconstruct cannot see each other
app/models/heads/yoso_head.py CHANGED
@@ -376,7 +376,7 @@ class CrossAttenHead(nn.Module):
376
  # back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
377
  # if world_size > 1:
378
  # dist.broadcast(back_token, src=0)
379
- back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda')
380
  # back_token = back_token.to(device='cpu')
381
  cls_embed = torch.cat([
382
  cls_embed, back_token.repeat(_prototypes, 1)[None]
 
376
  # back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda')
377
  # if world_size > 1:
378
  # dist.broadcast(back_token, src=0)
379
+ back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cpu')
380
  # back_token = back_token.to(device='cpu')
381
  cls_embed = torch.cat([
382
  cls_embed, back_token.repeat(_prototypes, 1)[None]