selfmask / utils.py
noelshin's picture
Add application file
35188e4
raw
history blame
No virus
3.67 kB
from argparse import Namespace
from typing import Optional
import torch
def get_model(
arch: str,
patch_size: Optional[int] = None,
training_method: Optional[str] = None,
configs: Optional[Namespace] = None,
**kwargs
):
if arch == "maskformer":
assert configs is not None
from networks.maskformer.maskformer import MaskFormer
model = MaskFormer(
n_queries=configs.n_queries,
n_decoder_layers=configs.n_decoder_layers,
learnable_pixel_decoder=configs.learnable_pixel_decoder,
lateral_connection=configs.lateral_connection,
return_intermediate=configs.loss_every_decoder_layer,
scale_factor=configs.scale_factor,
abs_2d_pe_init=configs.abs_2d_pe_init,
use_binary_classifier=configs.use_binary_classifier,
arch=configs.arch,
training_method=configs.training_method,
patch_size=configs.patch_size
)
for n, p in model.encoder.named_parameters():
p.requires_grad_(True)
elif "vit" in arch:
import networks.vision_transformer as vits
import networks.timm_deit as timm_deit
if training_method == "dino":
arch = arch.replace("vit", "deit") if arch.find("small") != -1 else arch
model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
load_model(model, arch, patch_size)
elif training_method == "deit":
assert patch_size == 16
model = timm_deit.deit_small_distilled_patch16_224(True)
elif training_method == "supervised":
assert patch_size == 16
state_dict: dict = torch.load(
"/users/gyungin/selfmask/networks/pretrained/deit_small_patch16_224-cd65a155.pth"
)["model"]
for k in list(state_dict.keys()):
if k in ["head.weight", "head.bias"]: # classifier head, which is not used in our network
state_dict.pop(k)
model = get_model(arch="vit_small", patch_size=16, training_method="dino")
model.load_state_dict(state_dict=state_dict, strict=True)
else:
raise NotImplementedError
print(f"{arch}_p{patch_size}_{training_method} is built.")
elif arch == "resnet50":
from networks.resnet import ResNet50
assert training_method in ["mocov2", "swav", "supervised"]
model = ResNet50(training_method)
else:
raise ValueError(f"{arch} is not supported arch. Choose from [maskformer, resnet50, vit, dino]")
return model
def load_model(model, arch: str, patch_size: int) -> None:
url = None
if arch == "deit_small" and patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif arch == "deit_small" and patch_size == 8:
# model used for visualizations in our paper
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
elif arch == "vit_base" and patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif arch == "vit_base" and patch_size == 8:
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
if url is not None:
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
else:
print("There is no reference weights available for this model => We use random weights.")