import os import gc import PIL import glob import timm import torch import nopdb import pickle import torchvision import numpy as np import gradio as gr from torch import nn from PIL import Image import matplotlib.pyplot as plt import IPython.display as ipd from typing import Tuple, Dict from timeit import default_timer as timer from timm.data import resolve_data_config, create_transform example_list = [["examples/" + example] for example in os.listdir("examples")] vision_transformer_weights = torch.load('pytorch_vit_b_16_timm.pth', map_location=torch.device('cpu')) vision_transformer = timm.create_model('vit_base_patch16_224', pretrained=False) vision_transformer.head = nn.Linear(in_features=768, out_features=38) vision_transformer.load_state_dict(vision_transformer_weights) from torchvision import datasets, transforms data_transforms = transforms.Compose([ transforms.Resize(size=(256, 256)), transforms.CenterCrop(size=224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],) ]) def inv_normalize(tensor): """Normalize an image tensor back to the 0-255 range.""" tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5) return tensor def inv_transform(tensor, normalize=True): """Convert a tensor back to an image.""" tensor = inv_normalize(tensor) array = tensor.detach().cpu().numpy() array = array.transpose(1, 2, 0).astype(np.uint8) return PIL.Image.fromarray(array) with open('class_names.ob', 'rb') as fp: class_names = pickle.load(fp) img = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB') img_transformed = data_transforms(img) def predict_disease(image) -> Tuple[Dict, float]: """Return prediction classes with probabilities for an input image.""" input = data_transforms(image) start_time = timer() prediction_dict = {} with torch.inference_mode(): [logits] = vision_transformer(input[None]) probs = torch.softmax(logits, dim=0) topk_prob, topk_id = torch.topk(probs, 3) for i in range(topk_prob.size(0)): prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item() prediction_time = round(timer() - start_time, 5) return prediction_dict, prediction_time def predict_tensor(img_tensor): """Return prediction classes with probabilities for an input image.""" with torch.inference_mode(): [logits] = vision_transformer(img_tensor[None]) probs = torch.softmax(logits, dim=0) topk_prob, topk_id = torch.topk(probs, 3) random_image = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB') def plot_attention(image, encoder_layer_num=5): """Given an input image, plot the average attention weight given to each image patch by each attention head.""" attention_map_outputs = [] input_data = data_transforms(image) with nopdb.capture_call(vision_transformer.blocks[encoder_layer_num].attn.forward) as attn_call: predict_tensor(img_transformed) attn = attn_call.locals['attn'][0] with torch.inference_mode(): # loop over attention heads for h_weights in attn: h_weights = h_weights.mean(axis=-2) # average over all attention keys h_weights = h_weights[1:] # skip the [class] token output_img = plot_weights(input_data, h_weights) attention_map_outputs.append(output_img) return attention_map_outputs def plot_weights(input_data, patch_weights): """Display the image: Brighter the patch, higher is the attention.""" # multiply each patch of the input image by the corresponding weight plot = inv_normalize(input_data.clone()) for i in range(patch_weights.shape[0]): x = i * 16 % 224 y = i // (224 // 16) * 16 plot[:, y:y + 16, x:x + 16] *= patch_weights[i] attn_map_img = inv_transform(plot, normalize=False) attn_map_img = attn_map_img.resize((224, 224), Image.Resampling.LANCZOS) return attn_map_img attention_maps = plot_attention(random_image, 5) title_classify = "Image Based Plant Disease Identification 🍃🤓" description_classify = """Finetuned a Vision Transformer Base (Patch Size: 16 | Image Size: 224) architecture to identify the plant disease.""" article_classify = """Upload an image from the example list or choose one of your own. [Dataset Classes](https://data.mendeley.com/datasets/tywbtsjrjv/1)""" title_attention = "Visualize Attention Weights 🧊🔍" description_attention = """The Vision Transformer Base architecture has 12 transformer Encoder layers (12 attention heads in each).""" article_attention = """From the dropdown menu, choose the Encoder layer whose attention weights you would like to visualize.""" classify_interface = gr.Interface( fn=predict_disease, inputs=gr.Image(type="pil", label="Image"), outputs=[gr.Label(num_top_classes=3, label="Predictions"), gr.Number(label="Prediction time (secs)")], examples=example_list, title=title_classify, description=description_classify, article=article_classify, thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80" ) attention_interface = gr.Interface( fn=plot_attention, inputs=[gr.Image(type="pil", label="Image"), gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"], label="Attention Layer", value="6", type="index")], outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)), examples=example_list, title=title_attention, description=description_attention, article=article_attention, thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80" ) demo = gr.TabbedInterface([classify_interface, attention_interface], ["Identify Disease", "Visualize Attention Map"], title="NatureAI Diagnostics🧑🩺") if __name__ == "__main__": demo.launch()