|
import open_clip |
|
from open_clip.transformer import VisionTransformer |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
import torch.nn.functional as F |
|
|
|
import numpy as np |
|
|
|
from einops import rearrange, repeat |
|
|
|
from typing import List, Optional |
|
|
|
|
|
from utils.factory import create_model_and_transforms, get_tokenizer |
|
from prs_hook import hook_prs_logger |
|
|
|
|
|
class CLIPPerHead(nn.Module): |
|
def __init__( |
|
self, pretrained="openai", model_name="ViT-B-16", spatial=False |
|
) -> None: |
|
super().__init__() |
|
self.spatial = spatial |
|
model, _, preprocess = create_model_and_transforms( |
|
model_name, pretrained=pretrained |
|
) |
|
model.eval() |
|
model.requires_grad_(False) |
|
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) |
|
self.model = model |
|
|
|
def forward(self, x): |
|
self.prs.reinit() |
|
with torch.no_grad(): |
|
attn_method = "head" if self.spatial else "head_no_spatial" |
|
representation = self.model.encode_image( |
|
x, attn_method=attn_method, normalize=False |
|
) |
|
|
|
attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) |
|
|
|
|
|
return attentions |
|
|
|
|
|
class CLIPAttnNode(nn.Module): |
|
def __init__( |
|
self, pretrained="openai", model_name="ViT-B-16", spatial=False |
|
) -> None: |
|
super().__init__() |
|
self.spatial = spatial |
|
model, _, preprocess = create_model_and_transforms( |
|
model_name, pretrained=pretrained |
|
) |
|
model.eval() |
|
model.requires_grad_(False) |
|
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) |
|
self.model = model |
|
|
|
def forward(self, x): |
|
self.prs.reinit() |
|
with torch.no_grad(): |
|
attn_method = "head" if self.spatial else "head_no_spatial" |
|
representation = self.model.encode_image( |
|
x, attn_method=attn_method, normalize=False |
|
) |
|
|
|
attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) |
|
|
|
|
|
|
|
attentions = attentions.sum(dim=-2) |
|
return attentions |
|
|
|
|
|
class CLIPMLPNode(nn.Module): |
|
def __init__( |
|
self, pretrained="openai", model_name="ViT-B-16", spatial=False |
|
) -> None: |
|
super().__init__() |
|
self.spatial = spatial |
|
model, _, preprocess = create_model_and_transforms( |
|
model_name, pretrained=pretrained |
|
) |
|
model.eval() |
|
model.requires_grad_(False) |
|
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) |
|
self.model = model |
|
|
|
def forward(self, x): |
|
self.prs.reinit() |
|
with torch.no_grad(): |
|
attn_method = "head" if self.spatial else "head_no_spatial" |
|
representation = self.model.encode_image( |
|
x, attn_method=attn_method, normalize=False |
|
) |
|
|
|
|
|
mlps = torch.stack(self.prs.mlps[1:], axis=1).to(x.device) |
|
|
|
|
|
|
|
return mlps if self.spatial else mlps[:, :, 0, :] |
|
|
|
|
|
class CLIPDebug(nn.Module): |
|
def __init__( |
|
self, pretrained="openai", model_name="ViT-B-16", spatial=False |
|
) -> None: |
|
super().__init__() |
|
model, _, preprocess = create_model_and_transforms( |
|
model_name, pretrained=pretrained |
|
) |
|
model.eval() |
|
model.requires_grad_(False) |
|
self.prs = hook_prs_logger(model, "cuda:0", spatial=False) |
|
self.model = model |
|
|
|
def forward(self, x): |
|
self.prs.reinit() |
|
with torch.no_grad(): |
|
attn_method = "head_no_spatial" |
|
representation = self.model.encode_image( |
|
x, attn_method=attn_method, normalize=False |
|
) |
|
|
|
mlps = torch.stack(self.prs.mlps, axis=1).to(x.device) |
|
|
|
|
|
return mlps[:, 1:, :] |
|
|
|
|
|
class CLIPLastLayer(nn.Module): |
|
def __init__( |
|
self, pretrained="openai", model_name="ViT-B-16", spatial=False |
|
) -> None: |
|
super().__init__() |
|
self.spatial = spatial |
|
model, _, preprocess = create_model_and_transforms( |
|
model_name, pretrained=pretrained |
|
) |
|
model.eval() |
|
model.requires_grad_(False) |
|
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) |
|
self.model = model |
|
|
|
def forward(self, x): |
|
self.prs.reinit() |
|
with torch.no_grad(): |
|
attn_method = "head" if self.spatial else "head_no_spatial" |
|
representation = self.model.encode_image( |
|
x, attn_method=attn_method, normalize=False |
|
) |
|
|
|
attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) |
|
mlps = torch.stack(self.prs.mlps, axis=1).to(x.device) |
|
mlps = mlps if self.spatial else mlps[:, :, 0, :] |
|
|
|
ret = attentions[:, :].sum(2).sum(1) + mlps[:, :].sum(1) |
|
return ret.unsqueeze(1) |
|
|
|
|
|
class SlowCLIPEndNode(nn.Module): |
|
def __init__( |
|
self, pretrained="openai", model_name="ViT-B-16", spatial=False |
|
) -> None: |
|
super().__init__() |
|
self.spatial = spatial |
|
model, _, preprocess = create_model_and_transforms( |
|
model_name, pretrained=pretrained |
|
) |
|
model.eval() |
|
model.requires_grad_(False) |
|
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) |
|
self.model = model |
|
|
|
def forward(self, x): |
|
self.prs.reinit() |
|
with torch.no_grad(): |
|
attn_method = "head" if self.spatial else "head_no_spatial" |
|
representation = self.model.encode_image( |
|
x, attn_method=attn_method, normalize=False |
|
) |
|
|
|
attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) |
|
mlps = torch.stack(self.prs.mlps, axis=1).to(x.device) |
|
mlps = mlps if self.spatial else mlps[:, :, 0, :] |
|
|
|
|
|
rets = [] |
|
for i in range(attentions.shape[1]): |
|
ret = attentions[:, : i + 1].sum(2).sum(1) + mlps[:, : i + 2].sum(1) |
|
rets.append(ret) |
|
rets = torch.stack(rets, dim=1) |
|
return rets |
|
|
|
|
|
class CLIPEverything(nn.Module): |
|
def __init__( |
|
self, pretrained="openai", model_name="ViT-B-16", spatial=False |
|
) -> None: |
|
super().__init__() |
|
self.spatial = spatial |
|
model, _, preprocess = create_model_and_transforms( |
|
model_name, pretrained=pretrained |
|
) |
|
model.eval() |
|
model.requires_grad_(False) |
|
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) |
|
self.model = model |
|
|
|
def forward(self, x): |
|
self.prs.reinit() |
|
with torch.no_grad(): |
|
attn_method = "head" if self.spatial else "head_no_spatial" |
|
representation = self.model.encode_image( |
|
x, attn_method=attn_method, normalize=False |
|
) |
|
|
|
attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) |
|
mlps = torch.stack(self.prs.mlps, axis=1).to(x.device) |
|
|
|
|
|
end_nodes = [] |
|
for i in range(attentions.shape[1]): |
|
ret = attentions[:, : i + 1].sum(-2).sum(1) + mlps[:, : i + 2].sum(1) |
|
end_nodes.append(ret) |
|
end_nodes = torch.stack(end_nodes, dim=1) |
|
|
|
attn_mats = torch.stack(self.prs.attn_mats, axis=1).to(x.device) |
|
|
|
return attentions, mlps, end_nodes, attn_mats |
|
|
|
|
|
class EasyCLIPLastLayer(nn.Module): |
|
def __init__(self, ver="ViT-B-16", data="openai", **kwargs) -> None: |
|
super().__init__() |
|
model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data) |
|
self.vision_model: VisionTransformer = model.visual |
|
self.vision_model.requires_grad_(False) |
|
self.vision_model.eval() |
|
|
|
def forward( |
|
self, |
|
x, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.vision_model.input_patchnorm: |
|
|
|
x = x.reshape( |
|
x.shape[0], |
|
x.shape[1], |
|
self.vision_model.grid_size[0], |
|
self.vision_model.patch_size[0], |
|
self.vision_model.grid_size[1], |
|
self.vision_model.patch_size[1], |
|
) |
|
x = x.permute(0, 2, 4, 1, 3, 5) |
|
x = x.reshape( |
|
x.shape[0], |
|
self.vision_model.grid_size[0] * self.vision_model.grid_size[1], |
|
-1, |
|
) |
|
x = self.vision_model.patchnorm_pre_ln(x) |
|
x = self.vision_model.conv1(x) |
|
else: |
|
x = self.vision_model.conv1(x) |
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
x = torch.cat( |
|
[ |
|
self.vision_model.class_embedding.to(x.dtype) |
|
+ torch.zeros( |
|
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device |
|
), |
|
x, |
|
], |
|
dim=1, |
|
) |
|
x = x + self.vision_model.positional_embedding.to(x.dtype) |
|
|
|
|
|
x = self.vision_model.patch_dropout(x) |
|
x = self.vision_model.ln_pre(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.permute(1, 0, 2) |
|
|
|
local_tokens = {} |
|
global_tokens = {} |
|
tokens = [] |
|
for i, r in enumerate(self.vision_model.transformer.resblocks): |
|
x = r(x) |
|
x_save = x.clone() |
|
x_save = x_save[1:, :, :] |
|
p = int(np.sqrt(x_save.shape[0])) |
|
x_save = rearrange(x_save, "(p1 p2) b d -> b d p1 p2", p1=p, p2=p) |
|
local_tokens[str(i)] = x_save |
|
global_tokens[str(i)] = x[0, :, :] |
|
tokens.append(x[0, :, :]) |
|
return tokens[-1].unsqueeze(1) |
|
|
|
|
|
class CLIPSumResidual(nn.Module): |
|
def __init__(self, ver="ViT-B-16", data="openai", output_text=False, **kwargs) -> None: |
|
super().__init__() |
|
model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data) |
|
self.vision_model: VisionTransformer = model.visual |
|
self.vision_model.requires_grad_(False) |
|
self.vision_model.eval() |
|
|
|
self.output_text = output_text |
|
|
|
def forward( |
|
self, |
|
x, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.vision_model.input_patchnorm: |
|
|
|
x = x.reshape( |
|
x.shape[0], |
|
x.shape[1], |
|
self.vision_model.grid_size[0], |
|
self.vision_model.patch_size[0], |
|
self.vision_model.grid_size[1], |
|
self.vision_model.patch_size[1], |
|
) |
|
x = x.permute(0, 2, 4, 1, 3, 5) |
|
x = x.reshape( |
|
x.shape[0], |
|
self.vision_model.grid_size[0] * self.vision_model.grid_size[1], |
|
-1, |
|
) |
|
x = self.vision_model.patchnorm_pre_ln(x) |
|
x = self.vision_model.conv1(x) |
|
else: |
|
x = self.vision_model.conv1(x) |
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
x = torch.cat( |
|
[ |
|
self.vision_model.class_embedding.to(x.dtype) |
|
+ torch.zeros( |
|
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device |
|
), |
|
x, |
|
], |
|
dim=1, |
|
) |
|
x = x + self.vision_model.positional_embedding.to(x.dtype) |
|
|
|
|
|
x = self.vision_model.patch_dropout(x) |
|
x = self.vision_model.ln_pre(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.permute(1, 0, 2) |
|
|
|
|
|
tokens = [] |
|
for i, r in enumerate(self.vision_model.transformer.resblocks): |
|
x = r(x) |
|
tokens.append(x.permute(1, 0, 2)) |
|
mytokens = torch.stack(tokens, dim=1) |
|
x = x.permute(1, 0, 2) |
|
|
|
if self.vision_model.attn_pool is not None: |
|
x = self.vision_model.attn_pool(x) |
|
x = self.vision_model.ln_post(x) |
|
pooled, tokens = self.vision_model._global_pool(x) |
|
else: |
|
pooled, tokens = self.vision_model._global_pool(x) |
|
pooled = self.vision_model.ln_post(pooled) |
|
|
|
if self.vision_model.proj is not None: |
|
pooled = pooled @ self.vision_model.proj |
|
|
|
if self.output_text: |
|
return pooled, mytokens |
|
|
|
return mytokens |
|
|
|
class CLIPEndNode(nn.Module): |
|
def __init__(self, ver="ViT-B-16", data="openai", spatial=False, **kwargs) -> None: |
|
super().__init__() |
|
model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data) |
|
self.vision_model: VisionTransformer = model.visual |
|
self.vision_model.requires_grad_(False) |
|
self.vision_model.eval() |
|
|
|
self.spatial = spatial |
|
|
|
def forward( |
|
self, |
|
x, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.vision_model.input_patchnorm: |
|
|
|
x = x.reshape( |
|
x.shape[0], |
|
x.shape[1], |
|
self.vision_model.grid_size[0], |
|
self.vision_model.patch_size[0], |
|
self.vision_model.grid_size[1], |
|
self.vision_model.patch_size[1], |
|
) |
|
x = x.permute(0, 2, 4, 1, 3, 5) |
|
x = x.reshape( |
|
x.shape[0], |
|
self.vision_model.grid_size[0] * self.vision_model.grid_size[1], |
|
-1, |
|
) |
|
x = self.vision_model.patchnorm_pre_ln(x) |
|
x = self.vision_model.conv1(x) |
|
else: |
|
x = self.vision_model.conv1(x) |
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
x = torch.cat( |
|
[ |
|
self.vision_model.class_embedding.to(x.dtype) |
|
+ torch.zeros( |
|
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device |
|
), |
|
x, |
|
], |
|
dim=1, |
|
) |
|
x = x + self.vision_model.positional_embedding.to(x.dtype) |
|
|
|
|
|
x = self.vision_model.patch_dropout(x) |
|
x = self.vision_model.ln_pre(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.permute(1, 0, 2) |
|
|
|
local_tokens = {} |
|
global_tokens = {} |
|
tokens = [] |
|
for i, r in enumerate(self.vision_model.transformer.resblocks): |
|
x = r(x) |
|
x_save = x.clone() |
|
x_save = x_save[1:, :, :] |
|
p = int(np.sqrt(x_save.shape[0])) |
|
x_save = rearrange(x_save, "(p1 p2) b d -> b d p1 p2", p1=p, p2=p) |
|
local_tokens[str(i)] = x_save |
|
global_tokens[str(i)] = x[0, :, :] |
|
if self.spatial: |
|
tokens.append(rearrange(x, "p b d -> b p d")) |
|
else: |
|
tokens.append(x[0, :, :]) |
|
return torch.stack(tokens, dim=1) |
|
|
|
|
|
|
|
|
|
class ModifiedCLIP(nn.Module): |
|
def __init__(self, ver="ViT-B-16", data="openai", **kwargs) -> None: |
|
super().__init__() |
|
model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data) |
|
self.vision_model: VisionTransformer = model.visual |
|
self.vision_model.requires_grad_(False) |
|
self.vision_model.eval() |
|
|
|
def get_tokens( |
|
self, |
|
x, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.vision_model.input_patchnorm: |
|
|
|
x = x.reshape( |
|
x.shape[0], |
|
x.shape[1], |
|
self.vision_model.grid_size[0], |
|
self.vision_model.patch_size[0], |
|
self.vision_model.grid_size[1], |
|
self.vision_model.patch_size[1], |
|
) |
|
x = x.permute(0, 2, 4, 1, 3, 5) |
|
x = x.reshape( |
|
x.shape[0], |
|
self.vision_model.grid_size[0] * self.vision_model.grid_size[1], |
|
-1, |
|
) |
|
x = self.vision_model.patchnorm_pre_ln(x) |
|
x = self.vision_model.conv1(x) |
|
else: |
|
x = self.vision_model.conv1(x) |
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
x = x.permute(0, 2, 1) |
|
|
|
|
|
x = torch.cat( |
|
[ |
|
self.vision_model.class_embedding.to(x.dtype) |
|
+ torch.zeros( |
|
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device |
|
), |
|
x, |
|
], |
|
dim=1, |
|
) |
|
x = x + self.vision_model.positional_embedding.to(x.dtype) |
|
|
|
|
|
x = self.vision_model.patch_dropout(x) |
|
x = self.vision_model.ln_pre(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.permute(1, 0, 2) |
|
|
|
local_tokens = {} |
|
global_tokens = {} |
|
for i, r in enumerate(self.vision_model.transformer.resblocks): |
|
x = r(x) |
|
x_save = x.clone() |
|
x_save = x_save[1:, :, :] |
|
p = int(np.sqrt(x_save.shape[0])) |
|
x_save = rearrange(x_save, "(p1 p2) b d -> b d p1 p2", p1=p, p2=p) |
|
local_tokens[str(i)] = x_save |
|
global_tokens[str(i)] = x[0, :, :] |
|
|
|
return local_tokens, global_tokens |
|
|
|
|
|
|
|
|
|
|
|
class ModifiedDiNOv2(nn.Module): |
|
def __init__(self, ver="dinov2_vitb14", **kwargs) -> None: |
|
super().__init__() |
|
vision_model = torch.hub.load("facebookresearch/dinov2", ver) |
|
|
|
self.vision_model = vision_model |
|
self.vision_model.requires_grad_(False) |
|
self.vision_model.eval() |
|
|
|
def get_tokens( |
|
self, |
|
x, |
|
): |
|
|
|
x = self.vision_model.prepare_tokens_with_masks(x) |
|
|
|
|
|
|
|
local_tokens = {} |
|
global_tokens = {} |
|
for i, blk in enumerate(self.vision_model.blocks): |
|
x = blk(x) |
|
saved_x = x.clone() |
|
global_tokens[str(i)] = saved_x[:, 0, :] |
|
saved_x = saved_x[:, 1:, :] |
|
p = int(np.sqrt(saved_x.shape[1])) |
|
saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p) |
|
local_tokens[str(i)] = saved_x |
|
return local_tokens, global_tokens |
|
|
|
|
|
class DiNOv2EndNode(nn.Module): |
|
def __init__(self, ver="dinov2_vitb14_reg", num_layers=12, spatial=False) -> None: |
|
super().__init__() |
|
self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver) |
|
self.dinov2.requires_grad_(False) |
|
self.dinov2.eval() |
|
|
|
self.num_layers = num_layers |
|
self.spatial = spatial |
|
|
|
def forward(self, x): |
|
out = self.dinov2.get_intermediate_layers( |
|
x, self.num_layers, return_class_token=True, norm=False |
|
) |
|
class_tokens, spatial_tokens = [], [] |
|
for i, (sp, cls) in enumerate(out): |
|
class_tokens.append(cls) |
|
spatial_tokens.append(sp) |
|
if self.spatial: |
|
c = torch.stack(class_tokens, dim=1) |
|
p = torch.stack(spatial_tokens, dim=1) |
|
c = repeat(c, "b l c -> b l p c", p=1) |
|
return torch.cat([c, p], dim=2) |
|
else: |
|
return torch.stack(class_tokens, dim=1) |
|
|
|
|
|
class DiNOv2SumResidual(nn.Module): |
|
def __init__(self, ver="dinov2_vitb14_reg", num_layers=12, spatial=True) -> None: |
|
super().__init__() |
|
self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver) |
|
self.dinov2.requires_grad_(False) |
|
self.dinov2.eval() |
|
|
|
self.num_layers = num_layers |
|
self.spatial = spatial |
|
|
|
def forward(self, x): |
|
|
|
x = torch.nn.functional.interpolate(x, size=(196, 196), mode="bilinear") |
|
out = self.dinov2.get_intermediate_layers( |
|
x, self.num_layers, return_class_token=True, norm=False |
|
) |
|
class_tokens, spatial_tokens = [], [] |
|
for i, (sp, cls) in enumerate(out): |
|
class_tokens.append(cls) |
|
spatial_tokens.append(sp) |
|
if self.spatial: |
|
c = torch.stack(class_tokens, dim=1) |
|
p = torch.stack(spatial_tokens, dim=1) |
|
c = repeat(c, "b l c -> b l p c", p=1) |
|
return torch.cat([c, p], dim=2) |
|
else: |
|
return torch.stack(class_tokens, dim=1) |
|
|
|
|
|
class DiNOv2AttnMlpNode(nn.Module): |
|
|
|
def __init__(self, ver="dinov2_vitb14_reg", num_reg=4) -> None: |
|
super().__init__() |
|
dinov2 = torch.hub.load("facebookresearch/dinov2", ver) |
|
dinov2.requires_grad_(False) |
|
dinov2.eval() |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
def attn_residual_func(x: Tensor) -> Tensor: |
|
return self.ls1(self.attn(self.norm1(x))) |
|
|
|
def ffn_residual_func(x: Tensor) -> Tensor: |
|
return self.ls2(self.mlp(self.norm2(x))) |
|
|
|
self.saved_attn_node = attn_residual_func(x) |
|
x = x + self.saved_attn_node |
|
self.saved_mlp_node = ffn_residual_func(x) |
|
x = x + self.saved_mlp_node |
|
return x |
|
|
|
setattr(dinov2.blocks[0].__class__, "forward", forward) |
|
|
|
self.dinov2 = dinov2 |
|
self.num_reg = num_reg |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
out = self.dinov2(x) |
|
attn_nodes = [block.saved_attn_node for block in self.dinov2.blocks] |
|
mlp_nodes = [block.saved_mlp_node for block in self.dinov2.blocks] |
|
nodes = torch.stack(attn_nodes + mlp_nodes, dim=1) |
|
|
|
nodes = torch.cat([nodes[:, :, :1], nodes[:, :, self.num_reg + 1 :]], dim=2) |
|
return nodes |
|
|
|
|
|
class DiNOv2AttnNode(nn.Module): |
|
def __init__(self, ver="dinov2_vitb14_reg", num_reg=4) -> None: |
|
super().__init__() |
|
self.dino = DiNOv2AttnMlpNode(ver=ver, num_reg=num_reg) |
|
self.num_reg = num_reg |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
|
|
|
out = self.dino(x) |
|
nodes = [block.saved_attn_node for block in self.dino.dinov2.blocks] |
|
nodes = torch.stack(nodes, dim=1) |
|
|
|
nodes = torch.cat([nodes[:, :, :1], nodes[:, :, self.num_reg + 1 :]], dim=2) |
|
return nodes |
|
|
|
|
|
|
|
class DINOv1AttnNode(nn.Module): |
|
def __init__(self, ver='dino_vits16'): |
|
super().__init__() |
|
dino = torch.hub.load('facebookresearch/dino:main', ver) |
|
dino.requires_grad_(False) |
|
dino.eval() |
|
|
|
def forward(self, x, return_attention=False): |
|
y, attn = self.attn(self.norm1(x)) |
|
if return_attention: |
|
return attn |
|
self.saved_attn = y |
|
x = x + self.drop_path(y) |
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
return x |
|
|
|
setattr(dino.blocks[0].__class__, 'forward', forward) |
|
self.dino = dino |
|
|
|
def forward(self, x): |
|
out = self.dino(x) |
|
attn_nodes = [block.saved_attn for block in self.dino.blocks] |
|
out = torch.stack(attn_nodes, dim=1) |
|
d = out.shape[-1] |
|
if d < 768: |
|
out = F.pad(out, (0, 768 - d), 'constant', 0) |
|
return out |
|
|
|
|
|
from segment_anything import sam_model_registry, SamPredictor |
|
from segment_anything.modeling.sam import Sam |
|
|
|
|
|
class ModifiedSAM(torch.nn.Module): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
sam: Sam = sam_model_registry["vit_b"](checkpoint=None) |
|
sd = torch.hub.load_state_dict_from_url( |
|
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" |
|
) |
|
sam.load_state_dict(sd) |
|
|
|
def new_forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.patch_embed(x) |
|
if self.pos_embed is not None: |
|
x = x + self.pos_embed |
|
local_tokens, global_tokens = {}, {} |
|
for i, blk in enumerate(self.blocks): |
|
x = blk(x) |
|
x_save = x.clone() |
|
x_save = x_save.permute(0, 3, 1, 2) |
|
local_tokens[f"{i}"] = x_save |
|
global_tokens[f"{i}"] = x_save.mean(dim=(2, 3)) |
|
|
|
return local_tokens, global_tokens |
|
|
|
setattr(sam.image_encoder.__class__, "forward", new_forward) |
|
|
|
self.image_encoder = sam.image_encoder |
|
self.image_encoder.requires_grad_(False) |
|
self.image_encoder.eval() |
|
|
|
def get_tokens( |
|
self, |
|
x, |
|
): |
|
with torch.no_grad(): |
|
x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear") |
|
local_tokens, global_tokens = self.image_encoder(x) |
|
return local_tokens, global_tokens |
|
|
|
|
|
import timm |
|
|
|
|
|
class ModifiedMAE(timm.models.vision_transformer.VisionTransformer): |
|
def __init__(self, **kwargs): |
|
super(ModifiedMAE, self).__init__(**kwargs) |
|
|
|
sd = torch.hub.load_state_dict_from_url( |
|
"https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth" |
|
) |
|
|
|
checkpoint_model = sd["model"] |
|
state_dict = self.state_dict() |
|
for k in ["head.weight", "head.bias"]: |
|
if ( |
|
k in checkpoint_model |
|
and checkpoint_model[k].shape != state_dict[k].shape |
|
): |
|
print(f"Removing key {k} from pretrained checkpoint") |
|
del checkpoint_model[k] |
|
|
|
|
|
msg = self.load_state_dict(checkpoint_model, strict=False) |
|
print(msg) |
|
|
|
self.requires_grad_(False) |
|
self.eval() |
|
|
|
def get_tokens( |
|
self, |
|
x, |
|
): |
|
B = x.shape[0] |
|
x = self.patch_embed(x) |
|
|
|
cls_tokens = self.cls_token.expand( |
|
B, -1, -1 |
|
) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
x = x + self.pos_embed |
|
x = self.pos_drop(x) |
|
|
|
local_tokens = {} |
|
global_tokens = {} |
|
for i, blk in enumerate(self.blocks): |
|
x = blk(x) |
|
saved_x = x.clone() |
|
saved_x = saved_x[:, 1:, :] |
|
p = int(np.sqrt(saved_x.shape[1])) |
|
saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p) |
|
local_tokens[str(i)] = saved_x |
|
global_tokens[str(i)] = x[:, 0, :] |
|
return local_tokens, global_tokens |
|
|
|
|
|
class MAEEndNode(nn.Module): |
|
def __init__(self, spatial=False, **kwargs): |
|
super().__init__(**kwargs) |
|
model = ModifiedMAE() |
|
model.requires_grad_(False) |
|
model.eval() |
|
self.model = model |
|
|
|
self.spatial = spatial |
|
|
|
def forward(self, x): |
|
local_tokens, global_tokens = self.model.get_tokens(x) |
|
|
|
|
|
if not self.spatial: |
|
local_tokens = [tk.mean(dim=(2, 3)) for tk in local_tokens.values()] |
|
local_tokens = torch.stack(local_tokens, dim=1) |
|
return local_tokens |
|
else: |
|
local_tokens = [ |
|
rearrange(tk, "b c p1 p2 -> b (p1 p2) c") |
|
for tk in local_tokens.values() |
|
] |
|
local_tokens = torch.stack(local_tokens, dim=1) |
|
global_tokens = torch.stack(list(global_tokens.values()), dim=1) |
|
global_tokens = repeat(global_tokens, "b l c -> b l p c", p=1) |
|
return torch.cat([global_tokens, local_tokens], dim=2) |
|
|
|
|
|
class MAEEndNodePatch(nn.Module): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
model = ModifiedMAE() |
|
model.requires_grad_(False) |
|
model.eval() |
|
self.model = model |
|
|
|
def forward(self, x): |
|
local_tokens, global_tokens = self.model.get_tokens(x) |
|
for k, v in local_tokens.items(): |
|
local_tokens[k] = v.mean(dim=(2, 3)) |
|
local_tokens = torch.stack(list(local_tokens.values()), dim=1) |
|
return local_tokens |
|
|
|
|
|
class MAEAttnMlpNode(timm.models.vision_transformer.VisionTransformer): |
|
def __init__(self, **kwargs): |
|
super(MAEAttnMlpNode, self).__init__(**kwargs) |
|
|
|
sd = torch.hub.load_state_dict_from_url( |
|
"https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth" |
|
) |
|
|
|
checkpoint_model = sd["model"] |
|
state_dict = self.state_dict() |
|
for k in ["head.weight", "head.bias"]: |
|
if ( |
|
k in checkpoint_model |
|
and checkpoint_model[k].shape != state_dict[k].shape |
|
): |
|
print(f"Removing key {k} from pretrained checkpoint") |
|
del checkpoint_model[k] |
|
|
|
|
|
msg = self.load_state_dict(checkpoint_model, strict=False) |
|
print(msg) |
|
|
|
self.requires_grad_(False) |
|
self.eval() |
|
|
|
def forward(self, x): |
|
self.saved_attn_node = self.ls1(self.attn(self.norm1(x))) |
|
x = x + self.saved_attn_node |
|
self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x))) |
|
x = x + self.saved_mlp_node |
|
|
|
|
|
return x |
|
|
|
setattr(self.blocks[0].__class__, "forward", forward) |
|
|
|
def forward(self, x): |
|
out = super().forward(x) |
|
attn_nodes = [block.saved_attn_node for block in self.blocks] |
|
mlp_nodes = [block.saved_mlp_node for block in self.blocks] |
|
nodes = torch.stack(attn_nodes + mlp_nodes, dim=1) |
|
return nodes |
|
|
|
|
|
class MAEAttnNode(nn.Module): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
model = MAEAttnMlpNode() |
|
self.model = model |
|
|
|
def forward(self, x): |
|
out = self.model(x) |
|
attn_nodes = [block.saved_attn_node for block in self.model.blocks] |
|
return torch.stack(attn_nodes, dim=1) |
|
|
|
|
|
from torchvision.models import ViT_B_16_Weights, ViT_L_16_Weights, ViT_H_14_Weights |
|
from torchvision.models import vit_b_16, vit_l_16, vit_h_14 |
|
from torchvision.models import list_models, get_model |
|
from torchvision.models.feature_extraction import ( |
|
create_feature_extractor, |
|
get_graph_node_names, |
|
) |
|
|
|
|
|
class ModifiedImgNet(nn.Module): |
|
def __init__(self, **kwargs) -> None: |
|
super().__init__() |
|
model = get_model("vit_b_16", weights=ViT_B_16_Weights.IMAGENET1K_V1) |
|
model.requires_grad_(False) |
|
model.eval() |
|
layers = [f"encoder.layers.encoder_layer_{i}.add_1" for i in range(12)] |
|
model = create_feature_extractor(model, layers) |
|
|
|
self.model = model |
|
|
|
def get_tokens( |
|
self, |
|
x, |
|
): |
|
em = self.model(x) |
|
out_list = list(em.values()) |
|
|
|
local_tokens = {} |
|
global_tokens = {} |
|
for i, out in enumerate(out_list): |
|
saved_x = out.clone() |
|
saved_x = saved_x[:, 1:, :] |
|
p = int(np.sqrt(saved_x.shape[1])) |
|
saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p) |
|
local_tokens[str(i)] = saved_x |
|
global_tokens[str(i)] = out[:, 0, :] |
|
return local_tokens, global_tokens |
|
|
|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
from functools import partial, reduce |
|
from operator import mul |
|
|
|
from timm.models.layers import PatchEmbed |
|
|
|
|
|
class ModifiedMoCov3(timm.models.vision_transformer.VisionTransformer): |
|
def __init__( |
|
self, |
|
stop_grad_conv1=False, |
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
**kwargs, |
|
): |
|
super().__init__(norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
|
|
|
self.build_2d_sincos_position_embedding() |
|
|
|
|
|
for name, m in self.named_modules(): |
|
if isinstance(m, nn.Linear): |
|
if "qkv" in name: |
|
|
|
val = math.sqrt( |
|
6.0 / float(m.weight.shape[0] // 3 + m.weight.shape[1]) |
|
) |
|
nn.init.uniform_(m.weight, -val, val) |
|
else: |
|
nn.init.xavier_uniform_(m.weight) |
|
nn.init.zeros_(m.bias) |
|
nn.init.normal_(self.cls_token, std=1e-6) |
|
|
|
if isinstance(self.patch_embed, PatchEmbed): |
|
|
|
val = math.sqrt( |
|
6.0 |
|
/ float( |
|
3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim |
|
) |
|
) |
|
nn.init.uniform_(self.patch_embed.proj.weight, -val, val) |
|
nn.init.zeros_(self.patch_embed.proj.bias) |
|
|
|
if stop_grad_conv1: |
|
self.patch_embed.proj.weight.requires_grad = False |
|
self.patch_embed.proj.bias.requires_grad = False |
|
|
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
"https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/vit-b-300ep.pth.tar" |
|
) |
|
|
|
linear_keyword = "head" |
|
|
|
state_dict = checkpoint["state_dict"] |
|
for k in list(state_dict.keys()): |
|
|
|
if k.startswith("module.base_encoder") and not k.startswith( |
|
"module.base_encoder.%s" % linear_keyword |
|
): |
|
|
|
state_dict[k[len("module.base_encoder.") :]] = state_dict[k] |
|
|
|
del state_dict[k] |
|
|
|
msg = self.load_state_dict(state_dict, strict=False) |
|
assert set(msg.missing_keys) == { |
|
"%s.weight" % linear_keyword, |
|
"%s.bias" % linear_keyword, |
|
} |
|
|
|
|
|
|
|
self.requires_grad_(False) |
|
self.eval() |
|
|
|
def build_2d_sincos_position_embedding(self, temperature=10000.0): |
|
h, w = self.patch_embed.grid_size |
|
grid_w = torch.arange(w, dtype=torch.float32) |
|
grid_h = torch.arange(h, dtype=torch.float32) |
|
grid_w, grid_h = torch.meshgrid(grid_w, grid_h) |
|
assert ( |
|
self.embed_dim % 4 == 0 |
|
), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" |
|
pos_dim = self.embed_dim // 4 |
|
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim |
|
omega = 1.0 / (temperature**omega) |
|
out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega]) |
|
out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega]) |
|
pos_emb = torch.cat( |
|
[torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], |
|
dim=1, |
|
)[None, :, :] |
|
|
|
|
|
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) |
|
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) |
|
self.pos_embed.requires_grad = False |
|
|
|
def get_tokens( |
|
self, |
|
x, |
|
): |
|
B = x.shape[0] |
|
x = self.patch_embed(x) |
|
|
|
cls_tokens = self.cls_token.expand( |
|
B, -1, -1 |
|
) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
x = x + self.pos_embed |
|
x = self.pos_drop(x) |
|
|
|
local_tokens = {} |
|
global_tokens = {} |
|
for i, blk in enumerate(self.blocks): |
|
x = blk(x) |
|
saved_x = x.clone() |
|
saved_x = saved_x[:, 1:, :] |
|
p = int(np.sqrt(saved_x.shape[1])) |
|
saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p) |
|
local_tokens[str(i)] = saved_x |
|
global_tokens[str(i)] = x[:, 0, :] |
|
return local_tokens, global_tokens |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
dinov1 = DINOv1AttnNode().cuda() |
|
|
|
x = torch.randn(1, 3, 224, 224).cuda() |
|
|
|
|
|
|
|
print(dinov1(x).shape) |
|
|