|
import os |
|
import gc |
|
import PIL |
|
import glob |
|
import timm |
|
import torch |
|
import nopdb |
|
import pickle |
|
import torchvision |
|
|
|
import numpy as np |
|
import gradio as gr |
|
from torch import nn |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import IPython.display as ipd |
|
from typing import Tuple, Dict |
|
from timeit import default_timer as timer |
|
from timm.data import resolve_data_config, create_transform |
|
|
|
example_list = [["examples/" + example] for example in os.listdir("examples")] |
|
|
|
vision_transformer_weights = torch.load('pytorch_vit_b_16_timm.pth', |
|
map_location=torch.device('cpu')) |
|
|
|
vision_transformer = timm.create_model('vit_base_patch16_224', pretrained=False) |
|
|
|
vision_transformer.head = nn.Linear(in_features=768, |
|
out_features=38) |
|
|
|
vision_transformer.load_state_dict(vision_transformer_weights) |
|
|
|
from torchvision import datasets, transforms |
|
|
|
data_transforms = transforms.Compose([ |
|
transforms.Resize(size=(256, 256)), |
|
transforms.CenterCrop(size=224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225],) |
|
]) |
|
|
|
def inv_normalize(tensor): |
|
"""Normalize an image tensor back to the 0-255 range.""" |
|
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5) |
|
return tensor |
|
|
|
def inv_transform(tensor, normalize=True): |
|
"""Convert a tensor back to an image.""" |
|
tensor = inv_normalize(tensor) |
|
array = tensor.detach().cpu().numpy() |
|
array = array.transpose(1, 2, 0).astype(np.uint8) |
|
return PIL.Image.fromarray(array) |
|
|
|
with open('class_names.ob', 'rb') as fp: |
|
class_names = pickle.load(fp) |
|
|
|
img = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB') |
|
img_transformed = data_transforms(img) |
|
|
|
def predict_disease(image) -> Tuple[Dict, float]: |
|
"""Return prediction classes with probabilities for an input image.""" |
|
input = data_transforms(image) |
|
start_time = timer() |
|
prediction_dict = {} |
|
with torch.inference_mode(): |
|
[logits] = vision_transformer(input[None]) |
|
probs = torch.softmax(logits, dim=0) |
|
topk_prob, topk_id = torch.topk(probs, 3) |
|
for i in range(topk_prob.size(0)): |
|
prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item() |
|
prediction_time = round(timer() - start_time, 5) |
|
return prediction_dict, prediction_time |
|
|
|
def predict_tensor(img_tensor): |
|
"""Return prediction classes with probabilities for an input image.""" |
|
with torch.inference_mode(): |
|
[logits] = vision_transformer(img_tensor[None]) |
|
probs = torch.softmax(logits, dim=0) |
|
topk_prob, topk_id = torch.topk(probs, 3) |
|
|
|
random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB') |
|
|
|
def plot_attention(image, encoder_layer_num=5): |
|
"""Given an input image, plot the average attention weight given to each image patch by each attention head.""" |
|
print(f'Selected encoder layer num: {encoder_layer_num}') |
|
attention_map_outputs = [] |
|
input_data = data_transforms(image) |
|
with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call: |
|
predict_tensor(img_transformed) |
|
attn = attn_call.locals['attn'][0] |
|
with torch.inference_mode(): |
|
|
|
for h_weights in attn: |
|
h_weights = h_weights.mean(axis=-2) |
|
h_weights = h_weights[1:] |
|
output_img = plot_weights(input_data, h_weights) |
|
attention_map_outputs.append(output_img) |
|
return attention_map_outputs |
|
|
|
def plot_weights(input_data, patch_weights): |
|
"""Display the image: Brighter the patch, higher is the attention.""" |
|
|
|
plot = inv_normalize(input_data.clone()) |
|
for i in range(patch_weights.shape[0]): |
|
x = i * 16 % 224 |
|
y = i // (224 // 16) * 16 |
|
plot[:, y:y + 16, x:x + 16] *= patch_weights[i] |
|
attn_map_img = inv_transform(plot, normalize=False) |
|
attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS) |
|
return attn_map_img |
|
|
|
attention_maps = plot_attention(random_image, 5) |
|
|
|
title_classify = "Image Based Plant Disease Identification 🍃🤓" |
|
|
|
description_classify = """Finetuned a Vision Transformer Base (Patch Size: 16 | Image Size: 224) architecture to |
|
identify the plant disease.""" |
|
|
|
article_classify = """Upload an image from the example list or choose one of your own. [Dataset Classes](https://data.mendeley.com/datasets/tywbtsjrjv/1)""" |
|
|
|
title_attention = "Visualize Attention Weights 🧊🔍" |
|
|
|
description_attention = """The Vision Transformer Base architecture has 12 transformer Encoder layers (12 attention heads in each).""" |
|
|
|
article_attention = """From the dropdown menu, choose the Encoder layer whose attention weights you would like to visualize.""" |
|
|
|
classify_interface = gr.Interface( |
|
fn=predict_disease, |
|
inputs=gr.Image(type="pil", label="Image"), |
|
outputs=[gr.Label(num_top_classes=3, label="Predictions"), |
|
gr.Number(label="Prediction time (secs)")], |
|
examples=example_list, |
|
title=title_classify, |
|
description=description_classify, |
|
article=article_classify, |
|
thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80" |
|
) |
|
|
|
attention_interface = gr.Interface( |
|
fn=plot_attention, |
|
inputs=[gr.Image(type="pil", label="Image"), |
|
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], value="6", type="index")], |
|
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)), |
|
examples=example_list, |
|
title=title_attention, |
|
description=description_attention, |
|
article=article_attention, |
|
thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80" |
|
) |
|
|
|
demo = gr.TabbedInterface([classify_interface, attention_interface], |
|
["Identify Disease", "Visualize Attention Map"], title="NatureAI Diagnostics🧑🩺") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |