File size: 7,434 Bytes
91969e1
 
 
 
 
 
 
 
 
 
 
 
27ca97b
91969e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1afe6fa
91969e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27ca97b
91969e1
 
aa40902
91969e1
aa40902
91969e1
 
 
 
27ca97b
91969e1
 
 
 
 
 
 
 
 
 
 
dceff07
91969e1
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import warnings
warnings.filterwarnings('ignore')
import torch, numpy as np, os
from torch import nn
from transformers import AutoModelForImageClassification, AutoConfig, AutoImageProcessor
import matplotlib.pyplot as plt
from PIL import Image
import saliency.core as saliency
import io
import gradio as gr
import PIL

model_choice = 3
model_names = ["nvidia/mit-b0",'facebook/convnext-base-224', 'microsoft/resnet-18', 'microsoft/swin-tiny-patch4-window7-224']
model_name = model_names[model_choice]
device     = 'cuda' if torch.cuda.is_available() else 'cpu'

class Model(nn.Module):
    def __init__(self, MODEL_NAME=model_name):
        super().__init__()
        self.config = AutoConfig.from_pretrained(MODEL_NAME, finetuning_task="image-classification")
        self.model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
        self.class_len = self.config.num_labels
        self.id2label = self.config.id2label
        self.label2id = self.config.label2id

    def forward(self, x):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x)
        if len(x.shape) == 3:
            x = x.unsqueeze(0)
        if x.shape[-1] == 3:
            x = x.permute(0, 3, 1, 2)
        x = x.to(device)
        x = self.model(x)
        return x.logits

def conv_layer_forward_hook(module, input, output):
    """Method from Examples_pytorch.ipynb for the gradcam library https://github.com/PAIR-code/saliency."""
    global last_conv_layer_outputs
    last_conv_layer_outputs[saliency.base.CONVOLUTION_LAYER_VALUES] = torch.movedim(output, 3, 1).detach().cpu().numpy()
def conv_layer_backward_hook(module, grad_input, grad_output):
    """Method from Examples_pytorch.ipynb for the gradcam library https://github.com/PAIR-code/saliency."""
    global last_conv_layer_outputs
    last_conv_layer_outputs[saliency.base.CONVOLUTION_OUTPUT_GRADIENTS] = torch.movedim(grad_output[0], 3, 1).detach().cpu().numpy()

auto_transformer, class_to_id, id_to_class, last_conv_layer, last_conv_layer_outputs = None, None, None, None, None


def swap_models(name):
    global model, auto_transformer, class_to_id, id_to_class, last_conv_layer, last_conv_layer_outputs
    auto_transformer = AutoImageProcessor.from_pretrained(name)
    model = Model(MODEL_NAME=name)
    model = model.to(device).eval()
    # register the hooks for the last convolution layer for Grad-Cam
    named_modules = dict(model.model.named_modules())
    last_conv_layer_name = None
    for name, module in named_modules.items():
        if isinstance(module, torch.nn.Conv2d):
            last_conv_layer_name = name

    last_conv_layer = named_modules[last_conv_layer_name]
    last_conv_layer_outputs = {}

    last_conv_layer.register_forward_hook(conv_layer_forward_hook)
    last_conv_layer.register_backward_hook(conv_layer_backward_hook)
    class_to_id = {v:k for k,v in model.model.config.id2label.items()}
    id_to_class = {k:v for k,v in model.model.config.id2label.items()}

swap_models(model_name)

