TexR6 commited on
Commit
e324642
1 Parent(s): 51cf39f

delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -154
app.py DELETED
@@ -1,154 +0,0 @@
1
- import os
2
- import gc
3
- import PIL
4
- import glob
5
- import timm
6
- import torch
7
- import nopdb
8
- import pickle
9
- import torchvision
10
-
11
- import numpy as np
12
- import gradio as gr
13
- from torch import nn
14
- from PIL import Image
15
- import matplotlib.pyplot as plt
16
- import IPython.display as ipd
17
- from typing import Tuple, Dict
18
- from timeit import default_timer as timer
19
- from timm.data import resolve_data_config, create_transform
20
-
21
- example_list = [["examples/" + example] for example in os.listdir("examples")]
22
-
23
- vision_transformer_weights = torch.load('pytorch_vit_b_16_timm.pth',
24
- map_location=torch.device('cpu'))
25
-
26
- vision_transformer = timm.create_model('vit_base_patch16_224', pretrained=False)
27
-
28
- vision_transformer.head = nn.Linear(in_features=768,
29
- out_features=38)
30
-
31
- vision_transformer.load_state_dict(vision_transformer_weights)
32
-
33
- from torchvision import datasets, transforms
34
-
35
- data_transforms = transforms.Compose([
36
- transforms.Resize(size=(256, 256)),
37
- transforms.CenterCrop(size=224),
38
- transforms.ToTensor(),
39
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
40
- std=[0.229, 0.224, 0.225],)
41
- ])
42
-
43
- def inv_normalize(tensor):
44
- """Normalize an image tensor back to the 0-255 range."""
45
- tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * (256 - 1e-5)
46
- return tensor
47
-
48
- def inv_transform(tensor, normalize=True):
49
- """Convert a tensor back to an image."""
50
- tensor = inv_normalize(tensor)
51
- array = tensor.detach().cpu().numpy()
52
- array = array.transpose(1, 2, 0).astype(np.uint8)
53
- return PIL.Image.fromarray(array)
54
-
55
- with open('class_names.ob', 'rb') as fp:
56
- class_names = pickle.load(fp)
57
-
58
- img = PIL.Image.open('examples/TomatoYellowCurlVirus6.JPG').convert('RGB')
59
- img_transformed = data_transforms(img)
60
-
61
- def predict_disease(image) -> Tuple[Dict, float]:
62
- """Return prediction classes with probabilities for an input image."""
63
- input = data_transforms(image)
64
- start_time = timer()
65
- prediction_dict = {}
66
- with torch.inference_mode():
67
- [logits] = vision_transformer(input[None])
68
- probs = torch.softmax(logits, dim=0)
69
- topk_prob, topk_id = torch.topk(probs, 3)
70
- for i in range(topk_prob.size(0)):
71
- prediction_dict[class_names[topk_id[i]]] = topk_prob[i].item()
72
- prediction_time = round(timer() - start_time, 5)
73
- return prediction_dict, prediction_time
74
-
75
- def predict_tensor(img_tensor):
76
- """Return prediction classes with probabilities for an input image."""
77
- with torch.inference_mode():
78
- [logits] = vision_transformer(img_tensor[None])
79
- probs = torch.softmax(logits, dim=0)
80
- topk_prob, topk_id = torch.topk(probs, 3)
81
-
82
- with nopdb.capture_call(vision_transformer.blocks[5].attn.forward) as attn_call:
83
- predict_tensor(img_transformed)
84
-
85
- def plot_attention(image, layer_num):
86
- """Given an input image, plot the average attention weight given to each image patch by each attention head."""
87
- input_data = data_transforms(image)
88
- with nopdb.capture_call(vision_transformer.blocks[int(layer_num)-1].attn.forward) as attn_call:
89
- predict_tensor(img_transformed)
90
- attn = attn_call.locals['attn'][0]
91
- with torch.inference_mode():
92
- # loop over attention heads
93
- attention_block_num = 0
94
- for h_weights in attn:
95
- h_weights = h_weights.mean(axis=-2) # average over all attention keys
96
- h_weights = h_weights[1:] # skip the [class] token
97
- attention_block_num += 1
98
- plot_weights(input_data, h_weights, attention_block_num)
99
- attention_maps = glob.glob('storage/*.png')
100
- return attention_maps
101
-
102
- def plot_weights(input_data, patch_weights, num_attention_block):
103
- """Display the image: Brighter the patch, higher is the attention."""
104
- # multiply each patch of the input image by the corresponding weight
105
- plot = inv_normalize(input_data.clone())
106
- for i in range(patch_weights.shape[0]):
107
- x = i * 16 % 224
108
- y = i // (224 // 16) * 16
109
- plot[:, y:y + 16, x:x + 16] *= patch_weights[i]
110
- attn_map_img = inv_transform(plot, normalize=False)
111
- attn_map_img = attn_map_img.resize((224, 224), Image.ANTIALIAS)
112
- attn_map_img.save(f"storage/attention_map_{num_attention_block}.png", "PNG")
113
-
114
- title_classify = "Image Based Plant Disease Identification 🍃🤓"
115
-
116
- description_classify = """Finetuned a Vision Transformer Base (Patch Size: 16 | Image Size: 224) architecture to
117
- identify the plant disease."""
118
-
119
- article_classify = """Upload an image from the example list or choose one of your own. [Dataset Classes](https://data.mendeley.com/datasets/tywbtsjrjv/1)"""
120
-
121
- title_attention = "Visualize Attention Weights 🧊🔍"
122
-
123
- description_attention = """The Vision Transformer Base architecture has 12 transformer Encoder layers (12 attention heads in each)."""
124
-
125
- article_attention = """From the dropdown menu, choose the Encoder layer whose attention weights you would like to visualize."""
126
-
127
- classify_interface = gr.Interface(
128
- fn=predict_disease,
129
- inputs=gr.Image(type="pil", label="Image"),
130
- outputs=[gr.Label(num_top_classes=3, label="Predictions"),
131
- gr.Number(label="Prediction time (secs)")],
132
- examples=example_list,
133
- title=title_classify,
134
- description=description_classify,
135
- article=article_classify,
136
- thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
137
- )
138
-
139
- attention_interface = gr.Interface(
140
- fn=plot_attention,
141
- inputs=[gr.Image(type="pil", label="Image"),
142
- gr.Dropdown(choices=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"],
143
- label="Attention Layer", value="6")],
144
- outputs=gr.Gallery(value=attention_maps, label="Attention Maps").style(grid=(3, 4)),
145
- examples=example_list,
146
- title=title_attention,
147
- description=description_attention,
148
- article=article_attention,
149
- thumbnail="https://images.unsplash.com/photo-1470058869958-2a77ade41c02?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1170&q=80"
150
- )
151
-
152
- demo = gr.TabbedInterface([classify_interface, attention_interface],
153
- ["Identify Disease", "Visualize Attention Map"],
154
- title="NatureAI Diagnostics🧑🩺").launch(debug=False, share=True)