Biomap / biomap /modules.py
jeremyLE-Ekimetrics's picture
first commit
5c718d1
raw
history blame
18.4 kB
import torch
from utils import *
import torch.nn.functional as F
import dino.vision_transformer as vits
import pdb
class LambdaLayer(nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class DinoFeaturizer(nn.Module):
def __init__(self, dim, cfg):
super().__init__()
self.cfg = cfg
self.dim = dim
patch_size = self.cfg.dino_patch_size
self.patch_size = patch_size
self.feat_type = self.cfg.dino_feat_type
arch = self.cfg.model_type
self.model = vits.__dict__[arch](
patch_size=patch_size,
num_classes=0)
for p in self.model.parameters():
p.requires_grad = False
# pdb.set_trace()
self.model=self.model.cpu()
self.model.eval()
self.dropout = torch.nn.Dropout2d(p=.1)
if arch == "vit_small" and patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif arch == "vit_small" and patch_size == 8:
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
elif arch == "vit_base" and patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif arch == "vit_base" and patch_size == 8:
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
else:
raise ValueError("Unknown arch and patch size")
if cfg.pretrained_weights is not None:
state_dict = torch.load(cfg.pretrained_weights, map_location="cpu")
state_dict = state_dict["teacher"]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
# state_dict = {k.replace("projection_head", "mlp"): v for k, v in state_dict.items()}
# state_dict = {k.replace("prototypes", "last_layer"): v for k, v in state_dict.items()}
msg = self.model.load_state_dict(state_dict, strict=False)
print('Pretrained weights found at {} and loaded with msg: {}'.format(cfg.pretrained_weights, msg))
else:
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
self.model.load_state_dict(state_dict, strict=True)
if arch == "vit_small":
self.n_feats = 384
else:
self.n_feats = 768
self.cluster1 = self.make_clusterer(self.n_feats)
self.proj_type = cfg.projection_type
if self.proj_type == "nonlinear":
self.cluster2 = self.make_nonlinear_clusterer(self.n_feats)
def make_clusterer(self, in_channels):
return torch.nn.Sequential(
torch.nn.Conv2d(in_channels, self.dim, (1, 1))) # ,
def make_nonlinear_clusterer(self, in_channels):
return torch.nn.Sequential(
torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels, self.dim, (1, 1)))
def forward(self, img, n=1, return_class_feat=False):
self.model.eval()
with torch.no_grad():
assert (img.shape[2] % self.patch_size == 0)
assert (img.shape[3] % self.patch_size == 0)
# get selected layer activations
feat, attn, qkv = self.model.get_intermediate_feat(img, n=n)
feat, attn, qkv = feat[0], attn[0], qkv[0]
feat_h = img.shape[2] // self.patch_size
feat_w = img.shape[3] // self.patch_size
if self.feat_type == "feat":
image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2)
elif self.feat_type == "KK":
image_k = qkv[1, :, :, 1:, :].reshape(feat.shape[0], 6, feat_h, feat_w, -1)
B, H, I, J, D = image_k.shape
image_feat = image_k.permute(0, 1, 4, 2, 3).reshape(B, H * D, I, J)
else:
raise ValueError("Unknown feat type:{}".format(self.feat_type))
if return_class_feat:
return feat[:, :1, :].reshape(feat.shape[0], 1, 1, -1).permute(0, 3, 1, 2)
if self.proj_type is not None:
code = self.cluster1(self.dropout(image_feat))
if self.proj_type == "nonlinear":
code += self.cluster2(self.dropout(image_feat))
else:
code = image_feat
if self.cfg.dropout:
return self.dropout(image_feat), code
else:
return image_feat, code
class ResizeAndClassify(nn.Module):
def __init__(self, dim: int, size: int, n_classes: int):
super(ResizeAndClassify, self).__init__()
self.size = size
self.predictor = torch.nn.Sequential(
torch.nn.Conv2d(dim, n_classes, (1, 1)),
torch.nn.LogSoftmax(1))
def forward(self, x):
return F.interpolate(self.predictor.forward(x), self.size, mode="bilinear", align_corners=False)
class ClusterLookup(nn.Module):
def __init__(self, dim: int, n_classes: int):
super(ClusterLookup, self).__init__()
self.n_classes = n_classes
self.dim = dim
self.clusters = torch.nn.Parameter(torch.randn(n_classes, dim))
def reset_parameters(self):
with torch.no_grad():
self.clusters.copy_(torch.randn(self.n_classes, self.dim))
def forward(self, x, alpha, log_probs=False):
normed_clusters = F.normalize(self.clusters, dim=1)
normed_features = F.normalize(x, dim=1)
inner_products = torch.einsum("bchw,nc->bnhw", normed_features, normed_clusters)
if alpha is None:
cluster_probs = F.one_hot(torch.argmax(inner_products, dim=1), self.clusters.shape[0]) \
.permute(0, 3, 1, 2).to(torch.float32)
else:
cluster_probs = nn.functional.softmax(inner_products * alpha, dim=1)
cluster_loss = -(cluster_probs * inner_products).sum(1).mean()
if log_probs:
return nn.functional.log_softmax(inner_products * alpha, dim=1)
else:
return cluster_loss, cluster_probs
class FeaturePyramidNet(nn.Module):
@staticmethod
def _helper(x):
# TODO remove this hard coded 56
return F.interpolate(x, 56, mode="bilinear", align_corners=False).unsqueeze(-1)
def make_clusterer(self, in_channels):
return torch.nn.Sequential(
torch.nn.Conv2d(in_channels, self.dim, (1, 1)),
LambdaLayer(FeaturePyramidNet._helper))
def make_nonlinear_clusterer(self, in_channels):
return torch.nn.Sequential(
torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels, self.dim, (1, 1)),
LambdaLayer(FeaturePyramidNet._helper))
def __init__(self, granularity, cut_model, dim, continuous):
super(FeaturePyramidNet, self).__init__()
self.layer_nums = [5, 6, 7]
self.spatial_resolutions = [7, 14, 28, 56]
self.feat_channels = [2048, 1024, 512, 3]
self.extra_channels = [128, 64, 32, 32]
self.granularity = granularity
self.encoder = NetWithActivations(cut_model, self.layer_nums)
self.dim = dim
self.continuous = continuous
self.n_feats = self.dim
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
assert granularity in {1, 2, 3, 4}
self.cluster1 = self.make_clusterer(self.feat_channels[0])
self.cluster1_nl = self.make_nonlinear_clusterer(self.feat_channels[0])
if granularity >= 2:
# self.conv1 = DoubleConv(self.feat_channels[0], self.extra_channels[0])
# self.conv2 = DoubleConv(self.extra_channels[0] + self.feat_channels[1], self.extra_channels[1])
self.conv2 = DoubleConv(self.feat_channels[0] + self.feat_channels[1], self.extra_channels[1])
self.cluster2 = self.make_clusterer(self.extra_channels[1])
if granularity >= 3:
self.conv3 = DoubleConv(self.extra_channels[1] + self.feat_channels[2], self.extra_channels[2])
self.cluster3 = self.make_clusterer(self.extra_channels[2])
if granularity >= 4:
self.conv4 = DoubleConv(self.extra_channels[2] + self.feat_channels[3], self.extra_channels[3])
self.cluster4 = self.make_clusterer(self.extra_channels[3])
def c(self, x, y):
return torch.cat([x, y], dim=1)
def forward(self, x):
with torch.no_grad():
feats = self.encoder(x)
low_res_feats = feats[self.layer_nums[-1]]
all_clusters = []
# all_clusters.append(self.cluster1(low_res_feats) + self.cluster1_nl(low_res_feats))
all_clusters.append(self.cluster1(low_res_feats))
if self.granularity >= 2:
# f1 = self.conv1(low_res_feats)
# f1_up = self.up(f1)
f1_up = self.up(low_res_feats)
f2 = self.conv2(self.c(f1_up, feats[self.layer_nums[-2]]))
all_clusters.append(self.cluster2(f2))
if self.granularity >= 3:
f2_up = self.up(f2)
f3 = self.conv3(self.c(f2_up, feats[self.layer_nums[-3]]))
all_clusters.append(self.cluster3(f3))
if self.granularity >= 4:
f3_up = self.up(f3)
final_size = self.spatial_resolutions[-1]
f4 = self.conv4(self.c(f3_up, F.interpolate(
x, (final_size, final_size), mode="bilinear", align_corners=False)))
all_clusters.append(self.cluster4(f4))
avg_code = torch.cat(all_clusters, 4).mean(4)
if self.continuous:
clusters = avg_code
else:
clusters = torch.log_softmax(avg_code, 1)
return low_res_feats, clusters
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
return self.double_conv(x)
def norm(t):
return F.normalize(t, dim=1, eps=1e-10)
def average_norm(t):
return t / t.square().sum(1, keepdim=True).sqrt().mean()
def tensor_correlation(a, b):
return torch.einsum("nchw,ncij->nhwij", a, b)
def sample(t: torch.Tensor, coords: torch.Tensor):
return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True)
@torch.jit.script
def super_perm(size: int, device: torch.device):
perm = torch.randperm(size, device=device, dtype=torch.long)
perm[perm == torch.arange(size, device=device)] += 1
return perm % size
def sample_nonzero_locations(t, target_size):
nonzeros = torch.nonzero(t)
coords = torch.zeros(target_size, dtype=nonzeros.dtype, device=nonzeros.device)
n = target_size[1] * target_size[2]
for i in range(t.shape[0]):
selected_nonzeros = nonzeros[nonzeros[:, 0] == i]
if selected_nonzeros.shape[0] == 0:
selected_coords = torch.randint(t.shape[1], size=(n, 2), device=nonzeros.device)
else:
selected_coords = selected_nonzeros[torch.randint(len(selected_nonzeros), size=(n,)), 1:]
coords[i, :, :, :] = selected_coords.reshape(target_size[1], target_size[2], 2)
coords = coords.to(torch.float32) / t.shape[1]
coords = coords * 2 - 1
return torch.flip(coords, dims=[-1])
class ContrastiveCorrelationLoss(nn.Module):
def __init__(self, cfg, ):
super(ContrastiveCorrelationLoss, self).__init__()
self.cfg = cfg
def standard_scale(self, t):
t1 = t - t.mean()
t2 = t1 / t1.std()
return t2
def helper(self, f1, f2, c1, c2, shift):
with torch.no_grad():
# Comes straight from backbone which is currently frozen. this saves mem.
fd = tensor_correlation(norm(f1), norm(f2))
if self.cfg.pointwise:
old_mean = fd.mean()
fd -= fd.mean([3, 4], keepdim=True)
fd = fd - fd.mean() + old_mean
cd = tensor_correlation(norm(c1), norm(c2))
if self.cfg.zero_clamp:
min_val = 0.0
else:
min_val = -9999.0
if self.cfg.stabalize:
loss = - cd.clamp(min_val, .8) * (fd - shift)
else:
loss = - cd.clamp(min_val) * (fd - shift)
return loss, cd
def forward(self,
orig_feats: torch.Tensor, orig_feats_pos: torch.Tensor,
orig_salience: torch.Tensor, orig_salience_pos: torch.Tensor,
orig_code: torch.Tensor, orig_code_pos: torch.Tensor,
):
coord_shape = [orig_feats.shape[0], self.cfg.feature_samples, self.cfg.feature_samples, 2]
if self.cfg.use_salience:
coords1_nonzero = sample_nonzero_locations(orig_salience, coord_shape)
coords2_nonzero = sample_nonzero_locations(orig_salience_pos, coord_shape)
coords1_reg = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
coords2_reg = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
mask = (torch.rand(coord_shape[:-1], device=orig_feats.device) > .1).unsqueeze(-1).to(torch.float32)
coords1 = coords1_nonzero * mask + coords1_reg * (1 - mask)
coords2 = coords2_nonzero * mask + coords2_reg * (1 - mask)
else:
coords1 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
coords2 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
feats = sample(orig_feats, coords1)
code = sample(orig_code, coords1)
feats_pos = sample(orig_feats_pos, coords2)
code_pos = sample(orig_code_pos, coords2)
pos_intra_loss, pos_intra_cd = self.helper(
feats, feats, code, code, self.cfg.pos_intra_shift)
pos_inter_loss, pos_inter_cd = self.helper(
feats, feats_pos, code, code_pos, self.cfg.pos_inter_shift)
neg_losses = []
neg_cds = []
for i in range(self.cfg.neg_samples):
perm_neg = super_perm(orig_feats.shape[0], orig_feats.device)
feats_neg = sample(orig_feats[perm_neg], coords2)
code_neg = sample(orig_code[perm_neg], coords2)
neg_inter_loss, neg_inter_cd = self.helper(
feats, feats_neg, code, code_neg, self.cfg.neg_inter_shift)
neg_losses.append(neg_inter_loss)
neg_cds.append(neg_inter_cd)
neg_inter_loss = torch.cat(neg_losses, axis=0)
neg_inter_cd = torch.cat(neg_cds, axis=0)
return (pos_intra_loss.mean(),
pos_intra_cd,
pos_inter_loss.mean(),
pos_inter_cd,
neg_inter_loss,
neg_inter_cd)
class Decoder(nn.Module):
def __init__(self, code_channels, feat_channels):
super().__init__()
self.linear = torch.nn.Conv2d(code_channels, feat_channels, (1, 1))
self.nonlinear = torch.nn.Sequential(
torch.nn.Conv2d(code_channels, code_channels, (1, 1)),
torch.nn.ReLU(),
torch.nn.Conv2d(code_channels, code_channels, (1, 1)),
torch.nn.ReLU(),
torch.nn.Conv2d(code_channels, feat_channels, (1, 1)))
def forward(self, x):
return self.linear(x) + self.nonlinear(x)
class NetWithActivations(torch.nn.Module):
def __init__(self, model, layer_nums):
super(NetWithActivations, self).__init__()
self.layers = nn.ModuleList(model.children())
self.layer_nums = []
for l in layer_nums:
if l < 0:
self.layer_nums.append(len(self.layers) + l)
else:
self.layer_nums.append(l)
self.layer_nums = set(sorted(self.layer_nums))
def forward(self, x):
activations = {}
for ln, l in enumerate(self.layers):
x = l(x)
if ln in self.layer_nums:
activations[ln] = x
return activations
class ContrastiveCRFLoss(nn.Module):
def __init__(self, n_samples, alpha, beta, gamma, w1, w2, shift):
super(ContrastiveCRFLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.w1 = w1
self.w2 = w2
self.n_samples = n_samples
self.shift = shift
def forward(self, guidance, clusters):
device = clusters.device
assert (guidance.shape[0] == clusters.shape[0])
assert (guidance.shape[2:] == clusters.shape[2:])
h = guidance.shape[2]
w = guidance.shape[3]
coords = torch.cat([
torch.randint(0, h, size=[1, self.n_samples], device=device),
torch.randint(0, w, size=[1, self.n_samples], device=device)], 0)
selected_guidance = guidance[:, :, coords[0, :], coords[1, :]]
coord_diff = (coords.unsqueeze(-1) - coords.unsqueeze(1)).square().sum(0).unsqueeze(0)
guidance_diff = (selected_guidance.unsqueeze(-1) - selected_guidance.unsqueeze(2)).square().sum(1)
sim_kernel = self.w1 * torch.exp(- coord_diff / (2 * self.alpha) - guidance_diff / (2 * self.beta)) + \
self.w2 * torch.exp(- coord_diff / (2 * self.gamma)) - self.shift
selected_clusters = clusters[:, :, coords[0, :], coords[1, :]]
cluster_sims = torch.einsum("nka,nkb->nab", selected_clusters, selected_clusters)
return -(cluster_sims * sim_kernel)