smoothjazzuser
commited on
Commit
•
91969e1
1
Parent(s):
1d9880f
Upload 6 files
Browse files- .gitattributes +1 -0
- ex1.jpg +0 -0
- ex2.jpg +3 -0
- ex3.jpg +0 -0
- ex4.jpg +0 -0
- gradio.py +169 -0
- requirements.txt +8 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
ex2.jpg filter=lfs diff=lfs merge=lfs -text
|
ex1.jpg
ADDED
ex2.jpg
ADDED
Git LFS Details
|
ex3.jpg
ADDED
ex4.jpg
ADDED
gradio.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
warnings.filterwarnings('ignore')
|
3 |
+
import torch, numpy as np, os
|
4 |
+
from torch import nn
|
5 |
+
from transformers import AutoModelForImageClassification, AutoConfig, AutoImageProcessor
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
import saliency.core as saliency
|
9 |
+
import io
|
10 |
+
import gradio as gr
|
11 |
+
import PIL
|
12 |
+
|
13 |
+
model_choice = 0
|
14 |
+
model_names = ["nvidia/mit-b0",'facebook/convnext-base-224', 'microsoft/resnet-18', 'microsoft/swin-tiny-patch4-window7-224']
|
15 |
+
model_name = model_names[model_choice]
|
16 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
17 |
+
|
18 |
+
class Model(nn.Module):
|
19 |
+
def __init__(self, MODEL_NAME=model_name):
|
20 |
+
super().__init__()
|
21 |
+
self.config = AutoConfig.from_pretrained(MODEL_NAME, finetuning_task="image-classification")
|
22 |
+
self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
|
23 |
+
self.class_len = self.config.num_labels
|
24 |
+
self.id2label = self.config.id2label
|
25 |
+
self.label2id = self.config.label2id
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
if isinstance(x, np.ndarray):
|
29 |
+
x = torch.from_numpy(x)
|
30 |
+
if len(x.shape) == 3:
|
31 |
+
x = x.unsqueeze(0)
|
32 |
+
if x.shape[-1] == 3:
|
33 |
+
x = x.permute(0, 3, 1, 2)
|
34 |
+
x = x.to(device)
|
35 |
+
x = self.model(x)
|
36 |
+
return x.logits
|
37 |
+
|
38 |
+
def conv_layer_forward_hook(module, input, output):
|
39 |
+
"""Method from Examples_pytorch.ipynb for the gradcam library https://github.com/PAIR-code/saliency."""
|
40 |
+
global last_conv_layer_outputs
|
41 |
+
last_conv_layer_outputs[saliency.base.CONVOLUTION_LAYER_VALUES] = torch.movedim(output, 3, 1).detach().cpu().numpy()
|
42 |
+
def conv_layer_backward_hook(module, grad_input, grad_output):
|
43 |
+
"""Method from Examples_pytorch.ipynb for the gradcam library https://github.com/PAIR-code/saliency."""
|
44 |
+
global last_conv_layer_outputs
|
45 |
+
last_conv_layer_outputs[saliency.base.CONVOLUTION_OUTPUT_GRADIENTS] = torch.movedim(grad_output[0], 3, 1).detach().cpu().numpy()
|
46 |
+
|
47 |
+
auto_transformer, class_to_id, id_to_class, last_conv_layer, last_conv_layer_outputs = None, None, None, None, None
|
48 |
+
|
49 |
+
|
50 |
+
def swap_models(name):
|
51 |
+
global model, auto_transformer, class_to_id, id_to_class, last_conv_layer, last_conv_layer_outputs
|
52 |
+
auto_transformer = AutoImageProcessor.from_pretrained(name)
|
53 |
+
model = Model(MODEL_NAME=name)
|
54 |
+
model = model.to(device).eval()
|
55 |
+
# register the hooks for the last convolution layer for Grad-Cam
|
56 |
+
named_modules = dict(model.model.named_modules())
|
57 |
+
last_conv_layer_name = None
|
58 |
+
for name, module in named_modules.items():
|
59 |
+
if isinstance(module, torch.nn.Conv2d):
|
60 |
+
last_conv_layer_name = name
|
61 |
+
|
62 |
+
last_conv_layer = named_modules[last_conv_layer_name]
|
63 |
+
last_conv_layer_outputs = {}
|
64 |
+
|
65 |
+
last_conv_layer.register_forward_hook(conv_layer_forward_hook)
|
66 |
+
last_conv_layer.register_backward_hook(conv_layer_backward_hook)
|
67 |
+
class_to_id = {v:k for k,v in model.model.config.id2label.items()}
|
68 |
+
id_to_class = {k:v for k,v in model.model.config.id2label.items()}
|
69 |
+
|
70 |
+
swap_models(model_name)
|
71 |
+
|
72 |
+
def saliency_graph(img1, steps=120):
|
73 |
+
img1 = auto_transformer(img1)
|
74 |
+
img1 = np.squeeze(np.array(img1.pixel_values))
|
75 |
+
if img1.shape[0] < img1.shape[1]:
|
76 |
+
img1 = np.moveaxis(img1, 0, -1)
|
77 |
+
img1 = (img1 - np.min(img1)) / (np.max(img1) - np.min(img1))
|
78 |
+
|
79 |
+
class_idx_str = 'class_idx_str'
|
80 |
+
def gradcam_call(images, call_model_args=None, expected_keys=None):
|
81 |
+
if not isinstance(images, np.ndarray) and not isinstance(images, torch.Tensor) and not isinstance(images, PIL.Image.Image):
|
82 |
+
# return two blank images
|
83 |
+
im1 = np.zeros((224, 224, 3))
|
84 |
+
im2 = np.zeros((224, 224, 3))
|
85 |
+
return im1, im2
|
86 |
+
|
87 |
+
if len(images.shape) == 3:
|
88 |
+
images = np.expand_dims(images, 0)
|
89 |
+
images = torch.tensor(images, dtype=torch.float32)
|
90 |
+
images = images.requires_grad_(True)
|
91 |
+
target_class_idx = call_model_args[class_idx_str]
|
92 |
+
y_pred = model(images)
|
93 |
+
|
94 |
+
if saliency.base.INPUT_OUTPUT_GRADIENTS in expected_keys:
|
95 |
+
out = y_pred[:, target_class_idx]
|
96 |
+
# move actual color channel to the 1st dimension
|
97 |
+
#images = torch.movedim(images, 3, 1)
|
98 |
+
grads = torch.autograd.grad(out, images, grad_outputs=torch.ones_like(out))
|
99 |
+
grads = grads[0].detach().cpu().numpy()
|
100 |
+
return {saliency.base.INPUT_OUTPUT_GRADIENTS: grads}
|
101 |
+
else:
|
102 |
+
hot = torch.zeroes_like(y_pred)
|
103 |
+
hot[:, target_class_idx] = 1
|
104 |
+
model.zero_grad()
|
105 |
+
y_pred.backward(gradient=hot, retain_graph=True)
|
106 |
+
return last_conv_layer_outputs
|
107 |
+
|
108 |
+
im = img1.astype(np.float32)
|
109 |
+
base = np.zeros(img1.shape)
|
110 |
+
|
111 |
+
pred = model(torch.from_numpy(im))
|
112 |
+
class_pred = pred.argmax(dim=1).item()
|
113 |
+
call_model_args = {class_idx_str: class_pred}
|
114 |
+
gradients = saliency.IntegratedGradients()
|
115 |
+
|
116 |
+
s = gradients.GetSmoothedMask(im, gradcam_call, call_model_args, x_steps=steps, x_baseline=base, batch_size=25)
|
117 |
+
|
118 |
+
smoothgrad_mask_grayscale = saliency.VisualizeImageGrayscale(s)
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
output = model.forward(img1)
|
122 |
+
output = torch.nn.functional.softmax(output, dim=1)
|
123 |
+
output = output.cpu().numpy()
|
124 |
+
top_5 = [(id_to_class[int(i)], output[0][i]) for i in np.argsort(output)[0][-5:][::-1]]
|
125 |
+
|
126 |
+
|
127 |
+
# Render the saliency masks.
|
128 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
129 |
+
ax.barh([x[0] for x in top_5], [x[1] for x in top_5])
|
130 |
+
ax.set_title('Top 5 Predictions')
|
131 |
+
buf = io.BytesIO()
|
132 |
+
fig.savefig(buf, format='jpg')
|
133 |
+
buf.seek(0)
|
134 |
+
fig_img = Image.open(buf)
|
135 |
+
plt.close(fig)
|
136 |
+
return smoothgrad_mask_grayscale, fig_img
|
137 |
+
|
138 |
+
# gradio Interface
|
139 |
+
def gradio_interface(img):
|
140 |
+
smoothgrad_mask_grayscale, fig_img = saliency_graph(img, steps=25)
|
141 |
+
return smoothgrad_mask_grayscale, fig_img
|
142 |
+
|
143 |
+
with gr.Blocks(title='Looking at the pixels models attend to', description="This function finds the most critical pixels in an image for predicting a class. The best models will ideally make predictions by highlighting the expected object. Poorly generalizable models will often rely on environmental cues instead and forego looking at the most important pixels. Highlighting the most important pixels helps explain/build trust about whether a given model uses the correct features to make its prediction.", live=True) as iface:
|
144 |
+
#examples = gr.Examples(examples=["ex1.jpg", "ex2.jpg", "ex3.jpg", "ex4.jpg"], label="Examples", inputs="image", examples_per_page=4)
|
145 |
+
gr.Markdown("Choose a model to use for classying images:")
|
146 |
+
with gr.Row():
|
147 |
+
with gr.Column():
|
148 |
+
test_image = gr.Image(label="Input Image", live=True)
|
149 |
+
input_btn = gr.Button(label="Classify image")
|
150 |
+
model_select_dropdown = gr.Radio(model_names, label="Model to test", interactive=True, default=0)
|
151 |
+
with gr.Column():
|
152 |
+
output = gr.Image(label="Pixels used for classification")
|
153 |
+
output2 = gr.Image(label="Top 5 Predictions")
|
154 |
+
|
155 |
+
input_btn.click(gradio_interface, test_image, outputs=[output, output2])
|
156 |
+
model_select_dropdown.change(swap_models, inputs=[model_select_dropdown])
|
157 |
+
examples = gr.Examples(
|
158 |
+
examples = [os.path.join('./', x) for x in os.listdir('./') if x.endswith('.jpg')],
|
159 |
+
inputs=gr.Image(),
|
160 |
+
label="Examples",
|
161 |
+
fn=gradio_interface,
|
162 |
+
cache_examples=False,
|
163 |
+
run_on_click=True,
|
164 |
+
postprocess=True,
|
165 |
+
preprocess=True,
|
166 |
+
outputs=[output, output2])
|
167 |
+
|
168 |
+
|
169 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
numpy
|
3 |
+
matplotlib
|
4 |
+
pillow
|
5 |
+
matplotlib
|
6 |
+
saliency
|
7 |
+
gradio
|
8 |
+
transformers
|