Spaces:
Build error
Build error
# ------------------------------------------------------------------------ | |
# Modified from DETR (https://github.com/facebookresearch/detr) | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
# ------------------------------------------------------------------------ | |
import timm | |
import torch | |
import copy | |
import torch.nn as nn | |
import torchvision | |
import json | |
from timm.models.hub import download_cached_file | |
from PIL import Image | |
class MyViT(nn.Module): | |
def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False): | |
super().__init__() | |
print('initializing ViT model as backbone using ckpt:', pretrain_path) | |
self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True) | |
# def forward_features(self, x): | |
# x = self.model.patch_embed(x) | |
# cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |
# if self.model.dist_token is None: | |
# x = torch.cat((cls_token, x), dim=1) | |
# else: | |
# x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1) | |
# x = self.model.pos_drop(x + self.model.pos_embed) | |
# x = self.model.blocks(x) | |
# x = self.model.norm(x) | |
# return self.model.pre_logits(x[:, 0]) | |
def forward(self, x): | |
x = self.model.forward(x) | |
return x | |
def timmvit(**kwargs): | |
default_kwargs={} | |
default_kwargs.update(**kwargs) | |
return MyViT(**default_kwargs) | |
def build_transforms(input_size, center_crop=True): | |
transform = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize(input_size * 8 // 7), | |
torchvision.transforms.CenterCrop(input_size), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
return transform | |
def pil_loader(filepath): | |
with Image.open(filepath) as img: | |
img = img.convert('RGB') | |
return img | |
def test_build(): | |
with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f: | |
id2name = json.load(f) | |
img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg') | |
eval_transforms = build_transforms(224) | |
img_t = eval_transforms(img) | |
img_t = img_t[None, :] | |
model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert') | |
# image = torch.rand(1, 3, 224, 224) | |
output = model(img_t) | |
# import pdb;pdb.set_trace() | |
prediction = output.softmax(-1).flatten() | |
_,top5_idx = torch.topk(prediction, 5) | |
# import pdb;pdb.set_trace() | |
print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}) | |
if __name__ == '__main__': | |
test_build() | |