delete app.py
Browse files
app.py
DELETED
@@ -1,154 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import gc
|
3 |
-
import PIL
|
4 |
-
import glob
|
5 |
-
import timm
|
6 |
-
import torch
|
7 |
-
import nopdb
|
8 |
-
import pickle
|
9 |
-
import torchvision
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import gradio as gr
|
13 |
-
from torch import nn
|
14 |
-
from PIL import Image
|
15 |
-
import matplotlib.pyplot as plt
|
16 |
-
import IPython.display as ipd
|
17 |
-
from typing import Tuple, Dict
|
18 |
-
from timeit import default_timer as timer
|
19 |
-
from timm.data import resolve_data_config, create_transform
|
20 |
-
|
21 |
-
example_list = [["examples/" + example] for example in os.listdir("examples")]
|
22 |
-
|
23 |
-
vision_transformer_weights = torch.load('pytorch_vit_b_16_timm.pth',
|
24 |
-
map_location=torch.device('cpu'))
|
25 |
-
|
26 |
-
vision_transformer = timm.create_model('vit_base_patch16_224', pretrained=False)
|
27 |
-
|
28 |
-
vision_transformer.head = nn.Linear(in_features=768,
|
29 |
-
out_features=38)
|
30 |
-
|
31 |
-
vision_transformer.load_state_dict(vision_transformer_weights)
|
32 |
-
|
33 |
-
from torchvision import datasets, transforms
|
34 |
-
|
35 |
-
data_transforms = transforms.Compose([
|
36 |
-
transforms.Resize(size=(256, 256)),
|
37 |
-
transforms.CenterCrop(size=224),
|
38 |
-
transforms.ToTensor(),
|
39 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
40 |
-
std=[0.229, 0.224, 0.225],)
|
41 |
-
])
|
42 |
-
|
43 |
-
def inv_normalize(tensor):
|
44 |
-
"""Normalize an image tensor back to the 0-255 range."""
|
45 |
-
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
|
46 |
-
return tensor
|
47 |
-
|
48 |
-
def inv_transform(tensor, normalize=True):
|
49 |
-
"""Convert a tensor back to an image."""
|
50 |
-
tensor = inv_normalize(tensor)
|
51 |
-
array = tensor.detach().cpu().numpy()
|
52 |
-
array = array.transpose(1, 2, 0).astype(np.uint8)
|
53 |
-
return PIL.Image.fromarray(array)
|
54 |
-
|
55 |
-
with open('class_names.ob', 'rb') as fp:
|
56 |
-
class_names = pickle.load(fp)
|
57 |
-
|
58 |
-
img = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
|
59 |
-
img_transformed = data_transforms(img)
|
60 |
-
|
61 |
-
def predict_disease(image) -> Tuple[Dict, float]:
|
62 |
-
"""Return prediction classes with probabilities for an input image."""
|
63 |
-
input = data_transforms(image)
|
64 |
-
start_time = timer()
|
65 |
-
prediction_dict = {}
|
66 |
-
with torch.inference_mode():
|
67 |
-
[logits] = vision_transformer(input[None])
|
68 |
-
probs = torch.softmax(logits, dim=0)
|
69 |
-
topk_prob, topk_id = torch.topk(probs, 3)
|
70 |
-
for i in range(topk_prob.size(0)):
|
71 |
-
prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item()
|
72 |
-
prediction_time = round(timer() - start_time, 5)
|
73 |
-
return prediction_dict, prediction_time
|
74 |
-
|
75 |
-
def predict_tensor(img_tensor):
|
76 |
-
"""Return prediction classes with probabilities for an input image."""
|
77 |
-
with torch.inference_mode():
|
78 |
-
[logits] = vision_transformer(img_tensor[None])
|
79 |
-
probs = torch.softmax(logits, dim=0)
|
80 |
-
topk_prob, topk_id = torch.topk(probs, 3)
|
81 |
-
|
82 |
-
with nopdb.capture_call(vision_transformer.blocks[5].attn.forward) as attn_call:
|
83 |
-
predict_tensor(img_transformed)
|
84 |
-
|
85 |
-
def plot_attention(image, layer_num):
|
86 |
-
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
|
87 |
-
input_data = data_transforms(image)
|
88 |
-
with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
|
89 |
-
predict_tensor(img_transformed)
|
90 |
-
attn = attn_call.locals['attn'][0]
|
91 |
-
with torch.inference_mode():
|
92 |
-
# loop over attention heads
|
93 |
-
attention_block_num = 0
|
94 |
-
for h_weights in attn:
|
95 |
-
h_weights = h_weights.mean(axis=-2) # average over all attention keys
|
96 |
-
h_weights = h_weights[1:] # skip the [class] token
|
97 |
-
attention_block_num += 1
|
98 |
-
plot_weights(input_data, h_weights, attention_block_num)
|
99 |
-
attention_maps = glob.glob('storage/*.png')
|
100 |
-
return attention_maps
|
101 |
-
|
102 |
-
def plot_weights(input_data, patch_weights, num_attention_block):
|
103 |
-
"""Display the image: Brighter the patch, higher is the attention."""
|
104 |
-
# multiply each patch of the input image by the corresponding weight
|
105 |
-
plot = inv_normalize(input_data.clone())
|
106 |
-
for i in range(patch_weights.shape[0]):
|
107 |
-
x = i * 16 % 224
|
108 |
-
y = i // (224 // 16) * 16
|
109 |
-
plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
|
110 |
-
attn_map_img = inv_transform(plot, normalize=False)
|
111 |
-
attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
|
112 |
-
attn_map_img.save(f"storage/attention_map_{num_attention_block}.png", "PNG")
|
113 |
-
|
114 |
-
title_classify = "Image Based Plant Disease Identification 🍃🤓"
|
115 |
-
|
116 |
-
description_classify = """Finetuned a Vision Transformer Base (Patch Size: 16 | Image Size: 224) architecture to
|
117 |
-
identify the plant disease."""
|
118 |
-
|
119 |
-
article_classify = """Upload an image from the example list or choose one of your own. [Dataset Classes](https://data.mendeley.com/datasets/tywbtsjrjv/1)"""
|
120 |
-
|
121 |
-
title_attention = "Visualize Attention Weights 🧊🔍"
|
122 |
-
|
123 |
-
description_attention = """The Vision Transformer Base architecture has 12 transformer Encoder layers (12 attention heads in each)."""
|
124 |
-
|
125 |
-
article_attention = """From the dropdown menu, choose the Encoder layer whose attention weights you would like to visualize."""
|
126 |
-
|
127 |
-
classify_interface = gr.Interface(
|
128 |
-
fn=predict_disease,
|
129 |
-
inputs=gr.Image(type="pil", label="Image"),
|
130 |
-
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
|
131 |
-
gr.Number(label="Prediction time (secs)")],
|
132 |
-
examples=example_list,
|
133 |
-
title=title_classify,
|
134 |
-
description=description_classify,
|
135 |
-
article=article_classify,
|
136 |
-
thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
|
137 |
-
)
|
138 |
-
|
139 |
-
attention_interface = gr.Interface(
|
140 |
-
fn=plot_attention,
|
141 |
-
inputs=[gr.Image(type="pil", label="Image"),
|
142 |
-
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
|
143 |
-
label="Attention Layer", value="6")],
|
144 |
-
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
|
145 |
-
examples=example_list,
|
146 |
-
title=title_attention,
|
147 |
-
description=description_attention,
|
148 |
-
article=article_attention,
|
149 |
-
thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
|
150 |
-
)
|
151 |
-
|
152 |
-
demo = gr.TabbedInterface([classify_interface, attention_interface],
|
153 |
-
["Identify Disease", "Visualize Attention Map"],
|
154 |
-
title="NatureAI Diagnostics🧑🩺").launch(debug=False, share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|