File size: 6,337 Bytes
9bec60b 0c3fb24 9bec60b fb71b6d 9bec60b b694259 9bec60b 3f3a5b0 fb71b6d 9bec60b 3f3a5b0 9bec60b fb71b6d 9bec60b a2446f8 9bec60b b6c588d 9bec60b a2446f8 b72342a 9bec60b 3f3a5b0 749a9c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import gc
import PIL
import glob
import timm
import torch
import nopdb
import pickle
import torchvision
import numpy as np
import gradio as gr
from torch import nn
from PIL import Image
import matplotlib.pyplot as plt
import IPython.display as ipd
from typing import Tuple, Dict
from timeit import default_timer as timer
from timm.data import resolve_data_config, create_transform
example_list = [["examples/" + example] for example in os.listdir("examples")]
vision_transformer_weights = torch.load('pytorch_vit_b_16_timm.pth',
map_location=torch.device('cpu'))
vision_transformer = timm.create_model('vit_base_patch16_224', pretrained=False)
vision_transformer.head = nn.Linear(in_features=768,
out_features=38)
vision_transformer.load_state_dict(vision_transformer_weights)
from torchvision import datasets, transforms
data_transforms = transforms.Compose([
transforms.Resize(size=(256, 256)),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],)
])
def inv_normalize(tensor):
"""Normalize an image tensor back to the 0-255 range."""
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
return tensor
def inv_transform(tensor, normalize=True):
"""Convert a tensor back to an image."""
tensor = inv_normalize(tensor)
array = tensor.detach().cpu().numpy()
array = array.transpose(1, 2, 0).astype(np.uint8)
return PIL.Image.fromarray(array)
with open('class_names.ob', 'rb') as fp:
class_names = pickle.load(fp)
img = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
img_transformed = data_transforms(img)
def predict_disease(image) -> Tuple[Dict, float]:
"""Return prediction classes with probabilities for an input image."""
input = data_transforms(image)
start_time = timer()
prediction_dict = {}
with torch.inference_mode():
[logits] = vision_transformer(input[None])
probs = torch.softmax(logits, dim=0)
topk_prob, topk_id = torch.topk(probs, 3)
for i in range(topk_prob.size(0)):
prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item()
prediction_time = round(timer() - start_time, 5)
return prediction_dict, prediction_time
def predict_tensor(img_tensor):
"""Return prediction classes with probabilities for an input image."""
with torch.inference_mode():
[logits] = vision_transformer(img_tensor[None])
probs = torch.softmax(logits, dim=0)
topk_prob, topk_id = torch.topk(probs, 3)
random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
def plot_attention(image, encoder_layer_num=5):
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
attention_map_outputs = []
input_data = data_transforms(image)
with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
predict_tensor(img_transformed)
attn = attn_call.locals['attn'][0]
with torch.inference_mode():
# loop over attention heads
for h_weights in attn:
h_weights = h_weights.mean(axis=-2) # average over all attention keys
h_weights = h_weights[1:] # skip the [class] token
output_img = plot_weights(input_data, h_weights)
attention_map_outputs.append(output_img)
return attention_map_outputs
def plot_weights(input_data, patch_weights):
"""Display the image: Brighter the patch, higher is the attention."""
# multiply each patch of the input image by the corresponding weight
plot = inv_normalize(input_data.clone())
for i in range(patch_weights.shape[0]):
x = i * 16 % 224
y = i // (224 // 16) * 16
plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
attn_map_img = inv_transform(plot, normalize=False)
attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
return attn_map_img
attention_maps = plot_attention(random_image, 5)
title_classify = "Image Based Plant Disease Identification 🍃🤓"
description_classify = """Finetuned a Vision Transformer Base (Patch Size: 16 | Image Size: 224) architecture to
identify the plant disease."""
article_classify = """Upload an image from the example list or choose one of your own. [Dataset Classes](https://data.mendeley.com/datasets/tywbtsjrjv/1)"""
title_attention = "Visualize Attention Weights 🧊🔍"
description_attention = """The Vision Transformer Base architecture has 12 transformer Encoder layers (12 attention heads in each)."""
article_attention = """From the dropdown menu, choose the Encoder layer whose attention weights you would like to visualize."""
classify_interface = gr.Interface(
fn=predict_disease,
inputs=gr.Image(type="pil", label="Image"),
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
gr.Number(label="Prediction time (secs)")],
examples=example_list,
title=title_classify,
description=description_classify,
article=article_classify,
thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
)
attention_interface = gr.Interface(
fn=plot_attention,
inputs=[gr.Image(type="pil", label="Image"),
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
label="Attention Layer", value="6", type="index")],
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
examples=example_list,
title=title_attention,
description=description_attention,
article=article_attention,
thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
)
demo = gr.TabbedInterface([classify_interface, attention_interface],
["Identify Disease", "Visualize Attention Map"], title="NatureAI Diagnostics🧑🩺")
if __name__ == "__main__":
demo.launch() |