def saliency_graph(img1, steps=25):
    img1 = auto_transformer(img1)
    img1 = np.squeeze(np.array(img1.pixel_values))
    if img1.shape[0] < img1.shape[1]:
        img1 = np.moveaxis(img1, 0, -1)
    img1 = (img1 - np.min(img1)) / (np.max(img1) - np.min(img1))

    class_idx_str = 'class_idx_str'
    def gradcam_call(images, call_model_args=None, expected_keys=None):
        if not isinstance(images, np.ndarray) and not isinstance(images, torch.Tensor) and not isinstance(images, PIL.Image.Image):
            # return two blank images
            im1 = np.zeros((224, 224, 3))
            im2 = np.zeros((224, 224, 3))
            return im1, im2
        
        if len(images.shape) == 3:
            images = np.expand_dims(images, 0)
        images = torch.tensor(images, dtype=torch.float32)
        images = images.requires_grad_(True)
        target_class_idx = call_model_args[class_idx_str]
        y_pred = model(images)
        
        if saliency.base.INPUT_OUTPUT_GRADIENTS in expected_keys:
            out =  y_pred[:, target_class_idx]
            # move actual color channel to the 1st dimension
            #images = torch.movedim(images, 3, 1)
            grads = torch.autograd.grad(out, images, grad_outputs=torch.ones_like(out))
            grads = grads[0].detach().cpu().numpy()
            return {saliency.base.INPUT_OUTPUT_GRADIENTS: grads}
        else:
            hot = torch.zeroes_like(y_pred)
            hot[:, target_class_idx] = 1
            model.zero_grad()
            y_pred.backward(gradient=hot, retain_graph=True)
            return last_conv_layer_outputs

    im = img1.astype(np.float32)
    base = np.zeros(img1.shape)

    pred = model(torch.from_numpy(im))
    class_pred = pred.argmax(dim=1).item()
    call_model_args = {class_idx_str: class_pred}
    gradients = saliency.IntegratedGradients()

    s = gradients.GetSmoothedMask(im, gradcam_call, call_model_args, x_steps=steps, x_baseline=base, batch_size=25)

    smoothgrad_mask_grayscale = saliency.VisualizeImageGrayscale(s)

    with torch.no_grad():
        output = model.forward(img1)
        output = torch.nn.functional.softmax(output, dim=1)
        output = output.cpu().numpy()
    top_5 = [(id_to_class[int(i)], output[0][i]) for i in np.argsort(output)[0][-5:][::-1]]


    # Render the saliency masks.
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.barh([x[0] for x in top_5], [x[1] for x in top_5])
    ax.set_title('Top 5 Predictions')
    buf = io.BytesIO()
    fig.savefig(buf, format='jpg')
    buf.seek(0)
    fig_img = Image.open(buf)
    plt.close(fig)
    return smoothgrad_mask_grayscale, fig_img

# gradio Interface
def gradio_interface(img):
    smoothgrad_mask_grayscale, fig_img = saliency_graph(img, steps=20)
    return smoothgrad_mask_grayscale, fig_img

with gr.Blocks(live=True) as iface:
    #examples = gr.Examples(examples=["ex1.jpg", "ex2.jpg", "ex3.jpg", "ex4.jpg"], label="Examples", inputs="image", examples_per_page=4)
    gr.Markdown("This function finds the most critical pixels in an image for predicting a class by looking at the pixels models attend to. The best models will ideally make predictions by highlighting the expected object. Poorly generalizable models will often rely on environmental cues instead and forego looking at the most important pixels. Highlighting the most important pixels helps explain/build trust about whether a given model uses the correct features to make its prediction.")
    with gr.Row():
        with gr.Column():
            test_image = gr.Image(label="Input Image", live=True)
            input_btn = gr.Button(label="Classify image")
            model_select_dropdown = gr.Radio(model_names, label="Model to test", interactive=True, default=3)
        with gr.Column():
            output = gr.Image(label="Pixels used for classification")
            output2 = gr.Image(label="Top 5 Predictions")

    input_btn.click(gradio_interface, test_image, outputs=[output, output2])
    model_select_dropdown.change(swap_models, inputs=[model_select_dropdown])
    examples = gr.Examples(
        examples = [os.path.join('./', x) for x in os.listdir('./') if x.endswith('.jpg')],
        inputs=gr.Image(),
        label="Examples",
        fn=gradio_interface,
        cache_examples=True,
        run_on_click=True,
        postprocess=True,
        preprocess=True,
        outputs=[output, output2])


iface.launch()