smoothjazzuser commited on
Commit
91969e1
1 Parent(s): 1d9880f

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. ex1.jpg +0 -0
  3. ex2.jpg +3 -0
  4. ex3.jpg +0 -0
  5. ex4.jpg +0 -0
  6. gradio.py +169 -0
  7. requirements.txt +8 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ ex2.jpg filter=lfs diff=lfs merge=lfs -text
ex1.jpg ADDED
ex2.jpg ADDED

Git LFS Details

  • SHA256: e32eebed8f0a773dc4ecd0fb126f196bfc425618e1f22e466c14e029b617ede2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.87 MB
ex3.jpg ADDED
ex4.jpg ADDED
gradio.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings('ignore')
3
+ import torch, numpy as np, os
4
+ from torch import nn
5
+ from transformers import AutoModelForImageClassification, AutoConfig, AutoImageProcessor
6
+ import matplotlib.pyplot as plt
7
+ from PIL import Image
8
+ import saliency.core as saliency
9
+ import io
10
+ import gradio as gr
11
+ import PIL
12
+
13
+ model_choice = 0
14
+ model_names = ["nvidia/mit-b0",'facebook/convnext-base-224', 'microsoft/resnet-18', 'microsoft/swin-tiny-patch4-window7-224']
15
+ model_name = model_names[model_choice]
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+
18
+ class Model(nn.Module):
19
+ def __init__(self, MODEL_NAME=model_name):
20
+ super().__init__()
21
+ self.config = AutoConfig.from_pretrained(MODEL_NAME, finetuning_task="image-classification")
22
+ self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
23
+ self.class_len = self.config.num_labels
24
+ self.id2label = self.config.id2label
25
+ self.label2id = self.config.label2id
26
+
27
+ def forward(self, x):
28
+ if isinstance(x, np.ndarray):
29
+ x = torch.from_numpy(x)
30
+ if len(x.shape) == 3:
31
+ x = x.unsqueeze(0)
32
+ if x.shape[-1] == 3:
33
+ x = x.permute(0, 3, 1, 2)
34
+ x = x.to(device)
35
+ x = self.model(x)
36
+ return x.logits
37
+
38
+ def conv_layer_forward_hook(module, input, output):
39
+ """Method from Examples_pytorch.ipynb for the gradcam library https://github.com/PAIR-code/saliency."""
40
+ global last_conv_layer_outputs
41
+ last_conv_layer_outputs[saliency.base.CONVOLUTION_LAYER_VALUES] = torch.movedim(output, 3, 1).detach().cpu().numpy()
42
+ def conv_layer_backward_hook(module, grad_input, grad_output):
43
+ """Method from Examples_pytorch.ipynb for the gradcam library https://github.com/PAIR-code/saliency."""
44
+ global last_conv_layer_outputs
45
+ last_conv_layer_outputs[saliency.base.CONVOLUTION_OUTPUT_GRADIENTS] = torch.movedim(grad_output[0], 3, 1).detach().cpu().numpy()
46
+
47
+ auto_transformer, class_to_id, id_to_class, last_conv_layer, last_conv_layer_outputs = None, None, None, None, None
48
+
49
+
50
+ def swap_models(name):
51
+ global model, auto_transformer, class_to_id, id_to_class, last_conv_layer, last_conv_layer_outputs
52
+ auto_transformer = AutoImageProcessor.from_pretrained(name)
53
+ model = Model(MODEL_NAME=name)
54
+ model = model.to(device).eval()
55
+ # register the hooks for the last convolution layer for Grad-Cam
56
+ named_modules = dict(model.model.named_modules())
57
+ last_conv_layer_name = None
58
+ for name, module in named_modules.items():
59
+ if isinstance(module, torch.nn.Conv2d):
60
+ last_conv_layer_name = name
61
+
62
+ last_conv_layer = named_modules[last_conv_layer_name]
63
+ last_conv_layer_outputs = {}
64
+
65
+ last_conv_layer.register_forward_hook(conv_layer_forward_hook)
66
+ last_conv_layer.register_backward_hook(conv_layer_backward_hook)
67
+ class_to_id = {v:k for k,v in model.model.config.id2label.items()}
68
+ id_to_class = {k:v for k,v in model.model.config.id2label.items()}
69
+
70
+ swap_models(model_name)
71
+
72
+ def saliency_graph(img1, steps=120):
73
+ img1 = auto_transformer(img1)
74
+ img1 = np.squeeze(np.array(img1.pixel_values))
75
+ if img1.shape[0] < img1.shape[1]:
76
+ img1 = np.moveaxis(img1, 0, -1)
77
+ img1 = (img1 - np.min(img1)) / (np.max(img1) - np.min(img1))
78
+
79
+ class_idx_str = 'class_idx_str'
80
+ def gradcam_call(images, call_model_args=None, expected_keys=None):
81
+ if not isinstance(images, np.ndarray) and not isinstance(images, torch.Tensor) and not isinstance(images, PIL.Image.Image):
82
+ # return two blank images
83
+ im1 = np.zeros((224, 224, 3))
84
+ im2 = np.zeros((224, 224, 3))
85
+ return im1, im2
86
+
87
+ if len(images.shape) == 3:
88
+ images = np.expand_dims(images, 0)
89
+ images = torch.tensor(images, dtype=torch.float32)
90
+ images = images.requires_grad_(True)
91
+ target_class_idx = call_model_args[class_idx_str]
92
+ y_pred = model(images)
93
+
94
+ if saliency.base.INPUT_OUTPUT_GRADIENTS in expected_keys:
95
+ out = y_pred[:, target_class_idx]
96
+ # move actual color channel to the 1st dimension
97
+ #images = torch.movedim(images, 3, 1)
98
+ grads = torch.autograd.grad(out, images, grad_outputs=torch.ones_like(out))
99
+ grads = grads[0].detach().cpu().numpy()
100
+ return {saliency.base.INPUT_OUTPUT_GRADIENTS: grads}
101
+ else:
102
+ hot = torch.zeroes_like(y_pred)
103
+ hot[:, target_class_idx] = 1
104
+ model.zero_grad()
105
+ y_pred.backward(gradient=hot, retain_graph=True)
106
+ return last_conv_layer_outputs
107
+
108
+ im = img1.astype(np.float32)
109
+ base = np.zeros(img1.shape)
110
+
111
+ pred = model(torch.from_numpy(im))
112
+ class_pred = pred.argmax(dim=1).item()
113
+ call_model_args = {class_idx_str: class_pred}
114
+ gradients = saliency.IntegratedGradients()
115
+
116
+ s = gradients.GetSmoothedMask(im, gradcam_call, call_model_args, x_steps=steps, x_baseline=base, batch_size=25)
117
+
118
+ smoothgrad_mask_grayscale = saliency.VisualizeImageGrayscale(s)
119
+
120
+ with torch.no_grad():
121
+ output = model.forward(img1)
122
+ output = torch.nn.functional.softmax(output, dim=1)
123
+ output = output.cpu().numpy()
124
+ top_5 = [(id_to_class[int(i)], output[0][i]) for i in np.argsort(output)[0][-5:][::-1]]
125
+
126
+
127
+ # Render the saliency masks.
128
+ fig, ax = plt.subplots(1, 1, figsize=(10, 10))
129
+ ax.barh([x[0] for x in top_5], [x[1] for x in top_5])
130
+ ax.set_title('Top 5 Predictions')
131
+ buf = io.BytesIO()
132
+ fig.savefig(buf, format='jpg')
133
+ buf.seek(0)
134
+ fig_img = Image.open(buf)
135
+ plt.close(fig)
136
+ return smoothgrad_mask_grayscale, fig_img
137
+
138
+ # gradio Interface
139
+ def gradio_interface(img):
140
+ smoothgrad_mask_grayscale, fig_img = saliency_graph(img, steps=25)
141
+ return smoothgrad_mask_grayscale, fig_img
142
+
143
+ with gr.Blocks(title='Looking at the pixels models attend to', description="This function finds the most critical pixels in an image for predicting a class. The best models will ideally make predictions by highlighting the expected object. Poorly generalizable models will often rely on environmental cues instead and forego looking at the most important pixels. Highlighting the most important pixels helps explain/build trust about whether a given model uses the correct features to make its prediction.", live=True) as iface:
144
+ #examples = gr.Examples(examples=["ex1.jpg", "ex2.jpg", "ex3.jpg", "ex4.jpg"], label="Examples", inputs="image", examples_per_page=4)
145
+ gr.Markdown("Choose a model to use for classying images:")
146
+ with gr.Row():
147
+ with gr.Column():
148
+ test_image = gr.Image(label="Input Image", live=True)
149
+ input_btn = gr.Button(label="Classify image")
150
+ model_select_dropdown = gr.Radio(model_names, label="Model to test", interactive=True, default=0)
151
+ with gr.Column():
152
+ output = gr.Image(label="Pixels used for classification")
153
+ output2 = gr.Image(label="Top 5 Predictions")
154
+
155
+ input_btn.click(gradio_interface, test_image, outputs=[output, output2])
156
+ model_select_dropdown.change(swap_models, inputs=[model_select_dropdown])
157
+ examples = gr.Examples(
158
+ examples = [os.path.join('./', x) for x in os.listdir('./') if x.endswith('.jpg')],
159
+ inputs=gr.Image(),
160
+ label="Examples",
161
+ fn=gradio_interface,
162
+ cache_examples=False,
163
+ run_on_click=True,
164
+ postprocess=True,
165
+ preprocess=True,
166
+ outputs=[output, output2])
167
+
168
+
169
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ matplotlib
4
+ pillow
5
+ matplotlib
6
+ saliency
7
+ gradio
8
+ transformers