initial commit
Browse files- app.py +154 -0
- class_names.ob +0 -0
- examples/AppleScab2.JPG +0 -0
- examples/PotatoHealthy2.JPG +0 -0
- examples/TomatoHealthy2.JPG +0 -0
- examples/TomatoYellowCurlVirus6.JPG +0 -0
- model.py +134 -0
- pytorch_vit_b_16_timm.pth +3 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import PIL
|
4 |
+
import glob
|
5 |
+
import timm
|
6 |
+
import torch
|
7 |
+
import nopdb
|
8 |
+
import pickle
|
9 |
+
import torchvision
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import gradio as gr
|
13 |
+
from torch import nn
|
14 |
+
from PIL import Image
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import IPython.display as ipd
|
17 |
+
from typing import Tuple, Dict
|
18 |
+
from timeit import default_timer as timer
|
19 |
+
from timm.data import resolve_data_config, create_transform
|
20 |
+
|
21 |
+
example_list = [["examples/" + example] for example in os.listdir("examples")]
|
22 |
+
|
23 |
+
vision_transformer_weights = torch.load('pytorch_vit_b_16_timm.pth',
|
24 |
+
map_location=torch.device('cpu'))
|
25 |
+
|
26 |
+
vision_transformer = timm.create_model('vit_base_patch16_224', pretrained=False)
|
27 |
+
|
28 |
+
vision_transformer.head = nn.Linear(in_features=768,
|
29 |
+
out_features=38)
|
30 |
+
|
31 |
+
vision_transformer.load_state_dict(vision_transformer_weights)
|
32 |
+
|
33 |
+
from torchvision import datasets, transforms
|
34 |
+
|
35 |
+
data_transforms = transforms.Compose([
|
36 |
+
transforms.Resize(size=(256, 256)),
|
37 |
+
transforms.CenterCrop(size=224),
|
38 |
+
transforms.ToTensor(),
|
39 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
40 |
+
std=[0.229, 0.224, 0.225],)
|
41 |
+
])
|
42 |
+
|
43 |
+
def inv_normalize(tensor):
|
44 |
+
"""Normalize an image tensor back to the 0-255 range."""
|
45 |
+
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
|
46 |
+
return tensor
|
47 |
+
|
48 |
+
def inv_transform(tensor, normalize=True):
|
49 |
+
"""Convert a tensor back to an image."""
|
50 |
+
tensor = inv_normalize(tensor)
|
51 |
+
array = tensor.detach().cpu().numpy()
|
52 |
+
array = array.transpose(1, 2, 0).astype(np.uint8)
|
53 |
+
return PIL.Image.fromarray(array)
|
54 |
+
|
55 |
+
with open('class_names.ob', 'rb') as fp:
|
56 |
+
class_names = pickle.load(fp)
|
57 |
+
|
58 |
+
img = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
|
59 |
+
img_transformed = data_transforms(img)
|
60 |
+
|
61 |
+
def predict_disease(image) -> Tuple[Dict, float]:
|
62 |
+
"""Return prediction classes with probabilities for an input image."""
|
63 |
+
input = data_transforms(image)
|
64 |
+
start_time = timer()
|
65 |
+
prediction_dict = {}
|
66 |
+
with torch.inference_mode():
|
67 |
+
[logits] = vision_transformer(input[None])
|
68 |
+
probs = torch.softmax(logits, dim=0)
|
69 |
+
topk_prob, topk_id = torch.topk(probs, 3)
|
70 |
+
for i in range(topk_prob.size(0)):
|
71 |
+
prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item()
|
72 |
+
prediction_time = round(timer() - start_time, 5)
|
73 |
+
return prediction_dict, prediction_time
|
74 |
+
|
75 |
+
def predict_tensor(img_tensor):
|
76 |
+
"""Return prediction classes with probabilities for an input image."""
|
77 |
+
with torch.inference_mode():
|
78 |
+
[logits] = vision_transformer(img_tensor[None])
|
79 |
+
probs = torch.softmax(logits, dim=0)
|
80 |
+
topk_prob, topk_id = torch.topk(probs, 3)
|
81 |
+
|
82 |
+
with nopdb.capture_call(vision_transformer.blocks[5].attn.forward) as attn_call:
|
83 |
+
predict_tensor(img_transformed)
|
84 |
+
|
85 |
+
def plot_attention(image, layer_num):
|
86 |
+
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
|
87 |
+
input_data = data_transforms(image)
|
88 |
+
with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
|
89 |
+
predict_tensor(img_transformed)
|
90 |
+
attn = attn_call.locals['attn'][0]
|
91 |
+
with torch.inference_mode():
|
92 |
+
# loop over attention heads
|
93 |
+
attention_block_num = 0
|
94 |
+
for h_weights in attn:
|
95 |
+
h_weights = h_weights.mean(axis=-2) # average over all attention keys
|
96 |
+
h_weights = h_weights[1:] # skip the [class] token
|
97 |
+
attention_block_num += 1
|
98 |
+
plot_weights(input_data, h_weights, attention_block_num)
|
99 |
+
attention_maps = glob.glob('storage/*.png')
|
100 |
+
return attention_maps
|
101 |
+
|
102 |
+
def plot_weights(input_data, patch_weights, num_attention_block):
|
103 |
+
"""Display the image: Brighter the patch, higher is the attention."""
|
104 |
+
# multiply each patch of the input image by the corresponding weight
|
105 |
+
plot = inv_normalize(input_data.clone())
|
106 |
+
for i in range(patch_weights.shape[0]):
|
107 |
+
x = i * 16 % 224
|
108 |
+
y = i // (224 // 16) * 16
|
109 |
+
plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
|
110 |
+
attn_map_img = inv_transform(plot, normalize=False)
|
111 |
+
attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
|
112 |
+
attn_map_img.save(f"storage/attention_map_{num_attention_block}.png", "PNG")
|
113 |
+
|
114 |
+
title_classify = "Image Based Plant Disease Identification 🍃🤓"
|
115 |
+
|
116 |
+
description_classify = """Finetuned a Vision Transformer Base (Patch Size: 16 | Image Size: 224) architecture to
|
117 |
+
identify the plant disease."""
|
118 |
+
|
119 |
+
article_classify = """Upload an image from the example list or choose one of your own. [Dataset Classes](https://data.mendeley.com/datasets/tywbtsjrjv/1)"""
|
120 |
+
|
121 |
+
title_attention = "Visualize Attention Weights 🧊🔍"
|
122 |
+
|
123 |
+
description_attention = """The Vision Transformer Base architecture has 12 transformer Encoder layers (12 attention heads in each)."""
|
124 |
+
|
125 |
+
article_attention = """From the dropdown menu, choose the Encoder layer whose attention weights you would like to visualize."""
|
126 |
+
|
127 |
+
classify_interface = gr.Interface(
|
128 |
+
fn=predict_disease,
|
129 |
+
inputs=gr.Image(type="pil", label="Image"),
|
130 |
+
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
|
131 |
+
gr.Number(label="Prediction time (secs)")],
|
132 |
+
examples=example_list,
|
133 |
+
title=title_classify,
|
134 |
+
description=description_classify,
|
135 |
+
article=article_classify,
|
136 |
+
thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
|
137 |
+
)
|
138 |
+
|
139 |
+
attention_interface = gr.Interface(
|
140 |
+
fn=plot_attention,
|
141 |
+
inputs=[gr.Image(type="pil", label="Image"),
|
142 |
+
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
|
143 |
+
label="Attention Layer", value="6")],
|
144 |
+
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
|
145 |
+
examples=example_list,
|
146 |
+
title=title_attention,
|
147 |
+
description=description_attention,
|
148 |
+
article=article_attention,
|
149 |
+
thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
|
150 |
+
)
|
151 |
+
|
152 |
+
demo = gr.TabbedInterface([classify_interface, attention_interface],
|
153 |
+
["Identify Disease", "Visualize Attention Map"],
|
154 |
+
title="NatureAI Diagnostics🧑🩺").launch(debug=False, share=True)
|
class_names.ob
ADDED
Binary file (1.08 kB). View file
|
|
examples/AppleScab2.JPG
ADDED
examples/PotatoHealthy2.JPG
ADDED
examples/TomatoHealthy2.JPG
ADDED
examples/TomatoYellowCurlVirus6.JPG
ADDED
model.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from torch import nn
|
4 |
+
from torchvision import datasets, transforms
|
5 |
+
|
6 |
+
PATCH_SIZE = 16
|
7 |
+
|
8 |
+
class PatchEmbeddings(nn.Module):
|
9 |
+
def __init__(self, in_channels: int=3,
|
10 |
+
patch_size: int=16,
|
11 |
+
embedding_dim: int=768):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.generate_patches = nn.Conv2d(in_channels=in_channels,
|
15 |
+
out_channels=embedding_dim,
|
16 |
+
kernel_size=patch_size,
|
17 |
+
stride=patch_size, padding=0)
|
18 |
+
|
19 |
+
self.flatten = nn.Flatten(start_dim=2, end_dim=3)
|
20 |
+
|
21 |
+
def forward(self, x: torch.Tensor):
|
22 |
+
image_resolution = x.shape[-1]
|
23 |
+
assert image_resolution % PATCH_SIZE == 0, f"Image size must be divisible by patch size!"
|
24 |
+
|
25 |
+
return self.flatten(self.generate_patches(x)).permute(0, 2, 1)
|
26 |
+
|
27 |
+
class MultiheadSelfAttention(nn.Module):
|
28 |
+
def __init__(self, embedding_dim: int=768,
|
29 |
+
num_heads: int=12, attn_dropout: int=0):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
|
33 |
+
|
34 |
+
self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads,
|
35 |
+
dropout=attn_dropout, batch_first=True)
|
36 |
+
|
37 |
+
def forward(self, x: torch.Tensor):
|
38 |
+
x = self.layer_norm(x)
|
39 |
+
|
40 |
+
attn_output, _ = self.multihead_attn(query=x, key=x, value=x,
|
41 |
+
need_weights=False)
|
42 |
+
return attn_output
|
43 |
+
|
44 |
+
class MLPBlock(nn.Module):
|
45 |
+
def __init__(self, embedding_dim: int=768,
|
46 |
+
mlp_size: int=3072, dropout: int=0.1):
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
|
50 |
+
|
51 |
+
self.mlp = nn.Sequential(
|
52 |
+
nn.Linear(in_features=embedding_dim,
|
53 |
+
out_features=mlp_size),
|
54 |
+
nn.GELU(),
|
55 |
+
nn.Dropout(p=dropout),
|
56 |
+
nn.Linear(in_features=mlp_size,
|
57 |
+
out_features=embedding_dim),
|
58 |
+
nn.Dropout(p=dropout)
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, x: torch.Tensor):
|
62 |
+
return self.mlp(self.layer_norm(x))
|
63 |
+
|
64 |
+
class TransformerEncoderBlock(nn.Module):
|
65 |
+
def __init__(self, embedding_dim: int=768,
|
66 |
+
mlp_size: int=3072, num_heads: int=12,
|
67 |
+
mlp_dropout: int=0.1, attn_dropout: int=0):
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
self.msa_block = MultiheadSelfAttention(embedding_dim=embedding_dim,
|
71 |
+
num_heads=num_heads, attn_dropout=attn_dropout)
|
72 |
+
|
73 |
+
self.mlp_block = MLPBlock(embedding_dim=embedding_dim,
|
74 |
+
mlp_size=mlp_size, dropout=mlp_dropout)
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor):
|
77 |
+
x = self.msa_block(x) + x
|
78 |
+
x = self.mlp_block(x) + x
|
79 |
+
return x
|
80 |
+
|
81 |
+
class VisionTransformer(nn.Module):
|
82 |
+
def __init__(self, img_size: int=IMG_SIZE,
|
83 |
+
in_channels: int=3, patch_size: int=16,
|
84 |
+
num_transformer_layers: int=12, embedding_dim: int=768,
|
85 |
+
mlp_size: int=3072, num_heads: int=12,
|
86 |
+
attn_dropout: int=0, mlp_dropout: int=0.1,
|
87 |
+
embedding_dropout: int=0.1, num_classes: int=38):
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
assert img_size % patch_size == 0, f"Image size must be divisible by patch size!"
|
91 |
+
|
92 |
+
self.num_patches = (img_size * img_size) // patch_size**2
|
93 |
+
|
94 |
+
self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
|
95 |
+
requires_grad=True)
|
96 |
+
|
97 |
+
self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
|
98 |
+
requires_grad=True)
|
99 |
+
|
100 |
+
self.embedding_dropout = nn.Dropout(p=embedding_dropout)
|
101 |
+
|
102 |
+
self.patch_embeddings = PatchEmbeddings(in_channels=in_channels,
|
103 |
+
patch_size=patch_size, embedding_dim=embedding_dim)
|
104 |
+
|
105 |
+
self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
|
106 |
+
num_heads=num_heads, mlp_size=mlp_size,
|
107 |
+
mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
|
108 |
+
|
109 |
+
self.classifier = nn.Sequential(
|
110 |
+
nn.LayerNorm(normalized_shape=embedding_dim),
|
111 |
+
nn.Linear(in_features=embedding_dim,
|
112 |
+
out_features=num_classes)
|
113 |
+
)
|
114 |
+
|
115 |
+
def forward(self, x: torch.Tensor):
|
116 |
+
batch_size = x.shape[0]
|
117 |
+
|
118 |
+
class_token = self.class_embedding.expand(batch_size, -1, -1)
|
119 |
+
|
120 |
+
x = self.patch_embeddings(x)
|
121 |
+
x = torch.cat((class_token, x), dim=1)
|
122 |
+
|
123 |
+
x = self.position_embedding + x
|
124 |
+
x = self.embedding_dropout(x)
|
125 |
+
|
126 |
+
x = self.transformer_encoder(x)
|
127 |
+
x = self.classifier(x[:, 0])
|
128 |
+
|
129 |
+
return x
|
130 |
+
|
131 |
+
with open("class_names.ob", "rb") as fp:
|
132 |
+
class_names = pickle.load(fp)
|
133 |
+
|
134 |
+
vision_transformer = VisionTransformer(num_classes=len(class_names))
|
pytorch_vit_b_16_timm.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b82b34f5fc2aa9be5dac0f146fcccb9589481ee5f93717d83f158695329da181
|
3 |
+
size 343366929
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchvision==0.14.1
|
3 |
+
gradio==3.16.2
|
4 |
+
timm==0.4.5
|
5 |
+
nopdb
|
6 |
+
IPython
|