File size: 6,337 Bytes
9bec60b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c3fb24
9bec60b
fb71b6d
9bec60b
b694259
9bec60b
 
 
 
 
 
 
3f3a5b0
fb71b6d
 
9bec60b
3f3a5b0
9bec60b
 
 
 
 
 
 
 
 
fb71b6d
9bec60b
a2446f8
9bec60b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6c588d
9bec60b
 
 
 
 
 
 
 
 
 
 
a2446f8
b72342a
 
9bec60b
 
 
 
 
 
 
 
 
3f3a5b0
 
 
749a9c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import gc
import PIL
import glob
import timm
import torch
import nopdb
import pickle
import torchvision

import numpy as np
import gradio as gr
from torch import nn
from PIL import Image
import matplotlib.pyplot as plt
import IPython.display as ipd
from typing import Tuple, Dict
from timeit import default_timer as timer
from timm.data import resolve_data_config, create_transform

example_list = [["examples/" + example] for example in os.listdir("examples")]

vision_transformer_weights = torch.load('pytorch_vit_b_16_timm.pth',
                                        map_location=torch.device('cpu'))

vision_transformer = timm.create_model('vit_base_patch16_224', pretrained=False)

vision_transformer.head = nn.Linear(in_features=768,
                                    out_features=38)

vision_transformer.load_state_dict(vision_transformer_weights)

from torchvision import datasets, transforms

data_transforms = transforms.Compose([
    transforms.Resize(size=(256, 256)),
    transforms.CenterCrop(size=224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225],)
])

def inv_normalize(tensor):
    """Normalize an image tensor back to the 0-255 range."""
    tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
    return tensor

def inv_transform(tensor, normalize=True):
    """Convert a tensor back to an image."""
    tensor = inv_normalize(tensor)
    array = tensor.detach().cpu().numpy()
    array = array.transpose(1, 2, 0).astype(np.uint8)
    return PIL.Image.fromarray(array)

with open('class_names.ob', 'rb') as fp:
    class_names = pickle.load(fp)

img = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
img_transformed = data_transforms(img)

def predict_disease(image) -> Tuple[Dict, float]:
    """Return prediction classes with probabilities for an input image."""
    input = data_transforms(image)
    start_time = timer()
    prediction_dict = {}
    with torch.inference_mode():
        [logits] = vision_transformer(input[None])
        probs = torch.softmax(logits, dim=0)
        topk_prob, topk_id = torch.topk(probs, 3)
        for i in range(topk_prob.size(0)):
            prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item()
    prediction_time = round(timer() - start_time, 5)
    return prediction_dict, prediction_time

def predict_tensor(img_tensor):
    """Return prediction classes with probabilities for an input image."""
    with torch.inference_mode():
        [logits] = vision_transformer(img_tensor[None])
        probs = torch.softmax(logits, dim=0)
        topk_prob, topk_id = torch.topk(probs, 3)

random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')

def plot_attention(image, encoder_layer_num=5):
    """Given an input image, plot the average attention weight given to each image patch by each attention head."""
    attention_map_outputs = []
    input_data = data_transforms(image)
    with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
        predict_tensor(img_transformed)
    attn = attn_call.locals['attn'][0]
    with torch.inference_mode():
        # loop over attention heads
        for h_weights in attn:
            h_weights = h_weights.mean(axis=-2)  # average over all attention keys
            h_weights = h_weights[1:]  # skip the [class] token
            output_img = plot_weights(input_data, h_weights)
            attention_map_outputs.append(output_img)
    return attention_map_outputs

def plot_weights(input_data, patch_weights):
    """Display the image: Brighter the patch, higher is the attention."""
    # multiply each patch of the input image by the corresponding weight
    plot = inv_normalize(input_data.clone())
    for i in range(patch_weights.shape[0]):
        x = i * 16 % 224
        y = i // (224 // 16) * 16
        plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
    attn_map_img = inv_transform(plot, normalize=False)
    attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
    return attn_map_img

attention_maps = plot_attention(random_image, 5)

title_classify = "Image Based Plant Disease Identification 🍃🤓"

description_classify = """Finetuned a Vision Transformer Base (Patch Size: 16 | Image Size: 224) architecture to 
                  identify the plant disease."""

article_classify = """Upload an image from the example list or choose one of your own. [Dataset Classes](https://data.mendeley.com/datasets/tywbtsjrjv/1)"""

title_attention = "Visualize Attention Weights 🧊🔍"

description_attention = """The Vision Transformer Base architecture has 12 transformer Encoder layers (12 attention heads in each)."""

article_attention = """From the dropdown menu, choose the Encoder layer whose attention weights you would like to visualize."""

classify_interface = gr.Interface(
    fn=predict_disease,
    inputs=gr.Image(type="pil", label="Image"),
    outputs=[gr.Label(num_top_classes=3, label="Predictions"),
              gr.Number(label="Prediction time (secs)")],
    examples=example_list,
    title=title_classify,
    description=description_classify,
    article=article_classify,
    thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
)

attention_interface = gr.Interface(
    fn=plot_attention,
    inputs=[gr.Image(type="pil", label="Image"),
           gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], 
                       label="Attention Layer", value="6", type="index")],
    outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
    examples=example_list,
    title=title_attention,
    description=description_attention,
    article=article_attention,
    thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
)

demo = gr.TabbedInterface([classify_interface, attention_interface], 
                          ["Identify Disease", "Visualize Attention Map"], title="NatureAI Diagnostics🧑🩺")

if __name__ == "__main__":
    demo.launch()