F-G Fernandez commited on
Commit
c59b75f
1 Parent(s): a9a3664

fix: Fixed matplotlib call

Browse files
Files changed (1) hide show
  1. app.py +44 -29
app.py CHANGED
@@ -1,21 +1,22 @@
1
- # Copyright (C) 2020-2021, François-Guillaume Fernandez.
2
 
3
- # This program is licensed under the Apache License version 2.
4
- # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
5
 
 
 
 
6
  import requests
7
  import streamlit as st
8
- import matplotlib.pyplot as plt
9
  from PIL import Image
10
- from io import BytesIO
11
  from torchvision import models
12
- from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image
13
 
14
  from torchcam import methods
15
  from torchcam.methods._utils import locate_candidate_layer
16
  from torchcam.utils import overlay_mask
17
 
18
-
19
  CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"]
20
  TV_MODELS = [
21
  "resnet18",
@@ -39,10 +40,7 @@ def main():
39
  # Designing the interface
40
  st.title("TorchCAM: class activation explorer")
41
  # For newline
42
- st.write('\n')
43
- st.write('Check the project at: https://github.com/frgfm/torch-cam')
44
- # For newline
45
- st.write('\n')
46
  # Set the columns
47
  cols = st.columns((1, 1, 1))
48
  cols[0].header("Input image")
@@ -53,36 +51,50 @@ def main():
53
  # File selection
54
  st.sidebar.title("Input selection")
55
  # Disabling warning
56
- st.set_option('deprecation.showfileUploaderEncoding', False)
57
  # Choose your own image
58
- uploaded_file = st.sidebar.file_uploader("Upload files", type=['png', 'jpeg', 'jpg'])
59
  if uploaded_file is not None:
60
- img = Image.open(BytesIO(uploaded_file.read()), mode='r').convert('RGB')
61
 
62
  cols[0].image(img, use_column_width=True)
63
 
64
  # Model selection
65
  st.sidebar.title("Setup")
66
- tv_model = st.sidebar.selectbox("Classification model", TV_MODELS)
 
 
 
 
67
  default_layer = ""
68
  if tv_model is not None:
69
- with st.spinner('Loading model...'):
70
  model = models.__dict__[tv_model](pretrained=True).eval()
71
  default_layer = locate_candidate_layer(model, (3, 224, 224))
72
 
73
- target_layer = st.sidebar.text_input("Target layer", default_layer)
74
- cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
 
 
 
 
 
 
 
 
 
 
 
75
  if cam_method is not None:
76
  cam_extractor = methods.__dict__[cam_method](
77
- model,
78
- target_layer=target_layer.split("+") if len(target_layer) > 0 else None
79
  )
80
 
81
  class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
82
  class_selection = st.sidebar.selectbox("Class selection", ["Predicted class (argmax)"] + class_choices)
83
 
84
  # For newline
85
- st.sidebar.write('\n')
86
 
87
  if st.sidebar.button("Compute CAM"):
88
 
@@ -90,11 +102,14 @@ def main():
90
  st.sidebar.error("Please upload an image first")
91
 
92
  else:
93
- with st.spinner('Analyzing...'):
94
 
95
  # Preprocess image
96
  img_tensor = normalize(to_tensor(resize(img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
97
 
 
 
 
98
  # Forward the image to the model
99
  out = model(img_tensor.unsqueeze(0))
100
  # Select the target class
@@ -103,22 +118,22 @@ def main():
103
  else:
104
  class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
105
  # Retrieve the CAM
106
- cams = cam_extractor(class_idx, out)
107
  # Fuse the CAMs if there are several
108
- cam = cams[0] if len(cams) == 1 else cam_extractor.fuse_cams(cams)
109
  # Plot the raw heatmap
110
  fig, ax = plt.subplots()
111
- ax.imshow(cam.numpy())
112
- ax.axis('off')
113
  cols[1].pyplot(fig)
114
 
115
  # Overlayed CAM
116
  fig, ax = plt.subplots()
117
- result = overlay_mask(img, to_pil_image(cam, mode='F'), alpha=0.5)
118
  ax.imshow(result)
119
- ax.axis('off')
120
  cols[-1].pyplot(fig)
121
 
122
 
123
- if __name__ == '__main__':
124
  main()
1
+ # Copyright (C) 2021-2022, François-Guillaume Fernandez.
2
 
3
+ # This program is licensed under the Apache License 2.0.
4
+ # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
5
 
6
+ from io import BytesIO
7
+
8
+ import matplotlib.pyplot as plt
9
  import requests
10
  import streamlit as st
11
+ import torch
12
  from PIL import Image
 
13
  from torchvision import models
14
+ from torchvision.transforms.functional import normalize, resize, to_pil_image, to_tensor
15
 
16
  from torchcam import methods
17
  from torchcam.methods._utils import locate_candidate_layer
18
  from torchcam.utils import overlay_mask
19
 
 
20
  CAM_METHODS = ["CAM", "GradCAM", "GradCAMpp", "SmoothGradCAMpp", "ScoreCAM", "SSCAM", "ISCAM", "XGradCAM", "LayerCAM"]
21
  TV_MODELS = [
22
  "resnet18",
40
  # Designing the interface
41
  st.title("TorchCAM: class activation explorer")
42
  # For newline
43
+ st.write("\n")
 
 
 
44
  # Set the columns
45
  cols = st.columns((1, 1, 1))
46
  cols[0].header("Input image")
51
  # File selection
52
  st.sidebar.title("Input selection")
53
  # Disabling warning
54
+ st.set_option("deprecation.showfileUploaderEncoding", False)
55
  # Choose your own image
56
+ uploaded_file = st.sidebar.file_uploader("Upload files", type=["png", "jpeg", "jpg"])
57
  if uploaded_file is not None:
58
+ img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
59
 
60
  cols[0].image(img, use_column_width=True)
61
 
62
  # Model selection
63
  st.sidebar.title("Setup")
64
+ tv_model = st.sidebar.selectbox(
65
+ "Classification model",
66
+ TV_MODELS,
67
+ help="Supported models from Torchvision",
68
+ )
69
  default_layer = ""
70
  if tv_model is not None:
71
+ with st.spinner("Loading model..."):
72
  model = models.__dict__[tv_model](pretrained=True).eval()
73
  default_layer = locate_candidate_layer(model, (3, 224, 224))
74
 
75
+ if torch.cuda.is_available():
76
+ model = model.cuda()
77
+
78
+ target_layer = st.sidebar.text_input(
79
+ "Target layer",
80
+ default_layer,
81
+ help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")',
82
+ )
83
+ cam_method = st.sidebar.selectbox(
84
+ "CAM method",
85
+ CAM_METHODS,
86
+ help="The way your class activation map will be computed",
87
+ )
88
  if cam_method is not None:
89
  cam_extractor = methods.__dict__[cam_method](
90
+ model, target_layer=[s.strip() for s in target_layer.split("+")] if len(target_layer) > 0 else None
 
91
  )
92
 
93
  class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
94
  class_selection = st.sidebar.selectbox("Class selection", ["Predicted class (argmax)"] + class_choices)
95
 
96
  # For newline
97
+ st.sidebar.write("\n")
98
 
99
  if st.sidebar.button("Compute CAM"):
100
 
102
  st.sidebar.error("Please upload an image first")
103
 
104
  else:
105
+ with st.spinner("Analyzing..."):
106
 
107
  # Preprocess image
108
  img_tensor = normalize(to_tensor(resize(img, (224, 224))), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
109
 
110
+ if torch.cuda.is_available():
111
+ img_tensor = img_tensor.cuda()
112
+
113
  # Forward the image to the model
114
  out = model(img_tensor.unsqueeze(0))
115
  # Select the target class
118
  else:
119
  class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
120
  # Retrieve the CAM
121
+ act_maps = cam_extractor(class_idx, out)
122
  # Fuse the CAMs if there are several
123
+ activation_map = act_maps[0] if len(act_maps) == 1 else cam_extractor.fuse_cams(act_maps)
124
  # Plot the raw heatmap
125
  fig, ax = plt.subplots()
126
+ ax.imshow(activation_map.squeeze(0).cpu().numpy())
127
+ ax.axis("off")
128
  cols[1].pyplot(fig)
129
 
130
  # Overlayed CAM
131
  fig, ax = plt.subplots()
132
+ result = overlay_mask(img, to_pil_image(activation_map, mode="F"), alpha=0.5)
133
  ax.imshow(result)
134
+ ax.axis("off")
135
  cols[-1].pyplot(fig)
136
 
137
 
138
+ if __name__ == "__main__":
139
  main()