Upload 99 files
Browse files
open_clip_torch/src/open_clip/transformer.py
CHANGED
@@ -312,14 +312,19 @@ class Transformer(nn.Module):
|
|
312 |
return self.resblocks[0].mlp.c_fc.int8_original_dtype
|
313 |
return self.resblocks[0].mlp.c_fc.weight.dtype
|
314 |
|
315 |
-
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
316 |
-
|
|
|
317 |
if self.grad_checkpointing and not torch.jit.is_scripting():
|
318 |
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
319 |
x = checkpoint(r, x, None, None, attn_mask)
|
320 |
else:
|
321 |
x = r(x, attn_mask=attn_mask)
|
322 |
-
|
|
|
|
|
|
|
|
|
323 |
|
324 |
|
325 |
class VisionTransformer(nn.Module):
|
@@ -457,7 +462,7 @@ class VisionTransformer(nn.Module):
|
|
457 |
else:
|
458 |
return x[:, 0], x[:, 1:]
|
459 |
|
460 |
-
def forward(self, x: torch.Tensor):
|
461 |
|
462 |
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
|
463 |
if self.input_patchnorm:
|
@@ -478,12 +483,18 @@ class VisionTransformer(nn.Module):
|
|
478 |
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
479 |
x = x + self.positional_embedding.to(x.dtype)
|
480 |
|
|
|
|
|
|
|
481 |
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
482 |
x = self.patch_dropout(x)
|
483 |
x = self.ln_pre(x)
|
484 |
|
485 |
x = x.permute(1, 0, 2) # NLD -> LND
|
486 |
-
|
|
|
|
|
|
|
487 |
x = x.permute(1, 0, 2) # LND -> NLD
|
488 |
|
489 |
if self.attn_pool is not None:
|
|
|
312 |
return self.resblocks[0].mlp.c_fc.int8_original_dtype
|
313 |
return self.resblocks[0].mlp.c_fc.weight.dtype
|
314 |
|
315 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, return_all_blocks=False):
|
316 |
+
all_blocks = []
|
317 |
+
for i, r in enumerate(self.resblocks):
|
318 |
if self.grad_checkpointing and not torch.jit.is_scripting():
|
319 |
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
320 |
x = checkpoint(r, x, None, None, attn_mask)
|
321 |
else:
|
322 |
x = r(x, attn_mask=attn_mask)
|
323 |
+
all_blocks.append(x)
|
324 |
+
if return_all_blocks:
|
325 |
+
return x, all_blocks
|
326 |
+
else:
|
327 |
+
return x
|
328 |
|
329 |
|
330 |
class VisionTransformer(nn.Module):
|
|
|
462 |
else:
|
463 |
return x[:, 0], x[:, 1:]
|
464 |
|
465 |
+
def forward(self, x: torch.Tensor, return_all_blocks=False,need_OT=False, object_token=None):
|
466 |
|
467 |
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
|
468 |
if self.input_patchnorm:
|
|
|
483 |
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
484 |
x = x + self.positional_embedding.to(x.dtype)
|
485 |
|
486 |
+
######################################### For object-centric relation reasoning add ###########################
|
487 |
+
if need_OT:
|
488 |
+
x = torch.cat([x, object_token], dim=1)
|
489 |
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
490 |
x = self.patch_dropout(x)
|
491 |
x = self.ln_pre(x)
|
492 |
|
493 |
x = x.permute(1, 0, 2) # NLD -> LND
|
494 |
+
|
495 |
+
x = self.transformer(x, return_all_blocks=return_all_blocks)
|
496 |
+
if return_all_blocks:
|
497 |
+
x, all_blocks_feat = x[0], x[1],
|
498 |
x = x.permute(1, 0, 2) # LND -> NLD
|
499 |
|
500 |
if self.attn_pool is not None:
|