Davidzhangyuanhan commited on
Commit
34d86b5
β€’
1 Parent(s): cc8b572

Add application file

Browse files
Bamboo_v0-1_ViT-B16.pth.tar.convert ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6d30c823ba2fc764291e65a06747390a81b15a1e655dd02b45d58528e08c937
3
+ size 697651655
README copy.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Bamboo ViT-B16 Demo
3
+ emoji: πŸŽ‹
4
+ colorFrom: blue
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.0.17
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-4.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import gradio as gr
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+ import torchvision
10
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.data import create_transform
12
+
13
+ from timmvit import timmvit
14
+ import json
15
+ from timm.models.hub import download_cached_file
16
+ from PIL import Image
17
+
18
+ def pil_loader(filepath):
19
+ with Image.open(filepath) as img:
20
+ img = img.convert('RGB')
21
+ return img
22
+
23
+ def build_transforms(input_size, center_crop=True):
24
+ transform = torchvision.transforms.Compose([
25
+ torchvision.transforms.ToPILImage(),
26
+ torchvision.transforms.Resize(input_size * 8 // 7),
27
+ torchvision.transforms.CenterCrop(input_size),
28
+ torchvision.transforms.ToTensor(),
29
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30
+ ])
31
+ return transform
32
+
33
+ # Download human-readable labels for Bamboo.
34
+ with open('./trainid2name.json') as f:
35
+ id2name = json.load(f)
36
+
37
+
38
+ '''
39
+ build model
40
+ '''
41
+ model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
42
+ model.eval()
43
+
44
+ '''
45
+ borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
46
+ '''
47
+ def show_cam_on_image(img: np.ndarray,
48
+ mask: np.ndarray,
49
+ use_rgb: bool = False,
50
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
51
+ """ This function overlays the cam mask on the image as an heatmap.
52
+ By default the heatmap is in BGR format.
53
+ :param img: The base image in RGB or BGR format.
54
+ :param mask: The cam mask.
55
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
56
+ :param colormap: The OpenCV colormap to be used.
57
+ :returns: The default image with the cam overlay.
58
+ """
59
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
60
+ if use_rgb:
61
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
62
+ heatmap = np.float32(heatmap) / 255
63
+
64
+ if np.max(img) > 1:
65
+ raise Exception(
66
+ "The input image should np.float32 in the range [0, 1]")
67
+
68
+ cam = 0.7*heatmap + 0.3*img
69
+ # cam = cam / np.max(cam)
70
+ return np.uint8(255 * cam)
71
+
72
+
73
+
74
+
75
+ def recognize_image(image):
76
+ img_t = eval_transforms(image)
77
+ # compute output
78
+ output = model(img_t.unsqueeze(0))
79
+ prediction = output.softmax(-1).flatten()
80
+ _,top5_idx = torch.topk(prediction, 5)
81
+ return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
82
+
83
+ eval_transforms = build_transforms(224)
84
+
85
+
86
+ image = gr.inputs.Image()
87
+ label = gr.outputs.Label(num_top_classes=5)
88
+
89
+ gr.Interface(
90
+ description="Bamboo for Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo). Bamboo knows what this object is and what you are doing in a very fine-grain granularity: fratercula arctica (fig.5) and dribbler (fig.2)).",
91
+ fn=recognize_image,
92
+ inputs=["image"],
93
+ outputs=[
94
+ label,
95
+ ],
96
+ examples=[
97
+ ["./examples/playing_mahjong.jpg"],
98
+ ["./examples/dribbler.jpg"],
99
+ ["./examples/Ferrari-F355.jpg"],
100
+ ["./examples/northern_oriole.jpg"],
101
+ ["./examples/fratercula_arctica.jpg"],
102
+ ["./examples/husky.jpg"],
103
+ ["./examples/taraxacum_erythrospermum.jpg"],
104
+ ],
105
+ ).launch()
examples/Ferrari-F355.jpg ADDED
examples/basketball.jpg ADDED
examples/dribbler.jpg ADDED
examples/fratercula_arctica.jpg ADDED
examples/husky.jpg ADDED
examples/northern_oriole.jpg ADDED
examples/playing_mahjong.jpg ADDED
examples/taraxacum_erythrospermum.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ torchvision==0.11.2
2
+ torch==1.10.1
3
+ opencv-python-headless==4.5.3.56
4
+ timm==0.4.12
5
+ numpy
6
+
timmvit.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # SenseTime VTAB
3
+ # Copyright (c) 2021 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ import timm
11
+ import torch
12
+ import copy
13
+ import torch.nn as nn
14
+ import torchvision
15
+ import json
16
+ from timm.models.hub import download_cached_file
17
+ from PIL import Image
18
+
19
+
20
+
21
+ class MyViT(nn.Module):
22
+ def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False):
23
+ super().__init__()
24
+ print('initializing ViT model as backbone using ckpt:', pretrain_path)
25
+ self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True)
26
+ # def forward_features(self, x):
27
+ # x = self.model.patch_embed(x)
28
+ # cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
29
+ # if self.model.dist_token is None:
30
+ # x = torch.cat((cls_token, x), dim=1)
31
+ # else:
32
+ # x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
33
+
34
+ # x = self.model.pos_drop(x + self.model.pos_embed)
35
+ # x = self.model.blocks(x)
36
+ # x = self.model.norm(x)
37
+
38
+ # return self.model.pre_logits(x[:, 0])
39
+
40
+
41
+ def forward(self, x):
42
+ x = self.model.forward(x)
43
+ return x
44
+
45
+
46
+ def timmvit(**kwargs):
47
+ default_kwargs={}
48
+ default_kwargs.update(**kwargs)
49
+ return MyViT(**default_kwargs)
50
+
51
+
52
+ def build_transforms(input_size, center_crop=True):
53
+ transform = torchvision.transforms.Compose([
54
+ torchvision.transforms.Resize(input_size * 8 // 7),
55
+ torchvision.transforms.CenterCrop(input_size),
56
+ torchvision.transforms.ToTensor(),
57
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
58
+ ])
59
+ return transform
60
+
61
+ def pil_loader(filepath):
62
+ with Image.open(filepath) as img:
63
+ img = img.convert('RGB')
64
+ return img
65
+
66
+ def test_build():
67
+ with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f:
68
+ id2name = json.load(f)
69
+ img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg')
70
+ eval_transforms = build_transforms(224)
71
+ img_t = eval_transforms(img)
72
+ img_t = img_t[None, :]
73
+ model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert')
74
+ # image = torch.rand(1, 3, 224, 224)
75
+ output = model(img_t)
76
+ # import pdb;pdb.set_trace()
77
+ prediction = output.softmax(-1).flatten()
78
+ _,top5_idx = torch.topk(prediction, 5)
79
+ # import pdb;pdb.set_trace()
80
+ print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()})
81
+
82
+ if __name__ == '__main__':
83
+ test_build()
trainid2name.json ADDED
The diff for this file is too large to render. See raw diff