File size: 19,429 Bytes
3094730
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import copy
import warnings
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union

import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision
from mmcv.transforms import Compose
from mmdet.evaluation import get_classes
from mmdet.utils import ConfigType
from mmengine.config import Config
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmengine.structures import InstanceData
from torch import Tensor

from mmyolo.registry import MODELS

try:
    from pytorch_grad_cam import (AblationCAM, AblationLayer,
                                  ActivationsAndGradients)
    from pytorch_grad_cam import GradCAM as Base_GradCAM
    from pytorch_grad_cam import GradCAMPlusPlus as Base_GradCAMPlusPlus
    from pytorch_grad_cam.base_cam import BaseCAM
    from pytorch_grad_cam.utils.image import scale_cam_image, show_cam_on_image
    from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
except ImportError:
    pass


def init_detector(
    config: Union[str, Path, Config],
    checkpoint: Optional[str] = None,
    palette: str = 'coco',
    device: str = 'cuda:0',
    cfg_options: Optional[dict] = None,
) -> nn.Module:
    """Initialize a detector from config file.

    Args:
        config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path,
            :obj:`Path`, or the config object.
        checkpoint (str, optional): Checkpoint path. If left as None, the model
            will not load any weights.
        palette (str): Color palette used for visualization. If palette
            is stored in checkpoint, use checkpoint's palette first, otherwise
            use externally passed palette. Currently, supports 'coco', 'voc',
            'citys' and 'random'. Defaults to coco.
        device (str): The device where the anchors will be put on.
            Defaults to cuda:0.
        cfg_options (dict, optional): Options to override some settings in
            the used config.

    Returns:
        nn.Module: The constructed detector.
    """
    if isinstance(config, (str, Path)):
        config = Config.fromfile(config)
    elif not isinstance(config, Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    if cfg_options is not None:
        config.merge_from_dict(cfg_options)
    elif 'init_cfg' in config.model.backbone:
        config.model.backbone.init_cfg = None

    # only change this
    # grad based method requires train_cfg
    # config.model.train_cfg = None
    init_default_scope(config.get('default_scope', 'mmyolo'))

    model = MODELS.build(config.model)
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
        # Weights converted from elsewhere may not have meta fields.
        checkpoint_meta = checkpoint.get('meta', {})
        # save the dataset_meta in the model for convenience
        if 'dataset_meta' in checkpoint_meta:
            # mmdet 3.x, all keys should be lowercase
            model.dataset_meta = {
                k.lower(): v
                for k, v in checkpoint_meta['dataset_meta'].items()
            }
        elif 'CLASSES' in checkpoint_meta:
            # < mmdet 3.x
            classes = checkpoint_meta['CLASSES']
            model.dataset_meta = {'classes': classes, 'palette': palette}
        else:
            warnings.simplefilter('once')
            warnings.warn(
                'dataset_meta or class names are not saved in the '
                'checkpoint\'s meta data, use COCO classes by default.')
            model.dataset_meta = {
                'classes': get_classes('coco'),
                'palette': palette
            }

    model.cfg = config  # save the config in the model for convenience
    model.to(device)
    model.eval()
    return model


def reshape_transform(feats: Union[Tensor, List[Tensor]],
                      max_shape: Tuple[int, int] = (20, 20),
                      is_need_grad: bool = False):
    """Reshape and aggregate feature maps when the input is a multi-layer
    feature map.

    Takes these tensors with different sizes, resizes them to a common shape,
    and concatenates them.
    """
    if len(max_shape) == 1:
        max_shape = max_shape * 2

    if isinstance(feats, torch.Tensor):
        feats = [feats]
    else:
        if is_need_grad:
            raise NotImplementedError('The `grad_base` method does not '
                                      'support output multi-activation layers')

    max_h = max([im.shape[-2] for im in feats])
    max_w = max([im.shape[-1] for im in feats])
    if -1 in max_shape:
        max_shape = (max_h, max_w)
    else:
        max_shape = (min(max_h, max_shape[0]), min(max_w, max_shape[1]))

    activations = []
    for feat in feats:
        activations.append(
            torch.nn.functional.interpolate(
                torch.abs(feat), max_shape, mode='bilinear'))

    activations = torch.cat(activations, axis=1)
    return activations


class BoxAMDetectorWrapper(nn.Module):
    """Wrap the mmdet model class to facilitate handling of non-tensor
    situations during inference."""

    def __init__(self,
                 cfg: ConfigType,
                 checkpoint: str,
                 score_thr: float,
                 device: str = 'cuda:0'):
        super().__init__()
        self.cfg = cfg
        self.device = device
        self.score_thr = score_thr
        self.checkpoint = checkpoint
        self.detector = init_detector(self.cfg, self.checkpoint, device=device)

        pipeline_cfg = copy.deepcopy(self.cfg.test_dataloader.dataset.pipeline)
        pipeline_cfg[0].type = 'mmdet.LoadImageFromNDArray'

        new_test_pipeline = []
        for pipeline in pipeline_cfg:
            if not pipeline['type'].endswith('LoadAnnotations'):
                new_test_pipeline.append(pipeline)
        self.test_pipeline = Compose(new_test_pipeline)

        self.is_need_loss = False
        self.input_data = None
        self.image = None

    def need_loss(self, is_need_loss: bool):
        """Grad-based methods require loss."""
        self.is_need_loss = is_need_loss

    def set_input_data(self,
                       image: np.ndarray,
                       pred_instances: Optional[InstanceData] = None):
        """Set the input data to be used in the next step."""
        self.image = image

        if self.is_need_loss:
            assert pred_instances is not None
            pred_instances = pred_instances.numpy()
            data = dict(
                img=self.image,
                img_id=0,
                gt_bboxes=pred_instances.bboxes,
                gt_bboxes_labels=pred_instances.labels)
            data = self.test_pipeline(data)
        else:
            data = dict(img=self.image, img_id=0)
            data = self.test_pipeline(data)
            data['inputs'] = [data['inputs']]
            data['data_samples'] = [data['data_samples']]
        self.input_data = data

    def __call__(self, *args, **kwargs):
        assert self.input_data is not None
        if self.is_need_loss:
            # Maybe this is a direction that can be optimized
            # self.detector.init_weights()

            self.detector.bbox_head.head_module.training = True
            if hasattr(self.detector.bbox_head, 'featmap_sizes'):
                # Prevent the model algorithm error when calculating loss
                self.detector.bbox_head.featmap_sizes = None

            data_ = {}
            data_['inputs'] = [self.input_data['inputs']]
            data_['data_samples'] = [self.input_data['data_samples']]
            data = self.detector.data_preprocessor(data_, training=False)
            loss = self.detector._run_forward(data, mode='loss')

            if hasattr(self.detector.bbox_head, 'featmap_sizes'):
                self.detector.bbox_head.featmap_sizes = None

            return [loss]
        else:
            self.detector.bbox_head.head_module.training = False
            with torch.no_grad():
                results = self.detector.test_step(self.input_data)
                return results


class BoxAMDetectorVisualizer:
    """Box AM visualization class."""

    def __init__(self,
                 method_class,
                 model: nn.Module,
                 target_layers: List,
                 reshape_transform: Optional[Callable] = None,
                 is_need_grad: bool = False,
                 extra_params: Optional[dict] = None):
        self.target_layers = target_layers
        self.reshape_transform = reshape_transform
        self.is_need_grad = is_need_grad

        if method_class.__name__ == 'AblationCAM':
            batch_size = extra_params.get('batch_size', 1)
            ratio_channels_to_ablate = extra_params.get(
                'ratio_channels_to_ablate', 1.)
            self.cam = AblationCAM(
                model,
                target_layers,
                use_cuda=True if 'cuda' in model.device else False,
                reshape_transform=reshape_transform,
                batch_size=batch_size,
                ablation_layer=extra_params['ablation_layer'],
                ratio_channels_to_ablate=ratio_channels_to_ablate)
        else:
            self.cam = method_class(
                model,
                target_layers,
                use_cuda=True if 'cuda' in model.device else False,
                reshape_transform=reshape_transform,
            )
            if self.is_need_grad:
                self.cam.activations_and_grads.release()

        self.classes = model.detector.dataset_meta['classes']
        self.COLORS = np.random.uniform(0, 255, size=(len(self.classes), 3))

    def switch_activations_and_grads(self, model) -> None:
        """In the grad-based method, we need to switch
        ``ActivationsAndGradients`` layer, otherwise an error will occur."""
        self.cam.model = model

        if self.is_need_grad is True:
            self.cam.activations_and_grads = ActivationsAndGradients(
                model, self.target_layers, self.reshape_transform)
            self.is_need_grad = False
        else:
            self.cam.activations_and_grads.release()
            self.is_need_grad = True

    def __call__(self, img, targets, aug_smooth=False, eigen_smooth=False):
        img = torch.from_numpy(img)[None].permute(0, 3, 1, 2)
        return self.cam(img, targets, aug_smooth, eigen_smooth)[0, :]

    def show_am(self,
                image: np.ndarray,
                pred_instance: InstanceData,
                grayscale_am: np.ndarray,
                with_norm_in_bboxes: bool = False):
        """Normalize the AM to be in the range [0, 1] inside every bounding
        boxes, and zero outside of the bounding boxes."""

        boxes = pred_instance.bboxes
        labels = pred_instance.labels

        if with_norm_in_bboxes is True:
            boxes = boxes.astype(np.int32)
            renormalized_am = np.zeros(grayscale_am.shape, dtype=np.float32)
            images = []
            for x1, y1, x2, y2 in boxes:
                img = renormalized_am * 0
                img[y1:y2, x1:x2] = scale_cam_image(
                    [grayscale_am[y1:y2, x1:x2].copy()])[0]
                images.append(img)

            renormalized_am = np.max(np.float32(images), axis=0)
            renormalized_am = scale_cam_image([renormalized_am])[0]
        else:
            renormalized_am = grayscale_am

        am_image_renormalized = show_cam_on_image(
            image / 255, renormalized_am, use_rgb=False)

        image_with_bounding_boxes = self._draw_boxes(
            boxes, labels, am_image_renormalized, pred_instance.get('scores'))
        return image_with_bounding_boxes

    def _draw_boxes(self,
                    boxes: List,
                    labels: List,
                    image: np.ndarray,
                    scores: Optional[List] = None):
        """draw boxes on image."""
        for i, box in enumerate(boxes):
            label = labels[i]
            color = self.COLORS[label]
            cv2.rectangle(image, (int(box[0]), int(box[1])),
                          (int(box[2]), int(box[3])), color, 2)
            if scores is not None:
                score = scores[i]
                text = str(self.classes[label]) + ': ' + str(
                    round(score * 100, 1))
            else:
                text = self.classes[label]

            cv2.putText(
                image,
                text, (int(box[0]), int(box[1] - 5)),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.5,
                color,
                1,
                lineType=cv2.LINE_AA)
        return image


class DetAblationLayer(AblationLayer):
    """Det AblationLayer."""

    def __init__(self):
        super().__init__()
        self.activations = None

    def set_next_batch(self, input_batch_index, activations,
                       num_channels_to_ablate):
        """Extract the next batch member from activations, and repeat it
        num_channels_to_ablate times."""
        if isinstance(activations, torch.Tensor):
            return super().set_next_batch(input_batch_index, activations,
                                          num_channels_to_ablate)

        self.activations = []
        for activation in activations:
            activation = activation[
                input_batch_index, :, :, :].clone().unsqueeze(0)
            self.activations.append(
                activation.repeat(num_channels_to_ablate, 1, 1, 1))

    def __call__(self, x):
        """Go over the activation indices to be ablated, stored in
        self.indices."""
        result = self.activations

        if isinstance(result, torch.Tensor):
            return super().__call__(x)

        channel_cumsum = np.cumsum([r.shape[1] for r in result])
        num_channels_to_ablate = result[0].size(0)  # batch
        for i in range(num_channels_to_ablate):
            pyramid_layer = bisect.bisect_right(channel_cumsum,
                                                self.indices[i])
            if pyramid_layer > 0:
                index_in_pyramid_layer = self.indices[i] - channel_cumsum[
                    pyramid_layer - 1]
            else:
                index_in_pyramid_layer = self.indices[i]
            result[pyramid_layer][i, index_in_pyramid_layer, :, :] = -1000
        return result


class DetBoxScoreTarget:
    """Det Score calculation class.

    In the case of the grad-free method, the calculation method is that
    for every original detected bounding box specified in "bboxes",
    assign a score on how the current bounding boxes match it,

        1. In Bbox IoU
        2. In the classification score.
        3. In Mask IoU if ``segms`` exist.

    If there is not a large enough overlap, or the category changed,
    assign a score of 0. The total score is the sum of all the box scores.

    In the case of the grad-based method, the calculation method is
    the sum of losses after excluding a specific key.
    """

    def __init__(self,
                 pred_instance: InstanceData,
                 match_iou_thr: float = 0.5,
                 device: str = 'cuda:0',
                 ignore_loss_params: Optional[List] = None):
        self.focal_bboxes = pred_instance.bboxes
        self.focal_labels = pred_instance.labels
        self.match_iou_thr = match_iou_thr
        self.device = device
        self.ignore_loss_params = ignore_loss_params
        if ignore_loss_params is not None:
            assert isinstance(self.ignore_loss_params, list)

    def __call__(self, results):
        output = torch.tensor([0.], device=self.device)

        if 'loss_cls' in results:
            # grad-based method
            # results is dict
            for loss_key, loss_value in results.items():
                if 'loss' not in loss_key or \
                        loss_key in self.ignore_loss_params:
                    continue
                if isinstance(loss_value, list):
                    output += sum(loss_value)
                else:
                    output += loss_value
            return output
        else:
            # grad-free method
            # results is DetDataSample
            pred_instances = results.pred_instances
            if len(pred_instances) == 0:
                return output

            pred_bboxes = pred_instances.bboxes
            pred_scores = pred_instances.scores
            pred_labels = pred_instances.labels

            for focal_box, focal_label in zip(self.focal_bboxes,
                                              self.focal_labels):
                ious = torchvision.ops.box_iou(focal_box[None],
                                               pred_bboxes[..., :4])
                index = ious.argmax()
                if ious[0, index] > self.match_iou_thr and pred_labels[
                        index] == focal_label:
                    # TODO: Adaptive adjustment of weights based on algorithms
                    score = ious[0, index] + pred_scores[index]
                    output = output + score
            return output


class SpatialBaseCAM(BaseCAM):
    """CAM that maintains spatial information.

    Gradients are often averaged over the spatial dimension in CAM
    visualization for classification, but this is unreasonable in detection
    tasks. There is no need to average the gradients in the detection task.
    """

    def get_cam_image(self,
                      input_tensor: torch.Tensor,
                      target_layer: torch.nn.Module,
                      targets: List[torch.nn.Module],
                      activations: torch.Tensor,
                      grads: torch.Tensor,
                      eigen_smooth: bool = False) -> np.ndarray:

        weights = self.get_cam_weights(input_tensor, target_layer, targets,
                                       activations, grads)
        weighted_activations = weights * activations
        if eigen_smooth:
            cam = get_2d_projection(weighted_activations)
        else:
            cam = weighted_activations.sum(axis=1)
        return cam


class GradCAM(SpatialBaseCAM, Base_GradCAM):
    """Gradients are no longer averaged over the spatial dimension."""

    def get_cam_weights(self, input_tensor, target_layer, target_category,
                        activations, grads):
        return grads


class GradCAMPlusPlus(SpatialBaseCAM, Base_GradCAMPlusPlus):
    """Gradients are no longer averaged over the spatial dimension."""

    def get_cam_weights(self, input_tensor, target_layers, target_category,
                        activations, grads):
        grads_power_2 = grads**2
        grads_power_3 = grads_power_2 * grads
        # Equation 19 in https://arxiv.org/abs/1710.11063
        sum_activations = np.sum(activations, axis=(2, 3))
        eps = 0.000001
        aij = grads_power_2 / (
            2 * grads_power_2 +
            sum_activations[:, :, None, None] * grads_power_3 + eps)
        # Now bring back the ReLU from eq.7 in the paper,
        # And zero out aijs where the activations are 0
        aij = np.where(grads != 0, aij, 0)

        weights = np.maximum(grads, 0) * aij
        return weights