wgetdd commited on
Commit
c23917f
1 Parent(s): 50a0732

Added src files

Browse files
Files changed (4) hide show
  1. src/config.py +50 -0
  2. src/detect.py +55 -0
  3. src/model_yolov3.py +181 -0
  4. src/utils.py +268 -0
src/config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import cv2
3
+ import torch
4
+
5
+ from albumentations.pytorch import ToTensorV2
6
+
7
+
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+
11
+ IMAGE_SIZE = 416
12
+ transforms = A.Compose(
13
+ [
14
+ A.LongestMaxSize(max_size=IMAGE_SIZE),
15
+ A.PadIfNeeded(
16
+ min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
17
+ ),
18
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
19
+ ToTensorV2(),
20
+ ],
21
+ )
22
+ ANCHORS = [
23
+ [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
24
+ [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
25
+ [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
26
+ ] # Note these have been rescaled to be between [0, 1]
27
+ S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
28
+
29
+ PASCAL_CLASSES = [
30
+ "aeroplane",
31
+ "bicycle",
32
+ "bird",
33
+ "boat",
34
+ "bottle",
35
+ "bus",
36
+ "car",
37
+ "cat",
38
+ "chair",
39
+ "cow",
40
+ "diningtable",
41
+ "dog",
42
+ "horse",
43
+ "motorbike",
44
+ "person",
45
+ "pottedplant",
46
+ "sheep",
47
+ "sofa",
48
+ "train",
49
+ "tvmonitor"
50
+ ]
src/detect.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import src.config as config
6
+ from pytorch_grad_cam.utils.image import show_cam_on_image
7
+
8
+ from main_app.src.model_yolov3 import YOLOv3
9
+ from src.utils import cells_to_bboxes, non_max_suppression, draw_predictions, YoloCAM
10
+
11
+
12
+ model = YOLOv3(num_classes=20)
13
+
14
+ weights_path = "/home/deepanshu/Desktop/yolov3_gradio-20230811T035430Z-001/yolov3_gradio/Final_trained_model.pth"
15
+ ckpt = torch.load(weights_path, map_location="cpu")
16
+ model.load_state_dict(ckpt)
17
+ model.eval()
18
+ print("[x] Model Loaded..")
19
+
20
+ scaled_anchors = (
21
+ torch.tensor(config.ANCHORS)
22
+ * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
23
+ ).to(config.DEVICE)
24
+
25
+ cam = YoloCAM(model=model, target_layers=[model.layers[-2]], use_cuda=False)
26
+
27
+ def predict(image: np.ndarray, iou_thresh: float = 0.5, thresh: float = 0.4, show_cam: bool = False, transparency: float = 0.5) -> List[np.ndarray]:
28
+ with torch.no_grad():
29
+ transformed_image = config.transforms(image=image)["image"].unsqueeze(0)
30
+ output = model(transformed_image)
31
+
32
+ bboxes = [[] for _ in range(1)]
33
+ for i in range(3):
34
+ batch_size, A, S, _, _ = output[i].shape
35
+ anchor = scaled_anchors[i]
36
+ boxes_scale_i = cells_to_bboxes(
37
+ output[i], anchor, S=S, is_preds=True
38
+ )
39
+ for idx, (box) in enumerate(boxes_scale_i):
40
+ bboxes[idx] += box
41
+
42
+ nms_boxes = non_max_suppression(
43
+ bboxes[0], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
44
+ )
45
+ plot_img = draw_predictions(image, nms_boxes, class_labels=config.PASCAL_CLASSES)
46
+ if not show_cam:
47
+ return [plot_img]
48
+
49
+ grayscale_cam = cam(transformed_image, scaled_anchors)[0, :, :]
50
+ img = cv2.resize(image, (416, 416))
51
+ img = np.float32(img) / 255
52
+ cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True, image_weight=transparency)
53
+ return [plot_img, cam_image]
54
+
55
+
src/model_yolov3.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of YOLOv3 architecture."""
2
+ from typing import Any, List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ """
8
+ Information about architecture config:
9
+ Tuple is structured by (filters, kernel_size, stride)
10
+ Every conv is a same convolution.
11
+ List is structured by "B" indicating a residual block followed by the number of repeats
12
+ "S" is for scale prediction block and computing the yolo loss
13
+ "U" is for upsampling the feature map and concatenating with a previous layer
14
+ """
15
+ config = [
16
+ (32, 3, 1),
17
+ (64, 3, 2),
18
+ ["B", 1],
19
+ (128, 3, 2),
20
+ ["B", 2],
21
+ (256, 3, 2),
22
+ ["B", 8],
23
+ (512, 3, 2),
24
+ ["B", 8],
25
+ (1024, 3, 2),
26
+ ["B", 4], # To this point is Darknet-53
27
+ (512, 1, 1),
28
+ (1024, 3, 1),
29
+ "S",
30
+ (256, 1, 1),
31
+ "U",
32
+ (256, 1, 1),
33
+ (512, 3, 1),
34
+ "S",
35
+ (128, 1, 1),
36
+ "U",
37
+ (128, 1, 1),
38
+ (256, 3, 1),
39
+ "S",
40
+ ]
41
+
42
+
43
+ class CNNBlock(nn.Module):
44
+ def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
45
+ super().__init__()
46
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
47
+ self.bn = nn.BatchNorm2d(out_channels)
48
+ self.leaky = nn.LeakyReLU(0.1)
49
+ self.use_bn_act = bn_act
50
+
51
+ def forward(self, x):
52
+ if self.use_bn_act:
53
+ return self.leaky(self.bn(self.conv(x)))
54
+ else:
55
+ return self.conv(x)
56
+
57
+
58
+ class ResidualBlock(nn.Module):
59
+ def __init__(self, channels, use_residual=True, num_repeats=1):
60
+ super().__init__()
61
+ self.layers = nn.ModuleList()
62
+ for repeat in range(num_repeats):
63
+ self.layers += [
64
+ nn.Sequential(
65
+ CNNBlock(channels, channels // 2, kernel_size=1),
66
+ CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
67
+ )
68
+ ]
69
+
70
+ self.use_residual = use_residual
71
+ self.num_repeats = num_repeats
72
+
73
+ def forward(self, x):
74
+ for layer in self.layers:
75
+ if self.use_residual:
76
+ x = x + layer(x)
77
+ else:
78
+ x = layer(x)
79
+
80
+ return x
81
+
82
+
83
+ class ScalePrediction(nn.Module):
84
+ def __init__(self, in_channels, num_classes):
85
+ super().__init__()
86
+ self.pred = nn.Sequential(
87
+ CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
88
+ CNNBlock(2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1),
89
+ )
90
+ self.num_classes = num_classes
91
+
92
+ def forward(self, x):
93
+ return (
94
+ self.pred(x)
95
+ .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
96
+ .permute(0, 1, 3, 4, 2)
97
+ )
98
+
99
+
100
+ class YOLOv3(nn.Module):
101
+ def __init__(self, load_config: List[Any] = config, in_channels=3, num_classes=80):
102
+ super().__init__()
103
+ self.load_config = load_config
104
+ self.num_classes = num_classes
105
+ self.in_channels = in_channels
106
+ self.layers = self._create_conv_layers()
107
+
108
+ def forward(self, x):
109
+ outputs = [] # for each scale
110
+ route_connections = []
111
+ for layer in self.layers:
112
+ if isinstance(layer, ScalePrediction):
113
+ outputs.append(layer(x))
114
+ continue
115
+
116
+ x = layer(x)
117
+
118
+ if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
119
+ route_connections.append(x)
120
+
121
+ elif isinstance(layer, nn.Upsample):
122
+ x = torch.cat([x, route_connections[-1]], dim=1)
123
+ route_connections.pop()
124
+
125
+ return outputs
126
+
127
+ def _create_conv_layers(self):
128
+ layers = nn.ModuleList()
129
+ in_channels = self.in_channels
130
+
131
+ for module in self.load_config:
132
+ if isinstance(module, tuple):
133
+ out_channels, kernel_size, stride = module
134
+ layers.append(
135
+ CNNBlock(
136
+ in_channels,
137
+ out_channels,
138
+ kernel_size=kernel_size,
139
+ stride=stride,
140
+ padding=1 if kernel_size == 3 else 0,
141
+ )
142
+ )
143
+ in_channels = out_channels
144
+
145
+ elif isinstance(module, list):
146
+ num_repeats = module[1]
147
+ layers.append(
148
+ ResidualBlock(
149
+ in_channels,
150
+ num_repeats=num_repeats,
151
+ )
152
+ )
153
+
154
+ elif isinstance(module, str):
155
+ if module == "S":
156
+ layers += [
157
+ ResidualBlock(in_channels, use_residual=False, num_repeats=1),
158
+ CNNBlock(in_channels, in_channels // 2, kernel_size=1),
159
+ ScalePrediction(in_channels // 2, num_classes=self.num_classes),
160
+ ]
161
+ in_channels = in_channels // 2
162
+
163
+ elif module == "U":
164
+ layers.append(
165
+ nn.Upsample(scale_factor=2),
166
+ )
167
+ in_channels = in_channels * 3
168
+
169
+ return layers
170
+
171
+
172
+ if __name__ == "__main__":
173
+ num_classes = 20
174
+ IMAGE_SIZE = 416
175
+ model = YOLOv3(load_config=config, num_classes=num_classes)
176
+ x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
177
+ out = model(x)
178
+ assert out[0].shape == (2, 3, IMAGE_SIZE // 32, IMAGE_SIZE // 32, num_classes + 5)
179
+ assert out[1].shape == (2, 3, IMAGE_SIZE // 16, IMAGE_SIZE // 16, num_classes + 5)
180
+ assert out[2].shape == (2, 3, IMAGE_SIZE // 8, IMAGE_SIZE // 8, num_classes + 5)
181
+ print("Success!")
src/utils.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ import random
6
+
7
+ from pytorch_grad_cam.base_cam import BaseCAM
8
+ from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
9
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
10
+
11
+
12
+ def cells_to_bboxes(predictions, anchors, S, is_preds=True):
13
+ """
14
+ Scales the predictions coming from the model to
15
+ be relative to the entire image such that they for example later
16
+ can be plotted or.
17
+ INPUT:
18
+ predictions: tensor of size (N, 3, S, S, num_classes+5)
19
+ anchors: the anchors used for the predictions
20
+ S: the number of cells the image is divided in on the width (and height)
21
+ is_preds: whether the input is predictions or the true bounding boxes
22
+ OUTPUT:
23
+ converted_bboxes: the converted boxes of sizes (N, num_anchors, S, S, 1+5) with class index,
24
+ object score, bounding box coordinates
25
+ """
26
+ BATCH_SIZE = predictions.shape[0]
27
+ num_anchors = len(anchors)
28
+ box_predictions = predictions[..., 1:5]
29
+ if is_preds:
30
+ anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
31
+ box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
32
+ box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
33
+ scores = torch.sigmoid(predictions[..., 0:1])
34
+ best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
35
+ else:
36
+ scores = predictions[..., 0:1]
37
+ best_class = predictions[..., 5:6]
38
+
39
+ cell_indices = (
40
+ torch.arange(S)
41
+ .repeat(predictions.shape[0], 3, S, 1)
42
+ .unsqueeze(-1)
43
+ .to(predictions.device)
44
+ )
45
+ x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
46
+ y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
47
+ w_h = 1 / S * box_predictions[..., 2:4]
48
+ converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(BATCH_SIZE, num_anchors * S * S, 6)
49
+ return converted_bboxes.tolist()
50
+
51
+
52
+
53
+ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
54
+ """
55
+ Video explanation of this function:
56
+ https://youtu.be/XXYG5ZWtjj0
57
+
58
+ This function calculates intersection over union (iou) given pred boxes
59
+ and target boxes.
60
+
61
+ Parameters:
62
+ boxes_preds (tensor): Predictions of Bounding Boxes (BATCH_SIZE, 4)
63
+ boxes_labels (tensor): Correct labels of Bounding Boxes (BATCH_SIZE, 4)
64
+ box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)
65
+
66
+ Returns:
67
+ tensor: Intersection over union for all examples
68
+ """
69
+
70
+ if box_format == "midpoint":
71
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
72
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
73
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
74
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
75
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
76
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
77
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
78
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
79
+
80
+ if box_format == "corners":
81
+ box1_x1 = boxes_preds[..., 0:1]
82
+ box1_y1 = boxes_preds[..., 1:2]
83
+ box1_x2 = boxes_preds[..., 2:3]
84
+ box1_y2 = boxes_preds[..., 3:4]
85
+ box2_x1 = boxes_labels[..., 0:1]
86
+ box2_y1 = boxes_labels[..., 1:2]
87
+ box2_x2 = boxes_labels[..., 2:3]
88
+ box2_y2 = boxes_labels[..., 3:4]
89
+
90
+ x1 = torch.max(box1_x1, box2_x1)
91
+ y1 = torch.max(box1_y1, box2_y1)
92
+ x2 = torch.min(box1_x2, box2_x2)
93
+ y2 = torch.min(box1_y2, box2_y2)
94
+
95
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
96
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
97
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
98
+
99
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
100
+
101
+ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
102
+ """
103
+ Video explanation of this function:
104
+ https://youtu.be/YDkjWEN8jNA
105
+
106
+ Does Non Max Suppression given bboxes
107
+
108
+ Parameters:
109
+ bboxes (list): list of lists containing all bboxes with each bboxes
110
+ specified as [class_pred, prob_score, x1, y1, x2, y2]
111
+ iou_threshold (float): threshold where predicted bboxes is correct
112
+ threshold (float): threshold to remove predicted bboxes (independent of IoU)
113
+ box_format (str): "midpoint" or "corners" used to specify bboxes
114
+
115
+ Returns:
116
+ list: bboxes after performing NMS given a specific IoU threshold
117
+ """
118
+
119
+ assert type(bboxes) == list
120
+
121
+ bboxes = [box for box in bboxes if box[1] > threshold]
122
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
123
+ bboxes_after_nms = []
124
+
125
+ while bboxes:
126
+ chosen_box = bboxes.pop(0)
127
+
128
+ bboxes = [
129
+ box
130
+ for box in bboxes
131
+ if box[0] != chosen_box[0]
132
+ or intersection_over_union(
133
+ torch.tensor(chosen_box[2:]),
134
+ torch.tensor(box[2:]),
135
+ box_format=box_format,
136
+ )
137
+ < iou_threshold
138
+ ]
139
+
140
+ bboxes_after_nms.append(chosen_box)
141
+
142
+ return bboxes_after_nms
143
+
144
+
145
+
146
+
147
+ def draw_predictions(image: np.ndarray, boxes: List[List], class_labels: List[str]) -> np.ndarray:
148
+ """Plots predicted bounding boxes on the image"""
149
+
150
+ colors = [[random.randint(0, 255) for _ in range(3)] for name in class_labels]
151
+
152
+ im = np.array(image)
153
+ height, width, _ = im.shape
154
+ bbox_thick = int(0.6 * (height + width) / 600)
155
+
156
+ # Create a Rectangle patch
157
+ for box in boxes:
158
+ assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
159
+ class_pred = box[0]
160
+ conf = box[1]
161
+ box = box[2:]
162
+ upper_left_x = box[0] - box[2] / 2
163
+ upper_left_y = box[1] - box[3] / 2
164
+
165
+ x1 = int(upper_left_x * width)
166
+ y1 = int(upper_left_y * height)
167
+
168
+ x2 = x1 + int(box[2] * width)
169
+ y2 = y1 + int(box[3] * height)
170
+
171
+ cv2.rectangle(
172
+ image,
173
+ (x1, y1), (x2, y2),
174
+ color=colors[int(class_pred)],
175
+ thickness=bbox_thick
176
+ )
177
+ text = f"{class_labels[int(class_pred)]}: {conf:.2f}"
178
+ t_size = cv2.getTextSize(text, 0, 0.7, thickness=bbox_thick // 2)[0]
179
+ c3 = (x1 + t_size[0], y1 - t_size[1] - 3)
180
+
181
+ cv2.rectangle(image, (x1, y1), c3, colors[int(class_pred)], -1)
182
+ cv2.putText(
183
+ image,
184
+ text,
185
+ (x1, y1 - 2),
186
+ cv2.FONT_HERSHEY_SIMPLEX,
187
+ 0.7,
188
+ (0, 0, 0),
189
+ bbox_thick // 2,
190
+ lineType=cv2.LINE_AA,
191
+ )
192
+
193
+ return image
194
+
195
+
196
+ class YoloCAM(BaseCAM):
197
+ def __init__(self, model, target_layers, use_cuda=False,
198
+ reshape_transform=None):
199
+ super(YoloCAM, self).__init__(model,
200
+ target_layers,
201
+ use_cuda,
202
+ reshape_transform,
203
+ uses_gradients=False)
204
+
205
+ def forward(self,
206
+ input_tensor: torch.Tensor,
207
+ scaled_anchors: torch.Tensor,
208
+ targets: List[torch.nn.Module],
209
+ eigen_smooth: bool = False) -> np.ndarray:
210
+
211
+ if self.cuda:
212
+ input_tensor = input_tensor.cuda()
213
+
214
+ if self.compute_input_gradient:
215
+ input_tensor = torch.autograd.Variable(input_tensor,
216
+ requires_grad=True)
217
+
218
+ outputs = self.activations_and_grads(input_tensor)
219
+ if targets is None:
220
+ bboxes = [[] for _ in range(1)]
221
+ for i in range(3):
222
+ batch_size, A, S, _, _ = outputs[i].shape
223
+ anchor = scaled_anchors[i]
224
+ boxes_scale_i = cells_to_bboxes(
225
+ outputs[i], anchor, S=S, is_preds=True
226
+ )
227
+ for idx, (box) in enumerate(boxes_scale_i):
228
+ bboxes[idx] += box
229
+
230
+ nms_boxes = non_max_suppression(
231
+ bboxes[0], iou_threshold=0.5, threshold=0.4, box_format="midpoint",
232
+ )
233
+ # target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
234
+ target_categories = [box[0] for box in nms_boxes]
235
+ targets = [ClassifierOutputTarget(
236
+ category) for category in target_categories]
237
+
238
+ if self.uses_gradients:
239
+ self.model.zero_grad()
240
+ loss = sum([target(output)
241
+ for target, output in zip(targets, outputs)])
242
+ loss.backward(retain_graph=True)
243
+
244
+ # In most of the saliency attribution papers, the saliency is
245
+ # computed with a single target layer.
246
+ # Commonly it is the last convolutional layer.
247
+ # Here we support passing a list with multiple target layers.
248
+ # It will compute the saliency image for every image,
249
+ # and then aggregate them (with a default mean aggregation).
250
+ # This gives you more flexibility in case you just want to
251
+ # use all conv layers for example, all Batchnorm layers,
252
+ # or something else.
253
+ cam_per_layer = self.compute_cam_per_layer(input_tensor,
254
+ targets,
255
+ eigen_smooth)
256
+ return self.aggregate_multi_layers(cam_per_layer)
257
+
258
+ def get_cam_image(self,
259
+ input_tensor,
260
+ target_layer,
261
+ target_category,
262
+ activations,
263
+ grads,
264
+ eigen_smooth):
265
+ return get_2d_projection(activations)
266
+
267
+
268
+