Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import bisect | |
import copy | |
import warnings | |
from pathlib import Path | |
from typing import Callable, List, Optional, Tuple, Union | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from mmcv.transforms import Compose | |
from mmdet.evaluation import get_classes | |
from mmdet.utils import ConfigType | |
from mmengine.config import Config | |
from mmengine.registry import init_default_scope | |
from mmengine.runner import load_checkpoint | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmyolo.registry import MODELS | |
try: | |
from pytorch_grad_cam import (AblationCAM, AblationLayer, | |
ActivationsAndGradients) | |
from pytorch_grad_cam import GradCAM as Base_GradCAM | |
from pytorch_grad_cam import GradCAMPlusPlus as Base_GradCAMPlusPlus | |
from pytorch_grad_cam.base_cam import BaseCAM | |
from pytorch_grad_cam.utils.image import scale_cam_image, show_cam_on_image | |
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection | |
except ImportError: | |
pass | |
def init_detector( | |
config: Union[str, Path, Config], | |
checkpoint: Optional[str] = None, | |
palette: str = 'coco', | |
device: str = 'cuda:0', | |
cfg_options: Optional[dict] = None, | |
) -> nn.Module: | |
"""Initialize a detector from config file. | |
Args: | |
config (str, :obj:`Path`, or :obj:`mmengine.Config`): Config file path, | |
:obj:`Path`, or the config object. | |
checkpoint (str, optional): Checkpoint path. If left as None, the model | |
will not load any weights. | |
palette (str): Color palette used for visualization. If palette | |
is stored in checkpoint, use checkpoint's palette first, otherwise | |
use externally passed palette. Currently, supports 'coco', 'voc', | |
'citys' and 'random'. Defaults to coco. | |
device (str): The device where the anchors will be put on. | |
Defaults to cuda:0. | |
cfg_options (dict, optional): Options to override some settings in | |
the used config. | |
Returns: | |
nn.Module: The constructed detector. | |
""" | |
if isinstance(config, (str, Path)): | |
config = Config.fromfile(config) | |
elif not isinstance(config, Config): | |
raise TypeError('config must be a filename or Config object, ' | |
f'but got {type(config)}') | |
if cfg_options is not None: | |
config.merge_from_dict(cfg_options) | |
elif 'init_cfg' in config.model.backbone: | |
config.model.backbone.init_cfg = None | |
# only change this | |
# grad based method requires train_cfg | |
# config.model.train_cfg = None | |
init_default_scope(config.get('default_scope', 'mmyolo')) | |
model = MODELS.build(config.model) | |
if checkpoint is not None: | |
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') | |
# Weights converted from elsewhere may not have meta fields. | |
checkpoint_meta = checkpoint.get('meta', {}) | |
# save the dataset_meta in the model for convenience | |
if 'dataset_meta' in checkpoint_meta: | |
# mmdet 3.x, all keys should be lowercase | |
model.dataset_meta = { | |
k.lower(): v | |
for k, v in checkpoint_meta['dataset_meta'].items() | |
} | |
elif 'CLASSES' in checkpoint_meta: | |
# < mmdet 3.x | |
classes = checkpoint_meta['CLASSES'] | |
model.dataset_meta = {'classes': classes, 'palette': palette} | |
else: | |
warnings.simplefilter('once') | |
warnings.warn( | |
'dataset_meta or class names are not saved in the ' | |
'checkpoint\'s meta data, use COCO classes by default.') | |
model.dataset_meta = { | |
'classes': get_classes('coco'), | |
'palette': palette | |
} | |
model.cfg = config # save the config in the model for convenience | |
model.to(device) | |
model.eval() | |
return model | |
def reshape_transform(feats: Union[Tensor, List[Tensor]], | |
max_shape: Tuple[int, int] = (20, 20), | |
is_need_grad: bool = False): | |
"""Reshape and aggregate feature maps when the input is a multi-layer | |
feature map. | |
Takes these tensors with different sizes, resizes them to a common shape, | |
and concatenates them. | |
""" | |
if len(max_shape) == 1: | |
max_shape = max_shape * 2 | |
if isinstance(feats, torch.Tensor): | |
feats = [feats] | |
else: | |
if is_need_grad: | |
raise NotImplementedError('The `grad_base` method does not ' | |
'support output multi-activation layers') | |
max_h = max([im.shape[-2] for im in feats]) | |
max_w = max([im.shape[-1] for im in feats]) | |
if -1 in max_shape: | |
max_shape = (max_h, max_w) | |
else: | |
max_shape = (min(max_h, max_shape[0]), min(max_w, max_shape[1])) | |
activations = [] | |
for feat in feats: | |
activations.append( | |
torch.nn.functional.interpolate( | |
torch.abs(feat), max_shape, mode='bilinear')) | |
activations = torch.cat(activations, axis=1) | |
return activations | |
class BoxAMDetectorWrapper(nn.Module): | |
"""Wrap the mmdet model class to facilitate handling of non-tensor | |
situations during inference.""" | |
def __init__(self, | |
cfg: ConfigType, | |
checkpoint: str, | |
score_thr: float, | |
device: str = 'cuda:0'): | |
super().__init__() | |
self.cfg = cfg | |
self.device = device | |
self.score_thr = score_thr | |
self.checkpoint = checkpoint | |
self.detector = init_detector(self.cfg, self.checkpoint, device=device) | |
pipeline_cfg = copy.deepcopy(self.cfg.test_dataloader.dataset.pipeline) | |
pipeline_cfg[0].type = 'mmdet.LoadImageFromNDArray' | |
new_test_pipeline = [] | |
for pipeline in pipeline_cfg: | |
if not pipeline['type'].endswith('LoadAnnotations'): | |
new_test_pipeline.append(pipeline) | |
self.test_pipeline = Compose(new_test_pipeline) | |
self.is_need_loss = False | |
self.input_data = None | |
self.image = None | |
def need_loss(self, is_need_loss: bool): | |
"""Grad-based methods require loss.""" | |
self.is_need_loss = is_need_loss | |
def set_input_data(self, | |
image: np.ndarray, | |
pred_instances: Optional[InstanceData] = None): | |
"""Set the input data to be used in the next step.""" | |
self.image = image | |
if self.is_need_loss: | |
assert pred_instances is not None | |
pred_instances = pred_instances.numpy() | |
data = dict( | |
img=self.image, | |
img_id=0, | |
gt_bboxes=pred_instances.bboxes, | |
gt_bboxes_labels=pred_instances.labels) | |
data = self.test_pipeline(data) | |
else: | |
data = dict(img=self.image, img_id=0) | |
data = self.test_pipeline(data) | |
data['inputs'] = [data['inputs']] | |
data['data_samples'] = [data['data_samples']] | |
self.input_data = data | |
def __call__(self, *args, **kwargs): | |
assert self.input_data is not None | |
if self.is_need_loss: | |
# Maybe this is a direction that can be optimized | |
# self.detector.init_weights() | |
self.detector.bbox_head.head_module.training = True | |
if hasattr(self.detector.bbox_head, 'featmap_sizes'): | |
# Prevent the model algorithm error when calculating loss | |
self.detector.bbox_head.featmap_sizes = None | |
data_ = {} | |
data_['inputs'] = [self.input_data['inputs']] | |
data_['data_samples'] = [self.input_data['data_samples']] | |
data = self.detector.data_preprocessor(data_, training=False) | |
loss = self.detector._run_forward(data, mode='loss') | |
if hasattr(self.detector.bbox_head, 'featmap_sizes'): | |
self.detector.bbox_head.featmap_sizes = None | |
return [loss] | |
else: | |
self.detector.bbox_head.head_module.training = False | |
with torch.no_grad(): | |
results = self.detector.test_step(self.input_data) | |
return results | |
class BoxAMDetectorVisualizer: | |
"""Box AM visualization class.""" | |
def __init__(self, | |
method_class, | |
model: nn.Module, | |
target_layers: List, | |
reshape_transform: Optional[Callable] = None, | |
is_need_grad: bool = False, | |
extra_params: Optional[dict] = None): | |
self.target_layers = target_layers | |
self.reshape_transform = reshape_transform | |
self.is_need_grad = is_need_grad | |
if method_class.__name__ == 'AblationCAM': | |
batch_size = extra_params.get('batch_size', 1) | |
ratio_channels_to_ablate = extra_params.get( | |
'ratio_channels_to_ablate', 1.) | |
self.cam = AblationCAM( | |
model, | |
target_layers, | |
use_cuda=True if 'cuda' in model.device else False, | |
reshape_transform=reshape_transform, | |
batch_size=batch_size, | |
ablation_layer=extra_params['ablation_layer'], | |
ratio_channels_to_ablate=ratio_channels_to_ablate) | |
else: | |
self.cam = method_class( | |
model, | |
target_layers, | |
use_cuda=True if 'cuda' in model.device else False, | |
reshape_transform=reshape_transform, | |
) | |
if self.is_need_grad: | |
self.cam.activations_and_grads.release() | |
self.classes = model.detector.dataset_meta['classes'] | |
self.COLORS = np.random.uniform(0, 255, size=(len(self.classes), 3)) | |
def switch_activations_and_grads(self, model) -> None: | |
"""In the grad-based method, we need to switch | |
``ActivationsAndGradients`` layer, otherwise an error will occur.""" | |
self.cam.model = model | |
if self.is_need_grad is True: | |
self.cam.activations_and_grads = ActivationsAndGradients( | |
model, self.target_layers, self.reshape_transform) | |
self.is_need_grad = False | |
else: | |
self.cam.activations_and_grads.release() | |
self.is_need_grad = True | |
def __call__(self, img, targets, aug_smooth=False, eigen_smooth=False): | |
img = torch.from_numpy(img)[None].permute(0, 3, 1, 2) | |
return self.cam(img, targets, aug_smooth, eigen_smooth)[0, :] | |
def show_am(self, | |
image: np.ndarray, | |
pred_instance: InstanceData, | |
grayscale_am: np.ndarray, | |
with_norm_in_bboxes: bool = False): | |
"""Normalize the AM to be in the range [0, 1] inside every bounding | |
boxes, and zero outside of the bounding boxes.""" | |
boxes = pred_instance.bboxes | |
labels = pred_instance.labels | |
if with_norm_in_bboxes is True: | |
boxes = boxes.astype(np.int32) | |
renormalized_am = np.zeros(grayscale_am.shape, dtype=np.float32) | |
images = [] | |
for x1, y1, x2, y2 in boxes: | |
img = renormalized_am * 0 | |
img[y1:y2, x1:x2] = scale_cam_image( | |
[grayscale_am[y1:y2, x1:x2].copy()])[0] | |
images.append(img) | |
renormalized_am = np.max(np.float32(images), axis=0) | |
renormalized_am = scale_cam_image([renormalized_am])[0] | |
else: | |
renormalized_am = grayscale_am | |
am_image_renormalized = show_cam_on_image( | |
image / 255, renormalized_am, use_rgb=False) | |
image_with_bounding_boxes = self._draw_boxes( | |
boxes, labels, am_image_renormalized, pred_instance.get('scores')) | |
return image_with_bounding_boxes | |
def _draw_boxes(self, | |
boxes: List, | |
labels: List, | |
image: np.ndarray, | |
scores: Optional[List] = None): | |
"""draw boxes on image.""" | |
for i, box in enumerate(boxes): | |
label = labels[i] | |
color = self.COLORS[label] | |
cv2.rectangle(image, (int(box[0]), int(box[1])), | |
(int(box[2]), int(box[3])), color, 2) | |
if scores is not None: | |
score = scores[i] | |
text = str(self.classes[label]) + ': ' + str( | |
round(score * 100, 1)) | |
else: | |
text = self.classes[label] | |
cv2.putText( | |
image, | |
text, (int(box[0]), int(box[1] - 5)), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.5, | |
color, | |
1, | |
lineType=cv2.LINE_AA) | |
return image | |
class DetAblationLayer(AblationLayer): | |
"""Det AblationLayer.""" | |
def __init__(self): | |
super().__init__() | |
self.activations = None | |
def set_next_batch(self, input_batch_index, activations, | |
num_channels_to_ablate): | |
"""Extract the next batch member from activations, and repeat it | |
num_channels_to_ablate times.""" | |
if isinstance(activations, torch.Tensor): | |
return super().set_next_batch(input_batch_index, activations, | |
num_channels_to_ablate) | |
self.activations = [] | |
for activation in activations: | |
activation = activation[ | |
input_batch_index, :, :, :].clone().unsqueeze(0) | |
self.activations.append( | |
activation.repeat(num_channels_to_ablate, 1, 1, 1)) | |
def __call__(self, x): | |
"""Go over the activation indices to be ablated, stored in | |
self.indices.""" | |
result = self.activations | |
if isinstance(result, torch.Tensor): | |
return super().__call__(x) | |
channel_cumsum = np.cumsum([r.shape[1] for r in result]) | |
num_channels_to_ablate = result[0].size(0) # batch | |
for i in range(num_channels_to_ablate): | |
pyramid_layer = bisect.bisect_right(channel_cumsum, | |
self.indices[i]) | |
if pyramid_layer > 0: | |
index_in_pyramid_layer = self.indices[i] - channel_cumsum[ | |
pyramid_layer - 1] | |
else: | |
index_in_pyramid_layer = self.indices[i] | |
result[pyramid_layer][i, index_in_pyramid_layer, :, :] = -1000 | |
return result | |
class DetBoxScoreTarget: | |
"""Det Score calculation class. | |
In the case of the grad-free method, the calculation method is that | |
for every original detected bounding box specified in "bboxes", | |
assign a score on how the current bounding boxes match it, | |
1. In Bbox IoU | |
2. In the classification score. | |
3. In Mask IoU if ``segms`` exist. | |
If there is not a large enough overlap, or the category changed, | |
assign a score of 0. The total score is the sum of all the box scores. | |
In the case of the grad-based method, the calculation method is | |
the sum of losses after excluding a specific key. | |
""" | |
def __init__(self, | |
pred_instance: InstanceData, | |
match_iou_thr: float = 0.5, | |
device: str = 'cuda:0', | |
ignore_loss_params: Optional[List] = None): | |
self.focal_bboxes = pred_instance.bboxes | |
self.focal_labels = pred_instance.labels | |
self.match_iou_thr = match_iou_thr | |
self.device = device | |
self.ignore_loss_params = ignore_loss_params | |
if ignore_loss_params is not None: | |
assert isinstance(self.ignore_loss_params, list) | |
def __call__(self, results): | |
output = torch.tensor([0.], device=self.device) | |
if 'loss_cls' in results: | |
# grad-based method | |
# results is dict | |
for loss_key, loss_value in results.items(): | |
if 'loss' not in loss_key or \ | |
loss_key in self.ignore_loss_params: | |
continue | |
if isinstance(loss_value, list): | |
output += sum(loss_value) | |
else: | |
output += loss_value | |
return output | |
else: | |
# grad-free method | |
# results is DetDataSample | |
pred_instances = results.pred_instances | |
if len(pred_instances) == 0: | |
return output | |
pred_bboxes = pred_instances.bboxes | |
pred_scores = pred_instances.scores | |
pred_labels = pred_instances.labels | |
for focal_box, focal_label in zip(self.focal_bboxes, | |
self.focal_labels): | |
ious = torchvision.ops.box_iou(focal_box[None], | |
pred_bboxes[..., :4]) | |
index = ious.argmax() | |
if ious[0, index] > self.match_iou_thr and pred_labels[ | |
index] == focal_label: | |
# TODO: Adaptive adjustment of weights based on algorithms | |
score = ious[0, index] + pred_scores[index] | |
output = output + score | |
return output | |
class SpatialBaseCAM(BaseCAM): | |
"""CAM that maintains spatial information. | |
Gradients are often averaged over the spatial dimension in CAM | |
visualization for classification, but this is unreasonable in detection | |
tasks. There is no need to average the gradients in the detection task. | |
""" | |
def get_cam_image(self, | |
input_tensor: torch.Tensor, | |
target_layer: torch.nn.Module, | |
targets: List[torch.nn.Module], | |
activations: torch.Tensor, | |
grads: torch.Tensor, | |
eigen_smooth: bool = False) -> np.ndarray: | |
weights = self.get_cam_weights(input_tensor, target_layer, targets, | |
activations, grads) | |
weighted_activations = weights * activations | |
if eigen_smooth: | |
cam = get_2d_projection(weighted_activations) | |
else: | |
cam = weighted_activations.sum(axis=1) | |
return cam | |
class GradCAM(SpatialBaseCAM, Base_GradCAM): | |
"""Gradients are no longer averaged over the spatial dimension.""" | |
def get_cam_weights(self, input_tensor, target_layer, target_category, | |
activations, grads): | |
return grads | |
class GradCAMPlusPlus(SpatialBaseCAM, Base_GradCAMPlusPlus): | |
"""Gradients are no longer averaged over the spatial dimension.""" | |
def get_cam_weights(self, input_tensor, target_layers, target_category, | |
activations, grads): | |
grads_power_2 = grads**2 | |
grads_power_3 = grads_power_2 * grads | |
# Equation 19 in https://arxiv.org/abs/1710.11063 | |
sum_activations = np.sum(activations, axis=(2, 3)) | |
eps = 0.000001 | |
aij = grads_power_2 / ( | |
2 * grads_power_2 + | |
sum_activations[:, :, None, None] * grads_power_3 + eps) | |
# Now bring back the ReLU from eq.7 in the paper, | |
# And zero out aijs where the activations are 0 | |
aij = np.where(grads != 0, aij, 0) | |
weights = np.maximum(grads, 0) * aij | |
return weights | |