subhannadeem1 commited on
Commit
022243f
1 Parent(s): 2dc96c6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import matplotlib.pyplot as plt
4
+ import requests
5
+ import streamlit as st
6
+ import torch
7
+ from PIL import Image
8
+ from torchvision import models
9
+ from torchvision.transforms.functional import normalize, resize, to_pil_image, to_tensor
10
+
11
+ from torchcam import methods
12
+ from torchcam.methods._utils import locate_candidate_layer
13
+ from torchcam.utils import overlay_mask
14
+
15
+ CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"]
16
+ TV_MODELS = [
17
+ "resnet18",
18
+ "resnet50",
19
+ "mobilenet_v3_small",
20
+ "mobilenet_v3_large",
21
+ "regnet_y_400mf",
22
+ "convnext_tiny",
23
+ "convnext_small",
24
+ ]
25
+ LABEL_MAP = requests.get(
26
+ "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
27
+ ).json()
28
+
29
+
30
+ def main():
31
+
32
+ # Wide mode
33
+ st.set_page_config(layout="wide")
34
+
35
+ # Designing the interface
36
+ st.title("TorchCAM: class activation explorer")
37
+ # For newline
38
+ st.write("\n")
39
+ # Set the columns
40
+ cols = st.columns((1, 1, 1))
41
+ cols[0].header("Input image")
42
+ cols[1].header("Raw CAM")
43
+ cols[-1].header("Overlayed CAM")
44
+
45
+ # Sidebar
46
+ # File selection
47
+ st.sidebar.title("Input selection")
48
+ # Disabling warning
49
+ st.set_option("deprecation.showfileUploaderEncoding", False)
50
+ # Choose your own image
51
+ uploaded_file = st.sidebar.file_uploader("Upload files", type=["png", "jpeg", "jpg"])
52
+ if uploaded_file is not None:
53
+ img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
54
+
55
+ cols[0].image(img, use_column_width=True)
56
+
57
+ # Model selection
58
+ st.sidebar.title("Setup")
59
+ tv_model = st.sidebar.selectbox(
60
+ "Classification model",
61
+ TV_MODELS,
62
+ help="Supported models from Torchvision",
63
+ )
64
+ default_layer = ""
65
+ if tv_model is not None:
66
+ with st.spinner("Loading model..."):
67
+ model = models.__dict__[tv_model](pretrained=True).eval()
68
+ default_layer = locate_candidate_layer(model, (3, 224, 224))
69
+
70
+ if torch.cuda.is_available():
71
+ model = model.cuda()
72
+
73
+ target_layer = st.sidebar.text_input(
74
+ "Target layer",
75
+ default_layer,
76
+ help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")',
77
+ )
78
+ cam_method = st.sidebar.selectbox(
79
+ "CAM method",
80
+ CAM_METHODS,
81
+ help="The way your class activation map will be computed",
82
+ )
83
+ if cam_method is not None:
84
+ cam_extractor = methods.__dict__[cam_method](
85
+ model, target_layer=[s.strip() for s in target_layer.split("+")] if len(target_layer) > 0 else None
86
+ )
87
+
88
+ class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
89
+ class_selection = st.sidebar.selectbox("Class selection", ["Predicted class (argmax)"] + class_choices)
90
+
91
+ # For newline
92
+ st.sidebar.write("\n")
93
+
94
+ if st.sidebar.button("Compute CAM"):
95
+
96
+ if uploaded_file is None:
97
+ st.sidebar.error("Please upload an image first")
98
+
99
+ else:
100
+ with st.spinner("Analyzing..."):
101
+
102
+ # Preprocess image
103
+ img_tensor = normalize(to_tensor(resize(img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
104
+
105
+ if torch.cuda.is_available():
106
+ img_tensor = img_tensor.cuda()
107
+
108
+ # Forward the image to the model
109
+ out = model(img_tensor.unsqueeze(0))
110
+ # Select the target class
111
+ if class_selection == "Predicted class (argmax)":
112
+ class_idx = out.squeeze(0).argmax().item()
113
+ else:
114
+ class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
115
+ # Retrieve the CAM
116
+ act_maps = cam_extractor(class_idx, out)
117
+ # Fuse the CAMs if there are several
118
+ activation_map = act_maps[0] if len(act_maps) == 1 else cam_extractor.fuse_cams(act_maps)
119
+ # Plot the raw heatmap
120
+ fig, ax = plt.subplots()
121
+ ax.imshow(activation_map.squeeze(0).cpu().numpy())
122
+ ax.axis("off")
123
+ cols[1].pyplot(fig)
124
+
125
+ # Overlayed CAM
126
+ fig, ax = plt.subplots()
127
+ result = overlay_mask(img, to_pil_image(activation_map, mode="F"), alpha=0.5)
128
+ ax.imshow(result)
129
+ ax.axis("off")
130
+ cols[-1].pyplot(fig)
131
+
132
+
133
+ if __name__ == "__main__":
134
+ main()