# ------------------------------------------------------------------------ # Modified from DETR (https://github.com/facebookresearch/detr) # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # ------------------------------------------------------------------------ import timm import torch import copy import torch.nn as nn import torchvision import json from timm.models.hub import download_cached_file from PIL import Image class MyViT(nn.Module): def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False): super().__init__() print('initializing ViT model as backbone using ckpt:', pretrain_path) self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True) # def forward_features(self, x): # x = self.model.patch_embed(x) # cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks # if self.model.dist_token is None: # x = torch.cat((cls_token, x), dim=1) # else: # x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) # x = self.model.pos_drop(x + self.model.pos_embed) # x = self.model.blocks(x) # x = self.model.norm(x) # return self.model.pre_logits(x[:, 0]) def forward(self, x): x = self.model.forward(x) return x def timmvit(**kwargs): default_kwargs={} default_kwargs.update(**kwargs) return MyViT(**default_kwargs) def build_transforms(input_size, center_crop=True): transform = torchvision.transforms.Compose([ torchvision.transforms.Resize(input_size * 8 // 7), torchvision.transforms.CenterCrop(input_size), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform def pil_loader(filepath): with Image.open(filepath) as img: img = img.convert('RGB') return img def test_build(): with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f: id2name = json.load(f) img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg') eval_transforms = build_transforms(224) img_t = eval_transforms(img) img_t = img_t[None, :] model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert') # image = torch.rand(1, 3, 224, 224) output = model(img_t) # import pdb;pdb.set_trace() prediction = output.softmax(-1).flatten() _,top5_idx = torch.topk(prediction, 5) # import pdb;pdb.set_trace() print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}) if __name__ == '__main__': test_build()