xmutly commited on
Commit
b81f863
·
verified ·
1 Parent(s): 317bfc1

Upload 99 files

Browse files
open_clip_torch/src/open_clip/transformer.py CHANGED
@@ -312,14 +312,19 @@ class Transformer(nn.Module):
312
  return self.resblocks[0].mlp.c_fc.int8_original_dtype
313
  return self.resblocks[0].mlp.c_fc.weight.dtype
314
 
315
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
316
- for r in self.resblocks:
 
317
  if self.grad_checkpointing and not torch.jit.is_scripting():
318
  # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
319
  x = checkpoint(r, x, None, None, attn_mask)
320
  else:
321
  x = r(x, attn_mask=attn_mask)
322
- return x
 
 
 
 
323
 
324
 
325
  class VisionTransformer(nn.Module):
@@ -457,7 +462,7 @@ class VisionTransformer(nn.Module):
457
  else:
458
  return x[:, 0], x[:, 1:]
459
 
460
- def forward(self, x: torch.Tensor):
461
 
462
  # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
463
  if self.input_patchnorm:
@@ -478,12 +483,18 @@ class VisionTransformer(nn.Module):
478
  x], dim=1) # shape = [*, grid ** 2 + 1, width]
479
  x = x + self.positional_embedding.to(x.dtype)
480
 
 
 
 
481
  # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
482
  x = self.patch_dropout(x)
483
  x = self.ln_pre(x)
484
 
485
  x = x.permute(1, 0, 2) # NLD -> LND
486
- x = self.transformer(x)
 
 
 
487
  x = x.permute(1, 0, 2) # LND -> NLD
488
 
489
  if self.attn_pool is not None:
 
312
  return self.resblocks[0].mlp.c_fc.int8_original_dtype
313
  return self.resblocks[0].mlp.c_fc.weight.dtype
314
 
315
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, return_all_blocks=False):
316
+ all_blocks = []
317
+ for i, r in enumerate(self.resblocks):
318
  if self.grad_checkpointing and not torch.jit.is_scripting():
319
  # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
320
  x = checkpoint(r, x, None, None, attn_mask)
321
  else:
322
  x = r(x, attn_mask=attn_mask)
323
+ all_blocks.append(x)
324
+ if return_all_blocks:
325
+ return x, all_blocks
326
+ else:
327
+ return x
328
 
329
 
330
  class VisionTransformer(nn.Module):
 
462
  else:
463
  return x[:, 0], x[:, 1:]
464
 
465
+ def forward(self, x: torch.Tensor, return_all_blocks=False,need_OT=False, object_token=None):
466
 
467
  # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
468
  if self.input_patchnorm:
 
483
  x], dim=1) # shape = [*, grid ** 2 + 1, width]
484
  x = x + self.positional_embedding.to(x.dtype)
485
 
486
+ ######################################### For object-centric relation reasoning add ###########################
487
+ if need_OT:
488
+ x = torch.cat([x, object_token], dim=1)
489
  # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
490
  x = self.patch_dropout(x)
491
  x = self.ln_pre(x)
492
 
493
  x = x.permute(1, 0, 2) # NLD -> LND
494
+
495
+ x = self.transformer(x, return_all_blocks=return_all_blocks)
496
+ if return_all_blocks:
497
+ x, all_blocks_feat = x[0], x[1],
498
  x = x.permute(1, 0, 2) # LND -> NLD
499
 
500
  if self.attn_pool is not None: