Spaces:
Runtime error
Runtime error
import torch | |
from utils import * | |
import torch.nn.functional as F | |
import dino.vision_transformer as vits | |
import pdb | |
class LambdaLayer(nn.Module): | |
def __init__(self, lambd): | |
super(LambdaLayer, self).__init__() | |
self.lambd = lambd | |
def forward(self, x): | |
return self.lambd(x) | |
class DinoFeaturizer(nn.Module): | |
def __init__(self, dim, cfg): | |
super().__init__() | |
self.cfg = cfg | |
self.dim = dim | |
patch_size = self.cfg.dino_patch_size | |
self.patch_size = patch_size | |
self.feat_type = self.cfg.dino_feat_type | |
arch = self.cfg.model_type | |
self.model = vits.__dict__[arch]( | |
patch_size=patch_size, | |
num_classes=0) | |
for p in self.model.parameters(): | |
p.requires_grad = False | |
# pdb.set_trace() | |
self.model=self.model.cpu() | |
self.model.eval() | |
self.dropout = torch.nn.Dropout2d(p=.1) | |
if arch == "vit_small" and patch_size == 16: | |
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" | |
elif arch == "vit_small" and patch_size == 8: | |
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" | |
elif arch == "vit_base" and patch_size == 16: | |
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" | |
elif arch == "vit_base" and patch_size == 8: | |
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" | |
else: | |
raise ValueError("Unknown arch and patch size") | |
if cfg.pretrained_weights is not None: | |
state_dict = torch.load(cfg.pretrained_weights, map_location="cpu") | |
state_dict = state_dict["teacher"] | |
# remove `module.` prefix | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
# remove `backbone.` prefix induced by multicrop wrapper | |
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
# state_dict = {k.replace("projection_head", "mlp"): v for k, v in state_dict.items()} | |
# state_dict = {k.replace("prototypes", "last_layer"): v for k, v in state_dict.items()} | |
msg = self.model.load_state_dict(state_dict, strict=False) | |
print('Pretrained weights found at {} and loaded with msg: {}'.format(cfg.pretrained_weights, msg)) | |
else: | |
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") | |
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) | |
self.model.load_state_dict(state_dict, strict=True) | |
if arch == "vit_small": | |
self.n_feats = 384 | |
else: | |
self.n_feats = 768 | |
self.cluster1 = self.make_clusterer(self.n_feats) | |
self.proj_type = cfg.projection_type | |
if self.proj_type == "nonlinear": | |
self.cluster2 = self.make_nonlinear_clusterer(self.n_feats) | |
def make_clusterer(self, in_channels): | |
return torch.nn.Sequential( | |
torch.nn.Conv2d(in_channels, self.dim, (1, 1))) # , | |
def make_nonlinear_clusterer(self, in_channels): | |
return torch.nn.Sequential( | |
torch.nn.Conv2d(in_channels, in_channels, (1, 1)), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(in_channels, self.dim, (1, 1))) | |
def forward(self, img, n=1, return_class_feat=False): | |
self.model.eval() | |
with torch.no_grad(): | |
assert (img.shape[2] % self.patch_size == 0) | |
assert (img.shape[3] % self.patch_size == 0) | |
# get selected layer activations | |
feat, attn, qkv = self.model.get_intermediate_feat(img, n=n) | |
feat, attn, qkv = feat[0], attn[0], qkv[0] | |
feat_h = img.shape[2] // self.patch_size | |
feat_w = img.shape[3] // self.patch_size | |
if self.feat_type == "feat": | |
image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2) | |
elif self.feat_type == "KK": | |
image_k = qkv[1, :, :, 1:, :].reshape(feat.shape[0], 6, feat_h, feat_w, -1) | |
B, H, I, J, D = image_k.shape | |
image_feat = image_k.permute(0, 1, 4, 2, 3).reshape(B, H * D, I, J) | |
else: | |
raise ValueError("Unknown feat type:{}".format(self.feat_type)) | |
if return_class_feat: | |
return feat[:, :1, :].reshape(feat.shape[0], 1, 1, -1).permute(0, 3, 1, 2) | |
if self.proj_type is not None: | |
code = self.cluster1(self.dropout(image_feat)) | |
if self.proj_type == "nonlinear": | |
code += self.cluster2(self.dropout(image_feat)) | |
else: | |
code = image_feat | |
if self.cfg.dropout: | |
return self.dropout(image_feat), code | |
else: | |
return image_feat, code | |
class ResizeAndClassify(nn.Module): | |
def __init__(self, dim: int, size: int, n_classes: int): | |
super(ResizeAndClassify, self).__init__() | |
self.size = size | |
self.predictor = torch.nn.Sequential( | |
torch.nn.Conv2d(dim, n_classes, (1, 1)), | |
torch.nn.LogSoftmax(1)) | |
def forward(self, x): | |
return F.interpolate(self.predictor.forward(x), self.size, mode="bilinear", align_corners=False) | |
class ClusterLookup(nn.Module): | |
def __init__(self, dim: int, n_classes: int): | |
super(ClusterLookup, self).__init__() | |
self.n_classes = n_classes | |
self.dim = dim | |
self.clusters = torch.nn.Parameter(torch.randn(n_classes, dim)) | |
def reset_parameters(self): | |
with torch.no_grad(): | |
self.clusters.copy_(torch.randn(self.n_classes, self.dim)) | |
def forward(self, x, alpha, log_probs=False): | |
normed_clusters = F.normalize(self.clusters, dim=1) | |
normed_features = F.normalize(x, dim=1) | |
inner_products = torch.einsum("bchw,nc->bnhw", normed_features, normed_clusters) | |
if alpha is None: | |
cluster_probs = F.one_hot(torch.argmax(inner_products, dim=1), self.clusters.shape[0]) \ | |
.permute(0, 3, 1, 2).to(torch.float32) | |
else: | |
cluster_probs = nn.functional.softmax(inner_products * alpha, dim=1) | |
cluster_loss = -(cluster_probs * inner_products).sum(1).mean() | |
if log_probs: | |
return nn.functional.log_softmax(inner_products * alpha, dim=1) | |
else: | |
return cluster_loss, cluster_probs | |
class FeaturePyramidNet(nn.Module): | |
def _helper(x): | |
# TODO remove this hard coded 56 | |
return F.interpolate(x, 56, mode="bilinear", align_corners=False).unsqueeze(-1) | |
def make_clusterer(self, in_channels): | |
return torch.nn.Sequential( | |
torch.nn.Conv2d(in_channels, self.dim, (1, 1)), | |
LambdaLayer(FeaturePyramidNet._helper)) | |
def make_nonlinear_clusterer(self, in_channels): | |
return torch.nn.Sequential( | |
torch.nn.Conv2d(in_channels, in_channels, (1, 1)), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(in_channels, in_channels, (1, 1)), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(in_channels, self.dim, (1, 1)), | |
LambdaLayer(FeaturePyramidNet._helper)) | |
def __init__(self, granularity, cut_model, dim, continuous): | |
super(FeaturePyramidNet, self).__init__() | |
self.layer_nums = [5, 6, 7] | |
self.spatial_resolutions = [7, 14, 28, 56] | |
self.feat_channels = [2048, 1024, 512, 3] | |
self.extra_channels = [128, 64, 32, 32] | |
self.granularity = granularity | |
self.encoder = NetWithActivations(cut_model, self.layer_nums) | |
self.dim = dim | |
self.continuous = continuous | |
self.n_feats = self.dim | |
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) | |
assert granularity in {1, 2, 3, 4} | |
self.cluster1 = self.make_clusterer(self.feat_channels[0]) | |
self.cluster1_nl = self.make_nonlinear_clusterer(self.feat_channels[0]) | |
if granularity >= 2: | |
# self.conv1 = DoubleConv(self.feat_channels[0], self.extra_channels[0]) | |
# self.conv2 = DoubleConv(self.extra_channels[0] + self.feat_channels[1], self.extra_channels[1]) | |
self.conv2 = DoubleConv(self.feat_channels[0] + self.feat_channels[1], self.extra_channels[1]) | |
self.cluster2 = self.make_clusterer(self.extra_channels[1]) | |
if granularity >= 3: | |
self.conv3 = DoubleConv(self.extra_channels[1] + self.feat_channels[2], self.extra_channels[2]) | |
self.cluster3 = self.make_clusterer(self.extra_channels[2]) | |
if granularity >= 4: | |
self.conv4 = DoubleConv(self.extra_channels[2] + self.feat_channels[3], self.extra_channels[3]) | |
self.cluster4 = self.make_clusterer(self.extra_channels[3]) | |
def c(self, x, y): | |
return torch.cat([x, y], dim=1) | |
def forward(self, x): | |
with torch.no_grad(): | |
feats = self.encoder(x) | |
low_res_feats = feats[self.layer_nums[-1]] | |
all_clusters = [] | |
# all_clusters.append(self.cluster1(low_res_feats) + self.cluster1_nl(low_res_feats)) | |
all_clusters.append(self.cluster1(low_res_feats)) | |
if self.granularity >= 2: | |
# f1 = self.conv1(low_res_feats) | |
# f1_up = self.up(f1) | |
f1_up = self.up(low_res_feats) | |
f2 = self.conv2(self.c(f1_up, feats[self.layer_nums[-2]])) | |
all_clusters.append(self.cluster2(f2)) | |
if self.granularity >= 3: | |
f2_up = self.up(f2) | |
f3 = self.conv3(self.c(f2_up, feats[self.layer_nums[-3]])) | |
all_clusters.append(self.cluster3(f3)) | |
if self.granularity >= 4: | |
f3_up = self.up(f3) | |
final_size = self.spatial_resolutions[-1] | |
f4 = self.conv4(self.c(f3_up, F.interpolate( | |
x, (final_size, final_size), mode="bilinear", align_corners=False))) | |
all_clusters.append(self.cluster4(f4)) | |
avg_code = torch.cat(all_clusters, 4).mean(4) | |
if self.continuous: | |
clusters = avg_code | |
else: | |
clusters = torch.log_softmax(avg_code, 1) | |
return low_res_feats, clusters | |
class DoubleConv(nn.Module): | |
"""(convolution => [BN] => ReLU) * 2""" | |
def __init__(self, in_channels, out_channels, mid_channels=None): | |
super().__init__() | |
if not mid_channels: | |
mid_channels = out_channels | |
self.double_conv = nn.Sequential( | |
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), | |
nn.BatchNorm2d(mid_channels), | |
nn.ReLU(), | |
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU() | |
) | |
def forward(self, x): | |
return self.double_conv(x) | |
def norm(t): | |
return F.normalize(t, dim=1, eps=1e-10) | |
def average_norm(t): | |
return t / t.square().sum(1, keepdim=True).sqrt().mean() | |
def tensor_correlation(a, b): | |
return torch.einsum("nchw,ncij->nhwij", a, b) | |
def sample(t: torch.Tensor, coords: torch.Tensor): | |
return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True) | |
def super_perm(size: int, device: torch.device): | |
perm = torch.randperm(size, device=device, dtype=torch.long) | |
perm[perm == torch.arange(size, device=device)] += 1 | |
return perm % size | |
def sample_nonzero_locations(t, target_size): | |
nonzeros = torch.nonzero(t) | |
coords = torch.zeros(target_size, dtype=nonzeros.dtype, device=nonzeros.device) | |
n = target_size[1] * target_size[2] | |
for i in range(t.shape[0]): | |
selected_nonzeros = nonzeros[nonzeros[:, 0] == i] | |
if selected_nonzeros.shape[0] == 0: | |
selected_coords = torch.randint(t.shape[1], size=(n, 2), device=nonzeros.device) | |
else: | |
selected_coords = selected_nonzeros[torch.randint(len(selected_nonzeros), size=(n,)), 1:] | |
coords[i, :, :, :] = selected_coords.reshape(target_size[1], target_size[2], 2) | |
coords = coords.to(torch.float32) / t.shape[1] | |
coords = coords * 2 - 1 | |
return torch.flip(coords, dims=[-1]) | |
class ContrastiveCorrelationLoss(nn.Module): | |
def __init__(self, cfg, ): | |
super(ContrastiveCorrelationLoss, self).__init__() | |
self.cfg = cfg | |
def standard_scale(self, t): | |
t1 = t - t.mean() | |
t2 = t1 / t1.std() | |
return t2 | |
def helper(self, f1, f2, c1, c2, shift): | |
with torch.no_grad(): | |
# Comes straight from backbone which is currently frozen. this saves mem. | |
fd = tensor_correlation(norm(f1), norm(f2)) | |
if self.cfg.pointwise: | |
old_mean = fd.mean() | |
fd -= fd.mean([3, 4], keepdim=True) | |
fd = fd - fd.mean() + old_mean | |
cd = tensor_correlation(norm(c1), norm(c2)) | |
if self.cfg.zero_clamp: | |
min_val = 0.0 | |
else: | |
min_val = -9999.0 | |
if self.cfg.stabalize: | |
loss = - cd.clamp(min_val, .8) * (fd - shift) | |
else: | |
loss = - cd.clamp(min_val) * (fd - shift) | |
return loss, cd | |
def forward(self, | |
orig_feats: torch.Tensor, orig_feats_pos: torch.Tensor, | |
orig_salience: torch.Tensor, orig_salience_pos: torch.Tensor, | |
orig_code: torch.Tensor, orig_code_pos: torch.Tensor, | |
): | |
coord_shape = [orig_feats.shape[0], self.cfg.feature_samples, self.cfg.feature_samples, 2] | |
if self.cfg.use_salience: | |
coords1_nonzero = sample_nonzero_locations(orig_salience, coord_shape) | |
coords2_nonzero = sample_nonzero_locations(orig_salience_pos, coord_shape) | |
coords1_reg = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1 | |
coords2_reg = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1 | |
mask = (torch.rand(coord_shape[:-1], device=orig_feats.device) > .1).unsqueeze(-1).to(torch.float32) | |
coords1 = coords1_nonzero * mask + coords1_reg * (1 - mask) | |
coords2 = coords2_nonzero * mask + coords2_reg * (1 - mask) | |
else: | |
coords1 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1 | |
coords2 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1 | |
feats = sample(orig_feats, coords1) | |
code = sample(orig_code, coords1) | |
feats_pos = sample(orig_feats_pos, coords2) | |
code_pos = sample(orig_code_pos, coords2) | |
pos_intra_loss, pos_intra_cd = self.helper( | |
feats, feats, code, code, self.cfg.pos_intra_shift) | |
pos_inter_loss, pos_inter_cd = self.helper( | |
feats, feats_pos, code, code_pos, self.cfg.pos_inter_shift) | |
neg_losses = [] | |
neg_cds = [] | |
for i in range(self.cfg.neg_samples): | |
perm_neg = super_perm(orig_feats.shape[0], orig_feats.device) | |
feats_neg = sample(orig_feats[perm_neg], coords2) | |
code_neg = sample(orig_code[perm_neg], coords2) | |
neg_inter_loss, neg_inter_cd = self.helper( | |
feats, feats_neg, code, code_neg, self.cfg.neg_inter_shift) | |
neg_losses.append(neg_inter_loss) | |
neg_cds.append(neg_inter_cd) | |
neg_inter_loss = torch.cat(neg_losses, axis=0) | |
neg_inter_cd = torch.cat(neg_cds, axis=0) | |
return (pos_intra_loss.mean(), | |
pos_intra_cd, | |
pos_inter_loss.mean(), | |
pos_inter_cd, | |
neg_inter_loss, | |
neg_inter_cd) | |
class Decoder(nn.Module): | |
def __init__(self, code_channels, feat_channels): | |
super().__init__() | |
self.linear = torch.nn.Conv2d(code_channels, feat_channels, (1, 1)) | |
self.nonlinear = torch.nn.Sequential( | |
torch.nn.Conv2d(code_channels, code_channels, (1, 1)), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(code_channels, code_channels, (1, 1)), | |
torch.nn.ReLU(), | |
torch.nn.Conv2d(code_channels, feat_channels, (1, 1))) | |
def forward(self, x): | |
return self.linear(x) + self.nonlinear(x) | |
class NetWithActivations(torch.nn.Module): | |
def __init__(self, model, layer_nums): | |
super(NetWithActivations, self).__init__() | |
self.layers = nn.ModuleList(model.children()) | |
self.layer_nums = [] | |
for l in layer_nums: | |
if l < 0: | |
self.layer_nums.append(len(self.layers) + l) | |
else: | |
self.layer_nums.append(l) | |
self.layer_nums = set(sorted(self.layer_nums)) | |
def forward(self, x): | |
activations = {} | |
for ln, l in enumerate(self.layers): | |
x = l(x) | |
if ln in self.layer_nums: | |
activations[ln] = x | |
return activations | |
class ContrastiveCRFLoss(nn.Module): | |
def __init__(self, n_samples, alpha, beta, gamma, w1, w2, shift): | |
super(ContrastiveCRFLoss, self).__init__() | |
self.alpha = alpha | |
self.beta = beta | |
self.gamma = gamma | |
self.w1 = w1 | |
self.w2 = w2 | |
self.n_samples = n_samples | |
self.shift = shift | |
def forward(self, guidance, clusters): | |
device = clusters.device | |
assert (guidance.shape[0] == clusters.shape[0]) | |
assert (guidance.shape[2:] == clusters.shape[2:]) | |
h = guidance.shape[2] | |
w = guidance.shape[3] | |
coords = torch.cat([ | |
torch.randint(0, h, size=[1, self.n_samples], device=device), | |
torch.randint(0, w, size=[1, self.n_samples], device=device)], 0) | |
selected_guidance = guidance[:, :, coords[0, :], coords[1, :]] | |
coord_diff = (coords.unsqueeze(-1) - coords.unsqueeze(1)).square().sum(0).unsqueeze(0) | |
guidance_diff = (selected_guidance.unsqueeze(-1) - selected_guidance.unsqueeze(2)).square().sum(1) | |
sim_kernel = self.w1 * torch.exp(- coord_diff / (2 * self.alpha) - guidance_diff / (2 * self.beta)) + \ | |
self.w2 * torch.exp(- coord_diff / (2 * self.gamma)) - self.shift | |
selected_clusters = clusters[:, :, coords[0, :], coords[1, :]] | |
cluster_sims = torch.einsum("nka,nkb->nab", selected_clusters, selected_clusters) | |
return -(cluster_sims * sim_kernel) | |