TexR6 commited on
Commit
36fc972
1 Parent(s): 003da28

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import PIL
4
+ import glob
5
+ import timm
6
+ import torch
7
+ import nopdb
8
+ import pickle
9
+ import torchvision
10
+
11
+ import numpy as np
12
+ import gradio as gr
13
+ from torch import nn
14
+ from PIL import Image
15
+ import matplotlib.pyplot as plt
16
+ import IPython.display as ipd
17
+ from typing import Tuple, Dict
18
+ from timeit import default_timer as timer
19
+ from timm.data import resolve_data_config, create_transform
20
+
21
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
22
+
23
+ vision_transformer_weights = torch.load('pytorch_vit_b_16_timm.pth',
24
+ map_location=torch.device('cpu'))
25
+
26
+ vision_transformer = timm.create_model('vit_base_patch16_224', pretrained=False)
27
+
28
+ vision_transformer.head = nn.Linear(in_features=768,
29
+ out_features=38)
30
+
31
+ vision_transformer.load_state_dict(vision_transformer_weights)
32
+
33
+ from torchvision import datasets, transforms
34
+
35
+ data_transforms = transforms.Compose([
36
+ transforms.Resize(size=(256, 256)),
37
+ transforms.CenterCrop(size=224),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
40
+ std=[0.229, 0.224, 0.225],)
41
+ ])
42
+
43
+ def inv_normalize(tensor):
44
+ """Normalize an image tensor back to the 0-255 range."""
45
+ tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
46
+ return tensor
47
+
48
+ def inv_transform(tensor, normalize=True):
49
+ """Convert a tensor back to an image."""
50
+ tensor = inv_normalize(tensor)
51
+ array = tensor.detach().cpu().numpy()
52
+ array = array.transpose(1, 2, 0).astype(np.uint8)
53
+ return PIL.Image.fromarray(array)
54
+
55
+ with open('class_names.ob', 'rb') as fp:
56
+ class_names = pickle.load(fp)
57
+
58
+ img = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
59
+ img_transformed = data_transforms(img)
60
+
61
+ def predict_disease(image) -> Tuple[Dict, float]:
62
+ """Return prediction classes with probabilities for an input image."""
63
+ input = data_transforms(image)
64
+ start_time = timer()
65
+ prediction_dict = {}
66
+ with torch.inference_mode():
67
+ [logits] = vision_transformer(input[None])
68
+ probs = torch.softmax(logits, dim=0)
69
+ topk_prob, topk_id = torch.topk(probs, 3)
70
+ for i in range(topk_prob.size(0)):
71
+ prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item()
72
+ prediction_time = round(timer() - start_time, 5)
73
+ return prediction_dict, prediction_time
74
+
75
+ def predict_tensor(img_tensor):
76
+ """Return prediction classes with probabilities for an input image."""
77
+ with torch.inference_mode():
78
+ [logits] = vision_transformer(img_tensor[None])
79
+ probs = torch.softmax(logits, dim=0)
80
+ topk_prob, topk_id = torch.topk(probs, 3)
81
+
82
+ with nopdb.capture_call(vision_transformer.blocks[5].attn.forward) as attn_call:
83
+ predict_tensor(img_transformed)
84
+
85
+ def plot_attention(image, layer_num):
86
+ """Given an input image, plot the average attention weight given to each image patch by each attention head."""
87
+ input_data = data_transforms(image)
88
+ with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
89
+ predict_tensor(img_transformed)
90
+ attn = attn_call.locals['attn'][0]
91
+ with torch.inference_mode():
92
+ # loop over attention heads
93
+ attention_block_num = 0
94
+ for h_weights in attn:
95
+ h_weights = h_weights.mean(axis=-2) # average over all attention keys
96
+ h_weights = h_weights[1:] # skip the [class] token
97
+ attention_block_num += 1
98
+ plot_weights(input_data, h_weights, attention_block_num)
99
+ attention_maps = glob.glob('storage/*.png')
100
+ return attention_maps
101
+
102
+ def plot_weights(input_data, patch_weights, num_attention_block):
103
+ """Display the image: Brighter the patch, higher is the attention."""
104
+ # multiply each patch of the input image by the corresponding weight
105
+ plot = inv_normalize(input_data.clone())
106
+ for i in range(patch_weights.shape[0]):
107
+ x = i * 16 % 224
108
+ y = i // (224 // 16) * 16
109
+ plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
110
+ attn_map_img = inv_transform(plot, normalize=False)
111
+ attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
112
+ attn_map_img.save(f"storage/attention_map_{num_attention_block}.png", "PNG")
113
+
114
+ title_classify = "Image Based Plant Disease Identification 🍃🤓"
115
+
116
+ description_classify = """Finetuned a Vision Transformer Base (Patch Size: 16 | Image Size: 224) architecture to
117
+ identify the plant disease."""
118
+
119
+ article_classify = """Upload an image from the example list or choose one of your own. [Dataset Classes](https://data.mendeley.com/datasets/tywbtsjrjv/1)"""
120
+
121
+ title_attention = "Visualize Attention Weights 🧊🔍"
122
+
123
+ description_attention = """The Vision Transformer Base architecture has 12 transformer Encoder layers (12 attention heads in each)."""
124
+
125
+ article_attention = """From the dropdown menu, choose the Encoder layer whose attention weights you would like to visualize."""
126
+
127
+ classify_interface = gr.Interface(
128
+ fn=predict_disease,
129
+ inputs=gr.Image(type="pil", label="Image"),
130
+ outputs=[gr.Label(num_top_classes=3, label="Predictions"),
131
+ gr.Number(label="Prediction time (secs)")],
132
+ examples=example_list,
133
+ title=title_classify,
134
+ description=description_classify,
135
+ article=article_classify,
136
+ thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
137
+ )
138
+
139
+ attention_interface = gr.Interface(
140
+ fn=plot_attention,
141
+ inputs=[gr.Image(type="pil", label="Image"),
142
+ gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
143
+ label="Attention Layer", value="6")],
144
+ outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
145
+ examples=example_list,
146
+ title=title_attention,
147
+ description=description_attention,
148
+ article=article_attention,
149
+ thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
150
+ )
151
+
152
+ demo = gr.TabbedInterface([classify_interface, attention_interface],
153
+ ["Identify Disease", "Visualize Attention Map"],
154
+ title="NatureAI Diagnostics🧑🩺").launch(debug=False, share=True)
class_names.ob ADDED
Binary file (1.08 kB). View file
 
