Gosula commited on
Commit
1d30310
1 Parent(s): aab34e9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from pytorch_grad_cam import GradCAM
7
+ from pytorch_grad_cam.utils.image import show_cam_on_image
8
+ from custom_resnet import *
9
+ #from resnet import ResNet18 # Assuming you have a custom ResNet18 implementation
10
+
11
+ def load_custom_state_dict(model, state_dict):
12
+ model_state_dict = model.state_dict()
13
+ # Filter out unexpected keys
14
+ filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
15
+ # Update the model's state_dict
16
+ model_state_dict.update(filtered_state_dict)
17
+ # Load the updated state_dict to the model
18
+ model.load_state_dict(model_state_dict)
19
+
20
+
21
+ model = CustomResNet() # Replace this with your CustomResNet if necessary
22
+ # Load the state_dict using the custom function
23
+ state_dict = torch.load("model_pth.ckpt", map_location=torch.device('cpu'))
24
+ load_custom_state_dict(model, state_dict['state_dict'])
25
+
26
+ inv_normalize = transforms.Normalize(
27
+ mean=[-0.494 / 0.2470, -0.4822 / 0.2435, -0.4465 / 0.2616],
28
+ std=[1 / 0.2470, 1 / 0.2435, 1 / 0.2616]
29
+ )
30
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
31
+ 'dog', 'frog', 'horse', 'ship', 'truck')
32
+
33
+ def inference(input_img, transparency=0.5, target_layer_number=-1, num_images=1, num_top_classes=3):
34
+ transform = transforms.ToTensor()
35
+ org_img = input_img
36
+ input_img = transform(input_img)
37
+ input_img = input_img.unsqueeze(0)
38
+ outputs = model(input_img)
39
+ softmax = torch.nn.Softmax(dim=1)
40
+ probabilities = softmax(outputs)
41
+ confidences = {classes[i]: float(probabilities[0, i]) for i in range(10)}
42
+
43
+ _, prediction = torch.max(outputs, 1)
44
+
45
+ # Get GradCAM for the specified target_layer_number
46
+ target_layers = [model.layer_2[target_layer_number]]
47
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
48
+ grayscale_cam = cam(input_tensor=input_img, targets=None)
49
+ grayscale_cam = grayscale_cam[0, :]
50
+ img = input_img.squeeze(0)
51
+ img = inv_normalize(img)
52
+ rgb_img = np.transpose(img, (1, 2, 0))
53
+ rgb_img = rgb_img.numpy()
54
+
55
+ # Convert org_img (PIL image) to a NumPy array before performing arithmetic operations
56
+ visualization = show_cam_on_image(org_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency)
57
+
58
+ # Create a list to store multiple visualizations
59
+
60
+
61
+ # # Generate multiple GradCAM visualizations if num_images > 1
62
+ # for _ in range(num_images - 1):
63
+ # # Get GradCAM for different target_layer_number if provided by the user
64
+ # if target_layer_number >= -1:
65
+ # target_layers = [model.layer_2[target_layer_number]]
66
+ # cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
67
+ # grayscale_cam = cam(input_tensor=input_img, targets=None)
68
+ # grayscale_cam = grayscale_cam[0, :]
69
+
70
+ # visualization = show_cam_on_image(org_img / 255, grayscale_cam, use_rgb=True, image_weight=transparency)
71
+ # visualizations.append(visualization)
72
+
73
+ # Get top classes based on user input (up to a maximum of 10)
74
+ top_classes = {k: v for k, v in sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:min(num_top_classes, 10)]}
75
+
76
+ return top_classes, visualization
77
+
78
+
79
+
80
+
81
+ title = "CIFAR10 trained on ResNet18 Model with GradCAM"
82
+ description = "A simple Gradio interface to infer on ResNet model, and get GradCAM results"
83
+ examples = [["/content/examples/car_1.jpg",0.5,-1],["/content/examples/car_2.jpg",0.5,-1],["/content/examples/cat_1.jpg",0.5,-1],["/content/examples/cat_2.jpg",0.5,-1],["/content/examples/dog_1.jpg",0.5,-1],["/content/examples/dog_2.jpg",0.5,-1],["/content/examples/frog_1.jpg",0.5,-1],["/content/examples/frog_2.jpg",0.5,-1],["/content/examples/horse_1.jpg",0.5,-1],["/content/examples/horse_2.jpg",0.5,-1]]
84
+ demo = gr.Interface(
85
+ inference,
86
+ inputs = [gr.Image(shape=(32, 32), label="Input Image"),
87
+ gr.Slider(0, 1, value=0.5, label="Opacity of GradCAM"),
88
+ gr.Slider(-2, -1, value=-2, step=1, label="Which Layer?"),
89
+ gr.Number(default=1, label="Number of GradCAM Images to Show"),
90
+ gr.Slider(1, 10, value=3, step=1, label="Number of Top Classes to Show")],
91
+ outputs = [gr.Label(num_top_classes=5), gr.Image(shape=(32, 32), label="Output").style(width=128, height=128)],
92
+ title = title,
93
+ description = description,
94
+ examples = examples,
95
+ )
96
+ demo.launch()