Davidzhangyuanhan commited on
Commit
6ab04f7
1 Parent(s): 1414829

Add application file

Browse files
Files changed (6) hide show
  1. .gitignore +139 -0
  2. 142520422_6ad756ddf6_w_d.jpg +0 -0
  3. README.md +2 -2
  4. app.py +102 -0
  5. timmvit.py +83 -0
  6. trainid2name.json +0 -0
.gitignore ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ **/*.pyc
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ #lib/
19
+ #lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .coverage
43
+ .coverage.*
44
+ .cache
45
+ nosetests.xml
46
+ coverage.xml
47
+ *.cover
48
+ .hypothesis/
49
+ .pytest_cache/
50
+
51
+ # Translations
52
+ *.mo
53
+ *.pot
54
+
55
+ # Django stuff:
56
+ *.log
57
+ local_settings.py
58
+ db.sqlite3
59
+
60
+ # Flask stuff:
61
+ instance/
62
+ .webassets-cache
63
+
64
+ # Scrapy stuff:
65
+ .scrapy
66
+
67
+ # Auto generate documentation
68
+ docs/en/_build/
69
+ docs/en/_model_zoo.rst
70
+ docs/en/modelzoo_statistics.md
71
+ docs/en/papers/
72
+ docs/zh_CN/_build/
73
+ docs/zh_CN/_model_zoo.rst
74
+ docs/zh_CN/modelzoo_statistics.md
75
+ docs/zh_CN/papers/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # pyenv
84
+ .python-version
85
+
86
+ # celery beat schedule file
87
+ celerybeat-schedule
88
+
89
+ # SageMath parsed files
90
+ *.sage.py
91
+
92
+ # Environments
93
+ .env
94
+ .venv
95
+ env/
96
+ venv/
97
+ ENV/
98
+ env.bak/
99
+ venv.bak/
100
+
101
+ # Spyder project settings
102
+ .spyderproject
103
+ .spyproject
104
+
105
+ # Rope project settings
106
+ .ropeproject
107
+
108
+ # mkdocs documentation
109
+ /site
110
+
111
+ # mypy
112
+ .mypy_cache/
113
+
114
+ # custom
115
+ .vscode
116
+ .idea
117
+ *.pkl
118
+ *.pkl.json
119
+ *.log.json
120
+ /work_dirs
121
+ /mmcls/.mim
122
+
123
+ # Pytorch
124
+ *.pth.*
125
+
126
+
127
+ # work_dir
128
+ work_dir
129
+ saves
130
+
131
+ #checkpoint
132
+ weights
133
+
134
+ #logs
135
+ logs
136
+
137
+ #DS_Store
138
+ *DS_Store
139
+
142520422_6ad756ddf6_w_d.jpg ADDED
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Bamboo ViT-B16 Demo
3
- emoji: 💻
4
  colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
@@ -10,4 +10,4 @@ pinned: false
10
  license: cc-by-4.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Bamboo ViT-B16 Demo
3
+ emoji: 🎋
4
  colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
 
10
  license: cc-by-4.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import gradio as gr
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.data import create_transform
12
+
13
+ from timmvit import timmvit
14
+ import json
15
+ from timm.models.hub import download_cached_file
16
+ from PIL import Image
17
+
18
+ def pil_loader(filepath):
19
+ with Image.open(filepath) as img:
20
+ img = img.convert('RGB')
21
+ return img
22
+
23
+ def build_transforms(input_size):
24
+ transform = torchvision.transforms.Compose([
25
+ torchvision.transforms.Resize(input_size * 8 // 7),
26
+ torchvision.transforms.CenterCrop(input_size),
27
+ torchvision.transforms.ToTensor(),
28
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
29
+ ]))
30
+ return transforms
31
+
32
+ # Download human-readable labels for Bamboo.
33
+ with open('./Bamboo_ViT-B16_demo/trainid2name.json') as f:
34
+ id2name = json.load(f)
35
+
36
+
37
+ '''
38
+ build model
39
+ '''
40
+ model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
41
+ model.eval()
42
+
43
+ '''
44
+ build data transform
45
+ '''
46
+ eval_transforms = build_transforms(224)
47
+
48
+ '''
49
+ borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
50
+ '''
51
+ def show_cam_on_image(img: np.ndarray,
52
+ mask: np.ndarray,
53
+ use_rgb: bool = False,
54
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
55
+ """ This function overlays the cam mask on the image as an heatmap.
56
+ By default the heatmap is in BGR format.
57
+ :param img: The base image in RGB or BGR format.
58
+ :param mask: The cam mask.
59
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
60
+ :param colormap: The OpenCV colormap to be used.
61
+ :returns: The default image with the cam overlay.
62
+ """
63
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
64
+ if use_rgb:
65
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
66
+ heatmap = np.float32(heatmap) / 255
67
+
68
+ if np.max(img) > 1:
69
+ raise Exception(
70
+ "The input image should np.float32 in the range [0, 1]")
71
+
72
+ cam = 0.7*heatmap + 0.3*img
73
+ # cam = cam / np.max(cam)
74
+ return np.uint8(255 * cam)
75
+
76
+ def recognize_image(image, texts):
77
+ img_t = eval_transforms(image)
78
+
79
+ # compute output
80
+ output = model(img_t.unsqueeze(0))
81
+ prediction = output.softmax(-1).flatten()
82
+ _,top5_idx = torch.topk(prediction, 5)
83
+ return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
84
+
85
+
86
+ image = gr.inputs.Image()
87
+ label = gr.outputs.Label(num_top_classes=5)
88
+
89
+ gr.Interface(
90
+ description="Bamboo for Zero-shot Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo)",
91
+ fn=recognize_image,
92
+ inputs=["image"],
93
+ outputs=[
94
+ label,
95
+ ],
96
+ # examples=[
97
+ # ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
98
+ # ["./apple_with_ipod.jpg", "an ipod; an apple with a write note 'ipod'; an apple"],
99
+ # ["./crowd2.jpg", "a street; a street with a woman walking in the middle; a street with a man walking in the middle"],
100
+ # ["./zebras.png", "three zebras on the grass; two zebras on the grass; one zebra on the grass; no zebra on the grass; four zebras on the grass"],
101
+ # ],
102
+ ).launch()
timmvit.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # SenseTime VTAB
3
+ # Copyright (c) 2021 SenseTime. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8
+ # ------------------------------------------------------------------------
9
+
10
+ import timm
11
+ import torch
12
+ import copy
13
+ import torch.nn as nn
14
+ import torchvision
15
+ import json
16
+ from timm.models.hub import download_cached_file
17
+ from PIL import Image
18
+
19
+
20
+
21
+ class MyViT(nn.Module):
22
+ def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False):
23
+ super().__init__()
24
+ print('initializing ViT model as backbone using ckpt:', pretrain_path)
25
+ self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True)
26
+ # def forward_features(self, x):
27
+ # x = self.model.patch_embed(x)
28
+ # cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
29
+ # if self.model.dist_token is None:
30
+ # x = torch.cat((cls_token, x), dim=1)
31
+ # else:
32
+ # x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
33
+
34
+ # x = self.model.pos_drop(x + self.model.pos_embed)
35
+ # x = self.model.blocks(x)
36
+ # x = self.model.norm(x)
37
+
38
+ # return self.model.pre_logits(x[:, 0])
39
+
40
+
41
+ def forward(self, x):
42
+ x = self.model.forward(x)
43
+ return x
44
+
45
+
46
+ def timmvit(**kwargs):
47
+ default_kwargs={}
48
+ default_kwargs.update(**kwargs)
49
+ return MyViT(**default_kwargs)
50
+
51
+
52
+ def build_transforms(input_size, center_crop=True):
53
+ transform = torchvision.transforms.Compose([
54
+ torchvision.transforms.Resize(input_size * 8 // 7),
55
+ torchvision.transforms.CenterCrop(input_size),
56
+ torchvision.transforms.ToTensor(),
57
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
58
+ ])
59
+ return transform
60
+
61
+ def pil_loader(filepath):
62
+ with Image.open(filepath) as img:
63
+ img = img.convert('RGB')
64
+ return img
65
+
66
+ def test_build():
67
+ with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f:
68
+ id2name = json.load(f)
69
+ img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg')
70
+ eval_transforms = build_transforms(224)
71
+ img_t = eval_transforms(img)
72
+ img_t = img_t[None, :]
73
+ model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert')
74
+ # image = torch.rand(1, 3, 224, 224)
75
+ output = model(img_t)
76
+ # import pdb;pdb.set_trace()
77
+ prediction = output.softmax(-1).flatten()
78
+ _,top5_idx = torch.topk(prediction, 5)
79
+ # import pdb;pdb.set_trace()
80
+ print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()})
81
+
82
+ if __name__ == '__main__':
83
+ test_build()
trainid2name.json ADDED
The diff for this file is too large to render. See raw diff