LeafDoc / app.py
TexR6's picture
Image.ANTIALIAS changed to Image.Resampling.LANCZOS
dd62dfe
import os
import gc
import PIL
import glob
import timm
import torch
import nopdb
import pickle
import torchvision
import numpy as np
import gradio as gr
from torch import nn
from PIL import Image
import matplotlib.pyplot as plt
import IPython.display as ipd
from typing import Tuple, Dict
from timeit import default_timer as timer
from timm.data import resolve_data_config, create_transform
example_list = [["examples/" + example] for example in os.listdir("examples")]
vision_transformer_weights = torch.load('pytorch_vit_b_16_timm.pth',
map_location=torch.device('cpu'))
vision_transformer = timm.create_model('vit_base_patch16_224', pretrained=False)
vision_transformer.head = nn.Linear(in_features=768,
out_features=38)
vision_transformer.load_state_dict(vision_transformer_weights)
from torchvision import datasets, transforms
data_transforms = transforms.Compose([
transforms.Resize(size=(256, 256)),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],)
])
def inv_normalize(tensor):
"""Normalize an image tensor back to the 0-255 range."""
tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
return tensor
def inv_transform(tensor, normalize=True):
"""Convert a tensor back to an image."""
tensor = inv_normalize(tensor)
array = tensor.detach().cpu().numpy()
array = array.transpose(1, 2, 0).astype(np.uint8)
return PIL.Image.fromarray(array)
with open('class_names.ob', 'rb') as fp:
class_names = pickle.load(fp)
img = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
img_transformed = data_transforms(img)
def predict_disease(image) -> Tuple[Dict, float]:
"""Return prediction classes with probabilities for an input image."""
input = data_transforms(image)
start_time = timer()
prediction_dict = {}
with torch.inference_mode():
[logits] = vision_transformer(input[None])
probs = torch.softmax(logits, dim=0)
topk_prob, topk_id = torch.topk(probs, 3)
for i in range(topk_prob.size(0)):
prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item()
prediction_time = round(timer() - start_time, 5)
return prediction_dict, prediction_time
def predict_tensor(img_tensor):
"""Return prediction classes with probabilities for an input image."""
with torch.inference_mode():
[logits] = vision_transformer(img_tensor[None])
probs = torch.softmax(logits, dim=0)
topk_prob, topk_id = torch.topk(probs, 3)
random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
def plot_attention(image, encoder_layer_num=5):
"""Given an input image, plot the average attention weight given to each image patch by each attention head."""
attention_map_outputs = []
input_data = data_transforms(image)
with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call:
predict_tensor(img_transformed)
attn = attn_call.locals['attn'][0]
with torch.inference_mode():
# loop over attention heads
for h_weights in attn:
h_weights = h_weights.mean(axis=-2) # average over all attention keys
h_weights = h_weights[1:] # skip the [class] token
output_img = plot_weights(input_data, h_weights)
attention_map_outputs.append(output_img)
return attention_map_outputs
def plot_weights(input_data, patch_weights):
"""Display the image: Brighter the patch, higher is the attention."""
# multiply each patch of the input image by the corresponding weight
plot = inv_normalize(input_data.clone())
for i in range(patch_weights.shape[0]):
x = i * 16 % 224
y = i // (224 // 16) * 16
plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
attn_map_img = inv_transform(plot, normalize=False)
attn_map_img = attn_map_img.resize((224, 224), Image.Resampling.LANCZOS)
return attn_map_img
attention_maps = plot_attention(random_image, 5)
title_classify = "Image Based Plant Disease Identification 🍃🤓"
description_classify = """Finetuned a Vision Transformer Base (Patch Size: 16 | Image Size: 224) architecture to
identify the plant disease."""
article_classify = """Upload an image from the example list or choose one of your own. [Dataset Classes](https://data.mendeley.com/datasets/tywbtsjrjv/1)"""
title_attention = "Visualize Attention Weights 🧊🔍"
description_attention = """The Vision Transformer Base architecture has 12 transformer Encoder layers (12 attention heads in each)."""
article_attention = """From the dropdown menu, choose the Encoder layer whose attention weights you would like to visualize."""
classify_interface = gr.Interface(
fn=predict_disease,
inputs=gr.Image(type="pil", label="Image"),
outputs=[gr.Label(num_top_classes=3, label="Predictions"),
gr.Number(label="Prediction time (secs)")],
examples=example_list,
title=title_classify,
description=description_classify,
article=article_classify,
thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
)
attention_interface = gr.Interface(
fn=plot_attention,
inputs=[gr.Image(type="pil", label="Image"),
gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
label="Attention Layer", value="6", type="index")],
outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
examples=example_list,
title=title_attention,
description=description_attention,
article=article_attention,
thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
)
demo = gr.TabbedInterface([classify_interface, attention_interface],
["Identify Disease", "Visualize Attention Map"], title="NatureAI Diagnostics🧑🩺")
if __name__ == "__main__":
demo.launch()