Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
class INP_Former(nn.Module): | |
def __init__( | |
self, | |
encoder, | |
bottleneck, | |
aggregation, | |
decoder, | |
target_layers =[2, 3, 4, 5, 6, 7, 8, 9], | |
fuse_layer_encoder =[[0, 1, 2, 3, 4, 5, 6, 7]], | |
fuse_layer_decoder =[[0, 1, 2, 3, 4, 5, 6, 7]], | |
remove_class_token=False, | |
encoder_require_grad_layer=[], | |
prototype_token=None, | |
) -> None: | |
super(INP_Former, self).__init__() | |
self.encoder = encoder | |
self.bottleneck = bottleneck | |
self.aggregation = aggregation | |
self.decoder = decoder | |
self.target_layers = target_layers | |
self.fuse_layer_encoder = fuse_layer_encoder | |
self.fuse_layer_decoder = fuse_layer_decoder | |
self.remove_class_token = remove_class_token | |
self.encoder_require_grad_layer = encoder_require_grad_layer | |
self.prototype_token = prototype_token[0] | |
if not hasattr(self.encoder, 'num_register_tokens'): | |
self.encoder.num_register_tokens = 0 | |
def gather_loss(self, query, keys): | |
self.distribution = 1. - F.cosine_similarity(query.unsqueeze(2), keys.unsqueeze(1), dim=-1) | |
self.distance, self.cluster_index = torch.min(self.distribution, dim=2) | |
gather_loss = self.distance.mean() | |
return gather_loss | |
def forward(self, x): | |
x = self.encoder.prepare_tokens(x) | |
B, L, _ = x.shape | |
en_list = [] | |
for i, blk in enumerate(self.encoder.blocks): | |
if i <= self.target_layers[-1]: | |
if i in self.encoder_require_grad_layer: | |
x = blk(x) | |
else: | |
with torch.no_grad(): | |
x = blk(x) | |
else: | |
continue | |
if i in self.target_layers: | |
en_list.append(x) | |
side = int(math.sqrt(en_list[0].shape[1] - 1 - self.encoder.num_register_tokens)) | |
if self.remove_class_token: | |
en_list = [e[:, 1 + self.encoder.num_register_tokens:, :] for e in en_list] | |
x = self.fuse_feature(en_list) | |
agg_prototype = self.prototype_token | |
for i, blk in enumerate(self.aggregation): | |
agg_prototype = blk(agg_prototype.unsqueeze(0).repeat((B, 1, 1)), x) | |
g_loss = self.gather_loss(x, agg_prototype) | |
for i, blk in enumerate(self.bottleneck): | |
x = blk(x) | |
de_list = [] | |
for i, blk in enumerate(self.decoder): | |
x = blk(x, agg_prototype) | |
de_list.append(x) | |
de_list = de_list[::-1] | |
en = [self.fuse_feature([en_list[idx] for idx in idxs]) for idxs in self.fuse_layer_encoder] | |
de = [self.fuse_feature([de_list[idx] for idx in idxs]) for idxs in self.fuse_layer_decoder] | |
if not self.remove_class_token: # class tokens have not been removed above | |
en = [e[:, 1 + self.encoder.num_register_tokens:, :] for e in en] | |
de = [d[:, 1 + self.encoder.num_register_tokens:, :] for d in de] | |
en = [e.permute(0, 2, 1).reshape([x.shape[0], -1, side, side]).contiguous() for e in en] | |
de = [d.permute(0, 2, 1).reshape([x.shape[0], -1, side, side]).contiguous() for d in de] | |
return en, de, g_loss | |
def fuse_feature(self, feat_list): | |
return torch.stack(feat_list, dim=1).mean(dim=1) | |