Spaces:
Runtime error
Runtime error
Upload 89 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- mmyolo/__init__.py +39 -0
- mmyolo/datasets/__init__.py +12 -0
- mmyolo/datasets/transforms/__init__.py +14 -0
- mmyolo/datasets/transforms/mix_img_transforms.py +1150 -0
- mmyolo/datasets/transforms/transforms.py +1557 -0
- mmyolo/datasets/utils.py +114 -0
- mmyolo/datasets/yolov5_coco.py +65 -0
- mmyolo/datasets/yolov5_crowdhuman.py +15 -0
- mmyolo/datasets/yolov5_dota.py +29 -0
- mmyolo/datasets/yolov5_voc.py +15 -0
- mmyolo/deploy/__init__.py +7 -0
- mmyolo/deploy/models/__init__.py +2 -0
- mmyolo/deploy/models/dense_heads/__init__.py +4 -0
- mmyolo/deploy/models/dense_heads/yolov5_head.py +189 -0
- mmyolo/deploy/models/layers/__init__.py +4 -0
- mmyolo/deploy/models/layers/bbox_nms.py +113 -0
- mmyolo/deploy/object_detection.py +132 -0
- mmyolo/engine/__init__.py +3 -0
- mmyolo/engine/hooks/__init__.py +10 -0
- mmyolo/engine/hooks/ppyoloe_param_scheduler_hook.py +96 -0
- mmyolo/engine/hooks/switch_to_deploy_hook.py +21 -0
- mmyolo/engine/hooks/yolov5_param_scheduler_hook.py +130 -0
- mmyolo/engine/hooks/yolox_mode_switch_hook.py +54 -0
- mmyolo/engine/optimizers/__init__.py +5 -0
- mmyolo/engine/optimizers/yolov5_optim_constructor.py +132 -0
- mmyolo/engine/optimizers/yolov7_optim_wrapper_constructor.py +139 -0
- mmyolo/models/__init__.py +10 -0
- mmyolo/models/backbones/__init__.py +13 -0
- mmyolo/models/backbones/base_backbone.py +225 -0
- mmyolo/models/backbones/csp_darknet.py +427 -0
- mmyolo/models/backbones/csp_resnet.py +169 -0
- mmyolo/models/backbones/cspnext.py +187 -0
- mmyolo/models/backbones/efficient_rep.py +287 -0
- mmyolo/models/backbones/yolov7_backbone.py +285 -0
- mmyolo/models/data_preprocessors/__init__.py +10 -0
- mmyolo/models/data_preprocessors/data_preprocessor.py +302 -0
- mmyolo/models/dense_heads/__init__.py +20 -0
- mmyolo/models/dense_heads/ppyoloe_head.py +374 -0
- mmyolo/models/dense_heads/rtmdet_head.py +368 -0
- mmyolo/models/dense_heads/rtmdet_ins_head.py +725 -0
- mmyolo/models/dense_heads/rtmdet_rotated_head.py +641 -0
- mmyolo/models/dense_heads/yolov5_head.py +890 -0
- mmyolo/models/dense_heads/yolov6_head.py +369 -0
- mmyolo/models/dense_heads/yolov7_head.py +404 -0
- mmyolo/models/dense_heads/yolov8_head.py +398 -0
- mmyolo/models/dense_heads/yolox_head.py +514 -0
- mmyolo/models/detectors/__init__.py +4 -0
- mmyolo/models/detectors/yolo_detector.py +53 -0
- mmyolo/models/layers/__init__.py +16 -0
- mmyolo/models/layers/ema.py +96 -0
mmyolo/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import mmcv
|
3 |
+
import mmdet
|
4 |
+
import mmengine
|
5 |
+
from mmengine.utils import digit_version
|
6 |
+
|
7 |
+
from .version import __version__, version_info
|
8 |
+
|
9 |
+
mmcv_minimum_version = '2.0.0rc4'
|
10 |
+
mmcv_maximum_version = '2.1.0'
|
11 |
+
mmcv_version = digit_version(mmcv.__version__)
|
12 |
+
|
13 |
+
mmengine_minimum_version = '0.6.0'
|
14 |
+
mmengine_maximum_version = '1.0.0'
|
15 |
+
mmengine_version = digit_version(mmengine.__version__)
|
16 |
+
|
17 |
+
mmdet_minimum_version = '3.0.0rc6'
|
18 |
+
mmdet_maximum_version = '3.1.0'
|
19 |
+
mmdet_version = digit_version(mmdet.__version__)
|
20 |
+
|
21 |
+
|
22 |
+
assert (mmcv_version >= digit_version(mmcv_minimum_version)
|
23 |
+
and mmcv_version < digit_version(mmcv_maximum_version)), \
|
24 |
+
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
25 |
+
f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
|
26 |
+
|
27 |
+
assert (mmengine_version >= digit_version(mmengine_minimum_version)
|
28 |
+
and mmengine_version < digit_version(mmengine_maximum_version)), \
|
29 |
+
f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
|
30 |
+
f'Please install mmengine>={mmengine_minimum_version}, ' \
|
31 |
+
f'<{mmengine_maximum_version}.'
|
32 |
+
|
33 |
+
assert (mmdet_version >= digit_version(mmdet_minimum_version)
|
34 |
+
and mmdet_version < digit_version(mmdet_maximum_version)), \
|
35 |
+
f'MMDetection=={mmdet.__version__} is used but incompatible. ' \
|
36 |
+
f'Please install mmdet>={mmdet_minimum_version}, ' \
|
37 |
+
f'<{mmdet_maximum_version}.'
|
38 |
+
|
39 |
+
__all__ = ['__version__', 'version_info', 'digit_version']
|
mmyolo/datasets/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .transforms import * # noqa: F401,F403
|
3 |
+
from .utils import BatchShapePolicy, yolov5_collate
|
4 |
+
from .yolov5_coco import YOLOv5CocoDataset
|
5 |
+
from .yolov5_crowdhuman import YOLOv5CrowdHumanDataset
|
6 |
+
from .yolov5_dota import YOLOv5DOTADataset
|
7 |
+
from .yolov5_voc import YOLOv5VOCDataset
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
'YOLOv5CocoDataset', 'YOLOv5VOCDataset', 'BatchShapePolicy',
|
11 |
+
'yolov5_collate', 'YOLOv5CrowdHumanDataset', 'YOLOv5DOTADataset'
|
12 |
+
]
|
mmyolo/datasets/transforms/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .mix_img_transforms import Mosaic, Mosaic9, YOLOv5MixUp, YOLOXMixUp
|
3 |
+
from .transforms import (LetterResize, LoadAnnotations, PPYOLOERandomCrop,
|
4 |
+
PPYOLOERandomDistort, RegularizeRotatedBox,
|
5 |
+
RemoveDataElement, YOLOv5CopyPaste,
|
6 |
+
YOLOv5HSVRandomAug, YOLOv5KeepRatioResize,
|
7 |
+
YOLOv5RandomAffine)
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
'YOLOv5KeepRatioResize', 'LetterResize', 'Mosaic', 'YOLOXMixUp',
|
11 |
+
'YOLOv5MixUp', 'YOLOv5HSVRandomAug', 'LoadAnnotations',
|
12 |
+
'YOLOv5RandomAffine', 'PPYOLOERandomDistort', 'PPYOLOERandomCrop',
|
13 |
+
'Mosaic9', 'YOLOv5CopyPaste', 'RemoveDataElement', 'RegularizeRotatedBox'
|
14 |
+
]
|
mmyolo/datasets/transforms/mix_img_transforms.py
ADDED
@@ -0,0 +1,1150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import collections
|
3 |
+
import copy
|
4 |
+
from abc import ABCMeta, abstractmethod
|
5 |
+
from typing import Optional, Sequence, Tuple, Union
|
6 |
+
|
7 |
+
import mmcv
|
8 |
+
import numpy as np
|
9 |
+
from mmcv.transforms import BaseTransform
|
10 |
+
from mmdet.structures.bbox import autocast_box_type
|
11 |
+
from mmengine.dataset import BaseDataset
|
12 |
+
from mmengine.dataset.base_dataset import Compose
|
13 |
+
from numpy import random
|
14 |
+
|
15 |
+
from mmyolo.registry import TRANSFORMS
|
16 |
+
|
17 |
+
|
18 |
+
class BaseMixImageTransform(BaseTransform, metaclass=ABCMeta):
|
19 |
+
"""A Base Transform of multiple images mixed.
|
20 |
+
|
21 |
+
Suitable for training on multiple images mixed data augmentation like
|
22 |
+
mosaic and mixup.
|
23 |
+
|
24 |
+
Cached mosaic transform will random select images from the cache
|
25 |
+
and combine them into one output image if use_cached is True.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
pre_transform(Sequence[str]): Sequence of transform object or
|
29 |
+
config dict to be composed. Defaults to None.
|
30 |
+
prob(float): The transformation probability. Defaults to 1.0.
|
31 |
+
use_cached (bool): Whether to use cache. Defaults to False.
|
32 |
+
max_cached_images (int): The maximum length of the cache. The larger
|
33 |
+
the cache, the stronger the randomness of this transform. As a
|
34 |
+
rule of thumb, providing 10 caches for each image suffices for
|
35 |
+
randomness. Defaults to 40.
|
36 |
+
random_pop (bool): Whether to randomly pop a result from the cache
|
37 |
+
when the cache is full. If set to False, use FIFO popping method.
|
38 |
+
Defaults to True.
|
39 |
+
max_refetch (int): The maximum number of retry iterations for getting
|
40 |
+
valid results from the pipeline. If the number of iterations is
|
41 |
+
greater than `max_refetch`, but results is still None, then the
|
42 |
+
iteration is terminated and raise the error. Defaults to 15.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self,
|
46 |
+
pre_transform: Optional[Sequence[str]] = None,
|
47 |
+
prob: float = 1.0,
|
48 |
+
use_cached: bool = False,
|
49 |
+
max_cached_images: int = 40,
|
50 |
+
random_pop: bool = True,
|
51 |
+
max_refetch: int = 15):
|
52 |
+
|
53 |
+
self.max_refetch = max_refetch
|
54 |
+
self.prob = prob
|
55 |
+
|
56 |
+
self.use_cached = use_cached
|
57 |
+
self.max_cached_images = max_cached_images
|
58 |
+
self.random_pop = random_pop
|
59 |
+
self.results_cache = []
|
60 |
+
|
61 |
+
if pre_transform is None:
|
62 |
+
self.pre_transform = None
|
63 |
+
else:
|
64 |
+
self.pre_transform = Compose(pre_transform)
|
65 |
+
|
66 |
+
@abstractmethod
|
67 |
+
def get_indexes(self, dataset: Union[BaseDataset,
|
68 |
+
list]) -> Union[list, int]:
|
69 |
+
"""Call function to collect indexes.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
dataset (:obj:`Dataset` or list): The dataset or cached list.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
list or int: indexes.
|
76 |
+
"""
|
77 |
+
pass
|
78 |
+
|
79 |
+
@abstractmethod
|
80 |
+
def mix_img_transform(self, results: dict) -> dict:
|
81 |
+
"""Mixed image data transformation.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
results (dict): Result dict.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
results (dict): Updated result dict.
|
88 |
+
"""
|
89 |
+
pass
|
90 |
+
|
91 |
+
@autocast_box_type()
|
92 |
+
def transform(self, results: dict) -> dict:
|
93 |
+
"""Data augmentation function.
|
94 |
+
|
95 |
+
The transform steps are as follows:
|
96 |
+
1. Randomly generate index list of other images.
|
97 |
+
2. Before Mosaic or MixUp need to go through the necessary
|
98 |
+
pre_transform, such as MixUp' pre_transform pipeline
|
99 |
+
include: 'LoadImageFromFile','LoadAnnotations',
|
100 |
+
'Mosaic' and 'RandomAffine'.
|
101 |
+
3. Use mix_img_transform function to implement specific
|
102 |
+
mix operations.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
results (dict): Result dict.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
results (dict): Updated result dict.
|
109 |
+
"""
|
110 |
+
|
111 |
+
if random.uniform(0, 1) > self.prob:
|
112 |
+
return results
|
113 |
+
|
114 |
+
if self.use_cached:
|
115 |
+
# Be careful: deep copying can be very time-consuming
|
116 |
+
# if results includes dataset.
|
117 |
+
dataset = results.pop('dataset', None)
|
118 |
+
self.results_cache.append(copy.deepcopy(results))
|
119 |
+
if len(self.results_cache) > self.max_cached_images:
|
120 |
+
if self.random_pop:
|
121 |
+
index = random.randint(0, len(self.results_cache) - 1)
|
122 |
+
else:
|
123 |
+
index = 0
|
124 |
+
self.results_cache.pop(index)
|
125 |
+
|
126 |
+
if len(self.results_cache) <= 4:
|
127 |
+
return results
|
128 |
+
else:
|
129 |
+
assert 'dataset' in results
|
130 |
+
# Be careful: deep copying can be very time-consuming
|
131 |
+
# if results includes dataset.
|
132 |
+
dataset = results.pop('dataset', None)
|
133 |
+
|
134 |
+
for _ in range(self.max_refetch):
|
135 |
+
# get index of one or three other images
|
136 |
+
if self.use_cached:
|
137 |
+
indexes = self.get_indexes(self.results_cache)
|
138 |
+
else:
|
139 |
+
indexes = self.get_indexes(dataset)
|
140 |
+
|
141 |
+
if not isinstance(indexes, collections.abc.Sequence):
|
142 |
+
indexes = [indexes]
|
143 |
+
|
144 |
+
if self.use_cached:
|
145 |
+
mix_results = [
|
146 |
+
copy.deepcopy(self.results_cache[i]) for i in indexes
|
147 |
+
]
|
148 |
+
else:
|
149 |
+
# get images information will be used for Mosaic or MixUp
|
150 |
+
mix_results = [
|
151 |
+
copy.deepcopy(dataset.get_data_info(index))
|
152 |
+
for index in indexes
|
153 |
+
]
|
154 |
+
|
155 |
+
if self.pre_transform is not None:
|
156 |
+
for i, data in enumerate(mix_results):
|
157 |
+
# pre_transform may also require dataset
|
158 |
+
data.update({'dataset': dataset})
|
159 |
+
# before Mosaic or MixUp need to go through
|
160 |
+
# the necessary pre_transform
|
161 |
+
_results = self.pre_transform(data)
|
162 |
+
_results.pop('dataset')
|
163 |
+
mix_results[i] = _results
|
164 |
+
|
165 |
+
if None not in mix_results:
|
166 |
+
results['mix_results'] = mix_results
|
167 |
+
break
|
168 |
+
print('Repeated calculation')
|
169 |
+
else:
|
170 |
+
raise RuntimeError(
|
171 |
+
'The loading pipeline of the original dataset'
|
172 |
+
' always return None. Please check the correctness '
|
173 |
+
'of the dataset and its pipeline.')
|
174 |
+
|
175 |
+
# Mosaic or MixUp
|
176 |
+
results = self.mix_img_transform(results)
|
177 |
+
|
178 |
+
if 'mix_results' in results:
|
179 |
+
results.pop('mix_results')
|
180 |
+
results['dataset'] = dataset
|
181 |
+
|
182 |
+
return results
|
183 |
+
|
184 |
+
|
185 |
+
@TRANSFORMS.register_module()
|
186 |
+
class Mosaic(BaseMixImageTransform):
|
187 |
+
"""Mosaic augmentation.
|
188 |
+
|
189 |
+
Given 4 images, mosaic transform combines them into
|
190 |
+
one output image. The output image is composed of the parts from each sub-
|
191 |
+
image.
|
192 |
+
|
193 |
+
.. code:: text
|
194 |
+
|
195 |
+
mosaic transform
|
196 |
+
center_x
|
197 |
+
+------------------------------+
|
198 |
+
| pad | |
|
199 |
+
| +-----------+ pad |
|
200 |
+
| | | |
|
201 |
+
| | image1 +-----------+
|
202 |
+
| | | |
|
203 |
+
| | | image2 |
|
204 |
+
center_y |----+-+-----------+-----------+
|
205 |
+
| | cropped | |
|
206 |
+
|pad | image3 | image4 |
|
207 |
+
| | | |
|
208 |
+
+----|-------------+-----------+
|
209 |
+
| |
|
210 |
+
+-------------+
|
211 |
+
|
212 |
+
The mosaic transform steps are as follows:
|
213 |
+
|
214 |
+
1. Choose the mosaic center as the intersections of 4 images
|
215 |
+
2. Get the left top image according to the index, and randomly
|
216 |
+
sample another 3 images from the custom dataset.
|
217 |
+
3. Sub image will be cropped if image is larger than mosaic patch
|
218 |
+
|
219 |
+
Required Keys:
|
220 |
+
|
221 |
+
- img
|
222 |
+
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
|
223 |
+
- gt_bboxes_labels (np.int64) (optional)
|
224 |
+
- gt_ignore_flags (bool) (optional)
|
225 |
+
- mix_results (List[dict])
|
226 |
+
|
227 |
+
Modified Keys:
|
228 |
+
|
229 |
+
- img
|
230 |
+
- img_shape
|
231 |
+
- gt_bboxes (optional)
|
232 |
+
- gt_bboxes_labels (optional)
|
233 |
+
- gt_ignore_flags (optional)
|
234 |
+
|
235 |
+
Args:
|
236 |
+
img_scale (Sequence[int]): Image size after mosaic pipeline of single
|
237 |
+
image. The shape order should be (width, height).
|
238 |
+
Defaults to (640, 640).
|
239 |
+
center_ratio_range (Sequence[float]): Center ratio range of mosaic
|
240 |
+
output. Defaults to (0.5, 1.5).
|
241 |
+
bbox_clip_border (bool, optional): Whether to clip the objects outside
|
242 |
+
the border of the image. In some dataset like MOT17, the gt bboxes
|
243 |
+
are allowed to cross the border of images. Therefore, we don't
|
244 |
+
need to clip the gt bboxes in these cases. Defaults to True.
|
245 |
+
pad_val (int): Pad value. Defaults to 114.
|
246 |
+
pre_transform(Sequence[dict]): Sequence of transform object or
|
247 |
+
config dict to be composed.
|
248 |
+
prob (float): Probability of applying this transformation.
|
249 |
+
Defaults to 1.0.
|
250 |
+
use_cached (bool): Whether to use cache. Defaults to False.
|
251 |
+
max_cached_images (int): The maximum length of the cache. The larger
|
252 |
+
the cache, the stronger the randomness of this transform. As a
|
253 |
+
rule of thumb, providing 10 caches for each image suffices for
|
254 |
+
randomness. Defaults to 40.
|
255 |
+
random_pop (bool): Whether to randomly pop a result from the cache
|
256 |
+
when the cache is full. If set to False, use FIFO popping method.
|
257 |
+
Defaults to True.
|
258 |
+
max_refetch (int): The maximum number of retry iterations for getting
|
259 |
+
valid results from the pipeline. If the number of iterations is
|
260 |
+
greater than `max_refetch`, but results is still None, then the
|
261 |
+
iteration is terminated and raise the error. Defaults to 15.
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(self,
|
265 |
+
img_scale: Tuple[int, int] = (640, 640),
|
266 |
+
center_ratio_range: Tuple[float, float] = (0.5, 1.5),
|
267 |
+
bbox_clip_border: bool = True,
|
268 |
+
pad_val: float = 114.0,
|
269 |
+
pre_transform: Sequence[dict] = None,
|
270 |
+
prob: float = 1.0,
|
271 |
+
use_cached: bool = False,
|
272 |
+
max_cached_images: int = 40,
|
273 |
+
random_pop: bool = True,
|
274 |
+
max_refetch: int = 15):
|
275 |
+
assert isinstance(img_scale, tuple)
|
276 |
+
assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \
|
277 |
+
f'got {prob}.'
|
278 |
+
if use_cached:
|
279 |
+
assert max_cached_images >= 4, 'The length of cache must >= 4, ' \
|
280 |
+
f'but got {max_cached_images}.'
|
281 |
+
|
282 |
+
super().__init__(
|
283 |
+
pre_transform=pre_transform,
|
284 |
+
prob=prob,
|
285 |
+
use_cached=use_cached,
|
286 |
+
max_cached_images=max_cached_images,
|
287 |
+
random_pop=random_pop,
|
288 |
+
max_refetch=max_refetch)
|
289 |
+
|
290 |
+
self.img_scale = img_scale
|
291 |
+
self.center_ratio_range = center_ratio_range
|
292 |
+
self.bbox_clip_border = bbox_clip_border
|
293 |
+
self.pad_val = pad_val
|
294 |
+
|
295 |
+
def get_indexes(self, dataset: Union[BaseDataset, list]) -> list:
|
296 |
+
"""Call function to collect indexes.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
dataset (:obj:`Dataset` or list): The dataset or cached list.
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
list: indexes.
|
303 |
+
"""
|
304 |
+
indexes = [random.randint(0, len(dataset)) for _ in range(3)]
|
305 |
+
return indexes
|
306 |
+
|
307 |
+
def mix_img_transform(self, results: dict) -> dict:
|
308 |
+
"""Mixed image data transformation.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
results (dict): Result dict.
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
results (dict): Updated result dict.
|
315 |
+
"""
|
316 |
+
assert 'mix_results' in results
|
317 |
+
mosaic_bboxes = []
|
318 |
+
mosaic_bboxes_labels = []
|
319 |
+
mosaic_ignore_flags = []
|
320 |
+
mosaic_masks = []
|
321 |
+
with_mask = True if 'gt_masks' in results else False
|
322 |
+
# self.img_scale is wh format
|
323 |
+
img_scale_w, img_scale_h = self.img_scale
|
324 |
+
|
325 |
+
if len(results['img'].shape) == 3:
|
326 |
+
mosaic_img = np.full(
|
327 |
+
(int(img_scale_h * 2), int(img_scale_w * 2), 3),
|
328 |
+
self.pad_val,
|
329 |
+
dtype=results['img'].dtype)
|
330 |
+
else:
|
331 |
+
mosaic_img = np.full((int(img_scale_h * 2), int(img_scale_w * 2)),
|
332 |
+
self.pad_val,
|
333 |
+
dtype=results['img'].dtype)
|
334 |
+
|
335 |
+
# mosaic center x, y
|
336 |
+
center_x = int(random.uniform(*self.center_ratio_range) * img_scale_w)
|
337 |
+
center_y = int(random.uniform(*self.center_ratio_range) * img_scale_h)
|
338 |
+
center_position = (center_x, center_y)
|
339 |
+
|
340 |
+
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
341 |
+
for i, loc in enumerate(loc_strs):
|
342 |
+
if loc == 'top_left':
|
343 |
+
results_patch = results
|
344 |
+
else:
|
345 |
+
results_patch = results['mix_results'][i - 1]
|
346 |
+
|
347 |
+
img_i = results_patch['img']
|
348 |
+
h_i, w_i = img_i.shape[:2]
|
349 |
+
# keep_ratio resize
|
350 |
+
scale_ratio_i = min(img_scale_h / h_i, img_scale_w / w_i)
|
351 |
+
img_i = mmcv.imresize(
|
352 |
+
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
|
353 |
+
|
354 |
+
# compute the combine parameters
|
355 |
+
paste_coord, crop_coord = self._mosaic_combine(
|
356 |
+
loc, center_position, img_i.shape[:2][::-1])
|
357 |
+
x1_p, y1_p, x2_p, y2_p = paste_coord
|
358 |
+
x1_c, y1_c, x2_c, y2_c = crop_coord
|
359 |
+
|
360 |
+
# crop and paste image
|
361 |
+
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
|
362 |
+
|
363 |
+
# adjust coordinate
|
364 |
+
gt_bboxes_i = results_patch['gt_bboxes']
|
365 |
+
gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
|
366 |
+
gt_ignore_flags_i = results_patch['gt_ignore_flags']
|
367 |
+
|
368 |
+
padw = x1_p - x1_c
|
369 |
+
padh = y1_p - y1_c
|
370 |
+
gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
|
371 |
+
gt_bboxes_i.translate_([padw, padh])
|
372 |
+
mosaic_bboxes.append(gt_bboxes_i)
|
373 |
+
mosaic_bboxes_labels.append(gt_bboxes_labels_i)
|
374 |
+
mosaic_ignore_flags.append(gt_ignore_flags_i)
|
375 |
+
if with_mask and results_patch.get('gt_masks', None) is not None:
|
376 |
+
gt_masks_i = results_patch['gt_masks']
|
377 |
+
gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i))
|
378 |
+
gt_masks_i = gt_masks_i.translate(
|
379 |
+
out_shape=(int(self.img_scale[0] * 2),
|
380 |
+
int(self.img_scale[1] * 2)),
|
381 |
+
offset=padw,
|
382 |
+
direction='horizontal')
|
383 |
+
gt_masks_i = gt_masks_i.translate(
|
384 |
+
out_shape=(int(self.img_scale[0] * 2),
|
385 |
+
int(self.img_scale[1] * 2)),
|
386 |
+
offset=padh,
|
387 |
+
direction='vertical')
|
388 |
+
mosaic_masks.append(gt_masks_i)
|
389 |
+
|
390 |
+
mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
|
391 |
+
mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
|
392 |
+
mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
|
393 |
+
|
394 |
+
if self.bbox_clip_border:
|
395 |
+
mosaic_bboxes.clip_([2 * img_scale_h, 2 * img_scale_w])
|
396 |
+
if with_mask:
|
397 |
+
mosaic_masks = mosaic_masks[0].cat(mosaic_masks)
|
398 |
+
results['gt_masks'] = mosaic_masks
|
399 |
+
else:
|
400 |
+
# remove outside bboxes
|
401 |
+
inside_inds = mosaic_bboxes.is_inside(
|
402 |
+
[2 * img_scale_h, 2 * img_scale_w]).numpy()
|
403 |
+
mosaic_bboxes = mosaic_bboxes[inside_inds]
|
404 |
+
mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
|
405 |
+
mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
|
406 |
+
if with_mask:
|
407 |
+
mosaic_masks = mosaic_masks[0].cat(mosaic_masks)[inside_inds]
|
408 |
+
results['gt_masks'] = mosaic_masks
|
409 |
+
|
410 |
+
results['img'] = mosaic_img
|
411 |
+
results['img_shape'] = mosaic_img.shape
|
412 |
+
results['gt_bboxes'] = mosaic_bboxes
|
413 |
+
results['gt_bboxes_labels'] = mosaic_bboxes_labels
|
414 |
+
results['gt_ignore_flags'] = mosaic_ignore_flags
|
415 |
+
|
416 |
+
return results
|
417 |
+
|
418 |
+
def _mosaic_combine(
|
419 |
+
self, loc: str, center_position_xy: Sequence[float],
|
420 |
+
img_shape_wh: Sequence[int]) -> Tuple[Tuple[int], Tuple[int]]:
|
421 |
+
"""Calculate global coordinate of mosaic image and local coordinate of
|
422 |
+
cropped sub-image.
|
423 |
+
|
424 |
+
Args:
|
425 |
+
loc (str): Index for the sub-image, loc in ('top_left',
|
426 |
+
'top_right', 'bottom_left', 'bottom_right').
|
427 |
+
center_position_xy (Sequence[float]): Mixing center for 4 images,
|
428 |
+
(x, y).
|
429 |
+
img_shape_wh (Sequence[int]): Width and height of sub-image
|
430 |
+
|
431 |
+
Returns:
|
432 |
+
tuple[tuple[float]]: Corresponding coordinate of pasting and
|
433 |
+
cropping
|
434 |
+
- paste_coord (tuple): paste corner coordinate in mosaic image.
|
435 |
+
- crop_coord (tuple): crop corner coordinate in mosaic image.
|
436 |
+
"""
|
437 |
+
assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
438 |
+
if loc == 'top_left':
|
439 |
+
# index0 to top left part of image
|
440 |
+
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
|
441 |
+
max(center_position_xy[1] - img_shape_wh[1], 0), \
|
442 |
+
center_position_xy[0], \
|
443 |
+
center_position_xy[1]
|
444 |
+
crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
|
445 |
+
y2 - y1), img_shape_wh[0], img_shape_wh[1]
|
446 |
+
|
447 |
+
elif loc == 'top_right':
|
448 |
+
# index1 to top right part of image
|
449 |
+
x1, y1, x2, y2 = center_position_xy[0], \
|
450 |
+
max(center_position_xy[1] - img_shape_wh[1], 0), \
|
451 |
+
min(center_position_xy[0] + img_shape_wh[0],
|
452 |
+
self.img_scale[0] * 2), \
|
453 |
+
center_position_xy[1]
|
454 |
+
crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
|
455 |
+
img_shape_wh[0], x2 - x1), img_shape_wh[1]
|
456 |
+
|
457 |
+
elif loc == 'bottom_left':
|
458 |
+
# index2 to bottom left part of image
|
459 |
+
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
|
460 |
+
center_position_xy[1], \
|
461 |
+
center_position_xy[0], \
|
462 |
+
min(self.img_scale[1] * 2, center_position_xy[1] +
|
463 |
+
img_shape_wh[1])
|
464 |
+
crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
|
465 |
+
y2 - y1, img_shape_wh[1])
|
466 |
+
|
467 |
+
else:
|
468 |
+
# index3 to bottom right part of image
|
469 |
+
x1, y1, x2, y2 = center_position_xy[0], \
|
470 |
+
center_position_xy[1], \
|
471 |
+
min(center_position_xy[0] + img_shape_wh[0],
|
472 |
+
self.img_scale[0] * 2), \
|
473 |
+
min(self.img_scale[1] * 2, center_position_xy[1] +
|
474 |
+
img_shape_wh[1])
|
475 |
+
crop_coord = 0, 0, min(img_shape_wh[0],
|
476 |
+
x2 - x1), min(y2 - y1, img_shape_wh[1])
|
477 |
+
|
478 |
+
paste_coord = x1, y1, x2, y2
|
479 |
+
return paste_coord, crop_coord
|
480 |
+
|
481 |
+
def __repr__(self) -> str:
|
482 |
+
repr_str = self.__class__.__name__
|
483 |
+
repr_str += f'(img_scale={self.img_scale}, '
|
484 |
+
repr_str += f'center_ratio_range={self.center_ratio_range}, '
|
485 |
+
repr_str += f'pad_val={self.pad_val}, '
|
486 |
+
repr_str += f'prob={self.prob})'
|
487 |
+
return repr_str
|
488 |
+
|
489 |
+
|
490 |
+
@TRANSFORMS.register_module()
|
491 |
+
class Mosaic9(BaseMixImageTransform):
|
492 |
+
"""Mosaic9 augmentation.
|
493 |
+
|
494 |
+
Given 9 images, mosaic transform combines them into
|
495 |
+
one output image. The output image is composed of the parts from each sub-
|
496 |
+
image.
|
497 |
+
|
498 |
+
.. code:: text
|
499 |
+
|
500 |
+
+-------------------------------+------------+
|
501 |
+
| pad | pad | |
|
502 |
+
| +----------+ | |
|
503 |
+
| | +---------------+ top_right |
|
504 |
+
| | | top | image2 |
|
505 |
+
| | top_left | image1 | |
|
506 |
+
| | image8 o--------+------+--------+---+
|
507 |
+
| | | | | |
|
508 |
+
+----+----------+ | right |pad|
|
509 |
+
| | center | image3 | |
|
510 |
+
| left | image0 +---------------+---|
|
511 |
+
| image7 | | | |
|
512 |
+
+---+-----------+---+--------+ | |
|
513 |
+
| | cropped | | bottom_right |pad|
|
514 |
+
| |bottom_left| | image4 | |
|
515 |
+
| | image6 | bottom | | |
|
516 |
+
+---|-----------+ image5 +---------------+---|
|
517 |
+
| pad | | pad |
|
518 |
+
+-----------+------------+-------------------+
|
519 |
+
|
520 |
+
The mosaic transform steps are as follows:
|
521 |
+
|
522 |
+
1. Get the center image according to the index, and randomly
|
523 |
+
sample another 8 images from the custom dataset.
|
524 |
+
2. Randomly offset the image after Mosaic
|
525 |
+
|
526 |
+
Required Keys:
|
527 |
+
|
528 |
+
- img
|
529 |
+
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
|
530 |
+
- gt_bboxes_labels (np.int64) (optional)
|
531 |
+
- gt_ignore_flags (bool) (optional)
|
532 |
+
- mix_results (List[dict])
|
533 |
+
|
534 |
+
Modified Keys:
|
535 |
+
|
536 |
+
- img
|
537 |
+
- img_shape
|
538 |
+
- gt_bboxes (optional)
|
539 |
+
- gt_bboxes_labels (optional)
|
540 |
+
- gt_ignore_flags (optional)
|
541 |
+
|
542 |
+
Args:
|
543 |
+
img_scale (Sequence[int]): Image size after mosaic pipeline of single
|
544 |
+
image. The shape order should be (width, height).
|
545 |
+
Defaults to (640, 640).
|
546 |
+
bbox_clip_border (bool, optional): Whether to clip the objects outside
|
547 |
+
the border of the image. In some dataset like MOT17, the gt bboxes
|
548 |
+
are allowed to cross the border of images. Therefore, we don't
|
549 |
+
need to clip the gt bboxes in these cases. Defaults to True.
|
550 |
+
pad_val (int): Pad value. Defaults to 114.
|
551 |
+
pre_transform(Sequence[dict]): Sequence of transform object or
|
552 |
+
config dict to be composed.
|
553 |
+
prob (float): Probability of applying this transformation.
|
554 |
+
Defaults to 1.0.
|
555 |
+
use_cached (bool): Whether to use cache. Defaults to False.
|
556 |
+
max_cached_images (int): The maximum length of the cache. The larger
|
557 |
+
the cache, the stronger the randomness of this transform. As a
|
558 |
+
rule of thumb, providing 5 caches for each image suffices for
|
559 |
+
randomness. Defaults to 50.
|
560 |
+
random_pop (bool): Whether to randomly pop a result from the cache
|
561 |
+
when the cache is full. If set to False, use FIFO popping method.
|
562 |
+
Defaults to True.
|
563 |
+
max_refetch (int): The maximum number of retry iterations for getting
|
564 |
+
valid results from the pipeline. If the number of iterations is
|
565 |
+
greater than `max_refetch`, but results is still None, then the
|
566 |
+
iteration is terminated and raise the error. Defaults to 15.
|
567 |
+
"""
|
568 |
+
|
569 |
+
def __init__(self,
|
570 |
+
img_scale: Tuple[int, int] = (640, 640),
|
571 |
+
bbox_clip_border: bool = True,
|
572 |
+
pad_val: Union[float, int] = 114.0,
|
573 |
+
pre_transform: Sequence[dict] = None,
|
574 |
+
prob: float = 1.0,
|
575 |
+
use_cached: bool = False,
|
576 |
+
max_cached_images: int = 50,
|
577 |
+
random_pop: bool = True,
|
578 |
+
max_refetch: int = 15):
|
579 |
+
assert isinstance(img_scale, tuple)
|
580 |
+
assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \
|
581 |
+
f'got {prob}.'
|
582 |
+
if use_cached:
|
583 |
+
assert max_cached_images >= 9, 'The length of cache must >= 9, ' \
|
584 |
+
f'but got {max_cached_images}.'
|
585 |
+
|
586 |
+
super().__init__(
|
587 |
+
pre_transform=pre_transform,
|
588 |
+
prob=prob,
|
589 |
+
use_cached=use_cached,
|
590 |
+
max_cached_images=max_cached_images,
|
591 |
+
random_pop=random_pop,
|
592 |
+
max_refetch=max_refetch)
|
593 |
+
|
594 |
+
self.img_scale = img_scale
|
595 |
+
self.bbox_clip_border = bbox_clip_border
|
596 |
+
self.pad_val = pad_val
|
597 |
+
|
598 |
+
# intermediate variables
|
599 |
+
self._current_img_shape = [0, 0]
|
600 |
+
self._center_img_shape = [0, 0]
|
601 |
+
self._previous_img_shape = [0, 0]
|
602 |
+
|
603 |
+
def get_indexes(self, dataset: Union[BaseDataset, list]) -> list:
|
604 |
+
"""Call function to collect indexes.
|
605 |
+
|
606 |
+
Args:
|
607 |
+
dataset (:obj:`Dataset` or list): The dataset or cached list.
|
608 |
+
|
609 |
+
Returns:
|
610 |
+
list: indexes.
|
611 |
+
"""
|
612 |
+
indexes = [random.randint(0, len(dataset)) for _ in range(8)]
|
613 |
+
return indexes
|
614 |
+
|
615 |
+
def mix_img_transform(self, results: dict) -> dict:
|
616 |
+
"""Mixed image data transformation.
|
617 |
+
|
618 |
+
Args:
|
619 |
+
results (dict): Result dict.
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
results (dict): Updated result dict.
|
623 |
+
"""
|
624 |
+
assert 'mix_results' in results
|
625 |
+
|
626 |
+
mosaic_bboxes = []
|
627 |
+
mosaic_bboxes_labels = []
|
628 |
+
mosaic_ignore_flags = []
|
629 |
+
|
630 |
+
img_scale_w, img_scale_h = self.img_scale
|
631 |
+
|
632 |
+
if len(results['img'].shape) == 3:
|
633 |
+
mosaic_img = np.full(
|
634 |
+
(int(img_scale_h * 3), int(img_scale_w * 3), 3),
|
635 |
+
self.pad_val,
|
636 |
+
dtype=results['img'].dtype)
|
637 |
+
else:
|
638 |
+
mosaic_img = np.full((int(img_scale_h * 3), int(img_scale_w * 3)),
|
639 |
+
self.pad_val,
|
640 |
+
dtype=results['img'].dtype)
|
641 |
+
|
642 |
+
# index = 0 is mean original image
|
643 |
+
# len(results['mix_results']) = 8
|
644 |
+
loc_strs = ('center', 'top', 'top_right', 'right', 'bottom_right',
|
645 |
+
'bottom', 'bottom_left', 'left', 'top_left')
|
646 |
+
|
647 |
+
results_all = [results, *results['mix_results']]
|
648 |
+
for index, results_patch in enumerate(results_all):
|
649 |
+
img_i = results_patch['img']
|
650 |
+
# keep_ratio resize
|
651 |
+
img_i_h, img_i_w = img_i.shape[:2]
|
652 |
+
scale_ratio_i = min(img_scale_h / img_i_h, img_scale_w / img_i_w)
|
653 |
+
img_i = mmcv.imresize(
|
654 |
+
img_i,
|
655 |
+
(int(img_i_w * scale_ratio_i), int(img_i_h * scale_ratio_i)))
|
656 |
+
|
657 |
+
paste_coord = self._mosaic_combine(loc_strs[index],
|
658 |
+
img_i.shape[:2])
|
659 |
+
|
660 |
+
padw, padh = paste_coord[:2]
|
661 |
+
x1, y1, x2, y2 = (max(x, 0) for x in paste_coord)
|
662 |
+
mosaic_img[y1:y2, x1:x2] = img_i[y1 - padh:, x1 - padw:]
|
663 |
+
|
664 |
+
gt_bboxes_i = results_patch['gt_bboxes']
|
665 |
+
gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
|
666 |
+
gt_ignore_flags_i = results_patch['gt_ignore_flags']
|
667 |
+
gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
|
668 |
+
gt_bboxes_i.translate_([padw, padh])
|
669 |
+
|
670 |
+
mosaic_bboxes.append(gt_bboxes_i)
|
671 |
+
mosaic_bboxes_labels.append(gt_bboxes_labels_i)
|
672 |
+
mosaic_ignore_flags.append(gt_ignore_flags_i)
|
673 |
+
|
674 |
+
# Offset
|
675 |
+
offset_x = int(random.uniform(0, img_scale_w))
|
676 |
+
offset_y = int(random.uniform(0, img_scale_h))
|
677 |
+
mosaic_img = mosaic_img[offset_y:offset_y + 2 * img_scale_h,
|
678 |
+
offset_x:offset_x + 2 * img_scale_w]
|
679 |
+
|
680 |
+
mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
|
681 |
+
mosaic_bboxes.translate_([-offset_x, -offset_y])
|
682 |
+
mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
|
683 |
+
mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
|
684 |
+
|
685 |
+
if self.bbox_clip_border:
|
686 |
+
mosaic_bboxes.clip_([2 * img_scale_h, 2 * img_scale_w])
|
687 |
+
else:
|
688 |
+
# remove outside bboxes
|
689 |
+
inside_inds = mosaic_bboxes.is_inside(
|
690 |
+
[2 * img_scale_h, 2 * img_scale_w]).numpy()
|
691 |
+
mosaic_bboxes = mosaic_bboxes[inside_inds]
|
692 |
+
mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
|
693 |
+
mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
|
694 |
+
|
695 |
+
results['img'] = mosaic_img
|
696 |
+
results['img_shape'] = mosaic_img.shape
|
697 |
+
results['gt_bboxes'] = mosaic_bboxes
|
698 |
+
results['gt_bboxes_labels'] = mosaic_bboxes_labels
|
699 |
+
results['gt_ignore_flags'] = mosaic_ignore_flags
|
700 |
+
return results
|
701 |
+
|
702 |
+
def _mosaic_combine(self, loc: str,
|
703 |
+
img_shape_hw: Tuple[int, int]) -> Tuple[int, ...]:
|
704 |
+
"""Calculate global coordinate of mosaic image.
|
705 |
+
|
706 |
+
Args:
|
707 |
+
loc (str): Index for the sub-image.
|
708 |
+
img_shape_hw (Sequence[int]): Height and width of sub-image
|
709 |
+
|
710 |
+
Returns:
|
711 |
+
paste_coord (tuple): paste corner coordinate in mosaic image.
|
712 |
+
"""
|
713 |
+
assert loc in ('center', 'top', 'top_right', 'right', 'bottom_right',
|
714 |
+
'bottom', 'bottom_left', 'left', 'top_left')
|
715 |
+
|
716 |
+
img_scale_w, img_scale_h = self.img_scale
|
717 |
+
|
718 |
+
self._current_img_shape = img_shape_hw
|
719 |
+
current_img_h, current_img_w = self._current_img_shape
|
720 |
+
previous_img_h, previous_img_w = self._previous_img_shape
|
721 |
+
center_img_h, center_img_w = self._center_img_shape
|
722 |
+
|
723 |
+
if loc == 'center':
|
724 |
+
self._center_img_shape = self._current_img_shape
|
725 |
+
# xmin, ymin, xmax, ymax
|
726 |
+
paste_coord = img_scale_w, \
|
727 |
+
img_scale_h, \
|
728 |
+
img_scale_w + current_img_w, \
|
729 |
+
img_scale_h + current_img_h
|
730 |
+
elif loc == 'top':
|
731 |
+
paste_coord = img_scale_w, \
|
732 |
+
img_scale_h - current_img_h, \
|
733 |
+
img_scale_w + current_img_w, \
|
734 |
+
img_scale_h
|
735 |
+
elif loc == 'top_right':
|
736 |
+
paste_coord = img_scale_w + previous_img_w, \
|
737 |
+
img_scale_h - current_img_h, \
|
738 |
+
img_scale_w + previous_img_w + current_img_w, \
|
739 |
+
img_scale_h
|
740 |
+
elif loc == 'right':
|
741 |
+
paste_coord = img_scale_w + center_img_w, \
|
742 |
+
img_scale_h, \
|
743 |
+
img_scale_w + center_img_w + current_img_w, \
|
744 |
+
img_scale_h + current_img_h
|
745 |
+
elif loc == 'bottom_right':
|
746 |
+
paste_coord = img_scale_w + center_img_w, \
|
747 |
+
img_scale_h + previous_img_h, \
|
748 |
+
img_scale_w + center_img_w + current_img_w, \
|
749 |
+
img_scale_h + previous_img_h + current_img_h
|
750 |
+
elif loc == 'bottom':
|
751 |
+
paste_coord = img_scale_w + center_img_w - current_img_w, \
|
752 |
+
img_scale_h + center_img_h, \
|
753 |
+
img_scale_w + center_img_w, \
|
754 |
+
img_scale_h + center_img_h + current_img_h
|
755 |
+
elif loc == 'bottom_left':
|
756 |
+
paste_coord = img_scale_w + center_img_w - \
|
757 |
+
previous_img_w - current_img_w, \
|
758 |
+
img_scale_h + center_img_h, \
|
759 |
+
img_scale_w + center_img_w - previous_img_w, \
|
760 |
+
img_scale_h + center_img_h + current_img_h
|
761 |
+
elif loc == 'left':
|
762 |
+
paste_coord = img_scale_w - current_img_w, \
|
763 |
+
img_scale_h + center_img_h - current_img_h, \
|
764 |
+
img_scale_w, \
|
765 |
+
img_scale_h + center_img_h
|
766 |
+
elif loc == 'top_left':
|
767 |
+
paste_coord = img_scale_w - current_img_w, \
|
768 |
+
img_scale_h + center_img_h - \
|
769 |
+
previous_img_h - current_img_h, \
|
770 |
+
img_scale_w, \
|
771 |
+
img_scale_h + center_img_h - previous_img_h
|
772 |
+
|
773 |
+
self._previous_img_shape = self._current_img_shape
|
774 |
+
# xmin, ymin, xmax, ymax
|
775 |
+
return paste_coord
|
776 |
+
|
777 |
+
def __repr__(self) -> str:
|
778 |
+
repr_str = self.__class__.__name__
|
779 |
+
repr_str += f'(img_scale={self.img_scale}, '
|
780 |
+
repr_str += f'pad_val={self.pad_val}, '
|
781 |
+
repr_str += f'prob={self.prob})'
|
782 |
+
return repr_str
|
783 |
+
|
784 |
+
|
785 |
+
@TRANSFORMS.register_module()
|
786 |
+
class YOLOv5MixUp(BaseMixImageTransform):
|
787 |
+
"""MixUp data augmentation for YOLOv5.
|
788 |
+
|
789 |
+
.. code:: text
|
790 |
+
|
791 |
+
The mixup transform steps are as follows:
|
792 |
+
|
793 |
+
1. Another random image is picked by dataset.
|
794 |
+
2. Randomly obtain the fusion ratio from the beta distribution,
|
795 |
+
then fuse the target
|
796 |
+
of the original image and mixup image through this ratio.
|
797 |
+
|
798 |
+
Required Keys:
|
799 |
+
|
800 |
+
- img
|
801 |
+
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
|
802 |
+
- gt_bboxes_labels (np.int64) (optional)
|
803 |
+
- gt_ignore_flags (bool) (optional)
|
804 |
+
- mix_results (List[dict])
|
805 |
+
|
806 |
+
|
807 |
+
Modified Keys:
|
808 |
+
|
809 |
+
- img
|
810 |
+
- img_shape
|
811 |
+
- gt_bboxes (optional)
|
812 |
+
- gt_bboxes_labels (optional)
|
813 |
+
- gt_ignore_flags (optional)
|
814 |
+
|
815 |
+
|
816 |
+
Args:
|
817 |
+
alpha (float): parameter of beta distribution to get mixup ratio.
|
818 |
+
Defaults to 32.
|
819 |
+
beta (float): parameter of beta distribution to get mixup ratio.
|
820 |
+
Defaults to 32.
|
821 |
+
pre_transform (Sequence[dict]): Sequence of transform object or
|
822 |
+
config dict to be composed.
|
823 |
+
prob (float): Probability of applying this transformation.
|
824 |
+
Defaults to 1.0.
|
825 |
+
use_cached (bool): Whether to use cache. Defaults to False.
|
826 |
+
max_cached_images (int): The maximum length of the cache. The larger
|
827 |
+
the cache, the stronger the randomness of this transform. As a
|
828 |
+
rule of thumb, providing 10 caches for each image suffices for
|
829 |
+
randomness. Defaults to 20.
|
830 |
+
random_pop (bool): Whether to randomly pop a result from the cache
|
831 |
+
when the cache is full. If set to False, use FIFO popping method.
|
832 |
+
Defaults to True.
|
833 |
+
max_refetch (int): The maximum number of iterations. If the number of
|
834 |
+
iterations is greater than `max_refetch`, but gt_bbox is still
|
835 |
+
empty, then the iteration is terminated. Defaults to 15.
|
836 |
+
"""
|
837 |
+
|
838 |
+
def __init__(self,
|
839 |
+
alpha: float = 32.0,
|
840 |
+
beta: float = 32.0,
|
841 |
+
pre_transform: Sequence[dict] = None,
|
842 |
+
prob: float = 1.0,
|
843 |
+
use_cached: bool = False,
|
844 |
+
max_cached_images: int = 20,
|
845 |
+
random_pop: bool = True,
|
846 |
+
max_refetch: int = 15):
|
847 |
+
if use_cached:
|
848 |
+
assert max_cached_images >= 2, 'The length of cache must >= 2, ' \
|
849 |
+
f'but got {max_cached_images}.'
|
850 |
+
super().__init__(
|
851 |
+
pre_transform=pre_transform,
|
852 |
+
prob=prob,
|
853 |
+
use_cached=use_cached,
|
854 |
+
max_cached_images=max_cached_images,
|
855 |
+
random_pop=random_pop,
|
856 |
+
max_refetch=max_refetch)
|
857 |
+
self.alpha = alpha
|
858 |
+
self.beta = beta
|
859 |
+
|
860 |
+
def get_indexes(self, dataset: Union[BaseDataset, list]) -> int:
|
861 |
+
"""Call function to collect indexes.
|
862 |
+
|
863 |
+
Args:
|
864 |
+
dataset (:obj:`Dataset` or list): The dataset or cached list.
|
865 |
+
|
866 |
+
Returns:
|
867 |
+
int: indexes.
|
868 |
+
"""
|
869 |
+
return random.randint(0, len(dataset))
|
870 |
+
|
871 |
+
def mix_img_transform(self, results: dict) -> dict:
|
872 |
+
"""YOLOv5 MixUp transform function.
|
873 |
+
|
874 |
+
Args:
|
875 |
+
results (dict): Result dict
|
876 |
+
|
877 |
+
Returns:
|
878 |
+
results (dict): Updated result dict.
|
879 |
+
"""
|
880 |
+
assert 'mix_results' in results
|
881 |
+
|
882 |
+
retrieve_results = results['mix_results'][0]
|
883 |
+
retrieve_img = retrieve_results['img']
|
884 |
+
ori_img = results['img']
|
885 |
+
assert ori_img.shape == retrieve_img.shape
|
886 |
+
|
887 |
+
# Randomly obtain the fusion ratio from the beta distribution,
|
888 |
+
# which is around 0.5
|
889 |
+
ratio = np.random.beta(self.alpha, self.beta)
|
890 |
+
mixup_img = (ori_img * ratio + retrieve_img * (1 - ratio))
|
891 |
+
|
892 |
+
retrieve_gt_bboxes = retrieve_results['gt_bboxes']
|
893 |
+
retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
|
894 |
+
retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
|
895 |
+
|
896 |
+
mixup_gt_bboxes = retrieve_gt_bboxes.cat(
|
897 |
+
(results['gt_bboxes'], retrieve_gt_bboxes), dim=0)
|
898 |
+
mixup_gt_bboxes_labels = np.concatenate(
|
899 |
+
(results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
|
900 |
+
mixup_gt_ignore_flags = np.concatenate(
|
901 |
+
(results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
|
902 |
+
if 'gt_masks' in results:
|
903 |
+
assert 'gt_masks' in retrieve_results
|
904 |
+
mixup_gt_masks = results['gt_masks'].cat(
|
905 |
+
[results['gt_masks'], retrieve_results['gt_masks']])
|
906 |
+
results['gt_masks'] = mixup_gt_masks
|
907 |
+
|
908 |
+
results['img'] = mixup_img.astype(np.uint8)
|
909 |
+
results['img_shape'] = mixup_img.shape
|
910 |
+
results['gt_bboxes'] = mixup_gt_bboxes
|
911 |
+
results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
|
912 |
+
results['gt_ignore_flags'] = mixup_gt_ignore_flags
|
913 |
+
|
914 |
+
return results
|
915 |
+
|
916 |
+
|
917 |
+
@TRANSFORMS.register_module()
|
918 |
+
class YOLOXMixUp(BaseMixImageTransform):
|
919 |
+
"""MixUp data augmentation for YOLOX.
|
920 |
+
|
921 |
+
.. code:: text
|
922 |
+
|
923 |
+
mixup transform
|
924 |
+
+---------------+--------------+
|
925 |
+
| mixup image | |
|
926 |
+
| +--------|--------+ |
|
927 |
+
| | | | |
|
928 |
+
+---------------+ | |
|
929 |
+
| | | |
|
930 |
+
| | image | |
|
931 |
+
| | | |
|
932 |
+
| | | |
|
933 |
+
| +-----------------+ |
|
934 |
+
| pad |
|
935 |
+
+------------------------------+
|
936 |
+
|
937 |
+
The mixup transform steps are as follows:
|
938 |
+
|
939 |
+
1. Another random image is picked by dataset and embedded in
|
940 |
+
the top left patch(after padding and resizing)
|
941 |
+
2. The target of mixup transform is the weighted average of mixup
|
942 |
+
image and origin image.
|
943 |
+
|
944 |
+
Required Keys:
|
945 |
+
|
946 |
+
- img
|
947 |
+
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
|
948 |
+
- gt_bboxes_labels (np.int64) (optional)
|
949 |
+
- gt_ignore_flags (bool) (optional)
|
950 |
+
- mix_results (List[dict])
|
951 |
+
|
952 |
+
|
953 |
+
Modified Keys:
|
954 |
+
|
955 |
+
- img
|
956 |
+
- img_shape
|
957 |
+
- gt_bboxes (optional)
|
958 |
+
- gt_bboxes_labels (optional)
|
959 |
+
- gt_ignore_flags (optional)
|
960 |
+
|
961 |
+
|
962 |
+
Args:
|
963 |
+
img_scale (Sequence[int]): Image output size after mixup pipeline.
|
964 |
+
The shape order should be (width, height). Defaults to (640, 640).
|
965 |
+
ratio_range (Sequence[float]): Scale ratio of mixup image.
|
966 |
+
Defaults to (0.5, 1.5).
|
967 |
+
flip_ratio (float): Horizontal flip ratio of mixup image.
|
968 |
+
Defaults to 0.5.
|
969 |
+
pad_val (int): Pad value. Defaults to 114.
|
970 |
+
bbox_clip_border (bool, optional): Whether to clip the objects outside
|
971 |
+
the border of the image. In some dataset like MOT17, the gt bboxes
|
972 |
+
are allowed to cross the border of images. Therefore, we don't
|
973 |
+
need to clip the gt bboxes in these cases. Defaults to True.
|
974 |
+
pre_transform(Sequence[dict]): Sequence of transform object or
|
975 |
+
config dict to be composed.
|
976 |
+
prob (float): Probability of applying this transformation.
|
977 |
+
Defaults to 1.0.
|
978 |
+
use_cached (bool): Whether to use cache. Defaults to False.
|
979 |
+
max_cached_images (int): The maximum length of the cache. The larger
|
980 |
+
the cache, the stronger the randomness of this transform. As a
|
981 |
+
rule of thumb, providing 10 caches for each image suffices for
|
982 |
+
randomness. Defaults to 20.
|
983 |
+
random_pop (bool): Whether to randomly pop a result from the cache
|
984 |
+
when the cache is full. If set to False, use FIFO popping method.
|
985 |
+
Defaults to True.
|
986 |
+
max_refetch (int): The maximum number of iterations. If the number of
|
987 |
+
iterations is greater than `max_refetch`, but gt_bbox is still
|
988 |
+
empty, then the iteration is terminated. Defaults to 15.
|
989 |
+
"""
|
990 |
+
|
991 |
+
def __init__(self,
|
992 |
+
img_scale: Tuple[int, int] = (640, 640),
|
993 |
+
ratio_range: Tuple[float, float] = (0.5, 1.5),
|
994 |
+
flip_ratio: float = 0.5,
|
995 |
+
pad_val: float = 114.0,
|
996 |
+
bbox_clip_border: bool = True,
|
997 |
+
pre_transform: Sequence[dict] = None,
|
998 |
+
prob: float = 1.0,
|
999 |
+
use_cached: bool = False,
|
1000 |
+
max_cached_images: int = 20,
|
1001 |
+
random_pop: bool = True,
|
1002 |
+
max_refetch: int = 15):
|
1003 |
+
assert isinstance(img_scale, tuple)
|
1004 |
+
if use_cached:
|
1005 |
+
assert max_cached_images >= 2, 'The length of cache must >= 2, ' \
|
1006 |
+
f'but got {max_cached_images}.'
|
1007 |
+
super().__init__(
|
1008 |
+
pre_transform=pre_transform,
|
1009 |
+
prob=prob,
|
1010 |
+
use_cached=use_cached,
|
1011 |
+
max_cached_images=max_cached_images,
|
1012 |
+
random_pop=random_pop,
|
1013 |
+
max_refetch=max_refetch)
|
1014 |
+
self.img_scale = img_scale
|
1015 |
+
self.ratio_range = ratio_range
|
1016 |
+
self.flip_ratio = flip_ratio
|
1017 |
+
self.pad_val = pad_val
|
1018 |
+
self.bbox_clip_border = bbox_clip_border
|
1019 |
+
|
1020 |
+
def get_indexes(self, dataset: Union[BaseDataset, list]) -> int:
|
1021 |
+
"""Call function to collect indexes.
|
1022 |
+
|
1023 |
+
Args:
|
1024 |
+
dataset (:obj:`Dataset` or list): The dataset or cached list.
|
1025 |
+
|
1026 |
+
Returns:
|
1027 |
+
int: indexes.
|
1028 |
+
"""
|
1029 |
+
return random.randint(0, len(dataset))
|
1030 |
+
|
1031 |
+
def mix_img_transform(self, results: dict) -> dict:
|
1032 |
+
"""YOLOX MixUp transform function.
|
1033 |
+
|
1034 |
+
Args:
|
1035 |
+
results (dict): Result dict.
|
1036 |
+
|
1037 |
+
Returns:
|
1038 |
+
results (dict): Updated result dict.
|
1039 |
+
"""
|
1040 |
+
assert 'mix_results' in results
|
1041 |
+
assert len(
|
1042 |
+
results['mix_results']) == 1, 'MixUp only support 2 images now !'
|
1043 |
+
|
1044 |
+
if results['mix_results'][0]['gt_bboxes'].shape[0] == 0:
|
1045 |
+
# empty bbox
|
1046 |
+
return results
|
1047 |
+
|
1048 |
+
retrieve_results = results['mix_results'][0]
|
1049 |
+
retrieve_img = retrieve_results['img']
|
1050 |
+
|
1051 |
+
jit_factor = random.uniform(*self.ratio_range)
|
1052 |
+
is_filp = random.uniform(0, 1) > self.flip_ratio
|
1053 |
+
|
1054 |
+
if len(retrieve_img.shape) == 3:
|
1055 |
+
out_img = np.ones((self.img_scale[1], self.img_scale[0], 3),
|
1056 |
+
dtype=retrieve_img.dtype) * self.pad_val
|
1057 |
+
else:
|
1058 |
+
out_img = np.ones(
|
1059 |
+
self.img_scale[::-1], dtype=retrieve_img.dtype) * self.pad_val
|
1060 |
+
|
1061 |
+
# 1. keep_ratio resize
|
1062 |
+
scale_ratio = min(self.img_scale[1] / retrieve_img.shape[0],
|
1063 |
+
self.img_scale[0] / retrieve_img.shape[1])
|
1064 |
+
retrieve_img = mmcv.imresize(
|
1065 |
+
retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
|
1066 |
+
int(retrieve_img.shape[0] * scale_ratio)))
|
1067 |
+
|
1068 |
+
# 2. paste
|
1069 |
+
out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
|
1070 |
+
|
1071 |
+
# 3. scale jit
|
1072 |
+
scale_ratio *= jit_factor
|
1073 |
+
out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
|
1074 |
+
int(out_img.shape[0] * jit_factor)))
|
1075 |
+
|
1076 |
+
# 4. flip
|
1077 |
+
if is_filp:
|
1078 |
+
out_img = out_img[:, ::-1, :]
|
1079 |
+
|
1080 |
+
# 5. random crop
|
1081 |
+
ori_img = results['img']
|
1082 |
+
origin_h, origin_w = out_img.shape[:2]
|
1083 |
+
target_h, target_w = ori_img.shape[:2]
|
1084 |
+
padded_img = np.ones((max(origin_h, target_h), max(
|
1085 |
+
origin_w, target_w), 3)) * self.pad_val
|
1086 |
+
padded_img = padded_img.astype(np.uint8)
|
1087 |
+
padded_img[:origin_h, :origin_w] = out_img
|
1088 |
+
|
1089 |
+
x_offset, y_offset = 0, 0
|
1090 |
+
if padded_img.shape[0] > target_h:
|
1091 |
+
y_offset = random.randint(0, padded_img.shape[0] - target_h)
|
1092 |
+
if padded_img.shape[1] > target_w:
|
1093 |
+
x_offset = random.randint(0, padded_img.shape[1] - target_w)
|
1094 |
+
padded_cropped_img = padded_img[y_offset:y_offset + target_h,
|
1095 |
+
x_offset:x_offset + target_w]
|
1096 |
+
|
1097 |
+
# 6. adjust bbox
|
1098 |
+
retrieve_gt_bboxes = retrieve_results['gt_bboxes']
|
1099 |
+
retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio])
|
1100 |
+
if self.bbox_clip_border:
|
1101 |
+
retrieve_gt_bboxes.clip_([origin_h, origin_w])
|
1102 |
+
|
1103 |
+
if is_filp:
|
1104 |
+
retrieve_gt_bboxes.flip_([origin_h, origin_w],
|
1105 |
+
direction='horizontal')
|
1106 |
+
|
1107 |
+
# 7. filter
|
1108 |
+
cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone()
|
1109 |
+
cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset])
|
1110 |
+
if self.bbox_clip_border:
|
1111 |
+
cp_retrieve_gt_bboxes.clip_([target_h, target_w])
|
1112 |
+
|
1113 |
+
# 8. mix up
|
1114 |
+
mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img
|
1115 |
+
|
1116 |
+
retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
|
1117 |
+
retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
|
1118 |
+
|
1119 |
+
mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat(
|
1120 |
+
(results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0)
|
1121 |
+
mixup_gt_bboxes_labels = np.concatenate(
|
1122 |
+
(results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
|
1123 |
+
mixup_gt_ignore_flags = np.concatenate(
|
1124 |
+
(results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
|
1125 |
+
|
1126 |
+
if not self.bbox_clip_border:
|
1127 |
+
# remove outside bbox
|
1128 |
+
inside_inds = mixup_gt_bboxes.is_inside([target_h,
|
1129 |
+
target_w]).numpy()
|
1130 |
+
mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
|
1131 |
+
mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds]
|
1132 |
+
mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]
|
1133 |
+
|
1134 |
+
results['img'] = mixup_img.astype(np.uint8)
|
1135 |
+
results['img_shape'] = mixup_img.shape
|
1136 |
+
results['gt_bboxes'] = mixup_gt_bboxes
|
1137 |
+
results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
|
1138 |
+
results['gt_ignore_flags'] = mixup_gt_ignore_flags
|
1139 |
+
|
1140 |
+
return results
|
1141 |
+
|
1142 |
+
def __repr__(self) -> str:
|
1143 |
+
repr_str = self.__class__.__name__
|
1144 |
+
repr_str += f'(img_scale={self.img_scale}, '
|
1145 |
+
repr_str += f'ratio_range={self.ratio_range}, '
|
1146 |
+
repr_str += f'flip_ratio={self.flip_ratio}, '
|
1147 |
+
repr_str += f'pad_val={self.pad_val}, '
|
1148 |
+
repr_str += f'max_refetch={self.max_refetch}, '
|
1149 |
+
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
|
1150 |
+
return repr_str
|
mmyolo/datasets/transforms/transforms.py
ADDED
@@ -0,0 +1,1557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
from copy import deepcopy
|
4 |
+
from typing import List, Sequence, Tuple, Union
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import mmcv
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from mmcv.transforms import BaseTransform, Compose
|
11 |
+
from mmcv.transforms.utils import cache_randomness
|
12 |
+
from mmdet.datasets.transforms import LoadAnnotations as MMDET_LoadAnnotations
|
13 |
+
from mmdet.datasets.transforms import Resize as MMDET_Resize
|
14 |
+
from mmdet.structures.bbox import (HorizontalBoxes, autocast_box_type,
|
15 |
+
get_box_type)
|
16 |
+
from mmdet.structures.mask import PolygonMasks
|
17 |
+
from numpy import random
|
18 |
+
|
19 |
+
from mmyolo.registry import TRANSFORMS
|
20 |
+
|
21 |
+
# TODO: Waiting for MMCV support
|
22 |
+
TRANSFORMS.register_module(module=Compose, force=True)
|
23 |
+
|
24 |
+
|
25 |
+
@TRANSFORMS.register_module()
|
26 |
+
class YOLOv5KeepRatioResize(MMDET_Resize):
|
27 |
+
"""Resize images & bbox(if existed).
|
28 |
+
|
29 |
+
This transform resizes the input image according to ``scale``.
|
30 |
+
Bboxes (if existed) are then resized with the same scale factor.
|
31 |
+
|
32 |
+
Required Keys:
|
33 |
+
|
34 |
+
- img (np.uint8)
|
35 |
+
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
|
36 |
+
|
37 |
+
Modified Keys:
|
38 |
+
|
39 |
+
- img (np.uint8)
|
40 |
+
- img_shape (tuple)
|
41 |
+
- gt_bboxes (optional)
|
42 |
+
- scale (float)
|
43 |
+
|
44 |
+
Added Keys:
|
45 |
+
|
46 |
+
- scale_factor (np.float32)
|
47 |
+
|
48 |
+
Args:
|
49 |
+
scale (Union[int, Tuple[int, int]]): Images scales for resizing.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self,
|
53 |
+
scale: Union[int, Tuple[int, int]],
|
54 |
+
keep_ratio: bool = True,
|
55 |
+
**kwargs):
|
56 |
+
assert keep_ratio is True
|
57 |
+
super().__init__(scale=scale, keep_ratio=True, **kwargs)
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def _get_rescale_ratio(old_size: Tuple[int, int],
|
61 |
+
scale: Union[float, Tuple[int]]) -> float:
|
62 |
+
"""Calculate the ratio for rescaling.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
old_size (tuple[int]): The old size (w, h) of image.
|
66 |
+
scale (float | tuple[int]): The scaling factor or maximum size.
|
67 |
+
If it is a float number, then the image will be rescaled by
|
68 |
+
this factor, else if it is a tuple of 2 integers, then
|
69 |
+
the image will be rescaled as large as possible within
|
70 |
+
the scale.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
float: The resize ratio.
|
74 |
+
"""
|
75 |
+
w, h = old_size
|
76 |
+
if isinstance(scale, (float, int)):
|
77 |
+
if scale <= 0:
|
78 |
+
raise ValueError(f'Invalid scale {scale}, must be positive.')
|
79 |
+
scale_factor = scale
|
80 |
+
elif isinstance(scale, tuple):
|
81 |
+
max_long_edge = max(scale)
|
82 |
+
max_short_edge = min(scale)
|
83 |
+
scale_factor = min(max_long_edge / max(h, w),
|
84 |
+
max_short_edge / min(h, w))
|
85 |
+
else:
|
86 |
+
raise TypeError('Scale must be a number or tuple of int, '
|
87 |
+
f'but got {type(scale)}')
|
88 |
+
|
89 |
+
return scale_factor
|
90 |
+
|
91 |
+
def _resize_img(self, results: dict):
|
92 |
+
"""Resize images with ``results['scale']``."""
|
93 |
+
assert self.keep_ratio is True
|
94 |
+
|
95 |
+
if results.get('img', None) is not None:
|
96 |
+
image = results['img']
|
97 |
+
original_h, original_w = image.shape[:2]
|
98 |
+
ratio = self._get_rescale_ratio((original_h, original_w),
|
99 |
+
self.scale)
|
100 |
+
|
101 |
+
if ratio != 1:
|
102 |
+
# resize image according to the ratio
|
103 |
+
image = mmcv.imrescale(
|
104 |
+
img=image,
|
105 |
+
scale=ratio,
|
106 |
+
interpolation='area' if ratio < 1 else 'bilinear',
|
107 |
+
backend=self.backend)
|
108 |
+
|
109 |
+
resized_h, resized_w = image.shape[:2]
|
110 |
+
scale_ratio = resized_h / original_h
|
111 |
+
|
112 |
+
scale_factor = (scale_ratio, scale_ratio)
|
113 |
+
|
114 |
+
results['img'] = image
|
115 |
+
results['img_shape'] = image.shape[:2]
|
116 |
+
results['scale_factor'] = scale_factor
|
117 |
+
|
118 |
+
|
119 |
+
@TRANSFORMS.register_module()
|
120 |
+
class LetterResize(MMDET_Resize):
|
121 |
+
"""Resize and pad image while meeting stride-multiple constraints.
|
122 |
+
|
123 |
+
Required Keys:
|
124 |
+
|
125 |
+
- img (np.uint8)
|
126 |
+
- batch_shape (np.int64) (optional)
|
127 |
+
|
128 |
+
Modified Keys:
|
129 |
+
|
130 |
+
- img (np.uint8)
|
131 |
+
- img_shape (tuple)
|
132 |
+
- gt_bboxes (optional)
|
133 |
+
|
134 |
+
Added Keys:
|
135 |
+
- pad_param (np.float32)
|
136 |
+
|
137 |
+
Args:
|
138 |
+
scale (Union[int, Tuple[int, int]]): Images scales for resizing.
|
139 |
+
pad_val (dict): Padding value. Defaults to dict(img=0, seg=255).
|
140 |
+
use_mini_pad (bool): Whether using minimum rectangle padding.
|
141 |
+
Defaults to True
|
142 |
+
stretch_only (bool): Whether stretch to the specified size directly.
|
143 |
+
Defaults to False
|
144 |
+
allow_scale_up (bool): Allow scale up when ratio > 1. Defaults to True
|
145 |
+
"""
|
146 |
+
|
147 |
+
def __init__(self,
|
148 |
+
scale: Union[int, Tuple[int, int]],
|
149 |
+
pad_val: dict = dict(img=0, mask=0, seg=255),
|
150 |
+
use_mini_pad: bool = False,
|
151 |
+
stretch_only: bool = False,
|
152 |
+
allow_scale_up: bool = True,
|
153 |
+
**kwargs):
|
154 |
+
super().__init__(scale=scale, keep_ratio=True, **kwargs)
|
155 |
+
|
156 |
+
self.pad_val = pad_val
|
157 |
+
if isinstance(pad_val, (int, float)):
|
158 |
+
pad_val = dict(img=pad_val, seg=255)
|
159 |
+
assert isinstance(
|
160 |
+
pad_val, dict), f'pad_val must be dict, but got {type(pad_val)}'
|
161 |
+
|
162 |
+
self.use_mini_pad = use_mini_pad
|
163 |
+
self.stretch_only = stretch_only
|
164 |
+
self.allow_scale_up = allow_scale_up
|
165 |
+
|
166 |
+
def _resize_img(self, results: dict):
|
167 |
+
"""Resize images with ``results['scale']``."""
|
168 |
+
image = results.get('img', None)
|
169 |
+
if image is None:
|
170 |
+
return
|
171 |
+
|
172 |
+
# Use batch_shape if a batch_shape policy is configured
|
173 |
+
if 'batch_shape' in results:
|
174 |
+
scale = tuple(results['batch_shape']) # hw
|
175 |
+
else:
|
176 |
+
scale = self.scale[::-1] # wh -> hw
|
177 |
+
|
178 |
+
image_shape = image.shape[:2] # height, width
|
179 |
+
|
180 |
+
# Scale ratio (new / old)
|
181 |
+
ratio = min(scale[0] / image_shape[0], scale[1] / image_shape[1])
|
182 |
+
|
183 |
+
# only scale down, do not scale up (for better test mAP)
|
184 |
+
if not self.allow_scale_up:
|
185 |
+
ratio = min(ratio, 1.0)
|
186 |
+
|
187 |
+
ratio = [ratio, ratio] # float -> (float, float) for (height, width)
|
188 |
+
|
189 |
+
# compute the best size of the image
|
190 |
+
no_pad_shape = (int(round(image_shape[0] * ratio[0])),
|
191 |
+
int(round(image_shape[1] * ratio[1])))
|
192 |
+
|
193 |
+
# padding height & width
|
194 |
+
padding_h, padding_w = [
|
195 |
+
scale[0] - no_pad_shape[0], scale[1] - no_pad_shape[1]
|
196 |
+
]
|
197 |
+
if self.use_mini_pad:
|
198 |
+
# minimum rectangle padding
|
199 |
+
padding_w, padding_h = np.mod(padding_w, 32), np.mod(padding_h, 32)
|
200 |
+
|
201 |
+
elif self.stretch_only:
|
202 |
+
# stretch to the specified size directly
|
203 |
+
padding_h, padding_w = 0.0, 0.0
|
204 |
+
no_pad_shape = (scale[0], scale[1])
|
205 |
+
ratio = [scale[0] / image_shape[0],
|
206 |
+
scale[1] / image_shape[1]] # height, width ratios
|
207 |
+
|
208 |
+
if image_shape != no_pad_shape:
|
209 |
+
# compare with no resize and padding size
|
210 |
+
image = mmcv.imresize(
|
211 |
+
image, (no_pad_shape[1], no_pad_shape[0]),
|
212 |
+
interpolation=self.interpolation,
|
213 |
+
backend=self.backend)
|
214 |
+
|
215 |
+
scale_factor = (ratio[1], ratio[0]) # mmcv scale factor is (w, h)
|
216 |
+
|
217 |
+
if 'scale_factor' in results:
|
218 |
+
results['scale_factor_origin'] = results['scale_factor']
|
219 |
+
results['scale_factor'] = scale_factor
|
220 |
+
|
221 |
+
# padding
|
222 |
+
top_padding, left_padding = int(round(padding_h // 2 - 0.1)), int(
|
223 |
+
round(padding_w // 2 - 0.1))
|
224 |
+
bottom_padding = padding_h - top_padding
|
225 |
+
right_padding = padding_w - left_padding
|
226 |
+
|
227 |
+
padding_list = [
|
228 |
+
top_padding, bottom_padding, left_padding, right_padding
|
229 |
+
]
|
230 |
+
if top_padding != 0 or bottom_padding != 0 or \
|
231 |
+
left_padding != 0 or right_padding != 0:
|
232 |
+
|
233 |
+
pad_val = self.pad_val.get('img', 0)
|
234 |
+
if isinstance(pad_val, int) and image.ndim == 3:
|
235 |
+
pad_val = tuple(pad_val for _ in range(image.shape[2]))
|
236 |
+
|
237 |
+
image = mmcv.impad(
|
238 |
+
img=image,
|
239 |
+
padding=(padding_list[2], padding_list[0], padding_list[3],
|
240 |
+
padding_list[1]),
|
241 |
+
pad_val=pad_val,
|
242 |
+
padding_mode='constant')
|
243 |
+
|
244 |
+
results['img'] = image
|
245 |
+
results['img_shape'] = image.shape
|
246 |
+
if 'pad_param' in results:
|
247 |
+
results['pad_param_origin'] = results['pad_param'] * \
|
248 |
+
np.repeat(ratio, 2)
|
249 |
+
results['pad_param'] = np.array(padding_list, dtype=np.float32)
|
250 |
+
|
251 |
+
def _resize_masks(self, results: dict):
|
252 |
+
"""Resize masks with ``results['scale']``"""
|
253 |
+
if results.get('gt_masks', None) is None:
|
254 |
+
return
|
255 |
+
|
256 |
+
gt_masks = results['gt_masks']
|
257 |
+
assert isinstance(
|
258 |
+
gt_masks, PolygonMasks
|
259 |
+
), f'Only supports PolygonMasks, but got {type(gt_masks)}'
|
260 |
+
|
261 |
+
# resize the gt_masks
|
262 |
+
gt_mask_h = results['gt_masks'].height * results['scale_factor'][1]
|
263 |
+
gt_mask_w = results['gt_masks'].width * results['scale_factor'][0]
|
264 |
+
gt_masks = results['gt_masks'].resize(
|
265 |
+
(int(round(gt_mask_h)), int(round(gt_mask_w))))
|
266 |
+
|
267 |
+
top_padding, _, left_padding, _ = results['pad_param']
|
268 |
+
if int(left_padding) != 0:
|
269 |
+
gt_masks = gt_masks.translate(
|
270 |
+
out_shape=results['img_shape'][:2],
|
271 |
+
offset=int(left_padding),
|
272 |
+
direction='horizontal')
|
273 |
+
if int(top_padding) != 0:
|
274 |
+
gt_masks = gt_masks.translate(
|
275 |
+
out_shape=results['img_shape'][:2],
|
276 |
+
offset=int(top_padding),
|
277 |
+
direction='vertical')
|
278 |
+
results['gt_masks'] = gt_masks
|
279 |
+
|
280 |
+
def _resize_bboxes(self, results: dict):
|
281 |
+
"""Resize bounding boxes with ``results['scale_factor']``."""
|
282 |
+
if results.get('gt_bboxes', None) is None:
|
283 |
+
return
|
284 |
+
results['gt_bboxes'].rescale_(results['scale_factor'])
|
285 |
+
|
286 |
+
if len(results['pad_param']) != 4:
|
287 |
+
return
|
288 |
+
results['gt_bboxes'].translate_(
|
289 |
+
(results['pad_param'][2], results['pad_param'][0]))
|
290 |
+
|
291 |
+
if self.clip_object_border:
|
292 |
+
results['gt_bboxes'].clip_(results['img_shape'])
|
293 |
+
|
294 |
+
def transform(self, results: dict) -> dict:
|
295 |
+
results = super().transform(results)
|
296 |
+
if 'scale_factor_origin' in results:
|
297 |
+
scale_factor_origin = results.pop('scale_factor_origin')
|
298 |
+
results['scale_factor'] = (results['scale_factor'][0] *
|
299 |
+
scale_factor_origin[0],
|
300 |
+
results['scale_factor'][1] *
|
301 |
+
scale_factor_origin[1])
|
302 |
+
if 'pad_param_origin' in results:
|
303 |
+
pad_param_origin = results.pop('pad_param_origin')
|
304 |
+
results['pad_param'] += pad_param_origin
|
305 |
+
return results
|
306 |
+
|
307 |
+
|
308 |
+
# TODO: Check if it can be merged with mmdet.YOLOXHSVRandomAug
|
309 |
+
@TRANSFORMS.register_module()
|
310 |
+
class YOLOv5HSVRandomAug(BaseTransform):
|
311 |
+
"""Apply HSV augmentation to image sequentially.
|
312 |
+
|
313 |
+
Required Keys:
|
314 |
+
|
315 |
+
- img
|
316 |
+
|
317 |
+
Modified Keys:
|
318 |
+
|
319 |
+
- img
|
320 |
+
|
321 |
+
Args:
|
322 |
+
hue_delta ([int, float]): delta of hue. Defaults to 0.015.
|
323 |
+
saturation_delta ([int, float]): delta of saturation. Defaults to 0.7.
|
324 |
+
value_delta ([int, float]): delta of value. Defaults to 0.4.
|
325 |
+
"""
|
326 |
+
|
327 |
+
def __init__(self,
|
328 |
+
hue_delta: Union[int, float] = 0.015,
|
329 |
+
saturation_delta: Union[int, float] = 0.7,
|
330 |
+
value_delta: Union[int, float] = 0.4):
|
331 |
+
self.hue_delta = hue_delta
|
332 |
+
self.saturation_delta = saturation_delta
|
333 |
+
self.value_delta = value_delta
|
334 |
+
|
335 |
+
def transform(self, results: dict) -> dict:
|
336 |
+
"""The HSV augmentation transform function.
|
337 |
+
|
338 |
+
Args:
|
339 |
+
results (dict): The result dict.
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
dict: The result dict.
|
343 |
+
"""
|
344 |
+
hsv_gains = \
|
345 |
+
random.uniform(-1, 1, 3) * \
|
346 |
+
[self.hue_delta, self.saturation_delta, self.value_delta] + 1
|
347 |
+
hue, sat, val = cv2.split(
|
348 |
+
cv2.cvtColor(results['img'], cv2.COLOR_BGR2HSV))
|
349 |
+
|
350 |
+
table_list = np.arange(0, 256, dtype=hsv_gains.dtype)
|
351 |
+
lut_hue = ((table_list * hsv_gains[0]) % 180).astype(np.uint8)
|
352 |
+
lut_sat = np.clip(table_list * hsv_gains[1], 0, 255).astype(np.uint8)
|
353 |
+
lut_val = np.clip(table_list * hsv_gains[2], 0, 255).astype(np.uint8)
|
354 |
+
|
355 |
+
im_hsv = cv2.merge(
|
356 |
+
(cv2.LUT(hue, lut_hue), cv2.LUT(sat,
|
357 |
+
lut_sat), cv2.LUT(val, lut_val)))
|
358 |
+
results['img'] = cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR)
|
359 |
+
return results
|
360 |
+
|
361 |
+
def __repr__(self) -> str:
|
362 |
+
repr_str = self.__class__.__name__
|
363 |
+
repr_str += f'(hue_delta={self.hue_delta}, '
|
364 |
+
repr_str += f'saturation_delta={self.saturation_delta}, '
|
365 |
+
repr_str += f'value_delta={self.value_delta})'
|
366 |
+
return repr_str
|
367 |
+
|
368 |
+
|
369 |
+
@TRANSFORMS.register_module()
|
370 |
+
class LoadAnnotations(MMDET_LoadAnnotations):
|
371 |
+
"""Because the yolo series does not need to consider ignore bboxes for the
|
372 |
+
time being, in order to speed up the pipeline, it can be excluded in
|
373 |
+
advance."""
|
374 |
+
|
375 |
+
def __init__(self,
|
376 |
+
mask2bbox: bool = False,
|
377 |
+
poly2mask: bool = False,
|
378 |
+
**kwargs) -> None:
|
379 |
+
self.mask2bbox = mask2bbox
|
380 |
+
assert not poly2mask, 'Does not support BitmapMasks considering ' \
|
381 |
+
'that bitmap consumes more memory.'
|
382 |
+
super().__init__(poly2mask=poly2mask, **kwargs)
|
383 |
+
if self.mask2bbox:
|
384 |
+
assert self.with_mask, 'Using mask2bbox requires ' \
|
385 |
+
'with_mask is True.'
|
386 |
+
self._mask_ignore_flag = None
|
387 |
+
|
388 |
+
def transform(self, results: dict) -> dict:
|
389 |
+
"""Function to load multiple types annotations.
|
390 |
+
|
391 |
+
Args:
|
392 |
+
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
|
393 |
+
|
394 |
+
Returns:
|
395 |
+
dict: The dict contains loaded bounding box, label and
|
396 |
+
semantic segmentation.
|
397 |
+
"""
|
398 |
+
if self.mask2bbox:
|
399 |
+
self._load_masks(results)
|
400 |
+
if self.with_label:
|
401 |
+
self._load_labels(results)
|
402 |
+
self._update_mask_ignore_data(results)
|
403 |
+
gt_bboxes = results['gt_masks'].get_bboxes(dst_type='hbox')
|
404 |
+
results['gt_bboxes'] = gt_bboxes
|
405 |
+
else:
|
406 |
+
results = super().transform(results)
|
407 |
+
self._update_mask_ignore_data(results)
|
408 |
+
return results
|
409 |
+
|
410 |
+
def _update_mask_ignore_data(self, results: dict) -> None:
|
411 |
+
if 'gt_masks' not in results:
|
412 |
+
return
|
413 |
+
|
414 |
+
if 'gt_bboxes_labels' in results and len(
|
415 |
+
results['gt_bboxes_labels']) != len(results['gt_masks']):
|
416 |
+
assert len(results['gt_bboxes_labels']) == len(
|
417 |
+
self._mask_ignore_flag)
|
418 |
+
results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
|
419 |
+
self._mask_ignore_flag]
|
420 |
+
|
421 |
+
if 'gt_bboxes' in results and len(results['gt_bboxes']) != len(
|
422 |
+
results['gt_masks']):
|
423 |
+
assert len(results['gt_bboxes']) == len(self._mask_ignore_flag)
|
424 |
+
results['gt_bboxes'] = results['gt_bboxes'][self._mask_ignore_flag]
|
425 |
+
|
426 |
+
def _load_bboxes(self, results: dict):
|
427 |
+
"""Private function to load bounding box annotations.
|
428 |
+
Note: BBoxes with ignore_flag of 1 is not considered.
|
429 |
+
Args:
|
430 |
+
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
|
431 |
+
|
432 |
+
Returns:
|
433 |
+
dict: The dict contains loaded bounding box annotations.
|
434 |
+
"""
|
435 |
+
gt_bboxes = []
|
436 |
+
gt_ignore_flags = []
|
437 |
+
for instance in results.get('instances', []):
|
438 |
+
if instance['ignore_flag'] == 0:
|
439 |
+
gt_bboxes.append(instance['bbox'])
|
440 |
+
gt_ignore_flags.append(instance['ignore_flag'])
|
441 |
+
results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
|
442 |
+
|
443 |
+
if self.box_type is None:
|
444 |
+
results['gt_bboxes'] = np.array(
|
445 |
+
gt_bboxes, dtype=np.float32).reshape((-1, 4))
|
446 |
+
else:
|
447 |
+
_, box_type_cls = get_box_type(self.box_type)
|
448 |
+
results['gt_bboxes'] = box_type_cls(gt_bboxes, dtype=torch.float32)
|
449 |
+
|
450 |
+
def _load_labels(self, results: dict):
|
451 |
+
"""Private function to load label annotations.
|
452 |
+
|
453 |
+
Note: BBoxes with ignore_flag of 1 is not considered.
|
454 |
+
Args:
|
455 |
+
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
|
456 |
+
Returns:
|
457 |
+
dict: The dict contains loaded label annotations.
|
458 |
+
"""
|
459 |
+
gt_bboxes_labels = []
|
460 |
+
for instance in results.get('instances', []):
|
461 |
+
if instance['ignore_flag'] == 0:
|
462 |
+
gt_bboxes_labels.append(instance['bbox_label'])
|
463 |
+
results['gt_bboxes_labels'] = np.array(
|
464 |
+
gt_bboxes_labels, dtype=np.int64)
|
465 |
+
|
466 |
+
def _load_masks(self, results: dict) -> None:
|
467 |
+
"""Private function to load mask annotations.
|
468 |
+
|
469 |
+
Args:
|
470 |
+
results (dict): Result dict from :obj:``mmengine.BaseDataset``.
|
471 |
+
"""
|
472 |
+
gt_masks = []
|
473 |
+
gt_ignore_flags = []
|
474 |
+
self._mask_ignore_flag = []
|
475 |
+
for instance in results.get('instances', []):
|
476 |
+
if instance['ignore_flag'] == 0:
|
477 |
+
if 'mask' in instance:
|
478 |
+
gt_mask = instance['mask']
|
479 |
+
if isinstance(gt_mask, list):
|
480 |
+
gt_mask = [
|
481 |
+
np.array(polygon) for polygon in gt_mask
|
482 |
+
if len(polygon) % 2 == 0 and len(polygon) >= 6
|
483 |
+
]
|
484 |
+
if len(gt_mask) == 0:
|
485 |
+
# ignore
|
486 |
+
self._mask_ignore_flag.append(0)
|
487 |
+
else:
|
488 |
+
gt_masks.append(gt_mask)
|
489 |
+
gt_ignore_flags.append(instance['ignore_flag'])
|
490 |
+
self._mask_ignore_flag.append(1)
|
491 |
+
else:
|
492 |
+
raise NotImplementedError(
|
493 |
+
'Only supports mask annotations in polygon '
|
494 |
+
'format currently')
|
495 |
+
else:
|
496 |
+
# TODO: Actually, gt with bbox and without mask needs
|
497 |
+
# to be retained
|
498 |
+
self._mask_ignore_flag.append(0)
|
499 |
+
self._mask_ignore_flag = np.array(self._mask_ignore_flag, dtype=bool)
|
500 |
+
results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
|
501 |
+
|
502 |
+
h, w = results['ori_shape']
|
503 |
+
gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
|
504 |
+
results['gt_masks'] = gt_masks
|
505 |
+
|
506 |
+
def __repr__(self) -> str:
|
507 |
+
repr_str = self.__class__.__name__
|
508 |
+
repr_str += f'(with_bbox={self.with_bbox}, '
|
509 |
+
repr_str += f'with_label={self.with_label}, '
|
510 |
+
repr_str += f'with_mask={self.with_mask}, '
|
511 |
+
repr_str += f'with_seg={self.with_seg}, '
|
512 |
+
repr_str += f'mask2bbox={self.mask2bbox}, '
|
513 |
+
repr_str += f'poly2mask={self.poly2mask}, '
|
514 |
+
repr_str += f"imdecode_backend='{self.imdecode_backend}', "
|
515 |
+
repr_str += f'file_client_args={self.file_client_args})'
|
516 |
+
return repr_str
|
517 |
+
|
518 |
+
|
519 |
+
@TRANSFORMS.register_module()
|
520 |
+
class YOLOv5RandomAffine(BaseTransform):
|
521 |
+
"""Random affine transform data augmentation in YOLOv5 and YOLOv8. It is
|
522 |
+
different from the implementation in YOLOX.
|
523 |
+
|
524 |
+
This operation randomly generates affine transform matrix which including
|
525 |
+
rotation, translation, shear and scaling transforms.
|
526 |
+
If you set use_mask_refine == True, the code will use the masks
|
527 |
+
annotation to refine the bbox.
|
528 |
+
Our implementation is slightly different from the official. In COCO
|
529 |
+
dataset, a gt may have multiple mask tags. The official YOLOv5
|
530 |
+
annotation file already combines the masks that an object has,
|
531 |
+
but our code takes into account the fact that an object has multiple masks.
|
532 |
+
|
533 |
+
Required Keys:
|
534 |
+
|
535 |
+
- img
|
536 |
+
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
|
537 |
+
- gt_bboxes_labels (np.int64) (optional)
|
538 |
+
- gt_ignore_flags (bool) (optional)
|
539 |
+
- gt_masks (PolygonMasks) (optional)
|
540 |
+
|
541 |
+
Modified Keys:
|
542 |
+
|
543 |
+
- img
|
544 |
+
- img_shape
|
545 |
+
- gt_bboxes (optional)
|
546 |
+
- gt_bboxes_labels (optional)
|
547 |
+
- gt_ignore_flags (optional)
|
548 |
+
- gt_masks (PolygonMasks) (optional)
|
549 |
+
|
550 |
+
Args:
|
551 |
+
max_rotate_degree (float): Maximum degrees of rotation transform.
|
552 |
+
Defaults to 10.
|
553 |
+
max_translate_ratio (float): Maximum ratio of translation.
|
554 |
+
Defaults to 0.1.
|
555 |
+
scaling_ratio_range (tuple[float]): Min and max ratio of
|
556 |
+
scaling transform. Defaults to (0.5, 1.5).
|
557 |
+
max_shear_degree (float): Maximum degrees of shear
|
558 |
+
transform. Defaults to 2.
|
559 |
+
border (tuple[int]): Distance from width and height sides of input
|
560 |
+
image to adjust output shape. Only used in mosaic dataset.
|
561 |
+
Defaults to (0, 0).
|
562 |
+
border_val (tuple[int]): Border padding values of 3 channels.
|
563 |
+
Defaults to (114, 114, 114).
|
564 |
+
bbox_clip_border (bool, optional): Whether to clip the objects outside
|
565 |
+
the border of the image. In some dataset like MOT17, the gt bboxes
|
566 |
+
are allowed to cross the border of images. Therefore, we don't
|
567 |
+
need to clip the gt bboxes in these cases. Defaults to True.
|
568 |
+
min_bbox_size (float): Width and height threshold to filter bboxes.
|
569 |
+
If the height or width of a box is smaller than this value, it
|
570 |
+
will be removed. Defaults to 2.
|
571 |
+
min_area_ratio (float): Threshold of area ratio between
|
572 |
+
original bboxes and wrapped bboxes. If smaller than this value,
|
573 |
+
the box will be removed. Defaults to 0.1.
|
574 |
+
use_mask_refine (bool): Whether to refine bbox by mask.
|
575 |
+
max_aspect_ratio (float): Aspect ratio of width and height
|
576 |
+
threshold to filter bboxes. If max(h/w, w/h) larger than this
|
577 |
+
value, the box will be removed. Defaults to 20.
|
578 |
+
resample_num (int): Number of poly to resample to.
|
579 |
+
"""
|
580 |
+
|
581 |
+
def __init__(self,
|
582 |
+
max_rotate_degree: float = 10.0,
|
583 |
+
max_translate_ratio: float = 0.1,
|
584 |
+
scaling_ratio_range: Tuple[float, float] = (0.5, 1.5),
|
585 |
+
max_shear_degree: float = 2.0,
|
586 |
+
border: Tuple[int, int] = (0, 0),
|
587 |
+
border_val: Tuple[int, int, int] = (114, 114, 114),
|
588 |
+
bbox_clip_border: bool = True,
|
589 |
+
min_bbox_size: int = 2,
|
590 |
+
min_area_ratio: float = 0.1,
|
591 |
+
use_mask_refine: bool = False,
|
592 |
+
max_aspect_ratio: float = 20.,
|
593 |
+
resample_num: int = 1000):
|
594 |
+
assert 0 <= max_translate_ratio <= 1
|
595 |
+
assert scaling_ratio_range[0] <= scaling_ratio_range[1]
|
596 |
+
assert scaling_ratio_range[0] > 0
|
597 |
+
self.max_rotate_degree = max_rotate_degree
|
598 |
+
self.max_translate_ratio = max_translate_ratio
|
599 |
+
self.scaling_ratio_range = scaling_ratio_range
|
600 |
+
self.max_shear_degree = max_shear_degree
|
601 |
+
self.border = border
|
602 |
+
self.border_val = border_val
|
603 |
+
self.bbox_clip_border = bbox_clip_border
|
604 |
+
self.min_bbox_size = min_bbox_size
|
605 |
+
self.min_area_ratio = min_area_ratio
|
606 |
+
self.use_mask_refine = use_mask_refine
|
607 |
+
self.max_aspect_ratio = max_aspect_ratio
|
608 |
+
self.resample_num = resample_num
|
609 |
+
|
610 |
+
@autocast_box_type()
|
611 |
+
def transform(self, results: dict) -> dict:
|
612 |
+
"""The YOLOv5 random affine transform function.
|
613 |
+
|
614 |
+
Args:
|
615 |
+
results (dict): The result dict.
|
616 |
+
|
617 |
+
Returns:
|
618 |
+
dict: The result dict.
|
619 |
+
"""
|
620 |
+
img = results['img']
|
621 |
+
# self.border is wh format
|
622 |
+
height = img.shape[0] + self.border[1] * 2
|
623 |
+
width = img.shape[1] + self.border[0] * 2
|
624 |
+
|
625 |
+
# Note: Different from YOLOX
|
626 |
+
center_matrix = np.eye(3, dtype=np.float32)
|
627 |
+
center_matrix[0, 2] = -img.shape[1] / 2
|
628 |
+
center_matrix[1, 2] = -img.shape[0] / 2
|
629 |
+
|
630 |
+
warp_matrix, scaling_ratio = self._get_random_homography_matrix(
|
631 |
+
height, width)
|
632 |
+
warp_matrix = warp_matrix @ center_matrix
|
633 |
+
|
634 |
+
img = cv2.warpPerspective(
|
635 |
+
img,
|
636 |
+
warp_matrix,
|
637 |
+
dsize=(width, height),
|
638 |
+
borderValue=self.border_val)
|
639 |
+
results['img'] = img
|
640 |
+
results['img_shape'] = img.shape
|
641 |
+
img_h, img_w = img.shape[:2]
|
642 |
+
|
643 |
+
bboxes = results['gt_bboxes']
|
644 |
+
num_bboxes = len(bboxes)
|
645 |
+
if num_bboxes:
|
646 |
+
orig_bboxes = bboxes.clone()
|
647 |
+
if self.use_mask_refine and 'gt_masks' in results:
|
648 |
+
# If the dataset has annotations of mask,
|
649 |
+
# the mask will be used to refine bbox.
|
650 |
+
gt_masks = results['gt_masks']
|
651 |
+
|
652 |
+
gt_masks_resample = self.resample_masks(gt_masks)
|
653 |
+
gt_masks = self.warp_mask(gt_masks_resample, warp_matrix,
|
654 |
+
img_h, img_w)
|
655 |
+
|
656 |
+
# refine bboxes by masks
|
657 |
+
bboxes = gt_masks.get_bboxes(dst_type='hbox')
|
658 |
+
# filter bboxes outside image
|
659 |
+
valid_index = self.filter_gt_bboxes(orig_bboxes,
|
660 |
+
bboxes).numpy()
|
661 |
+
results['gt_masks'] = gt_masks[valid_index]
|
662 |
+
else:
|
663 |
+
bboxes.project_(warp_matrix)
|
664 |
+
if self.bbox_clip_border:
|
665 |
+
bboxes.clip_([height, width])
|
666 |
+
|
667 |
+
# filter bboxes
|
668 |
+
orig_bboxes.rescale_([scaling_ratio, scaling_ratio])
|
669 |
+
|
670 |
+
# Be careful: valid_index must convert to numpy,
|
671 |
+
# otherwise it will raise out of bounds when len(valid_index)=1
|
672 |
+
valid_index = self.filter_gt_bboxes(orig_bboxes,
|
673 |
+
bboxes).numpy()
|
674 |
+
if 'gt_masks' in results:
|
675 |
+
results['gt_masks'] = PolygonMasks(
|
676 |
+
results['gt_masks'].masks, img_h, img_w)
|
677 |
+
|
678 |
+
results['gt_bboxes'] = bboxes[valid_index]
|
679 |
+
results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
|
680 |
+
valid_index]
|
681 |
+
results['gt_ignore_flags'] = results['gt_ignore_flags'][
|
682 |
+
valid_index]
|
683 |
+
|
684 |
+
return results
|
685 |
+
|
686 |
+
@staticmethod
|
687 |
+
def warp_poly(poly: np.ndarray, warp_matrix: np.ndarray, img_w: int,
|
688 |
+
img_h: int) -> np.ndarray:
|
689 |
+
"""Function to warp one mask and filter points outside image.
|
690 |
+
|
691 |
+
Args:
|
692 |
+
poly (np.ndarray): Segmentation annotation with shape (n, ) and
|
693 |
+
with format (x1, y1, x2, y2, ...).
|
694 |
+
warp_matrix (np.ndarray): Affine transformation matrix.
|
695 |
+
Shape: (3, 3).
|
696 |
+
img_w (int): Width of output image.
|
697 |
+
img_h (int): Height of output image.
|
698 |
+
"""
|
699 |
+
# TODO: Current logic may cause retained masks unusable for
|
700 |
+
# semantic segmentation training, which is same as official
|
701 |
+
# implementation.
|
702 |
+
poly = poly.reshape((-1, 2))
|
703 |
+
poly = np.concatenate((poly, np.ones(
|
704 |
+
(len(poly), 1), dtype=poly.dtype)),
|
705 |
+
axis=-1)
|
706 |
+
# transform poly
|
707 |
+
poly = poly @ warp_matrix.T
|
708 |
+
poly = poly[:, :2] / poly[:, 2:3]
|
709 |
+
|
710 |
+
# filter point outside image
|
711 |
+
x, y = poly.T
|
712 |
+
valid_ind_point = (x >= 0) & (y >= 0) & (x <= img_w) & (y <= img_h)
|
713 |
+
return poly[valid_ind_point].reshape(-1)
|
714 |
+
|
715 |
+
def warp_mask(self, gt_masks: PolygonMasks, warp_matrix: np.ndarray,
|
716 |
+
img_w: int, img_h: int) -> PolygonMasks:
|
717 |
+
"""Warp masks by warp_matrix and retain masks inside image after
|
718 |
+
warping.
|
719 |
+
|
720 |
+
Args:
|
721 |
+
gt_masks (PolygonMasks): Annotations of semantic segmentation.
|
722 |
+
warp_matrix (np.ndarray): Affine transformation matrix.
|
723 |
+
Shape: (3, 3).
|
724 |
+
img_w (int): Width of output image.
|
725 |
+
img_h (int): Height of output image.
|
726 |
+
|
727 |
+
Returns:
|
728 |
+
PolygonMasks: Masks after warping.
|
729 |
+
"""
|
730 |
+
masks = gt_masks.masks
|
731 |
+
|
732 |
+
new_masks = []
|
733 |
+
for poly_per_obj in masks:
|
734 |
+
warpped_poly_per_obj = []
|
735 |
+
# One gt may have multiple masks.
|
736 |
+
for poly in poly_per_obj:
|
737 |
+
valid_poly = self.warp_poly(poly, warp_matrix, img_w, img_h)
|
738 |
+
if len(valid_poly):
|
739 |
+
warpped_poly_per_obj.append(valid_poly.reshape(-1))
|
740 |
+
# If all the masks are invalid,
|
741 |
+
# add [0, 0, 0, 0, 0, 0,] here.
|
742 |
+
if not warpped_poly_per_obj:
|
743 |
+
# This will be filtered in function `filter_gt_bboxes`.
|
744 |
+
warpped_poly_per_obj = [
|
745 |
+
np.zeros(6, dtype=poly_per_obj[0].dtype)
|
746 |
+
]
|
747 |
+
new_masks.append(warpped_poly_per_obj)
|
748 |
+
|
749 |
+
gt_masks = PolygonMasks(new_masks, img_h, img_w)
|
750 |
+
return gt_masks
|
751 |
+
|
752 |
+
def resample_masks(self, gt_masks: PolygonMasks) -> PolygonMasks:
|
753 |
+
"""Function to resample each mask annotation with shape (2 * n, ) to
|
754 |
+
shape (resample_num * 2, ).
|
755 |
+
|
756 |
+
Args:
|
757 |
+
gt_masks (PolygonMasks): Annotations of semantic segmentation.
|
758 |
+
"""
|
759 |
+
masks = gt_masks.masks
|
760 |
+
new_masks = []
|
761 |
+
for poly_per_obj in masks:
|
762 |
+
resample_poly_per_obj = []
|
763 |
+
for poly in poly_per_obj:
|
764 |
+
poly = poly.reshape((-1, 2)) # xy
|
765 |
+
poly = np.concatenate((poly, poly[0:1, :]), axis=0)
|
766 |
+
x = np.linspace(0, len(poly) - 1, self.resample_num)
|
767 |
+
xp = np.arange(len(poly))
|
768 |
+
poly = np.concatenate([
|
769 |
+
np.interp(x, xp, poly[:, i]) for i in range(2)
|
770 |
+
]).reshape(2, -1).T.reshape(-1)
|
771 |
+
resample_poly_per_obj.append(poly)
|
772 |
+
new_masks.append(resample_poly_per_obj)
|
773 |
+
return PolygonMasks(new_masks, gt_masks.height, gt_masks.width)
|
774 |
+
|
775 |
+
def filter_gt_bboxes(self, origin_bboxes: HorizontalBoxes,
|
776 |
+
wrapped_bboxes: HorizontalBoxes) -> torch.Tensor:
|
777 |
+
"""Filter gt bboxes.
|
778 |
+
|
779 |
+
Args:
|
780 |
+
origin_bboxes (HorizontalBoxes): Origin bboxes.
|
781 |
+
wrapped_bboxes (HorizontalBoxes): Wrapped bboxes
|
782 |
+
|
783 |
+
Returns:
|
784 |
+
dict: The result dict.
|
785 |
+
"""
|
786 |
+
origin_w = origin_bboxes.widths
|
787 |
+
origin_h = origin_bboxes.heights
|
788 |
+
wrapped_w = wrapped_bboxes.widths
|
789 |
+
wrapped_h = wrapped_bboxes.heights
|
790 |
+
aspect_ratio = np.maximum(wrapped_w / (wrapped_h + 1e-16),
|
791 |
+
wrapped_h / (wrapped_w + 1e-16))
|
792 |
+
|
793 |
+
wh_valid_idx = (wrapped_w > self.min_bbox_size) & \
|
794 |
+
(wrapped_h > self.min_bbox_size)
|
795 |
+
area_valid_idx = wrapped_w * wrapped_h / (origin_w * origin_h +
|
796 |
+
1e-16) > self.min_area_ratio
|
797 |
+
aspect_ratio_valid_idx = aspect_ratio < self.max_aspect_ratio
|
798 |
+
return wh_valid_idx & area_valid_idx & aspect_ratio_valid_idx
|
799 |
+
|
800 |
+
@cache_randomness
|
801 |
+
def _get_random_homography_matrix(self, height: int,
|
802 |
+
width: int) -> Tuple[np.ndarray, float]:
|
803 |
+
"""Get random homography matrix.
|
804 |
+
|
805 |
+
Args:
|
806 |
+
height (int): Image height.
|
807 |
+
width (int): Image width.
|
808 |
+
|
809 |
+
Returns:
|
810 |
+
Tuple[np.ndarray, float]: The result of warp_matrix and
|
811 |
+
scaling_ratio.
|
812 |
+
"""
|
813 |
+
# Rotation
|
814 |
+
rotation_degree = random.uniform(-self.max_rotate_degree,
|
815 |
+
self.max_rotate_degree)
|
816 |
+
rotation_matrix = self._get_rotation_matrix(rotation_degree)
|
817 |
+
|
818 |
+
# Scaling
|
819 |
+
scaling_ratio = random.uniform(self.scaling_ratio_range[0],
|
820 |
+
self.scaling_ratio_range[1])
|
821 |
+
scaling_matrix = self._get_scaling_matrix(scaling_ratio)
|
822 |
+
|
823 |
+
# Shear
|
824 |
+
x_degree = random.uniform(-self.max_shear_degree,
|
825 |
+
self.max_shear_degree)
|
826 |
+
y_degree = random.uniform(-self.max_shear_degree,
|
827 |
+
self.max_shear_degree)
|
828 |
+
shear_matrix = self._get_shear_matrix(x_degree, y_degree)
|
829 |
+
|
830 |
+
# Translation
|
831 |
+
trans_x = random.uniform(0.5 - self.max_translate_ratio,
|
832 |
+
0.5 + self.max_translate_ratio) * width
|
833 |
+
trans_y = random.uniform(0.5 - self.max_translate_ratio,
|
834 |
+
0.5 + self.max_translate_ratio) * height
|
835 |
+
translate_matrix = self._get_translation_matrix(trans_x, trans_y)
|
836 |
+
warp_matrix = (
|
837 |
+
translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix)
|
838 |
+
return warp_matrix, scaling_ratio
|
839 |
+
|
840 |
+
@staticmethod
|
841 |
+
def _get_rotation_matrix(rotate_degrees: float) -> np.ndarray:
|
842 |
+
"""Get rotation matrix.
|
843 |
+
|
844 |
+
Args:
|
845 |
+
rotate_degrees (float): Rotate degrees.
|
846 |
+
|
847 |
+
Returns:
|
848 |
+
np.ndarray: The rotation matrix.
|
849 |
+
"""
|
850 |
+
radian = math.radians(rotate_degrees)
|
851 |
+
rotation_matrix = np.array(
|
852 |
+
[[np.cos(radian), -np.sin(radian), 0.],
|
853 |
+
[np.sin(radian), np.cos(radian), 0.], [0., 0., 1.]],
|
854 |
+
dtype=np.float32)
|
855 |
+
return rotation_matrix
|
856 |
+
|
857 |
+
@staticmethod
|
858 |
+
def _get_scaling_matrix(scale_ratio: float) -> np.ndarray:
|
859 |
+
"""Get scaling matrix.
|
860 |
+
|
861 |
+
Args:
|
862 |
+
scale_ratio (float): Scale ratio.
|
863 |
+
|
864 |
+
Returns:
|
865 |
+
np.ndarray: The scaling matrix.
|
866 |
+
"""
|
867 |
+
scaling_matrix = np.array(
|
868 |
+
[[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]],
|
869 |
+
dtype=np.float32)
|
870 |
+
return scaling_matrix
|
871 |
+
|
872 |
+
@staticmethod
|
873 |
+
def _get_shear_matrix(x_shear_degrees: float,
|
874 |
+
y_shear_degrees: float) -> np.ndarray:
|
875 |
+
"""Get shear matrix.
|
876 |
+
|
877 |
+
Args:
|
878 |
+
x_shear_degrees (float): X shear degrees.
|
879 |
+
y_shear_degrees (float): Y shear degrees.
|
880 |
+
|
881 |
+
Returns:
|
882 |
+
np.ndarray: The shear matrix.
|
883 |
+
"""
|
884 |
+
x_radian = math.radians(x_shear_degrees)
|
885 |
+
y_radian = math.radians(y_shear_degrees)
|
886 |
+
shear_matrix = np.array([[1, np.tan(x_radian), 0.],
|
887 |
+
[np.tan(y_radian), 1, 0.], [0., 0., 1.]],
|
888 |
+
dtype=np.float32)
|
889 |
+
return shear_matrix
|
890 |
+
|
891 |
+
@staticmethod
|
892 |
+
def _get_translation_matrix(x: float, y: float) -> np.ndarray:
|
893 |
+
"""Get translation matrix.
|
894 |
+
|
895 |
+
Args:
|
896 |
+
x (float): X translation.
|
897 |
+
y (float): Y translation.
|
898 |
+
|
899 |
+
Returns:
|
900 |
+
np.ndarray: The translation matrix.
|
901 |
+
"""
|
902 |
+
translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
|
903 |
+
dtype=np.float32)
|
904 |
+
return translation_matrix
|
905 |
+
|
906 |
+
def __repr__(self) -> str:
|
907 |
+
repr_str = self.__class__.__name__
|
908 |
+
repr_str += f'(max_rotate_degree={self.max_rotate_degree}, '
|
909 |
+
repr_str += f'max_translate_ratio={self.max_translate_ratio}, '
|
910 |
+
repr_str += f'scaling_ratio_range={self.scaling_ratio_range}, '
|
911 |
+
repr_str += f'max_shear_degree={self.max_shear_degree}, '
|
912 |
+
repr_str += f'border={self.border}, '
|
913 |
+
repr_str += f'border_val={self.border_val}, '
|
914 |
+
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
|
915 |
+
return repr_str
|
916 |
+
|
917 |
+
|
918 |
+
@TRANSFORMS.register_module()
|
919 |
+
class PPYOLOERandomDistort(BaseTransform):
|
920 |
+
"""Random hue, saturation, contrast and brightness distortion.
|
921 |
+
|
922 |
+
Required Keys:
|
923 |
+
|
924 |
+
- img
|
925 |
+
|
926 |
+
Modified Keys:
|
927 |
+
|
928 |
+
- img (np.float32)
|
929 |
+
|
930 |
+
Args:
|
931 |
+
hue_cfg (dict): Hue settings. Defaults to dict(min=-18,
|
932 |
+
max=18, prob=0.5).
|
933 |
+
saturation_cfg (dict): Saturation settings. Defaults to dict(
|
934 |
+
min=0.5, max=1.5, prob=0.5).
|
935 |
+
contrast_cfg (dict): Contrast settings. Defaults to dict(
|
936 |
+
min=0.5, max=1.5, prob=0.5).
|
937 |
+
brightness_cfg (dict): Brightness settings. Defaults to dict(
|
938 |
+
min=0.5, max=1.5, prob=0.5).
|
939 |
+
num_distort_func (int): The number of distort function. Defaults
|
940 |
+
to 4.
|
941 |
+
"""
|
942 |
+
|
943 |
+
def __init__(self,
|
944 |
+
hue_cfg: dict = dict(min=-18, max=18, prob=0.5),
|
945 |
+
saturation_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
|
946 |
+
contrast_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
|
947 |
+
brightness_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
|
948 |
+
num_distort_func: int = 4):
|
949 |
+
self.hue_cfg = hue_cfg
|
950 |
+
self.saturation_cfg = saturation_cfg
|
951 |
+
self.contrast_cfg = contrast_cfg
|
952 |
+
self.brightness_cfg = brightness_cfg
|
953 |
+
self.num_distort_func = num_distort_func
|
954 |
+
assert 0 < self.num_distort_func <= 4, \
|
955 |
+
'num_distort_func must > 0 and <= 4'
|
956 |
+
for cfg in [
|
957 |
+
self.hue_cfg, self.saturation_cfg, self.contrast_cfg,
|
958 |
+
self.brightness_cfg
|
959 |
+
]:
|
960 |
+
assert 0. <= cfg['prob'] <= 1., 'prob must >=0 and <=1'
|
961 |
+
|
962 |
+
def transform_hue(self, results):
|
963 |
+
"""Transform hue randomly."""
|
964 |
+
if random.uniform(0., 1.) >= self.hue_cfg['prob']:
|
965 |
+
return results
|
966 |
+
img = results['img']
|
967 |
+
delta = random.uniform(self.hue_cfg['min'], self.hue_cfg['max'])
|
968 |
+
u = np.cos(delta * np.pi)
|
969 |
+
w = np.sin(delta * np.pi)
|
970 |
+
delta_iq = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
|
971 |
+
rgb2yiq_matrix = np.array([[0.114, 0.587, 0.299],
|
972 |
+
[-0.321, -0.274, 0.596],
|
973 |
+
[0.311, -0.523, 0.211]])
|
974 |
+
yiq2rgb_matric = np.array([[1.0, -1.107, 1.705], [1.0, -0.272, -0.647],
|
975 |
+
[1.0, 0.956, 0.621]])
|
976 |
+
t = np.dot(np.dot(yiq2rgb_matric, delta_iq), rgb2yiq_matrix).T
|
977 |
+
img = np.dot(img, t)
|
978 |
+
results['img'] = img
|
979 |
+
return results
|
980 |
+
|
981 |
+
def transform_saturation(self, results):
|
982 |
+
"""Transform saturation randomly."""
|
983 |
+
if random.uniform(0., 1.) >= self.saturation_cfg['prob']:
|
984 |
+
return results
|
985 |
+
img = results['img']
|
986 |
+
delta = random.uniform(self.saturation_cfg['min'],
|
987 |
+
self.saturation_cfg['max'])
|
988 |
+
|
989 |
+
# convert bgr img to gray img
|
990 |
+
gray = img * np.array([[[0.114, 0.587, 0.299]]], dtype=np.float32)
|
991 |
+
gray = gray.sum(axis=2, keepdims=True)
|
992 |
+
gray *= (1.0 - delta)
|
993 |
+
img *= delta
|
994 |
+
img += gray
|
995 |
+
results['img'] = img
|
996 |
+
return results
|
997 |
+
|
998 |
+
def transform_contrast(self, results):
|
999 |
+
"""Transform contrast randomly."""
|
1000 |
+
if random.uniform(0., 1.) >= self.contrast_cfg['prob']:
|
1001 |
+
return results
|
1002 |
+
img = results['img']
|
1003 |
+
delta = random.uniform(self.contrast_cfg['min'],
|
1004 |
+
self.contrast_cfg['max'])
|
1005 |
+
img *= delta
|
1006 |
+
results['img'] = img
|
1007 |
+
return results
|
1008 |
+
|
1009 |
+
def transform_brightness(self, results):
|
1010 |
+
"""Transform brightness randomly."""
|
1011 |
+
if random.uniform(0., 1.) >= self.brightness_cfg['prob']:
|
1012 |
+
return results
|
1013 |
+
img = results['img']
|
1014 |
+
delta = random.uniform(self.brightness_cfg['min'],
|
1015 |
+
self.brightness_cfg['max'])
|
1016 |
+
img += delta
|
1017 |
+
results['img'] = img
|
1018 |
+
return results
|
1019 |
+
|
1020 |
+
def transform(self, results: dict) -> dict:
|
1021 |
+
"""The hue, saturation, contrast and brightness distortion function.
|
1022 |
+
|
1023 |
+
Args:
|
1024 |
+
results (dict): The result dict.
|
1025 |
+
|
1026 |
+
Returns:
|
1027 |
+
dict: The result dict.
|
1028 |
+
"""
|
1029 |
+
results['img'] = results['img'].astype(np.float32)
|
1030 |
+
|
1031 |
+
functions = [
|
1032 |
+
self.transform_brightness, self.transform_contrast,
|
1033 |
+
self.transform_saturation, self.transform_hue
|
1034 |
+
]
|
1035 |
+
distortions = random.permutation(functions)[:self.num_distort_func]
|
1036 |
+
for func in distortions:
|
1037 |
+
results = func(results)
|
1038 |
+
return results
|
1039 |
+
|
1040 |
+
def __repr__(self) -> str:
|
1041 |
+
repr_str = self.__class__.__name__
|
1042 |
+
repr_str += f'(hue_cfg={self.hue_cfg}, '
|
1043 |
+
repr_str += f'saturation_cfg={self.saturation_cfg}, '
|
1044 |
+
repr_str += f'contrast_cfg={self.contrast_cfg}, '
|
1045 |
+
repr_str += f'brightness_cfg={self.brightness_cfg}, '
|
1046 |
+
repr_str += f'num_distort_func={self.num_distort_func})'
|
1047 |
+
return repr_str
|
1048 |
+
|
1049 |
+
|
1050 |
+
@TRANSFORMS.register_module()
|
1051 |
+
class PPYOLOERandomCrop(BaseTransform):
|
1052 |
+
"""Random crop the img and bboxes. Different thresholds are used in PPYOLOE
|
1053 |
+
to judge whether the clipped image meets the requirements. This
|
1054 |
+
implementation is different from the implementation of RandomCrop in mmdet.
|
1055 |
+
|
1056 |
+
Required Keys:
|
1057 |
+
|
1058 |
+
- img
|
1059 |
+
- gt_bboxes (BaseBoxes[torch.float32]) (optional)
|
1060 |
+
- gt_bboxes_labels (np.int64) (optional)
|
1061 |
+
- gt_ignore_flags (bool) (optional)
|
1062 |
+
|
1063 |
+
Modified Keys:
|
1064 |
+
|
1065 |
+
- img
|
1066 |
+
- img_shape
|
1067 |
+
- gt_bboxes (optional)
|
1068 |
+
- gt_bboxes_labels (optional)
|
1069 |
+
- gt_ignore_flags (optional)
|
1070 |
+
|
1071 |
+
Added Keys:
|
1072 |
+
- pad_param (np.float32)
|
1073 |
+
|
1074 |
+
Args:
|
1075 |
+
aspect_ratio (List[float]): Aspect ratio of cropped region. Default to
|
1076 |
+
[.5, 2].
|
1077 |
+
thresholds (List[float]): Iou thresholds for deciding a valid bbox crop
|
1078 |
+
in [min, max] format. Defaults to [.0, .1, .3, .5, .7, .9].
|
1079 |
+
scaling (List[float]): Ratio between a cropped region and the original
|
1080 |
+
image in [min, max] format. Default to [.3, 1.].
|
1081 |
+
num_attempts (int): Number of tries for each threshold before
|
1082 |
+
giving up. Default to 50.
|
1083 |
+
allow_no_crop (bool): Allow return without actually cropping them.
|
1084 |
+
Default to True.
|
1085 |
+
cover_all_box (bool): Ensure all bboxes are covered in the final crop.
|
1086 |
+
Default to False.
|
1087 |
+
"""
|
1088 |
+
|
1089 |
+
def __init__(self,
|
1090 |
+
aspect_ratio: List[float] = [.5, 2.],
|
1091 |
+
thresholds: List[float] = [.0, .1, .3, .5, .7, .9],
|
1092 |
+
scaling: List[float] = [.3, 1.],
|
1093 |
+
num_attempts: int = 50,
|
1094 |
+
allow_no_crop: bool = True,
|
1095 |
+
cover_all_box: bool = False):
|
1096 |
+
self.aspect_ratio = aspect_ratio
|
1097 |
+
self.thresholds = thresholds
|
1098 |
+
self.scaling = scaling
|
1099 |
+
self.num_attempts = num_attempts
|
1100 |
+
self.allow_no_crop = allow_no_crop
|
1101 |
+
self.cover_all_box = cover_all_box
|
1102 |
+
|
1103 |
+
def _crop_data(self, results: dict, crop_box: Tuple[int, int, int, int],
|
1104 |
+
valid_inds: np.ndarray) -> Union[dict, None]:
|
1105 |
+
"""Function to randomly crop images, bounding boxes, masks, semantic
|
1106 |
+
segmentation maps.
|
1107 |
+
|
1108 |
+
Args:
|
1109 |
+
results (dict): Result dict from loading pipeline.
|
1110 |
+
crop_box (Tuple[int, int, int, int]): Expected absolute coordinates
|
1111 |
+
for cropping, (x1, y1, x2, y2).
|
1112 |
+
valid_inds (np.ndarray): The indexes of gt that needs to be
|
1113 |
+
retained.
|
1114 |
+
|
1115 |
+
Returns:
|
1116 |
+
results (Union[dict, None]): Randomly cropped results, 'img_shape'
|
1117 |
+
key in result dict is updated according to crop size. None will
|
1118 |
+
be returned when there is no valid bbox after cropping.
|
1119 |
+
"""
|
1120 |
+
# crop the image
|
1121 |
+
img = results['img']
|
1122 |
+
crop_x1, crop_y1, crop_x2, crop_y2 = crop_box
|
1123 |
+
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
|
1124 |
+
results['img'] = img
|
1125 |
+
img_shape = img.shape
|
1126 |
+
results['img_shape'] = img.shape
|
1127 |
+
|
1128 |
+
# crop bboxes accordingly and clip to the image boundary
|
1129 |
+
if results.get('gt_bboxes', None) is not None:
|
1130 |
+
bboxes = results['gt_bboxes']
|
1131 |
+
bboxes.translate_([-crop_x1, -crop_y1])
|
1132 |
+
bboxes.clip_(img_shape[:2])
|
1133 |
+
|
1134 |
+
results['gt_bboxes'] = bboxes[valid_inds]
|
1135 |
+
|
1136 |
+
if results.get('gt_ignore_flags', None) is not None:
|
1137 |
+
results['gt_ignore_flags'] = \
|
1138 |
+
results['gt_ignore_flags'][valid_inds]
|
1139 |
+
|
1140 |
+
if results.get('gt_bboxes_labels', None) is not None:
|
1141 |
+
results['gt_bboxes_labels'] = \
|
1142 |
+
results['gt_bboxes_labels'][valid_inds]
|
1143 |
+
|
1144 |
+
if results.get('gt_masks', None) is not None:
|
1145 |
+
results['gt_masks'] = results['gt_masks'][
|
1146 |
+
valid_inds.nonzero()[0]].crop(
|
1147 |
+
np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
|
1148 |
+
|
1149 |
+
# crop semantic seg
|
1150 |
+
if results.get('gt_seg_map', None) is not None:
|
1151 |
+
results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
|
1152 |
+
crop_x1:crop_x2]
|
1153 |
+
|
1154 |
+
return results
|
1155 |
+
|
1156 |
+
@autocast_box_type()
|
1157 |
+
def transform(self, results: dict) -> Union[dict, None]:
|
1158 |
+
"""The random crop transform function.
|
1159 |
+
|
1160 |
+
Args:
|
1161 |
+
results (dict): The result dict.
|
1162 |
+
|
1163 |
+
Returns:
|
1164 |
+
dict: The result dict.
|
1165 |
+
"""
|
1166 |
+
if results.get('gt_bboxes', None) is None or len(
|
1167 |
+
results['gt_bboxes']) == 0:
|
1168 |
+
return results
|
1169 |
+
|
1170 |
+
orig_img_h, orig_img_w = results['img'].shape[:2]
|
1171 |
+
gt_bboxes = results['gt_bboxes']
|
1172 |
+
|
1173 |
+
thresholds = list(self.thresholds)
|
1174 |
+
if self.allow_no_crop:
|
1175 |
+
thresholds.append('no_crop')
|
1176 |
+
random.shuffle(thresholds)
|
1177 |
+
|
1178 |
+
for thresh in thresholds:
|
1179 |
+
# Determine the coordinates for cropping
|
1180 |
+
if thresh == 'no_crop':
|
1181 |
+
return results
|
1182 |
+
|
1183 |
+
found = False
|
1184 |
+
for i in range(self.num_attempts):
|
1185 |
+
crop_h, crop_w = self._get_crop_size((orig_img_h, orig_img_w))
|
1186 |
+
if self.aspect_ratio is None:
|
1187 |
+
if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
|
1188 |
+
continue
|
1189 |
+
|
1190 |
+
# get image crop_box
|
1191 |
+
margin_h = max(orig_img_h - crop_h, 0)
|
1192 |
+
margin_w = max(orig_img_w - crop_w, 0)
|
1193 |
+
offset_h, offset_w = self._rand_offset((margin_h, margin_w))
|
1194 |
+
crop_y1, crop_y2 = offset_h, offset_h + crop_h
|
1195 |
+
crop_x1, crop_x2 = offset_w, offset_w + crop_w
|
1196 |
+
|
1197 |
+
crop_box = [crop_x1, crop_y1, crop_x2, crop_y2]
|
1198 |
+
# Calculate the iou between gt_bboxes and crop_boxes
|
1199 |
+
iou = self._iou_matrix(gt_bboxes,
|
1200 |
+
np.array([crop_box], dtype=np.float32))
|
1201 |
+
# If the maximum value of the iou is less than thresh,
|
1202 |
+
# the current crop_box is considered invalid.
|
1203 |
+
if iou.max() < thresh:
|
1204 |
+
continue
|
1205 |
+
|
1206 |
+
# If cover_all_box == True and the minimum value of
|
1207 |
+
# the iou is less than thresh, the current crop_box
|
1208 |
+
# is considered invalid.
|
1209 |
+
if self.cover_all_box and iou.min() < thresh:
|
1210 |
+
continue
|
1211 |
+
|
1212 |
+
# Get which gt_bboxes to keep after cropping.
|
1213 |
+
valid_inds = self._get_valid_inds(
|
1214 |
+
gt_bboxes, np.array(crop_box, dtype=np.float32))
|
1215 |
+
if valid_inds.size > 0:
|
1216 |
+
found = True
|
1217 |
+
break
|
1218 |
+
|
1219 |
+
if found:
|
1220 |
+
results = self._crop_data(results, crop_box, valid_inds)
|
1221 |
+
return results
|
1222 |
+
return results
|
1223 |
+
|
1224 |
+
@cache_randomness
|
1225 |
+
def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]:
|
1226 |
+
"""Randomly generate crop offset.
|
1227 |
+
|
1228 |
+
Args:
|
1229 |
+
margin (Tuple[int, int]): The upper bound for the offset generated
|
1230 |
+
randomly.
|
1231 |
+
|
1232 |
+
Returns:
|
1233 |
+
Tuple[int, int]: The random offset for the crop.
|
1234 |
+
"""
|
1235 |
+
margin_h, margin_w = margin
|
1236 |
+
offset_h = np.random.randint(0, margin_h + 1)
|
1237 |
+
offset_w = np.random.randint(0, margin_w + 1)
|
1238 |
+
|
1239 |
+
return (offset_h, offset_w)
|
1240 |
+
|
1241 |
+
@cache_randomness
|
1242 |
+
def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]:
|
1243 |
+
"""Randomly generates the crop size based on `image_size`.
|
1244 |
+
|
1245 |
+
Args:
|
1246 |
+
image_size (Tuple[int, int]): (h, w).
|
1247 |
+
|
1248 |
+
Returns:
|
1249 |
+
crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels.
|
1250 |
+
"""
|
1251 |
+
h, w = image_size
|
1252 |
+
scale = random.uniform(*self.scaling)
|
1253 |
+
if self.aspect_ratio is not None:
|
1254 |
+
min_ar, max_ar = self.aspect_ratio
|
1255 |
+
aspect_ratio = random.uniform(
|
1256 |
+
max(min_ar, scale**2), min(max_ar, scale**-2))
|
1257 |
+
h_scale = scale / np.sqrt(aspect_ratio)
|
1258 |
+
w_scale = scale * np.sqrt(aspect_ratio)
|
1259 |
+
else:
|
1260 |
+
h_scale = random.uniform(*self.scaling)
|
1261 |
+
w_scale = random.uniform(*self.scaling)
|
1262 |
+
crop_h = h * h_scale
|
1263 |
+
crop_w = w * w_scale
|
1264 |
+
return int(crop_h), int(crop_w)
|
1265 |
+
|
1266 |
+
def _iou_matrix(self,
|
1267 |
+
gt_bbox: HorizontalBoxes,
|
1268 |
+
crop_bbox: np.ndarray,
|
1269 |
+
eps: float = 1e-10) -> np.ndarray:
|
1270 |
+
"""Calculate iou between gt and image crop box.
|
1271 |
+
|
1272 |
+
Args:
|
1273 |
+
gt_bbox (HorizontalBoxes): Ground truth bounding boxes.
|
1274 |
+
crop_bbox (np.ndarray): Image crop coordinates in
|
1275 |
+
[x1, y1, x2, y2] format.
|
1276 |
+
eps (float): Default to 1e-10.
|
1277 |
+
Return:
|
1278 |
+
(np.ndarray): IoU.
|
1279 |
+
"""
|
1280 |
+
gt_bbox = gt_bbox.tensor.numpy()
|
1281 |
+
lefttop = np.maximum(gt_bbox[:, np.newaxis, :2], crop_bbox[:, :2])
|
1282 |
+
rightbottom = np.minimum(gt_bbox[:, np.newaxis, 2:], crop_bbox[:, 2:])
|
1283 |
+
|
1284 |
+
overlap = np.prod(
|
1285 |
+
rightbottom - lefttop,
|
1286 |
+
axis=2) * (lefttop < rightbottom).all(axis=2)
|
1287 |
+
area_gt_bbox = np.prod(gt_bbox[:, 2:] - crop_bbox[:, :2], axis=1)
|
1288 |
+
area_crop_bbox = np.prod(gt_bbox[:, 2:] - crop_bbox[:, :2], axis=1)
|
1289 |
+
area_o = (area_gt_bbox[:, np.newaxis] + area_crop_bbox - overlap)
|
1290 |
+
return overlap / (area_o + eps)
|
1291 |
+
|
1292 |
+
def _get_valid_inds(self, gt_bbox: HorizontalBoxes,
|
1293 |
+
img_crop_bbox: np.ndarray) -> np.ndarray:
|
1294 |
+
"""Get which Bboxes to keep at the current cropping coordinates.
|
1295 |
+
|
1296 |
+
Args:
|
1297 |
+
gt_bbox (HorizontalBoxes): Ground truth bounding boxes.
|
1298 |
+
img_crop_bbox (np.ndarray): Image crop coordinates in
|
1299 |
+
[x1, y1, x2, y2] format.
|
1300 |
+
|
1301 |
+
Returns:
|
1302 |
+
(np.ndarray): Valid indexes.
|
1303 |
+
"""
|
1304 |
+
cropped_box = gt_bbox.tensor.numpy().copy()
|
1305 |
+
gt_bbox = gt_bbox.tensor.numpy().copy()
|
1306 |
+
|
1307 |
+
cropped_box[:, :2] = np.maximum(gt_bbox[:, :2], img_crop_bbox[:2])
|
1308 |
+
cropped_box[:, 2:] = np.minimum(gt_bbox[:, 2:], img_crop_bbox[2:])
|
1309 |
+
cropped_box[:, :2] -= img_crop_bbox[:2]
|
1310 |
+
cropped_box[:, 2:] -= img_crop_bbox[:2]
|
1311 |
+
|
1312 |
+
centers = (gt_bbox[:, :2] + gt_bbox[:, 2:]) / 2
|
1313 |
+
valid = np.logical_and(img_crop_bbox[:2] <= centers,
|
1314 |
+
centers < img_crop_bbox[2:]).all(axis=1)
|
1315 |
+
valid = np.logical_and(
|
1316 |
+
valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
|
1317 |
+
|
1318 |
+
return np.where(valid)[0]
|
1319 |
+
|
1320 |
+
def __repr__(self) -> str:
|
1321 |
+
repr_str = self.__class__.__name__
|
1322 |
+
repr_str += f'(aspect_ratio={self.aspect_ratio}, '
|
1323 |
+
repr_str += f'thresholds={self.thresholds}, '
|
1324 |
+
repr_str += f'scaling={self.scaling}, '
|
1325 |
+
repr_str += f'num_attempts={self.num_attempts}, '
|
1326 |
+
repr_str += f'allow_no_crop={self.allow_no_crop}, '
|
1327 |
+
repr_str += f'cover_all_box={self.cover_all_box})'
|
1328 |
+
return repr_str
|
1329 |
+
|
1330 |
+
|
1331 |
+
@TRANSFORMS.register_module()
|
1332 |
+
class YOLOv5CopyPaste(BaseTransform):
|
1333 |
+
"""Copy-Paste used in YOLOv5 and YOLOv8.
|
1334 |
+
|
1335 |
+
This transform randomly copy some objects in the image to the mirror
|
1336 |
+
position of the image.It is different from the `CopyPaste` in mmdet.
|
1337 |
+
|
1338 |
+
Required Keys:
|
1339 |
+
|
1340 |
+
- img (np.uint8)
|
1341 |
+
- gt_bboxes (BaseBoxes[torch.float32])
|
1342 |
+
- gt_bboxes_labels (np.int64) (optional)
|
1343 |
+
- gt_ignore_flags (bool) (optional)
|
1344 |
+
- gt_masks (PolygonMasks) (optional)
|
1345 |
+
|
1346 |
+
Modified Keys:
|
1347 |
+
|
1348 |
+
- img
|
1349 |
+
- gt_bboxes
|
1350 |
+
- gt_bboxes_labels (np.int64) (optional)
|
1351 |
+
- gt_ignore_flags (optional)
|
1352 |
+
- gt_masks (optional)
|
1353 |
+
|
1354 |
+
Args:
|
1355 |
+
ioa_thresh (float): Ioa thresholds for deciding valid bbox.
|
1356 |
+
prob (float): Probability of choosing objects.
|
1357 |
+
Defaults to 0.5.
|
1358 |
+
"""
|
1359 |
+
|
1360 |
+
def __init__(self, ioa_thresh: float = 0.3, prob: float = 0.5):
|
1361 |
+
self.ioa_thresh = ioa_thresh
|
1362 |
+
self.prob = prob
|
1363 |
+
|
1364 |
+
@autocast_box_type()
|
1365 |
+
def transform(self, results: dict) -> Union[dict, None]:
|
1366 |
+
"""The YOLOv5 and YOLOv8 Copy-Paste transform function.
|
1367 |
+
|
1368 |
+
Args:
|
1369 |
+
results (dict): The result dict.
|
1370 |
+
|
1371 |
+
Returns:
|
1372 |
+
dict: The result dict.
|
1373 |
+
"""
|
1374 |
+
if len(results.get('gt_masks', [])) == 0:
|
1375 |
+
return results
|
1376 |
+
gt_masks = results['gt_masks']
|
1377 |
+
assert isinstance(gt_masks, PolygonMasks),\
|
1378 |
+
'only support type of PolygonMasks,' \
|
1379 |
+
' but get type: %s' % type(gt_masks)
|
1380 |
+
gt_bboxes = results['gt_bboxes']
|
1381 |
+
gt_bboxes_labels = results.get('gt_bboxes_labels', None)
|
1382 |
+
img = results['img']
|
1383 |
+
img_h, img_w = img.shape[:2]
|
1384 |
+
|
1385 |
+
# calculate ioa
|
1386 |
+
gt_bboxes_flip = deepcopy(gt_bboxes)
|
1387 |
+
gt_bboxes_flip.flip_(img.shape)
|
1388 |
+
|
1389 |
+
ioa = self.bbox_ioa(gt_bboxes_flip, gt_bboxes)
|
1390 |
+
indexes = torch.nonzero((ioa < self.ioa_thresh).all(1))[:, 0]
|
1391 |
+
n = len(indexes)
|
1392 |
+
valid_inds = random.choice(
|
1393 |
+
indexes, size=round(self.prob * n), replace=False)
|
1394 |
+
if len(valid_inds) == 0:
|
1395 |
+
return results
|
1396 |
+
|
1397 |
+
if gt_bboxes_labels is not None:
|
1398 |
+
# prepare labels
|
1399 |
+
gt_bboxes_labels = np.concatenate(
|
1400 |
+
(gt_bboxes_labels, gt_bboxes_labels[valid_inds]), axis=0)
|
1401 |
+
|
1402 |
+
# prepare bboxes
|
1403 |
+
copypaste_bboxes = gt_bboxes_flip[valid_inds]
|
1404 |
+
gt_bboxes = gt_bboxes.cat([gt_bboxes, copypaste_bboxes])
|
1405 |
+
|
1406 |
+
# prepare images
|
1407 |
+
copypaste_gt_masks = gt_masks[valid_inds]
|
1408 |
+
copypaste_gt_masks_flip = copypaste_gt_masks.flip()
|
1409 |
+
# convert poly format to bitmap format
|
1410 |
+
# example: poly: [[array(0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0]]
|
1411 |
+
# -> bitmap: a mask with shape equal to (1, img_h, img_w)
|
1412 |
+
# # type1 low speed
|
1413 |
+
# copypaste_gt_masks_bitmap = copypaste_gt_masks.to_ndarray()
|
1414 |
+
# copypaste_mask = np.sum(copypaste_gt_masks_bitmap, axis=0) > 0
|
1415 |
+
|
1416 |
+
# type2
|
1417 |
+
copypaste_mask = np.zeros((img_h, img_w), dtype=np.uint8)
|
1418 |
+
for poly in copypaste_gt_masks.masks:
|
1419 |
+
poly = [i.reshape((-1, 1, 2)).astype(np.int32) for i in poly]
|
1420 |
+
cv2.drawContours(copypaste_mask, poly, -1, (1, ), cv2.FILLED)
|
1421 |
+
|
1422 |
+
copypaste_mask = copypaste_mask.astype(bool)
|
1423 |
+
|
1424 |
+
# copy objects, and paste to the mirror position of the image
|
1425 |
+
copypaste_mask_flip = mmcv.imflip(
|
1426 |
+
copypaste_mask, direction='horizontal')
|
1427 |
+
copypaste_img = mmcv.imflip(img, direction='horizontal')
|
1428 |
+
img[copypaste_mask_flip] = copypaste_img[copypaste_mask_flip]
|
1429 |
+
|
1430 |
+
# prepare masks
|
1431 |
+
gt_masks = copypaste_gt_masks.cat([gt_masks, copypaste_gt_masks_flip])
|
1432 |
+
|
1433 |
+
if 'gt_ignore_flags' in results:
|
1434 |
+
# prepare gt_ignore_flags
|
1435 |
+
gt_ignore_flags = results['gt_ignore_flags']
|
1436 |
+
gt_ignore_flags = np.concatenate(
|
1437 |
+
[gt_ignore_flags, gt_ignore_flags[valid_inds]], axis=0)
|
1438 |
+
results['gt_ignore_flags'] = gt_ignore_flags
|
1439 |
+
|
1440 |
+
results['img'] = img
|
1441 |
+
results['gt_bboxes'] = gt_bboxes
|
1442 |
+
if gt_bboxes_labels is not None:
|
1443 |
+
results['gt_bboxes_labels'] = gt_bboxes_labels
|
1444 |
+
results['gt_masks'] = gt_masks
|
1445 |
+
|
1446 |
+
return results
|
1447 |
+
|
1448 |
+
@staticmethod
|
1449 |
+
def bbox_ioa(gt_bboxes_flip: HorizontalBoxes,
|
1450 |
+
gt_bboxes: HorizontalBoxes,
|
1451 |
+
eps: float = 1e-7) -> np.ndarray:
|
1452 |
+
"""Calculate ioa between gt_bboxes_flip and gt_bboxes.
|
1453 |
+
|
1454 |
+
Args:
|
1455 |
+
gt_bboxes_flip (HorizontalBoxes): Flipped ground truth
|
1456 |
+
bounding boxes.
|
1457 |
+
gt_bboxes (HorizontalBoxes): Ground truth bounding boxes.
|
1458 |
+
eps (float): Default to 1e-10.
|
1459 |
+
Return:
|
1460 |
+
(Tensor): Ioa.
|
1461 |
+
"""
|
1462 |
+
gt_bboxes_flip = gt_bboxes_flip.tensor
|
1463 |
+
gt_bboxes = gt_bboxes.tensor
|
1464 |
+
|
1465 |
+
# Get the coordinates of bounding boxes
|
1466 |
+
b1_x1, b1_y1, b1_x2, b1_y2 = gt_bboxes_flip.T
|
1467 |
+
b2_x1, b2_y1, b2_x2, b2_y2 = gt_bboxes.T
|
1468 |
+
|
1469 |
+
# Intersection area
|
1470 |
+
inter_area = (torch.minimum(b1_x2[:, None],
|
1471 |
+
b2_x2) - torch.maximum(b1_x1[:, None],
|
1472 |
+
b2_x1)).clip(0) * \
|
1473 |
+
(torch.minimum(b1_y2[:, None],
|
1474 |
+
b2_y2) - torch.maximum(b1_y1[:, None],
|
1475 |
+
b2_y1)).clip(0)
|
1476 |
+
|
1477 |
+
# box2 area
|
1478 |
+
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
|
1479 |
+
|
1480 |
+
# Intersection over box2 area
|
1481 |
+
return inter_area / box2_area
|
1482 |
+
|
1483 |
+
def __repr__(self) -> str:
|
1484 |
+
repr_str = self.__class__.__name__
|
1485 |
+
repr_str += f'(ioa_thresh={self.ioa_thresh},'
|
1486 |
+
repr_str += f'prob={self.prob})'
|
1487 |
+
return repr_str
|
1488 |
+
|
1489 |
+
|
1490 |
+
@TRANSFORMS.register_module()
|
1491 |
+
class RemoveDataElement(BaseTransform):
|
1492 |
+
"""Remove unnecessary data element in results.
|
1493 |
+
|
1494 |
+
Args:
|
1495 |
+
keys (Union[str, Sequence[str]]): Keys need to be removed.
|
1496 |
+
"""
|
1497 |
+
|
1498 |
+
def __init__(self, keys: Union[str, Sequence[str]]):
|
1499 |
+
self.keys = [keys] if isinstance(keys, str) else keys
|
1500 |
+
|
1501 |
+
def transform(self, results: dict) -> dict:
|
1502 |
+
for key in self.keys:
|
1503 |
+
results.pop(key, None)
|
1504 |
+
return results
|
1505 |
+
|
1506 |
+
def __repr__(self) -> str:
|
1507 |
+
repr_str = self.__class__.__name__
|
1508 |
+
repr_str += f'(keys={self.keys})'
|
1509 |
+
return repr_str
|
1510 |
+
|
1511 |
+
|
1512 |
+
@TRANSFORMS.register_module()
|
1513 |
+
class RegularizeRotatedBox(BaseTransform):
|
1514 |
+
"""Regularize rotated boxes.
|
1515 |
+
|
1516 |
+
Due to the angle periodicity, one rotated box can be represented in
|
1517 |
+
many different (x, y, w, h, t). To make each rotated box unique,
|
1518 |
+
``regularize_boxes`` will take the remainder of the angle divided by
|
1519 |
+
180 degrees.
|
1520 |
+
|
1521 |
+
For convenience, three angle_version can be used here:
|
1522 |
+
|
1523 |
+
- 'oc': OpenCV Definition. Has the same box representation as
|
1524 |
+
``cv2.minAreaRect`` the angle ranges in [-90, 0).
|
1525 |
+
- 'le90': Long Edge Definition (90). the angle ranges in [-90, 90).
|
1526 |
+
The width is always longer than the height.
|
1527 |
+
- 'le135': Long Edge Definition (135). the angle ranges in [-45, 135).
|
1528 |
+
The width is always longer than the height.
|
1529 |
+
|
1530 |
+
Required Keys:
|
1531 |
+
|
1532 |
+
- gt_bboxes (RotatedBoxes[torch.float32])
|
1533 |
+
|
1534 |
+
Modified Keys:
|
1535 |
+
|
1536 |
+
- gt_bboxes
|
1537 |
+
|
1538 |
+
Args:
|
1539 |
+
angle_version (str): Angle version. Can only be 'oc',
|
1540 |
+
'le90', or 'le135'. Defaults to 'le90.
|
1541 |
+
"""
|
1542 |
+
|
1543 |
+
def __init__(self, angle_version='le90') -> None:
|
1544 |
+
self.angle_version = angle_version
|
1545 |
+
try:
|
1546 |
+
from mmrotate.structures.bbox import RotatedBoxes
|
1547 |
+
self.box_type = RotatedBoxes
|
1548 |
+
except ImportError:
|
1549 |
+
raise ImportError(
|
1550 |
+
'Please run "mim install -r requirements/mmrotate.txt" '
|
1551 |
+
'to install mmrotate first for rotated detection.')
|
1552 |
+
|
1553 |
+
def transform(self, results: dict) -> dict:
|
1554 |
+
assert isinstance(results['gt_bboxes'], self.box_type)
|
1555 |
+
results['gt_bboxes'] = self.box_type(
|
1556 |
+
results['gt_bboxes'].regularize_boxes(self.angle_version))
|
1557 |
+
return results
|
mmyolo/datasets/utils.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List, Sequence
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from mmengine.dataset import COLLATE_FUNCTIONS
|
7 |
+
|
8 |
+
from ..registry import TASK_UTILS
|
9 |
+
|
10 |
+
|
11 |
+
@COLLATE_FUNCTIONS.register_module()
|
12 |
+
def yolov5_collate(data_batch: Sequence,
|
13 |
+
use_ms_training: bool = False) -> dict:
|
14 |
+
"""Rewrite collate_fn to get faster training speed.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
data_batch (Sequence): Batch of data.
|
18 |
+
use_ms_training (bool): Whether to use multi-scale training.
|
19 |
+
"""
|
20 |
+
batch_imgs = []
|
21 |
+
batch_bboxes_labels = []
|
22 |
+
batch_masks = []
|
23 |
+
for i in range(len(data_batch)):
|
24 |
+
datasamples = data_batch[i]['data_samples']
|
25 |
+
inputs = data_batch[i]['inputs']
|
26 |
+
batch_imgs.append(inputs)
|
27 |
+
|
28 |
+
gt_bboxes = datasamples.gt_instances.bboxes.tensor
|
29 |
+
gt_labels = datasamples.gt_instances.labels
|
30 |
+
if 'masks' in datasamples.gt_instances:
|
31 |
+
masks = datasamples.gt_instances.masks.to_tensor(
|
32 |
+
dtype=torch.bool, device=gt_bboxes.device)
|
33 |
+
batch_masks.append(masks)
|
34 |
+
batch_idx = gt_labels.new_full((len(gt_labels), 1), i)
|
35 |
+
bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes),
|
36 |
+
dim=1)
|
37 |
+
batch_bboxes_labels.append(bboxes_labels)
|
38 |
+
|
39 |
+
collated_results = {
|
40 |
+
'data_samples': {
|
41 |
+
'bboxes_labels': torch.cat(batch_bboxes_labels, 0)
|
42 |
+
}
|
43 |
+
}
|
44 |
+
if len(batch_masks) > 0:
|
45 |
+
collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0)
|
46 |
+
|
47 |
+
if use_ms_training:
|
48 |
+
collated_results['inputs'] = batch_imgs
|
49 |
+
else:
|
50 |
+
collated_results['inputs'] = torch.stack(batch_imgs, 0)
|
51 |
+
return collated_results
|
52 |
+
|
53 |
+
|
54 |
+
@TASK_UTILS.register_module()
|
55 |
+
class BatchShapePolicy:
|
56 |
+
"""BatchShapePolicy is only used in the testing phase, which can reduce the
|
57 |
+
number of pad pixels during batch inference.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
batch_size (int): Single GPU batch size during batch inference.
|
61 |
+
Defaults to 32.
|
62 |
+
img_size (int): Expected output image size. Defaults to 640.
|
63 |
+
size_divisor (int): The minimum size that is divisible
|
64 |
+
by size_divisor. Defaults to 32.
|
65 |
+
extra_pad_ratio (float): Extra pad ratio. Defaults to 0.5.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self,
|
69 |
+
batch_size: int = 32,
|
70 |
+
img_size: int = 640,
|
71 |
+
size_divisor: int = 32,
|
72 |
+
extra_pad_ratio: float = 0.5):
|
73 |
+
self.batch_size = batch_size
|
74 |
+
self.img_size = img_size
|
75 |
+
self.size_divisor = size_divisor
|
76 |
+
self.extra_pad_ratio = extra_pad_ratio
|
77 |
+
|
78 |
+
def __call__(self, data_list: List[dict]) -> List[dict]:
|
79 |
+
image_shapes = []
|
80 |
+
for data_info in data_list:
|
81 |
+
image_shapes.append((data_info['width'], data_info['height']))
|
82 |
+
|
83 |
+
image_shapes = np.array(image_shapes, dtype=np.float64)
|
84 |
+
|
85 |
+
n = len(image_shapes) # number of images
|
86 |
+
batch_index = np.floor(np.arange(n) / self.batch_size).astype(
|
87 |
+
np.int64) # batch index
|
88 |
+
number_of_batches = batch_index[-1] + 1 # number of batches
|
89 |
+
|
90 |
+
aspect_ratio = image_shapes[:, 1] / image_shapes[:, 0] # aspect ratio
|
91 |
+
irect = aspect_ratio.argsort()
|
92 |
+
|
93 |
+
data_list = [data_list[i] for i in irect]
|
94 |
+
|
95 |
+
aspect_ratio = aspect_ratio[irect]
|
96 |
+
# Set training image shapes
|
97 |
+
shapes = [[1, 1]] * number_of_batches
|
98 |
+
for i in range(number_of_batches):
|
99 |
+
aspect_ratio_index = aspect_ratio[batch_index == i]
|
100 |
+
min_index, max_index = aspect_ratio_index.min(
|
101 |
+
), aspect_ratio_index.max()
|
102 |
+
if max_index < 1:
|
103 |
+
shapes[i] = [max_index, 1]
|
104 |
+
elif min_index > 1:
|
105 |
+
shapes[i] = [1, 1 / min_index]
|
106 |
+
|
107 |
+
batch_shapes = np.ceil(
|
108 |
+
np.array(shapes) * self.img_size / self.size_divisor +
|
109 |
+
self.extra_pad_ratio).astype(np.int64) * self.size_divisor
|
110 |
+
|
111 |
+
for i, data_info in enumerate(data_list):
|
112 |
+
data_info['batch_shape'] = batch_shapes[batch_index[i]]
|
113 |
+
|
114 |
+
return data_list
|
mmyolo/datasets/yolov5_coco.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Any, Optional
|
3 |
+
|
4 |
+
from mmdet.datasets import BaseDetDataset, CocoDataset
|
5 |
+
|
6 |
+
from ..registry import DATASETS, TASK_UTILS
|
7 |
+
|
8 |
+
|
9 |
+
class BatchShapePolicyDataset(BaseDetDataset):
|
10 |
+
"""Dataset with the batch shape policy that makes paddings with least
|
11 |
+
pixels during batch inference process, which does not require the image
|
12 |
+
scales of all batches to be the same throughout validation."""
|
13 |
+
|
14 |
+
def __init__(self,
|
15 |
+
*args,
|
16 |
+
batch_shapes_cfg: Optional[dict] = None,
|
17 |
+
**kwargs):
|
18 |
+
self.batch_shapes_cfg = batch_shapes_cfg
|
19 |
+
super().__init__(*args, **kwargs)
|
20 |
+
|
21 |
+
def full_init(self):
|
22 |
+
"""rewrite full_init() to be compatible with serialize_data in
|
23 |
+
BatchShapePolicy."""
|
24 |
+
if self._fully_initialized:
|
25 |
+
return
|
26 |
+
# load data information
|
27 |
+
self.data_list = self.load_data_list()
|
28 |
+
|
29 |
+
# batch_shapes_cfg
|
30 |
+
if self.batch_shapes_cfg:
|
31 |
+
batch_shapes_policy = TASK_UTILS.build(self.batch_shapes_cfg)
|
32 |
+
self.data_list = batch_shapes_policy(self.data_list)
|
33 |
+
del batch_shapes_policy
|
34 |
+
|
35 |
+
# filter illegal data, such as data that has no annotations.
|
36 |
+
self.data_list = self.filter_data()
|
37 |
+
# Get subset data according to indices.
|
38 |
+
if self._indices is not None:
|
39 |
+
self.data_list = self._get_unserialized_subset(self._indices)
|
40 |
+
|
41 |
+
# serialize data_list
|
42 |
+
if self.serialize_data:
|
43 |
+
self.data_bytes, self.data_address = self._serialize_data()
|
44 |
+
|
45 |
+
self._fully_initialized = True
|
46 |
+
|
47 |
+
def prepare_data(self, idx: int) -> Any:
|
48 |
+
"""Pass the dataset to the pipeline during training to support mixed
|
49 |
+
data augmentation, such as Mosaic and MixUp."""
|
50 |
+
if self.test_mode is False:
|
51 |
+
data_info = self.get_data_info(idx)
|
52 |
+
data_info['dataset'] = self
|
53 |
+
return self.pipeline(data_info)
|
54 |
+
else:
|
55 |
+
return super().prepare_data(idx)
|
56 |
+
|
57 |
+
|
58 |
+
@DATASETS.register_module()
|
59 |
+
class YOLOv5CocoDataset(BatchShapePolicyDataset, CocoDataset):
|
60 |
+
"""Dataset for YOLOv5 COCO Dataset.
|
61 |
+
|
62 |
+
We only add `BatchShapePolicy` function compared with CocoDataset. See
|
63 |
+
`mmyolo/datasets/utils.py#BatchShapePolicy` for details
|
64 |
+
"""
|
65 |
+
pass
|
mmyolo/datasets/yolov5_crowdhuman.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmdet.datasets import CrowdHumanDataset
|
3 |
+
|
4 |
+
from ..registry import DATASETS
|
5 |
+
from .yolov5_coco import BatchShapePolicyDataset
|
6 |
+
|
7 |
+
|
8 |
+
@DATASETS.register_module()
|
9 |
+
class YOLOv5CrowdHumanDataset(BatchShapePolicyDataset, CrowdHumanDataset):
|
10 |
+
"""Dataset for YOLOv5 CrowdHuman Dataset.
|
11 |
+
|
12 |
+
We only add `BatchShapePolicy` function compared with CrowdHumanDataset.
|
13 |
+
See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
|
14 |
+
"""
|
15 |
+
pass
|
mmyolo/datasets/yolov5_dota.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
|
3 |
+
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
|
4 |
+
from ..registry import DATASETS
|
5 |
+
|
6 |
+
try:
|
7 |
+
from mmrotate.datasets import DOTADataset
|
8 |
+
MMROTATE_AVAILABLE = True
|
9 |
+
except ImportError:
|
10 |
+
from mmengine.dataset import BaseDataset
|
11 |
+
DOTADataset = BaseDataset
|
12 |
+
MMROTATE_AVAILABLE = False
|
13 |
+
|
14 |
+
|
15 |
+
@DATASETS.register_module()
|
16 |
+
class YOLOv5DOTADataset(BatchShapePolicyDataset, DOTADataset):
|
17 |
+
"""Dataset for YOLOv5 DOTA Dataset.
|
18 |
+
|
19 |
+
We only add `BatchShapePolicy` function compared with DOTADataset. See
|
20 |
+
`mmyolo/datasets/utils.py#BatchShapePolicy` for details
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, *args, **kwargs):
|
24 |
+
if not MMROTATE_AVAILABLE:
|
25 |
+
raise ImportError(
|
26 |
+
'Please run "mim install -r requirements/mmrotate.txt" '
|
27 |
+
'to install mmrotate first for rotated detection.')
|
28 |
+
|
29 |
+
super().__init__(*args, **kwargs)
|
mmyolo/datasets/yolov5_voc.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmdet.datasets import VOCDataset
|
3 |
+
|
4 |
+
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
|
5 |
+
from ..registry import DATASETS
|
6 |
+
|
7 |
+
|
8 |
+
@DATASETS.register_module()
|
9 |
+
class YOLOv5VOCDataset(BatchShapePolicyDataset, VOCDataset):
|
10 |
+
"""Dataset for YOLOv5 VOC Dataset.
|
11 |
+
|
12 |
+
We only add `BatchShapePolicy` function compared with VOCDataset. See
|
13 |
+
`mmyolo/datasets/utils.py#BatchShapePolicy` for details
|
14 |
+
"""
|
15 |
+
pass
|
mmyolo/deploy/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmdeploy.codebase.base import MMCodebase
|
3 |
+
|
4 |
+
from .models import * # noqa: F401,F403
|
5 |
+
from .object_detection import MMYOLO, YOLOObjectDetection
|
6 |
+
|
7 |
+
__all__ = ['MMCodebase', 'MMYOLO', 'YOLOObjectDetection']
|
mmyolo/deploy/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from . import dense_heads # noqa: F401,F403
|
mmyolo/deploy/models/dense_heads/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from . import yolov5_head # noqa: F401,F403
|
3 |
+
|
4 |
+
__all__ = ['yolov5_head']
|
mmyolo/deploy/models/dense_heads/yolov5_head.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
from functools import partial
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from mmdeploy.codebase.mmdet import get_post_processing_params
|
8 |
+
from mmdeploy.codebase.mmdet.models.layers import multiclass_nms
|
9 |
+
from mmdeploy.core import FUNCTION_REWRITER
|
10 |
+
from mmengine.config import ConfigDict
|
11 |
+
from mmengine.structures import InstanceData
|
12 |
+
from torch import Tensor
|
13 |
+
|
14 |
+
from mmyolo.deploy.models.layers import efficient_nms
|
15 |
+
from mmyolo.models.dense_heads import YOLOv5Head
|
16 |
+
|
17 |
+
|
18 |
+
def yolov5_bbox_decoder(priors: Tensor, bbox_preds: Tensor,
|
19 |
+
stride: int) -> Tensor:
|
20 |
+
"""Decode YOLOv5 bounding boxes.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
priors (Tensor): Prior boxes in center-offset form.
|
24 |
+
bbox_preds (Tensor): Predicted bounding boxes.
|
25 |
+
stride (int): Stride of the feature map.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Tensor: Decoded bounding boxes.
|
29 |
+
"""
|
30 |
+
bbox_preds = bbox_preds.sigmoid()
|
31 |
+
|
32 |
+
x_center = (priors[..., 0] + priors[..., 2]) * 0.5
|
33 |
+
y_center = (priors[..., 1] + priors[..., 3]) * 0.5
|
34 |
+
w = priors[..., 2] - priors[..., 0]
|
35 |
+
h = priors[..., 3] - priors[..., 1]
|
36 |
+
|
37 |
+
x_center_pred = (bbox_preds[..., 0] - 0.5) * 2 * stride + x_center
|
38 |
+
y_center_pred = (bbox_preds[..., 1] - 0.5) * 2 * stride + y_center
|
39 |
+
w_pred = (bbox_preds[..., 2] * 2)**2 * w
|
40 |
+
h_pred = (bbox_preds[..., 3] * 2)**2 * h
|
41 |
+
|
42 |
+
decoded_bboxes = torch.stack(
|
43 |
+
[x_center_pred, y_center_pred, w_pred, h_pred], dim=-1)
|
44 |
+
|
45 |
+
return decoded_bboxes
|
46 |
+
|
47 |
+
|
48 |
+
@FUNCTION_REWRITER.register_rewriter(
|
49 |
+
func_name='mmyolo.models.dense_heads.yolov5_head.'
|
50 |
+
'YOLOv5Head.predict_by_feat')
|
51 |
+
def yolov5_head__predict_by_feat(self,
|
52 |
+
cls_scores: List[Tensor],
|
53 |
+
bbox_preds: List[Tensor],
|
54 |
+
objectnesses: Optional[List[Tensor]] = None,
|
55 |
+
batch_img_metas: Optional[List[dict]] = None,
|
56 |
+
cfg: Optional[ConfigDict] = None,
|
57 |
+
rescale: bool = False,
|
58 |
+
with_nms: bool = True) -> Tuple[InstanceData]:
|
59 |
+
"""Transform a batch of output features extracted by the head into
|
60 |
+
bbox results.
|
61 |
+
Args:
|
62 |
+
cls_scores (list[Tensor]): Classification scores for all
|
63 |
+
scale levels, each is a 4D-tensor, has shape
|
64 |
+
(batch_size, num_priors * num_classes, H, W).
|
65 |
+
bbox_preds (list[Tensor]): Box energies / deltas for all
|
66 |
+
scale levels, each is a 4D-tensor, has shape
|
67 |
+
(batch_size, num_priors * 4, H, W).
|
68 |
+
objectnesses (list[Tensor], Optional): Score factor for
|
69 |
+
all scale level, each is a 4D-tensor, has shape
|
70 |
+
(batch_size, 1, H, W).
|
71 |
+
batch_img_metas (list[dict], Optional): Batch image meta info.
|
72 |
+
Defaults to None.
|
73 |
+
cfg (ConfigDict, optional): Test / postprocessing
|
74 |
+
configuration, if None, test_cfg would be used.
|
75 |
+
Defaults to None.
|
76 |
+
rescale (bool): If True, return boxes in original image space.
|
77 |
+
Defaults to False.
|
78 |
+
with_nms (bool): If True, do nms before return boxes.
|
79 |
+
Defaults to True.
|
80 |
+
Returns:
|
81 |
+
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
|
82 |
+
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
|
83 |
+
size and the score between 0 and 1. The shape of the second
|
84 |
+
tensor in the tuple is (N, num_box), and each element
|
85 |
+
represents the class label of the corresponding box.
|
86 |
+
"""
|
87 |
+
ctx = FUNCTION_REWRITER.get_context()
|
88 |
+
detector_type = type(self)
|
89 |
+
deploy_cfg = ctx.cfg
|
90 |
+
use_efficientnms = deploy_cfg.get('use_efficientnms', False)
|
91 |
+
dtype = cls_scores[0].dtype
|
92 |
+
device = cls_scores[0].device
|
93 |
+
bbox_decoder = self.bbox_coder.decode
|
94 |
+
nms_func = multiclass_nms
|
95 |
+
if use_efficientnms:
|
96 |
+
if detector_type is YOLOv5Head:
|
97 |
+
nms_func = partial(efficient_nms, box_coding=0)
|
98 |
+
bbox_decoder = yolov5_bbox_decoder
|
99 |
+
else:
|
100 |
+
nms_func = efficient_nms
|
101 |
+
|
102 |
+
assert len(cls_scores) == len(bbox_preds)
|
103 |
+
cfg = self.test_cfg if cfg is None else cfg
|
104 |
+
cfg = copy.deepcopy(cfg)
|
105 |
+
|
106 |
+
num_imgs = cls_scores[0].shape[0]
|
107 |
+
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
108 |
+
|
109 |
+
mlvl_priors = self.prior_generator.grid_priors(
|
110 |
+
featmap_sizes, dtype=dtype, device=device)
|
111 |
+
|
112 |
+
flatten_priors = torch.cat(mlvl_priors)
|
113 |
+
|
114 |
+
mlvl_strides = [
|
115 |
+
flatten_priors.new_full(
|
116 |
+
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
|
117 |
+
stride)
|
118 |
+
for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
|
119 |
+
]
|
120 |
+
flatten_stride = torch.cat(mlvl_strides)
|
121 |
+
|
122 |
+
# flatten cls_scores, bbox_preds and objectness
|
123 |
+
flatten_cls_scores = [
|
124 |
+
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
|
125 |
+
for cls_score in cls_scores
|
126 |
+
]
|
127 |
+
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
128 |
+
|
129 |
+
flatten_bbox_preds = [
|
130 |
+
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
131 |
+
for bbox_pred in bbox_preds
|
132 |
+
]
|
133 |
+
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
134 |
+
|
135 |
+
if objectnesses is not None:
|
136 |
+
flatten_objectness = [
|
137 |
+
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
|
138 |
+
for objectness in objectnesses
|
139 |
+
]
|
140 |
+
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
|
141 |
+
cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))
|
142 |
+
|
143 |
+
scores = cls_scores
|
144 |
+
|
145 |
+
bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
|
146 |
+
flatten_stride)
|
147 |
+
|
148 |
+
if not with_nms:
|
149 |
+
return bboxes, scores
|
150 |
+
|
151 |
+
post_params = get_post_processing_params(deploy_cfg)
|
152 |
+
max_output_boxes_per_class = post_params.max_output_boxes_per_class
|
153 |
+
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
|
154 |
+
score_threshold = cfg.get('score_thr', post_params.score_threshold)
|
155 |
+
pre_top_k = post_params.pre_top_k
|
156 |
+
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
|
157 |
+
|
158 |
+
return nms_func(bboxes, scores, max_output_boxes_per_class, iou_threshold,
|
159 |
+
score_threshold, pre_top_k, keep_top_k)
|
160 |
+
|
161 |
+
|
162 |
+
@FUNCTION_REWRITER.register_rewriter(
|
163 |
+
func_name='mmyolo.models.dense_heads.yolov5_head.'
|
164 |
+
'YOLOv5Head.predict',
|
165 |
+
backend='rknn')
|
166 |
+
def yolov5_head__predict__rknn(self, x: Tuple[Tensor], *args,
|
167 |
+
**kwargs) -> Tuple[Tensor, Tensor, Tensor]:
|
168 |
+
"""Perform forward propagation of the detection head and predict detection
|
169 |
+
results on the features of the upstream network.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
x (tuple[Tensor]): Multi-level features from the
|
173 |
+
upstream network, each is a 4D-tensor.
|
174 |
+
"""
|
175 |
+
outs = self(x)
|
176 |
+
return outs
|
177 |
+
|
178 |
+
|
179 |
+
@FUNCTION_REWRITER.register_rewriter(
|
180 |
+
func_name='mmyolo.models.dense_heads.yolov5_head.'
|
181 |
+
'YOLOv5HeadModule.forward',
|
182 |
+
backend='rknn')
|
183 |
+
def yolov5_head_module__forward__rknn(
|
184 |
+
self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
|
185 |
+
"""Forward feature of a single scale level."""
|
186 |
+
out = []
|
187 |
+
for i, feat in enumerate(x):
|
188 |
+
out.append(self.convs_pred[i](feat))
|
189 |
+
return out
|
mmyolo/deploy/models/layers/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .bbox_nms import efficient_nms
|
3 |
+
|
4 |
+
__all__ = ['efficient_nms']
|
mmyolo/deploy/models/layers/bbox_nms.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
from mmdeploy.core import mark
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
|
7 |
+
def _efficient_nms(
|
8 |
+
boxes: Tensor,
|
9 |
+
scores: Tensor,
|
10 |
+
max_output_boxes_per_class: int = 1000,
|
11 |
+
iou_threshold: float = 0.5,
|
12 |
+
score_threshold: float = 0.05,
|
13 |
+
pre_top_k: int = -1,
|
14 |
+
keep_top_k: int = 100,
|
15 |
+
box_coding: int = 0,
|
16 |
+
):
|
17 |
+
"""Wrapper for `efficient_nms` with TensorRT.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
|
21 |
+
scores (Tensor): The detection scores of shape
|
22 |
+
[N, num_boxes, num_classes].
|
23 |
+
max_output_boxes_per_class (int): Maximum number of output
|
24 |
+
boxes per class of nms. Defaults to 1000.
|
25 |
+
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
|
26 |
+
score_threshold (float): score threshold of nms.
|
27 |
+
Defaults to 0.05.
|
28 |
+
pre_top_k (int): Number of top K boxes to keep before nms.
|
29 |
+
Defaults to -1.
|
30 |
+
keep_top_k (int): Number of top K boxes to keep after nms.
|
31 |
+
Defaults to -1.
|
32 |
+
box_coding (int): Bounding boxes format for nms.
|
33 |
+
Defaults to 0 means [x, y, w, h].
|
34 |
+
Set to 1 means [x1, y1 ,x2, y2].
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
|
38 |
+
and `labels` of shape [N, num_det].
|
39 |
+
"""
|
40 |
+
boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
|
41 |
+
_, det_boxes, det_scores, labels = TRTEfficientNMSop.apply(
|
42 |
+
boxes, scores, -1, box_coding, iou_threshold, keep_top_k, '1', 0,
|
43 |
+
score_threshold)
|
44 |
+
dets = torch.cat([det_boxes, det_scores.unsqueeze(2)], -1)
|
45 |
+
|
46 |
+
# retain shape info
|
47 |
+
batch_size = boxes.size(0)
|
48 |
+
|
49 |
+
dets_shape = dets.shape
|
50 |
+
label_shape = labels.shape
|
51 |
+
dets = dets.reshape([batch_size, *dets_shape[1:]])
|
52 |
+
labels = labels.reshape([batch_size, *label_shape[1:]])
|
53 |
+
return dets, labels
|
54 |
+
|
55 |
+
|
56 |
+
@mark('efficient_nms', inputs=['boxes', 'scores'], outputs=['dets', 'labels'])
|
57 |
+
def efficient_nms(*args, **kwargs):
|
58 |
+
"""Wrapper function for `_efficient_nms`."""
|
59 |
+
return _efficient_nms(*args, **kwargs)
|
60 |
+
|
61 |
+
|
62 |
+
class TRTEfficientNMSop(torch.autograd.Function):
|
63 |
+
"""Efficient NMS op for TensorRT."""
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def forward(
|
67 |
+
ctx,
|
68 |
+
boxes,
|
69 |
+
scores,
|
70 |
+
background_class=-1,
|
71 |
+
box_coding=0,
|
72 |
+
iou_threshold=0.45,
|
73 |
+
max_output_boxes=100,
|
74 |
+
plugin_version='1',
|
75 |
+
score_activation=0,
|
76 |
+
score_threshold=0.25,
|
77 |
+
):
|
78 |
+
"""Forward function of TRTEfficientNMSop."""
|
79 |
+
batch_size, num_boxes, num_classes = scores.shape
|
80 |
+
num_det = torch.randint(
|
81 |
+
0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
|
82 |
+
det_boxes = torch.randn(batch_size, max_output_boxes, 4)
|
83 |
+
det_scores = torch.randn(batch_size, max_output_boxes)
|
84 |
+
det_classes = torch.randint(
|
85 |
+
0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
|
86 |
+
return num_det, det_boxes, det_scores, det_classes
|
87 |
+
|
88 |
+
@staticmethod
|
89 |
+
def symbolic(g,
|
90 |
+
boxes,
|
91 |
+
scores,
|
92 |
+
background_class=-1,
|
93 |
+
box_coding=0,
|
94 |
+
iou_threshold=0.45,
|
95 |
+
max_output_boxes=100,
|
96 |
+
plugin_version='1',
|
97 |
+
score_activation=0,
|
98 |
+
score_threshold=0.25):
|
99 |
+
"""Symbolic function of TRTEfficientNMSop."""
|
100 |
+
out = g.op(
|
101 |
+
'TRT::EfficientNMS_TRT',
|
102 |
+
boxes,
|
103 |
+
scores,
|
104 |
+
background_class_i=background_class,
|
105 |
+
box_coding_i=box_coding,
|
106 |
+
iou_threshold_f=iou_threshold,
|
107 |
+
max_output_boxes_i=max_output_boxes,
|
108 |
+
plugin_version_s=plugin_version,
|
109 |
+
score_activation_i=score_activation,
|
110 |
+
score_threshold_f=score_threshold,
|
111 |
+
outputs=4)
|
112 |
+
nums, boxes, scores, classes = out
|
113 |
+
return nums, boxes, scores, classes
|
mmyolo/deploy/object_detection.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Callable, Dict, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from mmdeploy.codebase.base import CODEBASE, MMCodebase
|
6 |
+
from mmdeploy.codebase.mmdet.deploy import ObjectDetection
|
7 |
+
from mmdeploy.utils import Codebase, Task
|
8 |
+
from mmengine import Config
|
9 |
+
from mmengine.registry import Registry
|
10 |
+
|
11 |
+
MMYOLO_TASK = Registry('mmyolo_tasks')
|
12 |
+
|
13 |
+
|
14 |
+
@CODEBASE.register_module(Codebase.MMYOLO.value)
|
15 |
+
class MMYOLO(MMCodebase):
|
16 |
+
"""MMYOLO codebase class."""
|
17 |
+
|
18 |
+
task_registry = MMYOLO_TASK
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def register_deploy_modules(cls):
|
22 |
+
"""register all rewriters for mmdet."""
|
23 |
+
import mmdeploy.codebase.mmdet.models # noqa: F401
|
24 |
+
import mmdeploy.codebase.mmdet.ops # noqa: F401
|
25 |
+
import mmdeploy.codebase.mmdet.structures # noqa: F401
|
26 |
+
|
27 |
+
@classmethod
|
28 |
+
def register_all_modules(cls):
|
29 |
+
"""register all modules."""
|
30 |
+
from mmdet.utils.setup_env import \
|
31 |
+
register_all_modules as register_all_modules_mmdet
|
32 |
+
|
33 |
+
from mmyolo.utils.setup_env import \
|
34 |
+
register_all_modules as register_all_modules_mmyolo
|
35 |
+
|
36 |
+
cls.register_deploy_modules()
|
37 |
+
register_all_modules_mmyolo(True)
|
38 |
+
register_all_modules_mmdet(False)
|
39 |
+
|
40 |
+
|
41 |
+
def _get_dataset_metainfo(model_cfg: Config):
|
42 |
+
"""Get metainfo of dataset.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
model_cfg Config: Input model Config object.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
list[str]: A list of string specifying names of different class.
|
49 |
+
"""
|
50 |
+
from mmyolo import datasets # noqa
|
51 |
+
from mmyolo.registry import DATASETS
|
52 |
+
|
53 |
+
module_dict = DATASETS.module_dict
|
54 |
+
for dataloader_name in [
|
55 |
+
'test_dataloader', 'val_dataloader', 'train_dataloader'
|
56 |
+
]:
|
57 |
+
if dataloader_name not in model_cfg:
|
58 |
+
continue
|
59 |
+
dataloader_cfg = model_cfg[dataloader_name]
|
60 |
+
dataset_cfg = dataloader_cfg.dataset
|
61 |
+
dataset_cls = module_dict.get(dataset_cfg.type, None)
|
62 |
+
if dataset_cls is None:
|
63 |
+
continue
|
64 |
+
if hasattr(dataset_cls, '_load_metainfo') and isinstance(
|
65 |
+
dataset_cls._load_metainfo, Callable):
|
66 |
+
meta = dataset_cls._load_metainfo(
|
67 |
+
dataset_cfg.get('metainfo', None))
|
68 |
+
if meta is not None:
|
69 |
+
return meta
|
70 |
+
if hasattr(dataset_cls, 'METAINFO'):
|
71 |
+
return dataset_cls.METAINFO
|
72 |
+
|
73 |
+
return None
|
74 |
+
|
75 |
+
|
76 |
+
@MMYOLO_TASK.register_module(Task.OBJECT_DETECTION.value)
|
77 |
+
class YOLOObjectDetection(ObjectDetection):
|
78 |
+
"""YOLO Object Detection task."""
|
79 |
+
|
80 |
+
def get_visualizer(self, name: str, save_dir: str):
|
81 |
+
"""Get visualizer.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
name (str): Name of visualizer.
|
85 |
+
save_dir (str): Directory to save visualization results.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
Visualizer: A visualizer instance.
|
89 |
+
"""
|
90 |
+
from mmdet.visualization import DetLocalVisualizer # noqa: F401,F403
|
91 |
+
metainfo = _get_dataset_metainfo(self.model_cfg)
|
92 |
+
visualizer = super().get_visualizer(name, save_dir)
|
93 |
+
if metainfo is not None:
|
94 |
+
visualizer.dataset_meta = metainfo
|
95 |
+
return visualizer
|
96 |
+
|
97 |
+
def build_pytorch_model(self,
|
98 |
+
model_checkpoint: Optional[str] = None,
|
99 |
+
cfg_options: Optional[Dict] = None,
|
100 |
+
**kwargs) -> torch.nn.Module:
|
101 |
+
"""Initialize torch model.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
model_checkpoint (str): The checkpoint file of torch model,
|
105 |
+
defaults to `None`.
|
106 |
+
cfg_options (dict): Optional config key-pair parameters.
|
107 |
+
Returns:
|
108 |
+
nn.Module: An initialized torch model generated by other OpenMMLab
|
109 |
+
codebases.
|
110 |
+
"""
|
111 |
+
from copy import deepcopy
|
112 |
+
|
113 |
+
from mmengine.model import revert_sync_batchnorm
|
114 |
+
from mmengine.registry import MODELS
|
115 |
+
|
116 |
+
from mmyolo.utils import switch_to_deploy
|
117 |
+
|
118 |
+
model = deepcopy(self.model_cfg.model)
|
119 |
+
preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
|
120 |
+
preprocess_cfg.update(
|
121 |
+
deepcopy(self.model_cfg.get('data_preprocessor', {})))
|
122 |
+
model.setdefault('data_preprocessor', preprocess_cfg)
|
123 |
+
model = MODELS.build(model)
|
124 |
+
if model_checkpoint is not None:
|
125 |
+
from mmengine.runner.checkpoint import load_checkpoint
|
126 |
+
load_checkpoint(model, model_checkpoint, map_location=self.device)
|
127 |
+
|
128 |
+
model = revert_sync_batchnorm(model)
|
129 |
+
switch_to_deploy(model)
|
130 |
+
model = model.to(self.device)
|
131 |
+
model.eval()
|
132 |
+
return model
|
mmyolo/engine/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .hooks import * # noqa: F401,F403
|
3 |
+
from .optimizers import * # noqa: F401,F403
|
mmyolo/engine/hooks/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .ppyoloe_param_scheduler_hook import PPYOLOEParamSchedulerHook
|
3 |
+
from .switch_to_deploy_hook import SwitchToDeployHook
|
4 |
+
from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
|
5 |
+
from .yolox_mode_switch_hook import YOLOXModeSwitchHook
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook',
|
9 |
+
'PPYOLOEParamSchedulerHook'
|
10 |
+
]
|
mmyolo/engine/hooks/ppyoloe_param_scheduler_hook.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from mmengine.hooks import ParamSchedulerHook
|
6 |
+
from mmengine.runner import Runner
|
7 |
+
|
8 |
+
from mmyolo.registry import HOOKS
|
9 |
+
|
10 |
+
|
11 |
+
@HOOKS.register_module()
|
12 |
+
class PPYOLOEParamSchedulerHook(ParamSchedulerHook):
|
13 |
+
"""A hook to update learning rate and momentum in optimizer of PPYOLOE. We
|
14 |
+
use this hook to implement adaptive computation for `warmup_total_iters`,
|
15 |
+
which is not possible with the built-in ParamScheduler in mmyolo.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
warmup_min_iter (int): Minimum warmup iters. Defaults to 1000.
|
19 |
+
start_factor (float): The number we multiply learning rate in the
|
20 |
+
first epoch. The multiplication factor changes towards end_factor
|
21 |
+
in the following epochs. Defaults to 0.
|
22 |
+
warmup_epochs (int): Epochs for warmup. Defaults to 5.
|
23 |
+
min_lr_ratio (float): Minimum learning rate ratio.
|
24 |
+
total_epochs (int): In PPYOLOE, `total_epochs` is set to
|
25 |
+
training_epochs x 1.2. Defaults to 360.
|
26 |
+
"""
|
27 |
+
priority = 9
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
warmup_min_iter: int = 1000,
|
31 |
+
start_factor: float = 0.,
|
32 |
+
warmup_epochs: int = 5,
|
33 |
+
min_lr_ratio: float = 0.0,
|
34 |
+
total_epochs: int = 360):
|
35 |
+
|
36 |
+
self.warmup_min_iter = warmup_min_iter
|
37 |
+
self.start_factor = start_factor
|
38 |
+
self.warmup_epochs = warmup_epochs
|
39 |
+
self.min_lr_ratio = min_lr_ratio
|
40 |
+
self.total_epochs = total_epochs
|
41 |
+
|
42 |
+
self._warmup_end = False
|
43 |
+
self._base_lr = None
|
44 |
+
|
45 |
+
def before_train(self, runner: Runner):
|
46 |
+
"""Operations before train.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
runner (Runner): The runner of the training process.
|
50 |
+
"""
|
51 |
+
optimizer = runner.optim_wrapper.optimizer
|
52 |
+
for group in optimizer.param_groups:
|
53 |
+
# If the param is never be scheduled, record the current value
|
54 |
+
# as the initial value.
|
55 |
+
group.setdefault('initial_lr', group['lr'])
|
56 |
+
|
57 |
+
self._base_lr = [
|
58 |
+
group['initial_lr'] for group in optimizer.param_groups
|
59 |
+
]
|
60 |
+
self._min_lr = [i * self.min_lr_ratio for i in self._base_lr]
|
61 |
+
|
62 |
+
def before_train_iter(self,
|
63 |
+
runner: Runner,
|
64 |
+
batch_idx: int,
|
65 |
+
data_batch: Optional[dict] = None):
|
66 |
+
"""Operations before each training iteration.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
runner (Runner): The runner of the training process.
|
70 |
+
batch_idx (int): The index of the current batch in the train loop.
|
71 |
+
data_batch (dict or tuple or list, optional): Data from dataloader.
|
72 |
+
"""
|
73 |
+
cur_iters = runner.iter
|
74 |
+
optimizer = runner.optim_wrapper.optimizer
|
75 |
+
dataloader_len = len(runner.train_dataloader)
|
76 |
+
|
77 |
+
# The minimum warmup is self.warmup_min_iter
|
78 |
+
warmup_total_iters = max(
|
79 |
+
round(self.warmup_epochs * dataloader_len), self.warmup_min_iter)
|
80 |
+
|
81 |
+
if cur_iters <= warmup_total_iters:
|
82 |
+
# warm up
|
83 |
+
alpha = cur_iters / warmup_total_iters
|
84 |
+
factor = self.start_factor * (1 - alpha) + alpha
|
85 |
+
|
86 |
+
for group_idx, param in enumerate(optimizer.param_groups):
|
87 |
+
param['lr'] = self._base_lr[group_idx] * factor
|
88 |
+
else:
|
89 |
+
for group_idx, param in enumerate(optimizer.param_groups):
|
90 |
+
total_iters = self.total_epochs * dataloader_len
|
91 |
+
lr = self._min_lr[group_idx] + (
|
92 |
+
self._base_lr[group_idx] -
|
93 |
+
self._min_lr[group_idx]) * 0.5 * (
|
94 |
+
math.cos((cur_iters - warmup_total_iters) * math.pi /
|
95 |
+
(total_iters - warmup_total_iters)) + 1.0)
|
96 |
+
param['lr'] = lr
|
mmyolo/engine/hooks/switch_to_deploy_hook.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
|
3 |
+
from mmengine.hooks import Hook
|
4 |
+
from mmengine.runner import Runner
|
5 |
+
|
6 |
+
from mmyolo.registry import HOOKS
|
7 |
+
from mmyolo.utils import switch_to_deploy
|
8 |
+
|
9 |
+
|
10 |
+
@HOOKS.register_module()
|
11 |
+
class SwitchToDeployHook(Hook):
|
12 |
+
"""Switch to deploy mode before testing.
|
13 |
+
|
14 |
+
This hook converts the multi-channel structure of the training network
|
15 |
+
(high performance) to the one-way structure of the testing network (fast
|
16 |
+
speed and memory saving).
|
17 |
+
"""
|
18 |
+
|
19 |
+
def before_test_epoch(self, runner: Runner):
|
20 |
+
"""Switch to deploy mode before testing."""
|
21 |
+
switch_to_deploy(runner.model)
|
mmyolo/engine/hooks/yolov5_param_scheduler_hook.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from mmengine.hooks import ParamSchedulerHook
|
7 |
+
from mmengine.runner import Runner
|
8 |
+
|
9 |
+
from mmyolo.registry import HOOKS
|
10 |
+
|
11 |
+
|
12 |
+
def linear_fn(lr_factor: float, max_epochs: int):
|
13 |
+
"""Generate linear function."""
|
14 |
+
return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor
|
15 |
+
|
16 |
+
|
17 |
+
def cosine_fn(lr_factor: float, max_epochs: int):
|
18 |
+
"""Generate cosine function."""
|
19 |
+
return lambda x: (
|
20 |
+
(1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1
|
21 |
+
|
22 |
+
|
23 |
+
@HOOKS.register_module()
|
24 |
+
class YOLOv5ParamSchedulerHook(ParamSchedulerHook):
|
25 |
+
"""A hook to update learning rate and momentum in optimizer of YOLOv5."""
|
26 |
+
priority = 9
|
27 |
+
|
28 |
+
scheduler_maps = {'linear': linear_fn, 'cosine': cosine_fn}
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
scheduler_type: str = 'linear',
|
32 |
+
lr_factor: float = 0.01,
|
33 |
+
max_epochs: int = 300,
|
34 |
+
warmup_epochs: int = 3,
|
35 |
+
warmup_bias_lr: float = 0.1,
|
36 |
+
warmup_momentum: float = 0.8,
|
37 |
+
warmup_mim_iter: int = 1000,
|
38 |
+
**kwargs):
|
39 |
+
|
40 |
+
assert scheduler_type in self.scheduler_maps
|
41 |
+
|
42 |
+
self.warmup_epochs = warmup_epochs
|
43 |
+
self.warmup_bias_lr = warmup_bias_lr
|
44 |
+
self.warmup_momentum = warmup_momentum
|
45 |
+
self.warmup_mim_iter = warmup_mim_iter
|
46 |
+
|
47 |
+
kwargs.update({'lr_factor': lr_factor, 'max_epochs': max_epochs})
|
48 |
+
self.scheduler_fn = self.scheduler_maps[scheduler_type](**kwargs)
|
49 |
+
|
50 |
+
self._warmup_end = False
|
51 |
+
self._base_lr = None
|
52 |
+
self._base_momentum = None
|
53 |
+
|
54 |
+
def before_train(self, runner: Runner):
|
55 |
+
"""Operations before train.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
runner (Runner): The runner of the training process.
|
59 |
+
"""
|
60 |
+
optimizer = runner.optim_wrapper.optimizer
|
61 |
+
for group in optimizer.param_groups:
|
62 |
+
# If the param is never be scheduled, record the current value
|
63 |
+
# as the initial value.
|
64 |
+
group.setdefault('initial_lr', group['lr'])
|
65 |
+
group.setdefault('initial_momentum', group.get('momentum', -1))
|
66 |
+
|
67 |
+
self._base_lr = [
|
68 |
+
group['initial_lr'] for group in optimizer.param_groups
|
69 |
+
]
|
70 |
+
self._base_momentum = [
|
71 |
+
group['initial_momentum'] for group in optimizer.param_groups
|
72 |
+
]
|
73 |
+
|
74 |
+
def before_train_iter(self,
|
75 |
+
runner: Runner,
|
76 |
+
batch_idx: int,
|
77 |
+
data_batch: Optional[dict] = None):
|
78 |
+
"""Operations before each training iteration.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
runner (Runner): The runner of the training process.
|
82 |
+
batch_idx (int): The index of the current batch in the train loop.
|
83 |
+
data_batch (dict or tuple or list, optional): Data from dataloader.
|
84 |
+
"""
|
85 |
+
cur_iters = runner.iter
|
86 |
+
cur_epoch = runner.epoch
|
87 |
+
optimizer = runner.optim_wrapper.optimizer
|
88 |
+
|
89 |
+
# The minimum warmup is self.warmup_mim_iter
|
90 |
+
warmup_total_iters = max(
|
91 |
+
round(self.warmup_epochs * len(runner.train_dataloader)),
|
92 |
+
self.warmup_mim_iter)
|
93 |
+
|
94 |
+
if cur_iters <= warmup_total_iters:
|
95 |
+
xp = [0, warmup_total_iters]
|
96 |
+
for group_idx, param in enumerate(optimizer.param_groups):
|
97 |
+
if group_idx == 2:
|
98 |
+
# bias learning rate will be handled specially
|
99 |
+
yp = [
|
100 |
+
self.warmup_bias_lr,
|
101 |
+
self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
|
102 |
+
]
|
103 |
+
else:
|
104 |
+
yp = [
|
105 |
+
0.0,
|
106 |
+
self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
|
107 |
+
]
|
108 |
+
param['lr'] = np.interp(cur_iters, xp, yp)
|
109 |
+
|
110 |
+
if 'momentum' in param:
|
111 |
+
param['momentum'] = np.interp(
|
112 |
+
cur_iters, xp,
|
113 |
+
[self.warmup_momentum, self._base_momentum[group_idx]])
|
114 |
+
else:
|
115 |
+
self._warmup_end = True
|
116 |
+
|
117 |
+
def after_train_epoch(self, runner: Runner):
|
118 |
+
"""Operations after each training epoch.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
runner (Runner): The runner of the training process.
|
122 |
+
"""
|
123 |
+
if not self._warmup_end:
|
124 |
+
return
|
125 |
+
|
126 |
+
cur_epoch = runner.epoch
|
127 |
+
optimizer = runner.optim_wrapper.optimizer
|
128 |
+
for group_idx, param in enumerate(optimizer.param_groups):
|
129 |
+
param['lr'] = self._base_lr[group_idx] * self.scheduler_fn(
|
130 |
+
cur_epoch)
|
mmyolo/engine/hooks/yolox_mode_switch_hook.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
from typing import Sequence
|
4 |
+
|
5 |
+
from mmengine.hooks import Hook
|
6 |
+
from mmengine.model import is_model_wrapper
|
7 |
+
from mmengine.runner import Runner
|
8 |
+
|
9 |
+
from mmyolo.registry import HOOKS
|
10 |
+
|
11 |
+
|
12 |
+
@HOOKS.register_module()
|
13 |
+
class YOLOXModeSwitchHook(Hook):
|
14 |
+
"""Switch the mode of YOLOX during training.
|
15 |
+
|
16 |
+
This hook turns off the mosaic and mixup data augmentation and switches
|
17 |
+
to use L1 loss in bbox_head.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
num_last_epochs (int): The number of latter epochs in the end of the
|
21 |
+
training to close the data augmentation and switch to L1 loss.
|
22 |
+
Defaults to 15.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
num_last_epochs: int = 15,
|
27 |
+
new_train_pipeline: Sequence[dict] = None):
|
28 |
+
self.num_last_epochs = num_last_epochs
|
29 |
+
self.new_train_pipeline_cfg = new_train_pipeline
|
30 |
+
|
31 |
+
def before_train_epoch(self, runner: Runner):
|
32 |
+
"""Close mosaic and mixup augmentation and switches to use L1 loss."""
|
33 |
+
epoch = runner.epoch
|
34 |
+
model = runner.model
|
35 |
+
if is_model_wrapper(model):
|
36 |
+
model = model.module
|
37 |
+
|
38 |
+
if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
|
39 |
+
runner.logger.info(f'New Pipeline: {self.new_train_pipeline_cfg}')
|
40 |
+
|
41 |
+
train_dataloader_cfg = copy.deepcopy(runner.cfg.train_dataloader)
|
42 |
+
train_dataloader_cfg.dataset.pipeline = self.new_train_pipeline_cfg
|
43 |
+
# Note: Why rebuild the dataset?
|
44 |
+
# When build_dataloader will make a deep copy of the dataset,
|
45 |
+
# it will lead to potential risks, such as the global instance
|
46 |
+
# object FileClient data is disordered.
|
47 |
+
# This problem needs to be solved in the future.
|
48 |
+
new_train_dataloader = Runner.build_dataloader(
|
49 |
+
train_dataloader_cfg)
|
50 |
+
runner.train_loop.dataloader = new_train_dataloader
|
51 |
+
|
52 |
+
runner.logger.info('recreate the dataloader!')
|
53 |
+
runner.logger.info('Add additional bbox reg loss now!')
|
54 |
+
model.bbox_head.use_bbox_aux = True
|
mmyolo/engine/optimizers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .yolov5_optim_constructor import YOLOv5OptimizerConstructor
|
3 |
+
from .yolov7_optim_wrapper_constructor import YOLOv7OptimWrapperConstructor
|
4 |
+
|
5 |
+
__all__ = ['YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor']
|
mmyolo/engine/optimizers/yolov5_optim_constructor.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from mmengine.dist import get_world_size
|
6 |
+
from mmengine.logging import print_log
|
7 |
+
from mmengine.model import is_model_wrapper
|
8 |
+
from mmengine.optim import OptimWrapper
|
9 |
+
|
10 |
+
from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
11 |
+
OPTIMIZERS)
|
12 |
+
|
13 |
+
|
14 |
+
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
15 |
+
class YOLOv5OptimizerConstructor:
|
16 |
+
"""YOLOv5 constructor for optimizers.
|
17 |
+
|
18 |
+
It has the following functions:
|
19 |
+
|
20 |
+
- divides the optimizer parameters into 3 groups:
|
21 |
+
Conv, Bias and BN
|
22 |
+
|
23 |
+
- support `weight_decay` parameter adaption based on
|
24 |
+
`batch_size_per_gpu`
|
25 |
+
|
26 |
+
Args:
|
27 |
+
optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
|
28 |
+
Positional fields are
|
29 |
+
|
30 |
+
- ``type``: class name of the OptimizerWrapper
|
31 |
+
- ``optimizer``: The configuration of optimizer.
|
32 |
+
|
33 |
+
Optional fields are
|
34 |
+
|
35 |
+
- any arguments of the corresponding optimizer wrapper type,
|
36 |
+
e.g., accumulative_counts, clip_grad, etc.
|
37 |
+
|
38 |
+
The positional fields of ``optimizer`` are
|
39 |
+
|
40 |
+
- `type`: class name of the optimizer.
|
41 |
+
|
42 |
+
Optional fields are
|
43 |
+
|
44 |
+
- any arguments of the corresponding optimizer type, e.g.,
|
45 |
+
lr, weight_decay, momentum, etc.
|
46 |
+
|
47 |
+
paramwise_cfg (dict, optional): Parameter-wise options. Must include
|
48 |
+
`base_total_batch_size` if not None. If the total input batch
|
49 |
+
is smaller than `base_total_batch_size`, the `weight_decay`
|
50 |
+
parameter will be kept unchanged, otherwise linear scaling.
|
51 |
+
|
52 |
+
Example:
|
53 |
+
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
54 |
+
>>> optim_wrapper_cfg = dict(
|
55 |
+
>>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
|
56 |
+
>>> momentum=0.9, weight_decay=0.0001, batch_size_per_gpu=16))
|
57 |
+
>>> paramwise_cfg = dict(base_total_batch_size=64)
|
58 |
+
>>> optim_wrapper_builder = YOLOv5OptimizerConstructor(
|
59 |
+
>>> optim_wrapper_cfg, paramwise_cfg)
|
60 |
+
>>> optim_wrapper = optim_wrapper_builder(model)
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self,
|
64 |
+
optim_wrapper_cfg: dict,
|
65 |
+
paramwise_cfg: Optional[dict] = None):
|
66 |
+
if paramwise_cfg is None:
|
67 |
+
paramwise_cfg = {'base_total_batch_size': 64}
|
68 |
+
assert 'base_total_batch_size' in paramwise_cfg
|
69 |
+
|
70 |
+
if not isinstance(optim_wrapper_cfg, dict):
|
71 |
+
raise TypeError('optimizer_cfg should be a dict',
|
72 |
+
f'but got {type(optim_wrapper_cfg)}')
|
73 |
+
assert 'optimizer' in optim_wrapper_cfg, (
|
74 |
+
'`optim_wrapper_cfg` must contain "optimizer" config')
|
75 |
+
|
76 |
+
self.optim_wrapper_cfg = optim_wrapper_cfg
|
77 |
+
self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer')
|
78 |
+
self.base_total_batch_size = paramwise_cfg['base_total_batch_size']
|
79 |
+
|
80 |
+
def __call__(self, model: nn.Module) -> OptimWrapper:
|
81 |
+
if is_model_wrapper(model):
|
82 |
+
model = model.module
|
83 |
+
optimizer_cfg = self.optimizer_cfg.copy()
|
84 |
+
weight_decay = optimizer_cfg.pop('weight_decay', 0)
|
85 |
+
|
86 |
+
if 'batch_size_per_gpu' in optimizer_cfg:
|
87 |
+
batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu')
|
88 |
+
# No scaling if total_batch_size is less than
|
89 |
+
# base_total_batch_size, otherwise linear scaling.
|
90 |
+
total_batch_size = get_world_size() * batch_size_per_gpu
|
91 |
+
accumulate = max(
|
92 |
+
round(self.base_total_batch_size / total_batch_size), 1)
|
93 |
+
scale_factor = total_batch_size * \
|
94 |
+
accumulate / self.base_total_batch_size
|
95 |
+
|
96 |
+
if scale_factor != 1:
|
97 |
+
weight_decay *= scale_factor
|
98 |
+
print_log(f'Scaled weight_decay to {weight_decay}', 'current')
|
99 |
+
|
100 |
+
params_groups = [], [], []
|
101 |
+
|
102 |
+
for v in model.modules():
|
103 |
+
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
|
104 |
+
params_groups[2].append(v.bias)
|
105 |
+
# Includes SyncBatchNorm
|
106 |
+
if isinstance(v, nn.modules.batchnorm._NormBase):
|
107 |
+
params_groups[1].append(v.weight)
|
108 |
+
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
|
109 |
+
params_groups[0].append(v.weight)
|
110 |
+
|
111 |
+
# Note: Make sure bias is in the last parameter group
|
112 |
+
optimizer_cfg['params'] = []
|
113 |
+
# conv
|
114 |
+
optimizer_cfg['params'].append({
|
115 |
+
'params': params_groups[0],
|
116 |
+
'weight_decay': weight_decay
|
117 |
+
})
|
118 |
+
# bn
|
119 |
+
optimizer_cfg['params'].append({'params': params_groups[1]})
|
120 |
+
# bias
|
121 |
+
optimizer_cfg['params'].append({'params': params_groups[2]})
|
122 |
+
|
123 |
+
print_log(
|
124 |
+
'Optimizer groups: %g .bias, %g conv.weight, %g other' %
|
125 |
+
(len(params_groups[2]), len(params_groups[0]), len(
|
126 |
+
params_groups[1])), 'current')
|
127 |
+
del params_groups
|
128 |
+
|
129 |
+
optimizer = OPTIMIZERS.build(optimizer_cfg)
|
130 |
+
optim_wrapper = OPTIM_WRAPPERS.build(
|
131 |
+
self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
|
132 |
+
return optim_wrapper
|
mmyolo/engine/optimizers/yolov7_optim_wrapper_constructor.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from mmengine.dist import get_world_size
|
6 |
+
from mmengine.logging import print_log
|
7 |
+
from mmengine.model import is_model_wrapper
|
8 |
+
from mmengine.optim import OptimWrapper
|
9 |
+
|
10 |
+
from mmyolo.models.dense_heads.yolov7_head import ImplicitA, ImplicitM
|
11 |
+
from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
|
12 |
+
OPTIMIZERS)
|
13 |
+
|
14 |
+
|
15 |
+
# TODO: Consider merging into YOLOv5OptimizerConstructor
|
16 |
+
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
17 |
+
class YOLOv7OptimWrapperConstructor:
|
18 |
+
"""YOLOv7 constructor for optimizer wrappers.
|
19 |
+
|
20 |
+
It has the following functions:
|
21 |
+
|
22 |
+
- divides the optimizer parameters into 3 groups:
|
23 |
+
Conv, Bias and BN/ImplicitA/ImplicitM
|
24 |
+
|
25 |
+
- support `weight_decay` parameter adaption based on
|
26 |
+
`batch_size_per_gpu`
|
27 |
+
|
28 |
+
Args:
|
29 |
+
optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
|
30 |
+
Positional fields are
|
31 |
+
|
32 |
+
- ``type``: class name of the OptimizerWrapper
|
33 |
+
- ``optimizer``: The configuration of optimizer.
|
34 |
+
|
35 |
+
Optional fields are
|
36 |
+
|
37 |
+
- any arguments of the corresponding optimizer wrapper type,
|
38 |
+
e.g., accumulative_counts, clip_grad, etc.
|
39 |
+
|
40 |
+
The positional fields of ``optimizer`` are
|
41 |
+
|
42 |
+
- `type`: class name of the optimizer.
|
43 |
+
|
44 |
+
Optional fields are
|
45 |
+
|
46 |
+
- any arguments of the corresponding optimizer type, e.g.,
|
47 |
+
lr, weight_decay, momentum, etc.
|
48 |
+
|
49 |
+
paramwise_cfg (dict, optional): Parameter-wise options. Must include
|
50 |
+
`base_total_batch_size` if not None. If the total input batch
|
51 |
+
is smaller than `base_total_batch_size`, the `weight_decay`
|
52 |
+
parameter will be kept unchanged, otherwise linear scaling.
|
53 |
+
|
54 |
+
Example:
|
55 |
+
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
56 |
+
>>> optim_wrapper_cfg = dict(
|
57 |
+
>>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
|
58 |
+
>>> momentum=0.9, weight_decay=0.0001, batch_size_per_gpu=16))
|
59 |
+
>>> paramwise_cfg = dict(base_total_batch_size=64)
|
60 |
+
>>> optim_wrapper_builder = YOLOv7OptimWrapperConstructor(
|
61 |
+
>>> optim_wrapper_cfg, paramwise_cfg)
|
62 |
+
>>> optim_wrapper = optim_wrapper_builder(model)
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self,
|
66 |
+
optim_wrapper_cfg: dict,
|
67 |
+
paramwise_cfg: Optional[dict] = None):
|
68 |
+
if paramwise_cfg is None:
|
69 |
+
paramwise_cfg = {'base_total_batch_size': 64}
|
70 |
+
assert 'base_total_batch_size' in paramwise_cfg
|
71 |
+
|
72 |
+
if not isinstance(optim_wrapper_cfg, dict):
|
73 |
+
raise TypeError('optimizer_cfg should be a dict',
|
74 |
+
f'but got {type(optim_wrapper_cfg)}')
|
75 |
+
assert 'optimizer' in optim_wrapper_cfg, (
|
76 |
+
'`optim_wrapper_cfg` must contain "optimizer" config')
|
77 |
+
|
78 |
+
self.optim_wrapper_cfg = optim_wrapper_cfg
|
79 |
+
self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer')
|
80 |
+
self.base_total_batch_size = paramwise_cfg['base_total_batch_size']
|
81 |
+
|
82 |
+
def __call__(self, model: nn.Module) -> OptimWrapper:
|
83 |
+
if is_model_wrapper(model):
|
84 |
+
model = model.module
|
85 |
+
optimizer_cfg = self.optimizer_cfg.copy()
|
86 |
+
weight_decay = optimizer_cfg.pop('weight_decay', 0)
|
87 |
+
|
88 |
+
if 'batch_size_per_gpu' in optimizer_cfg:
|
89 |
+
batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu')
|
90 |
+
# No scaling if total_batch_size is less than
|
91 |
+
# base_total_batch_size, otherwise linear scaling.
|
92 |
+
total_batch_size = get_world_size() * batch_size_per_gpu
|
93 |
+
accumulate = max(
|
94 |
+
round(self.base_total_batch_size / total_batch_size), 1)
|
95 |
+
scale_factor = total_batch_size * \
|
96 |
+
accumulate / self.base_total_batch_size
|
97 |
+
|
98 |
+
if scale_factor != 1:
|
99 |
+
weight_decay *= scale_factor
|
100 |
+
print_log(f'Scaled weight_decay to {weight_decay}', 'current')
|
101 |
+
|
102 |
+
params_groups = [], [], []
|
103 |
+
for v in model.modules():
|
104 |
+
# no decay
|
105 |
+
# Caution: Coupling with model
|
106 |
+
if isinstance(v, (ImplicitA, ImplicitM)):
|
107 |
+
params_groups[0].append(v.implicit)
|
108 |
+
elif isinstance(v, nn.modules.batchnorm._NormBase):
|
109 |
+
params_groups[0].append(v.weight)
|
110 |
+
# apply decay
|
111 |
+
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
|
112 |
+
params_groups[1].append(v.weight) # apply decay
|
113 |
+
|
114 |
+
# biases, no decay
|
115 |
+
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
|
116 |
+
params_groups[2].append(v.bias)
|
117 |
+
|
118 |
+
# Note: Make sure bias is in the last parameter group
|
119 |
+
optimizer_cfg['params'] = []
|
120 |
+
# conv
|
121 |
+
optimizer_cfg['params'].append({
|
122 |
+
'params': params_groups[1],
|
123 |
+
'weight_decay': weight_decay
|
124 |
+
})
|
125 |
+
# bn ...
|
126 |
+
optimizer_cfg['params'].append({'params': params_groups[0]})
|
127 |
+
# bias
|
128 |
+
optimizer_cfg['params'].append({'params': params_groups[2]})
|
129 |
+
|
130 |
+
print_log(
|
131 |
+
'Optimizer groups: %g .bias, %g conv.weight, %g other' %
|
132 |
+
(len(params_groups[2]), len(params_groups[1]), len(
|
133 |
+
params_groups[0])), 'current')
|
134 |
+
del params_groups
|
135 |
+
|
136 |
+
optimizer = OPTIMIZERS.build(optimizer_cfg)
|
137 |
+
optim_wrapper = OPTIM_WRAPPERS.build(
|
138 |
+
self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
|
139 |
+
return optim_wrapper
|
mmyolo/models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .backbones import * # noqa: F401,F403
|
3 |
+
from .data_preprocessors import * # noqa: F401,F403
|
4 |
+
from .dense_heads import * # noqa: F401,F403
|
5 |
+
from .detectors import * # noqa: F401,F403
|
6 |
+
from .layers import * # noqa: F401,F403
|
7 |
+
from .losses import * # noqa: F401,F403
|
8 |
+
from .necks import * # noqa: F401,F403
|
9 |
+
from .plugins import * # noqa: F401,F403
|
10 |
+
from .task_modules import * # noqa: F401,F403
|
mmyolo/models/backbones/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .base_backbone import BaseBackbone
|
3 |
+
from .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet
|
4 |
+
from .csp_resnet import PPYOLOECSPResNet
|
5 |
+
from .cspnext import CSPNeXt
|
6 |
+
from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep
|
7 |
+
from .yolov7_backbone import YOLOv7Backbone
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',
|
11 |
+
'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet',
|
12 |
+
'YOLOv8CSPDarknet'
|
13 |
+
]
|
mmyolo/models/backbones/base_backbone.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from abc import ABCMeta, abstractmethod
|
3 |
+
from typing import List, Sequence, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from mmcv.cnn import build_plugin_layer
|
8 |
+
from mmdet.utils import ConfigType, OptMultiConfig
|
9 |
+
from mmengine.model import BaseModule
|
10 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
11 |
+
|
12 |
+
from mmyolo.registry import MODELS
|
13 |
+
|
14 |
+
|
15 |
+
@MODELS.register_module()
|
16 |
+
class BaseBackbone(BaseModule, metaclass=ABCMeta):
|
17 |
+
"""BaseBackbone backbone used in YOLO series.
|
18 |
+
|
19 |
+
.. code:: text
|
20 |
+
|
21 |
+
Backbone model structure diagram
|
22 |
+
+-----------+
|
23 |
+
| input |
|
24 |
+
+-----------+
|
25 |
+
v
|
26 |
+
+-----------+
|
27 |
+
| stem |
|
28 |
+
| layer |
|
29 |
+
+-----------+
|
30 |
+
v
|
31 |
+
+-----------+
|
32 |
+
| stage |
|
33 |
+
| layer 1 |
|
34 |
+
+-----------+
|
35 |
+
v
|
36 |
+
+-----------+
|
37 |
+
| stage |
|
38 |
+
| layer 2 |
|
39 |
+
+-----------+
|
40 |
+
v
|
41 |
+
......
|
42 |
+
v
|
43 |
+
+-----------+
|
44 |
+
| stage |
|
45 |
+
| layer n |
|
46 |
+
+-----------+
|
47 |
+
In P5 model, n=4
|
48 |
+
In P6 model, n=5
|
49 |
+
|
50 |
+
Args:
|
51 |
+
arch_setting (list): Architecture of BaseBackbone.
|
52 |
+
plugins (list[dict]): List of plugins for stages, each dict contains:
|
53 |
+
|
54 |
+
- cfg (dict, required): Cfg dict to build plugin.
|
55 |
+
- stages (tuple[bool], optional): Stages to apply plugin, length
|
56 |
+
should be same as 'num_stages'.
|
57 |
+
deepen_factor (float): Depth multiplier, multiply number of
|
58 |
+
blocks in CSP layer by this amount. Defaults to 1.0.
|
59 |
+
widen_factor (float): Width multiplier, multiply number of
|
60 |
+
channels in each layer by this amount. Defaults to 1.0.
|
61 |
+
input_channels: Number of input image channels. Defaults to 3.
|
62 |
+
out_indices (Sequence[int]): Output from which stages.
|
63 |
+
Defaults to (2, 3, 4).
|
64 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
65 |
+
mode). -1 means not freezing any parameters. Defaults to -1.
|
66 |
+
norm_cfg (dict): Dictionary to construct and config norm layer.
|
67 |
+
Defaults to None.
|
68 |
+
act_cfg (dict): Config dict for activation layer.
|
69 |
+
Defaults to None.
|
70 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
71 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
72 |
+
and its variants only. Defaults to False.
|
73 |
+
init_cfg (dict or list[dict], optional): Initialization config dict.
|
74 |
+
Defaults to None.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self,
|
78 |
+
arch_setting: list,
|
79 |
+
deepen_factor: float = 1.0,
|
80 |
+
widen_factor: float = 1.0,
|
81 |
+
input_channels: int = 3,
|
82 |
+
out_indices: Sequence[int] = (2, 3, 4),
|
83 |
+
frozen_stages: int = -1,
|
84 |
+
plugins: Union[dict, List[dict]] = None,
|
85 |
+
norm_cfg: ConfigType = None,
|
86 |
+
act_cfg: ConfigType = None,
|
87 |
+
norm_eval: bool = False,
|
88 |
+
init_cfg: OptMultiConfig = None):
|
89 |
+
super().__init__(init_cfg)
|
90 |
+
self.num_stages = len(arch_setting)
|
91 |
+
self.arch_setting = arch_setting
|
92 |
+
|
93 |
+
assert set(out_indices).issubset(
|
94 |
+
i for i in range(len(arch_setting) + 1))
|
95 |
+
|
96 |
+
if frozen_stages not in range(-1, len(arch_setting) + 1):
|
97 |
+
raise ValueError('"frozen_stages" must be in range(-1, '
|
98 |
+
'len(arch_setting) + 1). But received '
|
99 |
+
f'{frozen_stages}')
|
100 |
+
|
101 |
+
self.input_channels = input_channels
|
102 |
+
self.out_indices = out_indices
|
103 |
+
self.frozen_stages = frozen_stages
|
104 |
+
self.widen_factor = widen_factor
|
105 |
+
self.deepen_factor = deepen_factor
|
106 |
+
self.norm_eval = norm_eval
|
107 |
+
self.norm_cfg = norm_cfg
|
108 |
+
self.act_cfg = act_cfg
|
109 |
+
self.plugins = plugins
|
110 |
+
|
111 |
+
self.stem = self.build_stem_layer()
|
112 |
+
self.layers = ['stem']
|
113 |
+
|
114 |
+
for idx, setting in enumerate(arch_setting):
|
115 |
+
stage = []
|
116 |
+
stage += self.build_stage_layer(idx, setting)
|
117 |
+
if plugins is not None:
|
118 |
+
stage += self.make_stage_plugins(plugins, idx, setting)
|
119 |
+
self.add_module(f'stage{idx + 1}', nn.Sequential(*stage))
|
120 |
+
self.layers.append(f'stage{idx + 1}')
|
121 |
+
|
122 |
+
@abstractmethod
|
123 |
+
def build_stem_layer(self):
|
124 |
+
"""Build a stem layer."""
|
125 |
+
pass
|
126 |
+
|
127 |
+
@abstractmethod
|
128 |
+
def build_stage_layer(self, stage_idx: int, setting: list):
|
129 |
+
"""Build a stage layer.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
stage_idx (int): The index of a stage layer.
|
133 |
+
setting (list): The architecture setting of a stage layer.
|
134 |
+
"""
|
135 |
+
pass
|
136 |
+
|
137 |
+
def make_stage_plugins(self, plugins, stage_idx, setting):
|
138 |
+
"""Make plugins for backbone ``stage_idx`` th stage.
|
139 |
+
|
140 |
+
Currently we support to insert ``context_block``,
|
141 |
+
``empirical_attention_block``, ``nonlocal_block``, ``dropout_block``
|
142 |
+
into the backbone.
|
143 |
+
|
144 |
+
|
145 |
+
An example of plugins format could be:
|
146 |
+
|
147 |
+
Examples:
|
148 |
+
>>> plugins=[
|
149 |
+
... dict(cfg=dict(type='xxx', arg1='xxx'),
|
150 |
+
... stages=(False, True, True, True)),
|
151 |
+
... dict(cfg=dict(type='yyy'),
|
152 |
+
... stages=(True, True, True, True)),
|
153 |
+
... ]
|
154 |
+
>>> model = YOLOv5CSPDarknet()
|
155 |
+
>>> stage_plugins = model.make_stage_plugins(plugins, 0, setting)
|
156 |
+
>>> assert len(stage_plugins) == 1
|
157 |
+
|
158 |
+
Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
|
159 |
+
|
160 |
+
.. code-block:: none
|
161 |
+
|
162 |
+
conv1 -> conv2 -> conv3 -> yyy
|
163 |
+
|
164 |
+
Suppose ``stage_idx=1``, the structure of blocks in the stage would be:
|
165 |
+
|
166 |
+
.. code-block:: none
|
167 |
+
|
168 |
+
conv1 -> conv2 -> conv3 -> xxx -> yyy
|
169 |
+
|
170 |
+
|
171 |
+
Args:
|
172 |
+
plugins (list[dict]): List of plugins cfg to build. The postfix is
|
173 |
+
required if multiple same type plugins are inserted.
|
174 |
+
stage_idx (int): Index of stage to build
|
175 |
+
If stages is missing, the plugin would be applied to all
|
176 |
+
stages.
|
177 |
+
setting (list): The architecture setting of a stage layer.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
list[nn.Module]: Plugins for current stage
|
181 |
+
"""
|
182 |
+
# TODO: It is not general enough to support any channel and needs
|
183 |
+
# to be refactored
|
184 |
+
in_channels = int(setting[1] * self.widen_factor)
|
185 |
+
plugin_layers = []
|
186 |
+
for plugin in plugins:
|
187 |
+
plugin = plugin.copy()
|
188 |
+
stages = plugin.pop('stages', None)
|
189 |
+
assert stages is None or len(stages) == self.num_stages
|
190 |
+
if stages is None or stages[stage_idx]:
|
191 |
+
name, layer = build_plugin_layer(
|
192 |
+
plugin['cfg'], in_channels=in_channels)
|
193 |
+
plugin_layers.append(layer)
|
194 |
+
return plugin_layers
|
195 |
+
|
196 |
+
def _freeze_stages(self):
|
197 |
+
"""Freeze the parameters of the specified stage so that they are no
|
198 |
+
longer updated."""
|
199 |
+
if self.frozen_stages >= 0:
|
200 |
+
for i in range(self.frozen_stages + 1):
|
201 |
+
m = getattr(self, self.layers[i])
|
202 |
+
m.eval()
|
203 |
+
for param in m.parameters():
|
204 |
+
param.requires_grad = False
|
205 |
+
|
206 |
+
def train(self, mode: bool = True):
|
207 |
+
"""Convert the model into training mode while keep normalization layer
|
208 |
+
frozen."""
|
209 |
+
super().train(mode)
|
210 |
+
self._freeze_stages()
|
211 |
+
if mode and self.norm_eval:
|
212 |
+
for m in self.modules():
|
213 |
+
if isinstance(m, _BatchNorm):
|
214 |
+
m.eval()
|
215 |
+
|
216 |
+
def forward(self, x: torch.Tensor) -> tuple:
|
217 |
+
"""Forward batch_inputs from the data_preprocessor."""
|
218 |
+
outs = []
|
219 |
+
for i, layer_name in enumerate(self.layers):
|
220 |
+
layer = getattr(self, layer_name)
|
221 |
+
x = layer(x)
|
222 |
+
if i in self.out_indices:
|
223 |
+
outs.append(x)
|
224 |
+
|
225 |
+
return tuple(outs)
|
mmyolo/models/backbones/csp_darknet.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
7 |
+
from mmdet.models.backbones.csp_darknet import CSPLayer, Focus
|
8 |
+
from mmdet.utils import ConfigType, OptMultiConfig
|
9 |
+
|
10 |
+
from mmyolo.registry import MODELS
|
11 |
+
from ..layers import CSPLayerWithTwoConv, SPPFBottleneck
|
12 |
+
from ..utils import make_divisible, make_round
|
13 |
+
from .base_backbone import BaseBackbone
|
14 |
+
|
15 |
+
|
16 |
+
@MODELS.register_module()
|
17 |
+
class YOLOv5CSPDarknet(BaseBackbone):
|
18 |
+
"""CSP-Darknet backbone used in YOLOv5.
|
19 |
+
Args:
|
20 |
+
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
|
21 |
+
Defaults to P5.
|
22 |
+
plugins (list[dict]): List of plugins for stages, each dict contains:
|
23 |
+
- cfg (dict, required): Cfg dict to build plugin.
|
24 |
+
- stages (tuple[bool], optional): Stages to apply plugin, length
|
25 |
+
should be same as 'num_stages'.
|
26 |
+
deepen_factor (float): Depth multiplier, multiply number of
|
27 |
+
blocks in CSP layer by this amount. Defaults to 1.0.
|
28 |
+
widen_factor (float): Width multiplier, multiply number of
|
29 |
+
channels in each layer by this amount. Defaults to 1.0.
|
30 |
+
input_channels (int): Number of input image channels. Defaults to: 3.
|
31 |
+
out_indices (Tuple[int]): Output from which stages.
|
32 |
+
Defaults to (2, 3, 4).
|
33 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
34 |
+
mode). -1 means not freezing any parameters. Defaults to -1.
|
35 |
+
norm_cfg (dict): Dictionary to construct and config norm layer.
|
36 |
+
Defaults to dict(type='BN', requires_grad=True).
|
37 |
+
act_cfg (dict): Config dict for activation layer.
|
38 |
+
Defaults to dict(type='SiLU', inplace=True).
|
39 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
40 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
41 |
+
and its variants only. Defaults to False.
|
42 |
+
init_cfg (Union[dict,list[dict]], optional): Initialization config
|
43 |
+
dict. Defaults to None.
|
44 |
+
Example:
|
45 |
+
>>> from mmyolo.models import YOLOv5CSPDarknet
|
46 |
+
>>> import torch
|
47 |
+
>>> model = YOLOv5CSPDarknet()
|
48 |
+
>>> model.eval()
|
49 |
+
>>> inputs = torch.rand(1, 3, 416, 416)
|
50 |
+
>>> level_outputs = model(inputs)
|
51 |
+
>>> for level_out in level_outputs:
|
52 |
+
... print(tuple(level_out.shape))
|
53 |
+
...
|
54 |
+
(1, 256, 52, 52)
|
55 |
+
(1, 512, 26, 26)
|
56 |
+
(1, 1024, 13, 13)
|
57 |
+
"""
|
58 |
+
# From left to right:
|
59 |
+
# in_channels, out_channels, num_blocks, add_identity, use_spp
|
60 |
+
arch_settings = {
|
61 |
+
'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
|
62 |
+
[256, 512, 9, True, False], [512, 1024, 3, True, True]],
|
63 |
+
'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
|
64 |
+
[256, 512, 9, True, False], [512, 768, 3, True, False],
|
65 |
+
[768, 1024, 3, True, True]]
|
66 |
+
}
|
67 |
+
|
68 |
+
def __init__(self,
|
69 |
+
arch: str = 'P5',
|
70 |
+
plugins: Union[dict, List[dict]] = None,
|
71 |
+
deepen_factor: float = 1.0,
|
72 |
+
widen_factor: float = 1.0,
|
73 |
+
input_channels: int = 3,
|
74 |
+
out_indices: Tuple[int] = (2, 3, 4),
|
75 |
+
frozen_stages: int = -1,
|
76 |
+
norm_cfg: ConfigType = dict(
|
77 |
+
type='BN', momentum=0.03, eps=0.001),
|
78 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
79 |
+
norm_eval: bool = False,
|
80 |
+
init_cfg: OptMultiConfig = None):
|
81 |
+
super().__init__(
|
82 |
+
self.arch_settings[arch],
|
83 |
+
deepen_factor,
|
84 |
+
widen_factor,
|
85 |
+
input_channels=input_channels,
|
86 |
+
out_indices=out_indices,
|
87 |
+
plugins=plugins,
|
88 |
+
frozen_stages=frozen_stages,
|
89 |
+
norm_cfg=norm_cfg,
|
90 |
+
act_cfg=act_cfg,
|
91 |
+
norm_eval=norm_eval,
|
92 |
+
init_cfg=init_cfg)
|
93 |
+
|
94 |
+
def build_stem_layer(self) -> nn.Module:
|
95 |
+
"""Build a stem layer."""
|
96 |
+
return ConvModule(
|
97 |
+
self.input_channels,
|
98 |
+
make_divisible(self.arch_setting[0][0], self.widen_factor),
|
99 |
+
kernel_size=6,
|
100 |
+
stride=2,
|
101 |
+
padding=2,
|
102 |
+
norm_cfg=self.norm_cfg,
|
103 |
+
act_cfg=self.act_cfg)
|
104 |
+
|
105 |
+
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
106 |
+
"""Build a stage layer.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
stage_idx (int): The index of a stage layer.
|
110 |
+
setting (list): The architecture setting of a stage layer.
|
111 |
+
"""
|
112 |
+
in_channels, out_channels, num_blocks, add_identity, use_spp = setting
|
113 |
+
|
114 |
+
in_channels = make_divisible(in_channels, self.widen_factor)
|
115 |
+
out_channels = make_divisible(out_channels, self.widen_factor)
|
116 |
+
num_blocks = make_round(num_blocks, self.deepen_factor)
|
117 |
+
stage = []
|
118 |
+
conv_layer = ConvModule(
|
119 |
+
in_channels,
|
120 |
+
out_channels,
|
121 |
+
kernel_size=3,
|
122 |
+
stride=2,
|
123 |
+
padding=1,
|
124 |
+
norm_cfg=self.norm_cfg,
|
125 |
+
act_cfg=self.act_cfg)
|
126 |
+
stage.append(conv_layer)
|
127 |
+
csp_layer = CSPLayer(
|
128 |
+
out_channels,
|
129 |
+
out_channels,
|
130 |
+
num_blocks=num_blocks,
|
131 |
+
add_identity=add_identity,
|
132 |
+
norm_cfg=self.norm_cfg,
|
133 |
+
act_cfg=self.act_cfg)
|
134 |
+
stage.append(csp_layer)
|
135 |
+
if use_spp:
|
136 |
+
spp = SPPFBottleneck(
|
137 |
+
out_channels,
|
138 |
+
out_channels,
|
139 |
+
kernel_sizes=5,
|
140 |
+
norm_cfg=self.norm_cfg,
|
141 |
+
act_cfg=self.act_cfg)
|
142 |
+
stage.append(spp)
|
143 |
+
return stage
|
144 |
+
|
145 |
+
def init_weights(self):
|
146 |
+
"""Initialize the parameters."""
|
147 |
+
if self.init_cfg is None:
|
148 |
+
for m in self.modules():
|
149 |
+
if isinstance(m, torch.nn.Conv2d):
|
150 |
+
# In order to be consistent with the source code,
|
151 |
+
# reset the Conv2d initialization parameters
|
152 |
+
m.reset_parameters()
|
153 |
+
else:
|
154 |
+
super().init_weights()
|
155 |
+
|
156 |
+
|
157 |
+
@MODELS.register_module()
|
158 |
+
class YOLOv8CSPDarknet(BaseBackbone):
|
159 |
+
"""CSP-Darknet backbone used in YOLOv8.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
arch (str): Architecture of CSP-Darknet, from {P5}.
|
163 |
+
Defaults to P5.
|
164 |
+
last_stage_out_channels (int): Final layer output channel.
|
165 |
+
Defaults to 1024.
|
166 |
+
plugins (list[dict]): List of plugins for stages, each dict contains:
|
167 |
+
- cfg (dict, required): Cfg dict to build plugin.
|
168 |
+
- stages (tuple[bool], optional): Stages to apply plugin, length
|
169 |
+
should be same as 'num_stages'.
|
170 |
+
deepen_factor (float): Depth multiplier, multiply number of
|
171 |
+
blocks in CSP layer by this amount. Defaults to 1.0.
|
172 |
+
widen_factor (float): Width multiplier, multiply number of
|
173 |
+
channels in each layer by this amount. Defaults to 1.0.
|
174 |
+
input_channels (int): Number of input image channels. Defaults to: 3.
|
175 |
+
out_indices (Tuple[int]): Output from which stages.
|
176 |
+
Defaults to (2, 3, 4).
|
177 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
178 |
+
mode). -1 means not freezing any parameters. Defaults to -1.
|
179 |
+
norm_cfg (dict): Dictionary to construct and config norm layer.
|
180 |
+
Defaults to dict(type='BN', requires_grad=True).
|
181 |
+
act_cfg (dict): Config dict for activation layer.
|
182 |
+
Defaults to dict(type='SiLU', inplace=True).
|
183 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
184 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
185 |
+
and its variants only. Defaults to False.
|
186 |
+
init_cfg (Union[dict,list[dict]], optional): Initialization config
|
187 |
+
dict. Defaults to None.
|
188 |
+
|
189 |
+
Example:
|
190 |
+
>>> from mmyolo.models import YOLOv8CSPDarknet
|
191 |
+
>>> import torch
|
192 |
+
>>> model = YOLOv8CSPDarknet()
|
193 |
+
>>> model.eval()
|
194 |
+
>>> inputs = torch.rand(1, 3, 416, 416)
|
195 |
+
>>> level_outputs = model(inputs)
|
196 |
+
>>> for level_out in level_outputs:
|
197 |
+
... print(tuple(level_out.shape))
|
198 |
+
...
|
199 |
+
(1, 256, 52, 52)
|
200 |
+
(1, 512, 26, 26)
|
201 |
+
(1, 1024, 13, 13)
|
202 |
+
"""
|
203 |
+
# From left to right:
|
204 |
+
# in_channels, out_channels, num_blocks, add_identity, use_spp
|
205 |
+
# the final out_channels will be set according to the param.
|
206 |
+
arch_settings = {
|
207 |
+
'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
|
208 |
+
[256, 512, 6, True, False], [512, None, 3, True, True]],
|
209 |
+
}
|
210 |
+
|
211 |
+
def __init__(self,
|
212 |
+
arch: str = 'P5',
|
213 |
+
last_stage_out_channels: int = 1024,
|
214 |
+
plugins: Union[dict, List[dict]] = None,
|
215 |
+
deepen_factor: float = 1.0,
|
216 |
+
widen_factor: float = 1.0,
|
217 |
+
input_channels: int = 3,
|
218 |
+
out_indices: Tuple[int] = (2, 3, 4),
|
219 |
+
frozen_stages: int = -1,
|
220 |
+
norm_cfg: ConfigType = dict(
|
221 |
+
type='BN', momentum=0.03, eps=0.001),
|
222 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
223 |
+
norm_eval: bool = False,
|
224 |
+
init_cfg: OptMultiConfig = None):
|
225 |
+
self.arch_settings[arch][-1][1] = last_stage_out_channels
|
226 |
+
super().__init__(
|
227 |
+
self.arch_settings[arch],
|
228 |
+
deepen_factor,
|
229 |
+
widen_factor,
|
230 |
+
input_channels=input_channels,
|
231 |
+
out_indices=out_indices,
|
232 |
+
plugins=plugins,
|
233 |
+
frozen_stages=frozen_stages,
|
234 |
+
norm_cfg=norm_cfg,
|
235 |
+
act_cfg=act_cfg,
|
236 |
+
norm_eval=norm_eval,
|
237 |
+
init_cfg=init_cfg)
|
238 |
+
|
239 |
+
def build_stem_layer(self) -> nn.Module:
|
240 |
+
"""Build a stem layer."""
|
241 |
+
return ConvModule(
|
242 |
+
self.input_channels,
|
243 |
+
make_divisible(self.arch_setting[0][0], self.widen_factor),
|
244 |
+
kernel_size=3,
|
245 |
+
stride=2,
|
246 |
+
padding=1,
|
247 |
+
norm_cfg=self.norm_cfg,
|
248 |
+
act_cfg=self.act_cfg)
|
249 |
+
|
250 |
+
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
251 |
+
"""Build a stage layer.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
stage_idx (int): The index of a stage layer.
|
255 |
+
setting (list): The architecture setting of a stage layer.
|
256 |
+
"""
|
257 |
+
in_channels, out_channels, num_blocks, add_identity, use_spp = setting
|
258 |
+
|
259 |
+
in_channels = make_divisible(in_channels, self.widen_factor)
|
260 |
+
out_channels = make_divisible(out_channels, self.widen_factor)
|
261 |
+
num_blocks = make_round(num_blocks, self.deepen_factor)
|
262 |
+
stage = []
|
263 |
+
conv_layer = ConvModule(
|
264 |
+
in_channels,
|
265 |
+
out_channels,
|
266 |
+
kernel_size=3,
|
267 |
+
stride=2,
|
268 |
+
padding=1,
|
269 |
+
norm_cfg=self.norm_cfg,
|
270 |
+
act_cfg=self.act_cfg)
|
271 |
+
stage.append(conv_layer)
|
272 |
+
csp_layer = CSPLayerWithTwoConv(
|
273 |
+
out_channels,
|
274 |
+
out_channels,
|
275 |
+
num_blocks=num_blocks,
|
276 |
+
add_identity=add_identity,
|
277 |
+
norm_cfg=self.norm_cfg,
|
278 |
+
act_cfg=self.act_cfg)
|
279 |
+
stage.append(csp_layer)
|
280 |
+
if use_spp:
|
281 |
+
spp = SPPFBottleneck(
|
282 |
+
out_channels,
|
283 |
+
out_channels,
|
284 |
+
kernel_sizes=5,
|
285 |
+
norm_cfg=self.norm_cfg,
|
286 |
+
act_cfg=self.act_cfg)
|
287 |
+
stage.append(spp)
|
288 |
+
return stage
|
289 |
+
|
290 |
+
def init_weights(self):
|
291 |
+
"""Initialize the parameters."""
|
292 |
+
if self.init_cfg is None:
|
293 |
+
for m in self.modules():
|
294 |
+
if isinstance(m, torch.nn.Conv2d):
|
295 |
+
# In order to be consistent with the source code,
|
296 |
+
# reset the Conv2d initialization parameters
|
297 |
+
m.reset_parameters()
|
298 |
+
else:
|
299 |
+
super().init_weights()
|
300 |
+
|
301 |
+
|
302 |
+
@MODELS.register_module()
|
303 |
+
class YOLOXCSPDarknet(BaseBackbone):
|
304 |
+
"""CSP-Darknet backbone used in YOLOX.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
|
308 |
+
Defaults to P5.
|
309 |
+
plugins (list[dict]): List of plugins for stages, each dict contains:
|
310 |
+
|
311 |
+
- cfg (dict, required): Cfg dict to build plugin.
|
312 |
+
- stages (tuple[bool], optional): Stages to apply plugin, length
|
313 |
+
should be same as 'num_stages'.
|
314 |
+
deepen_factor (float): Depth multiplier, multiply number of
|
315 |
+
blocks in CSP layer by this amount. Defaults to 1.0.
|
316 |
+
widen_factor (float): Width multiplier, multiply number of
|
317 |
+
channels in each layer by this amount. Defaults to 1.0.
|
318 |
+
input_channels (int): Number of input image channels. Defaults to 3.
|
319 |
+
out_indices (Tuple[int]): Output from which stages.
|
320 |
+
Defaults to (2, 3, 4).
|
321 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
322 |
+
mode). -1 means not freezing any parameters. Defaults to -1.
|
323 |
+
use_depthwise (bool): Whether to use depthwise separable convolution.
|
324 |
+
Defaults to False.
|
325 |
+
spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP
|
326 |
+
layers. Defaults to (5, 9, 13).
|
327 |
+
norm_cfg (dict): Dictionary to construct and config norm layer.
|
328 |
+
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
329 |
+
act_cfg (dict): Config dict for activation layer.
|
330 |
+
Defaults to dict(type='SiLU', inplace=True).
|
331 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
332 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
333 |
+
and its variants only.
|
334 |
+
init_cfg (Union[dict,list[dict]], optional): Initialization config
|
335 |
+
dict. Defaults to None.
|
336 |
+
Example:
|
337 |
+
>>> from mmyolo.models import YOLOXCSPDarknet
|
338 |
+
>>> import torch
|
339 |
+
>>> model = YOLOXCSPDarknet()
|
340 |
+
>>> model.eval()
|
341 |
+
>>> inputs = torch.rand(1, 3, 416, 416)
|
342 |
+
>>> level_outputs = model(inputs)
|
343 |
+
>>> for level_out in level_outputs:
|
344 |
+
... print(tuple(level_out.shape))
|
345 |
+
...
|
346 |
+
(1, 256, 52, 52)
|
347 |
+
(1, 512, 26, 26)
|
348 |
+
(1, 1024, 13, 13)
|
349 |
+
"""
|
350 |
+
# From left to right:
|
351 |
+
# in_channels, out_channels, num_blocks, add_identity, use_spp
|
352 |
+
arch_settings = {
|
353 |
+
'P5': [[64, 128, 3, True, False], [128, 256, 9, True, False],
|
354 |
+
[256, 512, 9, True, False], [512, 1024, 3, False, True]],
|
355 |
+
}
|
356 |
+
|
357 |
+
def __init__(self,
|
358 |
+
arch: str = 'P5',
|
359 |
+
plugins: Union[dict, List[dict]] = None,
|
360 |
+
deepen_factor: float = 1.0,
|
361 |
+
widen_factor: float = 1.0,
|
362 |
+
input_channels: int = 3,
|
363 |
+
out_indices: Tuple[int] = (2, 3, 4),
|
364 |
+
frozen_stages: int = -1,
|
365 |
+
use_depthwise: bool = False,
|
366 |
+
spp_kernal_sizes: Tuple[int] = (5, 9, 13),
|
367 |
+
norm_cfg: ConfigType = dict(
|
368 |
+
type='BN', momentum=0.03, eps=0.001),
|
369 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
370 |
+
norm_eval: bool = False,
|
371 |
+
init_cfg: OptMultiConfig = None):
|
372 |
+
self.use_depthwise = use_depthwise
|
373 |
+
self.spp_kernal_sizes = spp_kernal_sizes
|
374 |
+
super().__init__(self.arch_settings[arch], deepen_factor, widen_factor,
|
375 |
+
input_channels, out_indices, frozen_stages, plugins,
|
376 |
+
norm_cfg, act_cfg, norm_eval, init_cfg)
|
377 |
+
|
378 |
+
def build_stem_layer(self) -> nn.Module:
|
379 |
+
"""Build a stem layer."""
|
380 |
+
return Focus(
|
381 |
+
3,
|
382 |
+
make_divisible(64, self.widen_factor),
|
383 |
+
kernel_size=3,
|
384 |
+
norm_cfg=self.norm_cfg,
|
385 |
+
act_cfg=self.act_cfg)
|
386 |
+
|
387 |
+
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
388 |
+
"""Build a stage layer.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
stage_idx (int): The index of a stage layer.
|
392 |
+
setting (list): The architecture setting of a stage layer.
|
393 |
+
"""
|
394 |
+
in_channels, out_channels, num_blocks, add_identity, use_spp = setting
|
395 |
+
|
396 |
+
in_channels = make_divisible(in_channels, self.widen_factor)
|
397 |
+
out_channels = make_divisible(out_channels, self.widen_factor)
|
398 |
+
num_blocks = make_round(num_blocks, self.deepen_factor)
|
399 |
+
stage = []
|
400 |
+
conv = DepthwiseSeparableConvModule \
|
401 |
+
if self.use_depthwise else ConvModule
|
402 |
+
conv_layer = conv(
|
403 |
+
in_channels,
|
404 |
+
out_channels,
|
405 |
+
kernel_size=3,
|
406 |
+
stride=2,
|
407 |
+
padding=1,
|
408 |
+
norm_cfg=self.norm_cfg,
|
409 |
+
act_cfg=self.act_cfg)
|
410 |
+
stage.append(conv_layer)
|
411 |
+
if use_spp:
|
412 |
+
spp = SPPFBottleneck(
|
413 |
+
out_channels,
|
414 |
+
out_channels,
|
415 |
+
kernel_sizes=self.spp_kernal_sizes,
|
416 |
+
norm_cfg=self.norm_cfg,
|
417 |
+
act_cfg=self.act_cfg)
|
418 |
+
stage.append(spp)
|
419 |
+
csp_layer = CSPLayer(
|
420 |
+
out_channels,
|
421 |
+
out_channels,
|
422 |
+
num_blocks=num_blocks,
|
423 |
+
add_identity=add_identity,
|
424 |
+
norm_cfg=self.norm_cfg,
|
425 |
+
act_cfg=self.act_cfg)
|
426 |
+
stage.append(csp_layer)
|
427 |
+
return stage
|
mmyolo/models/backbones/csp_resnet.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List, Tuple, Union
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from mmcv.cnn import ConvModule
|
6 |
+
from mmdet.utils import ConfigType, OptMultiConfig
|
7 |
+
|
8 |
+
from mmyolo.models.backbones import BaseBackbone
|
9 |
+
from mmyolo.models.layers.yolo_bricks import CSPResLayer
|
10 |
+
from mmyolo.registry import MODELS
|
11 |
+
|
12 |
+
|
13 |
+
@MODELS.register_module()
|
14 |
+
class PPYOLOECSPResNet(BaseBackbone):
|
15 |
+
"""CSP-ResNet backbone used in PPYOLOE.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
arch (str): Architecture of CSPNeXt, from {P5, P6}.
|
19 |
+
Defaults to P5.
|
20 |
+
deepen_factor (float): Depth multiplier, multiply number of
|
21 |
+
blocks in CSP layer by this amount. Defaults to 1.0.
|
22 |
+
widen_factor (float): Width multiplier, multiply number of
|
23 |
+
channels in each layer by this amount. Defaults to 1.0.
|
24 |
+
out_indices (Sequence[int]): Output from which stages.
|
25 |
+
Defaults to (2, 3, 4).
|
26 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
27 |
+
mode). -1 means not freezing any parameters. Defaults to -1.
|
28 |
+
plugins (list[dict]): List of plugins for stages, each dict contains:
|
29 |
+
- cfg (dict, required): Cfg dict to build plugin.
|
30 |
+
- stages (tuple[bool], optional): Stages to apply plugin, length
|
31 |
+
should be same as 'num_stages'.
|
32 |
+
arch_ovewrite (list): Overwrite default arch settings.
|
33 |
+
Defaults to None.
|
34 |
+
block_cfg (dict): Config dict for block. Defaults to
|
35 |
+
dict(type='PPYOLOEBasicBlock', shortcut=True, use_alpha=True)
|
36 |
+
norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
|
37 |
+
config norm layer. Defaults to dict(type='BN', momentum=0.1,
|
38 |
+
eps=1e-5).
|
39 |
+
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
40 |
+
Defaults to dict(type='SiLU', inplace=True).
|
41 |
+
attention_cfg (dict): Config dict for `EffectiveSELayer`.
|
42 |
+
Defaults to dict(type='EffectiveSELayer',
|
43 |
+
act_cfg=dict(type='HSigmoid')).
|
44 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
45 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
46 |
+
and its variants only.
|
47 |
+
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
|
48 |
+
list[:obj:`ConfigDict`]): Initialization config dict.
|
49 |
+
use_large_stem (bool): Whether to use large stem layer.
|
50 |
+
Defaults to False.
|
51 |
+
"""
|
52 |
+
# From left to right:
|
53 |
+
# in_channels, out_channels, num_blocks
|
54 |
+
arch_settings = {
|
55 |
+
'P5': [[64, 128, 3], [128, 256, 6], [256, 512, 6], [512, 1024, 3]]
|
56 |
+
}
|
57 |
+
|
58 |
+
def __init__(self,
|
59 |
+
arch: str = 'P5',
|
60 |
+
deepen_factor: float = 1.0,
|
61 |
+
widen_factor: float = 1.0,
|
62 |
+
input_channels: int = 3,
|
63 |
+
out_indices: Tuple[int] = (2, 3, 4),
|
64 |
+
frozen_stages: int = -1,
|
65 |
+
plugins: Union[dict, List[dict]] = None,
|
66 |
+
arch_ovewrite: dict = None,
|
67 |
+
block_cfg: ConfigType = dict(
|
68 |
+
type='PPYOLOEBasicBlock', shortcut=True, use_alpha=True),
|
69 |
+
norm_cfg: ConfigType = dict(
|
70 |
+
type='BN', momentum=0.1, eps=1e-5),
|
71 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
72 |
+
attention_cfg: ConfigType = dict(
|
73 |
+
type='EffectiveSELayer', act_cfg=dict(type='HSigmoid')),
|
74 |
+
norm_eval: bool = False,
|
75 |
+
init_cfg: OptMultiConfig = None,
|
76 |
+
use_large_stem: bool = False):
|
77 |
+
arch_setting = self.arch_settings[arch]
|
78 |
+
if arch_ovewrite:
|
79 |
+
arch_setting = arch_ovewrite
|
80 |
+
arch_setting = [[
|
81 |
+
int(in_channels * widen_factor),
|
82 |
+
int(out_channels * widen_factor),
|
83 |
+
round(num_blocks * deepen_factor)
|
84 |
+
] for in_channels, out_channels, num_blocks in arch_setting]
|
85 |
+
self.block_cfg = block_cfg
|
86 |
+
self.use_large_stem = use_large_stem
|
87 |
+
self.attention_cfg = attention_cfg
|
88 |
+
|
89 |
+
super().__init__(
|
90 |
+
arch_setting,
|
91 |
+
deepen_factor,
|
92 |
+
widen_factor,
|
93 |
+
input_channels=input_channels,
|
94 |
+
out_indices=out_indices,
|
95 |
+
plugins=plugins,
|
96 |
+
frozen_stages=frozen_stages,
|
97 |
+
norm_cfg=norm_cfg,
|
98 |
+
act_cfg=act_cfg,
|
99 |
+
norm_eval=norm_eval,
|
100 |
+
init_cfg=init_cfg)
|
101 |
+
|
102 |
+
def build_stem_layer(self) -> nn.Module:
|
103 |
+
"""Build a stem layer."""
|
104 |
+
if self.use_large_stem:
|
105 |
+
stem = nn.Sequential(
|
106 |
+
ConvModule(
|
107 |
+
self.input_channels,
|
108 |
+
self.arch_setting[0][0] // 2,
|
109 |
+
3,
|
110 |
+
stride=2,
|
111 |
+
padding=1,
|
112 |
+
act_cfg=self.act_cfg,
|
113 |
+
norm_cfg=self.norm_cfg),
|
114 |
+
ConvModule(
|
115 |
+
self.arch_setting[0][0] // 2,
|
116 |
+
self.arch_setting[0][0] // 2,
|
117 |
+
3,
|
118 |
+
stride=1,
|
119 |
+
padding=1,
|
120 |
+
norm_cfg=self.norm_cfg,
|
121 |
+
act_cfg=self.act_cfg),
|
122 |
+
ConvModule(
|
123 |
+
self.arch_setting[0][0] // 2,
|
124 |
+
self.arch_setting[0][0],
|
125 |
+
3,
|
126 |
+
stride=1,
|
127 |
+
padding=1,
|
128 |
+
norm_cfg=self.norm_cfg,
|
129 |
+
act_cfg=self.act_cfg))
|
130 |
+
else:
|
131 |
+
stem = nn.Sequential(
|
132 |
+
ConvModule(
|
133 |
+
self.input_channels,
|
134 |
+
self.arch_setting[0][0] // 2,
|
135 |
+
3,
|
136 |
+
stride=2,
|
137 |
+
padding=1,
|
138 |
+
norm_cfg=self.norm_cfg,
|
139 |
+
act_cfg=self.act_cfg),
|
140 |
+
ConvModule(
|
141 |
+
self.arch_setting[0][0] // 2,
|
142 |
+
self.arch_setting[0][0],
|
143 |
+
3,
|
144 |
+
stride=1,
|
145 |
+
padding=1,
|
146 |
+
norm_cfg=self.norm_cfg,
|
147 |
+
act_cfg=self.act_cfg))
|
148 |
+
return stem
|
149 |
+
|
150 |
+
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
151 |
+
"""Build a stage layer.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
stage_idx (int): The index of a stage layer.
|
155 |
+
setting (list): The architecture setting of a stage layer.
|
156 |
+
"""
|
157 |
+
in_channels, out_channels, num_blocks = setting
|
158 |
+
|
159 |
+
cspres_layer = CSPResLayer(
|
160 |
+
in_channels=in_channels,
|
161 |
+
out_channels=out_channels,
|
162 |
+
num_block=num_blocks,
|
163 |
+
block_cfg=self.block_cfg,
|
164 |
+
stride=2,
|
165 |
+
norm_cfg=self.norm_cfg,
|
166 |
+
act_cfg=self.act_cfg,
|
167 |
+
attention_cfg=self.attention_cfg,
|
168 |
+
use_spp=False)
|
169 |
+
return [cspres_layer]
|
mmyolo/models/backbones/cspnext.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
from typing import List, Sequence, Union
|
4 |
+
|
5 |
+
import torch.nn as nn
|
6 |
+
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
7 |
+
from mmdet.models.backbones.csp_darknet import CSPLayer
|
8 |
+
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
|
9 |
+
|
10 |
+
from mmyolo.registry import MODELS
|
11 |
+
from ..layers import SPPFBottleneck
|
12 |
+
from .base_backbone import BaseBackbone
|
13 |
+
|
14 |
+
|
15 |
+
@MODELS.register_module()
|
16 |
+
class CSPNeXt(BaseBackbone):
|
17 |
+
"""CSPNeXt backbone used in RTMDet.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
arch (str): Architecture of CSPNeXt, from {P5, P6}.
|
21 |
+
Defaults to P5.
|
22 |
+
deepen_factor (float): Depth multiplier, multiply number of
|
23 |
+
blocks in CSP layer by this amount. Defaults to 1.0.
|
24 |
+
widen_factor (float): Width multiplier, multiply number of
|
25 |
+
channels in each layer by this amount. Defaults to 1.0.
|
26 |
+
out_indices (Sequence[int]): Output from which stages.
|
27 |
+
Defaults to (2, 3, 4).
|
28 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
29 |
+
mode). -1 means not freezing any parameters. Defaults to -1.
|
30 |
+
plugins (list[dict]): List of plugins for stages, each dict contains:
|
31 |
+
- cfg (dict, required): Cfg dict to build plugin.Defaults to
|
32 |
+
- stages (tuple[bool], optional): Stages to apply plugin, length
|
33 |
+
should be same as 'num_stages'.
|
34 |
+
use_depthwise (bool): Whether to use depthwise separable convolution.
|
35 |
+
Defaults to False.
|
36 |
+
expand_ratio (float): Ratio to adjust the number of channels of the
|
37 |
+
hidden layer. Defaults to 0.5.
|
38 |
+
arch_ovewrite (list): Overwrite default arch settings.
|
39 |
+
Defaults to None.
|
40 |
+
channel_attention (bool): Whether to add channel attention in each
|
41 |
+
stage. Defaults to True.
|
42 |
+
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
|
43 |
+
convolution layer. Defaults to None.
|
44 |
+
norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
|
45 |
+
config norm layer. Defaults to dict(type='BN', requires_grad=True).
|
46 |
+
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
47 |
+
Defaults to dict(type='SiLU', inplace=True).
|
48 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
49 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
50 |
+
and its variants only.
|
51 |
+
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
|
52 |
+
list[:obj:`ConfigDict`]): Initialization config dict.
|
53 |
+
"""
|
54 |
+
# From left to right:
|
55 |
+
# in_channels, out_channels, num_blocks, add_identity, use_spp
|
56 |
+
arch_settings = {
|
57 |
+
'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
|
58 |
+
[256, 512, 6, True, False], [512, 1024, 3, False, True]],
|
59 |
+
'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
|
60 |
+
[256, 512, 6, True, False], [512, 768, 3, True, False],
|
61 |
+
[768, 1024, 3, False, True]]
|
62 |
+
}
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
arch: str = 'P5',
|
67 |
+
deepen_factor: float = 1.0,
|
68 |
+
widen_factor: float = 1.0,
|
69 |
+
input_channels: int = 3,
|
70 |
+
out_indices: Sequence[int] = (2, 3, 4),
|
71 |
+
frozen_stages: int = -1,
|
72 |
+
plugins: Union[dict, List[dict]] = None,
|
73 |
+
use_depthwise: bool = False,
|
74 |
+
expand_ratio: float = 0.5,
|
75 |
+
arch_ovewrite: dict = None,
|
76 |
+
channel_attention: bool = True,
|
77 |
+
conv_cfg: OptConfigType = None,
|
78 |
+
norm_cfg: ConfigType = dict(type='BN'),
|
79 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
80 |
+
norm_eval: bool = False,
|
81 |
+
init_cfg: OptMultiConfig = dict(
|
82 |
+
type='Kaiming',
|
83 |
+
layer='Conv2d',
|
84 |
+
a=math.sqrt(5),
|
85 |
+
distribution='uniform',
|
86 |
+
mode='fan_in',
|
87 |
+
nonlinearity='leaky_relu')
|
88 |
+
) -> None:
|
89 |
+
arch_setting = self.arch_settings[arch]
|
90 |
+
if arch_ovewrite:
|
91 |
+
arch_setting = arch_ovewrite
|
92 |
+
self.channel_attention = channel_attention
|
93 |
+
self.use_depthwise = use_depthwise
|
94 |
+
self.conv = DepthwiseSeparableConvModule \
|
95 |
+
if use_depthwise else ConvModule
|
96 |
+
self.expand_ratio = expand_ratio
|
97 |
+
self.conv_cfg = conv_cfg
|
98 |
+
|
99 |
+
super().__init__(
|
100 |
+
arch_setting,
|
101 |
+
deepen_factor,
|
102 |
+
widen_factor,
|
103 |
+
input_channels,
|
104 |
+
out_indices,
|
105 |
+
frozen_stages=frozen_stages,
|
106 |
+
plugins=plugins,
|
107 |
+
norm_cfg=norm_cfg,
|
108 |
+
act_cfg=act_cfg,
|
109 |
+
norm_eval=norm_eval,
|
110 |
+
init_cfg=init_cfg)
|
111 |
+
|
112 |
+
def build_stem_layer(self) -> nn.Module:
|
113 |
+
"""Build a stem layer."""
|
114 |
+
stem = nn.Sequential(
|
115 |
+
ConvModule(
|
116 |
+
3,
|
117 |
+
int(self.arch_setting[0][0] * self.widen_factor // 2),
|
118 |
+
3,
|
119 |
+
padding=1,
|
120 |
+
stride=2,
|
121 |
+
norm_cfg=self.norm_cfg,
|
122 |
+
act_cfg=self.act_cfg),
|
123 |
+
ConvModule(
|
124 |
+
int(self.arch_setting[0][0] * self.widen_factor // 2),
|
125 |
+
int(self.arch_setting[0][0] * self.widen_factor // 2),
|
126 |
+
3,
|
127 |
+
padding=1,
|
128 |
+
stride=1,
|
129 |
+
norm_cfg=self.norm_cfg,
|
130 |
+
act_cfg=self.act_cfg),
|
131 |
+
ConvModule(
|
132 |
+
int(self.arch_setting[0][0] * self.widen_factor // 2),
|
133 |
+
int(self.arch_setting[0][0] * self.widen_factor),
|
134 |
+
3,
|
135 |
+
padding=1,
|
136 |
+
stride=1,
|
137 |
+
norm_cfg=self.norm_cfg,
|
138 |
+
act_cfg=self.act_cfg))
|
139 |
+
return stem
|
140 |
+
|
141 |
+
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
142 |
+
"""Build a stage layer.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
stage_idx (int): The index of a stage layer.
|
146 |
+
setting (list): The architecture setting of a stage layer.
|
147 |
+
"""
|
148 |
+
in_channels, out_channels, num_blocks, add_identity, use_spp = setting
|
149 |
+
|
150 |
+
in_channels = int(in_channels * self.widen_factor)
|
151 |
+
out_channels = int(out_channels * self.widen_factor)
|
152 |
+
num_blocks = max(round(num_blocks * self.deepen_factor), 1)
|
153 |
+
|
154 |
+
stage = []
|
155 |
+
conv_layer = self.conv(
|
156 |
+
in_channels,
|
157 |
+
out_channels,
|
158 |
+
3,
|
159 |
+
stride=2,
|
160 |
+
padding=1,
|
161 |
+
conv_cfg=self.conv_cfg,
|
162 |
+
norm_cfg=self.norm_cfg,
|
163 |
+
act_cfg=self.act_cfg)
|
164 |
+
stage.append(conv_layer)
|
165 |
+
if use_spp:
|
166 |
+
spp = SPPFBottleneck(
|
167 |
+
out_channels,
|
168 |
+
out_channels,
|
169 |
+
kernel_sizes=5,
|
170 |
+
conv_cfg=self.conv_cfg,
|
171 |
+
norm_cfg=self.norm_cfg,
|
172 |
+
act_cfg=self.act_cfg)
|
173 |
+
stage.append(spp)
|
174 |
+
csp_layer = CSPLayer(
|
175 |
+
out_channels,
|
176 |
+
out_channels,
|
177 |
+
num_blocks=num_blocks,
|
178 |
+
add_identity=add_identity,
|
179 |
+
use_depthwise=self.use_depthwise,
|
180 |
+
use_cspnext_block=True,
|
181 |
+
expand_ratio=self.expand_ratio,
|
182 |
+
channel_attention=self.channel_attention,
|
183 |
+
conv_cfg=self.conv_cfg,
|
184 |
+
norm_cfg=self.norm_cfg,
|
185 |
+
act_cfg=self.act_cfg)
|
186 |
+
stage.append(csp_layer)
|
187 |
+
return stage
|
mmyolo/models/backbones/efficient_rep.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
|
3 |
+
from typing import List, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from mmdet.utils import ConfigType, OptMultiConfig
|
8 |
+
|
9 |
+
from mmyolo.models.layers.yolo_bricks import SPPFBottleneck
|
10 |
+
from mmyolo.registry import MODELS
|
11 |
+
from ..layers import BepC3StageBlock, RepStageBlock
|
12 |
+
from ..utils import make_round
|
13 |
+
from .base_backbone import BaseBackbone
|
14 |
+
|
15 |
+
|
16 |
+
@MODELS.register_module()
|
17 |
+
class YOLOv6EfficientRep(BaseBackbone):
|
18 |
+
"""EfficientRep backbone used in YOLOv6.
|
19 |
+
Args:
|
20 |
+
arch (str): Architecture of BaseDarknet, from {P5, P6}.
|
21 |
+
Defaults to P5.
|
22 |
+
plugins (list[dict]): List of plugins for stages, each dict contains:
|
23 |
+
- cfg (dict, required): Cfg dict to build plugin.
|
24 |
+
- stages (tuple[bool], optional): Stages to apply plugin, length
|
25 |
+
should be same as 'num_stages'.
|
26 |
+
deepen_factor (float): Depth multiplier, multiply number of
|
27 |
+
blocks in CSP layer by this amount. Defaults to 1.0.
|
28 |
+
widen_factor (float): Width multiplier, multiply number of
|
29 |
+
channels in each layer by this amount. Defaults to 1.0.
|
30 |
+
input_channels (int): Number of input image channels. Defaults to 3.
|
31 |
+
out_indices (Tuple[int]): Output from which stages.
|
32 |
+
Defaults to (2, 3, 4).
|
33 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
34 |
+
mode). -1 means not freezing any parameters. Defaults to -1.
|
35 |
+
norm_cfg (dict): Dictionary to construct and config norm layer.
|
36 |
+
Defaults to dict(type='BN', requires_grad=True).
|
37 |
+
act_cfg (dict): Config dict for activation layer.
|
38 |
+
Defaults to dict(type='LeakyReLU', negative_slope=0.1).
|
39 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
40 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
41 |
+
and its variants only. Defaults to False.
|
42 |
+
block_cfg (dict): Config dict for the block used to build each
|
43 |
+
layer. Defaults to dict(type='RepVGGBlock').
|
44 |
+
init_cfg (Union[dict, list[dict]], optional): Initialization config
|
45 |
+
dict. Defaults to None.
|
46 |
+
Example:
|
47 |
+
>>> from mmyolo.models import YOLOv6EfficientRep
|
48 |
+
>>> import torch
|
49 |
+
>>> model = YOLOv6EfficientRep()
|
50 |
+
>>> model.eval()
|
51 |
+
>>> inputs = torch.rand(1, 3, 416, 416)
|
52 |
+
>>> level_outputs = model(inputs)
|
53 |
+
>>> for level_out in level_outputs:
|
54 |
+
... print(tuple(level_out.shape))
|
55 |
+
...
|
56 |
+
(1, 256, 52, 52)
|
57 |
+
(1, 512, 26, 26)
|
58 |
+
(1, 1024, 13, 13)
|
59 |
+
"""
|
60 |
+
# From left to right:
|
61 |
+
# in_channels, out_channels, num_blocks, use_spp
|
62 |
+
arch_settings = {
|
63 |
+
'P5': [[64, 128, 6, False], [128, 256, 12, False],
|
64 |
+
[256, 512, 18, False], [512, 1024, 6, True]]
|
65 |
+
}
|
66 |
+
|
67 |
+
def __init__(self,
|
68 |
+
arch: str = 'P5',
|
69 |
+
plugins: Union[dict, List[dict]] = None,
|
70 |
+
deepen_factor: float = 1.0,
|
71 |
+
widen_factor: float = 1.0,
|
72 |
+
input_channels: int = 3,
|
73 |
+
out_indices: Tuple[int] = (2, 3, 4),
|
74 |
+
frozen_stages: int = -1,
|
75 |
+
norm_cfg: ConfigType = dict(
|
76 |
+
type='BN', momentum=0.03, eps=0.001),
|
77 |
+
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
|
78 |
+
norm_eval: bool = False,
|
79 |
+
block_cfg: ConfigType = dict(type='RepVGGBlock'),
|
80 |
+
init_cfg: OptMultiConfig = None):
|
81 |
+
self.block_cfg = block_cfg
|
82 |
+
super().__init__(
|
83 |
+
self.arch_settings[arch],
|
84 |
+
deepen_factor,
|
85 |
+
widen_factor,
|
86 |
+
input_channels=input_channels,
|
87 |
+
out_indices=out_indices,
|
88 |
+
plugins=plugins,
|
89 |
+
frozen_stages=frozen_stages,
|
90 |
+
norm_cfg=norm_cfg,
|
91 |
+
act_cfg=act_cfg,
|
92 |
+
norm_eval=norm_eval,
|
93 |
+
init_cfg=init_cfg)
|
94 |
+
|
95 |
+
def build_stem_layer(self) -> nn.Module:
|
96 |
+
"""Build a stem layer."""
|
97 |
+
|
98 |
+
block_cfg = self.block_cfg.copy()
|
99 |
+
block_cfg.update(
|
100 |
+
dict(
|
101 |
+
in_channels=self.input_channels,
|
102 |
+
out_channels=int(self.arch_setting[0][0] * self.widen_factor),
|
103 |
+
kernel_size=3,
|
104 |
+
stride=2,
|
105 |
+
))
|
106 |
+
return MODELS.build(block_cfg)
|
107 |
+
|
108 |
+
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
109 |
+
"""Build a stage layer.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
stage_idx (int): The index of a stage layer.
|
113 |
+
setting (list): The architecture setting of a stage layer.
|
114 |
+
"""
|
115 |
+
in_channels, out_channels, num_blocks, use_spp = setting
|
116 |
+
|
117 |
+
in_channels = int(in_channels * self.widen_factor)
|
118 |
+
out_channels = int(out_channels * self.widen_factor)
|
119 |
+
num_blocks = make_round(num_blocks, self.deepen_factor)
|
120 |
+
|
121 |
+
rep_stage_block = RepStageBlock(
|
122 |
+
in_channels=out_channels,
|
123 |
+
out_channels=out_channels,
|
124 |
+
num_blocks=num_blocks,
|
125 |
+
block_cfg=self.block_cfg,
|
126 |
+
)
|
127 |
+
|
128 |
+
block_cfg = self.block_cfg.copy()
|
129 |
+
block_cfg.update(
|
130 |
+
dict(
|
131 |
+
in_channels=in_channels,
|
132 |
+
out_channels=out_channels,
|
133 |
+
kernel_size=3,
|
134 |
+
stride=2))
|
135 |
+
stage = []
|
136 |
+
|
137 |
+
ef_block = nn.Sequential(MODELS.build(block_cfg), rep_stage_block)
|
138 |
+
|
139 |
+
stage.append(ef_block)
|
140 |
+
|
141 |
+
if use_spp:
|
142 |
+
spp = SPPFBottleneck(
|
143 |
+
in_channels=out_channels,
|
144 |
+
out_channels=out_channels,
|
145 |
+
kernel_sizes=5,
|
146 |
+
norm_cfg=self.norm_cfg,
|
147 |
+
act_cfg=self.act_cfg)
|
148 |
+
stage.append(spp)
|
149 |
+
return stage
|
150 |
+
|
151 |
+
def init_weights(self):
|
152 |
+
if self.init_cfg is None:
|
153 |
+
"""Initialize the parameters."""
|
154 |
+
for m in self.modules():
|
155 |
+
if isinstance(m, torch.nn.Conv2d):
|
156 |
+
# In order to be consistent with the source code,
|
157 |
+
# reset the Conv2d initialization parameters
|
158 |
+
m.reset_parameters()
|
159 |
+
else:
|
160 |
+
super().init_weights()
|
161 |
+
|
162 |
+
|
163 |
+
@MODELS.register_module()
|
164 |
+
class YOLOv6CSPBep(YOLOv6EfficientRep):
|
165 |
+
"""CSPBep backbone used in YOLOv6.
|
166 |
+
Args:
|
167 |
+
arch (str): Architecture of BaseDarknet, from {P5, P6}.
|
168 |
+
Defaults to P5.
|
169 |
+
plugins (list[dict]): List of plugins for stages, each dict contains:
|
170 |
+
- cfg (dict, required): Cfg dict to build plugin.
|
171 |
+
- stages (tuple[bool], optional): Stages to apply plugin, length
|
172 |
+
should be same as 'num_stages'.
|
173 |
+
deepen_factor (float): Depth multiplier, multiply number of
|
174 |
+
blocks in CSP layer by this amount. Defaults to 1.0.
|
175 |
+
widen_factor (float): Width multiplier, multiply number of
|
176 |
+
channels in each layer by this amount. Defaults to 1.0.
|
177 |
+
input_channels (int): Number of input image channels. Defaults to 3.
|
178 |
+
out_indices (Tuple[int]): Output from which stages.
|
179 |
+
Defaults to (2, 3, 4).
|
180 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
181 |
+
mode). -1 means not freezing any parameters. Defaults to -1.
|
182 |
+
norm_cfg (dict): Dictionary to construct and config norm layer.
|
183 |
+
Defaults to dict(type='BN', requires_grad=True).
|
184 |
+
act_cfg (dict): Config dict for activation layer.
|
185 |
+
Defaults to dict(type='LeakyReLU', negative_slope=0.1).
|
186 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
187 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
188 |
+
and its variants only. Defaults to False.
|
189 |
+
block_cfg (dict): Config dict for the block used to build each
|
190 |
+
layer. Defaults to dict(type='RepVGGBlock').
|
191 |
+
block_act_cfg (dict): Config dict for activation layer used in each
|
192 |
+
stage. Defaults to dict(type='SiLU', inplace=True).
|
193 |
+
init_cfg (Union[dict, list[dict]], optional): Initialization config
|
194 |
+
dict. Defaults to None.
|
195 |
+
Example:
|
196 |
+
>>> from mmyolo.models import YOLOv6CSPBep
|
197 |
+
>>> import torch
|
198 |
+
>>> model = YOLOv6CSPBep()
|
199 |
+
>>> model.eval()
|
200 |
+
>>> inputs = torch.rand(1, 3, 416, 416)
|
201 |
+
>>> level_outputs = model(inputs)
|
202 |
+
>>> for level_out in level_outputs:
|
203 |
+
... print(tuple(level_out.shape))
|
204 |
+
...
|
205 |
+
(1, 256, 52, 52)
|
206 |
+
(1, 512, 26, 26)
|
207 |
+
(1, 1024, 13, 13)
|
208 |
+
"""
|
209 |
+
# From left to right:
|
210 |
+
# in_channels, out_channels, num_blocks, use_spp
|
211 |
+
arch_settings = {
|
212 |
+
'P5': [[64, 128, 6, False], [128, 256, 12, False],
|
213 |
+
[256, 512, 18, False], [512, 1024, 6, True]]
|
214 |
+
}
|
215 |
+
|
216 |
+
def __init__(self,
|
217 |
+
arch: str = 'P5',
|
218 |
+
plugins: Union[dict, List[dict]] = None,
|
219 |
+
deepen_factor: float = 1.0,
|
220 |
+
widen_factor: float = 1.0,
|
221 |
+
input_channels: int = 3,
|
222 |
+
hidden_ratio: float = 0.5,
|
223 |
+
out_indices: Tuple[int] = (2, 3, 4),
|
224 |
+
frozen_stages: int = -1,
|
225 |
+
norm_cfg: ConfigType = dict(
|
226 |
+
type='BN', momentum=0.03, eps=0.001),
|
227 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
228 |
+
norm_eval: bool = False,
|
229 |
+
block_cfg: ConfigType = dict(type='ConvWrapper'),
|
230 |
+
init_cfg: OptMultiConfig = None):
|
231 |
+
self.hidden_ratio = hidden_ratio
|
232 |
+
super().__init__(
|
233 |
+
arch=arch,
|
234 |
+
deepen_factor=deepen_factor,
|
235 |
+
widen_factor=widen_factor,
|
236 |
+
input_channels=input_channels,
|
237 |
+
out_indices=out_indices,
|
238 |
+
plugins=plugins,
|
239 |
+
frozen_stages=frozen_stages,
|
240 |
+
norm_cfg=norm_cfg,
|
241 |
+
act_cfg=act_cfg,
|
242 |
+
norm_eval=norm_eval,
|
243 |
+
block_cfg=block_cfg,
|
244 |
+
init_cfg=init_cfg)
|
245 |
+
|
246 |
+
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
247 |
+
"""Build a stage layer.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
stage_idx (int): The index of a stage layer.
|
251 |
+
setting (list): The architecture setting of a stage layer.
|
252 |
+
"""
|
253 |
+
in_channels, out_channels, num_blocks, use_spp = setting
|
254 |
+
in_channels = int(in_channels * self.widen_factor)
|
255 |
+
out_channels = int(out_channels * self.widen_factor)
|
256 |
+
num_blocks = make_round(num_blocks, self.deepen_factor)
|
257 |
+
|
258 |
+
rep_stage_block = BepC3StageBlock(
|
259 |
+
in_channels=out_channels,
|
260 |
+
out_channels=out_channels,
|
261 |
+
num_blocks=num_blocks,
|
262 |
+
hidden_ratio=self.hidden_ratio,
|
263 |
+
block_cfg=self.block_cfg,
|
264 |
+
norm_cfg=self.norm_cfg,
|
265 |
+
act_cfg=self.act_cfg)
|
266 |
+
block_cfg = self.block_cfg.copy()
|
267 |
+
block_cfg.update(
|
268 |
+
dict(
|
269 |
+
in_channels=in_channels,
|
270 |
+
out_channels=out_channels,
|
271 |
+
kernel_size=3,
|
272 |
+
stride=2))
|
273 |
+
stage = []
|
274 |
+
|
275 |
+
ef_block = nn.Sequential(MODELS.build(block_cfg), rep_stage_block)
|
276 |
+
|
277 |
+
stage.append(ef_block)
|
278 |
+
|
279 |
+
if use_spp:
|
280 |
+
spp = SPPFBottleneck(
|
281 |
+
in_channels=out_channels,
|
282 |
+
out_channels=out_channels,
|
283 |
+
kernel_sizes=5,
|
284 |
+
norm_cfg=self.norm_cfg,
|
285 |
+
act_cfg=self.act_cfg)
|
286 |
+
stage.append(spp)
|
287 |
+
return stage
|
mmyolo/models/backbones/yolov7_backbone.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from mmcv.cnn import ConvModule
|
6 |
+
from mmdet.models.backbones.csp_darknet import Focus
|
7 |
+
from mmdet.utils import ConfigType, OptMultiConfig
|
8 |
+
|
9 |
+
from mmyolo.registry import MODELS
|
10 |
+
from ..layers import MaxPoolAndStrideConvBlock
|
11 |
+
from .base_backbone import BaseBackbone
|
12 |
+
|
13 |
+
|
14 |
+
@MODELS.register_module()
|
15 |
+
class YOLOv7Backbone(BaseBackbone):
|
16 |
+
"""Backbone used in YOLOv7.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
arch (str): Architecture of YOLOv7Defaults to L.
|
20 |
+
deepen_factor (float): Depth multiplier, multiply number of
|
21 |
+
blocks in CSP layer by this amount. Defaults to 1.0.
|
22 |
+
widen_factor (float): Width multiplier, multiply number of
|
23 |
+
channels in each layer by this amount. Defaults to 1.0.
|
24 |
+
out_indices (Sequence[int]): Output from which stages.
|
25 |
+
Defaults to (2, 3, 4).
|
26 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval
|
27 |
+
mode). -1 means not freezing any parameters. Defaults to -1.
|
28 |
+
plugins (list[dict]): List of plugins for stages, each dict contains:
|
29 |
+
|
30 |
+
- cfg (dict, required): Cfg dict to build plugin.
|
31 |
+
- stages (tuple[bool], optional): Stages to apply plugin, length
|
32 |
+
should be same as 'num_stages'.
|
33 |
+
norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
|
34 |
+
config norm layer. Defaults to dict(type='BN', requires_grad=True).
|
35 |
+
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
36 |
+
Defaults to dict(type='SiLU', inplace=True).
|
37 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
38 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
39 |
+
and its variants only.
|
40 |
+
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
|
41 |
+
list[:obj:`ConfigDict`]): Initialization config dict.
|
42 |
+
"""
|
43 |
+
_tiny_stage1_cfg = dict(type='TinyDownSampleBlock', middle_ratio=0.5)
|
44 |
+
_tiny_stage2_4_cfg = dict(type='TinyDownSampleBlock', middle_ratio=1.0)
|
45 |
+
_l_expand_channel_2x = dict(
|
46 |
+
type='ELANBlock',
|
47 |
+
middle_ratio=0.5,
|
48 |
+
block_ratio=0.5,
|
49 |
+
num_blocks=2,
|
50 |
+
num_convs_in_block=2)
|
51 |
+
_l_no_change_channel = dict(
|
52 |
+
type='ELANBlock',
|
53 |
+
middle_ratio=0.25,
|
54 |
+
block_ratio=0.25,
|
55 |
+
num_blocks=2,
|
56 |
+
num_convs_in_block=2)
|
57 |
+
_x_expand_channel_2x = dict(
|
58 |
+
type='ELANBlock',
|
59 |
+
middle_ratio=0.4,
|
60 |
+
block_ratio=0.4,
|
61 |
+
num_blocks=3,
|
62 |
+
num_convs_in_block=2)
|
63 |
+
_x_no_change_channel = dict(
|
64 |
+
type='ELANBlock',
|
65 |
+
middle_ratio=0.2,
|
66 |
+
block_ratio=0.2,
|
67 |
+
num_blocks=3,
|
68 |
+
num_convs_in_block=2)
|
69 |
+
_w_no_change_channel = dict(
|
70 |
+
type='ELANBlock',
|
71 |
+
middle_ratio=0.5,
|
72 |
+
block_ratio=0.5,
|
73 |
+
num_blocks=2,
|
74 |
+
num_convs_in_block=2)
|
75 |
+
_e_no_change_channel = dict(
|
76 |
+
type='ELANBlock',
|
77 |
+
middle_ratio=0.4,
|
78 |
+
block_ratio=0.4,
|
79 |
+
num_blocks=3,
|
80 |
+
num_convs_in_block=2)
|
81 |
+
_d_no_change_channel = dict(
|
82 |
+
type='ELANBlock',
|
83 |
+
middle_ratio=1 / 3,
|
84 |
+
block_ratio=1 / 3,
|
85 |
+
num_blocks=4,
|
86 |
+
num_convs_in_block=2)
|
87 |
+
_e2e_no_change_channel = dict(
|
88 |
+
type='EELANBlock',
|
89 |
+
num_elan_block=2,
|
90 |
+
middle_ratio=0.4,
|
91 |
+
block_ratio=0.4,
|
92 |
+
num_blocks=3,
|
93 |
+
num_convs_in_block=2)
|
94 |
+
|
95 |
+
# From left to right:
|
96 |
+
# in_channels, out_channels, Block_params
|
97 |
+
arch_settings = {
|
98 |
+
'Tiny': [[64, 64, _tiny_stage1_cfg], [64, 128, _tiny_stage2_4_cfg],
|
99 |
+
[128, 256, _tiny_stage2_4_cfg],
|
100 |
+
[256, 512, _tiny_stage2_4_cfg]],
|
101 |
+
'L': [[64, 256, _l_expand_channel_2x],
|
102 |
+
[256, 512, _l_expand_channel_2x],
|
103 |
+
[512, 1024, _l_expand_channel_2x],
|
104 |
+
[1024, 1024, _l_no_change_channel]],
|
105 |
+
'X': [[80, 320, _x_expand_channel_2x],
|
106 |
+
[320, 640, _x_expand_channel_2x],
|
107 |
+
[640, 1280, _x_expand_channel_2x],
|
108 |
+
[1280, 1280, _x_no_change_channel]],
|
109 |
+
'W':
|
110 |
+
[[64, 128, _w_no_change_channel], [128, 256, _w_no_change_channel],
|
111 |
+
[256, 512, _w_no_change_channel], [512, 768, _w_no_change_channel],
|
112 |
+
[768, 1024, _w_no_change_channel]],
|
113 |
+
'E':
|
114 |
+
[[80, 160, _e_no_change_channel], [160, 320, _e_no_change_channel],
|
115 |
+
[320, 640, _e_no_change_channel], [640, 960, _e_no_change_channel],
|
116 |
+
[960, 1280, _e_no_change_channel]],
|
117 |
+
'D': [[96, 192,
|
118 |
+
_d_no_change_channel], [192, 384, _d_no_change_channel],
|
119 |
+
[384, 768, _d_no_change_channel],
|
120 |
+
[768, 1152, _d_no_change_channel],
|
121 |
+
[1152, 1536, _d_no_change_channel]],
|
122 |
+
'E2E': [[80, 160, _e2e_no_change_channel],
|
123 |
+
[160, 320, _e2e_no_change_channel],
|
124 |
+
[320, 640, _e2e_no_change_channel],
|
125 |
+
[640, 960, _e2e_no_change_channel],
|
126 |
+
[960, 1280, _e2e_no_change_channel]],
|
127 |
+
}
|
128 |
+
|
129 |
+
def __init__(self,
|
130 |
+
arch: str = 'L',
|
131 |
+
deepen_factor: float = 1.0,
|
132 |
+
widen_factor: float = 1.0,
|
133 |
+
input_channels: int = 3,
|
134 |
+
out_indices: Tuple[int] = (2, 3, 4),
|
135 |
+
frozen_stages: int = -1,
|
136 |
+
plugins: Union[dict, List[dict]] = None,
|
137 |
+
norm_cfg: ConfigType = dict(
|
138 |
+
type='BN', momentum=0.03, eps=0.001),
|
139 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
140 |
+
norm_eval: bool = False,
|
141 |
+
init_cfg: OptMultiConfig = None):
|
142 |
+
assert arch in self.arch_settings.keys()
|
143 |
+
self.arch = arch
|
144 |
+
super().__init__(
|
145 |
+
self.arch_settings[arch],
|
146 |
+
deepen_factor,
|
147 |
+
widen_factor,
|
148 |
+
input_channels=input_channels,
|
149 |
+
out_indices=out_indices,
|
150 |
+
plugins=plugins,
|
151 |
+
frozen_stages=frozen_stages,
|
152 |
+
norm_cfg=norm_cfg,
|
153 |
+
act_cfg=act_cfg,
|
154 |
+
norm_eval=norm_eval,
|
155 |
+
init_cfg=init_cfg)
|
156 |
+
|
157 |
+
def build_stem_layer(self) -> nn.Module:
|
158 |
+
"""Build a stem layer."""
|
159 |
+
if self.arch in ['L', 'X']:
|
160 |
+
stem = nn.Sequential(
|
161 |
+
ConvModule(
|
162 |
+
3,
|
163 |
+
int(self.arch_setting[0][0] * self.widen_factor // 2),
|
164 |
+
3,
|
165 |
+
padding=1,
|
166 |
+
stride=1,
|
167 |
+
norm_cfg=self.norm_cfg,
|
168 |
+
act_cfg=self.act_cfg),
|
169 |
+
ConvModule(
|
170 |
+
int(self.arch_setting[0][0] * self.widen_factor // 2),
|
171 |
+
int(self.arch_setting[0][0] * self.widen_factor),
|
172 |
+
3,
|
173 |
+
padding=1,
|
174 |
+
stride=2,
|
175 |
+
norm_cfg=self.norm_cfg,
|
176 |
+
act_cfg=self.act_cfg),
|
177 |
+
ConvModule(
|
178 |
+
int(self.arch_setting[0][0] * self.widen_factor),
|
179 |
+
int(self.arch_setting[0][0] * self.widen_factor),
|
180 |
+
3,
|
181 |
+
padding=1,
|
182 |
+
stride=1,
|
183 |
+
norm_cfg=self.norm_cfg,
|
184 |
+
act_cfg=self.act_cfg))
|
185 |
+
elif self.arch == 'Tiny':
|
186 |
+
stem = nn.Sequential(
|
187 |
+
ConvModule(
|
188 |
+
3,
|
189 |
+
int(self.arch_setting[0][0] * self.widen_factor // 2),
|
190 |
+
3,
|
191 |
+
padding=1,
|
192 |
+
stride=2,
|
193 |
+
norm_cfg=self.norm_cfg,
|
194 |
+
act_cfg=self.act_cfg),
|
195 |
+
ConvModule(
|
196 |
+
int(self.arch_setting[0][0] * self.widen_factor // 2),
|
197 |
+
int(self.arch_setting[0][0] * self.widen_factor),
|
198 |
+
3,
|
199 |
+
padding=1,
|
200 |
+
stride=2,
|
201 |
+
norm_cfg=self.norm_cfg,
|
202 |
+
act_cfg=self.act_cfg))
|
203 |
+
elif self.arch in ['W', 'E', 'D', 'E2E']:
|
204 |
+
stem = Focus(
|
205 |
+
3,
|
206 |
+
int(self.arch_setting[0][0] * self.widen_factor),
|
207 |
+
kernel_size=3,
|
208 |
+
norm_cfg=self.norm_cfg,
|
209 |
+
act_cfg=self.act_cfg)
|
210 |
+
return stem
|
211 |
+
|
212 |
+
def build_stage_layer(self, stage_idx: int, setting: list) -> list:
|
213 |
+
"""Build a stage layer.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
stage_idx (int): The index of a stage layer.
|
217 |
+
setting (list): The architecture setting of a stage layer.
|
218 |
+
"""
|
219 |
+
in_channels, out_channels, stage_block_cfg = setting
|
220 |
+
in_channels = int(in_channels * self.widen_factor)
|
221 |
+
out_channels = int(out_channels * self.widen_factor)
|
222 |
+
|
223 |
+
stage_block_cfg = stage_block_cfg.copy()
|
224 |
+
stage_block_cfg.setdefault('norm_cfg', self.norm_cfg)
|
225 |
+
stage_block_cfg.setdefault('act_cfg', self.act_cfg)
|
226 |
+
|
227 |
+
stage_block_cfg['in_channels'] = in_channels
|
228 |
+
stage_block_cfg['out_channels'] = out_channels
|
229 |
+
|
230 |
+
stage = []
|
231 |
+
if self.arch in ['W', 'E', 'D', 'E2E']:
|
232 |
+
stage_block_cfg['in_channels'] = out_channels
|
233 |
+
elif self.arch in ['L', 'X']:
|
234 |
+
if stage_idx == 0:
|
235 |
+
stage_block_cfg['in_channels'] = out_channels // 2
|
236 |
+
|
237 |
+
downsample_layer = self._build_downsample_layer(
|
238 |
+
stage_idx, in_channels, out_channels)
|
239 |
+
stage.append(MODELS.build(stage_block_cfg))
|
240 |
+
if downsample_layer is not None:
|
241 |
+
stage.insert(0, downsample_layer)
|
242 |
+
return stage
|
243 |
+
|
244 |
+
def _build_downsample_layer(self, stage_idx: int, in_channels: int,
|
245 |
+
out_channels: int) -> Optional[nn.Module]:
|
246 |
+
"""Build a downsample layer pre stage."""
|
247 |
+
if self.arch in ['E', 'D', 'E2E']:
|
248 |
+
downsample_layer = MaxPoolAndStrideConvBlock(
|
249 |
+
in_channels,
|
250 |
+
out_channels,
|
251 |
+
use_in_channels_of_middle=True,
|
252 |
+
norm_cfg=self.norm_cfg,
|
253 |
+
act_cfg=self.act_cfg)
|
254 |
+
elif self.arch == 'W':
|
255 |
+
downsample_layer = ConvModule(
|
256 |
+
in_channels,
|
257 |
+
out_channels,
|
258 |
+
3,
|
259 |
+
stride=2,
|
260 |
+
padding=1,
|
261 |
+
norm_cfg=self.norm_cfg,
|
262 |
+
act_cfg=self.act_cfg)
|
263 |
+
elif self.arch == 'Tiny':
|
264 |
+
if stage_idx != 0:
|
265 |
+
downsample_layer = nn.MaxPool2d(2, 2)
|
266 |
+
else:
|
267 |
+
downsample_layer = None
|
268 |
+
elif self.arch in ['L', 'X']:
|
269 |
+
if stage_idx == 0:
|
270 |
+
downsample_layer = ConvModule(
|
271 |
+
in_channels,
|
272 |
+
out_channels // 2,
|
273 |
+
3,
|
274 |
+
stride=2,
|
275 |
+
padding=1,
|
276 |
+
norm_cfg=self.norm_cfg,
|
277 |
+
act_cfg=self.act_cfg)
|
278 |
+
else:
|
279 |
+
downsample_layer = MaxPoolAndStrideConvBlock(
|
280 |
+
in_channels,
|
281 |
+
in_channels,
|
282 |
+
use_in_channels_of_middle=False,
|
283 |
+
norm_cfg=self.norm_cfg,
|
284 |
+
act_cfg=self.act_cfg)
|
285 |
+
return downsample_layer
|
mmyolo/models/data_preprocessors/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .data_preprocessor import (PPYOLOEBatchRandomResize,
|
3 |
+
PPYOLOEDetDataPreprocessor,
|
4 |
+
YOLOv5DetDataPreprocessor,
|
5 |
+
YOLOXBatchSyncRandomResize)
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'YOLOv5DetDataPreprocessor', 'PPYOLOEDetDataPreprocessor',
|
9 |
+
'PPYOLOEBatchRandomResize', 'YOLOXBatchSyncRandomResize'
|
10 |
+
]
|
mmyolo/models/data_preprocessors/data_preprocessor.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import random
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from mmdet.models import BatchSyncRandomResize
|
8 |
+
from mmdet.models.data_preprocessors import DetDataPreprocessor
|
9 |
+
from mmengine import MessageHub, is_list_of
|
10 |
+
from mmengine.structures import BaseDataElement
|
11 |
+
from torch import Tensor
|
12 |
+
|
13 |
+
from mmyolo.registry import MODELS
|
14 |
+
|
15 |
+
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str,
|
16 |
+
None]
|
17 |
+
|
18 |
+
|
19 |
+
@MODELS.register_module()
|
20 |
+
class YOLOXBatchSyncRandomResize(BatchSyncRandomResize):
|
21 |
+
"""YOLOX batch random resize.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
random_size_range (tuple): The multi-scale random range during
|
25 |
+
multi-scale training.
|
26 |
+
interval (int): The iter interval of change
|
27 |
+
image size. Defaults to 10.
|
28 |
+
size_divisor (int): Image size divisible factor.
|
29 |
+
Defaults to 32.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def forward(self, inputs: Tensor, data_samples: dict) -> Tensor and dict:
|
33 |
+
"""resize a batch of images and bboxes to shape ``self._input_size``"""
|
34 |
+
h, w = inputs.shape[-2:]
|
35 |
+
inputs = inputs.float()
|
36 |
+
assert isinstance(data_samples, dict)
|
37 |
+
|
38 |
+
if self._input_size is None:
|
39 |
+
self._input_size = (h, w)
|
40 |
+
scale_y = self._input_size[0] / h
|
41 |
+
scale_x = self._input_size[1] / w
|
42 |
+
if scale_x != 1 or scale_y != 1:
|
43 |
+
inputs = F.interpolate(
|
44 |
+
inputs,
|
45 |
+
size=self._input_size,
|
46 |
+
mode='bilinear',
|
47 |
+
align_corners=False)
|
48 |
+
|
49 |
+
data_samples['bboxes_labels'][:, 2::2] *= scale_x
|
50 |
+
data_samples['bboxes_labels'][:, 3::2] *= scale_y
|
51 |
+
|
52 |
+
message_hub = MessageHub.get_current_instance()
|
53 |
+
if (message_hub.get_info('iter') + 1) % self._interval == 0:
|
54 |
+
self._input_size = self._get_random_size(
|
55 |
+
aspect_ratio=float(w / h), device=inputs.device)
|
56 |
+
|
57 |
+
return inputs, data_samples
|
58 |
+
|
59 |
+
|
60 |
+
@MODELS.register_module()
|
61 |
+
class YOLOv5DetDataPreprocessor(DetDataPreprocessor):
|
62 |
+
"""Rewrite collate_fn to get faster training speed.
|
63 |
+
|
64 |
+
Note: It must be used together with `mmyolo.datasets.utils.yolov5_collate`
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, *args, non_blocking: Optional[bool] = True, **kwargs):
|
68 |
+
super().__init__(*args, non_blocking=non_blocking, **kwargs)
|
69 |
+
|
70 |
+
def forward(self, data: dict, training: bool = False) -> dict:
|
71 |
+
"""Perform normalization, padding and bgr2rgb conversion based on
|
72 |
+
``DetDataPreprocessorr``.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
data (dict): Data sampled from dataloader.
|
76 |
+
training (bool): Whether to enable training time augmentation.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
dict: Data in the same format as the model input.
|
80 |
+
"""
|
81 |
+
if not training:
|
82 |
+
return super().forward(data, training)
|
83 |
+
|
84 |
+
data = self.cast_data(data)
|
85 |
+
inputs, data_samples = data['inputs'], data['data_samples']
|
86 |
+
assert isinstance(data['data_samples'], dict)
|
87 |
+
|
88 |
+
# TODO: Supports multi-scale training
|
89 |
+
if self._channel_conversion and inputs.shape[1] == 3:
|
90 |
+
inputs = inputs[:, [2, 1, 0], ...]
|
91 |
+
if self._enable_normalize:
|
92 |
+
inputs = (inputs - self.mean) / self.std
|
93 |
+
|
94 |
+
if self.batch_augments is not None:
|
95 |
+
for batch_aug in self.batch_augments:
|
96 |
+
inputs, data_samples = batch_aug(inputs, data_samples)
|
97 |
+
|
98 |
+
img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
|
99 |
+
data_samples_output = {
|
100 |
+
'bboxes_labels': data_samples['bboxes_labels'],
|
101 |
+
'img_metas': img_metas
|
102 |
+
}
|
103 |
+
if 'masks' in data_samples:
|
104 |
+
data_samples_output['masks'] = data_samples['masks']
|
105 |
+
|
106 |
+
return {'inputs': inputs, 'data_samples': data_samples_output}
|
107 |
+
|
108 |
+
|
109 |
+
@MODELS.register_module()
|
110 |
+
class PPYOLOEDetDataPreprocessor(DetDataPreprocessor):
|
111 |
+
"""Image pre-processor for detection tasks.
|
112 |
+
|
113 |
+
The main difference between PPYOLOEDetDataPreprocessor and
|
114 |
+
DetDataPreprocessor is the normalization order. The official
|
115 |
+
PPYOLOE resize image first, and then normalize image.
|
116 |
+
In DetDataPreprocessor, the order is reversed.
|
117 |
+
|
118 |
+
Note: It must be used together with
|
119 |
+
`mmyolo.datasets.utils.yolov5_collate`
|
120 |
+
"""
|
121 |
+
|
122 |
+
def forward(self, data: dict, training: bool = False) -> dict:
|
123 |
+
"""Perform normalization、padding and bgr2rgb conversion based on
|
124 |
+
``BaseDataPreprocessor``. This class use batch_augments first, and then
|
125 |
+
normalize the image, which is different from the `DetDataPreprocessor`
|
126 |
+
.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
data (dict): Data sampled from dataloader.
|
130 |
+
training (bool): Whether to enable training time augmentation.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
dict: Data in the same format as the model input.
|
134 |
+
"""
|
135 |
+
if not training:
|
136 |
+
return super().forward(data, training)
|
137 |
+
|
138 |
+
assert isinstance(data['inputs'], list) and is_list_of(
|
139 |
+
data['inputs'], torch.Tensor), \
|
140 |
+
'"inputs" should be a list of Tensor, but got ' \
|
141 |
+
f'{type(data["inputs"])}. The possible reason for this ' \
|
142 |
+
'is that you are not using it with ' \
|
143 |
+
'"mmyolo.datasets.utils.yolov5_collate". Please refer to ' \
|
144 |
+
'"cconfigs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py".'
|
145 |
+
|
146 |
+
data = self.cast_data(data)
|
147 |
+
inputs, data_samples = data['inputs'], data['data_samples']
|
148 |
+
assert isinstance(data['data_samples'], dict)
|
149 |
+
|
150 |
+
# Process data.
|
151 |
+
batch_inputs = []
|
152 |
+
for _input in inputs:
|
153 |
+
# channel transform
|
154 |
+
if self._channel_conversion:
|
155 |
+
_input = _input[[2, 1, 0], ...]
|
156 |
+
# Convert to float after channel conversion to ensure
|
157 |
+
# efficiency
|
158 |
+
_input = _input.float()
|
159 |
+
batch_inputs.append(_input)
|
160 |
+
|
161 |
+
# Batch random resize image.
|
162 |
+
if self.batch_augments is not None:
|
163 |
+
for batch_aug in self.batch_augments:
|
164 |
+
inputs, data_samples = batch_aug(batch_inputs, data_samples)
|
165 |
+
|
166 |
+
if self._enable_normalize:
|
167 |
+
inputs = (inputs - self.mean) / self.std
|
168 |
+
|
169 |
+
img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
|
170 |
+
data_samples = {
|
171 |
+
'bboxes_labels': data_samples['bboxes_labels'],
|
172 |
+
'img_metas': img_metas
|
173 |
+
}
|
174 |
+
|
175 |
+
return {'inputs': inputs, 'data_samples': data_samples}
|
176 |
+
|
177 |
+
|
178 |
+
# TODO: No generality. Its input data format is different
|
179 |
+
# mmdet's batch aug, and it must be compatible in the future.
|
180 |
+
@MODELS.register_module()
|
181 |
+
class PPYOLOEBatchRandomResize(BatchSyncRandomResize):
|
182 |
+
"""PPYOLOE batch random resize.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
random_size_range (tuple): The multi-scale random range during
|
186 |
+
multi-scale training.
|
187 |
+
interval (int): The iter interval of change
|
188 |
+
image size. Defaults to 10.
|
189 |
+
size_divisor (int): Image size divisible factor.
|
190 |
+
Defaults to 32.
|
191 |
+
random_interp (bool): Whether to choose interp_mode randomly.
|
192 |
+
If set to True, the type of `interp_mode` must be list.
|
193 |
+
If set to False, the type of `interp_mode` must be str.
|
194 |
+
Defaults to True.
|
195 |
+
interp_mode (Union[List, str]): The modes available for resizing
|
196 |
+
are ('nearest', 'bilinear', 'bicubic', 'area').
|
197 |
+
keep_ratio (bool): Whether to keep the aspect ratio when resizing
|
198 |
+
the image. Now we only support keep_ratio=False.
|
199 |
+
Defaults to False.
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(self,
|
203 |
+
random_size_range: Tuple[int, int],
|
204 |
+
interval: int = 1,
|
205 |
+
size_divisor: int = 32,
|
206 |
+
random_interp=True,
|
207 |
+
interp_mode: Union[List[str], str] = [
|
208 |
+
'nearest', 'bilinear', 'bicubic', 'area'
|
209 |
+
],
|
210 |
+
keep_ratio: bool = False) -> None:
|
211 |
+
super().__init__(random_size_range, interval, size_divisor)
|
212 |
+
self.random_interp = random_interp
|
213 |
+
self.keep_ratio = keep_ratio
|
214 |
+
# TODO: need to support keep_ratio==True
|
215 |
+
assert not self.keep_ratio, 'We do not yet support keep_ratio=True'
|
216 |
+
|
217 |
+
if self.random_interp:
|
218 |
+
assert isinstance(interp_mode, list) and len(interp_mode) > 1,\
|
219 |
+
'While random_interp==True, the type of `interp_mode`' \
|
220 |
+
' must be list and len(interp_mode) must large than 1'
|
221 |
+
self.interp_mode_list = interp_mode
|
222 |
+
self.interp_mode = None
|
223 |
+
else:
|
224 |
+
assert isinstance(interp_mode, str),\
|
225 |
+
'While random_interp==False, the type of ' \
|
226 |
+
'`interp_mode` must be str'
|
227 |
+
assert interp_mode in ['nearest', 'bilinear', 'bicubic', 'area']
|
228 |
+
self.interp_mode_list = None
|
229 |
+
self.interp_mode = interp_mode
|
230 |
+
|
231 |
+
def forward(self, inputs: list,
|
232 |
+
data_samples: dict) -> Tuple[Tensor, Tensor]:
|
233 |
+
"""Resize a batch of images and bboxes to shape ``self._input_size``.
|
234 |
+
|
235 |
+
The inputs and data_samples should be list, and
|
236 |
+
``PPYOLOEBatchRandomResize`` must be used with
|
237 |
+
``PPYOLOEDetDataPreprocessor`` and ``yolov5_collate`` with
|
238 |
+
``use_ms_training == True``.
|
239 |
+
"""
|
240 |
+
assert isinstance(inputs, list),\
|
241 |
+
'The type of inputs must be list. The possible reason for this ' \
|
242 |
+
'is that you are not using it with `PPYOLOEDetDataPreprocessor` ' \
|
243 |
+
'and `yolov5_collate` with use_ms_training == True.'
|
244 |
+
|
245 |
+
bboxes_labels = data_samples['bboxes_labels']
|
246 |
+
|
247 |
+
message_hub = MessageHub.get_current_instance()
|
248 |
+
if (message_hub.get_info('iter') + 1) % self._interval == 0:
|
249 |
+
# get current input size
|
250 |
+
self._input_size, interp_mode = self._get_random_size_and_interp()
|
251 |
+
if self.random_interp:
|
252 |
+
self.interp_mode = interp_mode
|
253 |
+
|
254 |
+
# TODO: need to support type(inputs)==Tensor
|
255 |
+
if isinstance(inputs, list):
|
256 |
+
outputs = []
|
257 |
+
for i in range(len(inputs)):
|
258 |
+
_batch_input = inputs[i]
|
259 |
+
h, w = _batch_input.shape[-2:]
|
260 |
+
scale_y = self._input_size[0] / h
|
261 |
+
scale_x = self._input_size[1] / w
|
262 |
+
if scale_x != 1. or scale_y != 1.:
|
263 |
+
if self.interp_mode in ('nearest', 'area'):
|
264 |
+
align_corners = None
|
265 |
+
else:
|
266 |
+
align_corners = False
|
267 |
+
_batch_input = F.interpolate(
|
268 |
+
_batch_input.unsqueeze(0),
|
269 |
+
size=self._input_size,
|
270 |
+
mode=self.interp_mode,
|
271 |
+
align_corners=align_corners)
|
272 |
+
|
273 |
+
# rescale boxes
|
274 |
+
indexes = bboxes_labels[:, 0] == i
|
275 |
+
bboxes_labels[indexes, 2] *= scale_x
|
276 |
+
bboxes_labels[indexes, 3] *= scale_y
|
277 |
+
bboxes_labels[indexes, 4] *= scale_x
|
278 |
+
bboxes_labels[indexes, 5] *= scale_y
|
279 |
+
|
280 |
+
data_samples['bboxes_labels'] = bboxes_labels
|
281 |
+
else:
|
282 |
+
_batch_input = _batch_input.unsqueeze(0)
|
283 |
+
|
284 |
+
outputs.append(_batch_input)
|
285 |
+
|
286 |
+
# convert to Tensor
|
287 |
+
return torch.cat(outputs, dim=0), data_samples
|
288 |
+
else:
|
289 |
+
raise NotImplementedError('Not implemented yet!')
|
290 |
+
|
291 |
+
def _get_random_size_and_interp(self) -> Tuple[int, int]:
|
292 |
+
"""Randomly generate a shape in ``_random_size_range`` and a
|
293 |
+
interp_mode in interp_mode_list."""
|
294 |
+
size = random.randint(*self._random_size_range)
|
295 |
+
input_size = (self._size_divisor * size, self._size_divisor * size)
|
296 |
+
|
297 |
+
if self.random_interp:
|
298 |
+
interp_ind = random.randint(0, len(self.interp_mode_list) - 1)
|
299 |
+
interp_mode = self.interp_mode_list[interp_ind]
|
300 |
+
else:
|
301 |
+
interp_mode = None
|
302 |
+
return input_size, interp_mode
|
mmyolo/models/dense_heads/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .ppyoloe_head import PPYOLOEHead, PPYOLOEHeadModule
|
3 |
+
from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
|
4 |
+
from .rtmdet_ins_head import RTMDetInsSepBNHead, RTMDetInsSepBNHeadModule
|
5 |
+
from .rtmdet_rotated_head import (RTMDetRotatedHead,
|
6 |
+
RTMDetRotatedSepBNHeadModule)
|
7 |
+
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
|
8 |
+
from .yolov6_head import YOLOv6Head, YOLOv6HeadModule
|
9 |
+
from .yolov7_head import YOLOv7Head, YOLOv7HeadModule, YOLOv7p6HeadModule
|
10 |
+
from .yolov8_head import YOLOv8Head, YOLOv8HeadModule
|
11 |
+
from .yolox_head import YOLOXHead, YOLOXHeadModule
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
'YOLOv5Head', 'YOLOv6Head', 'YOLOXHead', 'YOLOv5HeadModule',
|
15 |
+
'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead',
|
16 |
+
'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule',
|
17 |
+
'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule',
|
18 |
+
'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule', 'RTMDetInsSepBNHead',
|
19 |
+
'RTMDetInsSepBNHeadModule'
|
20 |
+
]
|
mmyolo/models/dense_heads/ppyoloe_head.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import Sequence, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from mmdet.models.utils import multi_apply
|
8 |
+
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
|
9 |
+
OptMultiConfig, reduce_mean)
|
10 |
+
from mmengine import MessageHub
|
11 |
+
from mmengine.model import BaseModule, bias_init_with_prob
|
12 |
+
from mmengine.structures import InstanceData
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
from mmyolo.registry import MODELS
|
16 |
+
from ..layers.yolo_bricks import PPYOLOESELayer
|
17 |
+
from ..utils import gt_instances_preprocess
|
18 |
+
from .yolov6_head import YOLOv6Head
|
19 |
+
|
20 |
+
|
21 |
+
@MODELS.register_module()
|
22 |
+
class PPYOLOEHeadModule(BaseModule):
|
23 |
+
"""PPYOLOEHead head module used in `PPYOLOE.
|
24 |
+
|
25 |
+
<https://arxiv.org/abs/2203.16250>`_.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
num_classes (int): Number of categories excluding the background
|
29 |
+
category.
|
30 |
+
in_channels (int): Number of channels in the input feature map.
|
31 |
+
widen_factor (float): Width multiplier, multiply number of
|
32 |
+
channels in each layer by this amount. Defaults to 1.0.
|
33 |
+
num_base_priors (int): The number of priors (points) at a point
|
34 |
+
on the feature grid.
|
35 |
+
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
36 |
+
Defaults to (8, 16, 32).
|
37 |
+
reg_max (int): Max value of integral set :math: ``{0, ..., reg_max}``
|
38 |
+
in QFL setting. Defaults to 16.
|
39 |
+
norm_cfg (dict): Config dict for normalization layer.
|
40 |
+
Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
41 |
+
act_cfg (dict): Config dict for activation layer.
|
42 |
+
Defaults to dict(type='SiLU', inplace=True).
|
43 |
+
init_cfg (dict or list[dict], optional): Initialization config dict.
|
44 |
+
Defaults to None.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self,
|
48 |
+
num_classes: int,
|
49 |
+
in_channels: Union[int, Sequence],
|
50 |
+
widen_factor: float = 1.0,
|
51 |
+
num_base_priors: int = 1,
|
52 |
+
featmap_strides: Sequence[int] = (8, 16, 32),
|
53 |
+
reg_max: int = 16,
|
54 |
+
norm_cfg: ConfigType = dict(
|
55 |
+
type='BN', momentum=0.1, eps=1e-5),
|
56 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
57 |
+
init_cfg: OptMultiConfig = None):
|
58 |
+
super().__init__(init_cfg=init_cfg)
|
59 |
+
|
60 |
+
self.num_classes = num_classes
|
61 |
+
self.featmap_strides = featmap_strides
|
62 |
+
self.num_levels = len(self.featmap_strides)
|
63 |
+
self.num_base_priors = num_base_priors
|
64 |
+
self.norm_cfg = norm_cfg
|
65 |
+
self.act_cfg = act_cfg
|
66 |
+
self.reg_max = reg_max
|
67 |
+
|
68 |
+
if isinstance(in_channels, int):
|
69 |
+
self.in_channels = [int(in_channels * widen_factor)
|
70 |
+
] * self.num_levels
|
71 |
+
else:
|
72 |
+
self.in_channels = [int(i * widen_factor) for i in in_channels]
|
73 |
+
|
74 |
+
self._init_layers()
|
75 |
+
|
76 |
+
def init_weights(self, prior_prob=0.01):
|
77 |
+
"""Initialize the weight and bias of PPYOLOE head."""
|
78 |
+
super().init_weights()
|
79 |
+
for conv in self.cls_preds:
|
80 |
+
conv.bias.data.fill_(bias_init_with_prob(prior_prob))
|
81 |
+
conv.weight.data.fill_(0.)
|
82 |
+
|
83 |
+
for conv in self.reg_preds:
|
84 |
+
conv.bias.data.fill_(1.0)
|
85 |
+
conv.weight.data.fill_(0.)
|
86 |
+
|
87 |
+
def _init_layers(self):
|
88 |
+
"""initialize conv layers in PPYOLOE head."""
|
89 |
+
self.cls_preds = nn.ModuleList()
|
90 |
+
self.reg_preds = nn.ModuleList()
|
91 |
+
self.cls_stems = nn.ModuleList()
|
92 |
+
self.reg_stems = nn.ModuleList()
|
93 |
+
|
94 |
+
for in_channel in self.in_channels:
|
95 |
+
self.cls_stems.append(
|
96 |
+
PPYOLOESELayer(
|
97 |
+
in_channel, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
98 |
+
self.reg_stems.append(
|
99 |
+
PPYOLOESELayer(
|
100 |
+
in_channel, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
|
101 |
+
|
102 |
+
for in_channel in self.in_channels:
|
103 |
+
self.cls_preds.append(
|
104 |
+
nn.Conv2d(in_channel, self.num_classes, 3, padding=1))
|
105 |
+
self.reg_preds.append(
|
106 |
+
nn.Conv2d(in_channel, 4 * (self.reg_max + 1), 3, padding=1))
|
107 |
+
|
108 |
+
# init proj
|
109 |
+
proj = torch.linspace(0, self.reg_max, self.reg_max + 1).view(
|
110 |
+
[1, self.reg_max + 1, 1, 1])
|
111 |
+
self.register_buffer('proj', proj, persistent=False)
|
112 |
+
|
113 |
+
def forward(self, x: Tuple[Tensor]) -> Tensor:
|
114 |
+
"""Forward features from the upstream network.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
x (Tuple[Tensor]): Features from the upstream network, each is
|
118 |
+
a 4D-tensor.
|
119 |
+
Returns:
|
120 |
+
Tuple[List]: A tuple of multi-level classification scores, bbox
|
121 |
+
predictions.
|
122 |
+
"""
|
123 |
+
assert len(x) == self.num_levels
|
124 |
+
|
125 |
+
return multi_apply(self.forward_single, x, self.cls_stems,
|
126 |
+
self.cls_preds, self.reg_stems, self.reg_preds)
|
127 |
+
|
128 |
+
def forward_single(self, x: Tensor, cls_stem: nn.ModuleList,
|
129 |
+
cls_pred: nn.ModuleList, reg_stem: nn.ModuleList,
|
130 |
+
reg_pred: nn.ModuleList) -> Tensor:
|
131 |
+
"""Forward feature of a single scale level."""
|
132 |
+
b, _, h, w = x.shape
|
133 |
+
hw = h * w
|
134 |
+
avg_feat = F.adaptive_avg_pool2d(x, (1, 1))
|
135 |
+
cls_logit = cls_pred(cls_stem(x, avg_feat) + x)
|
136 |
+
bbox_dist_preds = reg_pred(reg_stem(x, avg_feat))
|
137 |
+
# TODO: Test whether use matmul instead of conv can speed up training.
|
138 |
+
bbox_dist_preds = bbox_dist_preds.reshape(
|
139 |
+
[-1, 4, self.reg_max + 1, hw]).permute(0, 2, 3, 1)
|
140 |
+
|
141 |
+
bbox_preds = F.conv2d(F.softmax(bbox_dist_preds, dim=1), self.proj)
|
142 |
+
|
143 |
+
if self.training:
|
144 |
+
return cls_logit, bbox_preds, bbox_dist_preds
|
145 |
+
else:
|
146 |
+
return cls_logit, bbox_preds
|
147 |
+
|
148 |
+
|
149 |
+
@MODELS.register_module()
|
150 |
+
class PPYOLOEHead(YOLOv6Head):
|
151 |
+
"""PPYOLOEHead head used in `PPYOLOE <https://arxiv.org/abs/2203.16250>`_.
|
152 |
+
The YOLOv6 head and the PPYOLOE head are only slightly different.
|
153 |
+
Distribution focal loss is extra used in PPYOLOE, but not in YOLOv6.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
head_module(ConfigType): Base module used for YOLOv5Head
|
157 |
+
prior_generator(dict): Points generator feature maps in
|
158 |
+
2D points-based detectors.
|
159 |
+
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
|
160 |
+
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
161 |
+
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
162 |
+
loss_dfl (:obj:`ConfigDict` or dict): Config of distribution focal
|
163 |
+
loss.
|
164 |
+
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
165 |
+
anchor head. Defaults to None.
|
166 |
+
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
167 |
+
anchor head. Defaults to None.
|
168 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
169 |
+
list[dict], optional): Initialization config dict.
|
170 |
+
Defaults to None.
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self,
|
174 |
+
head_module: ConfigType,
|
175 |
+
prior_generator: ConfigType = dict(
|
176 |
+
type='mmdet.MlvlPointGenerator',
|
177 |
+
offset=0.5,
|
178 |
+
strides=[8, 16, 32]),
|
179 |
+
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
180 |
+
loss_cls: ConfigType = dict(
|
181 |
+
type='mmdet.VarifocalLoss',
|
182 |
+
use_sigmoid=True,
|
183 |
+
alpha=0.75,
|
184 |
+
gamma=2.0,
|
185 |
+
iou_weighted=True,
|
186 |
+
reduction='sum',
|
187 |
+
loss_weight=1.0),
|
188 |
+
loss_bbox: ConfigType = dict(
|
189 |
+
type='IoULoss',
|
190 |
+
iou_mode='giou',
|
191 |
+
bbox_format='xyxy',
|
192 |
+
reduction='mean',
|
193 |
+
loss_weight=2.5,
|
194 |
+
return_iou=False),
|
195 |
+
loss_dfl: ConfigType = dict(
|
196 |
+
type='mmdet.DistributionFocalLoss',
|
197 |
+
reduction='mean',
|
198 |
+
loss_weight=0.5 / 4),
|
199 |
+
train_cfg: OptConfigType = None,
|
200 |
+
test_cfg: OptConfigType = None,
|
201 |
+
init_cfg: OptMultiConfig = None):
|
202 |
+
super().__init__(
|
203 |
+
head_module=head_module,
|
204 |
+
prior_generator=prior_generator,
|
205 |
+
bbox_coder=bbox_coder,
|
206 |
+
loss_cls=loss_cls,
|
207 |
+
loss_bbox=loss_bbox,
|
208 |
+
train_cfg=train_cfg,
|
209 |
+
test_cfg=test_cfg,
|
210 |
+
init_cfg=init_cfg)
|
211 |
+
self.loss_dfl = MODELS.build(loss_dfl)
|
212 |
+
# ppyoloe doesn't need loss_obj
|
213 |
+
self.loss_obj = None
|
214 |
+
|
215 |
+
def loss_by_feat(
|
216 |
+
self,
|
217 |
+
cls_scores: Sequence[Tensor],
|
218 |
+
bbox_preds: Sequence[Tensor],
|
219 |
+
bbox_dist_preds: Sequence[Tensor],
|
220 |
+
batch_gt_instances: Sequence[InstanceData],
|
221 |
+
batch_img_metas: Sequence[dict],
|
222 |
+
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
223 |
+
"""Calculate the loss based on the features extracted by the detection
|
224 |
+
head.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
cls_scores (Sequence[Tensor]): Box scores for each scale level,
|
228 |
+
each is a 4D-tensor, the channel number is
|
229 |
+
num_priors * num_classes.
|
230 |
+
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
|
231 |
+
level, each is a 4D-tensor, the channel number is
|
232 |
+
num_priors * 4.
|
233 |
+
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
|
234 |
+
each scale level with shape (bs, reg_max + 1, H*W, 4).
|
235 |
+
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
236 |
+
gt_instance. It usually includes ``bboxes`` and ``labels``
|
237 |
+
attributes.
|
238 |
+
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
239 |
+
image size, scaling factor, etc.
|
240 |
+
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
|
241 |
+
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
242 |
+
data that is ignored during training and testing.
|
243 |
+
Defaults to None.
|
244 |
+
Returns:
|
245 |
+
dict[str, Tensor]: A dictionary of losses.
|
246 |
+
"""
|
247 |
+
|
248 |
+
# get epoch information from message hub
|
249 |
+
message_hub = MessageHub.get_current_instance()
|
250 |
+
current_epoch = message_hub.get_info('epoch')
|
251 |
+
|
252 |
+
num_imgs = len(batch_img_metas)
|
253 |
+
|
254 |
+
current_featmap_sizes = [
|
255 |
+
cls_score.shape[2:] for cls_score in cls_scores
|
256 |
+
]
|
257 |
+
# If the shape does not equal, generate new one
|
258 |
+
if current_featmap_sizes != self.featmap_sizes_train:
|
259 |
+
self.featmap_sizes_train = current_featmap_sizes
|
260 |
+
|
261 |
+
mlvl_priors_with_stride = self.prior_generator.grid_priors(
|
262 |
+
self.featmap_sizes_train,
|
263 |
+
dtype=cls_scores[0].dtype,
|
264 |
+
device=cls_scores[0].device,
|
265 |
+
with_stride=True)
|
266 |
+
|
267 |
+
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
|
268 |
+
self.flatten_priors_train = torch.cat(
|
269 |
+
mlvl_priors_with_stride, dim=0)
|
270 |
+
self.stride_tensor = self.flatten_priors_train[..., [2]]
|
271 |
+
|
272 |
+
# gt info
|
273 |
+
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
|
274 |
+
gt_labels = gt_info[:, :, :1]
|
275 |
+
gt_bboxes = gt_info[:, :, 1:] # xyxy
|
276 |
+
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
|
277 |
+
|
278 |
+
# pred info
|
279 |
+
flatten_cls_preds = [
|
280 |
+
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
281 |
+
self.num_classes)
|
282 |
+
for cls_pred in cls_scores
|
283 |
+
]
|
284 |
+
flatten_pred_bboxes = [
|
285 |
+
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
286 |
+
for bbox_pred in bbox_preds
|
287 |
+
]
|
288 |
+
# (bs, reg_max+1, n, 4) -> (bs, n, 4, reg_max+1)
|
289 |
+
flatten_pred_dists = [
|
290 |
+
bbox_pred_org.permute(0, 2, 3, 1).reshape(
|
291 |
+
num_imgs, -1, (self.head_module.reg_max + 1) * 4)
|
292 |
+
for bbox_pred_org in bbox_dist_preds
|
293 |
+
]
|
294 |
+
|
295 |
+
flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
|
296 |
+
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
|
297 |
+
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
|
298 |
+
flatten_pred_bboxes = self.bbox_coder.decode(
|
299 |
+
self.flatten_priors_train[..., :2], flatten_pred_bboxes,
|
300 |
+
self.stride_tensor[..., 0])
|
301 |
+
pred_scores = torch.sigmoid(flatten_cls_preds)
|
302 |
+
|
303 |
+
if current_epoch < self.initial_epoch:
|
304 |
+
assigned_result = self.initial_assigner(
|
305 |
+
flatten_pred_bboxes.detach(), self.flatten_priors_train,
|
306 |
+
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
|
307 |
+
else:
|
308 |
+
assigned_result = self.assigner(flatten_pred_bboxes.detach(),
|
309 |
+
pred_scores.detach(),
|
310 |
+
self.flatten_priors_train,
|
311 |
+
gt_labels, gt_bboxes,
|
312 |
+
pad_bbox_flag)
|
313 |
+
|
314 |
+
assigned_bboxes = assigned_result['assigned_bboxes']
|
315 |
+
assigned_scores = assigned_result['assigned_scores']
|
316 |
+
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
|
317 |
+
|
318 |
+
# cls loss
|
319 |
+
with torch.cuda.amp.autocast(enabled=False):
|
320 |
+
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores)
|
321 |
+
|
322 |
+
# rescale bbox
|
323 |
+
assigned_bboxes /= self.stride_tensor
|
324 |
+
flatten_pred_bboxes /= self.stride_tensor
|
325 |
+
|
326 |
+
assigned_scores_sum = assigned_scores.sum()
|
327 |
+
# reduce_mean between all gpus
|
328 |
+
assigned_scores_sum = torch.clamp(
|
329 |
+
reduce_mean(assigned_scores_sum), min=1)
|
330 |
+
loss_cls /= assigned_scores_sum
|
331 |
+
|
332 |
+
# select positive samples mask
|
333 |
+
num_pos = fg_mask_pre_prior.sum()
|
334 |
+
if num_pos > 0:
|
335 |
+
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
|
336 |
+
# will not report an error
|
337 |
+
# iou loss
|
338 |
+
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
|
339 |
+
pred_bboxes_pos = torch.masked_select(
|
340 |
+
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
|
341 |
+
assigned_bboxes_pos = torch.masked_select(
|
342 |
+
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
|
343 |
+
bbox_weight = torch.masked_select(
|
344 |
+
assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
|
345 |
+
loss_bbox = self.loss_bbox(
|
346 |
+
pred_bboxes_pos,
|
347 |
+
assigned_bboxes_pos,
|
348 |
+
weight=bbox_weight,
|
349 |
+
avg_factor=assigned_scores_sum)
|
350 |
+
|
351 |
+
# dfl loss
|
352 |
+
dist_mask = fg_mask_pre_prior.unsqueeze(-1).repeat(
|
353 |
+
[1, 1, (self.head_module.reg_max + 1) * 4])
|
354 |
+
|
355 |
+
pred_dist_pos = torch.masked_select(
|
356 |
+
flatten_dist_preds,
|
357 |
+
dist_mask).reshape([-1, 4, self.head_module.reg_max + 1])
|
358 |
+
assigned_ltrb = self.bbox_coder.encode(
|
359 |
+
self.flatten_priors_train[..., :2] / self.stride_tensor,
|
360 |
+
assigned_bboxes,
|
361 |
+
max_dis=self.head_module.reg_max,
|
362 |
+
eps=0.01)
|
363 |
+
assigned_ltrb_pos = torch.masked_select(
|
364 |
+
assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
|
365 |
+
loss_dfl = self.loss_dfl(
|
366 |
+
pred_dist_pos.reshape(-1, self.head_module.reg_max + 1),
|
367 |
+
assigned_ltrb_pos.reshape(-1),
|
368 |
+
weight=bbox_weight.expand(-1, 4).reshape(-1),
|
369 |
+
avg_factor=assigned_scores_sum)
|
370 |
+
else:
|
371 |
+
loss_bbox = flatten_pred_bboxes.sum() * 0
|
372 |
+
loss_dfl = flatten_pred_bboxes.sum() * 0
|
373 |
+
|
374 |
+
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
|
mmyolo/models/dense_heads/rtmdet_head.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List, Sequence, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from mmcv.cnn import ConvModule, is_norm
|
7 |
+
from mmdet.models.task_modules.samplers import PseudoSampler
|
8 |
+
from mmdet.structures.bbox import distance2bbox
|
9 |
+
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
|
10 |
+
OptInstanceList, OptMultiConfig, reduce_mean)
|
11 |
+
from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
|
12 |
+
normal_init)
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
from mmyolo.registry import MODELS, TASK_UTILS
|
16 |
+
from ..utils import gt_instances_preprocess
|
17 |
+
from .yolov5_head import YOLOv5Head
|
18 |
+
|
19 |
+
|
20 |
+
@MODELS.register_module()
|
21 |
+
class RTMDetSepBNHeadModule(BaseModule):
|
22 |
+
"""Detection Head of RTMDet.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
num_classes (int): Number of categories excluding the background
|
26 |
+
category.
|
27 |
+
in_channels (int): Number of channels in the input feature map.
|
28 |
+
widen_factor (float): Width multiplier, multiply number of
|
29 |
+
channels in each layer by this amount. Defaults to 1.0.
|
30 |
+
num_base_priors (int): The number of priors (points) at a point
|
31 |
+
on the feature grid. Defaults to 1.
|
32 |
+
feat_channels (int): Number of hidden channels. Used in child classes.
|
33 |
+
Defaults to 256
|
34 |
+
stacked_convs (int): Number of stacking convs of the head.
|
35 |
+
Defaults to 2.
|
36 |
+
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
37 |
+
Defaults to (8, 16, 32).
|
38 |
+
share_conv (bool): Whether to share conv layers between stages.
|
39 |
+
Defaults to True.
|
40 |
+
pred_kernel_size (int): Kernel size of ``nn.Conv2d``. Defaults to 1.
|
41 |
+
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
|
42 |
+
convolution layer. Defaults to None.
|
43 |
+
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
|
44 |
+
layer. Defaults to ``dict(type='BN')``.
|
45 |
+
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
46 |
+
Default: dict(type='SiLU', inplace=True).
|
47 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
48 |
+
list[dict], optional): Initialization config dict.
|
49 |
+
Defaults to None.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
num_classes: int,
|
55 |
+
in_channels: int,
|
56 |
+
widen_factor: float = 1.0,
|
57 |
+
num_base_priors: int = 1,
|
58 |
+
feat_channels: int = 256,
|
59 |
+
stacked_convs: int = 2,
|
60 |
+
featmap_strides: Sequence[int] = [8, 16, 32],
|
61 |
+
share_conv: bool = True,
|
62 |
+
pred_kernel_size: int = 1,
|
63 |
+
conv_cfg: OptConfigType = None,
|
64 |
+
norm_cfg: ConfigType = dict(type='BN'),
|
65 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
66 |
+
init_cfg: OptMultiConfig = None,
|
67 |
+
):
|
68 |
+
super().__init__(init_cfg=init_cfg)
|
69 |
+
self.share_conv = share_conv
|
70 |
+
self.num_classes = num_classes
|
71 |
+
self.pred_kernel_size = pred_kernel_size
|
72 |
+
self.feat_channels = int(feat_channels * widen_factor)
|
73 |
+
self.stacked_convs = stacked_convs
|
74 |
+
self.num_base_priors = num_base_priors
|
75 |
+
|
76 |
+
self.conv_cfg = conv_cfg
|
77 |
+
self.norm_cfg = norm_cfg
|
78 |
+
self.act_cfg = act_cfg
|
79 |
+
self.featmap_strides = featmap_strides
|
80 |
+
|
81 |
+
self.in_channels = int(in_channels * widen_factor)
|
82 |
+
|
83 |
+
self._init_layers()
|
84 |
+
|
85 |
+
def _init_layers(self):
|
86 |
+
"""Initialize layers of the head."""
|
87 |
+
self.cls_convs = nn.ModuleList()
|
88 |
+
self.reg_convs = nn.ModuleList()
|
89 |
+
|
90 |
+
self.rtm_cls = nn.ModuleList()
|
91 |
+
self.rtm_reg = nn.ModuleList()
|
92 |
+
for n in range(len(self.featmap_strides)):
|
93 |
+
cls_convs = nn.ModuleList()
|
94 |
+
reg_convs = nn.ModuleList()
|
95 |
+
for i in range(self.stacked_convs):
|
96 |
+
chn = self.in_channels if i == 0 else self.feat_channels
|
97 |
+
cls_convs.append(
|
98 |
+
ConvModule(
|
99 |
+
chn,
|
100 |
+
self.feat_channels,
|
101 |
+
3,
|
102 |
+
stride=1,
|
103 |
+
padding=1,
|
104 |
+
conv_cfg=self.conv_cfg,
|
105 |
+
norm_cfg=self.norm_cfg,
|
106 |
+
act_cfg=self.act_cfg))
|
107 |
+
reg_convs.append(
|
108 |
+
ConvModule(
|
109 |
+
chn,
|
110 |
+
self.feat_channels,
|
111 |
+
3,
|
112 |
+
stride=1,
|
113 |
+
padding=1,
|
114 |
+
conv_cfg=self.conv_cfg,
|
115 |
+
norm_cfg=self.norm_cfg,
|
116 |
+
act_cfg=self.act_cfg))
|
117 |
+
self.cls_convs.append(cls_convs)
|
118 |
+
self.reg_convs.append(reg_convs)
|
119 |
+
|
120 |
+
self.rtm_cls.append(
|
121 |
+
nn.Conv2d(
|
122 |
+
self.feat_channels,
|
123 |
+
self.num_base_priors * self.num_classes,
|
124 |
+
self.pred_kernel_size,
|
125 |
+
padding=self.pred_kernel_size // 2))
|
126 |
+
self.rtm_reg.append(
|
127 |
+
nn.Conv2d(
|
128 |
+
self.feat_channels,
|
129 |
+
self.num_base_priors * 4,
|
130 |
+
self.pred_kernel_size,
|
131 |
+
padding=self.pred_kernel_size // 2))
|
132 |
+
|
133 |
+
if self.share_conv:
|
134 |
+
for n in range(len(self.featmap_strides)):
|
135 |
+
for i in range(self.stacked_convs):
|
136 |
+
self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
|
137 |
+
self.reg_convs[n][i].conv = self.reg_convs[0][i].conv
|
138 |
+
|
139 |
+
def init_weights(self) -> None:
|
140 |
+
"""Initialize weights of the head."""
|
141 |
+
# Use prior in model initialization to improve stability
|
142 |
+
super().init_weights()
|
143 |
+
for m in self.modules():
|
144 |
+
if isinstance(m, nn.Conv2d):
|
145 |
+
normal_init(m, mean=0, std=0.01)
|
146 |
+
if is_norm(m):
|
147 |
+
constant_init(m, 1)
|
148 |
+
bias_cls = bias_init_with_prob(0.01)
|
149 |
+
for rtm_cls, rtm_reg in zip(self.rtm_cls, self.rtm_reg):
|
150 |
+
normal_init(rtm_cls, std=0.01, bias=bias_cls)
|
151 |
+
normal_init(rtm_reg, std=0.01)
|
152 |
+
|
153 |
+
def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
|
154 |
+
"""Forward features from the upstream network.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
feats (tuple[Tensor]): Features from the upstream network, each is
|
158 |
+
a 4D-tensor.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
tuple: Usually a tuple of classification scores and bbox prediction
|
162 |
+
- cls_scores (list[Tensor]): Classification scores for all scale
|
163 |
+
levels, each is a 4D-tensor, the channels number is
|
164 |
+
num_base_priors * num_classes.
|
165 |
+
- bbox_preds (list[Tensor]): Box energies / deltas for all scale
|
166 |
+
levels, each is a 4D-tensor, the channels number is
|
167 |
+
num_base_priors * 4.
|
168 |
+
"""
|
169 |
+
|
170 |
+
cls_scores = []
|
171 |
+
bbox_preds = []
|
172 |
+
for idx, x in enumerate(feats):
|
173 |
+
cls_feat = x
|
174 |
+
reg_feat = x
|
175 |
+
|
176 |
+
for cls_layer in self.cls_convs[idx]:
|
177 |
+
cls_feat = cls_layer(cls_feat)
|
178 |
+
cls_score = self.rtm_cls[idx](cls_feat)
|
179 |
+
|
180 |
+
for reg_layer in self.reg_convs[idx]:
|
181 |
+
reg_feat = reg_layer(reg_feat)
|
182 |
+
|
183 |
+
reg_dist = self.rtm_reg[idx](reg_feat)
|
184 |
+
cls_scores.append(cls_score)
|
185 |
+
bbox_preds.append(reg_dist)
|
186 |
+
return tuple(cls_scores), tuple(bbox_preds)
|
187 |
+
|
188 |
+
|
189 |
+
@MODELS.register_module()
|
190 |
+
class RTMDetHead(YOLOv5Head):
|
191 |
+
"""RTMDet head.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
head_module(ConfigType): Base module used for RTMDetHead
|
195 |
+
prior_generator: Points generator feature maps in
|
196 |
+
2D points-based detectors.
|
197 |
+
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
|
198 |
+
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
199 |
+
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
200 |
+
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
201 |
+
anchor head. Defaults to None.
|
202 |
+
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
203 |
+
anchor head. Defaults to None.
|
204 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
205 |
+
list[dict], optional): Initialization config dict.
|
206 |
+
Defaults to None.
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self,
|
210 |
+
head_module: ConfigType,
|
211 |
+
prior_generator: ConfigType = dict(
|
212 |
+
type='mmdet.MlvlPointGenerator',
|
213 |
+
offset=0,
|
214 |
+
strides=[8, 16, 32]),
|
215 |
+
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
216 |
+
loss_cls: ConfigType = dict(
|
217 |
+
type='mmdet.QualityFocalLoss',
|
218 |
+
use_sigmoid=True,
|
219 |
+
beta=2.0,
|
220 |
+
loss_weight=1.0),
|
221 |
+
loss_bbox: ConfigType = dict(
|
222 |
+
type='mmdet.GIoULoss', loss_weight=2.0),
|
223 |
+
train_cfg: OptConfigType = None,
|
224 |
+
test_cfg: OptConfigType = None,
|
225 |
+
init_cfg: OptMultiConfig = None):
|
226 |
+
|
227 |
+
super().__init__(
|
228 |
+
head_module=head_module,
|
229 |
+
prior_generator=prior_generator,
|
230 |
+
bbox_coder=bbox_coder,
|
231 |
+
loss_cls=loss_cls,
|
232 |
+
loss_bbox=loss_bbox,
|
233 |
+
train_cfg=train_cfg,
|
234 |
+
test_cfg=test_cfg,
|
235 |
+
init_cfg=init_cfg)
|
236 |
+
|
237 |
+
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
|
238 |
+
if self.use_sigmoid_cls:
|
239 |
+
self.cls_out_channels = self.num_classes
|
240 |
+
else:
|
241 |
+
self.cls_out_channels = self.num_classes + 1
|
242 |
+
# rtmdet doesn't need loss_obj
|
243 |
+
self.loss_obj = None
|
244 |
+
|
245 |
+
def special_init(self):
|
246 |
+
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
|
247 |
+
different algorithms have special initialization process.
|
248 |
+
|
249 |
+
The special_init function is designed to deal with this situation.
|
250 |
+
"""
|
251 |
+
if self.train_cfg:
|
252 |
+
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
|
253 |
+
if self.train_cfg.get('sampler', None) is not None:
|
254 |
+
self.sampler = TASK_UTILS.build(
|
255 |
+
self.train_cfg.sampler, default_args=dict(context=self))
|
256 |
+
else:
|
257 |
+
self.sampler = PseudoSampler(context=self)
|
258 |
+
|
259 |
+
self.featmap_sizes_train = None
|
260 |
+
self.flatten_priors_train = None
|
261 |
+
|
262 |
+
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
263 |
+
"""Forward features from the upstream network.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
x (Tuple[Tensor]): Features from the upstream network, each is
|
267 |
+
a 4D-tensor.
|
268 |
+
Returns:
|
269 |
+
Tuple[List]: A tuple of multi-level classification scores, bbox
|
270 |
+
predictions, and objectnesses.
|
271 |
+
"""
|
272 |
+
return self.head_module(x)
|
273 |
+
|
274 |
+
def loss_by_feat(
|
275 |
+
self,
|
276 |
+
cls_scores: List[Tensor],
|
277 |
+
bbox_preds: List[Tensor],
|
278 |
+
batch_gt_instances: InstanceList,
|
279 |
+
batch_img_metas: List[dict],
|
280 |
+
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
281 |
+
"""Compute losses of the head.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
cls_scores (list[Tensor]): Box scores for each scale level
|
285 |
+
Has shape (N, num_anchors * num_classes, H, W)
|
286 |
+
bbox_preds (list[Tensor]): Decoded box for each scale
|
287 |
+
level with shape (N, num_anchors * 4, H, W) in
|
288 |
+
[tl_x, tl_y, br_x, br_y] format.
|
289 |
+
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
290 |
+
gt_instance. It usually includes ``bboxes`` and ``labels``
|
291 |
+
attributes.
|
292 |
+
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
293 |
+
image size, scaling factor, etc.
|
294 |
+
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
|
295 |
+
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
296 |
+
data that is ignored during training and testing.
|
297 |
+
Defaults to None.
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
dict[str, Tensor]: A dictionary of loss components.
|
301 |
+
"""
|
302 |
+
num_imgs = len(batch_img_metas)
|
303 |
+
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
|
304 |
+
assert len(featmap_sizes) == self.prior_generator.num_levels
|
305 |
+
|
306 |
+
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
|
307 |
+
gt_labels = gt_info[:, :, :1]
|
308 |
+
gt_bboxes = gt_info[:, :, 1:] # xyxy
|
309 |
+
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
|
310 |
+
|
311 |
+
device = cls_scores[0].device
|
312 |
+
|
313 |
+
# If the shape does not equal, generate new one
|
314 |
+
if featmap_sizes != self.featmap_sizes_train:
|
315 |
+
self.featmap_sizes_train = featmap_sizes
|
316 |
+
mlvl_priors_with_stride = self.prior_generator.grid_priors(
|
317 |
+
featmap_sizes, device=device, with_stride=True)
|
318 |
+
self.flatten_priors_train = torch.cat(
|
319 |
+
mlvl_priors_with_stride, dim=0)
|
320 |
+
|
321 |
+
flatten_cls_scores = torch.cat([
|
322 |
+
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
323 |
+
self.cls_out_channels)
|
324 |
+
for cls_score in cls_scores
|
325 |
+
], 1).contiguous()
|
326 |
+
|
327 |
+
flatten_bboxes = torch.cat([
|
328 |
+
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
329 |
+
for bbox_pred in bbox_preds
|
330 |
+
], 1)
|
331 |
+
flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1,
|
332 |
+
None]
|
333 |
+
flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2],
|
334 |
+
flatten_bboxes)
|
335 |
+
|
336 |
+
assigned_result = self.assigner(flatten_bboxes.detach(),
|
337 |
+
flatten_cls_scores.detach(),
|
338 |
+
self.flatten_priors_train, gt_labels,
|
339 |
+
gt_bboxes, pad_bbox_flag)
|
340 |
+
|
341 |
+
labels = assigned_result['assigned_labels'].reshape(-1)
|
342 |
+
label_weights = assigned_result['assigned_labels_weights'].reshape(-1)
|
343 |
+
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4)
|
344 |
+
assign_metrics = assigned_result['assign_metrics'].reshape(-1)
|
345 |
+
cls_preds = flatten_cls_scores.reshape(-1, self.num_classes)
|
346 |
+
bbox_preds = flatten_bboxes.reshape(-1, 4)
|
347 |
+
|
348 |
+
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
|
349 |
+
bg_class_ind = self.num_classes
|
350 |
+
pos_inds = ((labels >= 0)
|
351 |
+
& (labels < bg_class_ind)).nonzero().squeeze(1)
|
352 |
+
avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item()
|
353 |
+
|
354 |
+
loss_cls = self.loss_cls(
|
355 |
+
cls_preds, (labels, assign_metrics),
|
356 |
+
label_weights,
|
357 |
+
avg_factor=avg_factor)
|
358 |
+
|
359 |
+
if len(pos_inds) > 0:
|
360 |
+
loss_bbox = self.loss_bbox(
|
361 |
+
bbox_preds[pos_inds],
|
362 |
+
bbox_targets[pos_inds],
|
363 |
+
weight=assign_metrics[pos_inds],
|
364 |
+
avg_factor=avg_factor)
|
365 |
+
else:
|
366 |
+
loss_bbox = bbox_preds.sum() * 0
|
367 |
+
|
368 |
+
return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
|
mmyolo/models/dense_heads/rtmdet_ins_head.py
ADDED
@@ -0,0 +1,725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
from typing import List, Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from mmcv.cnn import ConvModule, is_norm
|
10 |
+
from mmcv.ops import batched_nms
|
11 |
+
from mmdet.models.utils import filter_scores_and_topk
|
12 |
+
from mmdet.structures.bbox import get_box_tensor, get_box_wh, scale_boxes
|
13 |
+
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
|
14 |
+
OptInstanceList, OptMultiConfig)
|
15 |
+
from mmengine import ConfigDict
|
16 |
+
from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
|
17 |
+
normal_init)
|
18 |
+
from mmengine.structures import InstanceData
|
19 |
+
from torch import Tensor
|
20 |
+
|
21 |
+
from mmyolo.registry import MODELS
|
22 |
+
from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
|
23 |
+
|
24 |
+
|
25 |
+
class MaskFeatModule(BaseModule):
|
26 |
+
"""Mask feature head used in RTMDet-Ins. Copy from mmdet.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
in_channels (int): Number of channels in the input feature map.
|
30 |
+
feat_channels (int): Number of hidden channels of the mask feature
|
31 |
+
map branch.
|
32 |
+
stacked_convs (int): Number of convs in mask feature branch.
|
33 |
+
num_levels (int): The starting feature map level from RPN that
|
34 |
+
will be used to predict the mask feature map.
|
35 |
+
num_prototypes (int): Number of output channel of the mask feature
|
36 |
+
map branch. This is the channel count of the mask
|
37 |
+
feature map that to be dynamically convolved with the predicted
|
38 |
+
kernel.
|
39 |
+
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
40 |
+
Default: dict(type='ReLU', inplace=True)
|
41 |
+
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
in_channels: int,
|
47 |
+
feat_channels: int = 256,
|
48 |
+
stacked_convs: int = 4,
|
49 |
+
num_levels: int = 3,
|
50 |
+
num_prototypes: int = 8,
|
51 |
+
act_cfg: ConfigType = dict(type='ReLU', inplace=True),
|
52 |
+
norm_cfg: ConfigType = dict(type='BN')
|
53 |
+
) -> None:
|
54 |
+
super().__init__(init_cfg=None)
|
55 |
+
self.num_levels = num_levels
|
56 |
+
self.fusion_conv = nn.Conv2d(num_levels * in_channels, in_channels, 1)
|
57 |
+
convs = []
|
58 |
+
for i in range(stacked_convs):
|
59 |
+
in_c = in_channels if i == 0 else feat_channels
|
60 |
+
convs.append(
|
61 |
+
ConvModule(
|
62 |
+
in_c,
|
63 |
+
feat_channels,
|
64 |
+
3,
|
65 |
+
padding=1,
|
66 |
+
act_cfg=act_cfg,
|
67 |
+
norm_cfg=norm_cfg))
|
68 |
+
self.stacked_convs = nn.Sequential(*convs)
|
69 |
+
self.projection = nn.Conv2d(
|
70 |
+
feat_channels, num_prototypes, kernel_size=1)
|
71 |
+
|
72 |
+
def forward(self, features: Tuple[Tensor, ...]) -> Tensor:
|
73 |
+
# multi-level feature fusion
|
74 |
+
fusion_feats = [features[0]]
|
75 |
+
size = features[0].shape[-2:]
|
76 |
+
for i in range(1, self.num_levels):
|
77 |
+
f = F.interpolate(features[i], size=size, mode='bilinear')
|
78 |
+
fusion_feats.append(f)
|
79 |
+
fusion_feats = torch.cat(fusion_feats, dim=1)
|
80 |
+
fusion_feats = self.fusion_conv(fusion_feats)
|
81 |
+
# pred mask feats
|
82 |
+
mask_features = self.stacked_convs(fusion_feats)
|
83 |
+
mask_features = self.projection(mask_features)
|
84 |
+
return mask_features
|
85 |
+
|
86 |
+
|
87 |
+
@MODELS.register_module()
|
88 |
+
class RTMDetInsSepBNHeadModule(RTMDetSepBNHeadModule):
|
89 |
+
"""Detection and Instance Segmentation Head of RTMDet.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
num_classes (int): Number of categories excluding the background
|
93 |
+
category.
|
94 |
+
num_prototypes (int): Number of mask prototype features extracted
|
95 |
+
from the mask head. Defaults to 8.
|
96 |
+
dyconv_channels (int): Channel of the dynamic conv layers.
|
97 |
+
Defaults to 8.
|
98 |
+
num_dyconvs (int): Number of the dynamic convolution layers.
|
99 |
+
Defaults to 3.
|
100 |
+
use_sigmoid_cls (bool): Use sigmoid for class prediction.
|
101 |
+
Defaults to True.
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(self,
|
105 |
+
num_classes: int,
|
106 |
+
*args,
|
107 |
+
num_prototypes: int = 8,
|
108 |
+
dyconv_channels: int = 8,
|
109 |
+
num_dyconvs: int = 3,
|
110 |
+
use_sigmoid_cls: bool = True,
|
111 |
+
**kwargs):
|
112 |
+
self.num_prototypes = num_prototypes
|
113 |
+
self.num_dyconvs = num_dyconvs
|
114 |
+
self.dyconv_channels = dyconv_channels
|
115 |
+
self.use_sigmoid_cls = use_sigmoid_cls
|
116 |
+
if self.use_sigmoid_cls:
|
117 |
+
self.cls_out_channels = num_classes
|
118 |
+
else:
|
119 |
+
self.cls_out_channels = num_classes + 1
|
120 |
+
super().__init__(num_classes=num_classes, *args, **kwargs)
|
121 |
+
|
122 |
+
def _init_layers(self):
|
123 |
+
"""Initialize layers of the head."""
|
124 |
+
self.cls_convs = nn.ModuleList()
|
125 |
+
self.reg_convs = nn.ModuleList()
|
126 |
+
self.kernel_convs = nn.ModuleList()
|
127 |
+
|
128 |
+
self.rtm_cls = nn.ModuleList()
|
129 |
+
self.rtm_reg = nn.ModuleList()
|
130 |
+
self.rtm_kernel = nn.ModuleList()
|
131 |
+
self.rtm_obj = nn.ModuleList()
|
132 |
+
|
133 |
+
# calculate num dynamic parameters
|
134 |
+
weight_nums, bias_nums = [], []
|
135 |
+
for i in range(self.num_dyconvs):
|
136 |
+
if i == 0:
|
137 |
+
weight_nums.append(
|
138 |
+
(self.num_prototypes + 2) * self.dyconv_channels)
|
139 |
+
bias_nums.append(self.dyconv_channels)
|
140 |
+
elif i == self.num_dyconvs - 1:
|
141 |
+
weight_nums.append(self.dyconv_channels)
|
142 |
+
bias_nums.append(1)
|
143 |
+
else:
|
144 |
+
weight_nums.append(self.dyconv_channels * self.dyconv_channels)
|
145 |
+
bias_nums.append(self.dyconv_channels)
|
146 |
+
self.weight_nums = weight_nums
|
147 |
+
self.bias_nums = bias_nums
|
148 |
+
self.num_gen_params = sum(weight_nums) + sum(bias_nums)
|
149 |
+
pred_pad_size = self.pred_kernel_size // 2
|
150 |
+
|
151 |
+
for n in range(len(self.featmap_strides)):
|
152 |
+
cls_convs = nn.ModuleList()
|
153 |
+
reg_convs = nn.ModuleList()
|
154 |
+
kernel_convs = nn.ModuleList()
|
155 |
+
for i in range(self.stacked_convs):
|
156 |
+
chn = self.in_channels if i == 0 else self.feat_channels
|
157 |
+
cls_convs.append(
|
158 |
+
ConvModule(
|
159 |
+
chn,
|
160 |
+
self.feat_channels,
|
161 |
+
3,
|
162 |
+
stride=1,
|
163 |
+
padding=1,
|
164 |
+
conv_cfg=self.conv_cfg,
|
165 |
+
norm_cfg=self.norm_cfg,
|
166 |
+
act_cfg=self.act_cfg))
|
167 |
+
reg_convs.append(
|
168 |
+
ConvModule(
|
169 |
+
chn,
|
170 |
+
self.feat_channels,
|
171 |
+
3,
|
172 |
+
stride=1,
|
173 |
+
padding=1,
|
174 |
+
conv_cfg=self.conv_cfg,
|
175 |
+
norm_cfg=self.norm_cfg,
|
176 |
+
act_cfg=self.act_cfg))
|
177 |
+
kernel_convs.append(
|
178 |
+
ConvModule(
|
179 |
+
chn,
|
180 |
+
self.feat_channels,
|
181 |
+
3,
|
182 |
+
stride=1,
|
183 |
+
padding=1,
|
184 |
+
conv_cfg=self.conv_cfg,
|
185 |
+
norm_cfg=self.norm_cfg,
|
186 |
+
act_cfg=self.act_cfg))
|
187 |
+
self.cls_convs.append(cls_convs)
|
188 |
+
self.reg_convs.append(cls_convs)
|
189 |
+
self.kernel_convs.append(kernel_convs)
|
190 |
+
|
191 |
+
self.rtm_cls.append(
|
192 |
+
nn.Conv2d(
|
193 |
+
self.feat_channels,
|
194 |
+
self.num_base_priors * self.cls_out_channels,
|
195 |
+
self.pred_kernel_size,
|
196 |
+
padding=pred_pad_size))
|
197 |
+
self.rtm_reg.append(
|
198 |
+
nn.Conv2d(
|
199 |
+
self.feat_channels,
|
200 |
+
self.num_base_priors * 4,
|
201 |
+
self.pred_kernel_size,
|
202 |
+
padding=pred_pad_size))
|
203 |
+
self.rtm_kernel.append(
|
204 |
+
nn.Conv2d(
|
205 |
+
self.feat_channels,
|
206 |
+
self.num_gen_params,
|
207 |
+
self.pred_kernel_size,
|
208 |
+
padding=pred_pad_size))
|
209 |
+
|
210 |
+
if self.share_conv:
|
211 |
+
for n in range(len(self.featmap_strides)):
|
212 |
+
for i in range(self.stacked_convs):
|
213 |
+
self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
|
214 |
+
self.reg_convs[n][i].conv = self.reg_convs[0][i].conv
|
215 |
+
|
216 |
+
self.mask_head = MaskFeatModule(
|
217 |
+
in_channels=self.in_channels,
|
218 |
+
feat_channels=self.feat_channels,
|
219 |
+
stacked_convs=4,
|
220 |
+
num_levels=len(self.featmap_strides),
|
221 |
+
num_prototypes=self.num_prototypes,
|
222 |
+
act_cfg=self.act_cfg,
|
223 |
+
norm_cfg=self.norm_cfg)
|
224 |
+
|
225 |
+
def init_weights(self) -> None:
|
226 |
+
"""Initialize weights of the head."""
|
227 |
+
for m in self.modules():
|
228 |
+
if isinstance(m, nn.Conv2d):
|
229 |
+
normal_init(m, mean=0, std=0.01)
|
230 |
+
if is_norm(m):
|
231 |
+
constant_init(m, 1)
|
232 |
+
bias_cls = bias_init_with_prob(0.01)
|
233 |
+
for rtm_cls, rtm_reg, rtm_kernel in zip(self.rtm_cls, self.rtm_reg,
|
234 |
+
self.rtm_kernel):
|
235 |
+
normal_init(rtm_cls, std=0.01, bias=bias_cls)
|
236 |
+
normal_init(rtm_reg, std=0.01, bias=1)
|
237 |
+
|
238 |
+
def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
|
239 |
+
"""Forward features from the upstream network.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
feats (tuple[Tensor]): Features from the upstream network, each is
|
243 |
+
a 4D-tensor.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
tuple: Usually a tuple of classification scores and bbox prediction
|
247 |
+
- cls_scores (list[Tensor]): Classification scores for all scale
|
248 |
+
levels, each is a 4D-tensor, the channels number is
|
249 |
+
num_base_priors * num_classes.
|
250 |
+
- bbox_preds (list[Tensor]): Box energies / deltas for all scale
|
251 |
+
levels, each is a 4D-tensor, the channels number is
|
252 |
+
num_base_priors * 4.
|
253 |
+
- kernel_preds (list[Tensor]): Dynamic conv kernels for all scale
|
254 |
+
levels, each is a 4D-tensor, the channels number is
|
255 |
+
num_gen_params.
|
256 |
+
- mask_feat (Tensor): Mask prototype features.
|
257 |
+
Has shape (batch_size, num_prototypes, H, W).
|
258 |
+
"""
|
259 |
+
mask_feat = self.mask_head(feats)
|
260 |
+
|
261 |
+
cls_scores = []
|
262 |
+
bbox_preds = []
|
263 |
+
kernel_preds = []
|
264 |
+
for idx, (x, stride) in enumerate(zip(feats, self.featmap_strides)):
|
265 |
+
cls_feat = x
|
266 |
+
reg_feat = x
|
267 |
+
kernel_feat = x
|
268 |
+
|
269 |
+
for cls_layer in self.cls_convs[idx]:
|
270 |
+
cls_feat = cls_layer(cls_feat)
|
271 |
+
cls_score = self.rtm_cls[idx](cls_feat)
|
272 |
+
|
273 |
+
for kernel_layer in self.kernel_convs[idx]:
|
274 |
+
kernel_feat = kernel_layer(kernel_feat)
|
275 |
+
kernel_pred = self.rtm_kernel[idx](kernel_feat)
|
276 |
+
|
277 |
+
for reg_layer in self.reg_convs[idx]:
|
278 |
+
reg_feat = reg_layer(reg_feat)
|
279 |
+
reg_dist = self.rtm_reg[idx](reg_feat)
|
280 |
+
|
281 |
+
cls_scores.append(cls_score)
|
282 |
+
bbox_preds.append(reg_dist)
|
283 |
+
kernel_preds.append(kernel_pred)
|
284 |
+
return tuple(cls_scores), tuple(bbox_preds), tuple(
|
285 |
+
kernel_preds), mask_feat
|
286 |
+
|
287 |
+
|
288 |
+
@MODELS.register_module()
|
289 |
+
class RTMDetInsSepBNHead(RTMDetHead):
|
290 |
+
"""RTMDet Instance Segmentation head.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
head_module(ConfigType): Base module used for RTMDetInsSepBNHead
|
294 |
+
prior_generator: Points generator feature maps in
|
295 |
+
2D points-based detectors.
|
296 |
+
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
|
297 |
+
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
298 |
+
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
299 |
+
loss_mask (:obj:`ConfigDict` or dict): Config of mask loss.
|
300 |
+
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
301 |
+
anchor head. Defaults to None.
|
302 |
+
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
303 |
+
anchor head. Defaults to None.
|
304 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
305 |
+
list[dict], optional): Initialization config dict.
|
306 |
+
Defaults to None.
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(self,
|
310 |
+
head_module: ConfigType,
|
311 |
+
prior_generator: ConfigType = dict(
|
312 |
+
type='mmdet.MlvlPointGenerator',
|
313 |
+
offset=0,
|
314 |
+
strides=[8, 16, 32]),
|
315 |
+
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
316 |
+
loss_cls: ConfigType = dict(
|
317 |
+
type='mmdet.QualityFocalLoss',
|
318 |
+
use_sigmoid=True,
|
319 |
+
beta=2.0,
|
320 |
+
loss_weight=1.0),
|
321 |
+
loss_bbox: ConfigType = dict(
|
322 |
+
type='mmdet.GIoULoss', loss_weight=2.0),
|
323 |
+
loss_mask=dict(
|
324 |
+
type='mmdet.DiceLoss',
|
325 |
+
loss_weight=2.0,
|
326 |
+
eps=5e-6,
|
327 |
+
reduction='mean'),
|
328 |
+
train_cfg: OptConfigType = None,
|
329 |
+
test_cfg: OptConfigType = None,
|
330 |
+
init_cfg: OptMultiConfig = None):
|
331 |
+
|
332 |
+
super().__init__(
|
333 |
+
head_module=head_module,
|
334 |
+
prior_generator=prior_generator,
|
335 |
+
bbox_coder=bbox_coder,
|
336 |
+
loss_cls=loss_cls,
|
337 |
+
loss_bbox=loss_bbox,
|
338 |
+
train_cfg=train_cfg,
|
339 |
+
test_cfg=test_cfg,
|
340 |
+
init_cfg=init_cfg)
|
341 |
+
|
342 |
+
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
|
343 |
+
if isinstance(self.head_module, RTMDetInsSepBNHeadModule):
|
344 |
+
assert self.use_sigmoid_cls == self.head_module.use_sigmoid_cls
|
345 |
+
self.loss_mask = MODELS.build(loss_mask)
|
346 |
+
|
347 |
+
def predict_by_feat(self,
|
348 |
+
cls_scores: List[Tensor],
|
349 |
+
bbox_preds: List[Tensor],
|
350 |
+
kernel_preds: List[Tensor],
|
351 |
+
mask_feats: Tensor,
|
352 |
+
score_factors: Optional[List[Tensor]] = None,
|
353 |
+
batch_img_metas: Optional[List[dict]] = None,
|
354 |
+
cfg: Optional[ConfigDict] = None,
|
355 |
+
rescale: bool = True,
|
356 |
+
with_nms: bool = True) -> List[InstanceData]:
|
357 |
+
"""Transform a batch of output features extracted from the head into
|
358 |
+
bbox results.
|
359 |
+
|
360 |
+
Note: When score_factors is not None, the cls_scores are
|
361 |
+
usually multiplied by it then obtain the real score used in NMS.
|
362 |
+
|
363 |
+
Args:
|
364 |
+
cls_scores (list[Tensor]): Classification scores for all
|
365 |
+
scale levels, each is a 4D-tensor, has shape
|
366 |
+
(batch_size, num_priors * num_classes, H, W).
|
367 |
+
bbox_preds (list[Tensor]): Box energies / deltas for all
|
368 |
+
scale levels, each is a 4D-tensor, has shape
|
369 |
+
(batch_size, num_priors * 4, H, W).
|
370 |
+
kernel_preds (list[Tensor]): Kernel predictions of dynamic
|
371 |
+
convs for all scale levels, each is a 4D-tensor, has shape
|
372 |
+
(batch_size, num_params, H, W).
|
373 |
+
mask_feats (Tensor): Mask prototype features extracted from the
|
374 |
+
mask head, has shape (batch_size, num_prototypes, H, W).
|
375 |
+
score_factors (list[Tensor], optional): Score factor for
|
376 |
+
all scale level, each is a 4D-tensor, has shape
|
377 |
+
(batch_size, num_priors * 1, H, W). Defaults to None.
|
378 |
+
batch_img_metas (list[dict], Optional): Batch image meta info.
|
379 |
+
Defaults to None.
|
380 |
+
cfg (ConfigDict, optional): Test / postprocessing
|
381 |
+
configuration, if None, test_cfg would be used.
|
382 |
+
Defaults to None.
|
383 |
+
rescale (bool): If True, return boxes in original image space.
|
384 |
+
Defaults to False.
|
385 |
+
with_nms (bool): If True, do nms before return boxes.
|
386 |
+
Defaults to True.
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
list[:obj:`InstanceData`]: Object detection and instance
|
390 |
+
segmentation results of each image after the post process.
|
391 |
+
Each item usually contains following keys.
|
392 |
+
|
393 |
+
- scores (Tensor): Classification scores, has a shape
|
394 |
+
(num_instance, )
|
395 |
+
- labels (Tensor): Labels of bboxes, has a shape
|
396 |
+
(num_instances, ).
|
397 |
+
- bboxes (Tensor): Has a shape (num_instances, 4),
|
398 |
+
the last dimension 4 arrange as (x1, y1, x2, y2).
|
399 |
+
- masks (Tensor): Has a shape (num_instances, h, w).
|
400 |
+
"""
|
401 |
+
cfg = self.test_cfg if cfg is None else cfg
|
402 |
+
cfg = copy.deepcopy(cfg)
|
403 |
+
|
404 |
+
multi_label = cfg.multi_label
|
405 |
+
multi_label &= self.num_classes > 1
|
406 |
+
cfg.multi_label = multi_label
|
407 |
+
|
408 |
+
num_imgs = len(batch_img_metas)
|
409 |
+
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
410 |
+
|
411 |
+
# If the shape does not change, use the previous mlvl_priors
|
412 |
+
if featmap_sizes != self.featmap_sizes:
|
413 |
+
self.mlvl_priors = self.prior_generator.grid_priors(
|
414 |
+
featmap_sizes,
|
415 |
+
dtype=cls_scores[0].dtype,
|
416 |
+
device=cls_scores[0].device,
|
417 |
+
with_stride=True)
|
418 |
+
self.featmap_sizes = featmap_sizes
|
419 |
+
flatten_priors = torch.cat(self.mlvl_priors)
|
420 |
+
|
421 |
+
mlvl_strides = [
|
422 |
+
flatten_priors.new_full(
|
423 |
+
(featmap_size.numel() * self.num_base_priors, ), stride) for
|
424 |
+
featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
|
425 |
+
]
|
426 |
+
flatten_stride = torch.cat(mlvl_strides)
|
427 |
+
|
428 |
+
# flatten cls_scores, bbox_preds
|
429 |
+
flatten_cls_scores = [
|
430 |
+
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
431 |
+
self.num_classes)
|
432 |
+
for cls_score in cls_scores
|
433 |
+
]
|
434 |
+
flatten_bbox_preds = [
|
435 |
+
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
436 |
+
for bbox_pred in bbox_preds
|
437 |
+
]
|
438 |
+
flatten_kernel_preds = [
|
439 |
+
kernel_pred.permute(0, 2, 3,
|
440 |
+
1).reshape(num_imgs, -1,
|
441 |
+
self.head_module.num_gen_params)
|
442 |
+
for kernel_pred in kernel_preds
|
443 |
+
]
|
444 |
+
|
445 |
+
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
446 |
+
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
447 |
+
flatten_decoded_bboxes = self.bbox_coder.decode(
|
448 |
+
flatten_priors[..., :2].unsqueeze(0), flatten_bbox_preds,
|
449 |
+
flatten_stride)
|
450 |
+
|
451 |
+
flatten_kernel_preds = torch.cat(flatten_kernel_preds, dim=1)
|
452 |
+
|
453 |
+
results_list = []
|
454 |
+
for (bboxes, scores, kernel_pred, mask_feat,
|
455 |
+
img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
|
456 |
+
flatten_kernel_preds, mask_feats,
|
457 |
+
batch_img_metas):
|
458 |
+
ori_shape = img_meta['ori_shape']
|
459 |
+
scale_factor = img_meta['scale_factor']
|
460 |
+
if 'pad_param' in img_meta:
|
461 |
+
pad_param = img_meta['pad_param']
|
462 |
+
else:
|
463 |
+
pad_param = None
|
464 |
+
|
465 |
+
score_thr = cfg.get('score_thr', -1)
|
466 |
+
if scores.shape[0] == 0:
|
467 |
+
empty_results = InstanceData()
|
468 |
+
empty_results.bboxes = bboxes
|
469 |
+
empty_results.scores = scores[:, 0]
|
470 |
+
empty_results.labels = scores[:, 0].int()
|
471 |
+
h, w = ori_shape[:2] if rescale else img_meta['img_shape'][:2]
|
472 |
+
empty_results.masks = torch.zeros(
|
473 |
+
size=(0, h, w), dtype=torch.bool, device=bboxes.device)
|
474 |
+
results_list.append(empty_results)
|
475 |
+
continue
|
476 |
+
|
477 |
+
nms_pre = cfg.get('nms_pre', 100000)
|
478 |
+
if cfg.multi_label is False:
|
479 |
+
scores, labels = scores.max(1, keepdim=True)
|
480 |
+
scores, _, keep_idxs, results = filter_scores_and_topk(
|
481 |
+
scores,
|
482 |
+
score_thr,
|
483 |
+
nms_pre,
|
484 |
+
results=dict(
|
485 |
+
labels=labels[:, 0],
|
486 |
+
kernel_pred=kernel_pred,
|
487 |
+
priors=flatten_priors))
|
488 |
+
labels = results['labels']
|
489 |
+
kernel_pred = results['kernel_pred']
|
490 |
+
priors = results['priors']
|
491 |
+
else:
|
492 |
+
out = filter_scores_and_topk(
|
493 |
+
scores,
|
494 |
+
score_thr,
|
495 |
+
nms_pre,
|
496 |
+
results=dict(
|
497 |
+
kernel_pred=kernel_pred, priors=flatten_priors))
|
498 |
+
scores, labels, keep_idxs, filtered_results = out
|
499 |
+
kernel_pred = filtered_results['kernel_pred']
|
500 |
+
priors = filtered_results['priors']
|
501 |
+
|
502 |
+
results = InstanceData(
|
503 |
+
scores=scores,
|
504 |
+
labels=labels,
|
505 |
+
bboxes=bboxes[keep_idxs],
|
506 |
+
kernels=kernel_pred,
|
507 |
+
priors=priors)
|
508 |
+
|
509 |
+
if rescale:
|
510 |
+
if pad_param is not None:
|
511 |
+
results.bboxes -= results.bboxes.new_tensor([
|
512 |
+
pad_param[2], pad_param[0], pad_param[2], pad_param[0]
|
513 |
+
])
|
514 |
+
results.bboxes /= results.bboxes.new_tensor(
|
515 |
+
scale_factor).repeat((1, 2))
|
516 |
+
|
517 |
+
if cfg.get('yolox_style', False):
|
518 |
+
# do not need max_per_img
|
519 |
+
cfg.max_per_img = len(results)
|
520 |
+
|
521 |
+
results = self._bbox_mask_post_process(
|
522 |
+
results=results,
|
523 |
+
mask_feat=mask_feat,
|
524 |
+
cfg=cfg,
|
525 |
+
rescale_bbox=False,
|
526 |
+
rescale_mask=rescale,
|
527 |
+
with_nms=with_nms,
|
528 |
+
pad_param=pad_param,
|
529 |
+
img_meta=img_meta)
|
530 |
+
results.bboxes[:, 0::2].clamp_(0, ori_shape[1])
|
531 |
+
results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
|
532 |
+
|
533 |
+
results_list.append(results)
|
534 |
+
return results_list
|
535 |
+
|
536 |
+
def _bbox_mask_post_process(
|
537 |
+
self,
|
538 |
+
results: InstanceData,
|
539 |
+
mask_feat: Tensor,
|
540 |
+
cfg: ConfigDict,
|
541 |
+
rescale_bbox: bool = False,
|
542 |
+
rescale_mask: bool = True,
|
543 |
+
with_nms: bool = True,
|
544 |
+
pad_param: Optional[np.ndarray] = None,
|
545 |
+
img_meta: Optional[dict] = None) -> InstanceData:
|
546 |
+
"""bbox and mask post-processing method.
|
547 |
+
|
548 |
+
The boxes would be rescaled to the original image scale and do
|
549 |
+
the nms operation. Usually `with_nms` is False is used for aug test.
|
550 |
+
|
551 |
+
Args:
|
552 |
+
results (:obj:`InstaceData`): Detection instance results,
|
553 |
+
each item has shape (num_bboxes, ).
|
554 |
+
mask_feat (Tensor): Mask prototype features extracted from the
|
555 |
+
mask head, has shape (batch_size, num_prototypes, H, W).
|
556 |
+
cfg (ConfigDict): Test / postprocessing configuration,
|
557 |
+
if None, test_cfg would be used.
|
558 |
+
rescale_bbox (bool): If True, return boxes in original image space.
|
559 |
+
Default to False.
|
560 |
+
rescale_mask (bool): If True, return masks in original image space.
|
561 |
+
Default to True.
|
562 |
+
with_nms (bool): If True, do nms before return boxes.
|
563 |
+
Default to True.
|
564 |
+
img_meta (dict, optional): Image meta info. Defaults to None.
|
565 |
+
|
566 |
+
Returns:
|
567 |
+
:obj:`InstanceData`: Detection results of each image
|
568 |
+
after the post process.
|
569 |
+
Each item usually contains following keys.
|
570 |
+
|
571 |
+
- scores (Tensor): Classification scores, has a shape
|
572 |
+
(num_instance, )
|
573 |
+
- labels (Tensor): Labels of bboxes, has a shape
|
574 |
+
(num_instances, ).
|
575 |
+
- bboxes (Tensor): Has a shape (num_instances, 4),
|
576 |
+
the last dimension 4 arrange as (x1, y1, x2, y2).
|
577 |
+
- masks (Tensor): Has a shape (num_instances, h, w).
|
578 |
+
"""
|
579 |
+
if rescale_bbox:
|
580 |
+
assert img_meta.get('scale_factor') is not None
|
581 |
+
scale_factor = [1 / s for s in img_meta['scale_factor']]
|
582 |
+
results.bboxes = scale_boxes(results.bboxes, scale_factor)
|
583 |
+
|
584 |
+
if hasattr(results, 'score_factors'):
|
585 |
+
# TODO: Add sqrt operation in order to be consistent with
|
586 |
+
# the paper.
|
587 |
+
score_factors = results.pop('score_factors')
|
588 |
+
results.scores = results.scores * score_factors
|
589 |
+
|
590 |
+
# filter small size bboxes
|
591 |
+
if cfg.get('min_bbox_size', -1) >= 0:
|
592 |
+
w, h = get_box_wh(results.bboxes)
|
593 |
+
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
|
594 |
+
if not valid_mask.all():
|
595 |
+
results = results[valid_mask]
|
596 |
+
|
597 |
+
# TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg
|
598 |
+
assert with_nms, 'with_nms must be True for RTMDet-Ins'
|
599 |
+
if results.bboxes.numel() > 0:
|
600 |
+
bboxes = get_box_tensor(results.bboxes)
|
601 |
+
det_bboxes, keep_idxs = batched_nms(bboxes, results.scores,
|
602 |
+
results.labels, cfg.nms)
|
603 |
+
results = results[keep_idxs]
|
604 |
+
# some nms would reweight the score, such as softnms
|
605 |
+
results.scores = det_bboxes[:, -1]
|
606 |
+
results = results[:cfg.max_per_img]
|
607 |
+
|
608 |
+
# process masks
|
609 |
+
mask_logits = self._mask_predict_by_feat(mask_feat,
|
610 |
+
results.kernels,
|
611 |
+
results.priors)
|
612 |
+
|
613 |
+
stride = self.prior_generator.strides[0][0]
|
614 |
+
mask_logits = F.interpolate(
|
615 |
+
mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear')
|
616 |
+
if rescale_mask:
|
617 |
+
# TODO: When use mmdet.Resize or mmdet.Pad, will meet bug
|
618 |
+
# Use img_meta to crop and resize
|
619 |
+
ori_h, ori_w = img_meta['ori_shape'][:2]
|
620 |
+
if isinstance(pad_param, np.ndarray):
|
621 |
+
pad_param = pad_param.astype(np.int32)
|
622 |
+
crop_y1, crop_y2 = pad_param[
|
623 |
+
0], mask_logits.shape[-2] - pad_param[1]
|
624 |
+
crop_x1, crop_x2 = pad_param[
|
625 |
+
2], mask_logits.shape[-1] - pad_param[3]
|
626 |
+
mask_logits = mask_logits[..., crop_y1:crop_y2,
|
627 |
+
crop_x1:crop_x2]
|
628 |
+
mask_logits = F.interpolate(
|
629 |
+
mask_logits,
|
630 |
+
size=[ori_h, ori_w],
|
631 |
+
mode='bilinear',
|
632 |
+
align_corners=False)
|
633 |
+
|
634 |
+
masks = mask_logits.sigmoid().squeeze(0)
|
635 |
+
masks = masks > cfg.mask_thr_binary
|
636 |
+
results.masks = masks
|
637 |
+
else:
|
638 |
+
h, w = img_meta['ori_shape'][:2] if rescale_mask else img_meta[
|
639 |
+
'img_shape'][:2]
|
640 |
+
results.masks = torch.zeros(
|
641 |
+
size=(results.bboxes.shape[0], h, w),
|
642 |
+
dtype=torch.bool,
|
643 |
+
device=results.bboxes.device)
|
644 |
+
return results
|
645 |
+
|
646 |
+
def _mask_predict_by_feat(self, mask_feat: Tensor, kernels: Tensor,
|
647 |
+
priors: Tensor) -> Tensor:
|
648 |
+
"""Generate mask logits from mask features with dynamic convs.
|
649 |
+
|
650 |
+
Args:
|
651 |
+
mask_feat (Tensor): Mask prototype features.
|
652 |
+
Has shape (num_prototypes, H, W).
|
653 |
+
kernels (Tensor): Kernel parameters for each instance.
|
654 |
+
Has shape (num_instance, num_params)
|
655 |
+
priors (Tensor): Center priors for each instance.
|
656 |
+
Has shape (num_instance, 4).
|
657 |
+
Returns:
|
658 |
+
Tensor: Instance segmentation masks for each instance.
|
659 |
+
Has shape (num_instance, H, W).
|
660 |
+
"""
|
661 |
+
num_inst = kernels.shape[0]
|
662 |
+
h, w = mask_feat.size()[-2:]
|
663 |
+
if num_inst < 1:
|
664 |
+
return torch.empty(
|
665 |
+
size=(num_inst, h, w),
|
666 |
+
dtype=mask_feat.dtype,
|
667 |
+
device=mask_feat.device)
|
668 |
+
if len(mask_feat.shape) < 4:
|
669 |
+
mask_feat.unsqueeze(0)
|
670 |
+
|
671 |
+
coord = self.prior_generator.single_level_grid_priors(
|
672 |
+
(h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2)
|
673 |
+
num_inst = priors.shape[0]
|
674 |
+
points = priors[:, :2].reshape(-1, 1, 2)
|
675 |
+
strides = priors[:, 2:].reshape(-1, 1, 2)
|
676 |
+
relative_coord = (points - coord).permute(0, 2, 1) / (
|
677 |
+
strides[..., 0].reshape(-1, 1, 1) * 8)
|
678 |
+
relative_coord = relative_coord.reshape(num_inst, 2, h, w)
|
679 |
+
|
680 |
+
mask_feat = torch.cat(
|
681 |
+
[relative_coord,
|
682 |
+
mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
|
683 |
+
weights, biases = self.parse_dynamic_params(kernels)
|
684 |
+
|
685 |
+
n_layers = len(weights)
|
686 |
+
x = mask_feat.reshape(1, -1, h, w)
|
687 |
+
for i, (weight, bias) in enumerate(zip(weights, biases)):
|
688 |
+
x = F.conv2d(
|
689 |
+
x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
|
690 |
+
if i < n_layers - 1:
|
691 |
+
x = F.relu(x)
|
692 |
+
x = x.reshape(num_inst, h, w)
|
693 |
+
return x
|
694 |
+
|
695 |
+
def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple:
|
696 |
+
"""split kernel head prediction to conv weight and bias."""
|
697 |
+
n_inst = flatten_kernels.size(0)
|
698 |
+
n_layers = len(self.head_module.weight_nums)
|
699 |
+
params_splits = list(
|
700 |
+
torch.split_with_sizes(
|
701 |
+
flatten_kernels,
|
702 |
+
self.head_module.weight_nums + self.head_module.bias_nums,
|
703 |
+
dim=1))
|
704 |
+
weight_splits = params_splits[:n_layers]
|
705 |
+
bias_splits = params_splits[n_layers:]
|
706 |
+
for i in range(n_layers):
|
707 |
+
if i < n_layers - 1:
|
708 |
+
weight_splits[i] = weight_splits[i].reshape(
|
709 |
+
n_inst * self.head_module.dyconv_channels, -1, 1, 1)
|
710 |
+
bias_splits[i] = bias_splits[i].reshape(
|
711 |
+
n_inst * self.head_module.dyconv_channels)
|
712 |
+
else:
|
713 |
+
weight_splits[i] = weight_splits[i].reshape(n_inst, -1, 1, 1)
|
714 |
+
bias_splits[i] = bias_splits[i].reshape(n_inst)
|
715 |
+
|
716 |
+
return weight_splits, bias_splits
|
717 |
+
|
718 |
+
def loss_by_feat(
|
719 |
+
self,
|
720 |
+
cls_scores: List[Tensor],
|
721 |
+
bbox_preds: List[Tensor],
|
722 |
+
batch_gt_instances: InstanceList,
|
723 |
+
batch_img_metas: List[dict],
|
724 |
+
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
725 |
+
raise NotImplementedError
|
mmyolo/models/dense_heads/rtmdet_rotated_head.py
ADDED
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import warnings
|
4 |
+
from typing import List, Optional, Sequence, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from mmdet.models.utils import filter_scores_and_topk
|
9 |
+
from mmdet.structures.bbox import HorizontalBoxes, distance2bbox
|
10 |
+
from mmdet.structures.bbox.transforms import bbox_cxcywh_to_xyxy, scale_boxes
|
11 |
+
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
|
12 |
+
OptInstanceList, OptMultiConfig, reduce_mean)
|
13 |
+
from mmengine.config import ConfigDict
|
14 |
+
from mmengine.model import normal_init
|
15 |
+
from mmengine.structures import InstanceData
|
16 |
+
from torch import Tensor
|
17 |
+
|
18 |
+
from mmyolo.registry import MODELS, TASK_UTILS
|
19 |
+
from ..utils import gt_instances_preprocess
|
20 |
+
from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
|
21 |
+
|
22 |
+
try:
|
23 |
+
from mmrotate.structures.bbox import RotatedBoxes, distance2obb
|
24 |
+
MMROTATE_AVAILABLE = True
|
25 |
+
except ImportError:
|
26 |
+
RotatedBoxes = None
|
27 |
+
distance2obb = None
|
28 |
+
MMROTATE_AVAILABLE = False
|
29 |
+
|
30 |
+
|
31 |
+
@MODELS.register_module()
|
32 |
+
class RTMDetRotatedSepBNHeadModule(RTMDetSepBNHeadModule):
|
33 |
+
"""Detection Head Module of RTMDet-R.
|
34 |
+
|
35 |
+
Compared with RTMDet Detection Head Module, RTMDet-R adds
|
36 |
+
a conv for angle prediction.
|
37 |
+
An `angle_out_dim` arg is added, which is generated by the
|
38 |
+
angle_coder module and controls the angle pred dim.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
num_classes (int): Number of categories excluding the background
|
42 |
+
category.
|
43 |
+
in_channels (int): Number of channels in the input feature map.
|
44 |
+
widen_factor (float): Width multiplier, multiply number of
|
45 |
+
channels in each layer by this amount. Defaults to 1.0.
|
46 |
+
num_base_priors (int): The number of priors (points) at a point
|
47 |
+
on the feature grid. Defaults to 1.
|
48 |
+
feat_channels (int): Number of hidden channels. Used in child classes.
|
49 |
+
Defaults to 256
|
50 |
+
stacked_convs (int): Number of stacking convs of the head.
|
51 |
+
Defaults to 2.
|
52 |
+
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
53 |
+
Defaults to (8, 16, 32).
|
54 |
+
share_conv (bool): Whether to share conv layers between stages.
|
55 |
+
Defaults to True.
|
56 |
+
pred_kernel_size (int): Kernel size of ``nn.Conv2d``. Defaults to 1.
|
57 |
+
angle_out_dim (int): Encoded length of angle, will passed by head.
|
58 |
+
Defaults to 1.
|
59 |
+
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
|
60 |
+
convolution layer. Defaults to None.
|
61 |
+
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
|
62 |
+
layer. Defaults to ``dict(type='BN')``.
|
63 |
+
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
64 |
+
Default: dict(type='SiLU', inplace=True).
|
65 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
66 |
+
list[dict], optional): Initialization config dict.
|
67 |
+
Defaults to None.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
num_classes: int,
|
73 |
+
in_channels: int,
|
74 |
+
widen_factor: float = 1.0,
|
75 |
+
num_base_priors: int = 1,
|
76 |
+
feat_channels: int = 256,
|
77 |
+
stacked_convs: int = 2,
|
78 |
+
featmap_strides: Sequence[int] = [8, 16, 32],
|
79 |
+
share_conv: bool = True,
|
80 |
+
pred_kernel_size: int = 1,
|
81 |
+
angle_out_dim: int = 1,
|
82 |
+
conv_cfg: OptConfigType = None,
|
83 |
+
norm_cfg: ConfigType = dict(type='BN'),
|
84 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
85 |
+
init_cfg: OptMultiConfig = None,
|
86 |
+
):
|
87 |
+
self.angle_out_dim = angle_out_dim
|
88 |
+
super().__init__(
|
89 |
+
num_classes=num_classes,
|
90 |
+
in_channels=in_channels,
|
91 |
+
widen_factor=widen_factor,
|
92 |
+
num_base_priors=num_base_priors,
|
93 |
+
feat_channels=feat_channels,
|
94 |
+
stacked_convs=stacked_convs,
|
95 |
+
featmap_strides=featmap_strides,
|
96 |
+
share_conv=share_conv,
|
97 |
+
pred_kernel_size=pred_kernel_size,
|
98 |
+
conv_cfg=conv_cfg,
|
99 |
+
norm_cfg=norm_cfg,
|
100 |
+
act_cfg=act_cfg,
|
101 |
+
init_cfg=init_cfg)
|
102 |
+
|
103 |
+
def _init_layers(self):
|
104 |
+
"""Initialize layers of the head."""
|
105 |
+
super()._init_layers()
|
106 |
+
self.rtm_ang = nn.ModuleList()
|
107 |
+
for _ in range(len(self.featmap_strides)):
|
108 |
+
self.rtm_ang.append(
|
109 |
+
nn.Conv2d(
|
110 |
+
self.feat_channels,
|
111 |
+
self.num_base_priors * self.angle_out_dim,
|
112 |
+
self.pred_kernel_size,
|
113 |
+
padding=self.pred_kernel_size // 2))
|
114 |
+
|
115 |
+
def init_weights(self) -> None:
|
116 |
+
"""Initialize weights of the head."""
|
117 |
+
# Use prior in model initialization to improve stability
|
118 |
+
super().init_weights()
|
119 |
+
for rtm_ang in self.rtm_ang:
|
120 |
+
normal_init(rtm_ang, std=0.01)
|
121 |
+
|
122 |
+
def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
|
123 |
+
"""Forward features from the upstream network.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
feats (tuple[Tensor]): Features from the upstream network, each is
|
127 |
+
a 4D-tensor.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
tuple: Usually a tuple of classification scores and bbox prediction
|
131 |
+
- cls_scores (list[Tensor]): Classification scores for all scale
|
132 |
+
levels, each is a 4D-tensor, the channels number is
|
133 |
+
num_base_priors * num_classes.
|
134 |
+
- bbox_preds (list[Tensor]): Box energies / deltas for all scale
|
135 |
+
levels, each is a 4D-tensor, the channels number is
|
136 |
+
num_base_priors * 4.
|
137 |
+
- angle_preds (list[Tensor]): Angle prediction for all scale
|
138 |
+
levels, each is a 4D-tensor, the channels number is
|
139 |
+
num_base_priors * angle_out_dim.
|
140 |
+
"""
|
141 |
+
|
142 |
+
cls_scores = []
|
143 |
+
bbox_preds = []
|
144 |
+
angle_preds = []
|
145 |
+
for idx, x in enumerate(feats):
|
146 |
+
cls_feat = x
|
147 |
+
reg_feat = x
|
148 |
+
|
149 |
+
for cls_layer in self.cls_convs[idx]:
|
150 |
+
cls_feat = cls_layer(cls_feat)
|
151 |
+
cls_score = self.rtm_cls[idx](cls_feat)
|
152 |
+
|
153 |
+
for reg_layer in self.reg_convs[idx]:
|
154 |
+
reg_feat = reg_layer(reg_feat)
|
155 |
+
|
156 |
+
reg_dist = self.rtm_reg[idx](reg_feat)
|
157 |
+
angle_pred = self.rtm_ang[idx](reg_feat)
|
158 |
+
|
159 |
+
cls_scores.append(cls_score)
|
160 |
+
bbox_preds.append(reg_dist)
|
161 |
+
angle_preds.append(angle_pred)
|
162 |
+
return tuple(cls_scores), tuple(bbox_preds), tuple(angle_preds)
|
163 |
+
|
164 |
+
|
165 |
+
@MODELS.register_module()
|
166 |
+
class RTMDetRotatedHead(RTMDetHead):
|
167 |
+
"""RTMDet-R head.
|
168 |
+
|
169 |
+
Compared with RTMDetHead, RTMDetRotatedHead add some args to support
|
170 |
+
rotated object detection.
|
171 |
+
|
172 |
+
- `angle_version` used to limit angle_range during training.
|
173 |
+
- `angle_coder` used to encode and decode angle, which is similar
|
174 |
+
to bbox_coder.
|
175 |
+
- `use_hbbox_loss` and `loss_angle` allow custom regression loss
|
176 |
+
calculation for rotated box.
|
177 |
+
|
178 |
+
There are three combination options for regression:
|
179 |
+
|
180 |
+
1. `use_hbbox_loss=False` and loss_angle is None.
|
181 |
+
|
182 |
+
.. code:: text
|
183 |
+
|
184 |
+
bbox_pred────(tblr)───┐
|
185 |
+
▼
|
186 |
+
angle_pred decode──►rbox_pred──(xywha)─►loss_bbox
|
187 |
+
│ ▲
|
188 |
+
└────►decode──(a)─┘
|
189 |
+
|
190 |
+
2. `use_hbbox_loss=False` and loss_angle is specified.
|
191 |
+
A angle loss is added on angle_pred.
|
192 |
+
|
193 |
+
.. code:: text
|
194 |
+
|
195 |
+
bbox_pred────(tblr)───┐
|
196 |
+
▼
|
197 |
+
angle_pred decode──►rbox_pred──(xywha)─►loss_bbox
|
198 |
+
│ ▲
|
199 |
+
├────►decode──(a)─┘
|
200 |
+
│
|
201 |
+
└───────────────────────────────────────────►loss_angle
|
202 |
+
|
203 |
+
3. `use_hbbox_loss=True` and loss_angle is specified.
|
204 |
+
In this case the loss_angle must be set.
|
205 |
+
|
206 |
+
.. code:: text
|
207 |
+
|
208 |
+
bbox_pred──(tblr)──►decode──►hbox_pred──(xyxy)──►loss_bbox
|
209 |
+
|
210 |
+
angle_pred──────────────────────────────────────►loss_angle
|
211 |
+
|
212 |
+
- There's a `decoded_with_angle` flag in test_cfg, which is similar
|
213 |
+
to training process.
|
214 |
+
|
215 |
+
When `decoded_with_angle=True`:
|
216 |
+
|
217 |
+
.. code:: text
|
218 |
+
|
219 |
+
bbox_pred────(tblr)───┐
|
220 |
+
▼
|
221 |
+
angle_pred decode──(xywha)──►rbox_pred
|
222 |
+
│ ▲
|
223 |
+
└────►decode──(a)─┘
|
224 |
+
|
225 |
+
When `decoded_with_angle=False`:
|
226 |
+
|
227 |
+
.. code:: text
|
228 |
+
|
229 |
+
bbox_pred──(tblr)─►decode
|
230 |
+
│ (xyxy)
|
231 |
+
▼
|
232 |
+
format───(xywh)──►concat──(xywha)──►rbox_pred
|
233 |
+
▲
|
234 |
+
angle_pred────────►decode────(a)───────┘
|
235 |
+
|
236 |
+
Args:
|
237 |
+
head_module(ConfigType): Base module used for RTMDetRotatedHead.
|
238 |
+
prior_generator: Points generator feature maps in
|
239 |
+
2D points-based detectors.
|
240 |
+
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
|
241 |
+
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
242 |
+
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
243 |
+
angle_version (str): Angle representations. Defaults to 'le90'.
|
244 |
+
use_hbbox_loss (bool): If true, use horizontal bbox loss and
|
245 |
+
loss_angle should not be None. Default to False.
|
246 |
+
angle_coder (:obj:`ConfigDict` or dict): Config of angle coder.
|
247 |
+
loss_angle (:obj:`ConfigDict` or dict, optional): Config of angle loss.
|
248 |
+
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
249 |
+
anchor head. Defaults to None.
|
250 |
+
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
251 |
+
anchor head. Defaults to None.
|
252 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
253 |
+
list[dict], optional): Initialization config dict.
|
254 |
+
Defaults to None.
|
255 |
+
"""
|
256 |
+
|
257 |
+
def __init__(
|
258 |
+
self,
|
259 |
+
head_module: ConfigType,
|
260 |
+
prior_generator: ConfigType = dict(
|
261 |
+
type='mmdet.MlvlPointGenerator', strides=[8, 16, 32],
|
262 |
+
offset=0),
|
263 |
+
bbox_coder: ConfigType = dict(type='DistanceAnglePointCoder'),
|
264 |
+
loss_cls: ConfigType = dict(
|
265 |
+
type='mmdet.QualityFocalLoss',
|
266 |
+
use_sigmoid=True,
|
267 |
+
beta=2.0,
|
268 |
+
loss_weight=1.0),
|
269 |
+
loss_bbox: ConfigType = dict(
|
270 |
+
type='mmrotate.RotatedIoULoss', mode='linear',
|
271 |
+
loss_weight=2.0),
|
272 |
+
angle_version: str = 'le90',
|
273 |
+
use_hbbox_loss: bool = False,
|
274 |
+
angle_coder: ConfigType = dict(type='mmrotate.PseudoAngleCoder'),
|
275 |
+
loss_angle: OptConfigType = None,
|
276 |
+
train_cfg: OptConfigType = None,
|
277 |
+
test_cfg: OptConfigType = None,
|
278 |
+
init_cfg: OptMultiConfig = None):
|
279 |
+
if not MMROTATE_AVAILABLE:
|
280 |
+
raise ImportError(
|
281 |
+
'Please run "mim install -r requirements/mmrotate.txt" '
|
282 |
+
'to install mmrotate first for rotated detection.')
|
283 |
+
|
284 |
+
self.angle_version = angle_version
|
285 |
+
self.use_hbbox_loss = use_hbbox_loss
|
286 |
+
if self.use_hbbox_loss:
|
287 |
+
assert loss_angle is not None, \
|
288 |
+
('When use hbbox loss, loss_angle needs to be specified')
|
289 |
+
self.angle_coder = TASK_UTILS.build(angle_coder)
|
290 |
+
self.angle_out_dim = self.angle_coder.encode_size
|
291 |
+
if head_module.get('angle_out_dim') is not None:
|
292 |
+
warnings.warn('angle_out_dim will be overridden by angle_coder '
|
293 |
+
'and does not need to be set manually')
|
294 |
+
|
295 |
+
head_module['angle_out_dim'] = self.angle_out_dim
|
296 |
+
super().__init__(
|
297 |
+
head_module=head_module,
|
298 |
+
prior_generator=prior_generator,
|
299 |
+
bbox_coder=bbox_coder,
|
300 |
+
loss_cls=loss_cls,
|
301 |
+
loss_bbox=loss_bbox,
|
302 |
+
train_cfg=train_cfg,
|
303 |
+
test_cfg=test_cfg,
|
304 |
+
init_cfg=init_cfg)
|
305 |
+
|
306 |
+
if loss_angle is not None:
|
307 |
+
self.loss_angle = MODELS.build(loss_angle)
|
308 |
+
else:
|
309 |
+
self.loss_angle = None
|
310 |
+
|
311 |
+
def predict_by_feat(self,
|
312 |
+
cls_scores: List[Tensor],
|
313 |
+
bbox_preds: List[Tensor],
|
314 |
+
angle_preds: List[Tensor],
|
315 |
+
objectnesses: Optional[List[Tensor]] = None,
|
316 |
+
batch_img_metas: Optional[List[dict]] = None,
|
317 |
+
cfg: Optional[ConfigDict] = None,
|
318 |
+
rescale: bool = True,
|
319 |
+
with_nms: bool = True) -> List[InstanceData]:
|
320 |
+
"""Transform a batch of output features extracted by the head into bbox
|
321 |
+
results.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
cls_scores (list[Tensor]): Classification scores for all
|
325 |
+
scale levels, each is a 4D-tensor, has shape
|
326 |
+
(batch_size, num_priors * num_classes, H, W).
|
327 |
+
bbox_preds (list[Tensor]): Box energies / deltas for all
|
328 |
+
scale levels, each is a 4D-tensor, has shape
|
329 |
+
(batch_size, num_priors * 4, H, W).
|
330 |
+
angle_preds (list[Tensor]): Box angle for each scale level
|
331 |
+
with shape (N, num_points * angle_dim, H, W)
|
332 |
+
objectnesses (list[Tensor], Optional): Score factor for
|
333 |
+
all scale level, each is a 4D-tensor, has shape
|
334 |
+
(batch_size, 1, H, W).
|
335 |
+
batch_img_metas (list[dict], Optional): Batch image meta info.
|
336 |
+
Defaults to None.
|
337 |
+
cfg (ConfigDict, optional): Test / postprocessing
|
338 |
+
configuration, if None, test_cfg would be used.
|
339 |
+
Defaults to None.
|
340 |
+
rescale (bool): If True, return boxes in original image space.
|
341 |
+
Defaults to False.
|
342 |
+
with_nms (bool): If True, do nms before return boxes.
|
343 |
+
Defaults to True.
|
344 |
+
|
345 |
+
Returns:
|
346 |
+
list[:obj:`InstanceData`]: Object detection results of each image
|
347 |
+
after the post process. Each item usually contains following keys.
|
348 |
+
- scores (Tensor): Classification scores, has a shape
|
349 |
+
(num_instance, )
|
350 |
+
- labels (Tensor): Labels of bboxes, has a shape
|
351 |
+
(num_instances, ).
|
352 |
+
- bboxes (Tensor): Has a shape (num_instances, 5),
|
353 |
+
the last dimension 4 arrange as (x, y, w, h, angle).
|
354 |
+
"""
|
355 |
+
assert len(cls_scores) == len(bbox_preds)
|
356 |
+
if objectnesses is None:
|
357 |
+
with_objectnesses = False
|
358 |
+
else:
|
359 |
+
with_objectnesses = True
|
360 |
+
assert len(cls_scores) == len(objectnesses)
|
361 |
+
|
362 |
+
cfg = self.test_cfg if cfg is None else cfg
|
363 |
+
cfg = copy.deepcopy(cfg)
|
364 |
+
|
365 |
+
multi_label = cfg.multi_label
|
366 |
+
multi_label &= self.num_classes > 1
|
367 |
+
cfg.multi_label = multi_label
|
368 |
+
|
369 |
+
# Whether to decode rbox with angle.
|
370 |
+
# different setting lead to different final results.
|
371 |
+
# Defaults to True.
|
372 |
+
decode_with_angle = cfg.get('decode_with_angle', True)
|
373 |
+
|
374 |
+
num_imgs = len(batch_img_metas)
|
375 |
+
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
376 |
+
|
377 |
+
# If the shape does not change, use the previous mlvl_priors
|
378 |
+
if featmap_sizes != self.featmap_sizes:
|
379 |
+
self.mlvl_priors = self.prior_generator.grid_priors(
|
380 |
+
featmap_sizes,
|
381 |
+
dtype=cls_scores[0].dtype,
|
382 |
+
device=cls_scores[0].device)
|
383 |
+
self.featmap_sizes = featmap_sizes
|
384 |
+
flatten_priors = torch.cat(self.mlvl_priors)
|
385 |
+
|
386 |
+
mlvl_strides = [
|
387 |
+
flatten_priors.new_full(
|
388 |
+
(featmap_size.numel() * self.num_base_priors, ), stride) for
|
389 |
+
featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
|
390 |
+
]
|
391 |
+
flatten_stride = torch.cat(mlvl_strides)
|
392 |
+
|
393 |
+
# flatten cls_scores, bbox_preds and objectness
|
394 |
+
flatten_cls_scores = [
|
395 |
+
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
396 |
+
self.num_classes)
|
397 |
+
for cls_score in cls_scores
|
398 |
+
]
|
399 |
+
flatten_bbox_preds = [
|
400 |
+
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
401 |
+
for bbox_pred in bbox_preds
|
402 |
+
]
|
403 |
+
flatten_angle_preds = [
|
404 |
+
angle_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
405 |
+
self.angle_out_dim)
|
406 |
+
for angle_pred in angle_preds
|
407 |
+
]
|
408 |
+
|
409 |
+
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
410 |
+
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
411 |
+
flatten_angle_preds = torch.cat(flatten_angle_preds, dim=1)
|
412 |
+
flatten_angle_preds = self.angle_coder.decode(
|
413 |
+
flatten_angle_preds, keepdim=True)
|
414 |
+
|
415 |
+
if decode_with_angle:
|
416 |
+
flatten_rbbox_preds = torch.cat(
|
417 |
+
[flatten_bbox_preds, flatten_angle_preds], dim=-1)
|
418 |
+
flatten_decoded_bboxes = self.bbox_coder.decode(
|
419 |
+
flatten_priors[None], flatten_rbbox_preds, flatten_stride)
|
420 |
+
else:
|
421 |
+
flatten_decoded_hbboxes = self.bbox_coder.decode(
|
422 |
+
flatten_priors[None], flatten_bbox_preds, flatten_stride)
|
423 |
+
flatten_decoded_hbboxes = HorizontalBoxes.xyxy_to_cxcywh(
|
424 |
+
flatten_decoded_hbboxes)
|
425 |
+
flatten_decoded_bboxes = torch.cat(
|
426 |
+
[flatten_decoded_hbboxes, flatten_angle_preds], dim=-1)
|
427 |
+
|
428 |
+
if with_objectnesses:
|
429 |
+
flatten_objectness = [
|
430 |
+
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
|
431 |
+
for objectness in objectnesses
|
432 |
+
]
|
433 |
+
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
|
434 |
+
else:
|
435 |
+
flatten_objectness = [None for _ in range(num_imgs)]
|
436 |
+
|
437 |
+
results_list = []
|
438 |
+
for (bboxes, scores, objectness,
|
439 |
+
img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
|
440 |
+
flatten_objectness, batch_img_metas):
|
441 |
+
scale_factor = img_meta['scale_factor']
|
442 |
+
if 'pad_param' in img_meta:
|
443 |
+
pad_param = img_meta['pad_param']
|
444 |
+
else:
|
445 |
+
pad_param = None
|
446 |
+
|
447 |
+
score_thr = cfg.get('score_thr', -1)
|
448 |
+
# yolox_style does not require the following operations
|
449 |
+
if objectness is not None and score_thr > 0 and not cfg.get(
|
450 |
+
'yolox_style', False):
|
451 |
+
conf_inds = objectness > score_thr
|
452 |
+
bboxes = bboxes[conf_inds, :]
|
453 |
+
scores = scores[conf_inds, :]
|
454 |
+
objectness = objectness[conf_inds]
|
455 |
+
|
456 |
+
if objectness is not None:
|
457 |
+
# conf = obj_conf * cls_conf
|
458 |
+
scores *= objectness[:, None]
|
459 |
+
|
460 |
+
if scores.shape[0] == 0:
|
461 |
+
empty_results = InstanceData()
|
462 |
+
empty_results.bboxes = RotatedBoxes(bboxes)
|
463 |
+
empty_results.scores = scores[:, 0]
|
464 |
+
empty_results.labels = scores[:, 0].int()
|
465 |
+
results_list.append(empty_results)
|
466 |
+
continue
|
467 |
+
|
468 |
+
nms_pre = cfg.get('nms_pre', 100000)
|
469 |
+
if cfg.multi_label is False:
|
470 |
+
scores, labels = scores.max(1, keepdim=True)
|
471 |
+
scores, _, keep_idxs, results = filter_scores_and_topk(
|
472 |
+
scores,
|
473 |
+
score_thr,
|
474 |
+
nms_pre,
|
475 |
+
results=dict(labels=labels[:, 0]))
|
476 |
+
labels = results['labels']
|
477 |
+
else:
|
478 |
+
scores, labels, keep_idxs, _ = filter_scores_and_topk(
|
479 |
+
scores, score_thr, nms_pre)
|
480 |
+
|
481 |
+
results = InstanceData(
|
482 |
+
scores=scores,
|
483 |
+
labels=labels,
|
484 |
+
bboxes=RotatedBoxes(bboxes[keep_idxs]))
|
485 |
+
|
486 |
+
if rescale:
|
487 |
+
if pad_param is not None:
|
488 |
+
results.bboxes.translate_([-pad_param[2], -pad_param[0]])
|
489 |
+
|
490 |
+
scale_factor = [1 / s for s in img_meta['scale_factor']]
|
491 |
+
results.bboxes = scale_boxes(results.bboxes, scale_factor)
|
492 |
+
|
493 |
+
if cfg.get('yolox_style', False):
|
494 |
+
# do not need max_per_img
|
495 |
+
cfg.max_per_img = len(results)
|
496 |
+
|
497 |
+
results = self._bbox_post_process(
|
498 |
+
results=results,
|
499 |
+
cfg=cfg,
|
500 |
+
rescale=False,
|
501 |
+
with_nms=with_nms,
|
502 |
+
img_meta=img_meta)
|
503 |
+
|
504 |
+
results_list.append(results)
|
505 |
+
return results_list
|
506 |
+
|
507 |
+
def loss_by_feat(
|
508 |
+
self,
|
509 |
+
cls_scores: List[Tensor],
|
510 |
+
bbox_preds: List[Tensor],
|
511 |
+
angle_preds: List[Tensor],
|
512 |
+
batch_gt_instances: InstanceList,
|
513 |
+
batch_img_metas: List[dict],
|
514 |
+
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
515 |
+
"""Compute losses of the head.
|
516 |
+
|
517 |
+
Args:
|
518 |
+
cls_scores (list[Tensor]): Box scores for each scale level
|
519 |
+
Has shape (N, num_anchors * num_classes, H, W)
|
520 |
+
bbox_preds (list[Tensor]): Decoded box for each scale
|
521 |
+
level with shape (N, num_anchors * 4, H, W) in
|
522 |
+
[tl_x, tl_y, br_x, br_y] format.
|
523 |
+
angle_preds (list[Tensor]): Angle prediction for each scale
|
524 |
+
level with shape (N, num_anchors * angle_out_dim, H, W).
|
525 |
+
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
526 |
+
gt_instance. It usually includes ``bboxes`` and ``labels``
|
527 |
+
attributes.
|
528 |
+
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
529 |
+
image size, scaling factor, etc.
|
530 |
+
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
|
531 |
+
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
532 |
+
data that is ignored during training and testing.
|
533 |
+
Defaults to None.
|
534 |
+
|
535 |
+
Returns:
|
536 |
+
dict[str, Tensor]: A dictionary of loss components.
|
537 |
+
"""
|
538 |
+
num_imgs = len(batch_img_metas)
|
539 |
+
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
|
540 |
+
assert len(featmap_sizes) == self.prior_generator.num_levels
|
541 |
+
|
542 |
+
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
|
543 |
+
gt_labels = gt_info[:, :, :1]
|
544 |
+
gt_bboxes = gt_info[:, :, 1:] # xywha
|
545 |
+
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
|
546 |
+
|
547 |
+
device = cls_scores[0].device
|
548 |
+
|
549 |
+
# If the shape does not equal, generate new one
|
550 |
+
if featmap_sizes != self.featmap_sizes_train:
|
551 |
+
self.featmap_sizes_train = featmap_sizes
|
552 |
+
mlvl_priors_with_stride = self.prior_generator.grid_priors(
|
553 |
+
featmap_sizes, device=device, with_stride=True)
|
554 |
+
self.flatten_priors_train = torch.cat(
|
555 |
+
mlvl_priors_with_stride, dim=0)
|
556 |
+
|
557 |
+
flatten_cls_scores = torch.cat([
|
558 |
+
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
559 |
+
self.cls_out_channels)
|
560 |
+
for cls_score in cls_scores
|
561 |
+
], 1).contiguous()
|
562 |
+
|
563 |
+
flatten_tblrs = torch.cat([
|
564 |
+
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
565 |
+
for bbox_pred in bbox_preds
|
566 |
+
], 1)
|
567 |
+
flatten_tblrs = flatten_tblrs * self.flatten_priors_train[..., -1,
|
568 |
+
None]
|
569 |
+
flatten_angles = torch.cat([
|
570 |
+
angle_pred.permute(0, 2, 3, 1).reshape(
|
571 |
+
num_imgs, -1, self.angle_out_dim) for angle_pred in angle_preds
|
572 |
+
], 1)
|
573 |
+
flatten_decoded_angle = self.angle_coder.decode(
|
574 |
+
flatten_angles, keepdim=True)
|
575 |
+
flatten_tblra = torch.cat([flatten_tblrs, flatten_decoded_angle],
|
576 |
+
dim=-1)
|
577 |
+
flatten_rbboxes = distance2obb(
|
578 |
+
self.flatten_priors_train[..., :2],
|
579 |
+
flatten_tblra,
|
580 |
+
angle_version=self.angle_version)
|
581 |
+
if self.use_hbbox_loss:
|
582 |
+
flatten_hbboxes = distance2bbox(self.flatten_priors_train[..., :2],
|
583 |
+
flatten_tblrs)
|
584 |
+
|
585 |
+
assigned_result = self.assigner(flatten_rbboxes.detach(),
|
586 |
+
flatten_cls_scores.detach(),
|
587 |
+
self.flatten_priors_train, gt_labels,
|
588 |
+
gt_bboxes, pad_bbox_flag)
|
589 |
+
|
590 |
+
labels = assigned_result['assigned_labels'].reshape(-1)
|
591 |
+
label_weights = assigned_result['assigned_labels_weights'].reshape(-1)
|
592 |
+
bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 5)
|
593 |
+
assign_metrics = assigned_result['assign_metrics'].reshape(-1)
|
594 |
+
cls_preds = flatten_cls_scores.reshape(-1, self.num_classes)
|
595 |
+
|
596 |
+
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
|
597 |
+
bg_class_ind = self.num_classes
|
598 |
+
pos_inds = ((labels >= 0)
|
599 |
+
& (labels < bg_class_ind)).nonzero().squeeze(1)
|
600 |
+
avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item()
|
601 |
+
|
602 |
+
loss_cls = self.loss_cls(
|
603 |
+
cls_preds, (labels, assign_metrics),
|
604 |
+
label_weights,
|
605 |
+
avg_factor=avg_factor)
|
606 |
+
|
607 |
+
pos_bbox_targets = bbox_targets[pos_inds]
|
608 |
+
|
609 |
+
if self.use_hbbox_loss:
|
610 |
+
bbox_preds = flatten_hbboxes.reshape(-1, 4)
|
611 |
+
pos_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets[:, :4])
|
612 |
+
else:
|
613 |
+
bbox_preds = flatten_rbboxes.reshape(-1, 5)
|
614 |
+
angle_preds = flatten_angles.reshape(-1, self.angle_out_dim)
|
615 |
+
|
616 |
+
if len(pos_inds) > 0:
|
617 |
+
loss_bbox = self.loss_bbox(
|
618 |
+
bbox_preds[pos_inds],
|
619 |
+
pos_bbox_targets,
|
620 |
+
weight=assign_metrics[pos_inds],
|
621 |
+
avg_factor=avg_factor)
|
622 |
+
loss_angle = angle_preds.sum() * 0
|
623 |
+
if self.loss_angle is not None:
|
624 |
+
pos_angle_targets = bbox_targets[pos_inds][:, 4:5]
|
625 |
+
pos_angle_targets = self.angle_coder.encode(pos_angle_targets)
|
626 |
+
loss_angle = self.loss_angle(
|
627 |
+
angle_preds[pos_inds],
|
628 |
+
pos_angle_targets,
|
629 |
+
weight=assign_metrics[pos_inds],
|
630 |
+
avg_factor=avg_factor)
|
631 |
+
else:
|
632 |
+
loss_bbox = bbox_preds.sum() * 0
|
633 |
+
loss_angle = angle_preds.sum() * 0
|
634 |
+
|
635 |
+
losses = dict()
|
636 |
+
losses['loss_cls'] = loss_cls
|
637 |
+
losses['loss_bbox'] = loss_bbox
|
638 |
+
if self.loss_angle is not None:
|
639 |
+
losses['loss_angle'] = loss_angle
|
640 |
+
|
641 |
+
return losses
|
mmyolo/models/dense_heads/yolov5_head.py
ADDED
@@ -0,0 +1,890 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from mmdet.models.dense_heads.base_dense_head import BaseDenseHead
|
9 |
+
from mmdet.models.utils import filter_scores_and_topk, multi_apply
|
10 |
+
from mmdet.structures.bbox import bbox_overlaps
|
11 |
+
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
|
12 |
+
OptMultiConfig)
|
13 |
+
from mmengine.config import ConfigDict
|
14 |
+
from mmengine.dist import get_dist_info
|
15 |
+
from mmengine.logging import print_log
|
16 |
+
from mmengine.model import BaseModule
|
17 |
+
from mmengine.structures import InstanceData
|
18 |
+
from torch import Tensor
|
19 |
+
|
20 |
+
from mmyolo.registry import MODELS, TASK_UTILS
|
21 |
+
from ..utils import make_divisible
|
22 |
+
|
23 |
+
|
24 |
+
def get_prior_xy_info(index: int, num_base_priors: int,
|
25 |
+
featmap_sizes: int) -> Tuple[int, int, int]:
|
26 |
+
"""Get prior index and xy index in feature map by flatten index."""
|
27 |
+
_, featmap_w = featmap_sizes
|
28 |
+
priors = index % num_base_priors
|
29 |
+
xy_index = index // num_base_priors
|
30 |
+
grid_y = xy_index // featmap_w
|
31 |
+
grid_x = xy_index % featmap_w
|
32 |
+
return priors, grid_x, grid_y
|
33 |
+
|
34 |
+
|
35 |
+
@MODELS.register_module()
|
36 |
+
class YOLOv5HeadModule(BaseModule):
|
37 |
+
"""YOLOv5Head head module used in `YOLOv5`.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
num_classes (int): Number of categories excluding the background
|
41 |
+
category.
|
42 |
+
in_channels (Union[int, Sequence]): Number of channels in the input
|
43 |
+
feature map.
|
44 |
+
widen_factor (float): Width multiplier, multiply number of
|
45 |
+
channels in each layer by this amount. Defaults to 1.0.
|
46 |
+
num_base_priors (int): The number of priors (points) at a point
|
47 |
+
on the feature grid.
|
48 |
+
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
49 |
+
Defaults to (8, 16, 32).
|
50 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
51 |
+
list[dict], optional): Initialization config dict.
|
52 |
+
Defaults to None.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self,
|
56 |
+
num_classes: int,
|
57 |
+
in_channels: Union[int, Sequence],
|
58 |
+
widen_factor: float = 1.0,
|
59 |
+
num_base_priors: int = 3,
|
60 |
+
featmap_strides: Sequence[int] = (8, 16, 32),
|
61 |
+
init_cfg: OptMultiConfig = None):
|
62 |
+
super().__init__(init_cfg=init_cfg)
|
63 |
+
self.num_classes = num_classes
|
64 |
+
self.widen_factor = widen_factor
|
65 |
+
|
66 |
+
self.featmap_strides = featmap_strides
|
67 |
+
self.num_out_attrib = 5 + self.num_classes
|
68 |
+
self.num_levels = len(self.featmap_strides)
|
69 |
+
self.num_base_priors = num_base_priors
|
70 |
+
|
71 |
+
if isinstance(in_channels, int):
|
72 |
+
self.in_channels = [make_divisible(in_channels, widen_factor)
|
73 |
+
] * self.num_levels
|
74 |
+
else:
|
75 |
+
self.in_channels = [
|
76 |
+
make_divisible(i, widen_factor) for i in in_channels
|
77 |
+
]
|
78 |
+
|
79 |
+
self._init_layers()
|
80 |
+
|
81 |
+
def _init_layers(self):
|
82 |
+
"""initialize conv layers in YOLOv5 head."""
|
83 |
+
self.convs_pred = nn.ModuleList()
|
84 |
+
for i in range(self.num_levels):
|
85 |
+
conv_pred = nn.Conv2d(self.in_channels[i],
|
86 |
+
self.num_base_priors * self.num_out_attrib,
|
87 |
+
1)
|
88 |
+
|
89 |
+
self.convs_pred.append(conv_pred)
|
90 |
+
|
91 |
+
def init_weights(self):
|
92 |
+
"""Initialize the bias of YOLOv5 head."""
|
93 |
+
super().init_weights()
|
94 |
+
for mi, s in zip(self.convs_pred, self.featmap_strides): # from
|
95 |
+
b = mi.bias.data.view(self.num_base_priors, -1)
|
96 |
+
# obj (8 objects per 640 image)
|
97 |
+
b.data[:, 4] += math.log(8 / (640 / s)**2)
|
98 |
+
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.999999))
|
99 |
+
|
100 |
+
mi.bias.data = b.view(-1)
|
101 |
+
|
102 |
+
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
103 |
+
"""Forward features from the upstream network.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
x (Tuple[Tensor]): Features from the upstream network, each is
|
107 |
+
a 4D-tensor.
|
108 |
+
Returns:
|
109 |
+
Tuple[List]: A tuple of multi-level classification scores, bbox
|
110 |
+
predictions, and objectnesses.
|
111 |
+
"""
|
112 |
+
assert len(x) == self.num_levels
|
113 |
+
return multi_apply(self.forward_single, x, self.convs_pred)
|
114 |
+
|
115 |
+
def forward_single(self, x: Tensor,
|
116 |
+
convs: nn.Module) -> Tuple[Tensor, Tensor, Tensor]:
|
117 |
+
"""Forward feature of a single scale level."""
|
118 |
+
|
119 |
+
pred_map = convs(x)
|
120 |
+
bs, _, ny, nx = pred_map.shape
|
121 |
+
pred_map = pred_map.view(bs, self.num_base_priors, self.num_out_attrib,
|
122 |
+
ny, nx)
|
123 |
+
|
124 |
+
cls_score = pred_map[:, :, 5:, ...].reshape(bs, -1, ny, nx)
|
125 |
+
bbox_pred = pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
|
126 |
+
objectness = pred_map[:, :, 4:5, ...].reshape(bs, -1, ny, nx)
|
127 |
+
|
128 |
+
return cls_score, bbox_pred, objectness
|
129 |
+
|
130 |
+
|
131 |
+
@MODELS.register_module()
|
132 |
+
class YOLOv5Head(BaseDenseHead):
|
133 |
+
"""YOLOv5Head head used in `YOLOv5`.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
head_module(ConfigType): Base module used for YOLOv5Head
|
137 |
+
prior_generator(dict): Points generator feature maps in
|
138 |
+
2D points-based detectors.
|
139 |
+
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
|
140 |
+
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
141 |
+
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
142 |
+
loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
|
143 |
+
prior_match_thr (float): Defaults to 4.0.
|
144 |
+
ignore_iof_thr (float): Defaults to -1.0.
|
145 |
+
obj_level_weights (List[float]): Defaults to [4.0, 1.0, 0.4].
|
146 |
+
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
147 |
+
anchor head. Defaults to None.
|
148 |
+
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
149 |
+
anchor head. Defaults to None.
|
150 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
151 |
+
list[dict], optional): Initialization config dict.
|
152 |
+
Defaults to None.
|
153 |
+
"""
|
154 |
+
|
155 |
+
def __init__(self,
|
156 |
+
head_module: ConfigType,
|
157 |
+
prior_generator: ConfigType = dict(
|
158 |
+
type='mmdet.YOLOAnchorGenerator',
|
159 |
+
base_sizes=[[(10, 13), (16, 30), (33, 23)],
|
160 |
+
[(30, 61), (62, 45), (59, 119)],
|
161 |
+
[(116, 90), (156, 198), (373, 326)]],
|
162 |
+
strides=[8, 16, 32]),
|
163 |
+
bbox_coder: ConfigType = dict(type='YOLOv5BBoxCoder'),
|
164 |
+
loss_cls: ConfigType = dict(
|
165 |
+
type='mmdet.CrossEntropyLoss',
|
166 |
+
use_sigmoid=True,
|
167 |
+
reduction='mean',
|
168 |
+
loss_weight=0.5),
|
169 |
+
loss_bbox: ConfigType = dict(
|
170 |
+
type='IoULoss',
|
171 |
+
iou_mode='ciou',
|
172 |
+
bbox_format='xywh',
|
173 |
+
eps=1e-7,
|
174 |
+
reduction='mean',
|
175 |
+
loss_weight=0.05,
|
176 |
+
return_iou=True),
|
177 |
+
loss_obj: ConfigType = dict(
|
178 |
+
type='mmdet.CrossEntropyLoss',
|
179 |
+
use_sigmoid=True,
|
180 |
+
reduction='mean',
|
181 |
+
loss_weight=1.0),
|
182 |
+
prior_match_thr: float = 4.0,
|
183 |
+
near_neighbor_thr: float = 0.5,
|
184 |
+
ignore_iof_thr: float = -1.0,
|
185 |
+
obj_level_weights: List[float] = [4.0, 1.0, 0.4],
|
186 |
+
train_cfg: OptConfigType = None,
|
187 |
+
test_cfg: OptConfigType = None,
|
188 |
+
init_cfg: OptMultiConfig = None):
|
189 |
+
super().__init__(init_cfg=init_cfg)
|
190 |
+
|
191 |
+
self.head_module = MODELS.build(head_module)
|
192 |
+
self.num_classes = self.head_module.num_classes
|
193 |
+
self.featmap_strides = self.head_module.featmap_strides
|
194 |
+
self.num_levels = len(self.featmap_strides)
|
195 |
+
|
196 |
+
self.train_cfg = train_cfg
|
197 |
+
self.test_cfg = test_cfg
|
198 |
+
|
199 |
+
self.loss_cls: nn.Module = MODELS.build(loss_cls)
|
200 |
+
self.loss_bbox: nn.Module = MODELS.build(loss_bbox)
|
201 |
+
self.loss_obj: nn.Module = MODELS.build(loss_obj)
|
202 |
+
|
203 |
+
self.prior_generator = TASK_UTILS.build(prior_generator)
|
204 |
+
self.bbox_coder = TASK_UTILS.build(bbox_coder)
|
205 |
+
self.num_base_priors = self.prior_generator.num_base_priors[0]
|
206 |
+
|
207 |
+
self.featmap_sizes = [torch.empty(1)] * self.num_levels
|
208 |
+
|
209 |
+
self.prior_match_thr = prior_match_thr
|
210 |
+
self.near_neighbor_thr = near_neighbor_thr
|
211 |
+
self.obj_level_weights = obj_level_weights
|
212 |
+
self.ignore_iof_thr = ignore_iof_thr
|
213 |
+
|
214 |
+
self.special_init()
|
215 |
+
|
216 |
+
def special_init(self):
|
217 |
+
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
|
218 |
+
different algorithms have special initialization process.
|
219 |
+
|
220 |
+
The special_init function is designed to deal with this situation.
|
221 |
+
"""
|
222 |
+
assert len(self.obj_level_weights) == len(
|
223 |
+
self.featmap_strides) == self.num_levels
|
224 |
+
if self.prior_match_thr != 4.0:
|
225 |
+
print_log(
|
226 |
+
"!!!Now, you've changed the prior_match_thr "
|
227 |
+
'parameter to something other than 4.0. Please make sure '
|
228 |
+
'that you have modified both the regression formula in '
|
229 |
+
'bbox_coder and before loss_box computation, '
|
230 |
+
'otherwise the accuracy may be degraded!!!')
|
231 |
+
|
232 |
+
if self.num_classes == 1:
|
233 |
+
print_log('!!!You are using `YOLOv5Head` with num_classes == 1.'
|
234 |
+
' The loss_cls will be 0. This is a normal phenomenon.')
|
235 |
+
|
236 |
+
priors_base_sizes = torch.tensor(
|
237 |
+
self.prior_generator.base_sizes, dtype=torch.float)
|
238 |
+
featmap_strides = torch.tensor(
|
239 |
+
self.featmap_strides, dtype=torch.float)[:, None, None]
|
240 |
+
self.register_buffer(
|
241 |
+
'priors_base_sizes',
|
242 |
+
priors_base_sizes / featmap_strides,
|
243 |
+
persistent=False)
|
244 |
+
|
245 |
+
grid_offset = torch.tensor([
|
246 |
+
[0, 0], # center
|
247 |
+
[1, 0], # left
|
248 |
+
[0, 1], # up
|
249 |
+
[-1, 0], # right
|
250 |
+
[0, -1], # bottom
|
251 |
+
]).float()
|
252 |
+
self.register_buffer(
|
253 |
+
'grid_offset', grid_offset[:, None], persistent=False)
|
254 |
+
|
255 |
+
prior_inds = torch.arange(self.num_base_priors).float().view(
|
256 |
+
self.num_base_priors, 1)
|
257 |
+
self.register_buffer('prior_inds', prior_inds, persistent=False)
|
258 |
+
|
259 |
+
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
260 |
+
"""Forward features from the upstream network.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
x (Tuple[Tensor]): Features from the upstream network, each is
|
264 |
+
a 4D-tensor.
|
265 |
+
Returns:
|
266 |
+
Tuple[List]: A tuple of multi-level classification scores, bbox
|
267 |
+
predictions, and objectnesses.
|
268 |
+
"""
|
269 |
+
return self.head_module(x)
|
270 |
+
|
271 |
+
def predict_by_feat(self,
|
272 |
+
cls_scores: List[Tensor],
|
273 |
+
bbox_preds: List[Tensor],
|
274 |
+
objectnesses: Optional[List[Tensor]] = None,
|
275 |
+
batch_img_metas: Optional[List[dict]] = None,
|
276 |
+
cfg: Optional[ConfigDict] = None,
|
277 |
+
rescale: bool = True,
|
278 |
+
with_nms: bool = True) -> List[InstanceData]:
|
279 |
+
"""Transform a batch of output features extracted by the head into
|
280 |
+
bbox results.
|
281 |
+
Args:
|
282 |
+
cls_scores (list[Tensor]): Classification scores for all
|
283 |
+
scale levels, each is a 4D-tensor, has shape
|
284 |
+
(batch_size, num_priors * num_classes, H, W).
|
285 |
+
bbox_preds (list[Tensor]): Box energies / deltas for all
|
286 |
+
scale levels, each is a 4D-tensor, has shape
|
287 |
+
(batch_size, num_priors * 4, H, W).
|
288 |
+
objectnesses (list[Tensor], Optional): Score factor for
|
289 |
+
all scale level, each is a 4D-tensor, has shape
|
290 |
+
(batch_size, 1, H, W).
|
291 |
+
batch_img_metas (list[dict], Optional): Batch image meta info.
|
292 |
+
Defaults to None.
|
293 |
+
cfg (ConfigDict, optional): Test / postprocessing
|
294 |
+
configuration, if None, test_cfg would be used.
|
295 |
+
Defaults to None.
|
296 |
+
rescale (bool): If True, return boxes in original image space.
|
297 |
+
Defaults to False.
|
298 |
+
with_nms (bool): If True, do nms before return boxes.
|
299 |
+
Defaults to True.
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
list[:obj:`InstanceData`]: Object detection results of each image
|
303 |
+
after the post process. Each item usually contains following keys.
|
304 |
+
|
305 |
+
- scores (Tensor): Classification scores, has a shape
|
306 |
+
(num_instance, )
|
307 |
+
- labels (Tensor): Labels of bboxes, has a shape
|
308 |
+
(num_instances, ).
|
309 |
+
- bboxes (Tensor): Has a shape (num_instances, 4),
|
310 |
+
the last dimension 4 arrange as (x1, y1, x2, y2).
|
311 |
+
"""
|
312 |
+
assert len(cls_scores) == len(bbox_preds)
|
313 |
+
if objectnesses is None:
|
314 |
+
with_objectnesses = False
|
315 |
+
else:
|
316 |
+
with_objectnesses = True
|
317 |
+
assert len(cls_scores) == len(objectnesses)
|
318 |
+
|
319 |
+
cfg = self.test_cfg if cfg is None else cfg
|
320 |
+
cfg = copy.deepcopy(cfg)
|
321 |
+
|
322 |
+
multi_label = cfg.multi_label
|
323 |
+
multi_label &= self.num_classes > 1
|
324 |
+
cfg.multi_label = multi_label
|
325 |
+
|
326 |
+
num_imgs = len(batch_img_metas)
|
327 |
+
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
328 |
+
|
329 |
+
# If the shape does not change, use the previous mlvl_priors
|
330 |
+
if featmap_sizes != self.featmap_sizes:
|
331 |
+
self.mlvl_priors = self.prior_generator.grid_priors(
|
332 |
+
featmap_sizes,
|
333 |
+
dtype=cls_scores[0].dtype,
|
334 |
+
device=cls_scores[0].device)
|
335 |
+
self.featmap_sizes = featmap_sizes
|
336 |
+
flatten_priors = torch.cat(self.mlvl_priors)
|
337 |
+
|
338 |
+
mlvl_strides = [
|
339 |
+
flatten_priors.new_full(
|
340 |
+
(featmap_size.numel() * self.num_base_priors, ), stride) for
|
341 |
+
featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
|
342 |
+
]
|
343 |
+
flatten_stride = torch.cat(mlvl_strides)
|
344 |
+
|
345 |
+
# flatten cls_scores, bbox_preds and objectness
|
346 |
+
flatten_cls_scores = [
|
347 |
+
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
348 |
+
self.num_classes)
|
349 |
+
for cls_score in cls_scores
|
350 |
+
]
|
351 |
+
flatten_bbox_preds = [
|
352 |
+
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
353 |
+
for bbox_pred in bbox_preds
|
354 |
+
]
|
355 |
+
|
356 |
+
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
|
357 |
+
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
358 |
+
flatten_decoded_bboxes = self.bbox_coder.decode(
|
359 |
+
flatten_priors[None], flatten_bbox_preds, flatten_stride)
|
360 |
+
|
361 |
+
if with_objectnesses:
|
362 |
+
flatten_objectness = [
|
363 |
+
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
|
364 |
+
for objectness in objectnesses
|
365 |
+
]
|
366 |
+
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
|
367 |
+
else:
|
368 |
+
flatten_objectness = [None for _ in range(num_imgs)]
|
369 |
+
|
370 |
+
results_list = []
|
371 |
+
for (bboxes, scores, objectness,
|
372 |
+
img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
|
373 |
+
flatten_objectness, batch_img_metas):
|
374 |
+
ori_shape = img_meta['ori_shape']
|
375 |
+
scale_factor = img_meta['scale_factor']
|
376 |
+
if 'pad_param' in img_meta:
|
377 |
+
pad_param = img_meta['pad_param']
|
378 |
+
else:
|
379 |
+
pad_param = None
|
380 |
+
|
381 |
+
score_thr = cfg.get('score_thr', -1)
|
382 |
+
# yolox_style does not require the following operations
|
383 |
+
if objectness is not None and score_thr > 0 and not cfg.get(
|
384 |
+
'yolox_style', False):
|
385 |
+
conf_inds = objectness > score_thr
|
386 |
+
bboxes = bboxes[conf_inds, :]
|
387 |
+
scores = scores[conf_inds, :]
|
388 |
+
objectness = objectness[conf_inds]
|
389 |
+
|
390 |
+
if objectness is not None:
|
391 |
+
# conf = obj_conf * cls_conf
|
392 |
+
scores *= objectness[:, None]
|
393 |
+
|
394 |
+
if scores.shape[0] == 0:
|
395 |
+
empty_results = InstanceData()
|
396 |
+
empty_results.bboxes = bboxes
|
397 |
+
empty_results.scores = scores[:, 0]
|
398 |
+
empty_results.labels = scores[:, 0].int()
|
399 |
+
results_list.append(empty_results)
|
400 |
+
continue
|
401 |
+
|
402 |
+
nms_pre = cfg.get('nms_pre', 100000)
|
403 |
+
if cfg.multi_label is False:
|
404 |
+
scores, labels = scores.max(1, keepdim=True)
|
405 |
+
scores, _, keep_idxs, results = filter_scores_and_topk(
|
406 |
+
scores,
|
407 |
+
score_thr,
|
408 |
+
nms_pre,
|
409 |
+
results=dict(labels=labels[:, 0]))
|
410 |
+
labels = results['labels']
|
411 |
+
else:
|
412 |
+
scores, labels, keep_idxs, _ = filter_scores_and_topk(
|
413 |
+
scores, score_thr, nms_pre)
|
414 |
+
|
415 |
+
results = InstanceData(
|
416 |
+
scores=scores, labels=labels, bboxes=bboxes[keep_idxs])
|
417 |
+
|
418 |
+
if rescale:
|
419 |
+
if pad_param is not None:
|
420 |
+
results.bboxes -= results.bboxes.new_tensor([
|
421 |
+
pad_param[2], pad_param[0], pad_param[2], pad_param[0]
|
422 |
+
])
|
423 |
+
results.bboxes /= results.bboxes.new_tensor(
|
424 |
+
scale_factor).repeat((1, 2))
|
425 |
+
|
426 |
+
if cfg.get('yolox_style', False):
|
427 |
+
# do not need max_per_img
|
428 |
+
cfg.max_per_img = len(results)
|
429 |
+
|
430 |
+
results = self._bbox_post_process(
|
431 |
+
results=results,
|
432 |
+
cfg=cfg,
|
433 |
+
rescale=False,
|
434 |
+
with_nms=with_nms,
|
435 |
+
img_meta=img_meta)
|
436 |
+
results.bboxes[:, 0::2].clamp_(0, ori_shape[1])
|
437 |
+
results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
|
438 |
+
|
439 |
+
results_list.append(results)
|
440 |
+
return results_list
|
441 |
+
|
442 |
+
def loss(self, x: Tuple[Tensor], batch_data_samples: Union[list,
|
443 |
+
dict]) -> dict:
|
444 |
+
"""Perform forward propagation and loss calculation of the detection
|
445 |
+
head on the features of the upstream network.
|
446 |
+
|
447 |
+
Args:
|
448 |
+
x (tuple[Tensor]): Features from the upstream network, each is
|
449 |
+
a 4D-tensor.
|
450 |
+
batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
|
451 |
+
Samples. It usually includes information such as
|
452 |
+
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
|
453 |
+
|
454 |
+
Returns:
|
455 |
+
dict: A dictionary of loss components.
|
456 |
+
"""
|
457 |
+
|
458 |
+
if isinstance(batch_data_samples, list):
|
459 |
+
losses = super().loss(x, batch_data_samples)
|
460 |
+
else:
|
461 |
+
outs = self(x)
|
462 |
+
# Fast version
|
463 |
+
loss_inputs = outs + (batch_data_samples['bboxes_labels'],
|
464 |
+
batch_data_samples['img_metas'])
|
465 |
+
losses = self.loss_by_feat(*loss_inputs)
|
466 |
+
|
467 |
+
return losses
|
468 |
+
|
469 |
+
def loss_by_feat(
|
470 |
+
self,
|
471 |
+
cls_scores: Sequence[Tensor],
|
472 |
+
bbox_preds: Sequence[Tensor],
|
473 |
+
objectnesses: Sequence[Tensor],
|
474 |
+
batch_gt_instances: Sequence[InstanceData],
|
475 |
+
batch_img_metas: Sequence[dict],
|
476 |
+
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
477 |
+
"""Calculate the loss based on the features extracted by the detection
|
478 |
+
head.
|
479 |
+
|
480 |
+
Args:
|
481 |
+
cls_scores (Sequence[Tensor]): Box scores for each scale level,
|
482 |
+
each is a 4D-tensor, the channel number is
|
483 |
+
num_priors * num_classes.
|
484 |
+
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
|
485 |
+
level, each is a 4D-tensor, the channel number is
|
486 |
+
num_priors * 4.
|
487 |
+
objectnesses (Sequence[Tensor]): Score factor for
|
488 |
+
all scale level, each is a 4D-tensor, has shape
|
489 |
+
(batch_size, 1, H, W).
|
490 |
+
batch_gt_instances (Sequence[InstanceData]): Batch of
|
491 |
+
gt_instance. It usually includes ``bboxes`` and ``labels``
|
492 |
+
attributes.
|
493 |
+
batch_img_metas (Sequence[dict]): Meta information of each image,
|
494 |
+
e.g., image size, scaling factor, etc.
|
495 |
+
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
|
496 |
+
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
497 |
+
data that is ignored during training and testing.
|
498 |
+
Defaults to None.
|
499 |
+
Returns:
|
500 |
+
dict[str, Tensor]: A dictionary of losses.
|
501 |
+
"""
|
502 |
+
if self.ignore_iof_thr != -1:
|
503 |
+
# TODO: Support fast version
|
504 |
+
# convert ignore gt
|
505 |
+
batch_target_ignore_list = []
|
506 |
+
for i, gt_instances_ignore in enumerate(batch_gt_instances_ignore):
|
507 |
+
bboxes = gt_instances_ignore.bboxes
|
508 |
+
labels = gt_instances_ignore.labels
|
509 |
+
index = bboxes.new_full((len(bboxes), 1), i)
|
510 |
+
# (batch_idx, label, bboxes)
|
511 |
+
target = torch.cat((index, labels[:, None].float(), bboxes),
|
512 |
+
dim=1)
|
513 |
+
batch_target_ignore_list.append(target)
|
514 |
+
|
515 |
+
# (num_bboxes, 6)
|
516 |
+
batch_gt_targets_ignore = torch.cat(
|
517 |
+
batch_target_ignore_list, dim=0)
|
518 |
+
if batch_gt_targets_ignore.shape[0] != 0:
|
519 |
+
# Consider regions with ignore in annotations
|
520 |
+
return self._loss_by_feat_with_ignore(
|
521 |
+
cls_scores,
|
522 |
+
bbox_preds,
|
523 |
+
objectnesses,
|
524 |
+
batch_gt_instances=batch_gt_instances,
|
525 |
+
batch_img_metas=batch_img_metas,
|
526 |
+
batch_gt_instances_ignore=batch_gt_targets_ignore)
|
527 |
+
|
528 |
+
# 1. Convert gt to norm format
|
529 |
+
batch_targets_normed = self._convert_gt_to_norm_format(
|
530 |
+
batch_gt_instances, batch_img_metas)
|
531 |
+
|
532 |
+
device = cls_scores[0].device
|
533 |
+
loss_cls = torch.zeros(1, device=device)
|
534 |
+
loss_box = torch.zeros(1, device=device)
|
535 |
+
loss_obj = torch.zeros(1, device=device)
|
536 |
+
scaled_factor = torch.ones(7, device=device)
|
537 |
+
|
538 |
+
for i in range(self.num_levels):
|
539 |
+
batch_size, _, h, w = bbox_preds[i].shape
|
540 |
+
target_obj = torch.zeros_like(objectnesses[i])
|
541 |
+
|
542 |
+
# empty gt bboxes
|
543 |
+
if batch_targets_normed.shape[1] == 0:
|
544 |
+
loss_box += bbox_preds[i].sum() * 0
|
545 |
+
loss_cls += cls_scores[i].sum() * 0
|
546 |
+
loss_obj += self.loss_obj(
|
547 |
+
objectnesses[i], target_obj) * self.obj_level_weights[i]
|
548 |
+
continue
|
549 |
+
|
550 |
+
priors_base_sizes_i = self.priors_base_sizes[i]
|
551 |
+
# feature map scale whwh
|
552 |
+
scaled_factor[2:6] = torch.tensor(
|
553 |
+
bbox_preds[i].shape)[[3, 2, 3, 2]]
|
554 |
+
# Scale batch_targets from range 0-1 to range 0-features_maps size.
|
555 |
+
# (num_base_priors, num_bboxes, 7)
|
556 |
+
batch_targets_scaled = batch_targets_normed * scaled_factor
|
557 |
+
|
558 |
+
# 2. Shape match
|
559 |
+
wh_ratio = batch_targets_scaled[...,
|
560 |
+
4:6] / priors_base_sizes_i[:, None]
|
561 |
+
match_inds = torch.max(
|
562 |
+
wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr
|
563 |
+
batch_targets_scaled = batch_targets_scaled[match_inds]
|
564 |
+
|
565 |
+
# no gt bbox matches anchor
|
566 |
+
if batch_targets_scaled.shape[0] == 0:
|
567 |
+
loss_box += bbox_preds[i].sum() * 0
|
568 |
+
loss_cls += cls_scores[i].sum() * 0
|
569 |
+
loss_obj += self.loss_obj(
|
570 |
+
objectnesses[i], target_obj) * self.obj_level_weights[i]
|
571 |
+
continue
|
572 |
+
|
573 |
+
# 3. Positive samples with additional neighbors
|
574 |
+
|
575 |
+
# check the left, up, right, bottom sides of the
|
576 |
+
# targets grid, and determine whether assigned
|
577 |
+
# them as positive samples as well.
|
578 |
+
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
|
579 |
+
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
|
580 |
+
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
|
581 |
+
(batch_targets_cxcy > 1)).T
|
582 |
+
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
|
583 |
+
(grid_xy > 1)).T
|
584 |
+
offset_inds = torch.stack(
|
585 |
+
(torch.ones_like(left), left, up, right, bottom))
|
586 |
+
|
587 |
+
batch_targets_scaled = batch_targets_scaled.repeat(
|
588 |
+
(5, 1, 1))[offset_inds]
|
589 |
+
retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1],
|
590 |
+
1)[offset_inds]
|
591 |
+
|
592 |
+
# prepare pred results and positive sample indexes to
|
593 |
+
# calculate class loss and bbox lo
|
594 |
+
_chunk_targets = batch_targets_scaled.chunk(4, 1)
|
595 |
+
img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets
|
596 |
+
priors_inds, (img_inds, class_inds) = priors_inds.long().view(
|
597 |
+
-1), img_class_inds.long().T
|
598 |
+
|
599 |
+
grid_xy_long = (grid_xy -
|
600 |
+
retained_offsets * self.near_neighbor_thr).long()
|
601 |
+
grid_x_inds, grid_y_inds = grid_xy_long.T
|
602 |
+
bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1)
|
603 |
+
|
604 |
+
# 4. Calculate loss
|
605 |
+
# bbox loss
|
606 |
+
retained_bbox_pred = bbox_preds[i].reshape(
|
607 |
+
batch_size, self.num_base_priors, -1, h,
|
608 |
+
w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
|
609 |
+
priors_base_sizes_i = priors_base_sizes_i[priors_inds]
|
610 |
+
decoded_bbox_pred = self._decode_bbox_to_xywh(
|
611 |
+
retained_bbox_pred, priors_base_sizes_i)
|
612 |
+
loss_box_i, iou = self.loss_bbox(decoded_bbox_pred, bboxes_targets)
|
613 |
+
loss_box += loss_box_i
|
614 |
+
|
615 |
+
# obj loss
|
616 |
+
iou = iou.detach().clamp(0)
|
617 |
+
target_obj[img_inds, priors_inds, grid_y_inds,
|
618 |
+
grid_x_inds] = iou.type(target_obj.dtype)
|
619 |
+
loss_obj += self.loss_obj(objectnesses[i],
|
620 |
+
target_obj) * self.obj_level_weights[i]
|
621 |
+
|
622 |
+
# cls loss
|
623 |
+
if self.num_classes > 1:
|
624 |
+
pred_cls_scores = cls_scores[i].reshape(
|
625 |
+
batch_size, self.num_base_priors, -1, h,
|
626 |
+
w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
|
627 |
+
|
628 |
+
target_class = torch.full_like(pred_cls_scores, 0.)
|
629 |
+
target_class[range(batch_targets_scaled.shape[0]),
|
630 |
+
class_inds] = 1.
|
631 |
+
loss_cls += self.loss_cls(pred_cls_scores, target_class)
|
632 |
+
else:
|
633 |
+
loss_cls += cls_scores[i].sum() * 0
|
634 |
+
|
635 |
+
_, world_size = get_dist_info()
|
636 |
+
return dict(
|
637 |
+
loss_cls=loss_cls * batch_size * world_size,
|
638 |
+
loss_obj=loss_obj * batch_size * world_size,
|
639 |
+
loss_bbox=loss_box * batch_size * world_size)
|
640 |
+
|
641 |
+
def _convert_gt_to_norm_format(self,
|
642 |
+
batch_gt_instances: Sequence[InstanceData],
|
643 |
+
batch_img_metas: Sequence[dict]) -> Tensor:
|
644 |
+
if isinstance(batch_gt_instances, torch.Tensor):
|
645 |
+
# fast version
|
646 |
+
img_shape = batch_img_metas[0]['batch_input_shape']
|
647 |
+
gt_bboxes_xyxy = batch_gt_instances[:, 2:]
|
648 |
+
xy1, xy2 = gt_bboxes_xyxy.split((2, 2), dim=-1)
|
649 |
+
gt_bboxes_xywh = torch.cat([(xy2 + xy1) / 2, (xy2 - xy1)], dim=-1)
|
650 |
+
gt_bboxes_xywh[:, 1::2] /= img_shape[0]
|
651 |
+
gt_bboxes_xywh[:, 0::2] /= img_shape[1]
|
652 |
+
batch_gt_instances[:, 2:] = gt_bboxes_xywh
|
653 |
+
|
654 |
+
# (num_base_priors, num_bboxes, 6)
|
655 |
+
batch_targets_normed = batch_gt_instances.repeat(
|
656 |
+
self.num_base_priors, 1, 1)
|
657 |
+
else:
|
658 |
+
batch_target_list = []
|
659 |
+
# Convert xyxy bbox to yolo format.
|
660 |
+
for i, gt_instances in enumerate(batch_gt_instances):
|
661 |
+
img_shape = batch_img_metas[i]['batch_input_shape']
|
662 |
+
bboxes = gt_instances.bboxes
|
663 |
+
labels = gt_instances.labels
|
664 |
+
|
665 |
+
xy1, xy2 = bboxes.split((2, 2), dim=-1)
|
666 |
+
bboxes = torch.cat([(xy2 + xy1) / 2, (xy2 - xy1)], dim=-1)
|
667 |
+
# normalized to 0-1
|
668 |
+
bboxes[:, 1::2] /= img_shape[0]
|
669 |
+
bboxes[:, 0::2] /= img_shape[1]
|
670 |
+
|
671 |
+
index = bboxes.new_full((len(bboxes), 1), i)
|
672 |
+
# (batch_idx, label, normed_bbox)
|
673 |
+
target = torch.cat((index, labels[:, None].float(), bboxes),
|
674 |
+
dim=1)
|
675 |
+
batch_target_list.append(target)
|
676 |
+
|
677 |
+
# (num_base_priors, num_bboxes, 6)
|
678 |
+
batch_targets_normed = torch.cat(
|
679 |
+
batch_target_list, dim=0).repeat(self.num_base_priors, 1, 1)
|
680 |
+
|
681 |
+
# (num_base_priors, num_bboxes, 1)
|
682 |
+
batch_targets_prior_inds = self.prior_inds.repeat(
|
683 |
+
1, batch_targets_normed.shape[1])[..., None]
|
684 |
+
# (num_base_priors, num_bboxes, 7)
|
685 |
+
# (img_ind, labels, bbox_cx, bbox_cy, bbox_w, bbox_h, prior_ind)
|
686 |
+
batch_targets_normed = torch.cat(
|
687 |
+
(batch_targets_normed, batch_targets_prior_inds), 2)
|
688 |
+
return batch_targets_normed
|
689 |
+
|
690 |
+
def _decode_bbox_to_xywh(self, bbox_pred, priors_base_sizes) -> Tensor:
|
691 |
+
bbox_pred = bbox_pred.sigmoid()
|
692 |
+
pred_xy = bbox_pred[:, :2] * 2 - 0.5
|
693 |
+
pred_wh = (bbox_pred[:, 2:] * 2)**2 * priors_base_sizes
|
694 |
+
decoded_bbox_pred = torch.cat((pred_xy, pred_wh), dim=-1)
|
695 |
+
return decoded_bbox_pred
|
696 |
+
|
697 |
+
def _loss_by_feat_with_ignore(
|
698 |
+
self, cls_scores: Sequence[Tensor], bbox_preds: Sequence[Tensor],
|
699 |
+
objectnesses: Sequence[Tensor],
|
700 |
+
batch_gt_instances: Sequence[InstanceData],
|
701 |
+
batch_img_metas: Sequence[dict],
|
702 |
+
batch_gt_instances_ignore: Sequence[Tensor]) -> dict:
|
703 |
+
"""Calculate the loss based on the features extracted by the detection
|
704 |
+
head.
|
705 |
+
|
706 |
+
Args:
|
707 |
+
cls_scores (Sequence[Tensor]): Box scores for each scale level,
|
708 |
+
each is a 4D-tensor, the channel number is
|
709 |
+
num_priors * num_classes.
|
710 |
+
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
|
711 |
+
level, each is a 4D-tensor, the channel number is
|
712 |
+
num_priors * 4.
|
713 |
+
objectnesses (Sequence[Tensor]): Score factor for
|
714 |
+
all scale level, each is a 4D-tensor, has shape
|
715 |
+
(batch_size, 1, H, W).
|
716 |
+
batch_gt_instances (Sequence[InstanceData]): Batch of
|
717 |
+
gt_instance. It usually includes ``bboxes`` and ``labels``
|
718 |
+
attributes.
|
719 |
+
batch_img_metas (Sequence[dict]): Meta information of each image,
|
720 |
+
e.g., image size, scaling factor, etc.
|
721 |
+
batch_gt_instances_ignore (Sequence[Tensor]): Ignore boxes with
|
722 |
+
batch_ids and labels, each is a 2D-tensor, the channel number
|
723 |
+
is 6, means that (batch_id, label, xmin, ymin, xmax, ymax).
|
724 |
+
Returns:
|
725 |
+
dict[str, Tensor]: A dictionary of losses.
|
726 |
+
"""
|
727 |
+
# 1. Convert gt to norm format
|
728 |
+
batch_targets_normed = self._convert_gt_to_norm_format(
|
729 |
+
batch_gt_instances, batch_img_metas)
|
730 |
+
|
731 |
+
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
732 |
+
if featmap_sizes != self.featmap_sizes:
|
733 |
+
self.mlvl_priors = self.prior_generator.grid_priors(
|
734 |
+
featmap_sizes,
|
735 |
+
dtype=cls_scores[0].dtype,
|
736 |
+
device=cls_scores[0].device)
|
737 |
+
self.featmap_sizes = featmap_sizes
|
738 |
+
|
739 |
+
device = cls_scores[0].device
|
740 |
+
loss_cls = torch.zeros(1, device=device)
|
741 |
+
loss_box = torch.zeros(1, device=device)
|
742 |
+
loss_obj = torch.zeros(1, device=device)
|
743 |
+
scaled_factor = torch.ones(7, device=device)
|
744 |
+
|
745 |
+
for i in range(self.num_levels):
|
746 |
+
batch_size, _, h, w = bbox_preds[i].shape
|
747 |
+
target_obj = torch.zeros_like(objectnesses[i])
|
748 |
+
|
749 |
+
not_ignore_flags = bbox_preds[i].new_ones(batch_size,
|
750 |
+
self.num_base_priors, h,
|
751 |
+
w)
|
752 |
+
|
753 |
+
ignore_overlaps = bbox_overlaps(self.mlvl_priors[i],
|
754 |
+
batch_gt_instances_ignore[..., 2:],
|
755 |
+
'iof')
|
756 |
+
ignore_max_overlaps, ignore_max_ignore_index = ignore_overlaps.max(
|
757 |
+
dim=1)
|
758 |
+
|
759 |
+
batch_inds = batch_gt_instances_ignore[:,
|
760 |
+
0][ignore_max_ignore_index]
|
761 |
+
ignore_inds = (ignore_max_overlaps > self.ignore_iof_thr).nonzero(
|
762 |
+
as_tuple=True)[0]
|
763 |
+
batch_inds = batch_inds[ignore_inds].long()
|
764 |
+
ignore_priors, ignore_grid_xs, ignore_grid_ys = get_prior_xy_info(
|
765 |
+
ignore_inds, self.num_base_priors, self.featmap_sizes[i])
|
766 |
+
not_ignore_flags[batch_inds, ignore_priors, ignore_grid_ys,
|
767 |
+
ignore_grid_xs] = 0
|
768 |
+
|
769 |
+
# empty gt bboxes
|
770 |
+
if batch_targets_normed.shape[1] == 0:
|
771 |
+
loss_box += bbox_preds[i].sum() * 0
|
772 |
+
loss_cls += cls_scores[i].sum() * 0
|
773 |
+
loss_obj += self.loss_obj(
|
774 |
+
objectnesses[i],
|
775 |
+
target_obj,
|
776 |
+
weight=not_ignore_flags,
|
777 |
+
avg_factor=max(not_ignore_flags.sum(),
|
778 |
+
1)) * self.obj_level_weights[i]
|
779 |
+
continue
|
780 |
+
|
781 |
+
priors_base_sizes_i = self.priors_base_sizes[i]
|
782 |
+
# feature map scale whwh
|
783 |
+
scaled_factor[2:6] = torch.tensor(
|
784 |
+
bbox_preds[i].shape)[[3, 2, 3, 2]]
|
785 |
+
# Scale batch_targets from range 0-1 to range 0-features_maps size.
|
786 |
+
# (num_base_priors, num_bboxes, 7)
|
787 |
+
batch_targets_scaled = batch_targets_normed * scaled_factor
|
788 |
+
|
789 |
+
# 2. Shape match
|
790 |
+
wh_ratio = batch_targets_scaled[...,
|
791 |
+
4:6] / priors_base_sizes_i[:, None]
|
792 |
+
match_inds = torch.max(
|
793 |
+
wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr
|
794 |
+
batch_targets_scaled = batch_targets_scaled[match_inds]
|
795 |
+
|
796 |
+
# no gt bbox matches anchor
|
797 |
+
if batch_targets_scaled.shape[0] == 0:
|
798 |
+
loss_box += bbox_preds[i].sum() * 0
|
799 |
+
loss_cls += cls_scores[i].sum() * 0
|
800 |
+
loss_obj += self.loss_obj(
|
801 |
+
objectnesses[i],
|
802 |
+
target_obj,
|
803 |
+
weight=not_ignore_flags,
|
804 |
+
avg_factor=max(not_ignore_flags.sum(),
|
805 |
+
1)) * self.obj_level_weights[i]
|
806 |
+
continue
|
807 |
+
|
808 |
+
# 3. Positive samples with additional neighbors
|
809 |
+
|
810 |
+
# check the left, up, right, bottom sides of the
|
811 |
+
# targets grid, and determine whether assigned
|
812 |
+
# them as positive samples as well.
|
813 |
+
batch_targets_cxcy = batch_targets_scaled[:, 2:4]
|
814 |
+
grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
|
815 |
+
left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
|
816 |
+
(batch_targets_cxcy > 1)).T
|
817 |
+
right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
|
818 |
+
(grid_xy > 1)).T
|
819 |
+
offset_inds = torch.stack(
|
820 |
+
(torch.ones_like(left), left, up, right, bottom))
|
821 |
+
|
822 |
+
batch_targets_scaled = batch_targets_scaled.repeat(
|
823 |
+
(5, 1, 1))[offset_inds]
|
824 |
+
retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1],
|
825 |
+
1)[offset_inds]
|
826 |
+
|
827 |
+
# prepare pred results and positive sample indexes to
|
828 |
+
# calculate class loss and bbox lo
|
829 |
+
_chunk_targets = batch_targets_scaled.chunk(4, 1)
|
830 |
+
img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets
|
831 |
+
priors_inds, (img_inds, class_inds) = priors_inds.long().view(
|
832 |
+
-1), img_class_inds.long().T
|
833 |
+
|
834 |
+
grid_xy_long = (grid_xy -
|
835 |
+
retained_offsets * self.near_neighbor_thr).long()
|
836 |
+
grid_x_inds, grid_y_inds = grid_xy_long.T
|
837 |
+
bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1)
|
838 |
+
|
839 |
+
# 4. Calculate loss
|
840 |
+
# bbox loss
|
841 |
+
retained_bbox_pred = bbox_preds[i].reshape(
|
842 |
+
batch_size, self.num_base_priors, -1, h,
|
843 |
+
w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
|
844 |
+
priors_base_sizes_i = priors_base_sizes_i[priors_inds]
|
845 |
+
decoded_bbox_pred = self._decode_bbox_to_xywh(
|
846 |
+
retained_bbox_pred, priors_base_sizes_i)
|
847 |
+
|
848 |
+
not_ignore_weights = not_ignore_flags[img_inds, priors_inds,
|
849 |
+
grid_y_inds, grid_x_inds]
|
850 |
+
loss_box_i, iou = self.loss_bbox(
|
851 |
+
decoded_bbox_pred,
|
852 |
+
bboxes_targets,
|
853 |
+
weight=not_ignore_weights,
|
854 |
+
avg_factor=max(not_ignore_weights.sum(), 1))
|
855 |
+
loss_box += loss_box_i
|
856 |
+
|
857 |
+
# obj loss
|
858 |
+
iou = iou.detach().clamp(0)
|
859 |
+
target_obj[img_inds, priors_inds, grid_y_inds,
|
860 |
+
grid_x_inds] = iou.type(target_obj.dtype)
|
861 |
+
loss_obj += self.loss_obj(
|
862 |
+
objectnesses[i],
|
863 |
+
target_obj,
|
864 |
+
weight=not_ignore_flags,
|
865 |
+
avg_factor=max(not_ignore_flags.sum(),
|
866 |
+
1)) * self.obj_level_weights[i]
|
867 |
+
|
868 |
+
# cls loss
|
869 |
+
if self.num_classes > 1:
|
870 |
+
pred_cls_scores = cls_scores[i].reshape(
|
871 |
+
batch_size, self.num_base_priors, -1, h,
|
872 |
+
w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
|
873 |
+
|
874 |
+
target_class = torch.full_like(pred_cls_scores, 0.)
|
875 |
+
target_class[range(batch_targets_scaled.shape[0]),
|
876 |
+
class_inds] = 1.
|
877 |
+
loss_cls += self.loss_cls(
|
878 |
+
pred_cls_scores,
|
879 |
+
target_class,
|
880 |
+
weight=not_ignore_weights[:, None].repeat(
|
881 |
+
1, self.num_classes),
|
882 |
+
avg_factor=max(not_ignore_weights.sum(), 1))
|
883 |
+
else:
|
884 |
+
loss_cls += cls_scores[i].sum() * 0
|
885 |
+
|
886 |
+
_, world_size = get_dist_info()
|
887 |
+
return dict(
|
888 |
+
loss_cls=loss_cls * batch_size * world_size,
|
889 |
+
loss_obj=loss_obj * batch_size * world_size,
|
890 |
+
loss_bbox=loss_box * batch_size * world_size)
|
mmyolo/models/dense_heads/yolov6_head.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from mmcv.cnn import ConvModule
|
7 |
+
from mmdet.models.utils import multi_apply
|
8 |
+
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
|
9 |
+
OptMultiConfig)
|
10 |
+
from mmengine import MessageHub
|
11 |
+
from mmengine.dist import get_dist_info
|
12 |
+
from mmengine.model import BaseModule, bias_init_with_prob
|
13 |
+
from mmengine.structures import InstanceData
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
from mmyolo.registry import MODELS, TASK_UTILS
|
17 |
+
from ..utils import gt_instances_preprocess
|
18 |
+
from .yolov5_head import YOLOv5Head
|
19 |
+
|
20 |
+
|
21 |
+
@MODELS.register_module()
|
22 |
+
class YOLOv6HeadModule(BaseModule):
|
23 |
+
"""YOLOv6Head head module used in `YOLOv6.
|
24 |
+
|
25 |
+
<https://arxiv.org/pdf/2209.02976>`_.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
num_classes (int): Number of categories excluding the background
|
29 |
+
category.
|
30 |
+
in_channels (Union[int, Sequence]): Number of channels in the input
|
31 |
+
feature map.
|
32 |
+
widen_factor (float): Width multiplier, multiply number of
|
33 |
+
channels in each layer by this amount. Defaults to 1.0.
|
34 |
+
num_base_priors: (int): The number of priors (points) at a point
|
35 |
+
on the feature grid.
|
36 |
+
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
37 |
+
Defaults to [8, 16, 32].
|
38 |
+
None, otherwise False. Defaults to "auto".
|
39 |
+
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
|
40 |
+
layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
41 |
+
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
42 |
+
Defaults to None.
|
43 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
44 |
+
list[dict], optional): Initialization config dict.
|
45 |
+
Defaults to None.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self,
|
49 |
+
num_classes: int,
|
50 |
+
in_channels: Union[int, Sequence],
|
51 |
+
widen_factor: float = 1.0,
|
52 |
+
num_base_priors: int = 1,
|
53 |
+
featmap_strides: Sequence[int] = (8, 16, 32),
|
54 |
+
norm_cfg: ConfigType = dict(
|
55 |
+
type='BN', momentum=0.03, eps=0.001),
|
56 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
57 |
+
init_cfg: OptMultiConfig = None):
|
58 |
+
super().__init__(init_cfg=init_cfg)
|
59 |
+
|
60 |
+
self.num_classes = num_classes
|
61 |
+
self.featmap_strides = featmap_strides
|
62 |
+
self.num_levels = len(self.featmap_strides)
|
63 |
+
self.num_base_priors = num_base_priors
|
64 |
+
self.norm_cfg = norm_cfg
|
65 |
+
self.act_cfg = act_cfg
|
66 |
+
|
67 |
+
if isinstance(in_channels, int):
|
68 |
+
self.in_channels = [int(in_channels * widen_factor)
|
69 |
+
] * self.num_levels
|
70 |
+
else:
|
71 |
+
self.in_channels = [int(i * widen_factor) for i in in_channels]
|
72 |
+
|
73 |
+
self._init_layers()
|
74 |
+
|
75 |
+
def _init_layers(self):
|
76 |
+
"""initialize conv layers in YOLOv6 head."""
|
77 |
+
# Init decouple head
|
78 |
+
self.cls_convs = nn.ModuleList()
|
79 |
+
self.reg_convs = nn.ModuleList()
|
80 |
+
self.cls_preds = nn.ModuleList()
|
81 |
+
self.reg_preds = nn.ModuleList()
|
82 |
+
self.stems = nn.ModuleList()
|
83 |
+
for i in range(self.num_levels):
|
84 |
+
self.stems.append(
|
85 |
+
ConvModule(
|
86 |
+
in_channels=self.in_channels[i],
|
87 |
+
out_channels=self.in_channels[i],
|
88 |
+
kernel_size=1,
|
89 |
+
stride=1,
|
90 |
+
padding=1 // 2,
|
91 |
+
norm_cfg=self.norm_cfg,
|
92 |
+
act_cfg=self.act_cfg))
|
93 |
+
self.cls_convs.append(
|
94 |
+
ConvModule(
|
95 |
+
in_channels=self.in_channels[i],
|
96 |
+
out_channels=self.in_channels[i],
|
97 |
+
kernel_size=3,
|
98 |
+
stride=1,
|
99 |
+
padding=3 // 2,
|
100 |
+
norm_cfg=self.norm_cfg,
|
101 |
+
act_cfg=self.act_cfg))
|
102 |
+
self.reg_convs.append(
|
103 |
+
ConvModule(
|
104 |
+
in_channels=self.in_channels[i],
|
105 |
+
out_channels=self.in_channels[i],
|
106 |
+
kernel_size=3,
|
107 |
+
stride=1,
|
108 |
+
padding=3 // 2,
|
109 |
+
norm_cfg=self.norm_cfg,
|
110 |
+
act_cfg=self.act_cfg))
|
111 |
+
self.cls_preds.append(
|
112 |
+
nn.Conv2d(
|
113 |
+
in_channels=self.in_channels[i],
|
114 |
+
out_channels=self.num_base_priors * self.num_classes,
|
115 |
+
kernel_size=1))
|
116 |
+
self.reg_preds.append(
|
117 |
+
nn.Conv2d(
|
118 |
+
in_channels=self.in_channels[i],
|
119 |
+
out_channels=self.num_base_priors * 4,
|
120 |
+
kernel_size=1))
|
121 |
+
|
122 |
+
def init_weights(self):
|
123 |
+
super().init_weights()
|
124 |
+
bias_init = bias_init_with_prob(0.01)
|
125 |
+
for conv in self.cls_preds:
|
126 |
+
conv.bias.data.fill_(bias_init)
|
127 |
+
conv.weight.data.fill_(0.)
|
128 |
+
|
129 |
+
for conv in self.reg_preds:
|
130 |
+
conv.bias.data.fill_(1.0)
|
131 |
+
conv.weight.data.fill_(0.)
|
132 |
+
|
133 |
+
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
134 |
+
"""Forward features from the upstream network.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
x (Tuple[Tensor]): Features from the upstream network, each is
|
138 |
+
a 4D-tensor.
|
139 |
+
Returns:
|
140 |
+
Tuple[List]: A tuple of multi-level classification scores, bbox
|
141 |
+
predictions.
|
142 |
+
"""
|
143 |
+
assert len(x) == self.num_levels
|
144 |
+
return multi_apply(self.forward_single, x, self.stems, self.cls_convs,
|
145 |
+
self.cls_preds, self.reg_convs, self.reg_preds)
|
146 |
+
|
147 |
+
def forward_single(self, x: Tensor, stem: nn.Module, cls_conv: nn.Module,
|
148 |
+
cls_pred: nn.Module, reg_conv: nn.Module,
|
149 |
+
reg_pred: nn.Module) -> Tuple[Tensor, Tensor]:
|
150 |
+
"""Forward feature of a single scale level."""
|
151 |
+
y = stem(x)
|
152 |
+
cls_x = y
|
153 |
+
reg_x = y
|
154 |
+
cls_feat = cls_conv(cls_x)
|
155 |
+
reg_feat = reg_conv(reg_x)
|
156 |
+
|
157 |
+
cls_score = cls_pred(cls_feat)
|
158 |
+
bbox_pred = reg_pred(reg_feat)
|
159 |
+
|
160 |
+
return cls_score, bbox_pred
|
161 |
+
|
162 |
+
|
163 |
+
@MODELS.register_module()
|
164 |
+
class YOLOv6Head(YOLOv5Head):
|
165 |
+
"""YOLOv6Head head used in `YOLOv6 <https://arxiv.org/pdf/2209.02976>`_.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
head_module(ConfigType): Base module used for YOLOv6Head
|
169 |
+
prior_generator(dict): Points generator feature maps
|
170 |
+
in 2D points-based detectors.
|
171 |
+
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
172 |
+
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
173 |
+
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
174 |
+
anchor head. Defaults to None.
|
175 |
+
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
176 |
+
anchor head. Defaults to None.
|
177 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
178 |
+
list[dict], optional): Initialization config dict.
|
179 |
+
Defaults to None.
|
180 |
+
"""
|
181 |
+
|
182 |
+
def __init__(self,
|
183 |
+
head_module: ConfigType,
|
184 |
+
prior_generator: ConfigType = dict(
|
185 |
+
type='mmdet.MlvlPointGenerator',
|
186 |
+
offset=0.5,
|
187 |
+
strides=[8, 16, 32]),
|
188 |
+
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
189 |
+
loss_cls: ConfigType = dict(
|
190 |
+
type='mmdet.VarifocalLoss',
|
191 |
+
use_sigmoid=True,
|
192 |
+
alpha=0.75,
|
193 |
+
gamma=2.0,
|
194 |
+
iou_weighted=True,
|
195 |
+
reduction='sum',
|
196 |
+
loss_weight=1.0),
|
197 |
+
loss_bbox: ConfigType = dict(
|
198 |
+
type='IoULoss',
|
199 |
+
iou_mode='giou',
|
200 |
+
bbox_format='xyxy',
|
201 |
+
reduction='mean',
|
202 |
+
loss_weight=2.5,
|
203 |
+
return_iou=False),
|
204 |
+
train_cfg: OptConfigType = None,
|
205 |
+
test_cfg: OptConfigType = None,
|
206 |
+
init_cfg: OptMultiConfig = None):
|
207 |
+
super().__init__(
|
208 |
+
head_module=head_module,
|
209 |
+
prior_generator=prior_generator,
|
210 |
+
bbox_coder=bbox_coder,
|
211 |
+
loss_cls=loss_cls,
|
212 |
+
loss_bbox=loss_bbox,
|
213 |
+
train_cfg=train_cfg,
|
214 |
+
test_cfg=test_cfg,
|
215 |
+
init_cfg=init_cfg)
|
216 |
+
# yolov6 doesn't need loss_obj
|
217 |
+
self.loss_obj = None
|
218 |
+
|
219 |
+
def special_init(self):
|
220 |
+
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
|
221 |
+
different algorithms have special initialization process.
|
222 |
+
|
223 |
+
The special_init function is designed to deal with this situation.
|
224 |
+
"""
|
225 |
+
if self.train_cfg:
|
226 |
+
self.initial_epoch = self.train_cfg['initial_epoch']
|
227 |
+
self.initial_assigner = TASK_UTILS.build(
|
228 |
+
self.train_cfg.initial_assigner)
|
229 |
+
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
|
230 |
+
|
231 |
+
# Add common attributes to reduce calculation
|
232 |
+
self.featmap_sizes_train = None
|
233 |
+
self.num_level_priors = None
|
234 |
+
self.flatten_priors_train = None
|
235 |
+
self.stride_tensor = None
|
236 |
+
|
237 |
+
def loss_by_feat(
|
238 |
+
self,
|
239 |
+
cls_scores: Sequence[Tensor],
|
240 |
+
bbox_preds: Sequence[Tensor],
|
241 |
+
batch_gt_instances: Sequence[InstanceData],
|
242 |
+
batch_img_metas: Sequence[dict],
|
243 |
+
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
244 |
+
"""Calculate the loss based on the features extracted by the detection
|
245 |
+
head.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
cls_scores (Sequence[Tensor]): Box scores for each scale level,
|
249 |
+
each is a 4D-tensor, the channel number is
|
250 |
+
num_priors * num_classes.
|
251 |
+
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
|
252 |
+
level, each is a 4D-tensor, the channel number is
|
253 |
+
num_priors * 4.
|
254 |
+
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
255 |
+
gt_instance. It usually includes ``bboxes`` and ``labels``
|
256 |
+
attributes.
|
257 |
+
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
258 |
+
image size, scaling factor, etc.
|
259 |
+
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
|
260 |
+
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
261 |
+
data that is ignored during training and testing.
|
262 |
+
Defaults to None.
|
263 |
+
Returns:
|
264 |
+
dict[str, Tensor]: A dictionary of losses.
|
265 |
+
"""
|
266 |
+
|
267 |
+
# get epoch information from message hub
|
268 |
+
message_hub = MessageHub.get_current_instance()
|
269 |
+
current_epoch = message_hub.get_info('epoch')
|
270 |
+
|
271 |
+
num_imgs = len(batch_img_metas)
|
272 |
+
if batch_gt_instances_ignore is None:
|
273 |
+
batch_gt_instances_ignore = [None] * num_imgs
|
274 |
+
|
275 |
+
current_featmap_sizes = [
|
276 |
+
cls_score.shape[2:] for cls_score in cls_scores
|
277 |
+
]
|
278 |
+
# If the shape does not equal, generate new one
|
279 |
+
if current_featmap_sizes != self.featmap_sizes_train:
|
280 |
+
self.featmap_sizes_train = current_featmap_sizes
|
281 |
+
|
282 |
+
mlvl_priors_with_stride = self.prior_generator.grid_priors(
|
283 |
+
self.featmap_sizes_train,
|
284 |
+
dtype=cls_scores[0].dtype,
|
285 |
+
device=cls_scores[0].device,
|
286 |
+
with_stride=True)
|
287 |
+
|
288 |
+
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
|
289 |
+
self.flatten_priors_train = torch.cat(
|
290 |
+
mlvl_priors_with_stride, dim=0)
|
291 |
+
self.stride_tensor = self.flatten_priors_train[..., [2]]
|
292 |
+
|
293 |
+
# gt info
|
294 |
+
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
|
295 |
+
gt_labels = gt_info[:, :, :1]
|
296 |
+
gt_bboxes = gt_info[:, :, 1:] # xyxy
|
297 |
+
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
|
298 |
+
|
299 |
+
# pred info
|
300 |
+
flatten_cls_preds = [
|
301 |
+
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
302 |
+
self.num_classes)
|
303 |
+
for cls_pred in cls_scores
|
304 |
+
]
|
305 |
+
|
306 |
+
flatten_pred_bboxes = [
|
307 |
+
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
308 |
+
for bbox_pred in bbox_preds
|
309 |
+
]
|
310 |
+
|
311 |
+
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
|
312 |
+
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
|
313 |
+
flatten_pred_bboxes = self.bbox_coder.decode(
|
314 |
+
self.flatten_priors_train[..., :2], flatten_pred_bboxes,
|
315 |
+
self.stride_tensor[:, 0])
|
316 |
+
pred_scores = torch.sigmoid(flatten_cls_preds)
|
317 |
+
|
318 |
+
if current_epoch < self.initial_epoch:
|
319 |
+
assigned_result = self.initial_assigner(
|
320 |
+
flatten_pred_bboxes.detach(), self.flatten_priors_train,
|
321 |
+
self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
|
322 |
+
else:
|
323 |
+
assigned_result = self.assigner(flatten_pred_bboxes.detach(),
|
324 |
+
pred_scores.detach(),
|
325 |
+
self.flatten_priors_train,
|
326 |
+
gt_labels, gt_bboxes,
|
327 |
+
pad_bbox_flag)
|
328 |
+
|
329 |
+
assigned_bboxes = assigned_result['assigned_bboxes']
|
330 |
+
assigned_scores = assigned_result['assigned_scores']
|
331 |
+
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
|
332 |
+
|
333 |
+
# cls loss
|
334 |
+
with torch.cuda.amp.autocast(enabled=False):
|
335 |
+
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores)
|
336 |
+
|
337 |
+
# rescale bbox
|
338 |
+
assigned_bboxes /= self.stride_tensor
|
339 |
+
flatten_pred_bboxes /= self.stride_tensor
|
340 |
+
|
341 |
+
# TODO: Add all_reduce makes training more stable
|
342 |
+
assigned_scores_sum = assigned_scores.sum()
|
343 |
+
if assigned_scores_sum > 0:
|
344 |
+
loss_cls /= assigned_scores_sum
|
345 |
+
|
346 |
+
# select positive samples mask
|
347 |
+
num_pos = fg_mask_pre_prior.sum()
|
348 |
+
if num_pos > 0:
|
349 |
+
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
|
350 |
+
# will not report an error
|
351 |
+
# iou loss
|
352 |
+
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
|
353 |
+
pred_bboxes_pos = torch.masked_select(
|
354 |
+
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
|
355 |
+
assigned_bboxes_pos = torch.masked_select(
|
356 |
+
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
|
357 |
+
bbox_weight = torch.masked_select(
|
358 |
+
assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
|
359 |
+
loss_bbox = self.loss_bbox(
|
360 |
+
pred_bboxes_pos,
|
361 |
+
assigned_bboxes_pos,
|
362 |
+
weight=bbox_weight,
|
363 |
+
avg_factor=assigned_scores_sum)
|
364 |
+
else:
|
365 |
+
loss_bbox = flatten_pred_bboxes.sum() * 0
|
366 |
+
|
367 |
+
_, world_size = get_dist_info()
|
368 |
+
return dict(
|
369 |
+
loss_cls=loss_cls * world_size, loss_bbox=loss_bbox * world_size)
|
mmyolo/models/dense_heads/yolov7_head.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from mmcv.cnn import ConvModule
|
8 |
+
from mmdet.models.utils import multi_apply
|
9 |
+
from mmdet.utils import ConfigType, OptInstanceList
|
10 |
+
from mmengine.dist import get_dist_info
|
11 |
+
from mmengine.structures import InstanceData
|
12 |
+
from torch import Tensor
|
13 |
+
|
14 |
+
from mmyolo.registry import MODELS
|
15 |
+
from ..layers import ImplicitA, ImplicitM
|
16 |
+
from ..task_modules.assigners.batch_yolov7_assigner import BatchYOLOv7Assigner
|
17 |
+
from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
|
18 |
+
|
19 |
+
|
20 |
+
@MODELS.register_module()
|
21 |
+
class YOLOv7HeadModule(YOLOv5HeadModule):
|
22 |
+
"""YOLOv7Head head module used in YOLOv7."""
|
23 |
+
|
24 |
+
def _init_layers(self):
|
25 |
+
"""initialize conv layers in YOLOv7 head."""
|
26 |
+
self.convs_pred = nn.ModuleList()
|
27 |
+
for i in range(self.num_levels):
|
28 |
+
conv_pred = nn.Sequential(
|
29 |
+
ImplicitA(self.in_channels[i]),
|
30 |
+
nn.Conv2d(self.in_channels[i],
|
31 |
+
self.num_base_priors * self.num_out_attrib, 1),
|
32 |
+
ImplicitM(self.num_base_priors * self.num_out_attrib),
|
33 |
+
)
|
34 |
+
self.convs_pred.append(conv_pred)
|
35 |
+
|
36 |
+
def init_weights(self):
|
37 |
+
"""Initialize the bias of YOLOv7 head."""
|
38 |
+
super(YOLOv5HeadModule, self).init_weights()
|
39 |
+
for mi, s in zip(self.convs_pred, self.featmap_strides): # from
|
40 |
+
mi = mi[1] # nn.Conv2d
|
41 |
+
|
42 |
+
b = mi.bias.data.view(3, -1)
|
43 |
+
# obj (8 objects per 640 image)
|
44 |
+
b.data[:, 4] += math.log(8 / (640 / s)**2)
|
45 |
+
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
|
46 |
+
|
47 |
+
mi.bias.data = b.view(-1)
|
48 |
+
|
49 |
+
|
50 |
+
@MODELS.register_module()
|
51 |
+
class YOLOv7p6HeadModule(YOLOv5HeadModule):
|
52 |
+
"""YOLOv7Head head module used in YOLOv7."""
|
53 |
+
|
54 |
+
def __init__(self,
|
55 |
+
*args,
|
56 |
+
main_out_channels: Sequence[int] = [256, 512, 768, 1024],
|
57 |
+
aux_out_channels: Sequence[int] = [320, 640, 960, 1280],
|
58 |
+
use_aux: bool = True,
|
59 |
+
norm_cfg: ConfigType = dict(
|
60 |
+
type='BN', momentum=0.03, eps=0.001),
|
61 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
62 |
+
**kwargs):
|
63 |
+
self.main_out_channels = main_out_channels
|
64 |
+
self.aux_out_channels = aux_out_channels
|
65 |
+
self.use_aux = use_aux
|
66 |
+
self.norm_cfg = norm_cfg
|
67 |
+
self.act_cfg = act_cfg
|
68 |
+
super().__init__(*args, **kwargs)
|
69 |
+
|
70 |
+
def _init_layers(self):
|
71 |
+
"""initialize conv layers in YOLOv7 head."""
|
72 |
+
self.main_convs_pred = nn.ModuleList()
|
73 |
+
for i in range(self.num_levels):
|
74 |
+
conv_pred = nn.Sequential(
|
75 |
+
ConvModule(
|
76 |
+
self.in_channels[i],
|
77 |
+
self.main_out_channels[i],
|
78 |
+
3,
|
79 |
+
padding=1,
|
80 |
+
norm_cfg=self.norm_cfg,
|
81 |
+
act_cfg=self.act_cfg),
|
82 |
+
ImplicitA(self.main_out_channels[i]),
|
83 |
+
nn.Conv2d(self.main_out_channels[i],
|
84 |
+
self.num_base_priors * self.num_out_attrib, 1),
|
85 |
+
ImplicitM(self.num_base_priors * self.num_out_attrib),
|
86 |
+
)
|
87 |
+
self.main_convs_pred.append(conv_pred)
|
88 |
+
|
89 |
+
if self.use_aux:
|
90 |
+
self.aux_convs_pred = nn.ModuleList()
|
91 |
+
for i in range(self.num_levels):
|
92 |
+
aux_pred = nn.Sequential(
|
93 |
+
ConvModule(
|
94 |
+
self.in_channels[i],
|
95 |
+
self.aux_out_channels[i],
|
96 |
+
3,
|
97 |
+
padding=1,
|
98 |
+
norm_cfg=self.norm_cfg,
|
99 |
+
act_cfg=self.act_cfg),
|
100 |
+
nn.Conv2d(self.aux_out_channels[i],
|
101 |
+
self.num_base_priors * self.num_out_attrib, 1))
|
102 |
+
self.aux_convs_pred.append(aux_pred)
|
103 |
+
else:
|
104 |
+
self.aux_convs_pred = [None] * len(self.main_convs_pred)
|
105 |
+
|
106 |
+
def init_weights(self):
|
107 |
+
"""Initialize the bias of YOLOv5 head."""
|
108 |
+
super(YOLOv5HeadModule, self).init_weights()
|
109 |
+
for mi, aux, s in zip(self.main_convs_pred, self.aux_convs_pred,
|
110 |
+
self.featmap_strides): # from
|
111 |
+
mi = mi[2] # nn.Conv2d
|
112 |
+
b = mi.bias.data.view(3, -1)
|
113 |
+
# obj (8 objects per 640 image)
|
114 |
+
b.data[:, 4] += math.log(8 / (640 / s)**2)
|
115 |
+
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
|
116 |
+
mi.bias.data = b.view(-1)
|
117 |
+
|
118 |
+
if self.use_aux:
|
119 |
+
aux = aux[1] # nn.Conv2d
|
120 |
+
b = aux.bias.data.view(3, -1)
|
121 |
+
# obj (8 objects per 640 image)
|
122 |
+
b.data[:, 4] += math.log(8 / (640 / s)**2)
|
123 |
+
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
|
124 |
+
mi.bias.data = b.view(-1)
|
125 |
+
|
126 |
+
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
127 |
+
"""Forward features from the upstream network.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
x (Tuple[Tensor]): Features from the upstream network, each is
|
131 |
+
a 4D-tensor.
|
132 |
+
Returns:
|
133 |
+
Tuple[List]: A tuple of multi-level classification scores, bbox
|
134 |
+
predictions, and objectnesses.
|
135 |
+
"""
|
136 |
+
assert len(x) == self.num_levels
|
137 |
+
return multi_apply(self.forward_single, x, self.main_convs_pred,
|
138 |
+
self.aux_convs_pred)
|
139 |
+
|
140 |
+
def forward_single(self, x: Tensor, convs: nn.Module,
|
141 |
+
aux_convs: Optional[nn.Module]) \
|
142 |
+
-> Tuple[Union[Tensor, List], Union[Tensor, List],
|
143 |
+
Union[Tensor, List]]:
|
144 |
+
"""Forward feature of a single scale level."""
|
145 |
+
|
146 |
+
pred_map = convs(x)
|
147 |
+
bs, _, ny, nx = pred_map.shape
|
148 |
+
pred_map = pred_map.view(bs, self.num_base_priors, self.num_out_attrib,
|
149 |
+
ny, nx)
|
150 |
+
|
151 |
+
cls_score = pred_map[:, :, 5:, ...].reshape(bs, -1, ny, nx)
|
152 |
+
bbox_pred = pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
|
153 |
+
objectness = pred_map[:, :, 4:5, ...].reshape(bs, -1, ny, nx)
|
154 |
+
|
155 |
+
if not self.training or not self.use_aux:
|
156 |
+
return cls_score, bbox_pred, objectness
|
157 |
+
else:
|
158 |
+
aux_pred_map = aux_convs(x)
|
159 |
+
aux_pred_map = aux_pred_map.view(bs, self.num_base_priors,
|
160 |
+
self.num_out_attrib, ny, nx)
|
161 |
+
aux_cls_score = aux_pred_map[:, :, 5:, ...].reshape(bs, -1, ny, nx)
|
162 |
+
aux_bbox_pred = aux_pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
|
163 |
+
aux_objectness = aux_pred_map[:, :, 4:5,
|
164 |
+
...].reshape(bs, -1, ny, nx)
|
165 |
+
|
166 |
+
return [cls_score,
|
167 |
+
aux_cls_score], [bbox_pred, aux_bbox_pred
|
168 |
+
], [objectness, aux_objectness]
|
169 |
+
|
170 |
+
|
171 |
+
@MODELS.register_module()
|
172 |
+
class YOLOv7Head(YOLOv5Head):
|
173 |
+
"""YOLOv7Head head used in `YOLOv7 <https://arxiv.org/abs/2207.02696>`_.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
simota_candidate_topk (int): The candidate top-k which used to
|
177 |
+
get top-k ious to calculate dynamic-k in BatchYOLOv7Assigner.
|
178 |
+
Defaults to 10.
|
179 |
+
simota_iou_weight (float): The scale factor for regression
|
180 |
+
iou cost in BatchYOLOv7Assigner. Defaults to 3.0.
|
181 |
+
simota_cls_weight (float): The scale factor for classification
|
182 |
+
cost in BatchYOLOv7Assigner. Defaults to 1.0.
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(self,
|
186 |
+
*args,
|
187 |
+
simota_candidate_topk: int = 20,
|
188 |
+
simota_iou_weight: float = 3.0,
|
189 |
+
simota_cls_weight: float = 1.0,
|
190 |
+
aux_loss_weights: float = 0.25,
|
191 |
+
**kwargs):
|
192 |
+
super().__init__(*args, **kwargs)
|
193 |
+
self.aux_loss_weights = aux_loss_weights
|
194 |
+
self.assigner = BatchYOLOv7Assigner(
|
195 |
+
num_classes=self.num_classes,
|
196 |
+
num_base_priors=self.num_base_priors,
|
197 |
+
featmap_strides=self.featmap_strides,
|
198 |
+
prior_match_thr=self.prior_match_thr,
|
199 |
+
candidate_topk=simota_candidate_topk,
|
200 |
+
iou_weight=simota_iou_weight,
|
201 |
+
cls_weight=simota_cls_weight)
|
202 |
+
|
203 |
+
def loss_by_feat(
|
204 |
+
self,
|
205 |
+
cls_scores: Sequence[Union[Tensor, List]],
|
206 |
+
bbox_preds: Sequence[Union[Tensor, List]],
|
207 |
+
objectnesses: Sequence[Union[Tensor, List]],
|
208 |
+
batch_gt_instances: Sequence[InstanceData],
|
209 |
+
batch_img_metas: Sequence[dict],
|
210 |
+
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
211 |
+
"""Calculate the loss based on the features extracted by the detection
|
212 |
+
head.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
cls_scores (Sequence[Tensor]): Box scores for each scale level,
|
216 |
+
each is a 4D-tensor, the channel number is
|
217 |
+
num_priors * num_classes.
|
218 |
+
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
|
219 |
+
level, each is a 4D-tensor, the channel number is
|
220 |
+
num_priors * 4.
|
221 |
+
objectnesses (Sequence[Tensor]): Score factor for
|
222 |
+
all scale level, each is a 4D-tensor, has shape
|
223 |
+
(batch_size, 1, H, W).
|
224 |
+
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
225 |
+
gt_instance. It usually includes ``bboxes`` and ``labels``
|
226 |
+
attributes.
|
227 |
+
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
228 |
+
image size, scaling factor, etc.
|
229 |
+
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
|
230 |
+
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
231 |
+
data that is ignored during training and testing.
|
232 |
+
Defaults to None.
|
233 |
+
Returns:
|
234 |
+
dict[str, Tensor]: A dictionary of losses.
|
235 |
+
"""
|
236 |
+
|
237 |
+
if isinstance(cls_scores[0], Sequence):
|
238 |
+
with_aux = True
|
239 |
+
batch_size = cls_scores[0][0].shape[0]
|
240 |
+
device = cls_scores[0][0].device
|
241 |
+
|
242 |
+
bbox_preds_main, bbox_preds_aux = zip(*bbox_preds)
|
243 |
+
objectnesses_main, objectnesses_aux = zip(*objectnesses)
|
244 |
+
cls_scores_main, cls_scores_aux = zip(*cls_scores)
|
245 |
+
|
246 |
+
head_preds = self._merge_predict_results(bbox_preds_main,
|
247 |
+
objectnesses_main,
|
248 |
+
cls_scores_main)
|
249 |
+
head_preds_aux = self._merge_predict_results(
|
250 |
+
bbox_preds_aux, objectnesses_aux, cls_scores_aux)
|
251 |
+
else:
|
252 |
+
with_aux = False
|
253 |
+
batch_size = cls_scores[0].shape[0]
|
254 |
+
device = cls_scores[0].device
|
255 |
+
|
256 |
+
head_preds = self._merge_predict_results(bbox_preds, objectnesses,
|
257 |
+
cls_scores)
|
258 |
+
|
259 |
+
# Convert gt to norm xywh format
|
260 |
+
# (num_base_priors, num_batch_gt, 7)
|
261 |
+
# 7 is mean (batch_idx, cls_id, x_norm, y_norm,
|
262 |
+
# w_norm, h_norm, prior_idx)
|
263 |
+
batch_targets_normed = self._convert_gt_to_norm_format(
|
264 |
+
batch_gt_instances, batch_img_metas)
|
265 |
+
|
266 |
+
scaled_factors = [
|
267 |
+
torch.tensor(head_pred.shape, device=device)[[3, 2, 3, 2]]
|
268 |
+
for head_pred in head_preds
|
269 |
+
]
|
270 |
+
|
271 |
+
loss_cls, loss_obj, loss_box = self._calc_loss(
|
272 |
+
head_preds=head_preds,
|
273 |
+
head_preds_aux=None,
|
274 |
+
batch_targets_normed=batch_targets_normed,
|
275 |
+
near_neighbor_thr=self.near_neighbor_thr,
|
276 |
+
scaled_factors=scaled_factors,
|
277 |
+
batch_img_metas=batch_img_metas,
|
278 |
+
device=device)
|
279 |
+
|
280 |
+
if with_aux:
|
281 |
+
loss_cls_aux, loss_obj_aux, loss_box_aux = self._calc_loss(
|
282 |
+
head_preds=head_preds,
|
283 |
+
head_preds_aux=head_preds_aux,
|
284 |
+
batch_targets_normed=batch_targets_normed,
|
285 |
+
near_neighbor_thr=self.near_neighbor_thr * 2,
|
286 |
+
scaled_factors=scaled_factors,
|
287 |
+
batch_img_metas=batch_img_metas,
|
288 |
+
device=device)
|
289 |
+
loss_cls += self.aux_loss_weights * loss_cls_aux
|
290 |
+
loss_obj += self.aux_loss_weights * loss_obj_aux
|
291 |
+
loss_box += self.aux_loss_weights * loss_box_aux
|
292 |
+
|
293 |
+
_, world_size = get_dist_info()
|
294 |
+
return dict(
|
295 |
+
loss_cls=loss_cls * batch_size * world_size,
|
296 |
+
loss_obj=loss_obj * batch_size * world_size,
|
297 |
+
loss_bbox=loss_box * batch_size * world_size)
|
298 |
+
|
299 |
+
def _calc_loss(self, head_preds, head_preds_aux, batch_targets_normed,
|
300 |
+
near_neighbor_thr, scaled_factors, batch_img_metas, device):
|
301 |
+
loss_cls = torch.zeros(1, device=device)
|
302 |
+
loss_box = torch.zeros(1, device=device)
|
303 |
+
loss_obj = torch.zeros(1, device=device)
|
304 |
+
|
305 |
+
assigner_results = self.assigner(
|
306 |
+
head_preds,
|
307 |
+
batch_targets_normed,
|
308 |
+
batch_img_metas[0]['batch_input_shape'],
|
309 |
+
self.priors_base_sizes,
|
310 |
+
self.grid_offset,
|
311 |
+
near_neighbor_thr=near_neighbor_thr)
|
312 |
+
# mlvl is mean multi_level
|
313 |
+
mlvl_positive_infos = assigner_results['mlvl_positive_infos']
|
314 |
+
mlvl_priors = assigner_results['mlvl_priors']
|
315 |
+
mlvl_targets_normed = assigner_results['mlvl_targets_normed']
|
316 |
+
|
317 |
+
if head_preds_aux is not None:
|
318 |
+
# This is mean calc aux branch loss
|
319 |
+
head_preds = head_preds_aux
|
320 |
+
|
321 |
+
for i, head_pred in enumerate(head_preds):
|
322 |
+
batch_inds, proir_idx, grid_x, grid_y = mlvl_positive_infos[i].T
|
323 |
+
num_pred_positive = batch_inds.shape[0]
|
324 |
+
target_obj = torch.zeros_like(head_pred[..., 0])
|
325 |
+
# empty positive sampler
|
326 |
+
if num_pred_positive == 0:
|
327 |
+
loss_box += head_pred[..., :4].sum() * 0
|
328 |
+
loss_cls += head_pred[..., 5:].sum() * 0
|
329 |
+
loss_obj += self.loss_obj(
|
330 |
+
head_pred[..., 4], target_obj) * self.obj_level_weights[i]
|
331 |
+
continue
|
332 |
+
|
333 |
+
priors = mlvl_priors[i]
|
334 |
+
targets_normed = mlvl_targets_normed[i]
|
335 |
+
|
336 |
+
head_pred_positive = head_pred[batch_inds, proir_idx, grid_y,
|
337 |
+
grid_x]
|
338 |
+
|
339 |
+
# calc bbox loss
|
340 |
+
grid_xy = torch.stack([grid_x, grid_y], dim=1)
|
341 |
+
decoded_pred_bbox = self._decode_bbox_to_xywh(
|
342 |
+
head_pred_positive[:, :4], priors, grid_xy)
|
343 |
+
target_bbox_scaled = targets_normed[:, 2:6] * scaled_factors[i]
|
344 |
+
|
345 |
+
loss_box_i, iou = self.loss_bbox(decoded_pred_bbox,
|
346 |
+
target_bbox_scaled)
|
347 |
+
loss_box += loss_box_i
|
348 |
+
|
349 |
+
# calc obj loss
|
350 |
+
target_obj[batch_inds, proir_idx, grid_y,
|
351 |
+
grid_x] = iou.detach().clamp(0).type(target_obj.dtype)
|
352 |
+
loss_obj += self.loss_obj(head_pred[..., 4],
|
353 |
+
target_obj) * self.obj_level_weights[i]
|
354 |
+
|
355 |
+
# calc cls loss
|
356 |
+
if self.num_classes > 1:
|
357 |
+
pred_cls_scores = targets_normed[:, 1].long()
|
358 |
+
target_class = torch.full_like(
|
359 |
+
head_pred_positive[:, 5:], 0., device=device)
|
360 |
+
target_class[range(num_pred_positive), pred_cls_scores] = 1.
|
361 |
+
loss_cls += self.loss_cls(head_pred_positive[:, 5:],
|
362 |
+
target_class)
|
363 |
+
else:
|
364 |
+
loss_cls += head_pred_positive[:, 5:].sum() * 0
|
365 |
+
return loss_cls, loss_obj, loss_box
|
366 |
+
|
367 |
+
def _merge_predict_results(self, bbox_preds: Sequence[Tensor],
|
368 |
+
objectnesses: Sequence[Tensor],
|
369 |
+
cls_scores: Sequence[Tensor]) -> List[Tensor]:
|
370 |
+
"""Merge predict output from 3 heads.
|
371 |
+
|
372 |
+
Args:
|
373 |
+
cls_scores (Sequence[Tensor]): Box scores for each scale level,
|
374 |
+
each is a 4D-tensor, the channel number is
|
375 |
+
num_priors * num_classes.
|
376 |
+
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
|
377 |
+
level, each is a 4D-tensor, the channel number is
|
378 |
+
num_priors * 4.
|
379 |
+
objectnesses (Sequence[Tensor]): Score factor for
|
380 |
+
all scale level, each is a 4D-tensor, has shape
|
381 |
+
(batch_size, 1, H, W).
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
List[Tensor]: Merged output.
|
385 |
+
"""
|
386 |
+
head_preds = []
|
387 |
+
for bbox_pred, objectness, cls_score in zip(bbox_preds, objectnesses,
|
388 |
+
cls_scores):
|
389 |
+
b, _, h, w = bbox_pred.shape
|
390 |
+
bbox_pred = bbox_pred.reshape(b, self.num_base_priors, -1, h, w)
|
391 |
+
objectness = objectness.reshape(b, self.num_base_priors, -1, h, w)
|
392 |
+
cls_score = cls_score.reshape(b, self.num_base_priors, -1, h, w)
|
393 |
+
head_pred = torch.cat([bbox_pred, objectness, cls_score],
|
394 |
+
dim=2).permute(0, 1, 3, 4, 2).contiguous()
|
395 |
+
head_preds.append(head_pred)
|
396 |
+
return head_preds
|
397 |
+
|
398 |
+
def _decode_bbox_to_xywh(self, bbox_pred, priors_base_sizes,
|
399 |
+
grid_xy) -> Tensor:
|
400 |
+
bbox_pred = bbox_pred.sigmoid()
|
401 |
+
pred_xy = bbox_pred[:, :2] * 2 - 0.5 + grid_xy
|
402 |
+
pred_wh = (bbox_pred[:, 2:] * 2)**2 * priors_base_sizes
|
403 |
+
decoded_bbox_pred = torch.cat((pred_xy, pred_wh), dim=-1)
|
404 |
+
return decoded_bbox_pred
|
mmyolo/models/dense_heads/yolov8_head.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
from typing import List, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from mmcv.cnn import ConvModule
|
8 |
+
from mmdet.models.utils import multi_apply
|
9 |
+
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
|
10 |
+
OptMultiConfig)
|
11 |
+
from mmengine.dist import get_dist_info
|
12 |
+
from mmengine.model import BaseModule
|
13 |
+
from mmengine.structures import InstanceData
|
14 |
+
from torch import Tensor
|
15 |
+
|
16 |
+
from mmyolo.registry import MODELS, TASK_UTILS
|
17 |
+
from ..utils import gt_instances_preprocess, make_divisible
|
18 |
+
from .yolov5_head import YOLOv5Head
|
19 |
+
|
20 |
+
|
21 |
+
@MODELS.register_module()
|
22 |
+
class YOLOv8HeadModule(BaseModule):
|
23 |
+
"""YOLOv8HeadModule head module used in `YOLOv8`.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
num_classes (int): Number of categories excluding the background
|
27 |
+
category.
|
28 |
+
in_channels (Union[int, Sequence]): Number of channels in the input
|
29 |
+
feature map.
|
30 |
+
widen_factor (float): Width multiplier, multiply number of
|
31 |
+
channels in each layer by this amount. Defaults to 1.0.
|
32 |
+
num_base_priors (int): The number of priors (points) at a point
|
33 |
+
on the feature grid.
|
34 |
+
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
35 |
+
Defaults to [8, 16, 32].
|
36 |
+
reg_max (int): Max value of integral set :math: ``{0, ..., reg_max-1}``
|
37 |
+
in QFL setting. Defaults to 16.
|
38 |
+
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
|
39 |
+
layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
40 |
+
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
41 |
+
Defaults to None.
|
42 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
43 |
+
list[dict], optional): Initialization config dict.
|
44 |
+
Defaults to None.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self,
|
48 |
+
num_classes: int,
|
49 |
+
in_channels: Union[int, Sequence],
|
50 |
+
widen_factor: float = 1.0,
|
51 |
+
num_base_priors: int = 1,
|
52 |
+
featmap_strides: Sequence[int] = (8, 16, 32),
|
53 |
+
reg_max: int = 16,
|
54 |
+
norm_cfg: ConfigType = dict(
|
55 |
+
type='BN', momentum=0.03, eps=0.001),
|
56 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
57 |
+
init_cfg: OptMultiConfig = None):
|
58 |
+
super().__init__(init_cfg=init_cfg)
|
59 |
+
self.num_classes = num_classes
|
60 |
+
self.featmap_strides = featmap_strides
|
61 |
+
self.num_levels = len(self.featmap_strides)
|
62 |
+
self.num_base_priors = num_base_priors
|
63 |
+
self.norm_cfg = norm_cfg
|
64 |
+
self.act_cfg = act_cfg
|
65 |
+
self.in_channels = in_channels
|
66 |
+
self.reg_max = reg_max
|
67 |
+
|
68 |
+
in_channels = []
|
69 |
+
for channel in self.in_channels:
|
70 |
+
channel = make_divisible(channel, widen_factor)
|
71 |
+
in_channels.append(channel)
|
72 |
+
self.in_channels = in_channels
|
73 |
+
|
74 |
+
self._init_layers()
|
75 |
+
|
76 |
+
def init_weights(self, prior_prob=0.01):
|
77 |
+
"""Initialize the weight and bias of PPYOLOE head."""
|
78 |
+
super().init_weights()
|
79 |
+
for reg_pred, cls_pred, stride in zip(self.reg_preds, self.cls_preds,
|
80 |
+
self.featmap_strides):
|
81 |
+
reg_pred[-1].bias.data[:] = 1.0 # box
|
82 |
+
# cls (.01 objects, 80 classes, 640 img)
|
83 |
+
cls_pred[-1].bias.data[:self.num_classes] = math.log(
|
84 |
+
5 / self.num_classes / (640 / stride)**2)
|
85 |
+
|
86 |
+
def _init_layers(self):
|
87 |
+
"""initialize conv layers in YOLOv8 head."""
|
88 |
+
# Init decouple head
|
89 |
+
self.cls_preds = nn.ModuleList()
|
90 |
+
self.reg_preds = nn.ModuleList()
|
91 |
+
|
92 |
+
reg_out_channels = max(
|
93 |
+
(16, self.in_channels[0] // 4, self.reg_max * 4))
|
94 |
+
cls_out_channels = max(self.in_channels[0], self.num_classes)
|
95 |
+
|
96 |
+
for i in range(self.num_levels):
|
97 |
+
self.reg_preds.append(
|
98 |
+
nn.Sequential(
|
99 |
+
ConvModule(
|
100 |
+
in_channels=self.in_channels[i],
|
101 |
+
out_channels=reg_out_channels,
|
102 |
+
kernel_size=3,
|
103 |
+
stride=1,
|
104 |
+
padding=1,
|
105 |
+
norm_cfg=self.norm_cfg,
|
106 |
+
act_cfg=self.act_cfg),
|
107 |
+
ConvModule(
|
108 |
+
in_channels=reg_out_channels,
|
109 |
+
out_channels=reg_out_channels,
|
110 |
+
kernel_size=3,
|
111 |
+
stride=1,
|
112 |
+
padding=1,
|
113 |
+
norm_cfg=self.norm_cfg,
|
114 |
+
act_cfg=self.act_cfg),
|
115 |
+
nn.Conv2d(
|
116 |
+
in_channels=reg_out_channels,
|
117 |
+
out_channels=4 * self.reg_max,
|
118 |
+
kernel_size=1)))
|
119 |
+
self.cls_preds.append(
|
120 |
+
nn.Sequential(
|
121 |
+
ConvModule(
|
122 |
+
in_channels=self.in_channels[i],
|
123 |
+
out_channels=cls_out_channels,
|
124 |
+
kernel_size=3,
|
125 |
+
stride=1,
|
126 |
+
padding=1,
|
127 |
+
norm_cfg=self.norm_cfg,
|
128 |
+
act_cfg=self.act_cfg),
|
129 |
+
ConvModule(
|
130 |
+
in_channels=cls_out_channels,
|
131 |
+
out_channels=cls_out_channels,
|
132 |
+
kernel_size=3,
|
133 |
+
stride=1,
|
134 |
+
padding=1,
|
135 |
+
norm_cfg=self.norm_cfg,
|
136 |
+
act_cfg=self.act_cfg),
|
137 |
+
nn.Conv2d(
|
138 |
+
in_channels=cls_out_channels,
|
139 |
+
out_channels=self.num_classes,
|
140 |
+
kernel_size=1)))
|
141 |
+
|
142 |
+
proj = torch.arange(self.reg_max, dtype=torch.float)
|
143 |
+
self.register_buffer('proj', proj, persistent=False)
|
144 |
+
|
145 |
+
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
146 |
+
"""Forward features from the upstream network.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
x (Tuple[Tensor]): Features from the upstream network, each is
|
150 |
+
a 4D-tensor.
|
151 |
+
Returns:
|
152 |
+
Tuple[List]: A tuple of multi-level classification scores, bbox
|
153 |
+
predictions
|
154 |
+
"""
|
155 |
+
assert len(x) == self.num_levels
|
156 |
+
return multi_apply(self.forward_single, x, self.cls_preds,
|
157 |
+
self.reg_preds)
|
158 |
+
|
159 |
+
def forward_single(self, x: torch.Tensor, cls_pred: nn.ModuleList,
|
160 |
+
reg_pred: nn.ModuleList) -> Tuple:
|
161 |
+
"""Forward feature of a single scale level."""
|
162 |
+
b, _, h, w = x.shape
|
163 |
+
cls_logit = cls_pred(x)
|
164 |
+
bbox_dist_preds = reg_pred(x)
|
165 |
+
if self.reg_max > 1:
|
166 |
+
bbox_dist_preds = bbox_dist_preds.reshape(
|
167 |
+
[-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2)
|
168 |
+
|
169 |
+
# TODO: The get_flops script cannot handle the situation of
|
170 |
+
# matmul, and needs to be fixed later
|
171 |
+
# bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
|
172 |
+
bbox_preds = bbox_dist_preds.softmax(3).matmul(
|
173 |
+
self.proj.view([-1, 1])).squeeze(-1)
|
174 |
+
bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w)
|
175 |
+
else:
|
176 |
+
bbox_preds = bbox_dist_preds
|
177 |
+
if self.training:
|
178 |
+
return cls_logit, bbox_preds, bbox_dist_preds
|
179 |
+
else:
|
180 |
+
return cls_logit, bbox_preds
|
181 |
+
|
182 |
+
|
183 |
+
@MODELS.register_module()
|
184 |
+
class YOLOv8Head(YOLOv5Head):
|
185 |
+
"""YOLOv8Head head used in `YOLOv8`.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
head_module(:obj:`ConfigDict` or dict): Base module used for YOLOv8Head
|
189 |
+
prior_generator(dict): Points generator feature maps
|
190 |
+
in 2D points-based detectors.
|
191 |
+
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
|
192 |
+
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
193 |
+
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
194 |
+
loss_dfl (:obj:`ConfigDict` or dict): Config of Distribution Focal
|
195 |
+
Loss.
|
196 |
+
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
197 |
+
anchor head. Defaults to None.
|
198 |
+
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
199 |
+
anchor head. Defaults to None.
|
200 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
201 |
+
list[dict], optional): Initialization config dict.
|
202 |
+
Defaults to None.
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self,
|
206 |
+
head_module: ConfigType,
|
207 |
+
prior_generator: ConfigType = dict(
|
208 |
+
type='mmdet.MlvlPointGenerator',
|
209 |
+
offset=0.5,
|
210 |
+
strides=[8, 16, 32]),
|
211 |
+
bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
|
212 |
+
loss_cls: ConfigType = dict(
|
213 |
+
type='mmdet.CrossEntropyLoss',
|
214 |
+
use_sigmoid=True,
|
215 |
+
reduction='none',
|
216 |
+
loss_weight=0.5),
|
217 |
+
loss_bbox: ConfigType = dict(
|
218 |
+
type='IoULoss',
|
219 |
+
iou_mode='ciou',
|
220 |
+
bbox_format='xyxy',
|
221 |
+
reduction='sum',
|
222 |
+
loss_weight=7.5,
|
223 |
+
return_iou=False),
|
224 |
+
loss_dfl=dict(
|
225 |
+
type='mmdet.DistributionFocalLoss',
|
226 |
+
reduction='mean',
|
227 |
+
loss_weight=1.5 / 4),
|
228 |
+
train_cfg: OptConfigType = None,
|
229 |
+
test_cfg: OptConfigType = None,
|
230 |
+
init_cfg: OptMultiConfig = None
|
231 |
+
):
|
232 |
+
super().__init__(
|
233 |
+
head_module=head_module,
|
234 |
+
prior_generator=prior_generator,
|
235 |
+
bbox_coder=bbox_coder,
|
236 |
+
loss_cls=loss_cls,
|
237 |
+
loss_bbox=loss_bbox,
|
238 |
+
train_cfg=train_cfg,
|
239 |
+
test_cfg=test_cfg,
|
240 |
+
init_cfg=init_cfg)
|
241 |
+
self.loss_dfl = MODELS.build(loss_dfl)
|
242 |
+
# YOLOv8 doesn't need loss_obj
|
243 |
+
self.loss_obj = None
|
244 |
+
|
245 |
+
def special_init(self):
|
246 |
+
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
|
247 |
+
different algorithms have special initialization process.
|
248 |
+
|
249 |
+
The special_init function is designed to deal with this situation.
|
250 |
+
"""
|
251 |
+
|
252 |
+
if self.train_cfg:
|
253 |
+
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
|
254 |
+
|
255 |
+
# Add common attributes to reduce calculation
|
256 |
+
self.featmap_sizes_train = None
|
257 |
+
self.num_level_priors = None
|
258 |
+
self.flatten_priors_train = None
|
259 |
+
self.stride_tensor = None
|
260 |
+
|
261 |
+
def loss_by_feat(
|
262 |
+
self,
|
263 |
+
cls_scores: Sequence[Tensor],
|
264 |
+
bbox_preds: Sequence[Tensor],
|
265 |
+
bbox_dist_preds: Sequence[Tensor],
|
266 |
+
batch_gt_instances: Sequence[InstanceData],
|
267 |
+
batch_img_metas: Sequence[dict],
|
268 |
+
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
269 |
+
"""Calculate the loss based on the features extracted by the detection
|
270 |
+
head.
|
271 |
+
|
272 |
+
Args:
|
273 |
+
cls_scores (Sequence[Tensor]): Box scores for each scale level,
|
274 |
+
each is a 4D-tensor, the channel number is
|
275 |
+
num_priors * num_classes.
|
276 |
+
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
|
277 |
+
level, each is a 4D-tensor, the channel number is
|
278 |
+
num_priors * 4.
|
279 |
+
bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
|
280 |
+
each scale level with shape (bs, reg_max + 1, H*W, 4).
|
281 |
+
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
282 |
+
gt_instance. It usually includes ``bboxes`` and ``labels``
|
283 |
+
attributes.
|
284 |
+
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
285 |
+
image size, scaling factor, etc.
|
286 |
+
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
|
287 |
+
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
288 |
+
data that is ignored during training and testing.
|
289 |
+
Defaults to None.
|
290 |
+
Returns:
|
291 |
+
dict[str, Tensor]: A dictionary of losses.
|
292 |
+
"""
|
293 |
+
num_imgs = len(batch_img_metas)
|
294 |
+
|
295 |
+
current_featmap_sizes = [
|
296 |
+
cls_score.shape[2:] for cls_score in cls_scores
|
297 |
+
]
|
298 |
+
# If the shape does not equal, generate new one
|
299 |
+
if current_featmap_sizes != self.featmap_sizes_train:
|
300 |
+
self.featmap_sizes_train = current_featmap_sizes
|
301 |
+
|
302 |
+
mlvl_priors_with_stride = self.prior_generator.grid_priors(
|
303 |
+
self.featmap_sizes_train,
|
304 |
+
dtype=cls_scores[0].dtype,
|
305 |
+
device=cls_scores[0].device,
|
306 |
+
with_stride=True)
|
307 |
+
|
308 |
+
self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
|
309 |
+
self.flatten_priors_train = torch.cat(
|
310 |
+
mlvl_priors_with_stride, dim=0)
|
311 |
+
self.stride_tensor = self.flatten_priors_train[..., [2]]
|
312 |
+
|
313 |
+
# gt info
|
314 |
+
gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
|
315 |
+
gt_labels = gt_info[:, :, :1]
|
316 |
+
gt_bboxes = gt_info[:, :, 1:] # xyxy
|
317 |
+
pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
|
318 |
+
|
319 |
+
# pred info
|
320 |
+
flatten_cls_preds = [
|
321 |
+
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
322 |
+
self.num_classes)
|
323 |
+
for cls_pred in cls_scores
|
324 |
+
]
|
325 |
+
flatten_pred_bboxes = [
|
326 |
+
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
327 |
+
for bbox_pred in bbox_preds
|
328 |
+
]
|
329 |
+
# (bs, n, 4 * reg_max)
|
330 |
+
flatten_pred_dists = [
|
331 |
+
bbox_pred_org.reshape(num_imgs, -1, self.head_module.reg_max * 4)
|
332 |
+
for bbox_pred_org in bbox_dist_preds
|
333 |
+
]
|
334 |
+
|
335 |
+
flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
|
336 |
+
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
|
337 |
+
flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
|
338 |
+
flatten_pred_bboxes = self.bbox_coder.decode(
|
339 |
+
self.flatten_priors_train[..., :2], flatten_pred_bboxes,
|
340 |
+
self.stride_tensor[..., 0])
|
341 |
+
|
342 |
+
assigned_result = self.assigner(
|
343 |
+
(flatten_pred_bboxes.detach()).type(gt_bboxes.dtype),
|
344 |
+
flatten_cls_preds.detach().sigmoid(), self.flatten_priors_train,
|
345 |
+
gt_labels, gt_bboxes, pad_bbox_flag)
|
346 |
+
|
347 |
+
assigned_bboxes = assigned_result['assigned_bboxes']
|
348 |
+
assigned_scores = assigned_result['assigned_scores']
|
349 |
+
fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
|
350 |
+
|
351 |
+
assigned_scores_sum = assigned_scores.sum().clamp(min=1)
|
352 |
+
|
353 |
+
loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores).sum()
|
354 |
+
loss_cls /= assigned_scores_sum
|
355 |
+
|
356 |
+
# rescale bbox
|
357 |
+
assigned_bboxes /= self.stride_tensor
|
358 |
+
flatten_pred_bboxes /= self.stride_tensor
|
359 |
+
|
360 |
+
# select positive samples mask
|
361 |
+
num_pos = fg_mask_pre_prior.sum()
|
362 |
+
if num_pos > 0:
|
363 |
+
# when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
|
364 |
+
# will not report an error
|
365 |
+
# iou loss
|
366 |
+
prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
|
367 |
+
pred_bboxes_pos = torch.masked_select(
|
368 |
+
flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
|
369 |
+
assigned_bboxes_pos = torch.masked_select(
|
370 |
+
assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
|
371 |
+
bbox_weight = torch.masked_select(
|
372 |
+
assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
|
373 |
+
loss_bbox = self.loss_bbox(
|
374 |
+
pred_bboxes_pos, assigned_bboxes_pos,
|
375 |
+
weight=bbox_weight) / assigned_scores_sum
|
376 |
+
|
377 |
+
# dfl loss
|
378 |
+
pred_dist_pos = flatten_dist_preds[fg_mask_pre_prior]
|
379 |
+
assigned_ltrb = self.bbox_coder.encode(
|
380 |
+
self.flatten_priors_train[..., :2] / self.stride_tensor,
|
381 |
+
assigned_bboxes,
|
382 |
+
max_dis=self.head_module.reg_max - 1,
|
383 |
+
eps=0.01)
|
384 |
+
assigned_ltrb_pos = torch.masked_select(
|
385 |
+
assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
|
386 |
+
loss_dfl = self.loss_dfl(
|
387 |
+
pred_dist_pos.reshape(-1, self.head_module.reg_max),
|
388 |
+
assigned_ltrb_pos.reshape(-1),
|
389 |
+
weight=bbox_weight.expand(-1, 4).reshape(-1),
|
390 |
+
avg_factor=assigned_scores_sum)
|
391 |
+
else:
|
392 |
+
loss_bbox = flatten_pred_bboxes.sum() * 0
|
393 |
+
loss_dfl = flatten_pred_bboxes.sum() * 0
|
394 |
+
_, world_size = get_dist_info()
|
395 |
+
return dict(
|
396 |
+
loss_cls=loss_cls * num_imgs * world_size,
|
397 |
+
loss_bbox=loss_bbox * num_imgs * world_size,
|
398 |
+
loss_dfl=loss_dfl * num_imgs * world_size)
|
mmyolo/models/dense_heads/yolox_head.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
8 |
+
from mmdet.models.task_modules.samplers import PseudoSampler
|
9 |
+
from mmdet.models.utils import multi_apply
|
10 |
+
from mmdet.structures.bbox import bbox_xyxy_to_cxcywh
|
11 |
+
from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
|
12 |
+
OptMultiConfig, reduce_mean)
|
13 |
+
from mmengine.model import BaseModule, bias_init_with_prob
|
14 |
+
from mmengine.structures import InstanceData
|
15 |
+
from torch import Tensor
|
16 |
+
|
17 |
+
from mmyolo.registry import MODELS, TASK_UTILS
|
18 |
+
from .yolov5_head import YOLOv5Head
|
19 |
+
|
20 |
+
|
21 |
+
@MODELS.register_module()
|
22 |
+
class YOLOXHeadModule(BaseModule):
|
23 |
+
"""YOLOXHead head module used in `YOLOX.
|
24 |
+
|
25 |
+
`<https://arxiv.org/abs/2107.08430>`_
|
26 |
+
|
27 |
+
Args:
|
28 |
+
num_classes (int): Number of categories excluding the background
|
29 |
+
category.
|
30 |
+
in_channels (Union[int, Sequence]): Number of channels in the input
|
31 |
+
feature map.
|
32 |
+
widen_factor (float): Width multiplier, multiply number of
|
33 |
+
channels in each layer by this amount. Defaults to 1.0.
|
34 |
+
num_base_priors (int): The number of priors (points) at a point
|
35 |
+
on the feature grid
|
36 |
+
stacked_convs (int): Number of stacking convs of the head.
|
37 |
+
Defaults to 2.
|
38 |
+
featmap_strides (Sequence[int]): Downsample factor of each feature map.
|
39 |
+
Defaults to [8, 16, 32].
|
40 |
+
use_depthwise (bool): Whether to depthwise separable convolution in
|
41 |
+
blocks. Defaults to False.
|
42 |
+
dcn_on_last_conv (bool): If true, use dcn in the last layer of
|
43 |
+
towers. Defaults to False.
|
44 |
+
conv_bias (bool or str): If specified as `auto`, it will be decided by
|
45 |
+
the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
|
46 |
+
None, otherwise False. Defaults to "auto".
|
47 |
+
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
|
48 |
+
convolution layer. Defaults to None.
|
49 |
+
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
|
50 |
+
layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
|
51 |
+
act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
|
52 |
+
Defaults to None.
|
53 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
54 |
+
list[dict], optional): Initialization config dict.
|
55 |
+
Defaults to None.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
num_classes: int,
|
61 |
+
in_channels: Union[int, Sequence],
|
62 |
+
widen_factor: float = 1.0,
|
63 |
+
num_base_priors: int = 1,
|
64 |
+
feat_channels: int = 256,
|
65 |
+
stacked_convs: int = 2,
|
66 |
+
featmap_strides: Sequence[int] = [8, 16, 32],
|
67 |
+
use_depthwise: bool = False,
|
68 |
+
dcn_on_last_conv: bool = False,
|
69 |
+
conv_bias: Union[bool, str] = 'auto',
|
70 |
+
conv_cfg: OptConfigType = None,
|
71 |
+
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
|
72 |
+
act_cfg: ConfigType = dict(type='SiLU', inplace=True),
|
73 |
+
init_cfg: OptMultiConfig = None,
|
74 |
+
):
|
75 |
+
super().__init__(init_cfg=init_cfg)
|
76 |
+
self.num_classes = num_classes
|
77 |
+
self.feat_channels = int(feat_channels * widen_factor)
|
78 |
+
self.stacked_convs = stacked_convs
|
79 |
+
self.use_depthwise = use_depthwise
|
80 |
+
self.dcn_on_last_conv = dcn_on_last_conv
|
81 |
+
assert conv_bias == 'auto' or isinstance(conv_bias, bool)
|
82 |
+
self.conv_bias = conv_bias
|
83 |
+
self.num_base_priors = num_base_priors
|
84 |
+
|
85 |
+
self.conv_cfg = conv_cfg
|
86 |
+
self.norm_cfg = norm_cfg
|
87 |
+
self.act_cfg = act_cfg
|
88 |
+
self.featmap_strides = featmap_strides
|
89 |
+
|
90 |
+
if isinstance(in_channels, int):
|
91 |
+
in_channels = int(in_channels * widen_factor)
|
92 |
+
self.in_channels = in_channels
|
93 |
+
|
94 |
+
self._init_layers()
|
95 |
+
|
96 |
+
def _init_layers(self):
|
97 |
+
"""Initialize heads for all level feature maps."""
|
98 |
+
self.multi_level_cls_convs = nn.ModuleList()
|
99 |
+
self.multi_level_reg_convs = nn.ModuleList()
|
100 |
+
self.multi_level_conv_cls = nn.ModuleList()
|
101 |
+
self.multi_level_conv_reg = nn.ModuleList()
|
102 |
+
self.multi_level_conv_obj = nn.ModuleList()
|
103 |
+
for _ in self.featmap_strides:
|
104 |
+
self.multi_level_cls_convs.append(self._build_stacked_convs())
|
105 |
+
self.multi_level_reg_convs.append(self._build_stacked_convs())
|
106 |
+
conv_cls, conv_reg, conv_obj = self._build_predictor()
|
107 |
+
self.multi_level_conv_cls.append(conv_cls)
|
108 |
+
self.multi_level_conv_reg.append(conv_reg)
|
109 |
+
self.multi_level_conv_obj.append(conv_obj)
|
110 |
+
|
111 |
+
def _build_stacked_convs(self) -> nn.Sequential:
|
112 |
+
"""Initialize conv layers of a single level head."""
|
113 |
+
conv = DepthwiseSeparableConvModule \
|
114 |
+
if self.use_depthwise else ConvModule
|
115 |
+
stacked_convs = []
|
116 |
+
for i in range(self.stacked_convs):
|
117 |
+
chn = self.in_channels if i == 0 else self.feat_channels
|
118 |
+
if self.dcn_on_last_conv and i == self.stacked_convs - 1:
|
119 |
+
conv_cfg = dict(type='DCNv2')
|
120 |
+
else:
|
121 |
+
conv_cfg = self.conv_cfg
|
122 |
+
stacked_convs.append(
|
123 |
+
conv(
|
124 |
+
chn,
|
125 |
+
self.feat_channels,
|
126 |
+
3,
|
127 |
+
stride=1,
|
128 |
+
padding=1,
|
129 |
+
conv_cfg=conv_cfg,
|
130 |
+
norm_cfg=self.norm_cfg,
|
131 |
+
act_cfg=self.act_cfg,
|
132 |
+
bias=self.conv_bias))
|
133 |
+
return nn.Sequential(*stacked_convs)
|
134 |
+
|
135 |
+
def _build_predictor(self) -> Tuple[nn.Module, nn.Module, nn.Module]:
|
136 |
+
"""Initialize predictor layers of a single level head."""
|
137 |
+
conv_cls = nn.Conv2d(self.feat_channels, self.num_classes, 1)
|
138 |
+
conv_reg = nn.Conv2d(self.feat_channels, 4, 1)
|
139 |
+
conv_obj = nn.Conv2d(self.feat_channels, 1, 1)
|
140 |
+
return conv_cls, conv_reg, conv_obj
|
141 |
+
|
142 |
+
def init_weights(self):
|
143 |
+
"""Initialize weights of the head."""
|
144 |
+
# Use prior in model initialization to improve stability
|
145 |
+
super().init_weights()
|
146 |
+
bias_init = bias_init_with_prob(0.01)
|
147 |
+
for conv_cls, conv_obj in zip(self.multi_level_conv_cls,
|
148 |
+
self.multi_level_conv_obj):
|
149 |
+
conv_cls.bias.data.fill_(bias_init)
|
150 |
+
conv_obj.bias.data.fill_(bias_init)
|
151 |
+
|
152 |
+
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
153 |
+
"""Forward features from the upstream network.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
x (Tuple[Tensor]): Features from the upstream network, each is
|
157 |
+
a 4D-tensor.
|
158 |
+
Returns:
|
159 |
+
Tuple[List]: A tuple of multi-level classification scores, bbox
|
160 |
+
predictions, and objectnesses.
|
161 |
+
"""
|
162 |
+
|
163 |
+
return multi_apply(self.forward_single, x, self.multi_level_cls_convs,
|
164 |
+
self.multi_level_reg_convs,
|
165 |
+
self.multi_level_conv_cls,
|
166 |
+
self.multi_level_conv_reg,
|
167 |
+
self.multi_level_conv_obj)
|
168 |
+
|
169 |
+
def forward_single(self, x: Tensor, cls_convs: nn.Module,
|
170 |
+
reg_convs: nn.Module, conv_cls: nn.Module,
|
171 |
+
conv_reg: nn.Module,
|
172 |
+
conv_obj: nn.Module) -> Tuple[Tensor, Tensor, Tensor]:
|
173 |
+
"""Forward feature of a single scale level."""
|
174 |
+
|
175 |
+
cls_feat = cls_convs(x)
|
176 |
+
reg_feat = reg_convs(x)
|
177 |
+
|
178 |
+
cls_score = conv_cls(cls_feat)
|
179 |
+
bbox_pred = conv_reg(reg_feat)
|
180 |
+
objectness = conv_obj(reg_feat)
|
181 |
+
|
182 |
+
return cls_score, bbox_pred, objectness
|
183 |
+
|
184 |
+
|
185 |
+
@MODELS.register_module()
|
186 |
+
class YOLOXHead(YOLOv5Head):
|
187 |
+
"""YOLOXHead head used in `YOLOX <https://arxiv.org/abs/2107.08430>`_.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
head_module(ConfigType): Base module used for YOLOXHead
|
191 |
+
prior_generator: Points generator feature maps in
|
192 |
+
2D points-based detectors.
|
193 |
+
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
|
194 |
+
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
|
195 |
+
loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
|
196 |
+
loss_bbox_aux (:obj:`ConfigDict` or dict): Config of bbox aux loss.
|
197 |
+
train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
|
198 |
+
anchor head. Defaults to None.
|
199 |
+
test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
|
200 |
+
anchor head. Defaults to None.
|
201 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
202 |
+
list[dict], optional): Initialization config dict.
|
203 |
+
Defaults to None.
|
204 |
+
"""
|
205 |
+
|
206 |
+
def __init__(self,
|
207 |
+
head_module: ConfigType,
|
208 |
+
prior_generator: ConfigType = dict(
|
209 |
+
type='mmdet.MlvlPointGenerator',
|
210 |
+
offset=0,
|
211 |
+
strides=[8, 16, 32]),
|
212 |
+
bbox_coder: ConfigType = dict(type='YOLOXBBoxCoder'),
|
213 |
+
loss_cls: ConfigType = dict(
|
214 |
+
type='mmdet.CrossEntropyLoss',
|
215 |
+
use_sigmoid=True,
|
216 |
+
reduction='sum',
|
217 |
+
loss_weight=1.0),
|
218 |
+
loss_bbox: ConfigType = dict(
|
219 |
+
type='mmdet.IoULoss',
|
220 |
+
mode='square',
|
221 |
+
eps=1e-16,
|
222 |
+
reduction='sum',
|
223 |
+
loss_weight=5.0),
|
224 |
+
loss_obj: ConfigType = dict(
|
225 |
+
type='mmdet.CrossEntropyLoss',
|
226 |
+
use_sigmoid=True,
|
227 |
+
reduction='sum',
|
228 |
+
loss_weight=1.0),
|
229 |
+
loss_bbox_aux: ConfigType = dict(
|
230 |
+
type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
|
231 |
+
train_cfg: OptConfigType = None,
|
232 |
+
test_cfg: OptConfigType = None,
|
233 |
+
init_cfg: OptMultiConfig = None):
|
234 |
+
self.use_bbox_aux = False
|
235 |
+
self.loss_bbox_aux = loss_bbox_aux
|
236 |
+
|
237 |
+
super().__init__(
|
238 |
+
head_module=head_module,
|
239 |
+
prior_generator=prior_generator,
|
240 |
+
bbox_coder=bbox_coder,
|
241 |
+
loss_cls=loss_cls,
|
242 |
+
loss_bbox=loss_bbox,
|
243 |
+
loss_obj=loss_obj,
|
244 |
+
train_cfg=train_cfg,
|
245 |
+
test_cfg=test_cfg,
|
246 |
+
init_cfg=init_cfg)
|
247 |
+
|
248 |
+
def special_init(self):
|
249 |
+
"""Since YOLO series algorithms will inherit from YOLOv5Head, but
|
250 |
+
different algorithms have special initialization process.
|
251 |
+
|
252 |
+
The special_init function is designed to deal with this situation.
|
253 |
+
"""
|
254 |
+
self.loss_bbox_aux: nn.Module = MODELS.build(self.loss_bbox_aux)
|
255 |
+
if self.train_cfg:
|
256 |
+
self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
|
257 |
+
# YOLOX does not support sampling
|
258 |
+
self.sampler = PseudoSampler()
|
259 |
+
|
260 |
+
def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
|
261 |
+
return self.head_module(x)
|
262 |
+
|
263 |
+
def loss_by_feat(
|
264 |
+
self,
|
265 |
+
cls_scores: Sequence[Tensor],
|
266 |
+
bbox_preds: Sequence[Tensor],
|
267 |
+
objectnesses: Sequence[Tensor],
|
268 |
+
batch_gt_instances: Tensor,
|
269 |
+
batch_img_metas: Sequence[dict],
|
270 |
+
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
|
271 |
+
"""Calculate the loss based on the features extracted by the detection
|
272 |
+
head.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
cls_scores (Sequence[Tensor]): Box scores for each scale level,
|
276 |
+
each is a 4D-tensor, the channel number is
|
277 |
+
num_priors * num_classes.
|
278 |
+
bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
|
279 |
+
level, each is a 4D-tensor, the channel number is
|
280 |
+
num_priors * 4.
|
281 |
+
objectnesses (Sequence[Tensor]): Score factor for
|
282 |
+
all scale level, each is a 4D-tensor, has shape
|
283 |
+
(batch_size, 1, H, W).
|
284 |
+
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
|
285 |
+
gt_instance. It usually includes ``bboxes`` and ``labels``
|
286 |
+
attributes.
|
287 |
+
batch_img_metas (list[dict]): Meta information of each image, e.g.,
|
288 |
+
image size, scaling factor, etc.
|
289 |
+
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
|
290 |
+
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
|
291 |
+
data that is ignored during training and testing.
|
292 |
+
Defaults to None.
|
293 |
+
Returns:
|
294 |
+
dict[str, Tensor]: A dictionary of losses.
|
295 |
+
"""
|
296 |
+
num_imgs = len(batch_img_metas)
|
297 |
+
if batch_gt_instances_ignore is None:
|
298 |
+
batch_gt_instances_ignore = [None] * num_imgs
|
299 |
+
|
300 |
+
batch_gt_instances = self.gt_instances_preprocess(
|
301 |
+
batch_gt_instances, len(batch_img_metas))
|
302 |
+
|
303 |
+
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
|
304 |
+
mlvl_priors = self.prior_generator.grid_priors(
|
305 |
+
featmap_sizes,
|
306 |
+
dtype=cls_scores[0].dtype,
|
307 |
+
device=cls_scores[0].device,
|
308 |
+
with_stride=True)
|
309 |
+
|
310 |
+
flatten_cls_preds = [
|
311 |
+
cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
|
312 |
+
self.num_classes)
|
313 |
+
for cls_pred in cls_scores
|
314 |
+
]
|
315 |
+
flatten_bbox_preds = [
|
316 |
+
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
|
317 |
+
for bbox_pred in bbox_preds
|
318 |
+
]
|
319 |
+
flatten_objectness = [
|
320 |
+
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
|
321 |
+
for objectness in objectnesses
|
322 |
+
]
|
323 |
+
|
324 |
+
flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
|
325 |
+
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
|
326 |
+
flatten_objectness = torch.cat(flatten_objectness, dim=1)
|
327 |
+
flatten_priors = torch.cat(mlvl_priors)
|
328 |
+
flatten_bboxes = self.bbox_coder.decode(flatten_priors[..., :2],
|
329 |
+
flatten_bbox_preds,
|
330 |
+
flatten_priors[..., 2])
|
331 |
+
|
332 |
+
(pos_masks, cls_targets, obj_targets, bbox_targets, bbox_aux_target,
|
333 |
+
num_fg_imgs) = multi_apply(
|
334 |
+
self._get_targets_single,
|
335 |
+
flatten_priors.unsqueeze(0).repeat(num_imgs, 1, 1),
|
336 |
+
flatten_cls_preds.detach(), flatten_bboxes.detach(),
|
337 |
+
flatten_objectness.detach(), batch_gt_instances, batch_img_metas,
|
338 |
+
batch_gt_instances_ignore)
|
339 |
+
|
340 |
+
# The experimental results show that 'reduce_mean' can improve
|
341 |
+
# performance on the COCO dataset.
|
342 |
+
num_pos = torch.tensor(
|
343 |
+
sum(num_fg_imgs),
|
344 |
+
dtype=torch.float,
|
345 |
+
device=flatten_cls_preds.device)
|
346 |
+
num_total_samples = max(reduce_mean(num_pos), 1.0)
|
347 |
+
|
348 |
+
pos_masks = torch.cat(pos_masks, 0)
|
349 |
+
cls_targets = torch.cat(cls_targets, 0)
|
350 |
+
obj_targets = torch.cat(obj_targets, 0)
|
351 |
+
bbox_targets = torch.cat(bbox_targets, 0)
|
352 |
+
if self.use_bbox_aux:
|
353 |
+
bbox_aux_target = torch.cat(bbox_aux_target, 0)
|
354 |
+
|
355 |
+
loss_obj = self.loss_obj(flatten_objectness.view(-1, 1),
|
356 |
+
obj_targets) / num_total_samples
|
357 |
+
if num_pos > 0:
|
358 |
+
loss_cls = self.loss_cls(
|
359 |
+
flatten_cls_preds.view(-1, self.num_classes)[pos_masks],
|
360 |
+
cls_targets) / num_total_samples
|
361 |
+
loss_bbox = self.loss_bbox(
|
362 |
+
flatten_bboxes.view(-1, 4)[pos_masks],
|
363 |
+
bbox_targets) / num_total_samples
|
364 |
+
else:
|
365 |
+
# Avoid cls and reg branch not participating in the gradient
|
366 |
+
# propagation when there is no ground-truth in the images.
|
367 |
+
# For more details, please refer to
|
368 |
+
# https://github.com/open-mmlab/mmdetection/issues/7298
|
369 |
+
loss_cls = flatten_cls_preds.sum() * 0
|
370 |
+
loss_bbox = flatten_bboxes.sum() * 0
|
371 |
+
|
372 |
+
loss_dict = dict(
|
373 |
+
loss_cls=loss_cls, loss_bbox=loss_bbox, loss_obj=loss_obj)
|
374 |
+
|
375 |
+
if self.use_bbox_aux:
|
376 |
+
if num_pos > 0:
|
377 |
+
loss_bbox_aux = self.loss_bbox_aux(
|
378 |
+
flatten_bbox_preds.view(-1, 4)[pos_masks],
|
379 |
+
bbox_aux_target) / num_total_samples
|
380 |
+
else:
|
381 |
+
# Avoid cls and reg branch not participating in the gradient
|
382 |
+
# propagation when there is no ground-truth in the images.
|
383 |
+
# For more details, please refer to
|
384 |
+
# https://github.com/open-mmlab/mmdetection/issues/7298
|
385 |
+
loss_bbox_aux = flatten_bbox_preds.sum() * 0
|
386 |
+
loss_dict.update(loss_bbox_aux=loss_bbox_aux)
|
387 |
+
|
388 |
+
return loss_dict
|
389 |
+
|
390 |
+
@torch.no_grad()
|
391 |
+
def _get_targets_single(
|
392 |
+
self,
|
393 |
+
priors: Tensor,
|
394 |
+
cls_preds: Tensor,
|
395 |
+
decoded_bboxes: Tensor,
|
396 |
+
objectness: Tensor,
|
397 |
+
gt_instances: InstanceData,
|
398 |
+
img_meta: dict,
|
399 |
+
gt_instances_ignore: Optional[InstanceData] = None) -> tuple:
|
400 |
+
"""Compute classification, regression, and objectness targets for
|
401 |
+
priors in a single image.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
priors (Tensor): All priors of one image, a 2D-Tensor with shape
|
405 |
+
[num_priors, 4] in [cx, xy, stride_w, stride_y] format.
|
406 |
+
cls_preds (Tensor): Classification predictions of one image,
|
407 |
+
a 2D-Tensor with shape [num_priors, num_classes]
|
408 |
+
decoded_bboxes (Tensor): Decoded bboxes predictions of one image,
|
409 |
+
a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y,
|
410 |
+
br_x, br_y] format.
|
411 |
+
objectness (Tensor): Objectness predictions of one image,
|
412 |
+
a 1D-Tensor with shape [num_priors]
|
413 |
+
gt_instances (:obj:`InstanceData`): Ground truth of instance
|
414 |
+
annotations. It should includes ``bboxes`` and ``labels``
|
415 |
+
attributes.
|
416 |
+
img_meta (dict): Meta information for current image.
|
417 |
+
gt_instances_ignore (:obj:`InstanceData`, optional): Instances
|
418 |
+
to be ignored during training. It includes ``bboxes`` attribute
|
419 |
+
data that is ignored during training and testing.
|
420 |
+
Defaults to None.
|
421 |
+
Returns:
|
422 |
+
tuple:
|
423 |
+
foreground_mask (list[Tensor]): Binary mask of foreground
|
424 |
+
targets.
|
425 |
+
cls_target (list[Tensor]): Classification targets of an image.
|
426 |
+
obj_target (list[Tensor]): Objectness targets of an image.
|
427 |
+
bbox_target (list[Tensor]): BBox targets of an image.
|
428 |
+
bbox_aux_target (int): BBox aux targets of an image.
|
429 |
+
num_pos_per_img (int): Number of positive samples in an image.
|
430 |
+
"""
|
431 |
+
|
432 |
+
num_priors = priors.size(0)
|
433 |
+
num_gts = len(gt_instances)
|
434 |
+
# No target
|
435 |
+
if num_gts == 0:
|
436 |
+
cls_target = cls_preds.new_zeros((0, self.num_classes))
|
437 |
+
bbox_target = cls_preds.new_zeros((0, 4))
|
438 |
+
bbox_aux_target = cls_preds.new_zeros((0, 4))
|
439 |
+
obj_target = cls_preds.new_zeros((num_priors, 1))
|
440 |
+
foreground_mask = cls_preds.new_zeros(num_priors).bool()
|
441 |
+
return (foreground_mask, cls_target, obj_target, bbox_target,
|
442 |
+
bbox_aux_target, 0)
|
443 |
+
|
444 |
+
# YOLOX uses center priors with 0.5 offset to assign targets,
|
445 |
+
# but use center priors without offset to regress bboxes.
|
446 |
+
offset_priors = torch.cat(
|
447 |
+
[priors[:, :2] + priors[:, 2:] * 0.5, priors[:, 2:]], dim=-1)
|
448 |
+
|
449 |
+
scores = cls_preds.sigmoid() * objectness.unsqueeze(1).sigmoid()
|
450 |
+
pred_instances = InstanceData(
|
451 |
+
bboxes=decoded_bboxes, scores=scores.sqrt_(), priors=offset_priors)
|
452 |
+
assign_result = self.assigner.assign(
|
453 |
+
pred_instances=pred_instances,
|
454 |
+
gt_instances=gt_instances,
|
455 |
+
gt_instances_ignore=gt_instances_ignore)
|
456 |
+
|
457 |
+
sampling_result = self.sampler.sample(assign_result, pred_instances,
|
458 |
+
gt_instances)
|
459 |
+
pos_inds = sampling_result.pos_inds
|
460 |
+
num_pos_per_img = pos_inds.size(0)
|
461 |
+
|
462 |
+
pos_ious = assign_result.max_overlaps[pos_inds]
|
463 |
+
# IOU aware classification score
|
464 |
+
cls_target = F.one_hot(sampling_result.pos_gt_labels,
|
465 |
+
self.num_classes) * pos_ious.unsqueeze(-1)
|
466 |
+
obj_target = torch.zeros_like(objectness).unsqueeze(-1)
|
467 |
+
obj_target[pos_inds] = 1
|
468 |
+
bbox_target = sampling_result.pos_gt_bboxes
|
469 |
+
bbox_aux_target = cls_preds.new_zeros((num_pos_per_img, 4))
|
470 |
+
if self.use_bbox_aux:
|
471 |
+
bbox_aux_target = self._get_bbox_aux_target(
|
472 |
+
bbox_aux_target, bbox_target, priors[pos_inds])
|
473 |
+
foreground_mask = torch.zeros_like(objectness).to(torch.bool)
|
474 |
+
foreground_mask[pos_inds] = 1
|
475 |
+
return (foreground_mask, cls_target, obj_target, bbox_target,
|
476 |
+
bbox_aux_target, num_pos_per_img)
|
477 |
+
|
478 |
+
def _get_bbox_aux_target(self,
|
479 |
+
bbox_aux_target: Tensor,
|
480 |
+
gt_bboxes: Tensor,
|
481 |
+
priors: Tensor,
|
482 |
+
eps: float = 1e-8) -> Tensor:
|
483 |
+
"""Convert gt bboxes to center offset and log width height."""
|
484 |
+
gt_cxcywh = bbox_xyxy_to_cxcywh(gt_bboxes)
|
485 |
+
bbox_aux_target[:, :2] = (gt_cxcywh[:, :2] -
|
486 |
+
priors[:, :2]) / priors[:, 2:]
|
487 |
+
bbox_aux_target[:,
|
488 |
+
2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps)
|
489 |
+
return bbox_aux_target
|
490 |
+
|
491 |
+
@staticmethod
|
492 |
+
def gt_instances_preprocess(batch_gt_instances: Tensor,
|
493 |
+
batch_size: int) -> List[InstanceData]:
|
494 |
+
"""Split batch_gt_instances with batch size.
|
495 |
+
|
496 |
+
Args:
|
497 |
+
batch_gt_instances (Tensor): Ground truth
|
498 |
+
a 2D-Tensor for whole batch, shape [all_gt_bboxes, 6]
|
499 |
+
batch_size (int): Batch size.
|
500 |
+
|
501 |
+
Returns:
|
502 |
+
List: batch gt instances data, shape [batch_size, InstanceData]
|
503 |
+
"""
|
504 |
+
# faster version
|
505 |
+
batch_instance_list = []
|
506 |
+
for i in range(batch_size):
|
507 |
+
batch_gt_instance_ = InstanceData()
|
508 |
+
single_batch_instance = \
|
509 |
+
batch_gt_instances[batch_gt_instances[:, 0] == i, :]
|
510 |
+
batch_gt_instance_.bboxes = single_batch_instance[:, 2:]
|
511 |
+
batch_gt_instance_.labels = single_batch_instance[:, 1]
|
512 |
+
batch_instance_list.append(batch_gt_instance_)
|
513 |
+
|
514 |
+
return batch_instance_list
|
mmyolo/models/detectors/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .yolo_detector import YOLODetector
|
3 |
+
|
4 |
+
__all__ = ['YOLODetector']
|
mmyolo/models/detectors/yolo_detector.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
from mmdet.models.detectors.single_stage import SingleStageDetector
|
4 |
+
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
|
5 |
+
from mmengine.dist import get_world_size
|
6 |
+
from mmengine.logging import print_log
|
7 |
+
|
8 |
+
from mmyolo.registry import MODELS
|
9 |
+
|
10 |
+
|
11 |
+
@MODELS.register_module()
|
12 |
+
class YOLODetector(SingleStageDetector):
|
13 |
+
r"""Implementation of YOLO Series
|
14 |
+
|
15 |
+
Args:
|
16 |
+
backbone (:obj:`ConfigDict` or dict): The backbone config.
|
17 |
+
neck (:obj:`ConfigDict` or dict): The neck config.
|
18 |
+
bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
|
19 |
+
train_cfg (:obj:`ConfigDict` or dict, optional): The training config
|
20 |
+
of YOLO. Defaults to None.
|
21 |
+
test_cfg (:obj:`ConfigDict` or dict, optional): The testing config
|
22 |
+
of YOLO. Defaults to None.
|
23 |
+
data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
|
24 |
+
:class:`DetDataPreprocessor` to process the input data.
|
25 |
+
Defaults to None.
|
26 |
+
init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
|
27 |
+
list[dict], optional): Initialization config dict.
|
28 |
+
Defaults to None.
|
29 |
+
use_syncbn (bool): whether to use SyncBatchNorm. Defaults to True.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
backbone: ConfigType,
|
34 |
+
neck: ConfigType,
|
35 |
+
bbox_head: ConfigType,
|
36 |
+
train_cfg: OptConfigType = None,
|
37 |
+
test_cfg: OptConfigType = None,
|
38 |
+
data_preprocessor: OptConfigType = None,
|
39 |
+
init_cfg: OptMultiConfig = None,
|
40 |
+
use_syncbn: bool = True):
|
41 |
+
super().__init__(
|
42 |
+
backbone=backbone,
|
43 |
+
neck=neck,
|
44 |
+
bbox_head=bbox_head,
|
45 |
+
train_cfg=train_cfg,
|
46 |
+
test_cfg=test_cfg,
|
47 |
+
data_preprocessor=data_preprocessor,
|
48 |
+
init_cfg=init_cfg)
|
49 |
+
|
50 |
+
# TODO: Waiting for mmengine support
|
51 |
+
if use_syncbn and get_world_size() > 1:
|
52 |
+
torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
|
53 |
+
print_log('Using SyncBatchNorm()', 'current')
|
mmyolo/models/layers/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .ema import ExpMomentumEMA
|
3 |
+
from .yolo_bricks import (BepC3StageBlock, CSPLayerWithTwoConv,
|
4 |
+
DarknetBottleneck, EELANBlock, EffectiveSELayer,
|
5 |
+
ELANBlock, ImplicitA, ImplicitM,
|
6 |
+
MaxPoolAndStrideConvBlock, PPYOLOEBasicBlock,
|
7 |
+
RepStageBlock, RepVGGBlock, SPPFBottleneck,
|
8 |
+
SPPFCSPBlock, TinyDownSampleBlock)
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
'SPPFBottleneck', 'RepVGGBlock', 'RepStageBlock', 'ExpMomentumEMA',
|
12 |
+
'ELANBlock', 'MaxPoolAndStrideConvBlock', 'SPPFCSPBlock',
|
13 |
+
'PPYOLOEBasicBlock', 'EffectiveSELayer', 'TinyDownSampleBlock',
|
14 |
+
'EELANBlock', 'ImplicitA', 'ImplicitM', 'BepC3StageBlock',
|
15 |
+
'CSPLayerWithTwoConv', 'DarknetBottleneck'
|
16 |
+
]
|
mmyolo/models/layers/ema.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from mmdet.models.layers import ExpMomentumEMA as MMDET_ExpMomentumEMA
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from mmyolo.registry import MODELS
|
11 |
+
|
12 |
+
|
13 |
+
@MODELS.register_module()
|
14 |
+
class ExpMomentumEMA(MMDET_ExpMomentumEMA):
|
15 |
+
"""Exponential moving average (EMA) with exponential momentum strategy,
|
16 |
+
which is used in YOLO.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model (nn.Module): The model to be averaged.
|
20 |
+
momentum (float): The momentum used for updating ema parameter.
|
21 |
+
Ema's parameters are updated with the formula:
|
22 |
+
`averaged_param = (1-momentum) * averaged_param + momentum *
|
23 |
+
source_param`. Defaults to 0.0002.
|
24 |
+
gamma (int): Use a larger momentum early in training and gradually
|
25 |
+
annealing to a smaller value to update the ema model smoothly. The
|
26 |
+
momentum is calculated as
|
27 |
+
`(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`.
|
28 |
+
Defaults to 2000.
|
29 |
+
interval (int): Interval between two updates. Defaults to 1.
|
30 |
+
device (torch.device, optional): If provided, the averaged model will
|
31 |
+
be stored on the :attr:`device`. Defaults to None.
|
32 |
+
update_buffers (bool): if True, it will compute running averages for
|
33 |
+
both the parameters and the buffers of the model. Defaults to
|
34 |
+
False.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
model: nn.Module,
|
39 |
+
momentum: float = 0.0002,
|
40 |
+
gamma: int = 2000,
|
41 |
+
interval=1,
|
42 |
+
device: Optional[torch.device] = None,
|
43 |
+
update_buffers: bool = False):
|
44 |
+
super().__init__(
|
45 |
+
model=model,
|
46 |
+
momentum=momentum,
|
47 |
+
interval=interval,
|
48 |
+
device=device,
|
49 |
+
update_buffers=update_buffers)
|
50 |
+
assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
|
51 |
+
self.gamma = gamma
|
52 |
+
|
53 |
+
# Note: There is no need to re-fetch every update,
|
54 |
+
# as most models do not change their structure
|
55 |
+
# during the training process.
|
56 |
+
self.src_parameters = (
|
57 |
+
model.state_dict()
|
58 |
+
if self.update_buffers else dict(model.named_parameters()))
|
59 |
+
if not self.update_buffers:
|
60 |
+
self.src_buffers = model.buffers()
|
61 |
+
|
62 |
+
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
|
63 |
+
steps: int):
|
64 |
+
"""Compute the moving average of the parameters using the exponential
|
65 |
+
momentum strategy.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
averaged_param (Tensor): The averaged parameters.
|
69 |
+
source_param (Tensor): The source parameters.
|
70 |
+
steps (int): The number of times the parameters have been
|
71 |
+
updated.
|
72 |
+
"""
|
73 |
+
momentum = (1 - self.momentum) * math.exp(
|
74 |
+
-float(1 + steps) / self.gamma) + self.momentum
|
75 |
+
averaged_param.lerp_(source_param, momentum)
|
76 |
+
|
77 |
+
def update_parameters(self, model: nn.Module):
|
78 |
+
"""Update the parameters after each training step.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
model (nn.Module): The model of the parameter needs to be updated.
|
82 |
+
"""
|
83 |
+
if self.steps == 0:
|
84 |
+
for k, p_avg in self.avg_parameters.items():
|
85 |
+
p_avg.data.copy_(self.src_parameters[k].data)
|
86 |
+
elif self.steps % self.interval == 0:
|
87 |
+
for k, p_avg in self.avg_parameters.items():
|
88 |
+
if p_avg.dtype.is_floating_point:
|
89 |
+
self.avg_func(p_avg.data, self.src_parameters[k].data,
|
90 |
+
self.steps)
|
91 |
+
if not self.update_buffers:
|
92 |
+
# If not update the buffers,
|
93 |
+
# keep the buffers in sync with the source model.
|
94 |
+
for b_avg, b_src in zip(self.module.buffers(), self.src_buffers):
|
95 |
+
b_avg.data.copy_(b_src.data)
|
96 |
+
self.steps += 1
|