TexR6 commited on
Commit
9bec60b
1 Parent(s): e324642

fixed missing argument error

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import PIL
4
+ import glob
5
+ import timm
6
+ import torch
7
+ import nopdb
8
+ import pickle
9
+ import torchvision
10
+
11
+ import numpy as np
12
+ import gradio as gr
13
+ from torch import nn
14
+ from PIL import Image
15
+ import matplotlib.pyplot as plt
16
+ import IPython.display as ipd
17
+ from typing import Tuple, Dict
18
+ from timeit import default_timer as timer
19
+ from timm.data import resolve_data_config, create_transform
20
+
21
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
22
+
23
+ vision_transformer_weights = torch.load('pytorch_vit_b_16_timm.pth',
24
+ map_location=torch.device('cpu'))
25
+
26
+ vision_transformer = timm.create_model('vit_base_patch16_224', pretrained=False)
27
+
28
+ vision_transformer.head = nn.Linear(in_features=768,
29
+ out_features=38)
30
+
31
+ vision_transformer.load_state_dict(vision_transformer_weights)
32
+
33
+ from torchvision import datasets, transforms
34
+
35
+ data_transforms = transforms.Compose([
36
+ transforms.Resize(size=(256, 256)),
37
+ transforms.CenterCrop(size=224),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
40
+ std=[0.229, 0.224, 0.225],)
41
+ ])
42
+
43
+ def inv_normalize(tensor):
44
+ """Normalize an image tensor back to the 0-255 range."""
45
+ tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
46
+ return tensor
47
+
48
+ def inv_transform(tensor, normalize=True):
49
+ """Convert a tensor back to an image."""
50
+ tensor = inv_normalize(tensor)
51
+ array = tensor.detach().cpu().numpy()
52
+ array = array.transpose(1, 2, 0).astype(np.uint8)
53
+ return PIL.Image.fromarray(array)
54
+
55
+ with open('class_names.ob', 'rb') as fp:
56
+ class_names = pickle.load(fp)
57
+
58
+ img = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
59
+ img_transformed = data_transforms(img)
60
+
61
+ def predict_disease(image) -> Tuple[Dict, float]:
62
+ """Return prediction classes with probabilities for an input image."""
63
+ input = data_transforms(image)
64
+ start_time = timer()
65
+ prediction_dict = {}
66
+ with torch.inference_mode():
67
+ [logits] = vision_transformer(input[None])
68
+ probs = torch.softmax(logits, dim=0)
69
+ topk_prob, topk_id = torch.topk(probs, 3)
70
+ for i in range(topk_prob.size(0)):
71
+ prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item()
72
+ prediction_time = round(timer() - start_time, 5)
73
+ return prediction_dict, prediction_time
74
+
75
+ def predict_tensor(img_tensor):
76
+ """Return prediction classes with probabilities for an input image."""
77
+ with torch.inference_mode():
78
+ [logits] = vision_transformer(img_tensor[None])
79
+ probs = torch.softmax(logits, dim=0)
80
+ topk_prob, topk_id = torch.topk(probs, 3)
81
+
82
+ with nopdb.capture_call(vision_transformer.blocks[5].attn.forward) as attn_call:
83
+ predict_tensor(img_transformed)
84
+
85
+ random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
86
+
87
+ def plot_attention(image, layer_num):
88
+ """Given an input image, plot the average attention weight given to each image patch by each attention head."""
89
+ input_data = data_transforms(image)
90
+ with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
91
+ predict_tensor(img_transformed)
92
+ attn = attn_call.locals['attn'][0]
93
+ with torch.inference_mode():
94
+ # loop over attention heads
95
+ attention_block_num = 0
96
+ for h_weights in attn:
97
+ h_weights = h_weights.mean(axis=-2) # average over all attention keys
98
+ h_weights = h_weights[1:] # skip the [class] token
99
+ attention_block_num += 1
100
+ plot_weights(input_data, h_weights, attention_block_num)
101
+ attention_maps = glob.glob('storage/*.png')
102
+ return attention_maps
103
+
104
+ def plot_weights(input_data, patch_weights, num_attention_block):
105
+ """Display the image: Brighter the patch, higher is the attention."""
106
+ # multiply each patch of the input image by the corresponding weight
107
+ plot = inv_normalize(input_data.clone())
108
+ for i in range(patch_weights.shape[0]):
109
+ x = i * 16 % 224
110
+ y = i // (224 // 16) * 16
111
+ plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
112
+ attn_map_img = inv_transform(plot, normalize=False)
113
+ attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
114
+ attn_map_img.save(f"storage/attention_map_{num_attention_block}.png", "PNG")
115
+
116
+ attention_maps = plot_attention(random_image, 6)
117
+
118
+ title_classify = "Image Based Plant Disease Identification 🍃🤓"
119
+
120
+ description_classify = """Finetuned a Vision Transformer Base (Patch Size: 16 | Image Size: 224) architecture to
121
+ identify the plant disease."""
122
+
123
+ article_classify = """Upload an image from the example list or choose one of your own. [Dataset Classes](https://data.mendeley.com/datasets/tywbtsjrjv/1)"""
124
+
125
+ title_attention = "Visualize Attention Weights 🧊🔍"
126
+
127
+ description_attention = """The Vision Transformer Base architecture has 12 transformer Encoder layers (12 attention heads in each)."""
128
+
129
+ article_attention = """From the dropdown menu, choose the Encoder layer whose attention weights you would like to visualize."""
130
+
131
+ classify_interface = gr.Interface(
132
+ fn=predict_disease,
133
+ inputs=gr.Image(type="pil", label="Image"),
134
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"),
135
+ gr.Number(label="Prediction time (secs)")],
136
+ examples=example_list,
137
+ title=title_classify,
138
+ description=description_classify,
139
+ article=article_classify,
140
+ thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
141
+ )
142
+
143
+ attention_interface = gr.Interface(
144
+ fn=plot_attention,
145
+ inputs=[gr.Image(type="pil", label="Image"),
146
+ gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
147
+ label="Attention Layer", value="6")],
148
+ outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
149
+ examples=example_list,
150
+ title=title_attention,
151
+ description=description_attention,
152
+ article=article_attention,
153
+ thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
154
+ )
155
+
156
+ demo = gr.TabbedInterface([classify_interface, attention_interface],
157
+ ["Identify Disease", "Visualize Attention Map"],
158
+ title="NatureAI Diagnostics🧑🩺").launch(debug=False, share=True)