examples/AppleScab2.JPG ADDED
examples/PotatoHealthy2.JPG ADDED
examples/TomatoHealthy2.JPG ADDED
examples/TomatoYellowCurlVirus6.JPG ADDED
model.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+ from torchvision import datasets, transforms
5
+
6
+ PATCH_SIZE = 16
7
+
8
+ class PatchEmbeddings(nn.Module):
9
+ def __init__(self, in_channels: int=3,
10
+ patch_size: int=16,
11
+ embedding_dim: int=768):
12
+ super().__init__()
13
+
14
+ self.generate_patches = nn.Conv2d(in_channels=in_channels,
15
+ out_channels=embedding_dim,
16
+ kernel_size=patch_size,
17
+ stride=patch_size, padding=0)
18
+
19
+ self.flatten = nn.Flatten(start_dim=2, end_dim=3)
20
+
21
+ def forward(self, x: torch.Tensor):
22
+ image_resolution = x.shape[-1]
23
+ assert image_resolution % PATCH_SIZE == 0, f"Image size must be divisible by patch size!"
24
+
25
+ return self.flatten(self.generate_patches(x)).permute(0, 2, 1)
26
+
27
+ class MultiheadSelfAttention(nn.Module):
28
+ def __init__(self, embedding_dim: int=768,
29
+ num_heads: int=12, attn_dropout: int=0):
30
+ super().__init__()
31
+
32
+ self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
33
+
34
+ self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads,
35
+ dropout=attn_dropout, batch_first=True)
36
+
37
+ def forward(self, x: torch.Tensor):
38
+ x = self.layer_norm(x)
39
+
40
+ attn_output, _ = self.multihead_attn(query=x, key=x, value=x,
41
+ need_weights=False)
42
+ return attn_output
43
+
44
+ class MLPBlock(nn.Module):
45
+ def __init__(self, embedding_dim: int=768,
46
+ mlp_size: int=3072, dropout: int=0.1):
47
+ super().__init__()
48
+
49
+ self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
50
+
51
+ self.mlp = nn.Sequential(
52
+ nn.Linear(in_features=embedding_dim,
53
+ out_features=mlp_size),
54
+ nn.GELU(),
55
+ nn.Dropout(p=dropout),
56
+ nn.Linear(in_features=mlp_size,
57
+ out_features=embedding_dim),
58
+ nn.Dropout(p=dropout)
59
+ )
60
+
61
+ def forward(self, x: torch.Tensor):
62
+ return self.mlp(self.layer_norm(x))
63
+
64
+ class TransformerEncoderBlock(nn.Module):
65
+ def __init__(self, embedding_dim: int=768,
66
+ mlp_size: int=3072, num_heads: int=12,
67
+ mlp_dropout: int=0.1, attn_dropout: int=0):
68
+ super().__init__()
69
+
70
+ self.msa_block = MultiheadSelfAttention(embedding_dim=embedding_dim,
71
+ num_heads=num_heads, attn_dropout=attn_dropout)
72
+
73
+ self.mlp_block = MLPBlock(embedding_dim=embedding_dim,
74
+ mlp_size=mlp_size, dropout=mlp_dropout)
75
+
76
+ def forward(self, x: torch.Tensor):
77
+ x = self.msa_block(x) + x
78
+ x = self.mlp_block(x) + x
79
+ return x
80
+
81
+ class VisionTransformer(nn.Module):
82
+ def __init__(self, img_size: int=IMG_SIZE,
83
+ in_channels: int=3, patch_size: int=16,
84
+ num_transformer_layers: int=12, embedding_dim: int=768,
85
+ mlp_size: int=3072, num_heads: int=12,
86
+ attn_dropout: int=0, mlp_dropout: int=0.1,
87
+ embedding_dropout: int=0.1, num_classes: int=38):
88
+ super().__init__()
89
+
90
+ assert img_size % patch_size == 0, f"Image size must be divisible by patch size!"
91
+
92
+ self.num_patches = (img_size * img_size) // patch_size**2
93
+
94
+ self.class_embedding = nn.Parameter(data=torch.randn(1, 1, embedding_dim),
95
+ requires_grad=True)
96
+
97
+ self.position_embedding = nn.Parameter(data=torch.randn(1, self.num_patches+1, embedding_dim),
98
+ requires_grad=True)
99
+
100
+ self.embedding_dropout = nn.Dropout(p=embedding_dropout)
101
+
102
+ self.patch_embeddings = PatchEmbeddings(in_channels=in_channels,
103
+ patch_size=patch_size, embedding_dim=embedding_dim)
104
+
105
+ self.transformer_encoder = nn.Sequential(*[TransformerEncoderBlock(embedding_dim=embedding_dim,
106
+ num_heads=num_heads, mlp_size=mlp_size,
107
+ mlp_dropout=mlp_dropout) for _ in range(num_transformer_layers)])
108
+
109
+ self.classifier = nn.Sequential(
110
+ nn.LayerNorm(normalized_shape=embedding_dim),
111
+ nn.Linear(in_features=embedding_dim,
112
+ out_features=num_classes)
113
+ )
114
+
115
+ def forward(self, x: torch.Tensor):
116
+ batch_size = x.shape[0]
117
+
118
+ class_token = self.class_embedding.expand(batch_size, -1, -1)
119
+
120
+ x = self.patch_embeddings(x)
121
+ x = torch.cat((class_token, x), dim=1)
122
+
123
+ x = self.position_embedding + x
124
+ x = self.embedding_dropout(x)
125
+
126
+ x = self.transformer_encoder(x)
127
+ x = self.classifier(x[:, 0])
128
+
129
+ return x
130
+
131
+ with open("class_names.ob", "rb") as fp:
132
+ class_names = pickle.load(fp)
133
+
134
+ vision_transformer = VisionTransformer(num_classes=len(class_names))
pytorch_vit_b_16_timm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b82b34f5fc2aa9be5dac0f146fcccb9589481ee5f93717d83f158695329da181
3
+ size 343366929
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ gradio==3.16.2
4
+ timm==0.4.5
5
+ nopdb
6
+ IPython