Patrick Ramos commited on
Commit
5345282
1 Parent(s): 461e0e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import ViTFeatureExtractor, ViTModel
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ from skops import hub_utils
6
+ from einops import reduce
7
+ import seaborn as sns
8
+ import pickle
9
+
10
+ labels = [
11
+ 'tench',
12
+ 'English springer',
13
+ 'cassette player',
14
+ 'chain saw',
15
+ 'church',
16
+ 'French horn',
17
+ 'garbage truck',
18
+ 'gas pump',
19
+ 'golf ball',
20
+ 'parachute'
21
+ ]
22
+
23
+ # load DINO
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16')
26
+ model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device)
27
+
28
+ # load logistic regression
29
+ !mkdir emb-gam-dino
30
+ hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')
31
+
32
+ with open('emb-gam-dino/model.pkl', 'rb') as file:
33
+ logistic_regression = pickle.load(file)
34
+
35
+ def classify_and_heatmap(input_img):
36
+ # get patch embeddings
37
+ inputs = {k: v.to(device) for k, v in feature_extractor(input_img, return_tensors='pt').items()}
38
+ with torch.no_grad():
39
+ patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
40
+
41
+ # get scores
42
+ scores = dict(zip(
43
+ labels,
44
+ logistic_regression.predict_proba(reduce(patch_embeddings, 'p d -> () d', 'sum'))[0]
45
+ ))
46
+
47
+ # make plot
48
+ num_patches_side = model.config.image_size // model.config.patch_size
49
+
50
+ # set up figure
51
+ fig, axs = plt.subplots(2, 6, figsize=(12, 5))
52
+ gs = axs[0, 0].get_gridspec()
53
+ for ax in axs[:, 0]:
54
+ ax.remove()
55
+ ax_orig_img = fig.add_subplot(gs[:, 0])
56
+
57
+ # plot original image
58
+ img = feature_extractor.to_pil_image(
59
+ inputs['pixel_values'].squeeze(0) * torch.tensor(feature_extractor.image_std).view(-1, 1, 1) + torch.tensor(feature_extractor.image_mean).view(-1, 1, 1)
60
+ )
61
+ ax_orig_img.imshow(img)
62
+ ax_orig_img.axis('off')
63
+
64
+ # plot patch contributions
65
+ patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
66
+ vmin = patch_contributions.min()
67
+ vmax = patch_contributions.max()
68
+
69
+ # print(len(list(axs[:, 1:].flat)))
70
+ for i, ax in enumerate(axs[:, 1:].flat):
71
+ sns.heatmap(
72
+ patch_contributions[i].reshape(num_patches_side, num_patches_side),
73
+ ax=ax,
74
+ square=True,
75
+ vmin=vmin,
76
+ vmax=vmax,
77
+ )
78
+ ax.set_title(labels[i])
79
+ ax.set_xlabel(f'score={patch_contributions[i].sum():.2f}')
80
+ ax.set_xticks([])
81
+ ax.set_yticks([])
82
+
83
+ return scores, plt
84
+
85
+ description='''
86
+ This demo is a simple extension of [Emb-GAM (Singh & Gao, 2022)](https://arxiv.org/abs/2209.11799) to images. It does image classification on [Imagenette](https://github.com/fastai/imagenette) and visualizes the contrbutions of each image patch to each label.
87
+ '''
88
+
89
+ article='''
90
+ Under the hood, we use [DINO](https://arxiv.org/abs/2104.14294) to extract patch embeddings and a logistic regression model following the set up of the [offical Emb-GAM implementation](https://github.com/csinva/emb-gam).
91
+
92
+ Citation for stuff involved (not our papers):
93
+ ```bibtex
94
+ @article{singh2022emb,
95
+ title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
96
+ author={Singh, Chandan and Gao, Jianfeng},
97
+ journal={arXiv preprint arXiv:2209.11799},
98
+ year={2022}
99
+ }
100
+
101
+ @InProceedings{Caron_2021_ICCV,
102
+ author = {Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
103
+ title = {Emerging Properties in Self-Supervised Vision Transformers},
104
+ booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
105
+ month = {October},
106
+ year = {2021},
107
+ pages = {9650-9660}
108
+ }
109
+
110
+ @misc{imagenette,
111
+ author = {fast.ai},
112
+ title = {Imagenette},
113
+ url = {https://github.com/fastai/imagenette},
114
+ }
115
+ ```
116
+ '''
117
+
118
+ demo = gr.Interface(
119
+ fn=classify_and_heatmap,
120
+ inputs=gr.Image(shape=(224, 224), type='pil', label='Input Image'),
121
+ outputs=[
122
+ gr.Label(label='Class'),
123
+ gr.Plot(label='Patch Contributions')
124
+ ],
125
+ title='Emb-GAM DINO',
126
+ description=description,
127
+ article=article,
128
+ examples=['./examples/english_springer.png', './examples/golf_ball.png']
129
+ )
130
+ demo.launch(debug=True)