InPeerReview commited on
Commit
c1651d2
·
verified ·
1 Parent(s): 5132c8d

Upload 9 files

Browse files
tools/__init__.py ADDED
File without changes
tools/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (134 Bytes). View file
 
tools/__pycache__/mask_convert.cpython-38.pyc ADDED
Binary file (2.02 kB). View file
 
tools/__pycache__/utilss.cpython-38.pyc ADDED
Binary file (7.35 kB). View file
 
tools/grad_cam_CNN.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append('.')
4
+
5
+ import matplotlib.pyplot as plt
6
+ from utils import GradCAM, show_cam_on_image, center_crop_img
7
+
8
+ import argparse
9
+ from utils.config import Config
10
+ from train import *
11
+
12
+ def get_args():
13
+ parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
14
+ parser.add_argument("-c", "--config", type=str, default="configs\cdxformer.py")
15
+ parser.add_argument("--output_dir", default=None)
16
+ parser.add_argument("--layer", default=None)
17
+ return parser.parse_args()
18
+
19
+ def main():
20
+ args = get_args()
21
+
22
+ if args.layer == None:
23
+ raise NameError("Please ensure the parameter '--layer' is not None!\n e.g. --layer=model.net.decoderhead.LHBlock2.mlp_l")
24
+
25
+ cfg = Config.fromfile(args.config)
26
+
27
+ model = myTrain.load_from_checkpoint(cfg.test_ckpt_path, cfg = cfg)
28
+ model = model.to('cuda')
29
+
30
+ test_loader = build_dataloader(cfg.dataset_config, mode='test')
31
+
32
+ if args.output_dir:
33
+ base_dir = args.output_dir
34
+ else:
35
+ base_dir = os.path.dirname(cfg.test_ckpt_path)
36
+ gradcam_output_dir = os.path.join(base_dir, "grad_cam", args.layer)
37
+ if os.path.exists(gradcam_output_dir):
38
+ raise NameError("Please ensure gradcam_output_dir does not exist!")
39
+
40
+ os.makedirs(gradcam_output_dir)
41
+
42
+ for input in tqdm(test_loader):
43
+ target_layers = [eval(args.layer)] # name of the network layer
44
+ mask, img_id = input[2].cuda(), input[3]
45
+
46
+ cam = GradCAM(cfg, model=model.net, target_layers=target_layers, use_cuda=True)
47
+ target_category = 1 # tabby, tabby cat
48
+
49
+ grayscale_cam_all = cam(input_tensor=(input[0], input[1]), target_category=target_category)
50
+
51
+ for i in range(grayscale_cam_all.shape[0]):
52
+ grayscale_cam = grayscale_cam_all[i, :]
53
+ visualization = show_cam_on_image(0,
54
+ grayscale_cam,
55
+ use_rgb=True)
56
+ fig = plt.figure()
57
+ ax = fig.add_subplot(111)
58
+ ax.imshow(visualization)
59
+ # ax = fig.add_subplot(122)
60
+ # ax.imshow(mask[i].cpu().numpy())
61
+ ax.set_xticks([])
62
+ ax.set_yticks([])
63
+ ax.spines['top'].set_visible(False)
64
+ ax.spines['right'].set_visible(False)
65
+ ax.spines['bottom'].set_visible(False)
66
+ ax.spines['left'].set_visible(False)
67
+ plt.savefig(os.path.join(gradcam_output_dir, '{}.png'.format(img_id[i])))
68
+ plt.close()
69
+
70
+
71
+ if __name__ == '__main__':
72
+ main()
tools/grad_cam_transformer.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append('.')
4
+
5
+ import matplotlib.pyplot as plt
6
+ from utilss import GradCAM, show_cam_on_image, center_crop_img
7
+ import math
8
+ import argparse
9
+ from utils.config import Config
10
+ from train import *
11
+
12
+ def get_args():
13
+ # input x: B, L, C
14
+ # if not, please adjust the order
15
+ parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
16
+ parser.add_argument("-c", "--config", type=str, default="configs/cdmask.py")
17
+ parser.add_argument("--output_dir", default=None)
18
+ parser.add_argument("--layer", default=None)
19
+ parser.add_argument("--imgsize", default=256)
20
+ return parser.parse_args()
21
+
22
+ class ResizeTransform:
23
+ def __init__(self, im_h: int, im_w: int):
24
+ self.height = im_h
25
+ self.width = im_w
26
+
27
+ def __call__(self, x):
28
+ # input x: B, L, C
29
+ result = x.reshape(x.size(0),
30
+ self.height,
31
+ self.width,
32
+ x.size(2))
33
+
34
+ # Bring the channels to the first dimension,
35
+ # like in CNNs.
36
+ # [batch_size, H, W, C] -> [batch, C, H, W]
37
+ result = result.permute(0, 3, 1, 2)
38
+
39
+ return result
40
+
41
+ def main():
42
+ args = get_args()
43
+
44
+ if args.layer == None:
45
+ raise NameError("Please ensure the parameter '--layer' is not None!\n e.g. --layer=model.net.decoderhead.LHBlock2.mlp_l")
46
+
47
+ cfg = Config.fromfile(args.config)
48
+
49
+ model = myTrain.load_from_checkpoint(cfg.test_ckpt_path, cfg = cfg)
50
+ model = model.to('cuda')
51
+
52
+ test_loader = build_dataloader(cfg.dataset_config, mode='test')
53
+
54
+ if args.output_dir:
55
+ base_dir = args.output_dir
56
+ else:
57
+ base_dir = os.path.dirname(cfg.test_ckpt_path)
58
+ gradcam_output_dir = os.path.join(base_dir, "grad_cam", args.layer)
59
+ if os.path.exists(gradcam_output_dir):
60
+ raise NameError("Please ensure gradcam_output_dir does not exist!")
61
+
62
+ os.makedirs(gradcam_output_dir)
63
+
64
+ for input in tqdm(test_loader):
65
+ target_layers = [eval(args.layer)] # name of the network layer
66
+ mask, img_id = input[2].cuda(), input[3]
67
+
68
+ cam = GradCAM(cfg, model=model.net, target_layers=target_layers, use_cuda=True,
69
+ reshape_transform=ResizeTransform(im_h=args.imgsize, im_w=args.imgsize))
70
+ target_category = 1 # tabby, tabby cat
71
+
72
+ grayscale_cam_all = cam(input_tensor=(input[0], input[1]), target_category=target_category)
73
+
74
+ for i in range(grayscale_cam_all.shape[0]):
75
+ grayscale_cam = grayscale_cam_all[i, :]
76
+ visualization = show_cam_on_image(0,
77
+ grayscale_cam,
78
+ use_rgb=True)
79
+ fig = plt.figure()
80
+ ax = fig.add_subplot(111)
81
+ ax.imshow(visualization)
82
+ # ax = fig.add_subplot(122)
83
+ # ax.imshow(mask[i].cpu().numpy())
84
+ ax.set_xticks([])
85
+ ax.set_yticks([])
86
+ ax.spines['top'].set_visible(False)
87
+ ax.spines['right'].set_visible(False)
88
+ ax.spines['bottom'].set_visible(False)
89
+ ax.spines['left'].set_visible(False)
90
+ plt.savefig(os.path.join(gradcam_output_dir, '{}.png'.format(img_id[i])))
91
+ plt.close()
92
+
93
+
94
+ if __name__ == '__main__':
95
+ main()
tools/mask_convert.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import argparse
3
+ import glob
4
+ import os
5
+ import sys
6
+ import torch
7
+ import cv2
8
+ import random
9
+ import time
10
+ import multiprocessing.pool as mpp
11
+ import multiprocessing as mp
12
+ SEED = 66
13
+
14
+ def seed_everything(seed):
15
+ random.seed(seed)
16
+ os.environ['PYTHONHASHSEED'] = str(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed(seed)
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = True
22
+ def label2rgb(mask, mask_pred):
23
+ real_1 = (mask == 1)
24
+ real_0 = (mask == 0)
25
+ pred_1 = (mask_pred == 1)
26
+ pred_0 = (mask_pred == 0)
27
+
28
+ TP = np.logical_and(real_1, pred_1)
29
+ TN = np.logical_and(real_0, pred_0)
30
+ FN = np.logical_and(real_1, pred_0)
31
+ FP = np.logical_and(real_0, pred_1)
32
+
33
+ mask_TP = TP[np.newaxis, :, :]
34
+ mask_TN = TN[np.newaxis, :, :]
35
+ mask_FN = FN[np.newaxis, :, :]
36
+ mask_FP = FP[np.newaxis, :, :]
37
+
38
+ h, w = mask.shape[0], mask.shape[1]
39
+ mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8)
40
+ mask_rgb[np.all(mask_TP, axis=0)] = [255, 255, 255] # TP
41
+ mask_rgb[np.all(mask_TN, axis=0)] = [0, 0, 0] # TN
42
+ mask_rgb[np.all(mask_FN, axis=0)] = [0, 255, 0] # FN
43
+ mask_rgb[np.all(mask_FP, axis=0)] = [255, 0, 0] # FP
44
+
45
+ return mask_rgb
46
+
47
+ def parse_args():
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("--dataset", default="Vaihingen")
50
+ parser.add_argument("--mask-dir", default="data/Test/masks")
51
+ parser.add_argument("--output-mask-dir", default="data/Test/masks_rgb")
52
+ return parser.parse_args()
53
+
54
+ def mask_save(inp):
55
+ (mask, mask_pred, masks_output_dir, file_name) = inp
56
+ out_mask_path = os.path.join(masks_output_dir, "{}.png".format(file_name))
57
+
58
+ label = label2rgb(mask.copy(), mask_pred.copy())
59
+
60
+ rgb_label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB)
61
+ cv2.imwrite(out_mask_path, rgb_label)
62
+
63
+ # def get_rgb(inp):
64
+ # (mask_path, masks_output_dir,dataset) = inp
65
+ # mask_filename = os.path.splitext(os.path.basename(mask_path))[0]
66
+ # mask_bgr = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
67
+ # mask = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2RGB)
68
+ # if dataset == "LoveDA":
69
+ # rgb_label = loveda_label2rgb(mask.copy())
70
+ # elif dataset == "Vaihingen":
71
+ # rgb_label = vaihingen_label2rgb(mask.copy())
72
+ # elif dataset == "Potsdam":
73
+ # rgb_label = potsdam_label2rgb(mask.copy())
74
+ # elif dataset == "uavid":
75
+ # rgb_label = uavid_label2rgb(mask.copy())
76
+ # else: return
77
+ # #rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_RGB2BGR)
78
+
79
+ # out_mask_path_rgb = os.path.join(masks_output_dir, "{}.png".format(mask_filename))
80
+ # rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_BGR2RGB)
81
+ # cv2.imwrite(out_mask_path_rgb, rgb_label)
82
+
83
+ # if __name__ == '__main__':
84
+ # base_path = "/home/xwma/lrr/rssegmentation/"
85
+ # args = parse_args()
86
+ # dataset = args.dataset
87
+
88
+ # seed_everything(SEED)
89
+ # masks_dir = args.mask_dir
90
+ # masks_output_dir = args.output_mask_dir
91
+ # masks_dir = base_path + masks_dir
92
+ # masks_output_dir = base_path + masks_output_dir
93
+
94
+ # mask_paths = glob.glob(os.path.join(masks_dir, "*.png"))
95
+ # inp = [(mask_path, masks_output_dir, dataset) for mask_path in mask_paths]
96
+ # if not os.path.exists(masks_output_dir):
97
+ # os.makedirs(masks_output_dir)
98
+
99
+ # t0 = time.time()
100
+ # mpp.Pool(processes=mp.cpu_count()).map(get_rgb, inp)
101
+ # t1 = time.time()
102
+ # split_time = t1 - t0
103
+ # print('images spliting spends: {} s'.format(split_time))
tools/params_flops.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ sys.path.append('.')
4
+ from train import *
5
+ from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count, parameter_count
6
+ from rscd.models.backbones.lamba_util.csms6s import flops_selective_scan_fn, flops_selective_scan_ref, selective_scan_flop_jit
7
+
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(description='count params and flops')
11
+ parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
12
+ parser.add_argument("--size", type=int, default=256)
13
+ args = parser.parse_args()
14
+ return args
15
+
16
+ def flops_mamba(model, shape=(3, 224, 224)):
17
+ # shape = self.__input_shape__[1:]
18
+ supported_ops = {
19
+ "aten::silu": None, # as relu is in _IGNORED_OPS
20
+ "aten::neg": None, # as relu is in _IGNORED_OPS
21
+ "aten::exp": None, # as relu is in _IGNORED_OPS
22
+ "aten::flip": None, # as permute is in _IGNORED_OPS
23
+ # "prim::PythonOp.CrossScan": None,
24
+ # "prim::PythonOp.CrossMerge": None,
25
+ "prim::PythonOp.SelectiveScanCuda": selective_scan_flop_jit,
26
+ "prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit,
27
+ "prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit,
28
+ "prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit,
29
+ "prim::PythonOp.SelectiveScanNRow": selective_scan_flop_jit,
30
+ }
31
+
32
+ model.cuda().eval()
33
+
34
+ input1 = torch.randn((1, *shape), device=next(model.parameters()).device)
35
+ input2 = torch.randn((1, *shape), device=next(model.parameters()).device)
36
+ params = parameter_count(model)[""]
37
+ Gflops, unsupported = flop_count(model=model, inputs=(input1,input2), supported_ops=supported_ops)
38
+
39
+ del model, input1, input2
40
+ # return sum(Gflops.values()) * 1e9
41
+ return f"params {params / 1e6} GFLOPs {sum(Gflops.values())}"
42
+
43
+ if __name__ == "__main__":
44
+ args = parse_args()
45
+ cfg = Config.fromfile(args.config)
46
+ net = myTrain(cfg).net.cuda()
47
+
48
+ size = args.size
49
+ input = torch.rand((1, 3, size, size)).cuda()
50
+
51
+ net.eval()
52
+ flops = FlopCountAnalysis(net, (input, input))
53
+ print(flop_count_table(flops, max_depth = 2))
54
+
55
+ print(flops_mamba(net, (3, size, size)))
tools/utilss.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from torch.nn import functional as F
4
+ import torch
5
+
6
+ class ActivationsAndGradients:
7
+ """ Class for extracting activations and
8
+ registering gradients from targeted intermediate layers """
9
+
10
+ def __init__(self, model, target_layers, reshape_transform):
11
+ self.model = model
12
+ self.gradients = []
13
+ self.activations = []
14
+ self.reshape_transform = reshape_transform
15
+ self.handles = []
16
+ for target_layer in target_layers:
17
+ self.handles.append(
18
+ target_layer.register_forward_hook(
19
+ self.save_activation))
20
+ # Backward compatibility with older pytorch versions:
21
+ if hasattr(target_layer, 'register_full_backward_hook'):
22
+ self.handles.append(
23
+ target_layer.register_full_backward_hook(
24
+ self.save_gradient))
25
+ else:
26
+ self.handles.append(
27
+ target_layer.register_backward_hook(
28
+ self.save_gradient))
29
+
30
+ def save_activation(self, module, input, output):
31
+ activation = output
32
+ if self.reshape_transform is not None:
33
+ activation = self.reshape_transform(activation)
34
+ self.activations.append(activation.cpu().detach())
35
+
36
+ def save_gradient(self, module, grad_input, grad_output):
37
+ # Gradients are computed in reverse order
38
+ grad = grad_output[0]
39
+ if self.reshape_transform is not None:
40
+ grad = self.reshape_transform(grad)
41
+ self.gradients = [grad.cpu().detach()] + self.gradients
42
+
43
+ def __call__(self, x, y):
44
+ self.gradients = []
45
+ self.activations = []
46
+ return self.model(x, y)
47
+
48
+ def release(self):
49
+ for handle in self.handles:
50
+ handle.remove()
51
+
52
+
53
+ class GradCAM:
54
+ def __init__(self,
55
+ cfg,
56
+ model,
57
+ target_layers,
58
+ reshape_transform=None,
59
+ use_cuda=False):
60
+ self.cfg = cfg
61
+ self.model = model.eval()
62
+ self.target_layers = target_layers
63
+ self.reshape_transform = reshape_transform
64
+ self.cuda = use_cuda
65
+ if self.cuda:
66
+ self.model = model.cuda()
67
+ self.activations_and_grads = ActivationsAndGradients(
68
+ self.model, target_layers, reshape_transform)
69
+
70
+ """ Get a vector of weights for every channel in the target layer.
71
+ Methods that return weights channels,
72
+ will typically need to only implement this function. """
73
+
74
+ @staticmethod
75
+ def get_cam_weights(grads):
76
+ return np.mean(grads, axis=(2, 3), keepdims=True)
77
+
78
+ @staticmethod
79
+ def get_loss(output, target_category):
80
+ loss = 0
81
+ for i in range(len(target_category)):
82
+ loss = loss + output[i]
83
+ return loss
84
+
85
+ def get_cam_image(self, activations, grads):
86
+ weights = self.get_cam_weights(grads)
87
+ weighted_activations = weights * activations
88
+ cam = weighted_activations.sum(axis=1)
89
+
90
+ return cam
91
+
92
+ @staticmethod
93
+ def get_target_width_height(input_tensor):
94
+ width, height = input_tensor.size(-1), input_tensor.size(-2)
95
+ return width, height
96
+
97
+ def compute_cam_per_layer(self, input_tensor):
98
+ activations_list = [a.cpu().data.numpy()
99
+ for a in self.activations_and_grads.activations]
100
+ grads_list = [g.cpu().data.numpy()
101
+ for g in self.activations_and_grads.gradients]
102
+ target_size = self.get_target_width_height(input_tensor)
103
+
104
+ cam_per_target_layer = []
105
+ # Loop over the saliency image from every layer
106
+
107
+ for layer_activations, layer_grads in zip(activations_list, grads_list):
108
+ cam = self.get_cam_image(layer_activations, layer_grads)
109
+ cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image
110
+ scaled = self.scale_cam_image(cam, target_size)
111
+ cam_per_target_layer.append(scaled[:, None, :])
112
+
113
+ return cam_per_target_layer
114
+
115
+ def aggregate_multi_layers(self, cam_per_target_layer):
116
+ cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
117
+ cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
118
+ result = np.mean(cam_per_target_layer, axis=1)
119
+ return self.scale_cam_image(result)
120
+
121
+ @staticmethod
122
+ def scale_cam_image(cam, target_size=None):
123
+ result = []
124
+ for img in cam:
125
+ img = img - np.min(img)
126
+ img = img / (1e-7 + np.max(img))
127
+ if target_size is not None:
128
+ img = cv2.resize(img, target_size)
129
+ result.append(img)
130
+ result = np.float32(result)
131
+
132
+ return result
133
+
134
+ def __call__(self, input_tensor, target_category=None):
135
+ x, y = input_tensor
136
+ if self.cuda:
137
+ x = x.cuda()
138
+ y = y.cuda()
139
+
140
+ # 正向传播得到网络输出logits(未经过softmax)
141
+ if self.cfg.net == 'cdmask':
142
+ o, outputs = self.activations_and_grads(x, y)
143
+ mask_cls_results = outputs["pred_logits"]
144
+ mask_pred_results = outputs["pred_masks"]
145
+ mask_pred_results = F.interpolate(
146
+ mask_pred_results,
147
+ scale_factor=(4,4),
148
+ mode="bilinear",
149
+ align_corners=False,
150
+ )
151
+ mask_cls = F.softmax(mask_cls_results, dim=-1)[...,1:]
152
+ mask_pred = mask_pred_results.sigmoid()
153
+ output = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
154
+ else:
155
+ output = self.activations_and_grads(x, y)
156
+
157
+ if isinstance(target_category, int):
158
+ target_category = [target_category] * x.size(0)
159
+
160
+ if target_category is None:
161
+ target_category = np.argmax(output.cpu().data.numpy(), axis=-1)
162
+ print(f"category id: {target_category}")
163
+ else:
164
+ assert (len(target_category) == x.size(0))
165
+
166
+ self.model.zero_grad()
167
+ loss = self.get_loss(output, target_category).sum()
168
+ loss.backward(retain_graph=True)
169
+
170
+ # In most of the saliency attribution papers, the saliency is
171
+ # computed with a single target layer.
172
+ # Commonly it is the last convolutional layer.
173
+ # Here we support passing a list with multiple target layers.
174
+ # It will compute the saliency image for every image,
175
+ # and then aggregate them (with a default mean aggregation).
176
+ # This gives you more flexibility in case you just want to
177
+ # use all conv layers for example, all Batchnorm layers,
178
+ # or something else.
179
+ cam_per_layer = self.compute_cam_per_layer(x)
180
+ return self.aggregate_multi_layers(cam_per_layer)
181
+
182
+ def __del__(self):
183
+ self.activations_and_grads.release()
184
+
185
+ def __enter__(self):
186
+ return self
187
+
188
+ def __exit__(self, exc_type, exc_value, exc_tb):
189
+ self.activations_and_grads.release()
190
+ if isinstance(exc_value, IndexError):
191
+ # Handle IndexError here...
192
+ print(
193
+ f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
194
+ return True
195
+
196
+
197
+ def show_cam_on_image(img: np.ndarray,
198
+ mask: np.ndarray,
199
+ use_rgb: bool = False,
200
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
201
+ """ This function overlays the cam mask on the image as an heatmap.
202
+ By default the heatmap is in BGR format.
203
+
204
+ :param img: The base image in RGB or BGR format.
205
+ :param mask: The cam mask.
206
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
207
+ :param colormap: The OpenCV colormap to be used.
208
+ :returns: The default image with the cam overlay.
209
+ """
210
+
211
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
212
+ if use_rgb:
213
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
214
+ heatmap = np.float32(heatmap) / 255
215
+
216
+ if np.max(img) > 1:
217
+ raise Exception(
218
+ "The input image should np.float32 in the range [0, 1]")
219
+
220
+ cam = heatmap + img
221
+ cam = cam / np.max(cam)
222
+ return np.uint8(255 * cam)
223
+
224
+
225
+ def center_crop_img(img: np.ndarray, size: int):
226
+ h, w, c = img.shape
227
+
228
+ if w == h == size:
229
+ return img
230
+
231
+ if w < h:
232
+ ratio = size / w
233
+ new_w = size
234
+ new_h = int(h * ratio)
235
+ else:
236
+ ratio = size / h
237
+ new_h = size
238
+ new_w = int(w * ratio)
239
+
240
+ img = cv2.resize(img, dsize=(new_w, new_h))
241
+
242
+ if new_w == size:
243
+ h = (new_h - size) // 2
244
+ img = img[h: h+size]
245
+ else:
246
+ w = (new_w - size) // 2
247
+ img = img[:, w: w+size]
248
+
249
+ return img