Bamboo_ViT-B16_demo / timmvit.py
Davidzhangyuanhan
Add application file
25047b0
raw
history blame
2.9 kB
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------
import timm
import torch
import copy
import torch.nn as nn
import torchvision
import json
from timm.models.hub import download_cached_file
from PIL import Image
class MyViT(nn.Module):
def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False):
super().__init__()
print('initializing ViT model as backbone using ckpt:', pretrain_path)
self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True)
# def forward_features(self, x):
# x = self.model.patch_embed(x)
# cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
# if self.model.dist_token is None:
# x = torch.cat((cls_token, x), dim=1)
# else:
# x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
# x = self.model.pos_drop(x + self.model.pos_embed)
# x = self.model.blocks(x)
# x = self.model.norm(x)
# return self.model.pre_logits(x[:, 0])
def forward(self, x):
x = self.model.forward(x)
return x
def timmvit(**kwargs):
default_kwargs={}
default_kwargs.update(**kwargs)
return MyViT(**default_kwargs)
def build_transforms(input_size, center_crop=True):
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize(input_size * 8 // 7),
torchvision.transforms.CenterCrop(input_size),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform
def pil_loader(filepath):
with Image.open(filepath) as img:
img = img.convert('RGB')
return img
def test_build():
with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f:
id2name = json.load(f)
img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg')
eval_transforms = build_transforms(224)
img_t = eval_transforms(img)
img_t = img_t[None, :]
model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert')
# image = torch.rand(1, 3, 224, 224)
output = model(img_t)
# import pdb;pdb.set_trace()
prediction = output.softmax(-1).flatten()
_,top5_idx = torch.topk(prediction, 5)
# import pdb;pdb.set_trace()
print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()})
if __name__ == '__main__':
test_build()