luoweibetter's picture
Upload 205 files
4057a1f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class INP_Former(nn.Module):
def __init__(
self,
encoder,
bottleneck,
aggregation,
decoder,
target_layers =[2, 3, 4, 5, 6, 7, 8, 9],
fuse_layer_encoder =[[0, 1, 2, 3, 4, 5, 6, 7]],
fuse_layer_decoder =[[0, 1, 2, 3, 4, 5, 6, 7]],
remove_class_token=False,
encoder_require_grad_layer=[],
prototype_token=None,
) -> None:
super(INP_Former, self).__init__()
self.encoder = encoder
self.bottleneck = bottleneck
self.aggregation = aggregation
self.decoder = decoder
self.target_layers = target_layers
self.fuse_layer_encoder = fuse_layer_encoder
self.fuse_layer_decoder = fuse_layer_decoder
self.remove_class_token = remove_class_token
self.encoder_require_grad_layer = encoder_require_grad_layer
self.prototype_token = prototype_token[0]
if not hasattr(self.encoder, 'num_register_tokens'):
self.encoder.num_register_tokens = 0
def gather_loss(self, query, keys):
self.distribution = 1. - F.cosine_similarity(query.unsqueeze(2), keys.unsqueeze(1), dim=-1)
self.distance, self.cluster_index = torch.min(self.distribution, dim=2)
gather_loss = self.distance.mean()
return gather_loss
def forward(self, x):
x = self.encoder.prepare_tokens(x)
B, L, _ = x.shape
en_list = []
for i, blk in enumerate(self.encoder.blocks):
if i <= self.target_layers[-1]:
if i in self.encoder_require_grad_layer:
x = blk(x)
else:
with torch.no_grad():
x = blk(x)
else:
continue
if i in self.target_layers:
en_list.append(x)
side = int(math.sqrt(en_list[0].shape[1] - 1 - self.encoder.num_register_tokens))
if self.remove_class_token:
en_list = [e[:, 1 + self.encoder.num_register_tokens:, :] for e in en_list]
x = self.fuse_feature(en_list)
agg_prototype = self.prototype_token
for i, blk in enumerate(self.aggregation):
agg_prototype = blk(agg_prototype.unsqueeze(0).repeat((B, 1, 1)), x)
g_loss = self.gather_loss(x, agg_prototype)
for i, blk in enumerate(self.bottleneck):
x = blk(x)
de_list = []
for i, blk in enumerate(self.decoder):
x = blk(x, agg_prototype)
de_list.append(x)
de_list = de_list[::-1]
en = [self.fuse_feature([en_list[idx] for idx in idxs]) for idxs in self.fuse_layer_encoder]
de = [self.fuse_feature([de_list[idx] for idx in idxs]) for idxs in self.fuse_layer_decoder]
if not self.remove_class_token: # class tokens have not been removed above
en = [e[:, 1 + self.encoder.num_register_tokens:, :] for e in en]
de = [d[:, 1 + self.encoder.num_register_tokens:, :] for d in de]
en = [e.permute(0, 2, 1).reshape([x.shape[0], -1, side, side]).contiguous() for e in en]
de = [d.permute(0, 2, 1).reshape([x.shape[0], -1, side, side]).contiguous() for d in de]
return en, de, g_loss
def fuse_feature(self, feat_list):
return torch.stack(feat_list, dim=1).mean(dim=1)