S3Diff / src /my_utils /dino_struct.py
zhangap's picture
Upload 213 files
36d9761 verified
raw
history blame
7.37 kB
import torch
import torchvision
import torch.nn.functional as F
def attn_cosine_sim(x, eps=1e-08):
x = x[0] # TEMP: getting rid of redundant dimension, TBF
norm1 = x.norm(dim=2, keepdim=True)
factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
sim_matrix = (x @ x.permute(0, 2, 1)) / factor
return sim_matrix
class VitExtractor:
BLOCK_KEY = 'block'
ATTN_KEY = 'attn'
PATCH_IMD_KEY = 'patch_imd'
QKV_KEY = 'qkv'
KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
def __init__(self, model_name, device):
# pdb.set_trace()
self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
self.model.eval()
self.model_name = model_name
self.hook_handlers = []
self.layers_dict = {}
self.outputs_dict = {}
for key in VitExtractor.KEY_LIST:
self.layers_dict[key] = []
self.outputs_dict[key] = []
self._init_hooks_data()
def _init_hooks_data(self):
self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
for key in VitExtractor.KEY_LIST:
# self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
self.outputs_dict[key] = []
def _register_hooks(self, **kwargs):
for block_idx, block in enumerate(self.model.blocks):
if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
def _clear_hooks(self):
for handler in self.hook_handlers:
handler.remove()
self.hook_handlers = []
def _get_block_hook(self):
def _get_block_output(model, input, output):
self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
return _get_block_output
def _get_attn_hook(self):
def _get_attn_output(model, inp, output):
self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
return _get_attn_output
def _get_qkv_hook(self):
def _get_qkv_output(model, inp, output):
self.outputs_dict[VitExtractor.QKV_KEY].append(output)
return _get_qkv_output
# TODO: CHECK ATTN OUTPUT TUPLE
def _get_patch_imd_hook(self):
def _get_attn_output(model, inp, output):
self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
return _get_attn_output
def get_feature_from_input(self, input_img): # List([B, N, D])
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_qkv_feature_from_input(self, input_img):
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.QKV_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_attn_feature_from_input(self, input_img):
self._register_hooks()
self.model(input_img)
feature = self.outputs_dict[VitExtractor.ATTN_KEY]
self._clear_hooks()
self._init_hooks_data()
return feature
def get_patch_size(self):
return 8 if "8" in self.model_name else 16
def get_width_patch_num(self, input_img_shape):
b, c, h, w = input_img_shape
patch_size = self.get_patch_size()
return w // patch_size
def get_height_patch_num(self, input_img_shape):
b, c, h, w = input_img_shape
patch_size = self.get_patch_size()
return h // patch_size
def get_patch_num(self, input_img_shape):
patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
return patch_num
def get_head_num(self):
if "dino" in self.model_name:
return 6 if "s" in self.model_name else 12
return 6 if "small" in self.model_name else 12
def get_embedding_dim(self):
if "dino" in self.model_name:
return 384 if "s" in self.model_name else 768
return 384 if "small" in self.model_name else 768
def get_queries_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
return q
def get_keys_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
return k
def get_values_from_qkv(self, qkv, input_img_shape):
patch_num = self.get_patch_num(input_img_shape)
head_num = self.get_head_num()
embedding_dim = self.get_embedding_dim()
v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
return v
def get_keys_from_input(self, input_img, layer_num):
qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
return keys
def get_keys_self_sim_from_input(self, input_img, layer_num):
keys = self.get_keys_from_input(input_img, layer_num=layer_num)
h, t, d = keys.shape
concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
return ssim_map
class DinoStructureLoss:
def __init__(self, ):
self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda")
self.preprocess = torchvision.transforms.Compose([
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def calculate_global_ssim_loss(self, outputs, inputs):
loss = 0.0
for a, b in zip(inputs, outputs): # avoid memory limitations
with torch.no_grad():
target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11)
keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11)
loss += F.mse_loss(keys_ssim, target_keys_self_sim)
return loss