alignedthreeattn / alignedthreeattn_backbone.py
huzey's picture
upload
7acde1f
import open_clip
from open_clip.transformer import VisionTransformer
import torch
from torch import Tensor, nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
from typing import List, Optional
from utils.factory import create_model_and_transforms, get_tokenizer
from prs_hook import hook_prs_logger
class CLIPPerHead(nn.Module):
def __init__(
self, pretrained="openai", model_name="ViT-B-16", spatial=False
) -> None:
super().__init__()
self.spatial = spatial
model, _, preprocess = create_model_and_transforms(
model_name, pretrained=pretrained
)
model.eval()
model.requires_grad_(False)
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial)
self.model = model
def forward(self, x):
self.prs.reinit()
with torch.no_grad():
attn_method = "head" if self.spatial else "head_no_spatial"
representation = self.model.encode_image(
x, attn_method=attn_method, normalize=False
)
# attentions, mlps = self.prs.finalize(representation)
attentions = torch.stack(self.prs.attentions, axis=1).to(x.device)
# return attentions, mlps
# attentions = rearrange(attentions, "b l h d -> b (l h) d")
return attentions
class CLIPAttnNode(nn.Module):
def __init__(
self, pretrained="openai", model_name="ViT-B-16", spatial=False
) -> None:
super().__init__()
self.spatial = spatial
model, _, preprocess = create_model_and_transforms(
model_name, pretrained=pretrained
)
model.eval()
model.requires_grad_(False)
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial)
self.model = model
def forward(self, x):
self.prs.reinit()
with torch.no_grad():
attn_method = "head" if self.spatial else "head_no_spatial"
representation = self.model.encode_image(
x, attn_method=attn_method, normalize=False
)
# attentions, mlps = self.prs.finalize(representation)
attentions = torch.stack(self.prs.attentions, axis=1).to(x.device)
# mlps = torch.stack(self.prs.mlps, axis=1).to(x.device)
# return attentions, mlps
# attentions = rearrange(attentions, "b l h d -> b (l h) d")
attentions = attentions.sum(dim=-2)
return attentions
class CLIPMLPNode(nn.Module):
def __init__(
self, pretrained="openai", model_name="ViT-B-16", spatial=False
) -> None:
super().__init__()
self.spatial = spatial
model, _, preprocess = create_model_and_transforms(
model_name, pretrained=pretrained
)
model.eval()
model.requires_grad_(False)
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial)
self.model = model
def forward(self, x):
self.prs.reinit()
with torch.no_grad():
attn_method = "head" if self.spatial else "head_no_spatial"
representation = self.model.encode_image(
x, attn_method=attn_method, normalize=False
)
# attentions, mlps = self.prs.finalize(representation)
# attentions = torch.stack(self.prs.attentions, axis=1).to(x.device)
mlps = torch.stack(self.prs.mlps[1:], axis=1).to(x.device)
# return attentions, mlps
# attentions = rearrange(attentions, "b l h d -> b (l h) d")
# attentions = attentions.sum(dim=2)
return mlps if self.spatial else mlps[:, :, 0, :]
class CLIPDebug(nn.Module):
def __init__(
self, pretrained="openai", model_name="ViT-B-16", spatial=False
) -> None:
super().__init__()
model, _, preprocess = create_model_and_transforms(
model_name, pretrained=pretrained
)
model.eval()
model.requires_grad_(False)
self.prs = hook_prs_logger(model, "cuda:0", spatial=False)
self.model = model
def forward(self, x):
self.prs.reinit()
with torch.no_grad():
attn_method = "head_no_spatial"
representation = self.model.encode_image(
x, attn_method=attn_method, normalize=False
)
# attentions, mlps = self.prs.finalize(representation)
mlps = torch.stack(self.prs.mlps, axis=1).to(x.device)
# return attentions, mlps
# attentions = rearrange(attentions, "b l h d -> b (l h) d")
return mlps[:, 1:, :]
class CLIPLastLayer(nn.Module):
def __init__(
self, pretrained="openai", model_name="ViT-B-16", spatial=False
) -> None:
super().__init__()
self.spatial = spatial
model, _, preprocess = create_model_and_transforms(
model_name, pretrained=pretrained
)
model.eval()
model.requires_grad_(False)
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial)
self.model = model
def forward(self, x):
self.prs.reinit()
with torch.no_grad():
attn_method = "head" if self.spatial else "head_no_spatial"
representation = self.model.encode_image(
x, attn_method=attn_method, normalize=False
)
# attentions, mlps = self.prs.finalize(representation)
attentions = torch.stack(self.prs.attentions, axis=1).to(x.device)
mlps = torch.stack(self.prs.mlps, axis=1).to(x.device)
mlps = mlps if self.spatial else mlps[:, :, 0, :]
# attentions = rearrange(attentions, "b l h d -> b (l h) d")
ret = attentions[:, :].sum(2).sum(1) + mlps[:, :].sum(1)
return ret.unsqueeze(1)
class SlowCLIPEndNode(nn.Module):
def __init__(
self, pretrained="openai", model_name="ViT-B-16", spatial=False
) -> None:
super().__init__()
self.spatial = spatial
model, _, preprocess = create_model_and_transforms(
model_name, pretrained=pretrained
)
model.eval()
model.requires_grad_(False)
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial)
self.model = model
def forward(self, x):
self.prs.reinit()
with torch.no_grad():
attn_method = "head" if self.spatial else "head_no_spatial"
representation = self.model.encode_image(
x, attn_method=attn_method, normalize=False
)
# attentions, mlps = self.prs.finalize(representation)
attentions = torch.stack(self.prs.attentions, axis=1).to(x.device)
mlps = torch.stack(self.prs.mlps, axis=1).to(x.device)
mlps = mlps if self.spatial else mlps[:, :, 0, :]
# attentions = rearrange(attentions, "b l h d -> b (l h) d")
rets = []
for i in range(attentions.shape[1]):
ret = attentions[:, : i + 1].sum(2).sum(1) + mlps[:, : i + 2].sum(1)
rets.append(ret)
rets = torch.stack(rets, dim=1)
return rets
class CLIPEverything(nn.Module):
def __init__(
self, pretrained="openai", model_name="ViT-B-16", spatial=False
) -> None:
super().__init__()
self.spatial = spatial
model, _, preprocess = create_model_and_transforms(
model_name, pretrained=pretrained
)
model.eval()
model.requires_grad_(False)
self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial)
self.model = model
def forward(self, x):
self.prs.reinit()
with torch.no_grad():
attn_method = "head" if self.spatial else "head_no_spatial"
representation = self.model.encode_image(
x, attn_method=attn_method, normalize=False
)
# attentions, mlps = self.prs.finalize(representation)
attentions = torch.stack(self.prs.attentions, axis=1).to(x.device)
mlps = torch.stack(self.prs.mlps, axis=1).to(x.device)
# attentions = rearrange(attentions, "b l h d -> b (l h) d")
end_nodes = []
for i in range(attentions.shape[1]):
ret = attentions[:, : i + 1].sum(-2).sum(1) + mlps[:, : i + 2].sum(1)
end_nodes.append(ret)
end_nodes = torch.stack(end_nodes, dim=1)
attn_mats = torch.stack(self.prs.attn_mats, axis=1).to(x.device)
return attentions, mlps, end_nodes, attn_mats
class EasyCLIPLastLayer(nn.Module):
def __init__(self, ver="ViT-B-16", data="openai", **kwargs) -> None:
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data)
self.vision_model: VisionTransformer = model.visual
self.vision_model.requires_grad_(False)
self.vision_model.eval()
def forward(
self,
x,
):
#### original code #### begin
##############################
### patchify ###
##############################
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.vision_model.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(
x.shape[0],
x.shape[1],
self.vision_model.grid_size[0],
self.vision_model.patch_size[0],
self.vision_model.grid_size[1],
self.vision_model.patch_size[1],
)
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(
x.shape[0],
self.vision_model.grid_size[0] * self.vision_model.grid_size[1],
-1,
)
x = self.vision_model.patchnorm_pre_ln(x)
x = self.vision_model.conv1(x)
else:
x = self.vision_model.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat(
[
self.vision_model.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.vision_model.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.vision_model.patch_dropout(x)
x = self.vision_model.ln_pre(x)
#### original code #### end
#### modified code #### begin
##############################
### transformer ###
##############################
x = x.permute(1, 0, 2) # NLD -> LND
local_tokens = {}
global_tokens = {}
tokens = []
for i, r in enumerate(self.vision_model.transformer.resblocks):
x = r(x) # [1+p**2, B, D]
x_save = x.clone()
x_save = x_save[1:, :, :] # [p**2, B, D]
p = int(np.sqrt(x_save.shape[0]))
x_save = rearrange(x_save, "(p1 p2) b d -> b d p1 p2", p1=p, p2=p)
local_tokens[str(i)] = x_save
global_tokens[str(i)] = x[0, :, :] # [B, D]
tokens.append(x[0, :, :])
return tokens[-1].unsqueeze(1)
class CLIPSumResidual(nn.Module):
def __init__(self, ver="ViT-B-16", data="openai", output_text=False, **kwargs) -> None:
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data)
self.vision_model: VisionTransformer = model.visual
self.vision_model.requires_grad_(False)
self.vision_model.eval()
self.output_text = output_text
def forward(
self,
x,
):
#### original code #### begin
##############################
### patchify ###
##############################
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.vision_model.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(
x.shape[0],
x.shape[1],
self.vision_model.grid_size[0],
self.vision_model.patch_size[0],
self.vision_model.grid_size[1],
self.vision_model.patch_size[1],
)
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(
x.shape[0],
self.vision_model.grid_size[0] * self.vision_model.grid_size[1],
-1,
)
x = self.vision_model.patchnorm_pre_ln(x)
x = self.vision_model.conv1(x)
else:
x = self.vision_model.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat(
[
self.vision_model.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.vision_model.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.vision_model.patch_dropout(x)
x = self.vision_model.ln_pre(x)
#### original code #### end
#### modified code #### begin
##############################
### transformer ###
##############################
x = x.permute(1, 0, 2) # NLD -> LND
tokens = []
for i, r in enumerate(self.vision_model.transformer.resblocks):
x = r(x) # [1+p**2, B, D]
tokens.append(x.permute(1, 0, 2))
mytokens = torch.stack(tokens, dim=1)
x = x.permute(1, 0, 2) # LND -> NLD
if self.vision_model.attn_pool is not None:
x = self.vision_model.attn_pool(x)
x = self.vision_model.ln_post(x)
pooled, tokens = self.vision_model._global_pool(x)
else:
pooled, tokens = self.vision_model._global_pool(x)
pooled = self.vision_model.ln_post(pooled)
if self.vision_model.proj is not None:
pooled = pooled @ self.vision_model.proj
if self.output_text:
return pooled, mytokens
return mytokens
class CLIPEndNode(nn.Module):
def __init__(self, ver="ViT-B-16", data="openai", spatial=False, **kwargs) -> None:
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data)
self.vision_model: VisionTransformer = model.visual
self.vision_model.requires_grad_(False)
self.vision_model.eval()
self.spatial = spatial
def forward(
self,
x,
):
#### original code #### begin
##############################
### patchify ###
##############################
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.vision_model.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(
x.shape[0],
x.shape[1],
self.vision_model.grid_size[0],
self.vision_model.patch_size[0],
self.vision_model.grid_size[1],
self.vision_model.patch_size[1],
)
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(
x.shape[0],
self.vision_model.grid_size[0] * self.vision_model.grid_size[1],
-1,
)
x = self.vision_model.patchnorm_pre_ln(x)
x = self.vision_model.conv1(x)
else:
x = self.vision_model.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat(
[
self.vision_model.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.vision_model.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.vision_model.patch_dropout(x)
x = self.vision_model.ln_pre(x)
#### original code #### end
#### modified code #### begin
##############################
### transformer ###
##############################
x = x.permute(1, 0, 2) # NLD -> LND
local_tokens = {}
global_tokens = {}
tokens = []
for i, r in enumerate(self.vision_model.transformer.resblocks):
x = r(x) # [1+p**2, B, D]
x_save = x.clone()
x_save = x_save[1:, :, :] # [p**2, B, D]
p = int(np.sqrt(x_save.shape[0]))
x_save = rearrange(x_save, "(p1 p2) b d -> b d p1 p2", p1=p, p2=p)
local_tokens[str(i)] = x_save
global_tokens[str(i)] = x[0, :, :] # [B, D]
if self.spatial:
tokens.append(rearrange(x, "p b d -> b p d"))
else:
tokens.append(x[0, :, :])
return torch.stack(tokens, dim=1)
# return local_tokens, global_tokens
class ModifiedCLIP(nn.Module):
def __init__(self, ver="ViT-B-16", data="openai", **kwargs) -> None:
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data)
self.vision_model: VisionTransformer = model.visual
self.vision_model.requires_grad_(False)
self.vision_model.eval()
def get_tokens(
self,
x,
):
#### original code #### begin
##############################
### patchify ###
##############################
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.vision_model.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(
x.shape[0],
x.shape[1],
self.vision_model.grid_size[0],
self.vision_model.patch_size[0],
self.vision_model.grid_size[1],
self.vision_model.patch_size[1],
)
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(
x.shape[0],
self.vision_model.grid_size[0] * self.vision_model.grid_size[1],
-1,
)
x = self.vision_model.patchnorm_pre_ln(x)
x = self.vision_model.conv1(x)
else:
x = self.vision_model.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat(
[
self.vision_model.class_embedding.to(x.dtype)
+ torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
),
x,
],
dim=1,
) # shape = [*, grid ** 2 + 1, width]
x = x + self.vision_model.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.vision_model.patch_dropout(x)
x = self.vision_model.ln_pre(x)
#### original code #### end
#### modified code #### begin
##############################
### transformer ###
##############################
x = x.permute(1, 0, 2) # NLD -> LND
local_tokens = {}
global_tokens = {}
for i, r in enumerate(self.vision_model.transformer.resblocks):
x = r(x) # [1+p**2, B, D]
x_save = x.clone()
x_save = x_save[1:, :, :] # [p**2, B, D]
p = int(np.sqrt(x_save.shape[0]))
x_save = rearrange(x_save, "(p1 p2) b d -> b d p1 p2", p1=p, p2=p)
local_tokens[str(i)] = x_save
global_tokens[str(i)] = x[0, :, :] # [B, D]
return local_tokens, global_tokens
# from dinov2.models.vision_transformer import DinoVisionTransformer
class ModifiedDiNOv2(nn.Module):
def __init__(self, ver="dinov2_vitb14", **kwargs) -> None:
super().__init__()
vision_model = torch.hub.load("facebookresearch/dinov2", ver)
# self.vision_model: DinoVisionTransformer = vision_model
self.vision_model = vision_model
self.vision_model.requires_grad_(False)
self.vision_model.eval()
def get_tokens(
self,
x,
):
#### original code #### begin
x = self.vision_model.prepare_tokens_with_masks(x)
#### original code #### end
#### modified code #### begin
local_tokens = {}
global_tokens = {}
for i, blk in enumerate(self.vision_model.blocks):
x = blk(x)
saved_x = x.clone()
global_tokens[str(i)] = saved_x[:, 0, :] # [B, C]
saved_x = saved_x[:, 1:, :] # remove cls token, [B, N, C]
p = int(np.sqrt(saved_x.shape[1]))
saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
local_tokens[str(i)] = saved_x
return local_tokens, global_tokens
class DiNOv2EndNode(nn.Module):
def __init__(self, ver="dinov2_vitb14_reg", num_layers=12, spatial=False) -> None:
super().__init__()
self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
self.dinov2.requires_grad_(False)
self.dinov2.eval()
self.num_layers = num_layers
self.spatial = spatial
def forward(self, x):
out = self.dinov2.get_intermediate_layers(
x, self.num_layers, return_class_token=True, norm=False
)
class_tokens, spatial_tokens = [], []
for i, (sp, cls) in enumerate(out):
class_tokens.append(cls)
spatial_tokens.append(sp)
if self.spatial:
c = torch.stack(class_tokens, dim=1) # [B, L, C]
p = torch.stack(spatial_tokens, dim=1) # [B, L, P, C]
c = repeat(c, "b l c -> b l p c", p=1)
return torch.cat([c, p], dim=2)
else:
return torch.stack(class_tokens, dim=1)
class DiNOv2SumResidual(nn.Module):
def __init__(self, ver="dinov2_vitb14_reg", num_layers=12, spatial=True) -> None:
super().__init__()
self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
self.dinov2.requires_grad_(False)
self.dinov2.eval()
self.num_layers = num_layers
self.spatial = spatial
def forward(self, x):
# resample to 196x196
x = torch.nn.functional.interpolate(x, size=(196, 196), mode="bilinear")
out = self.dinov2.get_intermediate_layers(
x, self.num_layers, return_class_token=True, norm=False
)
class_tokens, spatial_tokens = [], []
for i, (sp, cls) in enumerate(out):
class_tokens.append(cls)
spatial_tokens.append(sp)
if self.spatial:
c = torch.stack(class_tokens, dim=1) # [B, L, C]
p = torch.stack(spatial_tokens, dim=1) # [B, L, P, C]
c = repeat(c, "b l c -> b l p c", p=1)
return torch.cat([c, p], dim=2)
else:
return torch.stack(class_tokens, dim=1)
class DiNOv2AttnMlpNode(nn.Module):
def __init__(self, ver="dinov2_vitb14_reg", num_reg=4) -> None:
super().__init__()
dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
dinov2.requires_grad_(False)
dinov2.eval()
def forward(self, x: Tensor) -> Tensor:
def attn_residual_func(x: Tensor) -> Tensor:
return self.ls1(self.attn(self.norm1(x)))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
self.saved_attn_node = attn_residual_func(x)
x = x + self.saved_attn_node
self.saved_mlp_node = ffn_residual_func(x)
x = x + self.saved_mlp_node
return x
setattr(dinov2.blocks[0].__class__, "forward", forward)
self.dinov2 = dinov2
self.num_reg = num_reg
def forward(self, x: Tensor) -> Tensor:
out = self.dinov2(x)
attn_nodes = [block.saved_attn_node for block in self.dinov2.blocks]
mlp_nodes = [block.saved_mlp_node for block in self.dinov2.blocks]
nodes = torch.stack(attn_nodes + mlp_nodes, dim=1)
# remove register tokens
nodes = torch.cat([nodes[:, :, :1], nodes[:, :, self.num_reg + 1 :]], dim=2)
return nodes
class DiNOv2AttnNode(nn.Module):
def __init__(self, ver="dinov2_vitb14_reg", num_reg=4) -> None:
super().__init__()
self.dino = DiNOv2AttnMlpNode(ver=ver, num_reg=num_reg)
self.num_reg = num_reg
def forward(self, x: Tensor) -> Tensor:
# resample to 196x196
# x = torch.nn.functional.interpolate(x, size=(196, 196), mode="bilinear")
out = self.dino(x)
nodes = [block.saved_attn_node for block in self.dino.dinov2.blocks]
nodes = torch.stack(nodes, dim=1)
# remove register tokens
nodes = torch.cat([nodes[:, :, :1], nodes[:, :, self.num_reg + 1 :]], dim=2)
return nodes
class DINOv1AttnNode(nn.Module):
def __init__(self, ver='dino_vits16'):
super().__init__()
dino = torch.hub.load('facebookresearch/dino:main', ver)
dino.requires_grad_(False)
dino.eval()
def forward(self, x, return_attention=False):
y, attn = self.attn(self.norm1(x))
if return_attention:
return attn
self.saved_attn = y
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
setattr(dino.blocks[0].__class__, 'forward', forward)
self.dino = dino
def forward(self, x):
out = self.dino(x)
attn_nodes = [block.saved_attn for block in self.dino.blocks]
out = torch.stack(attn_nodes, dim=1)
d = out.shape[-1]
if d < 768:
out = F.pad(out, (0, 768 - d), 'constant', 0)
return out
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.modeling.sam import Sam
class ModifiedSAM(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
sam: Sam = sam_model_registry["vit_b"](checkpoint=None)
sd = torch.hub.load_state_dict_from_url(
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
)
sam.load_state_dict(sd)
def new_forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
local_tokens, global_tokens = {}, {}
for i, blk in enumerate(self.blocks):
x = blk(x)
x_save = x.clone()
x_save = x_save.permute(0, 3, 1, 2)
local_tokens[f"{i}"] = x_save
global_tokens[f"{i}"] = x_save.mean(dim=(2, 3))
return local_tokens, global_tokens
setattr(sam.image_encoder.__class__, "forward", new_forward)
self.image_encoder = sam.image_encoder
self.image_encoder.requires_grad_(False)
self.image_encoder.eval()
def get_tokens(
self,
x,
):
with torch.no_grad():
x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
local_tokens, global_tokens = self.image_encoder(x)
return local_tokens, global_tokens
import timm
class ModifiedMAE(timm.models.vision_transformer.VisionTransformer):
def __init__(self, **kwargs):
super(ModifiedMAE, self).__init__(**kwargs)
sd = torch.hub.load_state_dict_from_url(
"https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth"
)
checkpoint_model = sd["model"]
state_dict = self.state_dict()
for k in ["head.weight", "head.bias"]:
if (
k in checkpoint_model
and checkpoint_model[k].shape != state_dict[k].shape
):
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# load pre-trained model
msg = self.load_state_dict(checkpoint_model, strict=False)
print(msg)
self.requires_grad_(False)
self.eval()
def get_tokens(
self,
x,
):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
local_tokens = {}
global_tokens = {}
for i, blk in enumerate(self.blocks):
x = blk(x)
saved_x = x.clone()
saved_x = saved_x[:, 1:, :] # remove cls token, [B, N, C]
p = int(np.sqrt(saved_x.shape[1]))
saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
local_tokens[str(i)] = saved_x
global_tokens[str(i)] = x[:, 0, :] # [B, C]
return local_tokens, global_tokens
class MAEEndNode(nn.Module):
def __init__(self, spatial=False, **kwargs):
super().__init__(**kwargs)
model = ModifiedMAE()
model.requires_grad_(False)
model.eval()
self.model = model
self.spatial = spatial
def forward(self, x):
local_tokens, global_tokens = self.model.get_tokens(x)
# global_tokens = torch.stack(list(global_tokens.values()), dim=1)
# return global_tokens
if not self.spatial:
local_tokens = [tk.mean(dim=(2, 3)) for tk in local_tokens.values()]
local_tokens = torch.stack(local_tokens, dim=1)
return local_tokens
else:
local_tokens = [
rearrange(tk, "b c p1 p2 -> b (p1 p2) c")
for tk in local_tokens.values()
]
local_tokens = torch.stack(local_tokens, dim=1)
global_tokens = torch.stack(list(global_tokens.values()), dim=1)
global_tokens = repeat(global_tokens, "b l c -> b l p c", p=1)
return torch.cat([global_tokens, local_tokens], dim=2)
class MAEEndNodePatch(nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
model = ModifiedMAE()
model.requires_grad_(False)
model.eval()
self.model = model
def forward(self, x):
local_tokens, global_tokens = self.model.get_tokens(x)
for k, v in local_tokens.items():
local_tokens[k] = v.mean(dim=(2, 3))
local_tokens = torch.stack(list(local_tokens.values()), dim=1)
return local_tokens
class MAEAttnMlpNode(timm.models.vision_transformer.VisionTransformer):
def __init__(self, **kwargs):
super(MAEAttnMlpNode, self).__init__(**kwargs)
sd = torch.hub.load_state_dict_from_url(
"https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth"
)
checkpoint_model = sd["model"]
state_dict = self.state_dict()
for k in ["head.weight", "head.bias"]:
if (
k in checkpoint_model
and checkpoint_model[k].shape != state_dict[k].shape
):
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# load pre-trained model
msg = self.load_state_dict(checkpoint_model, strict=False)
print(msg)
self.requires_grad_(False)
self.eval()
def forward(self, x):
self.saved_attn_node = self.ls1(self.attn(self.norm1(x)))
x = x + self.saved_attn_node
self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x)))
x = x + self.saved_mlp_node
# x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
# x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
setattr(self.blocks[0].__class__, "forward", forward)
def forward(self, x):
out = super().forward(x)
attn_nodes = [block.saved_attn_node for block in self.blocks]
mlp_nodes = [block.saved_mlp_node for block in self.blocks]
nodes = torch.stack(attn_nodes + mlp_nodes, dim=1)
return nodes
class MAEAttnNode(nn.Module):
def __init__(self, **kwargs):
super().__init__(**kwargs)
model = MAEAttnMlpNode()
self.model = model
def forward(self, x):
out = self.model(x)
attn_nodes = [block.saved_attn_node for block in self.model.blocks]
return torch.stack(attn_nodes, dim=1)
from torchvision.models import ViT_B_16_Weights, ViT_L_16_Weights, ViT_H_14_Weights
from torchvision.models import vit_b_16, vit_l_16, vit_h_14
from torchvision.models import list_models, get_model
from torchvision.models.feature_extraction import (
create_feature_extractor,
get_graph_node_names,
)
class ModifiedImgNet(nn.Module):
def __init__(self, **kwargs) -> None:
super().__init__()
model = get_model("vit_b_16", weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.requires_grad_(False)
model.eval()
layers = [f"encoder.layers.encoder_layer_{i}.add_1" for i in range(12)]
model = create_feature_extractor(model, layers)
self.model = model
def get_tokens(
self,
x,
):
em = self.model(x)
out_list = list(em.values())
local_tokens = {}
global_tokens = {}
for i, out in enumerate(out_list):
saved_x = out.clone()
saved_x = saved_x[:, 1:, :] # remove cls token, [B, N, C]
p = int(np.sqrt(saved_x.shape[1]))
saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
local_tokens[str(i)] = saved_x
global_tokens[str(i)] = out[:, 0, :] # [B, C]
return local_tokens, global_tokens
import math
import torch
import torch.nn as nn
from functools import partial, reduce
from operator import mul
from timm.models.layers import PatchEmbed
class ModifiedMoCov3(timm.models.vision_transformer.VisionTransformer):
def __init__(
self,
stop_grad_conv1=False,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
):
super().__init__(norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
# Use fixed 2D sin-cos position embedding
self.build_2d_sincos_position_embedding()
# weight initialization
for name, m in self.named_modules():
if isinstance(m, nn.Linear):
if "qkv" in name:
# treat the weights of Q, K, V separately
val = math.sqrt(
6.0 / float(m.weight.shape[0] // 3 + m.weight.shape[1])
)
nn.init.uniform_(m.weight, -val, val)
else:
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
nn.init.normal_(self.cls_token, std=1e-6)
if isinstance(self.patch_embed, PatchEmbed):
# xavier_uniform initialization
val = math.sqrt(
6.0
/ float(
3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim
)
)
nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
nn.init.zeros_(self.patch_embed.proj.bias)
if stop_grad_conv1:
self.patch_embed.proj.weight.requires_grad = False
self.patch_embed.proj.bias.requires_grad = False
checkpoint = torch.hub.load_state_dict_from_url(
"https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/vit-b-300ep.pth.tar"
)
linear_keyword = "head"
# rename moco pre-trained keys
state_dict = checkpoint["state_dict"]
for k in list(state_dict.keys()):
# retain only base_encoder up to before the embedding layer
if k.startswith("module.base_encoder") and not k.startswith(
"module.base_encoder.%s" % linear_keyword
):
# remove prefix
state_dict[k[len("module.base_encoder.") :]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = self.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {
"%s.weight" % linear_keyword,
"%s.bias" % linear_keyword,
}
# print("=> loaded pre-trained self '{}'".format(checkpoint))
self.requires_grad_(False)
self.eval()
def build_2d_sincos_position_embedding(self, temperature=10000.0):
h, w = self.patch_embed.grid_size
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert (
self.embed_dim % 4 == 0
), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
pos_dim = self.embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1.0 / (temperature**omega)
out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
pos_emb = torch.cat(
[torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)],
dim=1,
)[None, :, :]
# assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
self.pos_embed.requires_grad = False
def get_tokens(
self,
x,
):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
local_tokens = {}
global_tokens = {}
for i, blk in enumerate(self.blocks):
x = blk(x)
saved_x = x.clone()
saved_x = saved_x[:, 1:, :] # remove cls token, [B, N, C]
p = int(np.sqrt(saved_x.shape[1]))
saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p)
local_tokens[str(i)] = saved_x
global_tokens[str(i)] = x[:, 0, :] # [B, C]
return local_tokens, global_tokens
if __name__ == "__main__":
# clip = CLIPAttnNode().cuda()
# dino = DiNOv2AttnNode().cuda()
dinov1 = DINOv1AttnNode().cuda()
# mae = MAEAttnNode().cuda()
x = torch.randn(1, 3, 224, 224).cuda()
# print(clip(x).shape)
# print(dino(x).shape)
# print(mae(x).shape)
print(dinov1(x).shape)