File size: 2,895 Bytes
34d86b5
 
25047b0
34d86b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# ------------------------------------------------------------------------

import timm
import torch
import copy
import torch.nn as nn
import torchvision
import json
from timm.models.hub import download_cached_file
from PIL import Image



class MyViT(nn.Module):
    def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False):
        super().__init__()
        print('initializing ViT model as backbone using ckpt:', pretrain_path)
        self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True)
    # def forward_features(self, x):
    #     x = self.model.patch_embed(x)
    #     cls_token = self.model.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
    #     if self.model.dist_token is None:
    #         x = torch.cat((cls_token, x), dim=1)
    #     else:
    #         x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

    #     x = self.model.pos_drop(x + self.model.pos_embed)
    #     x = self.model.blocks(x)
    #     x = self.model.norm(x)

        # return self.model.pre_logits(x[:, 0])


    def forward(self, x):
        x = self.model.forward(x)
        return x


def timmvit(**kwargs):
    default_kwargs={}
    default_kwargs.update(**kwargs)
    return MyViT(**default_kwargs)


def build_transforms(input_size, center_crop=True):
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(input_size * 8 // 7),
        torchvision.transforms.CenterCrop(input_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
             ])
    return transform

def pil_loader(filepath):
    with Image.open(filepath) as img:
        img = img.convert('RGB')
    return img

def test_build():
    with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f:
        id2name = json.load(f)
    img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg')
    eval_transforms = build_transforms(224)
    img_t = eval_transforms(img) 
    img_t = img_t[None, :]
    model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert')
    # image = torch.rand(1, 3, 224, 224)
    output = model(img_t)
    # import pdb;pdb.set_trace()
    prediction = output.softmax(-1).flatten()
    _,top5_idx = torch.topk(prediction, 5)
    # import pdb;pdb.set_trace()
    print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()})

if __name__ == '__main__':
    test_build()