| | import argparse |
| | from typing import Any, List, Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from timm.models.layers import trunc_normal_ |
| | from copy import deepcopy |
| | import os |
| | import torch.backends.cudnn as cudnn |
| |
|
| | import models.vision_transformer as vits |
| |
|
| | class vit(nn.Module): |
| | |
| | def __init__(self, model_size="base", freeze_transformer=True, pretrained_weights=None): |
| | super(ibotvit, self).__init__() |
| | self.model_size = model_size |
| | self.freeze_transformer = freeze_transformer |
| | self.pretrained_weights = pretrained_weights |
| |
|
| | |
| | n_register_tokens = 4 |
| | |
| | if model_size == "vit_small": |
| | self.embedding_size = 384 |
| | |
| | elif model_size == "vit_base": |
| | self.embedding_size = 768 |
| |
|
| | elif model_size == "vit_large": |
| | self.embedding_size = 1024 |
| | |
| | elif model_size == "giant": |
| | self.embedding_size = 1536 |
| |
|
| | |
| | model = vits.__dict__[model_size](patch_size=16) |
| | self.transformer = deepcopy(model) |
| |
|
| | |
| | if self.freeze_transformer: |
| | for param in self.transformer.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | if self.pretrained_weights and os.path.isfile(self.pretrained_weights): |
| | state_dict = torch.load(self.pretrained_weights, map_location="cpu") |
| | if 'teacher' in state_dict: |
| | state_dict = state_dict['teacher'] |
| | elif 'model' in state_dict: |
| | state_dict = state_dict['model'] |
| |
|
| | |
| | state_dict = { |
| | (k[len("teacher."):] if k.startswith("teacher.") else k): v |
| | for k, v in state_dict.items() |
| | } |
| | state_dict = { |
| | (k[len("backbone."):] if k.startswith("backbone.") else k): v |
| | for k, v in state_dict.items() |
| | } |
| | msg = self.transformer.load_state_dict(state_dict, strict=False) |
| | print(model_size, msg) |
| | |
| |
|
| | def forward(self, x): |
| | x = self.transformer(x) |
| |
|
| | return x |
| |
|
| |
|
| |
|
| | def build_model(args): |
| | |
| | net = vit("vit_base", freeze_transformer=True, pretrained_weights=args.pretrained_weights) |
| | net.cuda() |
| | |
| |
|
| |
|
| | return net |
| |
|