wgetdd commited on
Commit
ee91973
1 Parent(s): b2ce316

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +94 -17
utils.py CHANGED
@@ -7,6 +7,8 @@ import numpy as np
7
  from pytorch_grad_cam import GradCAM
8
  from pytorch_grad_cam import GradCAM
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
 
 
10
 
11
  def apply_normalization(chennels):
12
  return nn.BatchNorm2d(chennels)
@@ -78,7 +80,38 @@ class CustomResnet(nn.Module):
78
  x = x.view(-1, 512)
79
  x = self.linear1(x)
80
  return F.log_softmax(x, dim=-1)
 
 
 
 
 
 
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # Function to run inference and return top classes
83
  def get_gradcam(model,input_img, opacity,layer):
84
  targets = None
@@ -92,23 +125,67 @@ def get_gradcam(model,input_img, opacity,layer):
92
  input_img = input_img.to(device)
93
  input_img = input_img.unsqueeze(0)
94
  outputs = model(input_img)
95
- _, prediction = torch.max(outputs, 1)
96
- if layer == "layer3":
97
- target_layers = [model.convlayer3[-2]]
98
- if layer == "layer2":
99
- target_layers = [model.convlayer2[-2]]
100
- if layer == "layer1":
101
- target_layers = [model.convlayer1[-2]]
102
- #target_layers = [model.convlayer3[-2]]
103
- cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
104
- grayscale_cam = cam(input_tensor=input_img, targets=targets)
105
- grayscale_cam = grayscale_cam[0, :]
106
- img = input_img.squeeze(0).to('cpu')
107
- img = inv_normalize(img)
108
- rgb_img = np.transpose(img, (1, 2, 0))
109
- rgb_img = rgb_img.numpy()
110
- visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=opacity)
111
- return visualization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
  def get_misclassified_images(show_misclassified,num):
 
7
  from pytorch_grad_cam import GradCAM
8
  from pytorch_grad_cam import GradCAM
9
  from pytorch_grad_cam.utils.image import show_cam_on_image
10
+ import matplotlib.pyplot as plt
11
+ import textwrap
12
 
13
  def apply_normalization(chennels):
14
  return nn.BatchNorm2d(chennels)
 
80
  x = x.view(-1, 512)
81
  x = self.linear1(x)
82
  return F.log_softmax(x, dim=-1)
83
+
84
+
85
+ def resize_image(image, target_size=(200, 200)):
86
+ return cv2.resize(image, target_size)
87
+
88
+ def wrap_text(text, width=20):
89
+ return textwrap.fill(text, width)
90
 
91
+ def save_plot_as_image(images,texts, output_path):
92
+ num_images = len(images)
93
+ num_cols = min(4, num_images) # Assuming you want a maximum of 4 columns
94
+ num_rows = (num_images - 1) // num_cols + 1
95
+
96
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(3 * num_cols, 3 * num_rows))
97
+
98
+ subplot_height = 0.9 / num_rows # Adjust this value to control the height of each subplot
99
+ plt.subplots_adjust(hspace=subplot_height)
100
+ for i, ax in enumerate(axes.flat):
101
+ if i < num_images:
102
+ ax.imshow(images[i], cmap='gray')
103
+ ax.axis('off')
104
+ if texts is not None and i < len(texts):
105
+ wrapped_text = wrap_text(texts[i])
106
+ ax.set_title(wrapped_text, fontsize=12, pad=5)
107
+ else:
108
+ ax.axis('off')
109
+ plt.tight_layout()
110
+ # plt.savefig(output_path)
111
+ # plt.close()
112
+ return plt
113
+
114
+
115
  # Function to run inference and return top classes
116
  def get_gradcam(model,input_img, opacity,layer):
117
  targets = None
 
125
  input_img = input_img.to(device)
126
  input_img = input_img.unsqueeze(0)
127
  outputs = model(input_img)
128
+ if layer == "convblock1":
129
+ target_layers = model.convlayer1
130
+ elif layer == "convblock2":
131
+ target_layers = model.convlayer2
132
+ elif layer == "resblock1":
133
+ target_layers = model.reslayer1
134
+ elif layer == "resblock2":
135
+ target_layers = model.reslayer2
136
+ elif layer == "convblock3":
137
+ target_layers = model.convlayer3
138
+
139
+ layer_to_user = []
140
+ for i in target_layers:
141
+ if str(i) != "ReLU()":
142
+ layer_to_user.append(i)
143
+ print(layer_to_user)
144
+ final_outputs,texts = [],[]
145
+ for i in range(len(layer_to_user)):
146
+ cam = GradCAM(model=model, target_layers=[layer_to_user[i]], use_cuda=True)
147
+ grayscale_cam = cam(input_tensor=input_img, targets=targets)
148
+ grayscale_cam = grayscale_cam[0, :]
149
+ img = input_img.squeeze(0).to('cpu')
150
+ img = inv_normalize(img)
151
+ rgb_img = np.transpose(img, (1, 2, 0))
152
+ rgb_img = rgb_img.numpy()
153
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=opacity)
154
+ final_outputs.append(resize_image(visualization))
155
+ texts.append(str(layer_to_user[i]))
156
+ figure = save_plot_as_image(final_outputs,texts, "plot.png")
157
+ return figure
158
+
159
+ # # Function to run inference and return top classes
160
+ # def get_gradcam(model,input_img, opacity,layer):
161
+ # targets = None
162
+ # inv_normalize = transforms.Normalize(
163
+ # mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
164
+ # std=[1/0.23, 1/0.23, 1/0.23]
165
+ # )
166
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
167
+ # transform = transforms.ToTensor()
168
+ # input_img = transform(input_img)
169
+ # input_img = input_img.to(device)
170
+ # input_img = input_img.unsqueeze(0)
171
+ # outputs = model(input_img)
172
+ # _, prediction = torch.max(outputs, 1)
173
+ # if layer == "layer3":
174
+ # target_layers = [model.convlayer3[-2]]
175
+ # if layer == "layer2":
176
+ # target_layers = [model.convlayer2[-2]]
177
+ # if layer == "layer1":
178
+ # target_layers = [model.convlayer1[-2]]
179
+ # #target_layers = [model.convlayer3[-2]]
180
+ # cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
181
+ # grayscale_cam = cam(input_tensor=input_img, targets=targets)
182
+ # grayscale_cam = grayscale_cam[0, :]
183
+ # img = input_img.squeeze(0).to('cpu')
184
+ # img = inv_normalize(img)
185
+ # rgb_img = np.transpose(img, (1, 2, 0))
186
+ # rgb_img = rgb_img.numpy()
187
+ # visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=opacity)
188
+ # return visualization
189
 
190
 
191
  def get_misclassified_images(show_misclassified,num):