|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import dino.vision_transformer as vits |
|
|
|
from bkg_seg import compute_img_bkg_seg |
|
from misc import batch_apply_bilateral_solver |
|
|
|
class FoundModel(nn.Module): |
|
def __init__( |
|
self, |
|
vit_model="dino", |
|
vit_arch="vit_small", |
|
vit_patch_size=8, |
|
enc_type_feats="k", |
|
bkg_type_feats="k", |
|
bkg_th=0.3 |
|
): |
|
|
|
super(FoundModel, self).__init__() |
|
|
|
|
|
|
|
self.vit_encoder, self.initial_dim, self.hook_features = get_vit_encoder( |
|
vit_arch, vit_model, vit_patch_size, enc_type_feats |
|
) |
|
self.vit_patch_size = vit_patch_size |
|
self.enc_type_feats = enc_type_feats |
|
|
|
|
|
|
|
self.bkg_type_feats = bkg_type_feats |
|
self.bkg_th = bkg_th |
|
|
|
|
|
|
|
self.previous_dim = self.initial_dim |
|
self.decoder = nn.Conv2d(self.previous_dim, 1, (1, 1)) |
|
|
|
def forward_step(self, batch, decoder=None, for_eval=False): |
|
|
|
|
|
if for_eval: |
|
batch = self.make_input_divisible(batch) |
|
_w, _h = batch.shape[-2:] |
|
_h, _w = _h // self.vit_patch_size, _w // self.vit_patch_size |
|
else: |
|
|
|
w, h = ( |
|
batch.shape[-2] - batch.shape[-2] % self.vit_patch_size, |
|
batch.shape[-1] - batch.shape[-1] % self.vit_patch_size, |
|
) |
|
batch = batch[:, :, :w, :h] |
|
|
|
w_featmap = batch.shape[-2] // self.vit_patch_size |
|
h_featmap = batch.shape[-1] // self.vit_patch_size |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
att = self.vit_encoder.get_last_selfattention(batch) |
|
|
|
|
|
feats = self.extract_feats(dims=att.shape, type_feats=self.enc_type_feats) |
|
feats = feats[:, 1:, :, :].reshape(att.shape[0], w_featmap, h_featmap, -1) |
|
feats = feats.permute(0, 3, 1, 2) |
|
|
|
|
|
if decoder is None: |
|
decoder = self.decoder |
|
preds = decoder(feats) |
|
|
|
|
|
return preds, feats, (w_featmap, h_featmap), att |
|
|
|
def make_input_divisible(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
"""Pad some pixels to make the input size divisible by the patch size.""" |
|
B, _, H_0, W_0 = x.shape |
|
pad_w = (self.vit_patch_size - W_0 % self.vit_patch_size) % self.vit_patch_size |
|
pad_h = (self.vit_patch_size - H_0 % self.vit_patch_size) % self.vit_patch_size |
|
|
|
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0) |
|
return x |
|
|
|
def compute_background_batch( |
|
self, |
|
att, |
|
shape_f, |
|
|
|
): |
|
|
|
w_f, h_f = shape_f |
|
|
|
|
|
nb_im = att.shape[0] |
|
nh = att.shape[1] |
|
nb_tokens = att.shape[2] |
|
|
|
|
|
feats = self.extract_feats(dims=att.shape, |
|
|
|
type_feats=self.bkg_type_feats |
|
) |
|
feats = feats.reshape(nb_im, nb_tokens, -1) |
|
|
|
bkg_mask = compute_img_bkg_seg( |
|
att, |
|
feats, |
|
(w_f,h_f), |
|
th_bkg=self.bkg_th, |
|
dim=int(self.initial_dim / nh), |
|
) |
|
|
|
return bkg_mask |
|
|
|
|
|
def get_bkg_pseudo_labels_batch( |
|
self, |
|
att, |
|
shape_f, |
|
data, |
|
use_bilateral_solver = True, |
|
shape=None, |
|
): |
|
|
|
bkg_mask_pred = self.compute_background_batch( |
|
att, shape_f |
|
) |
|
|
|
|
|
obj_mask = (~bkg_mask_pred.bool()).float() |
|
|
|
if use_bilateral_solver: |
|
pseudo_labels, cnt_bs = batch_apply_bilateral_solver(data, obj_mask, shape) |
|
return pseudo_labels, cnt_bs |
|
else: |
|
return obj_mask, 0 |
|
|
|
@torch.no_grad() |
|
def decoder_load_weights(self, weights_path): |
|
print(f"Loading model from weights {weights_path}.") |
|
|
|
state_dict = torch.load(weights_path) |
|
|
|
|
|
self.decoder.load_state_dict(state_dict["decoder"]) |
|
self.decoder.eval() |
|
self.decoder.to("cuda") |
|
|
|
|
|
@torch.no_grad() |
|
def decoder_save_weights(self, save_dir, n_iter): |
|
state_dict = {} |
|
state_dict["decoder"] = self.decoder.state_dict() |
|
fname = os.path.join( |
|
save_dir, f"decoder_weights_niter{n_iter}.pt" |
|
) |
|
torch.save(state_dict, fname) |
|
print(f"\n----" |
|
f"\nModel saved at {fname}" |
|
) |
|
|
|
@torch.no_grad() |
|
def extract_feats(self, dims, type_feats="k"): |
|
|
|
nb_im, nh, nb_tokens, _ = dims |
|
qkv = ( |
|
self.hook_features["qkv"] |
|
.reshape( |
|
nb_im, nb_tokens, 3, nh, -1 // nh |
|
) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
if type_feats == "q": |
|
return q.transpose(1, 2).float() |
|
elif type_feats == "k": |
|
return k.transpose(1, 2).float() |
|
elif type_feats == "v": |
|
return v.transpose(1, 2).float() |
|
else: |
|
raise ValueError("Unknown features") |
|
|
|
|
|
def get_vit_encoder(vit_arch, vit_model, vit_patch_size, enc_type_feats): |
|
if vit_arch == "vit_small" and vit_patch_size == 16: |
|
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" |
|
initial_dim = 384 |
|
elif vit_arch == "vit_small" and vit_patch_size == 8: |
|
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" |
|
initial_dim = 384 |
|
elif vit_arch == "vit_base" and vit_patch_size == 16: |
|
if vit_model == "clip": |
|
url = "5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt" |
|
elif vit_model == "dino": |
|
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" |
|
initial_dim = 768 |
|
elif vit_arch == "vit_base" and vit_patch_size == 8: |
|
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" |
|
initial_dim = 768 |
|
|
|
if vit_model == "dino": |
|
vit_encoder = vits.__dict__[vit_arch](patch_size=vit_patch_size, num_classes=0) |
|
|
|
for p in vit_encoder.parameters(): |
|
p.requires_grad = False |
|
vit_encoder.eval().cuda() |
|
state_dict = torch.hub.load_state_dict_from_url( |
|
url="https://dl.fbaipublicfiles.com/dino/" + url |
|
) |
|
vit_encoder.load_state_dict(state_dict, strict=True) |
|
|
|
hook_features = {} |
|
if enc_type_feats in ["k", "q", "v", "qkv", "mlp"]: |
|
|
|
def hook_fn_forward_qkv(module, input, output): |
|
hook_features["qkv"] = output |
|
|
|
vit_encoder._modules["blocks"][-1]._modules["attn"]._modules[ |
|
"qkv" |
|
].register_forward_hook(hook_fn_forward_qkv) |
|
else: |
|
raise ValueError("Not implemented.") |
|
|
|
return vit_encoder, initial_dim, hook_features |