KyanChen commited on
Commit
3094730
1 Parent(s): 2ae34e9

Upload 89 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. mmyolo/__init__.py +39 -0
  2. mmyolo/datasets/__init__.py +12 -0
  3. mmyolo/datasets/transforms/__init__.py +14 -0
  4. mmyolo/datasets/transforms/mix_img_transforms.py +1150 -0
  5. mmyolo/datasets/transforms/transforms.py +1557 -0
  6. mmyolo/datasets/utils.py +114 -0
  7. mmyolo/datasets/yolov5_coco.py +65 -0
  8. mmyolo/datasets/yolov5_crowdhuman.py +15 -0
  9. mmyolo/datasets/yolov5_dota.py +29 -0
  10. mmyolo/datasets/yolov5_voc.py +15 -0
  11. mmyolo/deploy/__init__.py +7 -0
  12. mmyolo/deploy/models/__init__.py +2 -0
  13. mmyolo/deploy/models/dense_heads/__init__.py +4 -0
  14. mmyolo/deploy/models/dense_heads/yolov5_head.py +189 -0
  15. mmyolo/deploy/models/layers/__init__.py +4 -0
  16. mmyolo/deploy/models/layers/bbox_nms.py +113 -0
  17. mmyolo/deploy/object_detection.py +132 -0
  18. mmyolo/engine/__init__.py +3 -0
  19. mmyolo/engine/hooks/__init__.py +10 -0
  20. mmyolo/engine/hooks/ppyoloe_param_scheduler_hook.py +96 -0
  21. mmyolo/engine/hooks/switch_to_deploy_hook.py +21 -0
  22. mmyolo/engine/hooks/yolov5_param_scheduler_hook.py +130 -0
  23. mmyolo/engine/hooks/yolox_mode_switch_hook.py +54 -0
  24. mmyolo/engine/optimizers/__init__.py +5 -0
  25. mmyolo/engine/optimizers/yolov5_optim_constructor.py +132 -0
  26. mmyolo/engine/optimizers/yolov7_optim_wrapper_constructor.py +139 -0
  27. mmyolo/models/__init__.py +10 -0
  28. mmyolo/models/backbones/__init__.py +13 -0
  29. mmyolo/models/backbones/base_backbone.py +225 -0
  30. mmyolo/models/backbones/csp_darknet.py +427 -0
  31. mmyolo/models/backbones/csp_resnet.py +169 -0
  32. mmyolo/models/backbones/cspnext.py +187 -0
  33. mmyolo/models/backbones/efficient_rep.py +287 -0
  34. mmyolo/models/backbones/yolov7_backbone.py +285 -0
  35. mmyolo/models/data_preprocessors/__init__.py +10 -0
  36. mmyolo/models/data_preprocessors/data_preprocessor.py +302 -0
  37. mmyolo/models/dense_heads/__init__.py +20 -0
  38. mmyolo/models/dense_heads/ppyoloe_head.py +374 -0
  39. mmyolo/models/dense_heads/rtmdet_head.py +368 -0
  40. mmyolo/models/dense_heads/rtmdet_ins_head.py +725 -0
  41. mmyolo/models/dense_heads/rtmdet_rotated_head.py +641 -0
  42. mmyolo/models/dense_heads/yolov5_head.py +890 -0
  43. mmyolo/models/dense_heads/yolov6_head.py +369 -0
  44. mmyolo/models/dense_heads/yolov7_head.py +404 -0
  45. mmyolo/models/dense_heads/yolov8_head.py +398 -0
  46. mmyolo/models/dense_heads/yolox_head.py +514 -0
  47. mmyolo/models/detectors/__init__.py +4 -0
  48. mmyolo/models/detectors/yolo_detector.py +53 -0
  49. mmyolo/models/layers/__init__.py +16 -0
  50. mmyolo/models/layers/ema.py +96 -0
mmyolo/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import mmcv
3
+ import mmdet
4
+ import mmengine
5
+ from mmengine.utils import digit_version
6
+
7
+ from .version import __version__, version_info
8
+
9
+ mmcv_minimum_version = '2.0.0rc4'
10
+ mmcv_maximum_version = '2.1.0'
11
+ mmcv_version = digit_version(mmcv.__version__)
12
+
13
+ mmengine_minimum_version = '0.6.0'
14
+ mmengine_maximum_version = '1.0.0'
15
+ mmengine_version = digit_version(mmengine.__version__)
16
+
17
+ mmdet_minimum_version = '3.0.0rc6'
18
+ mmdet_maximum_version = '3.1.0'
19
+ mmdet_version = digit_version(mmdet.__version__)
20
+
21
+
22
+ assert (mmcv_version >= digit_version(mmcv_minimum_version)
23
+ and mmcv_version < digit_version(mmcv_maximum_version)), \
24
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
25
+ f'Please install mmcv>={mmcv_minimum_version}, <{mmcv_maximum_version}.'
26
+
27
+ assert (mmengine_version >= digit_version(mmengine_minimum_version)
28
+ and mmengine_version < digit_version(mmengine_maximum_version)), \
29
+ f'MMEngine=={mmengine.__version__} is used but incompatible. ' \
30
+ f'Please install mmengine>={mmengine_minimum_version}, ' \
31
+ f'<{mmengine_maximum_version}.'
32
+
33
+ assert (mmdet_version >= digit_version(mmdet_minimum_version)
34
+ and mmdet_version < digit_version(mmdet_maximum_version)), \
35
+ f'MMDetection=={mmdet.__version__} is used but incompatible. ' \
36
+ f'Please install mmdet>={mmdet_minimum_version}, ' \
37
+ f'<{mmdet_maximum_version}.'
38
+
39
+ __all__ = ['__version__', 'version_info', 'digit_version']
mmyolo/datasets/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .transforms import * # noqa: F401,F403
3
+ from .utils import BatchShapePolicy, yolov5_collate
4
+ from .yolov5_coco import YOLOv5CocoDataset
5
+ from .yolov5_crowdhuman import YOLOv5CrowdHumanDataset
6
+ from .yolov5_dota import YOLOv5DOTADataset
7
+ from .yolov5_voc import YOLOv5VOCDataset
8
+
9
+ __all__ = [
10
+ 'YOLOv5CocoDataset', 'YOLOv5VOCDataset', 'BatchShapePolicy',
11
+ 'yolov5_collate', 'YOLOv5CrowdHumanDataset', 'YOLOv5DOTADataset'
12
+ ]
mmyolo/datasets/transforms/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .mix_img_transforms import Mosaic, Mosaic9, YOLOv5MixUp, YOLOXMixUp
3
+ from .transforms import (LetterResize, LoadAnnotations, PPYOLOERandomCrop,
4
+ PPYOLOERandomDistort, RegularizeRotatedBox,
5
+ RemoveDataElement, YOLOv5CopyPaste,
6
+ YOLOv5HSVRandomAug, YOLOv5KeepRatioResize,
7
+ YOLOv5RandomAffine)
8
+
9
+ __all__ = [
10
+ 'YOLOv5KeepRatioResize', 'LetterResize', 'Mosaic', 'YOLOXMixUp',
11
+ 'YOLOv5MixUp', 'YOLOv5HSVRandomAug', 'LoadAnnotations',
12
+ 'YOLOv5RandomAffine', 'PPYOLOERandomDistort', 'PPYOLOERandomCrop',
13
+ 'Mosaic9', 'YOLOv5CopyPaste', 'RemoveDataElement', 'RegularizeRotatedBox'
14
+ ]
mmyolo/datasets/transforms/mix_img_transforms.py ADDED
@@ -0,0 +1,1150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import collections
3
+ import copy
4
+ from abc import ABCMeta, abstractmethod
5
+ from typing import Optional, Sequence, Tuple, Union
6
+
7
+ import mmcv
8
+ import numpy as np
9
+ from mmcv.transforms import BaseTransform
10
+ from mmdet.structures.bbox import autocast_box_type
11
+ from mmengine.dataset import BaseDataset
12
+ from mmengine.dataset.base_dataset import Compose
13
+ from numpy import random
14
+
15
+ from mmyolo.registry import TRANSFORMS
16
+
17
+
18
+ class BaseMixImageTransform(BaseTransform, metaclass=ABCMeta):
19
+ """A Base Transform of multiple images mixed.
20
+
21
+ Suitable for training on multiple images mixed data augmentation like
22
+ mosaic and mixup.
23
+
24
+ Cached mosaic transform will random select images from the cache
25
+ and combine them into one output image if use_cached is True.
26
+
27
+ Args:
28
+ pre_transform(Sequence[str]): Sequence of transform object or
29
+ config dict to be composed. Defaults to None.
30
+ prob(float): The transformation probability. Defaults to 1.0.
31
+ use_cached (bool): Whether to use cache. Defaults to False.
32
+ max_cached_images (int): The maximum length of the cache. The larger
33
+ the cache, the stronger the randomness of this transform. As a
34
+ rule of thumb, providing 10 caches for each image suffices for
35
+ randomness. Defaults to 40.
36
+ random_pop (bool): Whether to randomly pop a result from the cache
37
+ when the cache is full. If set to False, use FIFO popping method.
38
+ Defaults to True.
39
+ max_refetch (int): The maximum number of retry iterations for getting
40
+ valid results from the pipeline. If the number of iterations is
41
+ greater than `max_refetch`, but results is still None, then the
42
+ iteration is terminated and raise the error. Defaults to 15.
43
+ """
44
+
45
+ def __init__(self,
46
+ pre_transform: Optional[Sequence[str]] = None,
47
+ prob: float = 1.0,
48
+ use_cached: bool = False,
49
+ max_cached_images: int = 40,
50
+ random_pop: bool = True,
51
+ max_refetch: int = 15):
52
+
53
+ self.max_refetch = max_refetch
54
+ self.prob = prob
55
+
56
+ self.use_cached = use_cached
57
+ self.max_cached_images = max_cached_images
58
+ self.random_pop = random_pop
59
+ self.results_cache = []
60
+
61
+ if pre_transform is None:
62
+ self.pre_transform = None
63
+ else:
64
+ self.pre_transform = Compose(pre_transform)
65
+
66
+ @abstractmethod
67
+ def get_indexes(self, dataset: Union[BaseDataset,
68
+ list]) -> Union[list, int]:
69
+ """Call function to collect indexes.
70
+
71
+ Args:
72
+ dataset (:obj:`Dataset` or list): The dataset or cached list.
73
+
74
+ Returns:
75
+ list or int: indexes.
76
+ """
77
+ pass
78
+
79
+ @abstractmethod
80
+ def mix_img_transform(self, results: dict) -> dict:
81
+ """Mixed image data transformation.
82
+
83
+ Args:
84
+ results (dict): Result dict.
85
+
86
+ Returns:
87
+ results (dict): Updated result dict.
88
+ """
89
+ pass
90
+
91
+ @autocast_box_type()
92
+ def transform(self, results: dict) -> dict:
93
+ """Data augmentation function.
94
+
95
+ The transform steps are as follows:
96
+ 1. Randomly generate index list of other images.
97
+ 2. Before Mosaic or MixUp need to go through the necessary
98
+ pre_transform, such as MixUp' pre_transform pipeline
99
+ include: 'LoadImageFromFile','LoadAnnotations',
100
+ 'Mosaic' and 'RandomAffine'.
101
+ 3. Use mix_img_transform function to implement specific
102
+ mix operations.
103
+
104
+ Args:
105
+ results (dict): Result dict.
106
+
107
+ Returns:
108
+ results (dict): Updated result dict.
109
+ """
110
+
111
+ if random.uniform(0, 1) > self.prob:
112
+ return results
113
+
114
+ if self.use_cached:
115
+ # Be careful: deep copying can be very time-consuming
116
+ # if results includes dataset.
117
+ dataset = results.pop('dataset', None)
118
+ self.results_cache.append(copy.deepcopy(results))
119
+ if len(self.results_cache) > self.max_cached_images:
120
+ if self.random_pop:
121
+ index = random.randint(0, len(self.results_cache) - 1)
122
+ else:
123
+ index = 0
124
+ self.results_cache.pop(index)
125
+
126
+ if len(self.results_cache) <= 4:
127
+ return results
128
+ else:
129
+ assert 'dataset' in results
130
+ # Be careful: deep copying can be very time-consuming
131
+ # if results includes dataset.
132
+ dataset = results.pop('dataset', None)
133
+
134
+ for _ in range(self.max_refetch):
135
+ # get index of one or three other images
136
+ if self.use_cached:
137
+ indexes = self.get_indexes(self.results_cache)
138
+ else:
139
+ indexes = self.get_indexes(dataset)
140
+
141
+ if not isinstance(indexes, collections.abc.Sequence):
142
+ indexes = [indexes]
143
+
144
+ if self.use_cached:
145
+ mix_results = [
146
+ copy.deepcopy(self.results_cache[i]) for i in indexes
147
+ ]
148
+ else:
149
+ # get images information will be used for Mosaic or MixUp
150
+ mix_results = [
151
+ copy.deepcopy(dataset.get_data_info(index))
152
+ for index in indexes
153
+ ]
154
+
155
+ if self.pre_transform is not None:
156
+ for i, data in enumerate(mix_results):
157
+ # pre_transform may also require dataset
158
+ data.update({'dataset': dataset})
159
+ # before Mosaic or MixUp need to go through
160
+ # the necessary pre_transform
161
+ _results = self.pre_transform(data)
162
+ _results.pop('dataset')
163
+ mix_results[i] = _results
164
+
165
+ if None not in mix_results:
166
+ results['mix_results'] = mix_results
167
+ break
168
+ print('Repeated calculation')
169
+ else:
170
+ raise RuntimeError(
171
+ 'The loading pipeline of the original dataset'
172
+ ' always return None. Please check the correctness '
173
+ 'of the dataset and its pipeline.')
174
+
175
+ # Mosaic or MixUp
176
+ results = self.mix_img_transform(results)
177
+
178
+ if 'mix_results' in results:
179
+ results.pop('mix_results')
180
+ results['dataset'] = dataset
181
+
182
+ return results
183
+
184
+
185
+ @TRANSFORMS.register_module()
186
+ class Mosaic(BaseMixImageTransform):
187
+ """Mosaic augmentation.
188
+
189
+ Given 4 images, mosaic transform combines them into
190
+ one output image. The output image is composed of the parts from each sub-
191
+ image.
192
+
193
+ .. code:: text
194
+
195
+ mosaic transform
196
+ center_x
197
+ +------------------------------+
198
+ | pad | |
199
+ | +-----------+ pad |
200
+ | | | |
201
+ | | image1 +-----------+
202
+ | | | |
203
+ | | | image2 |
204
+ center_y |----+-+-----------+-----------+
205
+ | | cropped | |
206
+ |pad | image3 | image4 |
207
+ | | | |
208
+ +----|-------------+-----------+
209
+ | |
210
+ +-------------+
211
+
212
+ The mosaic transform steps are as follows:
213
+
214
+ 1. Choose the mosaic center as the intersections of 4 images
215
+ 2. Get the left top image according to the index, and randomly
216
+ sample another 3 images from the custom dataset.
217
+ 3. Sub image will be cropped if image is larger than mosaic patch
218
+
219
+ Required Keys:
220
+
221
+ - img
222
+ - gt_bboxes (BaseBoxes[torch.float32]) (optional)
223
+ - gt_bboxes_labels (np.int64) (optional)
224
+ - gt_ignore_flags (bool) (optional)
225
+ - mix_results (List[dict])
226
+
227
+ Modified Keys:
228
+
229
+ - img
230
+ - img_shape
231
+ - gt_bboxes (optional)
232
+ - gt_bboxes_labels (optional)
233
+ - gt_ignore_flags (optional)
234
+
235
+ Args:
236
+ img_scale (Sequence[int]): Image size after mosaic pipeline of single
237
+ image. The shape order should be (width, height).
238
+ Defaults to (640, 640).
239
+ center_ratio_range (Sequence[float]): Center ratio range of mosaic
240
+ output. Defaults to (0.5, 1.5).
241
+ bbox_clip_border (bool, optional): Whether to clip the objects outside
242
+ the border of the image. In some dataset like MOT17, the gt bboxes
243
+ are allowed to cross the border of images. Therefore, we don't
244
+ need to clip the gt bboxes in these cases. Defaults to True.
245
+ pad_val (int): Pad value. Defaults to 114.
246
+ pre_transform(Sequence[dict]): Sequence of transform object or
247
+ config dict to be composed.
248
+ prob (float): Probability of applying this transformation.
249
+ Defaults to 1.0.
250
+ use_cached (bool): Whether to use cache. Defaults to False.
251
+ max_cached_images (int): The maximum length of the cache. The larger
252
+ the cache, the stronger the randomness of this transform. As a
253
+ rule of thumb, providing 10 caches for each image suffices for
254
+ randomness. Defaults to 40.
255
+ random_pop (bool): Whether to randomly pop a result from the cache
256
+ when the cache is full. If set to False, use FIFO popping method.
257
+ Defaults to True.
258
+ max_refetch (int): The maximum number of retry iterations for getting
259
+ valid results from the pipeline. If the number of iterations is
260
+ greater than `max_refetch`, but results is still None, then the
261
+ iteration is terminated and raise the error. Defaults to 15.
262
+ """
263
+
264
+ def __init__(self,
265
+ img_scale: Tuple[int, int] = (640, 640),
266
+ center_ratio_range: Tuple[float, float] = (0.5, 1.5),
267
+ bbox_clip_border: bool = True,
268
+ pad_val: float = 114.0,
269
+ pre_transform: Sequence[dict] = None,
270
+ prob: float = 1.0,
271
+ use_cached: bool = False,
272
+ max_cached_images: int = 40,
273
+ random_pop: bool = True,
274
+ max_refetch: int = 15):
275
+ assert isinstance(img_scale, tuple)
276
+ assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \
277
+ f'got {prob}.'
278
+ if use_cached:
279
+ assert max_cached_images >= 4, 'The length of cache must >= 4, ' \
280
+ f'but got {max_cached_images}.'
281
+
282
+ super().__init__(
283
+ pre_transform=pre_transform,
284
+ prob=prob,
285
+ use_cached=use_cached,
286
+ max_cached_images=max_cached_images,
287
+ random_pop=random_pop,
288
+ max_refetch=max_refetch)
289
+
290
+ self.img_scale = img_scale
291
+ self.center_ratio_range = center_ratio_range
292
+ self.bbox_clip_border = bbox_clip_border
293
+ self.pad_val = pad_val
294
+
295
+ def get_indexes(self, dataset: Union[BaseDataset, list]) -> list:
296
+ """Call function to collect indexes.
297
+
298
+ Args:
299
+ dataset (:obj:`Dataset` or list): The dataset or cached list.
300
+
301
+ Returns:
302
+ list: indexes.
303
+ """
304
+ indexes = [random.randint(0, len(dataset)) for _ in range(3)]
305
+ return indexes
306
+
307
+ def mix_img_transform(self, results: dict) -> dict:
308
+ """Mixed image data transformation.
309
+
310
+ Args:
311
+ results (dict): Result dict.
312
+
313
+ Returns:
314
+ results (dict): Updated result dict.
315
+ """
316
+ assert 'mix_results' in results
317
+ mosaic_bboxes = []
318
+ mosaic_bboxes_labels = []
319
+ mosaic_ignore_flags = []
320
+ mosaic_masks = []
321
+ with_mask = True if 'gt_masks' in results else False
322
+ # self.img_scale is wh format
323
+ img_scale_w, img_scale_h = self.img_scale
324
+
325
+ if len(results['img'].shape) == 3:
326
+ mosaic_img = np.full(
327
+ (int(img_scale_h * 2), int(img_scale_w * 2), 3),
328
+ self.pad_val,
329
+ dtype=results['img'].dtype)
330
+ else:
331
+ mosaic_img = np.full((int(img_scale_h * 2), int(img_scale_w * 2)),
332
+ self.pad_val,
333
+ dtype=results['img'].dtype)
334
+
335
+ # mosaic center x, y
336
+ center_x = int(random.uniform(*self.center_ratio_range) * img_scale_w)
337
+ center_y = int(random.uniform(*self.center_ratio_range) * img_scale_h)
338
+ center_position = (center_x, center_y)
339
+
340
+ loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
341
+ for i, loc in enumerate(loc_strs):
342
+ if loc == 'top_left':
343
+ results_patch = results
344
+ else:
345
+ results_patch = results['mix_results'][i - 1]
346
+
347
+ img_i = results_patch['img']
348
+ h_i, w_i = img_i.shape[:2]
349
+ # keep_ratio resize
350
+ scale_ratio_i = min(img_scale_h / h_i, img_scale_w / w_i)
351
+ img_i = mmcv.imresize(
352
+ img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
353
+
354
+ # compute the combine parameters
355
+ paste_coord, crop_coord = self._mosaic_combine(
356
+ loc, center_position, img_i.shape[:2][::-1])
357
+ x1_p, y1_p, x2_p, y2_p = paste_coord
358
+ x1_c, y1_c, x2_c, y2_c = crop_coord
359
+
360
+ # crop and paste image
361
+ mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
362
+
363
+ # adjust coordinate
364
+ gt_bboxes_i = results_patch['gt_bboxes']
365
+ gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
366
+ gt_ignore_flags_i = results_patch['gt_ignore_flags']
367
+
368
+ padw = x1_p - x1_c
369
+ padh = y1_p - y1_c
370
+ gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
371
+ gt_bboxes_i.translate_([padw, padh])
372
+ mosaic_bboxes.append(gt_bboxes_i)
373
+ mosaic_bboxes_labels.append(gt_bboxes_labels_i)
374
+ mosaic_ignore_flags.append(gt_ignore_flags_i)
375
+ if with_mask and results_patch.get('gt_masks', None) is not None:
376
+ gt_masks_i = results_patch['gt_masks']
377
+ gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i))
378
+ gt_masks_i = gt_masks_i.translate(
379
+ out_shape=(int(self.img_scale[0] * 2),
380
+ int(self.img_scale[1] * 2)),
381
+ offset=padw,
382
+ direction='horizontal')
383
+ gt_masks_i = gt_masks_i.translate(
384
+ out_shape=(int(self.img_scale[0] * 2),
385
+ int(self.img_scale[1] * 2)),
386
+ offset=padh,
387
+ direction='vertical')
388
+ mosaic_masks.append(gt_masks_i)
389
+
390
+ mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
391
+ mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
392
+ mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
393
+
394
+ if self.bbox_clip_border:
395
+ mosaic_bboxes.clip_([2 * img_scale_h, 2 * img_scale_w])
396
+ if with_mask:
397
+ mosaic_masks = mosaic_masks[0].cat(mosaic_masks)
398
+ results['gt_masks'] = mosaic_masks
399
+ else:
400
+ # remove outside bboxes
401
+ inside_inds = mosaic_bboxes.is_inside(
402
+ [2 * img_scale_h, 2 * img_scale_w]).numpy()
403
+ mosaic_bboxes = mosaic_bboxes[inside_inds]
404
+ mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
405
+ mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
406
+ if with_mask:
407
+ mosaic_masks = mosaic_masks[0].cat(mosaic_masks)[inside_inds]
408
+ results['gt_masks'] = mosaic_masks
409
+
410
+ results['img'] = mosaic_img
411
+ results['img_shape'] = mosaic_img.shape
412
+ results['gt_bboxes'] = mosaic_bboxes
413
+ results['gt_bboxes_labels'] = mosaic_bboxes_labels
414
+ results['gt_ignore_flags'] = mosaic_ignore_flags
415
+
416
+ return results
417
+
418
+ def _mosaic_combine(
419
+ self, loc: str, center_position_xy: Sequence[float],
420
+ img_shape_wh: Sequence[int]) -> Tuple[Tuple[int], Tuple[int]]:
421
+ """Calculate global coordinate of mosaic image and local coordinate of
422
+ cropped sub-image.
423
+
424
+ Args:
425
+ loc (str): Index for the sub-image, loc in ('top_left',
426
+ 'top_right', 'bottom_left', 'bottom_right').
427
+ center_position_xy (Sequence[float]): Mixing center for 4 images,
428
+ (x, y).
429
+ img_shape_wh (Sequence[int]): Width and height of sub-image
430
+
431
+ Returns:
432
+ tuple[tuple[float]]: Corresponding coordinate of pasting and
433
+ cropping
434
+ - paste_coord (tuple): paste corner coordinate in mosaic image.
435
+ - crop_coord (tuple): crop corner coordinate in mosaic image.
436
+ """
437
+ assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
438
+ if loc == 'top_left':
439
+ # index0 to top left part of image
440
+ x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
441
+ max(center_position_xy[1] - img_shape_wh[1], 0), \
442
+ center_position_xy[0], \
443
+ center_position_xy[1]
444
+ crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
445
+ y2 - y1), img_shape_wh[0], img_shape_wh[1]
446
+
447
+ elif loc == 'top_right':
448
+ # index1 to top right part of image
449
+ x1, y1, x2, y2 = center_position_xy[0], \
450
+ max(center_position_xy[1] - img_shape_wh[1], 0), \
451
+ min(center_position_xy[0] + img_shape_wh[0],
452
+ self.img_scale[0] * 2), \
453
+ center_position_xy[1]
454
+ crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
455
+ img_shape_wh[0], x2 - x1), img_shape_wh[1]
456
+
457
+ elif loc == 'bottom_left':
458
+ # index2 to bottom left part of image
459
+ x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
460
+ center_position_xy[1], \
461
+ center_position_xy[0], \
462
+ min(self.img_scale[1] * 2, center_position_xy[1] +
463
+ img_shape_wh[1])
464
+ crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
465
+ y2 - y1, img_shape_wh[1])
466
+
467
+ else:
468
+ # index3 to bottom right part of image
469
+ x1, y1, x2, y2 = center_position_xy[0], \
470
+ center_position_xy[1], \
471
+ min(center_position_xy[0] + img_shape_wh[0],
472
+ self.img_scale[0] * 2), \
473
+ min(self.img_scale[1] * 2, center_position_xy[1] +
474
+ img_shape_wh[1])
475
+ crop_coord = 0, 0, min(img_shape_wh[0],
476
+ x2 - x1), min(y2 - y1, img_shape_wh[1])
477
+
478
+ paste_coord = x1, y1, x2, y2
479
+ return paste_coord, crop_coord
480
+
481
+ def __repr__(self) -> str:
482
+ repr_str = self.__class__.__name__
483
+ repr_str += f'(img_scale={self.img_scale}, '
484
+ repr_str += f'center_ratio_range={self.center_ratio_range}, '
485
+ repr_str += f'pad_val={self.pad_val}, '
486
+ repr_str += f'prob={self.prob})'
487
+ return repr_str
488
+
489
+
490
+ @TRANSFORMS.register_module()
491
+ class Mosaic9(BaseMixImageTransform):
492
+ """Mosaic9 augmentation.
493
+
494
+ Given 9 images, mosaic transform combines them into
495
+ one output image. The output image is composed of the parts from each sub-
496
+ image.
497
+
498
+ .. code:: text
499
+
500
+ +-------------------------------+------------+
501
+ | pad | pad | |
502
+ | +----------+ | |
503
+ | | +---------------+ top_right |
504
+ | | | top | image2 |
505
+ | | top_left | image1 | |
506
+ | | image8 o--------+------+--------+---+
507
+ | | | | | |
508
+ +----+----------+ | right |pad|
509
+ | | center | image3 | |
510
+ | left | image0 +---------------+---|
511
+ | image7 | | | |
512
+ +---+-----------+---+--------+ | |
513
+ | | cropped | | bottom_right |pad|
514
+ | |bottom_left| | image4 | |
515
+ | | image6 | bottom | | |
516
+ +---|-----------+ image5 +---------------+---|
517
+ | pad | | pad |
518
+ +-----------+------------+-------------------+
519
+
520
+ The mosaic transform steps are as follows:
521
+
522
+ 1. Get the center image according to the index, and randomly
523
+ sample another 8 images from the custom dataset.
524
+ 2. Randomly offset the image after Mosaic
525
+
526
+ Required Keys:
527
+
528
+ - img
529
+ - gt_bboxes (BaseBoxes[torch.float32]) (optional)
530
+ - gt_bboxes_labels (np.int64) (optional)
531
+ - gt_ignore_flags (bool) (optional)
532
+ - mix_results (List[dict])
533
+
534
+ Modified Keys:
535
+
536
+ - img
537
+ - img_shape
538
+ - gt_bboxes (optional)
539
+ - gt_bboxes_labels (optional)
540
+ - gt_ignore_flags (optional)
541
+
542
+ Args:
543
+ img_scale (Sequence[int]): Image size after mosaic pipeline of single
544
+ image. The shape order should be (width, height).
545
+ Defaults to (640, 640).
546
+ bbox_clip_border (bool, optional): Whether to clip the objects outside
547
+ the border of the image. In some dataset like MOT17, the gt bboxes
548
+ are allowed to cross the border of images. Therefore, we don't
549
+ need to clip the gt bboxes in these cases. Defaults to True.
550
+ pad_val (int): Pad value. Defaults to 114.
551
+ pre_transform(Sequence[dict]): Sequence of transform object or
552
+ config dict to be composed.
553
+ prob (float): Probability of applying this transformation.
554
+ Defaults to 1.0.
555
+ use_cached (bool): Whether to use cache. Defaults to False.
556
+ max_cached_images (int): The maximum length of the cache. The larger
557
+ the cache, the stronger the randomness of this transform. As a
558
+ rule of thumb, providing 5 caches for each image suffices for
559
+ randomness. Defaults to 50.
560
+ random_pop (bool): Whether to randomly pop a result from the cache
561
+ when the cache is full. If set to False, use FIFO popping method.
562
+ Defaults to True.
563
+ max_refetch (int): The maximum number of retry iterations for getting
564
+ valid results from the pipeline. If the number of iterations is
565
+ greater than `max_refetch`, but results is still None, then the
566
+ iteration is terminated and raise the error. Defaults to 15.
567
+ """
568
+
569
+ def __init__(self,
570
+ img_scale: Tuple[int, int] = (640, 640),
571
+ bbox_clip_border: bool = True,
572
+ pad_val: Union[float, int] = 114.0,
573
+ pre_transform: Sequence[dict] = None,
574
+ prob: float = 1.0,
575
+ use_cached: bool = False,
576
+ max_cached_images: int = 50,
577
+ random_pop: bool = True,
578
+ max_refetch: int = 15):
579
+ assert isinstance(img_scale, tuple)
580
+ assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \
581
+ f'got {prob}.'
582
+ if use_cached:
583
+ assert max_cached_images >= 9, 'The length of cache must >= 9, ' \
584
+ f'but got {max_cached_images}.'
585
+
586
+ super().__init__(
587
+ pre_transform=pre_transform,
588
+ prob=prob,
589
+ use_cached=use_cached,
590
+ max_cached_images=max_cached_images,
591
+ random_pop=random_pop,
592
+ max_refetch=max_refetch)
593
+
594
+ self.img_scale = img_scale
595
+ self.bbox_clip_border = bbox_clip_border
596
+ self.pad_val = pad_val
597
+
598
+ # intermediate variables
599
+ self._current_img_shape = [0, 0]
600
+ self._center_img_shape = [0, 0]
601
+ self._previous_img_shape = [0, 0]
602
+
603
+ def get_indexes(self, dataset: Union[BaseDataset, list]) -> list:
604
+ """Call function to collect indexes.
605
+
606
+ Args:
607
+ dataset (:obj:`Dataset` or list): The dataset or cached list.
608
+
609
+ Returns:
610
+ list: indexes.
611
+ """
612
+ indexes = [random.randint(0, len(dataset)) for _ in range(8)]
613
+ return indexes
614
+
615
+ def mix_img_transform(self, results: dict) -> dict:
616
+ """Mixed image data transformation.
617
+
618
+ Args:
619
+ results (dict): Result dict.
620
+
621
+ Returns:
622
+ results (dict): Updated result dict.
623
+ """
624
+ assert 'mix_results' in results
625
+
626
+ mosaic_bboxes = []
627
+ mosaic_bboxes_labels = []
628
+ mosaic_ignore_flags = []
629
+
630
+ img_scale_w, img_scale_h = self.img_scale
631
+
632
+ if len(results['img'].shape) == 3:
633
+ mosaic_img = np.full(
634
+ (int(img_scale_h * 3), int(img_scale_w * 3), 3),
635
+ self.pad_val,
636
+ dtype=results['img'].dtype)
637
+ else:
638
+ mosaic_img = np.full((int(img_scale_h * 3), int(img_scale_w * 3)),
639
+ self.pad_val,
640
+ dtype=results['img'].dtype)
641
+
642
+ # index = 0 is mean original image
643
+ # len(results['mix_results']) = 8
644
+ loc_strs = ('center', 'top', 'top_right', 'right', 'bottom_right',
645
+ 'bottom', 'bottom_left', 'left', 'top_left')
646
+
647
+ results_all = [results, *results['mix_results']]
648
+ for index, results_patch in enumerate(results_all):
649
+ img_i = results_patch['img']
650
+ # keep_ratio resize
651
+ img_i_h, img_i_w = img_i.shape[:2]
652
+ scale_ratio_i = min(img_scale_h / img_i_h, img_scale_w / img_i_w)
653
+ img_i = mmcv.imresize(
654
+ img_i,
655
+ (int(img_i_w * scale_ratio_i), int(img_i_h * scale_ratio_i)))
656
+
657
+ paste_coord = self._mosaic_combine(loc_strs[index],
658
+ img_i.shape[:2])
659
+
660
+ padw, padh = paste_coord[:2]
661
+ x1, y1, x2, y2 = (max(x, 0) for x in paste_coord)
662
+ mosaic_img[y1:y2, x1:x2] = img_i[y1 - padh:, x1 - padw:]
663
+
664
+ gt_bboxes_i = results_patch['gt_bboxes']
665
+ gt_bboxes_labels_i = results_patch['gt_bboxes_labels']
666
+ gt_ignore_flags_i = results_patch['gt_ignore_flags']
667
+ gt_bboxes_i.rescale_([scale_ratio_i, scale_ratio_i])
668
+ gt_bboxes_i.translate_([padw, padh])
669
+
670
+ mosaic_bboxes.append(gt_bboxes_i)
671
+ mosaic_bboxes_labels.append(gt_bboxes_labels_i)
672
+ mosaic_ignore_flags.append(gt_ignore_flags_i)
673
+
674
+ # Offset
675
+ offset_x = int(random.uniform(0, img_scale_w))
676
+ offset_y = int(random.uniform(0, img_scale_h))
677
+ mosaic_img = mosaic_img[offset_y:offset_y + 2 * img_scale_h,
678
+ offset_x:offset_x + 2 * img_scale_w]
679
+
680
+ mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0)
681
+ mosaic_bboxes.translate_([-offset_x, -offset_y])
682
+ mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0)
683
+ mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0)
684
+
685
+ if self.bbox_clip_border:
686
+ mosaic_bboxes.clip_([2 * img_scale_h, 2 * img_scale_w])
687
+ else:
688
+ # remove outside bboxes
689
+ inside_inds = mosaic_bboxes.is_inside(
690
+ [2 * img_scale_h, 2 * img_scale_w]).numpy()
691
+ mosaic_bboxes = mosaic_bboxes[inside_inds]
692
+ mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds]
693
+ mosaic_ignore_flags = mosaic_ignore_flags[inside_inds]
694
+
695
+ results['img'] = mosaic_img
696
+ results['img_shape'] = mosaic_img.shape
697
+ results['gt_bboxes'] = mosaic_bboxes
698
+ results['gt_bboxes_labels'] = mosaic_bboxes_labels
699
+ results['gt_ignore_flags'] = mosaic_ignore_flags
700
+ return results
701
+
702
+ def _mosaic_combine(self, loc: str,
703
+ img_shape_hw: Tuple[int, int]) -> Tuple[int, ...]:
704
+ """Calculate global coordinate of mosaic image.
705
+
706
+ Args:
707
+ loc (str): Index for the sub-image.
708
+ img_shape_hw (Sequence[int]): Height and width of sub-image
709
+
710
+ Returns:
711
+ paste_coord (tuple): paste corner coordinate in mosaic image.
712
+ """
713
+ assert loc in ('center', 'top', 'top_right', 'right', 'bottom_right',
714
+ 'bottom', 'bottom_left', 'left', 'top_left')
715
+
716
+ img_scale_w, img_scale_h = self.img_scale
717
+
718
+ self._current_img_shape = img_shape_hw
719
+ current_img_h, current_img_w = self._current_img_shape
720
+ previous_img_h, previous_img_w = self._previous_img_shape
721
+ center_img_h, center_img_w = self._center_img_shape
722
+
723
+ if loc == 'center':
724
+ self._center_img_shape = self._current_img_shape
725
+ # xmin, ymin, xmax, ymax
726
+ paste_coord = img_scale_w, \
727
+ img_scale_h, \
728
+ img_scale_w + current_img_w, \
729
+ img_scale_h + current_img_h
730
+ elif loc == 'top':
731
+ paste_coord = img_scale_w, \
732
+ img_scale_h - current_img_h, \
733
+ img_scale_w + current_img_w, \
734
+ img_scale_h
735
+ elif loc == 'top_right':
736
+ paste_coord = img_scale_w + previous_img_w, \
737
+ img_scale_h - current_img_h, \
738
+ img_scale_w + previous_img_w + current_img_w, \
739
+ img_scale_h
740
+ elif loc == 'right':
741
+ paste_coord = img_scale_w + center_img_w, \
742
+ img_scale_h, \
743
+ img_scale_w + center_img_w + current_img_w, \
744
+ img_scale_h + current_img_h
745
+ elif loc == 'bottom_right':
746
+ paste_coord = img_scale_w + center_img_w, \
747
+ img_scale_h + previous_img_h, \
748
+ img_scale_w + center_img_w + current_img_w, \
749
+ img_scale_h + previous_img_h + current_img_h
750
+ elif loc == 'bottom':
751
+ paste_coord = img_scale_w + center_img_w - current_img_w, \
752
+ img_scale_h + center_img_h, \
753
+ img_scale_w + center_img_w, \
754
+ img_scale_h + center_img_h + current_img_h
755
+ elif loc == 'bottom_left':
756
+ paste_coord = img_scale_w + center_img_w - \
757
+ previous_img_w - current_img_w, \
758
+ img_scale_h + center_img_h, \
759
+ img_scale_w + center_img_w - previous_img_w, \
760
+ img_scale_h + center_img_h + current_img_h
761
+ elif loc == 'left':
762
+ paste_coord = img_scale_w - current_img_w, \
763
+ img_scale_h + center_img_h - current_img_h, \
764
+ img_scale_w, \
765
+ img_scale_h + center_img_h
766
+ elif loc == 'top_left':
767
+ paste_coord = img_scale_w - current_img_w, \
768
+ img_scale_h + center_img_h - \
769
+ previous_img_h - current_img_h, \
770
+ img_scale_w, \
771
+ img_scale_h + center_img_h - previous_img_h
772
+
773
+ self._previous_img_shape = self._current_img_shape
774
+ # xmin, ymin, xmax, ymax
775
+ return paste_coord
776
+
777
+ def __repr__(self) -> str:
778
+ repr_str = self.__class__.__name__
779
+ repr_str += f'(img_scale={self.img_scale}, '
780
+ repr_str += f'pad_val={self.pad_val}, '
781
+ repr_str += f'prob={self.prob})'
782
+ return repr_str
783
+
784
+
785
+ @TRANSFORMS.register_module()
786
+ class YOLOv5MixUp(BaseMixImageTransform):
787
+ """MixUp data augmentation for YOLOv5.
788
+
789
+ .. code:: text
790
+
791
+ The mixup transform steps are as follows:
792
+
793
+ 1. Another random image is picked by dataset.
794
+ 2. Randomly obtain the fusion ratio from the beta distribution,
795
+ then fuse the target
796
+ of the original image and mixup image through this ratio.
797
+
798
+ Required Keys:
799
+
800
+ - img
801
+ - gt_bboxes (BaseBoxes[torch.float32]) (optional)
802
+ - gt_bboxes_labels (np.int64) (optional)
803
+ - gt_ignore_flags (bool) (optional)
804
+ - mix_results (List[dict])
805
+
806
+
807
+ Modified Keys:
808
+
809
+ - img
810
+ - img_shape
811
+ - gt_bboxes (optional)
812
+ - gt_bboxes_labels (optional)
813
+ - gt_ignore_flags (optional)
814
+
815
+
816
+ Args:
817
+ alpha (float): parameter of beta distribution to get mixup ratio.
818
+ Defaults to 32.
819
+ beta (float): parameter of beta distribution to get mixup ratio.
820
+ Defaults to 32.
821
+ pre_transform (Sequence[dict]): Sequence of transform object or
822
+ config dict to be composed.
823
+ prob (float): Probability of applying this transformation.
824
+ Defaults to 1.0.
825
+ use_cached (bool): Whether to use cache. Defaults to False.
826
+ max_cached_images (int): The maximum length of the cache. The larger
827
+ the cache, the stronger the randomness of this transform. As a
828
+ rule of thumb, providing 10 caches for each image suffices for
829
+ randomness. Defaults to 20.
830
+ random_pop (bool): Whether to randomly pop a result from the cache
831
+ when the cache is full. If set to False, use FIFO popping method.
832
+ Defaults to True.
833
+ max_refetch (int): The maximum number of iterations. If the number of
834
+ iterations is greater than `max_refetch`, but gt_bbox is still
835
+ empty, then the iteration is terminated. Defaults to 15.
836
+ """
837
+
838
+ def __init__(self,
839
+ alpha: float = 32.0,
840
+ beta: float = 32.0,
841
+ pre_transform: Sequence[dict] = None,
842
+ prob: float = 1.0,
843
+ use_cached: bool = False,
844
+ max_cached_images: int = 20,
845
+ random_pop: bool = True,
846
+ max_refetch: int = 15):
847
+ if use_cached:
848
+ assert max_cached_images >= 2, 'The length of cache must >= 2, ' \
849
+ f'but got {max_cached_images}.'
850
+ super().__init__(
851
+ pre_transform=pre_transform,
852
+ prob=prob,
853
+ use_cached=use_cached,
854
+ max_cached_images=max_cached_images,
855
+ random_pop=random_pop,
856
+ max_refetch=max_refetch)
857
+ self.alpha = alpha
858
+ self.beta = beta
859
+
860
+ def get_indexes(self, dataset: Union[BaseDataset, list]) -> int:
861
+ """Call function to collect indexes.
862
+
863
+ Args:
864
+ dataset (:obj:`Dataset` or list): The dataset or cached list.
865
+
866
+ Returns:
867
+ int: indexes.
868
+ """
869
+ return random.randint(0, len(dataset))
870
+
871
+ def mix_img_transform(self, results: dict) -> dict:
872
+ """YOLOv5 MixUp transform function.
873
+
874
+ Args:
875
+ results (dict): Result dict
876
+
877
+ Returns:
878
+ results (dict): Updated result dict.
879
+ """
880
+ assert 'mix_results' in results
881
+
882
+ retrieve_results = results['mix_results'][0]
883
+ retrieve_img = retrieve_results['img']
884
+ ori_img = results['img']
885
+ assert ori_img.shape == retrieve_img.shape
886
+
887
+ # Randomly obtain the fusion ratio from the beta distribution,
888
+ # which is around 0.5
889
+ ratio = np.random.beta(self.alpha, self.beta)
890
+ mixup_img = (ori_img * ratio + retrieve_img * (1 - ratio))
891
+
892
+ retrieve_gt_bboxes = retrieve_results['gt_bboxes']
893
+ retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
894
+ retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
895
+
896
+ mixup_gt_bboxes = retrieve_gt_bboxes.cat(
897
+ (results['gt_bboxes'], retrieve_gt_bboxes), dim=0)
898
+ mixup_gt_bboxes_labels = np.concatenate(
899
+ (results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
900
+ mixup_gt_ignore_flags = np.concatenate(
901
+ (results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
902
+ if 'gt_masks' in results:
903
+ assert 'gt_masks' in retrieve_results
904
+ mixup_gt_masks = results['gt_masks'].cat(
905
+ [results['gt_masks'], retrieve_results['gt_masks']])
906
+ results['gt_masks'] = mixup_gt_masks
907
+
908
+ results['img'] = mixup_img.astype(np.uint8)
909
+ results['img_shape'] = mixup_img.shape
910
+ results['gt_bboxes'] = mixup_gt_bboxes
911
+ results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
912
+ results['gt_ignore_flags'] = mixup_gt_ignore_flags
913
+
914
+ return results
915
+
916
+
917
+ @TRANSFORMS.register_module()
918
+ class YOLOXMixUp(BaseMixImageTransform):
919
+ """MixUp data augmentation for YOLOX.
920
+
921
+ .. code:: text
922
+
923
+ mixup transform
924
+ +---------------+--------------+
925
+ | mixup image | |
926
+ | +--------|--------+ |
927
+ | | | | |
928
+ +---------------+ | |
929
+ | | | |
930
+ | | image | |
931
+ | | | |
932
+ | | | |
933
+ | +-----------------+ |
934
+ | pad |
935
+ +------------------------------+
936
+
937
+ The mixup transform steps are as follows:
938
+
939
+ 1. Another random image is picked by dataset and embedded in
940
+ the top left patch(after padding and resizing)
941
+ 2. The target of mixup transform is the weighted average of mixup
942
+ image and origin image.
943
+
944
+ Required Keys:
945
+
946
+ - img
947
+ - gt_bboxes (BaseBoxes[torch.float32]) (optional)
948
+ - gt_bboxes_labels (np.int64) (optional)
949
+ - gt_ignore_flags (bool) (optional)
950
+ - mix_results (List[dict])
951
+
952
+
953
+ Modified Keys:
954
+
955
+ - img
956
+ - img_shape
957
+ - gt_bboxes (optional)
958
+ - gt_bboxes_labels (optional)
959
+ - gt_ignore_flags (optional)
960
+
961
+
962
+ Args:
963
+ img_scale (Sequence[int]): Image output size after mixup pipeline.
964
+ The shape order should be (width, height). Defaults to (640, 640).
965
+ ratio_range (Sequence[float]): Scale ratio of mixup image.
966
+ Defaults to (0.5, 1.5).
967
+ flip_ratio (float): Horizontal flip ratio of mixup image.
968
+ Defaults to 0.5.
969
+ pad_val (int): Pad value. Defaults to 114.
970
+ bbox_clip_border (bool, optional): Whether to clip the objects outside
971
+ the border of the image. In some dataset like MOT17, the gt bboxes
972
+ are allowed to cross the border of images. Therefore, we don't
973
+ need to clip the gt bboxes in these cases. Defaults to True.
974
+ pre_transform(Sequence[dict]): Sequence of transform object or
975
+ config dict to be composed.
976
+ prob (float): Probability of applying this transformation.
977
+ Defaults to 1.0.
978
+ use_cached (bool): Whether to use cache. Defaults to False.
979
+ max_cached_images (int): The maximum length of the cache. The larger
980
+ the cache, the stronger the randomness of this transform. As a
981
+ rule of thumb, providing 10 caches for each image suffices for
982
+ randomness. Defaults to 20.
983
+ random_pop (bool): Whether to randomly pop a result from the cache
984
+ when the cache is full. If set to False, use FIFO popping method.
985
+ Defaults to True.
986
+ max_refetch (int): The maximum number of iterations. If the number of
987
+ iterations is greater than `max_refetch`, but gt_bbox is still
988
+ empty, then the iteration is terminated. Defaults to 15.
989
+ """
990
+
991
+ def __init__(self,
992
+ img_scale: Tuple[int, int] = (640, 640),
993
+ ratio_range: Tuple[float, float] = (0.5, 1.5),
994
+ flip_ratio: float = 0.5,
995
+ pad_val: float = 114.0,
996
+ bbox_clip_border: bool = True,
997
+ pre_transform: Sequence[dict] = None,
998
+ prob: float = 1.0,
999
+ use_cached: bool = False,
1000
+ max_cached_images: int = 20,
1001
+ random_pop: bool = True,
1002
+ max_refetch: int = 15):
1003
+ assert isinstance(img_scale, tuple)
1004
+ if use_cached:
1005
+ assert max_cached_images >= 2, 'The length of cache must >= 2, ' \
1006
+ f'but got {max_cached_images}.'
1007
+ super().__init__(
1008
+ pre_transform=pre_transform,
1009
+ prob=prob,
1010
+ use_cached=use_cached,
1011
+ max_cached_images=max_cached_images,
1012
+ random_pop=random_pop,
1013
+ max_refetch=max_refetch)
1014
+ self.img_scale = img_scale
1015
+ self.ratio_range = ratio_range
1016
+ self.flip_ratio = flip_ratio
1017
+ self.pad_val = pad_val
1018
+ self.bbox_clip_border = bbox_clip_border
1019
+
1020
+ def get_indexes(self, dataset: Union[BaseDataset, list]) -> int:
1021
+ """Call function to collect indexes.
1022
+
1023
+ Args:
1024
+ dataset (:obj:`Dataset` or list): The dataset or cached list.
1025
+
1026
+ Returns:
1027
+ int: indexes.
1028
+ """
1029
+ return random.randint(0, len(dataset))
1030
+
1031
+ def mix_img_transform(self, results: dict) -> dict:
1032
+ """YOLOX MixUp transform function.
1033
+
1034
+ Args:
1035
+ results (dict): Result dict.
1036
+
1037
+ Returns:
1038
+ results (dict): Updated result dict.
1039
+ """
1040
+ assert 'mix_results' in results
1041
+ assert len(
1042
+ results['mix_results']) == 1, 'MixUp only support 2 images now !'
1043
+
1044
+ if results['mix_results'][0]['gt_bboxes'].shape[0] == 0:
1045
+ # empty bbox
1046
+ return results
1047
+
1048
+ retrieve_results = results['mix_results'][0]
1049
+ retrieve_img = retrieve_results['img']
1050
+
1051
+ jit_factor = random.uniform(*self.ratio_range)
1052
+ is_filp = random.uniform(0, 1) > self.flip_ratio
1053
+
1054
+ if len(retrieve_img.shape) == 3:
1055
+ out_img = np.ones((self.img_scale[1], self.img_scale[0], 3),
1056
+ dtype=retrieve_img.dtype) * self.pad_val
1057
+ else:
1058
+ out_img = np.ones(
1059
+ self.img_scale[::-1], dtype=retrieve_img.dtype) * self.pad_val
1060
+
1061
+ # 1. keep_ratio resize
1062
+ scale_ratio = min(self.img_scale[1] / retrieve_img.shape[0],
1063
+ self.img_scale[0] / retrieve_img.shape[1])
1064
+ retrieve_img = mmcv.imresize(
1065
+ retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
1066
+ int(retrieve_img.shape[0] * scale_ratio)))
1067
+
1068
+ # 2. paste
1069
+ out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
1070
+
1071
+ # 3. scale jit
1072
+ scale_ratio *= jit_factor
1073
+ out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
1074
+ int(out_img.shape[0] * jit_factor)))
1075
+
1076
+ # 4. flip
1077
+ if is_filp:
1078
+ out_img = out_img[:, ::-1, :]
1079
+
1080
+ # 5. random crop
1081
+ ori_img = results['img']
1082
+ origin_h, origin_w = out_img.shape[:2]
1083
+ target_h, target_w = ori_img.shape[:2]
1084
+ padded_img = np.ones((max(origin_h, target_h), max(
1085
+ origin_w, target_w), 3)) * self.pad_val
1086
+ padded_img = padded_img.astype(np.uint8)
1087
+ padded_img[:origin_h, :origin_w] = out_img
1088
+
1089
+ x_offset, y_offset = 0, 0
1090
+ if padded_img.shape[0] > target_h:
1091
+ y_offset = random.randint(0, padded_img.shape[0] - target_h)
1092
+ if padded_img.shape[1] > target_w:
1093
+ x_offset = random.randint(0, padded_img.shape[1] - target_w)
1094
+ padded_cropped_img = padded_img[y_offset:y_offset + target_h,
1095
+ x_offset:x_offset + target_w]
1096
+
1097
+ # 6. adjust bbox
1098
+ retrieve_gt_bboxes = retrieve_results['gt_bboxes']
1099
+ retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio])
1100
+ if self.bbox_clip_border:
1101
+ retrieve_gt_bboxes.clip_([origin_h, origin_w])
1102
+
1103
+ if is_filp:
1104
+ retrieve_gt_bboxes.flip_([origin_h, origin_w],
1105
+ direction='horizontal')
1106
+
1107
+ # 7. filter
1108
+ cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone()
1109
+ cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset])
1110
+ if self.bbox_clip_border:
1111
+ cp_retrieve_gt_bboxes.clip_([target_h, target_w])
1112
+
1113
+ # 8. mix up
1114
+ mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img
1115
+
1116
+ retrieve_gt_bboxes_labels = retrieve_results['gt_bboxes_labels']
1117
+ retrieve_gt_ignore_flags = retrieve_results['gt_ignore_flags']
1118
+
1119
+ mixup_gt_bboxes = cp_retrieve_gt_bboxes.cat(
1120
+ (results['gt_bboxes'], cp_retrieve_gt_bboxes), dim=0)
1121
+ mixup_gt_bboxes_labels = np.concatenate(
1122
+ (results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0)
1123
+ mixup_gt_ignore_flags = np.concatenate(
1124
+ (results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0)
1125
+
1126
+ if not self.bbox_clip_border:
1127
+ # remove outside bbox
1128
+ inside_inds = mixup_gt_bboxes.is_inside([target_h,
1129
+ target_w]).numpy()
1130
+ mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
1131
+ mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds]
1132
+ mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds]
1133
+
1134
+ results['img'] = mixup_img.astype(np.uint8)
1135
+ results['img_shape'] = mixup_img.shape
1136
+ results['gt_bboxes'] = mixup_gt_bboxes
1137
+ results['gt_bboxes_labels'] = mixup_gt_bboxes_labels
1138
+ results['gt_ignore_flags'] = mixup_gt_ignore_flags
1139
+
1140
+ return results
1141
+
1142
+ def __repr__(self) -> str:
1143
+ repr_str = self.__class__.__name__
1144
+ repr_str += f'(img_scale={self.img_scale}, '
1145
+ repr_str += f'ratio_range={self.ratio_range}, '
1146
+ repr_str += f'flip_ratio={self.flip_ratio}, '
1147
+ repr_str += f'pad_val={self.pad_val}, '
1148
+ repr_str += f'max_refetch={self.max_refetch}, '
1149
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
1150
+ return repr_str
mmyolo/datasets/transforms/transforms.py ADDED
@@ -0,0 +1,1557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ from copy import deepcopy
4
+ from typing import List, Sequence, Tuple, Union
5
+
6
+ import cv2
7
+ import mmcv
8
+ import numpy as np
9
+ import torch
10
+ from mmcv.transforms import BaseTransform, Compose
11
+ from mmcv.transforms.utils import cache_randomness
12
+ from mmdet.datasets.transforms import LoadAnnotations as MMDET_LoadAnnotations
13
+ from mmdet.datasets.transforms import Resize as MMDET_Resize
14
+ from mmdet.structures.bbox import (HorizontalBoxes, autocast_box_type,
15
+ get_box_type)
16
+ from mmdet.structures.mask import PolygonMasks
17
+ from numpy import random
18
+
19
+ from mmyolo.registry import TRANSFORMS
20
+
21
+ # TODO: Waiting for MMCV support
22
+ TRANSFORMS.register_module(module=Compose, force=True)
23
+
24
+
25
+ @TRANSFORMS.register_module()
26
+ class YOLOv5KeepRatioResize(MMDET_Resize):
27
+ """Resize images & bbox(if existed).
28
+
29
+ This transform resizes the input image according to ``scale``.
30
+ Bboxes (if existed) are then resized with the same scale factor.
31
+
32
+ Required Keys:
33
+
34
+ - img (np.uint8)
35
+ - gt_bboxes (BaseBoxes[torch.float32]) (optional)
36
+
37
+ Modified Keys:
38
+
39
+ - img (np.uint8)
40
+ - img_shape (tuple)
41
+ - gt_bboxes (optional)
42
+ - scale (float)
43
+
44
+ Added Keys:
45
+
46
+ - scale_factor (np.float32)
47
+
48
+ Args:
49
+ scale (Union[int, Tuple[int, int]]): Images scales for resizing.
50
+ """
51
+
52
+ def __init__(self,
53
+ scale: Union[int, Tuple[int, int]],
54
+ keep_ratio: bool = True,
55
+ **kwargs):
56
+ assert keep_ratio is True
57
+ super().__init__(scale=scale, keep_ratio=True, **kwargs)
58
+
59
+ @staticmethod
60
+ def _get_rescale_ratio(old_size: Tuple[int, int],
61
+ scale: Union[float, Tuple[int]]) -> float:
62
+ """Calculate the ratio for rescaling.
63
+
64
+ Args:
65
+ old_size (tuple[int]): The old size (w, h) of image.
66
+ scale (float | tuple[int]): The scaling factor or maximum size.
67
+ If it is a float number, then the image will be rescaled by
68
+ this factor, else if it is a tuple of 2 integers, then
69
+ the image will be rescaled as large as possible within
70
+ the scale.
71
+
72
+ Returns:
73
+ float: The resize ratio.
74
+ """
75
+ w, h = old_size
76
+ if isinstance(scale, (float, int)):
77
+ if scale <= 0:
78
+ raise ValueError(f'Invalid scale {scale}, must be positive.')
79
+ scale_factor = scale
80
+ elif isinstance(scale, tuple):
81
+ max_long_edge = max(scale)
82
+ max_short_edge = min(scale)
83
+ scale_factor = min(max_long_edge / max(h, w),
84
+ max_short_edge / min(h, w))
85
+ else:
86
+ raise TypeError('Scale must be a number or tuple of int, '
87
+ f'but got {type(scale)}')
88
+
89
+ return scale_factor
90
+
91
+ def _resize_img(self, results: dict):
92
+ """Resize images with ``results['scale']``."""
93
+ assert self.keep_ratio is True
94
+
95
+ if results.get('img', None) is not None:
96
+ image = results['img']
97
+ original_h, original_w = image.shape[:2]
98
+ ratio = self._get_rescale_ratio((original_h, original_w),
99
+ self.scale)
100
+
101
+ if ratio != 1:
102
+ # resize image according to the ratio
103
+ image = mmcv.imrescale(
104
+ img=image,
105
+ scale=ratio,
106
+ interpolation='area' if ratio < 1 else 'bilinear',
107
+ backend=self.backend)
108
+
109
+ resized_h, resized_w = image.shape[:2]
110
+ scale_ratio = resized_h / original_h
111
+
112
+ scale_factor = (scale_ratio, scale_ratio)
113
+
114
+ results['img'] = image
115
+ results['img_shape'] = image.shape[:2]
116
+ results['scale_factor'] = scale_factor
117
+
118
+
119
+ @TRANSFORMS.register_module()
120
+ class LetterResize(MMDET_Resize):
121
+ """Resize and pad image while meeting stride-multiple constraints.
122
+
123
+ Required Keys:
124
+
125
+ - img (np.uint8)
126
+ - batch_shape (np.int64) (optional)
127
+
128
+ Modified Keys:
129
+
130
+ - img (np.uint8)
131
+ - img_shape (tuple)
132
+ - gt_bboxes (optional)
133
+
134
+ Added Keys:
135
+ - pad_param (np.float32)
136
+
137
+ Args:
138
+ scale (Union[int, Tuple[int, int]]): Images scales for resizing.
139
+ pad_val (dict): Padding value. Defaults to dict(img=0, seg=255).
140
+ use_mini_pad (bool): Whether using minimum rectangle padding.
141
+ Defaults to True
142
+ stretch_only (bool): Whether stretch to the specified size directly.
143
+ Defaults to False
144
+ allow_scale_up (bool): Allow scale up when ratio > 1. Defaults to True
145
+ """
146
+
147
+ def __init__(self,
148
+ scale: Union[int, Tuple[int, int]],
149
+ pad_val: dict = dict(img=0, mask=0, seg=255),
150
+ use_mini_pad: bool = False,
151
+ stretch_only: bool = False,
152
+ allow_scale_up: bool = True,
153
+ **kwargs):
154
+ super().__init__(scale=scale, keep_ratio=True, **kwargs)
155
+
156
+ self.pad_val = pad_val
157
+ if isinstance(pad_val, (int, float)):
158
+ pad_val = dict(img=pad_val, seg=255)
159
+ assert isinstance(
160
+ pad_val, dict), f'pad_val must be dict, but got {type(pad_val)}'
161
+
162
+ self.use_mini_pad = use_mini_pad
163
+ self.stretch_only = stretch_only
164
+ self.allow_scale_up = allow_scale_up
165
+
166
+ def _resize_img(self, results: dict):
167
+ """Resize images with ``results['scale']``."""
168
+ image = results.get('img', None)
169
+ if image is None:
170
+ return
171
+
172
+ # Use batch_shape if a batch_shape policy is configured
173
+ if 'batch_shape' in results:
174
+ scale = tuple(results['batch_shape']) # hw
175
+ else:
176
+ scale = self.scale[::-1] # wh -> hw
177
+
178
+ image_shape = image.shape[:2] # height, width
179
+
180
+ # Scale ratio (new / old)
181
+ ratio = min(scale[0] / image_shape[0], scale[1] / image_shape[1])
182
+
183
+ # only scale down, do not scale up (for better test mAP)
184
+ if not self.allow_scale_up:
185
+ ratio = min(ratio, 1.0)
186
+
187
+ ratio = [ratio, ratio] # float -> (float, float) for (height, width)
188
+
189
+ # compute the best size of the image
190
+ no_pad_shape = (int(round(image_shape[0] * ratio[0])),
191
+ int(round(image_shape[1] * ratio[1])))
192
+
193
+ # padding height & width
194
+ padding_h, padding_w = [
195
+ scale[0] - no_pad_shape[0], scale[1] - no_pad_shape[1]
196
+ ]
197
+ if self.use_mini_pad:
198
+ # minimum rectangle padding
199
+ padding_w, padding_h = np.mod(padding_w, 32), np.mod(padding_h, 32)
200
+
201
+ elif self.stretch_only:
202
+ # stretch to the specified size directly
203
+ padding_h, padding_w = 0.0, 0.0
204
+ no_pad_shape = (scale[0], scale[1])
205
+ ratio = [scale[0] / image_shape[0],
206
+ scale[1] / image_shape[1]] # height, width ratios
207
+
208
+ if image_shape != no_pad_shape:
209
+ # compare with no resize and padding size
210
+ image = mmcv.imresize(
211
+ image, (no_pad_shape[1], no_pad_shape[0]),
212
+ interpolation=self.interpolation,
213
+ backend=self.backend)
214
+
215
+ scale_factor = (ratio[1], ratio[0]) # mmcv scale factor is (w, h)
216
+
217
+ if 'scale_factor' in results:
218
+ results['scale_factor_origin'] = results['scale_factor']
219
+ results['scale_factor'] = scale_factor
220
+
221
+ # padding
222
+ top_padding, left_padding = int(round(padding_h // 2 - 0.1)), int(
223
+ round(padding_w // 2 - 0.1))
224
+ bottom_padding = padding_h - top_padding
225
+ right_padding = padding_w - left_padding
226
+
227
+ padding_list = [
228
+ top_padding, bottom_padding, left_padding, right_padding
229
+ ]
230
+ if top_padding != 0 or bottom_padding != 0 or \
231
+ left_padding != 0 or right_padding != 0:
232
+
233
+ pad_val = self.pad_val.get('img', 0)
234
+ if isinstance(pad_val, int) and image.ndim == 3:
235
+ pad_val = tuple(pad_val for _ in range(image.shape[2]))
236
+
237
+ image = mmcv.impad(
238
+ img=image,
239
+ padding=(padding_list[2], padding_list[0], padding_list[3],
240
+ padding_list[1]),
241
+ pad_val=pad_val,
242
+ padding_mode='constant')
243
+
244
+ results['img'] = image
245
+ results['img_shape'] = image.shape
246
+ if 'pad_param' in results:
247
+ results['pad_param_origin'] = results['pad_param'] * \
248
+ np.repeat(ratio, 2)
249
+ results['pad_param'] = np.array(padding_list, dtype=np.float32)
250
+
251
+ def _resize_masks(self, results: dict):
252
+ """Resize masks with ``results['scale']``"""
253
+ if results.get('gt_masks', None) is None:
254
+ return
255
+
256
+ gt_masks = results['gt_masks']
257
+ assert isinstance(
258
+ gt_masks, PolygonMasks
259
+ ), f'Only supports PolygonMasks, but got {type(gt_masks)}'
260
+
261
+ # resize the gt_masks
262
+ gt_mask_h = results['gt_masks'].height * results['scale_factor'][1]
263
+ gt_mask_w = results['gt_masks'].width * results['scale_factor'][0]
264
+ gt_masks = results['gt_masks'].resize(
265
+ (int(round(gt_mask_h)), int(round(gt_mask_w))))
266
+
267
+ top_padding, _, left_padding, _ = results['pad_param']
268
+ if int(left_padding) != 0:
269
+ gt_masks = gt_masks.translate(
270
+ out_shape=results['img_shape'][:2],
271
+ offset=int(left_padding),
272
+ direction='horizontal')
273
+ if int(top_padding) != 0:
274
+ gt_masks = gt_masks.translate(
275
+ out_shape=results['img_shape'][:2],
276
+ offset=int(top_padding),
277
+ direction='vertical')
278
+ results['gt_masks'] = gt_masks
279
+
280
+ def _resize_bboxes(self, results: dict):
281
+ """Resize bounding boxes with ``results['scale_factor']``."""
282
+ if results.get('gt_bboxes', None) is None:
283
+ return
284
+ results['gt_bboxes'].rescale_(results['scale_factor'])
285
+
286
+ if len(results['pad_param']) != 4:
287
+ return
288
+ results['gt_bboxes'].translate_(
289
+ (results['pad_param'][2], results['pad_param'][0]))
290
+
291
+ if self.clip_object_border:
292
+ results['gt_bboxes'].clip_(results['img_shape'])
293
+
294
+ def transform(self, results: dict) -> dict:
295
+ results = super().transform(results)
296
+ if 'scale_factor_origin' in results:
297
+ scale_factor_origin = results.pop('scale_factor_origin')
298
+ results['scale_factor'] = (results['scale_factor'][0] *
299
+ scale_factor_origin[0],
300
+ results['scale_factor'][1] *
301
+ scale_factor_origin[1])
302
+ if 'pad_param_origin' in results:
303
+ pad_param_origin = results.pop('pad_param_origin')
304
+ results['pad_param'] += pad_param_origin
305
+ return results
306
+
307
+
308
+ # TODO: Check if it can be merged with mmdet.YOLOXHSVRandomAug
309
+ @TRANSFORMS.register_module()
310
+ class YOLOv5HSVRandomAug(BaseTransform):
311
+ """Apply HSV augmentation to image sequentially.
312
+
313
+ Required Keys:
314
+
315
+ - img
316
+
317
+ Modified Keys:
318
+
319
+ - img
320
+
321
+ Args:
322
+ hue_delta ([int, float]): delta of hue. Defaults to 0.015.
323
+ saturation_delta ([int, float]): delta of saturation. Defaults to 0.7.
324
+ value_delta ([int, float]): delta of value. Defaults to 0.4.
325
+ """
326
+
327
+ def __init__(self,
328
+ hue_delta: Union[int, float] = 0.015,
329
+ saturation_delta: Union[int, float] = 0.7,
330
+ value_delta: Union[int, float] = 0.4):
331
+ self.hue_delta = hue_delta
332
+ self.saturation_delta = saturation_delta
333
+ self.value_delta = value_delta
334
+
335
+ def transform(self, results: dict) -> dict:
336
+ """The HSV augmentation transform function.
337
+
338
+ Args:
339
+ results (dict): The result dict.
340
+
341
+ Returns:
342
+ dict: The result dict.
343
+ """
344
+ hsv_gains = \
345
+ random.uniform(-1, 1, 3) * \
346
+ [self.hue_delta, self.saturation_delta, self.value_delta] + 1
347
+ hue, sat, val = cv2.split(
348
+ cv2.cvtColor(results['img'], cv2.COLOR_BGR2HSV))
349
+
350
+ table_list = np.arange(0, 256, dtype=hsv_gains.dtype)
351
+ lut_hue = ((table_list * hsv_gains[0]) % 180).astype(np.uint8)
352
+ lut_sat = np.clip(table_list * hsv_gains[1], 0, 255).astype(np.uint8)
353
+ lut_val = np.clip(table_list * hsv_gains[2], 0, 255).astype(np.uint8)
354
+
355
+ im_hsv = cv2.merge(
356
+ (cv2.LUT(hue, lut_hue), cv2.LUT(sat,
357
+ lut_sat), cv2.LUT(val, lut_val)))
358
+ results['img'] = cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR)
359
+ return results
360
+
361
+ def __repr__(self) -> str:
362
+ repr_str = self.__class__.__name__
363
+ repr_str += f'(hue_delta={self.hue_delta}, '
364
+ repr_str += f'saturation_delta={self.saturation_delta}, '
365
+ repr_str += f'value_delta={self.value_delta})'
366
+ return repr_str
367
+
368
+
369
+ @TRANSFORMS.register_module()
370
+ class LoadAnnotations(MMDET_LoadAnnotations):
371
+ """Because the yolo series does not need to consider ignore bboxes for the
372
+ time being, in order to speed up the pipeline, it can be excluded in
373
+ advance."""
374
+
375
+ def __init__(self,
376
+ mask2bbox: bool = False,
377
+ poly2mask: bool = False,
378
+ **kwargs) -> None:
379
+ self.mask2bbox = mask2bbox
380
+ assert not poly2mask, 'Does not support BitmapMasks considering ' \
381
+ 'that bitmap consumes more memory.'
382
+ super().__init__(poly2mask=poly2mask, **kwargs)
383
+ if self.mask2bbox:
384
+ assert self.with_mask, 'Using mask2bbox requires ' \
385
+ 'with_mask is True.'
386
+ self._mask_ignore_flag = None
387
+
388
+ def transform(self, results: dict) -> dict:
389
+ """Function to load multiple types annotations.
390
+
391
+ Args:
392
+ results (dict): Result dict from :obj:``mmengine.BaseDataset``.
393
+
394
+ Returns:
395
+ dict: The dict contains loaded bounding box, label and
396
+ semantic segmentation.
397
+ """
398
+ if self.mask2bbox:
399
+ self._load_masks(results)
400
+ if self.with_label:
401
+ self._load_labels(results)
402
+ self._update_mask_ignore_data(results)
403
+ gt_bboxes = results['gt_masks'].get_bboxes(dst_type='hbox')
404
+ results['gt_bboxes'] = gt_bboxes
405
+ else:
406
+ results = super().transform(results)
407
+ self._update_mask_ignore_data(results)
408
+ return results
409
+
410
+ def _update_mask_ignore_data(self, results: dict) -> None:
411
+ if 'gt_masks' not in results:
412
+ return
413
+
414
+ if 'gt_bboxes_labels' in results and len(
415
+ results['gt_bboxes_labels']) != len(results['gt_masks']):
416
+ assert len(results['gt_bboxes_labels']) == len(
417
+ self._mask_ignore_flag)
418
+ results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
419
+ self._mask_ignore_flag]
420
+
421
+ if 'gt_bboxes' in results and len(results['gt_bboxes']) != len(
422
+ results['gt_masks']):
423
+ assert len(results['gt_bboxes']) == len(self._mask_ignore_flag)
424
+ results['gt_bboxes'] = results['gt_bboxes'][self._mask_ignore_flag]
425
+
426
+ def _load_bboxes(self, results: dict):
427
+ """Private function to load bounding box annotations.
428
+ Note: BBoxes with ignore_flag of 1 is not considered.
429
+ Args:
430
+ results (dict): Result dict from :obj:``mmengine.BaseDataset``.
431
+
432
+ Returns:
433
+ dict: The dict contains loaded bounding box annotations.
434
+ """
435
+ gt_bboxes = []
436
+ gt_ignore_flags = []
437
+ for instance in results.get('instances', []):
438
+ if instance['ignore_flag'] == 0:
439
+ gt_bboxes.append(instance['bbox'])
440
+ gt_ignore_flags.append(instance['ignore_flag'])
441
+ results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
442
+
443
+ if self.box_type is None:
444
+ results['gt_bboxes'] = np.array(
445
+ gt_bboxes, dtype=np.float32).reshape((-1, 4))
446
+ else:
447
+ _, box_type_cls = get_box_type(self.box_type)
448
+ results['gt_bboxes'] = box_type_cls(gt_bboxes, dtype=torch.float32)
449
+
450
+ def _load_labels(self, results: dict):
451
+ """Private function to load label annotations.
452
+
453
+ Note: BBoxes with ignore_flag of 1 is not considered.
454
+ Args:
455
+ results (dict): Result dict from :obj:``mmengine.BaseDataset``.
456
+ Returns:
457
+ dict: The dict contains loaded label annotations.
458
+ """
459
+ gt_bboxes_labels = []
460
+ for instance in results.get('instances', []):
461
+ if instance['ignore_flag'] == 0:
462
+ gt_bboxes_labels.append(instance['bbox_label'])
463
+ results['gt_bboxes_labels'] = np.array(
464
+ gt_bboxes_labels, dtype=np.int64)
465
+
466
+ def _load_masks(self, results: dict) -> None:
467
+ """Private function to load mask annotations.
468
+
469
+ Args:
470
+ results (dict): Result dict from :obj:``mmengine.BaseDataset``.
471
+ """
472
+ gt_masks = []
473
+ gt_ignore_flags = []
474
+ self._mask_ignore_flag = []
475
+ for instance in results.get('instances', []):
476
+ if instance['ignore_flag'] == 0:
477
+ if 'mask' in instance:
478
+ gt_mask = instance['mask']
479
+ if isinstance(gt_mask, list):
480
+ gt_mask = [
481
+ np.array(polygon) for polygon in gt_mask
482
+ if len(polygon) % 2 == 0 and len(polygon) >= 6
483
+ ]
484
+ if len(gt_mask) == 0:
485
+ # ignore
486
+ self._mask_ignore_flag.append(0)
487
+ else:
488
+ gt_masks.append(gt_mask)
489
+ gt_ignore_flags.append(instance['ignore_flag'])
490
+ self._mask_ignore_flag.append(1)
491
+ else:
492
+ raise NotImplementedError(
493
+ 'Only supports mask annotations in polygon '
494
+ 'format currently')
495
+ else:
496
+ # TODO: Actually, gt with bbox and without mask needs
497
+ # to be retained
498
+ self._mask_ignore_flag.append(0)
499
+ self._mask_ignore_flag = np.array(self._mask_ignore_flag, dtype=bool)
500
+ results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool)
501
+
502
+ h, w = results['ori_shape']
503
+ gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
504
+ results['gt_masks'] = gt_masks
505
+
506
+ def __repr__(self) -> str:
507
+ repr_str = self.__class__.__name__
508
+ repr_str += f'(with_bbox={self.with_bbox}, '
509
+ repr_str += f'with_label={self.with_label}, '
510
+ repr_str += f'with_mask={self.with_mask}, '
511
+ repr_str += f'with_seg={self.with_seg}, '
512
+ repr_str += f'mask2bbox={self.mask2bbox}, '
513
+ repr_str += f'poly2mask={self.poly2mask}, '
514
+ repr_str += f"imdecode_backend='{self.imdecode_backend}', "
515
+ repr_str += f'file_client_args={self.file_client_args})'
516
+ return repr_str
517
+
518
+
519
+ @TRANSFORMS.register_module()
520
+ class YOLOv5RandomAffine(BaseTransform):
521
+ """Random affine transform data augmentation in YOLOv5 and YOLOv8. It is
522
+ different from the implementation in YOLOX.
523
+
524
+ This operation randomly generates affine transform matrix which including
525
+ rotation, translation, shear and scaling transforms.
526
+ If you set use_mask_refine == True, the code will use the masks
527
+ annotation to refine the bbox.
528
+ Our implementation is slightly different from the official. In COCO
529
+ dataset, a gt may have multiple mask tags. The official YOLOv5
530
+ annotation file already combines the masks that an object has,
531
+ but our code takes into account the fact that an object has multiple masks.
532
+
533
+ Required Keys:
534
+
535
+ - img
536
+ - gt_bboxes (BaseBoxes[torch.float32]) (optional)
537
+ - gt_bboxes_labels (np.int64) (optional)
538
+ - gt_ignore_flags (bool) (optional)
539
+ - gt_masks (PolygonMasks) (optional)
540
+
541
+ Modified Keys:
542
+
543
+ - img
544
+ - img_shape
545
+ - gt_bboxes (optional)
546
+ - gt_bboxes_labels (optional)
547
+ - gt_ignore_flags (optional)
548
+ - gt_masks (PolygonMasks) (optional)
549
+
550
+ Args:
551
+ max_rotate_degree (float): Maximum degrees of rotation transform.
552
+ Defaults to 10.
553
+ max_translate_ratio (float): Maximum ratio of translation.
554
+ Defaults to 0.1.
555
+ scaling_ratio_range (tuple[float]): Min and max ratio of
556
+ scaling transform. Defaults to (0.5, 1.5).
557
+ max_shear_degree (float): Maximum degrees of shear
558
+ transform. Defaults to 2.
559
+ border (tuple[int]): Distance from width and height sides of input
560
+ image to adjust output shape. Only used in mosaic dataset.
561
+ Defaults to (0, 0).
562
+ border_val (tuple[int]): Border padding values of 3 channels.
563
+ Defaults to (114, 114, 114).
564
+ bbox_clip_border (bool, optional): Whether to clip the objects outside
565
+ the border of the image. In some dataset like MOT17, the gt bboxes
566
+ are allowed to cross the border of images. Therefore, we don't
567
+ need to clip the gt bboxes in these cases. Defaults to True.
568
+ min_bbox_size (float): Width and height threshold to filter bboxes.
569
+ If the height or width of a box is smaller than this value, it
570
+ will be removed. Defaults to 2.
571
+ min_area_ratio (float): Threshold of area ratio between
572
+ original bboxes and wrapped bboxes. If smaller than this value,
573
+ the box will be removed. Defaults to 0.1.
574
+ use_mask_refine (bool): Whether to refine bbox by mask.
575
+ max_aspect_ratio (float): Aspect ratio of width and height
576
+ threshold to filter bboxes. If max(h/w, w/h) larger than this
577
+ value, the box will be removed. Defaults to 20.
578
+ resample_num (int): Number of poly to resample to.
579
+ """
580
+
581
+ def __init__(self,
582
+ max_rotate_degree: float = 10.0,
583
+ max_translate_ratio: float = 0.1,
584
+ scaling_ratio_range: Tuple[float, float] = (0.5, 1.5),
585
+ max_shear_degree: float = 2.0,
586
+ border: Tuple[int, int] = (0, 0),
587
+ border_val: Tuple[int, int, int] = (114, 114, 114),
588
+ bbox_clip_border: bool = True,
589
+ min_bbox_size: int = 2,
590
+ min_area_ratio: float = 0.1,
591
+ use_mask_refine: bool = False,
592
+ max_aspect_ratio: float = 20.,
593
+ resample_num: int = 1000):
594
+ assert 0 <= max_translate_ratio <= 1
595
+ assert scaling_ratio_range[0] <= scaling_ratio_range[1]
596
+ assert scaling_ratio_range[0] > 0
597
+ self.max_rotate_degree = max_rotate_degree
598
+ self.max_translate_ratio = max_translate_ratio
599
+ self.scaling_ratio_range = scaling_ratio_range
600
+ self.max_shear_degree = max_shear_degree
601
+ self.border = border
602
+ self.border_val = border_val
603
+ self.bbox_clip_border = bbox_clip_border
604
+ self.min_bbox_size = min_bbox_size
605
+ self.min_area_ratio = min_area_ratio
606
+ self.use_mask_refine = use_mask_refine
607
+ self.max_aspect_ratio = max_aspect_ratio
608
+ self.resample_num = resample_num
609
+
610
+ @autocast_box_type()
611
+ def transform(self, results: dict) -> dict:
612
+ """The YOLOv5 random affine transform function.
613
+
614
+ Args:
615
+ results (dict): The result dict.
616
+
617
+ Returns:
618
+ dict: The result dict.
619
+ """
620
+ img = results['img']
621
+ # self.border is wh format
622
+ height = img.shape[0] + self.border[1] * 2
623
+ width = img.shape[1] + self.border[0] * 2
624
+
625
+ # Note: Different from YOLOX
626
+ center_matrix = np.eye(3, dtype=np.float32)
627
+ center_matrix[0, 2] = -img.shape[1] / 2
628
+ center_matrix[1, 2] = -img.shape[0] / 2
629
+
630
+ warp_matrix, scaling_ratio = self._get_random_homography_matrix(
631
+ height, width)
632
+ warp_matrix = warp_matrix @ center_matrix
633
+
634
+ img = cv2.warpPerspective(
635
+ img,
636
+ warp_matrix,
637
+ dsize=(width, height),
638
+ borderValue=self.border_val)
639
+ results['img'] = img
640
+ results['img_shape'] = img.shape
641
+ img_h, img_w = img.shape[:2]
642
+
643
+ bboxes = results['gt_bboxes']
644
+ num_bboxes = len(bboxes)
645
+ if num_bboxes:
646
+ orig_bboxes = bboxes.clone()
647
+ if self.use_mask_refine and 'gt_masks' in results:
648
+ # If the dataset has annotations of mask,
649
+ # the mask will be used to refine bbox.
650
+ gt_masks = results['gt_masks']
651
+
652
+ gt_masks_resample = self.resample_masks(gt_masks)
653
+ gt_masks = self.warp_mask(gt_masks_resample, warp_matrix,
654
+ img_h, img_w)
655
+
656
+ # refine bboxes by masks
657
+ bboxes = gt_masks.get_bboxes(dst_type='hbox')
658
+ # filter bboxes outside image
659
+ valid_index = self.filter_gt_bboxes(orig_bboxes,
660
+ bboxes).numpy()
661
+ results['gt_masks'] = gt_masks[valid_index]
662
+ else:
663
+ bboxes.project_(warp_matrix)
664
+ if self.bbox_clip_border:
665
+ bboxes.clip_([height, width])
666
+
667
+ # filter bboxes
668
+ orig_bboxes.rescale_([scaling_ratio, scaling_ratio])
669
+
670
+ # Be careful: valid_index must convert to numpy,
671
+ # otherwise it will raise out of bounds when len(valid_index)=1
672
+ valid_index = self.filter_gt_bboxes(orig_bboxes,
673
+ bboxes).numpy()
674
+ if 'gt_masks' in results:
675
+ results['gt_masks'] = PolygonMasks(
676
+ results['gt_masks'].masks, img_h, img_w)
677
+
678
+ results['gt_bboxes'] = bboxes[valid_index]
679
+ results['gt_bboxes_labels'] = results['gt_bboxes_labels'][
680
+ valid_index]
681
+ results['gt_ignore_flags'] = results['gt_ignore_flags'][
682
+ valid_index]
683
+
684
+ return results
685
+
686
+ @staticmethod
687
+ def warp_poly(poly: np.ndarray, warp_matrix: np.ndarray, img_w: int,
688
+ img_h: int) -> np.ndarray:
689
+ """Function to warp one mask and filter points outside image.
690
+
691
+ Args:
692
+ poly (np.ndarray): Segmentation annotation with shape (n, ) and
693
+ with format (x1, y1, x2, y2, ...).
694
+ warp_matrix (np.ndarray): Affine transformation matrix.
695
+ Shape: (3, 3).
696
+ img_w (int): Width of output image.
697
+ img_h (int): Height of output image.
698
+ """
699
+ # TODO: Current logic may cause retained masks unusable for
700
+ # semantic segmentation training, which is same as official
701
+ # implementation.
702
+ poly = poly.reshape((-1, 2))
703
+ poly = np.concatenate((poly, np.ones(
704
+ (len(poly), 1), dtype=poly.dtype)),
705
+ axis=-1)
706
+ # transform poly
707
+ poly = poly @ warp_matrix.T
708
+ poly = poly[:, :2] / poly[:, 2:3]
709
+
710
+ # filter point outside image
711
+ x, y = poly.T
712
+ valid_ind_point = (x >= 0) & (y >= 0) & (x <= img_w) & (y <= img_h)
713
+ return poly[valid_ind_point].reshape(-1)
714
+
715
+ def warp_mask(self, gt_masks: PolygonMasks, warp_matrix: np.ndarray,
716
+ img_w: int, img_h: int) -> PolygonMasks:
717
+ """Warp masks by warp_matrix and retain masks inside image after
718
+ warping.
719
+
720
+ Args:
721
+ gt_masks (PolygonMasks): Annotations of semantic segmentation.
722
+ warp_matrix (np.ndarray): Affine transformation matrix.
723
+ Shape: (3, 3).
724
+ img_w (int): Width of output image.
725
+ img_h (int): Height of output image.
726
+
727
+ Returns:
728
+ PolygonMasks: Masks after warping.
729
+ """
730
+ masks = gt_masks.masks
731
+
732
+ new_masks = []
733
+ for poly_per_obj in masks:
734
+ warpped_poly_per_obj = []
735
+ # One gt may have multiple masks.
736
+ for poly in poly_per_obj:
737
+ valid_poly = self.warp_poly(poly, warp_matrix, img_w, img_h)
738
+ if len(valid_poly):
739
+ warpped_poly_per_obj.append(valid_poly.reshape(-1))
740
+ # If all the masks are invalid,
741
+ # add [0, 0, 0, 0, 0, 0,] here.
742
+ if not warpped_poly_per_obj:
743
+ # This will be filtered in function `filter_gt_bboxes`.
744
+ warpped_poly_per_obj = [
745
+ np.zeros(6, dtype=poly_per_obj[0].dtype)
746
+ ]
747
+ new_masks.append(warpped_poly_per_obj)
748
+
749
+ gt_masks = PolygonMasks(new_masks, img_h, img_w)
750
+ return gt_masks
751
+
752
+ def resample_masks(self, gt_masks: PolygonMasks) -> PolygonMasks:
753
+ """Function to resample each mask annotation with shape (2 * n, ) to
754
+ shape (resample_num * 2, ).
755
+
756
+ Args:
757
+ gt_masks (PolygonMasks): Annotations of semantic segmentation.
758
+ """
759
+ masks = gt_masks.masks
760
+ new_masks = []
761
+ for poly_per_obj in masks:
762
+ resample_poly_per_obj = []
763
+ for poly in poly_per_obj:
764
+ poly = poly.reshape((-1, 2)) # xy
765
+ poly = np.concatenate((poly, poly[0:1, :]), axis=0)
766
+ x = np.linspace(0, len(poly) - 1, self.resample_num)
767
+ xp = np.arange(len(poly))
768
+ poly = np.concatenate([
769
+ np.interp(x, xp, poly[:, i]) for i in range(2)
770
+ ]).reshape(2, -1).T.reshape(-1)
771
+ resample_poly_per_obj.append(poly)
772
+ new_masks.append(resample_poly_per_obj)
773
+ return PolygonMasks(new_masks, gt_masks.height, gt_masks.width)
774
+
775
+ def filter_gt_bboxes(self, origin_bboxes: HorizontalBoxes,
776
+ wrapped_bboxes: HorizontalBoxes) -> torch.Tensor:
777
+ """Filter gt bboxes.
778
+
779
+ Args:
780
+ origin_bboxes (HorizontalBoxes): Origin bboxes.
781
+ wrapped_bboxes (HorizontalBoxes): Wrapped bboxes
782
+
783
+ Returns:
784
+ dict: The result dict.
785
+ """
786
+ origin_w = origin_bboxes.widths
787
+ origin_h = origin_bboxes.heights
788
+ wrapped_w = wrapped_bboxes.widths
789
+ wrapped_h = wrapped_bboxes.heights
790
+ aspect_ratio = np.maximum(wrapped_w / (wrapped_h + 1e-16),
791
+ wrapped_h / (wrapped_w + 1e-16))
792
+
793
+ wh_valid_idx = (wrapped_w > self.min_bbox_size) & \
794
+ (wrapped_h > self.min_bbox_size)
795
+ area_valid_idx = wrapped_w * wrapped_h / (origin_w * origin_h +
796
+ 1e-16) > self.min_area_ratio
797
+ aspect_ratio_valid_idx = aspect_ratio < self.max_aspect_ratio
798
+ return wh_valid_idx & area_valid_idx & aspect_ratio_valid_idx
799
+
800
+ @cache_randomness
801
+ def _get_random_homography_matrix(self, height: int,
802
+ width: int) -> Tuple[np.ndarray, float]:
803
+ """Get random homography matrix.
804
+
805
+ Args:
806
+ height (int): Image height.
807
+ width (int): Image width.
808
+
809
+ Returns:
810
+ Tuple[np.ndarray, float]: The result of warp_matrix and
811
+ scaling_ratio.
812
+ """
813
+ # Rotation
814
+ rotation_degree = random.uniform(-self.max_rotate_degree,
815
+ self.max_rotate_degree)
816
+ rotation_matrix = self._get_rotation_matrix(rotation_degree)
817
+
818
+ # Scaling
819
+ scaling_ratio = random.uniform(self.scaling_ratio_range[0],
820
+ self.scaling_ratio_range[1])
821
+ scaling_matrix = self._get_scaling_matrix(scaling_ratio)
822
+
823
+ # Shear
824
+ x_degree = random.uniform(-self.max_shear_degree,
825
+ self.max_shear_degree)
826
+ y_degree = random.uniform(-self.max_shear_degree,
827
+ self.max_shear_degree)
828
+ shear_matrix = self._get_shear_matrix(x_degree, y_degree)
829
+
830
+ # Translation
831
+ trans_x = random.uniform(0.5 - self.max_translate_ratio,
832
+ 0.5 + self.max_translate_ratio) * width
833
+ trans_y = random.uniform(0.5 - self.max_translate_ratio,
834
+ 0.5 + self.max_translate_ratio) * height
835
+ translate_matrix = self._get_translation_matrix(trans_x, trans_y)
836
+ warp_matrix = (
837
+ translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix)
838
+ return warp_matrix, scaling_ratio
839
+
840
+ @staticmethod
841
+ def _get_rotation_matrix(rotate_degrees: float) -> np.ndarray:
842
+ """Get rotation matrix.
843
+
844
+ Args:
845
+ rotate_degrees (float): Rotate degrees.
846
+
847
+ Returns:
848
+ np.ndarray: The rotation matrix.
849
+ """
850
+ radian = math.radians(rotate_degrees)
851
+ rotation_matrix = np.array(
852
+ [[np.cos(radian), -np.sin(radian), 0.],
853
+ [np.sin(radian), np.cos(radian), 0.], [0., 0., 1.]],
854
+ dtype=np.float32)
855
+ return rotation_matrix
856
+
857
+ @staticmethod
858
+ def _get_scaling_matrix(scale_ratio: float) -> np.ndarray:
859
+ """Get scaling matrix.
860
+
861
+ Args:
862
+ scale_ratio (float): Scale ratio.
863
+
864
+ Returns:
865
+ np.ndarray: The scaling matrix.
866
+ """
867
+ scaling_matrix = np.array(
868
+ [[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]],
869
+ dtype=np.float32)
870
+ return scaling_matrix
871
+
872
+ @staticmethod
873
+ def _get_shear_matrix(x_shear_degrees: float,
874
+ y_shear_degrees: float) -> np.ndarray:
875
+ """Get shear matrix.
876
+
877
+ Args:
878
+ x_shear_degrees (float): X shear degrees.
879
+ y_shear_degrees (float): Y shear degrees.
880
+
881
+ Returns:
882
+ np.ndarray: The shear matrix.
883
+ """
884
+ x_radian = math.radians(x_shear_degrees)
885
+ y_radian = math.radians(y_shear_degrees)
886
+ shear_matrix = np.array([[1, np.tan(x_radian), 0.],
887
+ [np.tan(y_radian), 1, 0.], [0., 0., 1.]],
888
+ dtype=np.float32)
889
+ return shear_matrix
890
+
891
+ @staticmethod
892
+ def _get_translation_matrix(x: float, y: float) -> np.ndarray:
893
+ """Get translation matrix.
894
+
895
+ Args:
896
+ x (float): X translation.
897
+ y (float): Y translation.
898
+
899
+ Returns:
900
+ np.ndarray: The translation matrix.
901
+ """
902
+ translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
903
+ dtype=np.float32)
904
+ return translation_matrix
905
+
906
+ def __repr__(self) -> str:
907
+ repr_str = self.__class__.__name__
908
+ repr_str += f'(max_rotate_degree={self.max_rotate_degree}, '
909
+ repr_str += f'max_translate_ratio={self.max_translate_ratio}, '
910
+ repr_str += f'scaling_ratio_range={self.scaling_ratio_range}, '
911
+ repr_str += f'max_shear_degree={self.max_shear_degree}, '
912
+ repr_str += f'border={self.border}, '
913
+ repr_str += f'border_val={self.border_val}, '
914
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
915
+ return repr_str
916
+
917
+
918
+ @TRANSFORMS.register_module()
919
+ class PPYOLOERandomDistort(BaseTransform):
920
+ """Random hue, saturation, contrast and brightness distortion.
921
+
922
+ Required Keys:
923
+
924
+ - img
925
+
926
+ Modified Keys:
927
+
928
+ - img (np.float32)
929
+
930
+ Args:
931
+ hue_cfg (dict): Hue settings. Defaults to dict(min=-18,
932
+ max=18, prob=0.5).
933
+ saturation_cfg (dict): Saturation settings. Defaults to dict(
934
+ min=0.5, max=1.5, prob=0.5).
935
+ contrast_cfg (dict): Contrast settings. Defaults to dict(
936
+ min=0.5, max=1.5, prob=0.5).
937
+ brightness_cfg (dict): Brightness settings. Defaults to dict(
938
+ min=0.5, max=1.5, prob=0.5).
939
+ num_distort_func (int): The number of distort function. Defaults
940
+ to 4.
941
+ """
942
+
943
+ def __init__(self,
944
+ hue_cfg: dict = dict(min=-18, max=18, prob=0.5),
945
+ saturation_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
946
+ contrast_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
947
+ brightness_cfg: dict = dict(min=0.5, max=1.5, prob=0.5),
948
+ num_distort_func: int = 4):
949
+ self.hue_cfg = hue_cfg
950
+ self.saturation_cfg = saturation_cfg
951
+ self.contrast_cfg = contrast_cfg
952
+ self.brightness_cfg = brightness_cfg
953
+ self.num_distort_func = num_distort_func
954
+ assert 0 < self.num_distort_func <= 4, \
955
+ 'num_distort_func must > 0 and <= 4'
956
+ for cfg in [
957
+ self.hue_cfg, self.saturation_cfg, self.contrast_cfg,
958
+ self.brightness_cfg
959
+ ]:
960
+ assert 0. <= cfg['prob'] <= 1., 'prob must >=0 and <=1'
961
+
962
+ def transform_hue(self, results):
963
+ """Transform hue randomly."""
964
+ if random.uniform(0., 1.) >= self.hue_cfg['prob']:
965
+ return results
966
+ img = results['img']
967
+ delta = random.uniform(self.hue_cfg['min'], self.hue_cfg['max'])
968
+ u = np.cos(delta * np.pi)
969
+ w = np.sin(delta * np.pi)
970
+ delta_iq = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
971
+ rgb2yiq_matrix = np.array([[0.114, 0.587, 0.299],
972
+ [-0.321, -0.274, 0.596],
973
+ [0.311, -0.523, 0.211]])
974
+ yiq2rgb_matric = np.array([[1.0, -1.107, 1.705], [1.0, -0.272, -0.647],
975
+ [1.0, 0.956, 0.621]])
976
+ t = np.dot(np.dot(yiq2rgb_matric, delta_iq), rgb2yiq_matrix).T
977
+ img = np.dot(img, t)
978
+ results['img'] = img
979
+ return results
980
+
981
+ def transform_saturation(self, results):
982
+ """Transform saturation randomly."""
983
+ if random.uniform(0., 1.) >= self.saturation_cfg['prob']:
984
+ return results
985
+ img = results['img']
986
+ delta = random.uniform(self.saturation_cfg['min'],
987
+ self.saturation_cfg['max'])
988
+
989
+ # convert bgr img to gray img
990
+ gray = img * np.array([[[0.114, 0.587, 0.299]]], dtype=np.float32)
991
+ gray = gray.sum(axis=2, keepdims=True)
992
+ gray *= (1.0 - delta)
993
+ img *= delta
994
+ img += gray
995
+ results['img'] = img
996
+ return results
997
+
998
+ def transform_contrast(self, results):
999
+ """Transform contrast randomly."""
1000
+ if random.uniform(0., 1.) >= self.contrast_cfg['prob']:
1001
+ return results
1002
+ img = results['img']
1003
+ delta = random.uniform(self.contrast_cfg['min'],
1004
+ self.contrast_cfg['max'])
1005
+ img *= delta
1006
+ results['img'] = img
1007
+ return results
1008
+
1009
+ def transform_brightness(self, results):
1010
+ """Transform brightness randomly."""
1011
+ if random.uniform(0., 1.) >= self.brightness_cfg['prob']:
1012
+ return results
1013
+ img = results['img']
1014
+ delta = random.uniform(self.brightness_cfg['min'],
1015
+ self.brightness_cfg['max'])
1016
+ img += delta
1017
+ results['img'] = img
1018
+ return results
1019
+
1020
+ def transform(self, results: dict) -> dict:
1021
+ """The hue, saturation, contrast and brightness distortion function.
1022
+
1023
+ Args:
1024
+ results (dict): The result dict.
1025
+
1026
+ Returns:
1027
+ dict: The result dict.
1028
+ """
1029
+ results['img'] = results['img'].astype(np.float32)
1030
+
1031
+ functions = [
1032
+ self.transform_brightness, self.transform_contrast,
1033
+ self.transform_saturation, self.transform_hue
1034
+ ]
1035
+ distortions = random.permutation(functions)[:self.num_distort_func]
1036
+ for func in distortions:
1037
+ results = func(results)
1038
+ return results
1039
+
1040
+ def __repr__(self) -> str:
1041
+ repr_str = self.__class__.__name__
1042
+ repr_str += f'(hue_cfg={self.hue_cfg}, '
1043
+ repr_str += f'saturation_cfg={self.saturation_cfg}, '
1044
+ repr_str += f'contrast_cfg={self.contrast_cfg}, '
1045
+ repr_str += f'brightness_cfg={self.brightness_cfg}, '
1046
+ repr_str += f'num_distort_func={self.num_distort_func})'
1047
+ return repr_str
1048
+
1049
+
1050
+ @TRANSFORMS.register_module()
1051
+ class PPYOLOERandomCrop(BaseTransform):
1052
+ """Random crop the img and bboxes. Different thresholds are used in PPYOLOE
1053
+ to judge whether the clipped image meets the requirements. This
1054
+ implementation is different from the implementation of RandomCrop in mmdet.
1055
+
1056
+ Required Keys:
1057
+
1058
+ - img
1059
+ - gt_bboxes (BaseBoxes[torch.float32]) (optional)
1060
+ - gt_bboxes_labels (np.int64) (optional)
1061
+ - gt_ignore_flags (bool) (optional)
1062
+
1063
+ Modified Keys:
1064
+
1065
+ - img
1066
+ - img_shape
1067
+ - gt_bboxes (optional)
1068
+ - gt_bboxes_labels (optional)
1069
+ - gt_ignore_flags (optional)
1070
+
1071
+ Added Keys:
1072
+ - pad_param (np.float32)
1073
+
1074
+ Args:
1075
+ aspect_ratio (List[float]): Aspect ratio of cropped region. Default to
1076
+ [.5, 2].
1077
+ thresholds (List[float]): Iou thresholds for deciding a valid bbox crop
1078
+ in [min, max] format. Defaults to [.0, .1, .3, .5, .7, .9].
1079
+ scaling (List[float]): Ratio between a cropped region and the original
1080
+ image in [min, max] format. Default to [.3, 1.].
1081
+ num_attempts (int): Number of tries for each threshold before
1082
+ giving up. Default to 50.
1083
+ allow_no_crop (bool): Allow return without actually cropping them.
1084
+ Default to True.
1085
+ cover_all_box (bool): Ensure all bboxes are covered in the final crop.
1086
+ Default to False.
1087
+ """
1088
+
1089
+ def __init__(self,
1090
+ aspect_ratio: List[float] = [.5, 2.],
1091
+ thresholds: List[float] = [.0, .1, .3, .5, .7, .9],
1092
+ scaling: List[float] = [.3, 1.],
1093
+ num_attempts: int = 50,
1094
+ allow_no_crop: bool = True,
1095
+ cover_all_box: bool = False):
1096
+ self.aspect_ratio = aspect_ratio
1097
+ self.thresholds = thresholds
1098
+ self.scaling = scaling
1099
+ self.num_attempts = num_attempts
1100
+ self.allow_no_crop = allow_no_crop
1101
+ self.cover_all_box = cover_all_box
1102
+
1103
+ def _crop_data(self, results: dict, crop_box: Tuple[int, int, int, int],
1104
+ valid_inds: np.ndarray) -> Union[dict, None]:
1105
+ """Function to randomly crop images, bounding boxes, masks, semantic
1106
+ segmentation maps.
1107
+
1108
+ Args:
1109
+ results (dict): Result dict from loading pipeline.
1110
+ crop_box (Tuple[int, int, int, int]): Expected absolute coordinates
1111
+ for cropping, (x1, y1, x2, y2).
1112
+ valid_inds (np.ndarray): The indexes of gt that needs to be
1113
+ retained.
1114
+
1115
+ Returns:
1116
+ results (Union[dict, None]): Randomly cropped results, 'img_shape'
1117
+ key in result dict is updated according to crop size. None will
1118
+ be returned when there is no valid bbox after cropping.
1119
+ """
1120
+ # crop the image
1121
+ img = results['img']
1122
+ crop_x1, crop_y1, crop_x2, crop_y2 = crop_box
1123
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
1124
+ results['img'] = img
1125
+ img_shape = img.shape
1126
+ results['img_shape'] = img.shape
1127
+
1128
+ # crop bboxes accordingly and clip to the image boundary
1129
+ if results.get('gt_bboxes', None) is not None:
1130
+ bboxes = results['gt_bboxes']
1131
+ bboxes.translate_([-crop_x1, -crop_y1])
1132
+ bboxes.clip_(img_shape[:2])
1133
+
1134
+ results['gt_bboxes'] = bboxes[valid_inds]
1135
+
1136
+ if results.get('gt_ignore_flags', None) is not None:
1137
+ results['gt_ignore_flags'] = \
1138
+ results['gt_ignore_flags'][valid_inds]
1139
+
1140
+ if results.get('gt_bboxes_labels', None) is not None:
1141
+ results['gt_bboxes_labels'] = \
1142
+ results['gt_bboxes_labels'][valid_inds]
1143
+
1144
+ if results.get('gt_masks', None) is not None:
1145
+ results['gt_masks'] = results['gt_masks'][
1146
+ valid_inds.nonzero()[0]].crop(
1147
+ np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
1148
+
1149
+ # crop semantic seg
1150
+ if results.get('gt_seg_map', None) is not None:
1151
+ results['gt_seg_map'] = results['gt_seg_map'][crop_y1:crop_y2,
1152
+ crop_x1:crop_x2]
1153
+
1154
+ return results
1155
+
1156
+ @autocast_box_type()
1157
+ def transform(self, results: dict) -> Union[dict, None]:
1158
+ """The random crop transform function.
1159
+
1160
+ Args:
1161
+ results (dict): The result dict.
1162
+
1163
+ Returns:
1164
+ dict: The result dict.
1165
+ """
1166
+ if results.get('gt_bboxes', None) is None or len(
1167
+ results['gt_bboxes']) == 0:
1168
+ return results
1169
+
1170
+ orig_img_h, orig_img_w = results['img'].shape[:2]
1171
+ gt_bboxes = results['gt_bboxes']
1172
+
1173
+ thresholds = list(self.thresholds)
1174
+ if self.allow_no_crop:
1175
+ thresholds.append('no_crop')
1176
+ random.shuffle(thresholds)
1177
+
1178
+ for thresh in thresholds:
1179
+ # Determine the coordinates for cropping
1180
+ if thresh == 'no_crop':
1181
+ return results
1182
+
1183
+ found = False
1184
+ for i in range(self.num_attempts):
1185
+ crop_h, crop_w = self._get_crop_size((orig_img_h, orig_img_w))
1186
+ if self.aspect_ratio is None:
1187
+ if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
1188
+ continue
1189
+
1190
+ # get image crop_box
1191
+ margin_h = max(orig_img_h - crop_h, 0)
1192
+ margin_w = max(orig_img_w - crop_w, 0)
1193
+ offset_h, offset_w = self._rand_offset((margin_h, margin_w))
1194
+ crop_y1, crop_y2 = offset_h, offset_h + crop_h
1195
+ crop_x1, crop_x2 = offset_w, offset_w + crop_w
1196
+
1197
+ crop_box = [crop_x1, crop_y1, crop_x2, crop_y2]
1198
+ # Calculate the iou between gt_bboxes and crop_boxes
1199
+ iou = self._iou_matrix(gt_bboxes,
1200
+ np.array([crop_box], dtype=np.float32))
1201
+ # If the maximum value of the iou is less than thresh,
1202
+ # the current crop_box is considered invalid.
1203
+ if iou.max() < thresh:
1204
+ continue
1205
+
1206
+ # If cover_all_box == True and the minimum value of
1207
+ # the iou is less than thresh, the current crop_box
1208
+ # is considered invalid.
1209
+ if self.cover_all_box and iou.min() < thresh:
1210
+ continue
1211
+
1212
+ # Get which gt_bboxes to keep after cropping.
1213
+ valid_inds = self._get_valid_inds(
1214
+ gt_bboxes, np.array(crop_box, dtype=np.float32))
1215
+ if valid_inds.size > 0:
1216
+ found = True
1217
+ break
1218
+
1219
+ if found:
1220
+ results = self._crop_data(results, crop_box, valid_inds)
1221
+ return results
1222
+ return results
1223
+
1224
+ @cache_randomness
1225
+ def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]:
1226
+ """Randomly generate crop offset.
1227
+
1228
+ Args:
1229
+ margin (Tuple[int, int]): The upper bound for the offset generated
1230
+ randomly.
1231
+
1232
+ Returns:
1233
+ Tuple[int, int]: The random offset for the crop.
1234
+ """
1235
+ margin_h, margin_w = margin
1236
+ offset_h = np.random.randint(0, margin_h + 1)
1237
+ offset_w = np.random.randint(0, margin_w + 1)
1238
+
1239
+ return (offset_h, offset_w)
1240
+
1241
+ @cache_randomness
1242
+ def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]:
1243
+ """Randomly generates the crop size based on `image_size`.
1244
+
1245
+ Args:
1246
+ image_size (Tuple[int, int]): (h, w).
1247
+
1248
+ Returns:
1249
+ crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels.
1250
+ """
1251
+ h, w = image_size
1252
+ scale = random.uniform(*self.scaling)
1253
+ if self.aspect_ratio is not None:
1254
+ min_ar, max_ar = self.aspect_ratio
1255
+ aspect_ratio = random.uniform(
1256
+ max(min_ar, scale**2), min(max_ar, scale**-2))
1257
+ h_scale = scale / np.sqrt(aspect_ratio)
1258
+ w_scale = scale * np.sqrt(aspect_ratio)
1259
+ else:
1260
+ h_scale = random.uniform(*self.scaling)
1261
+ w_scale = random.uniform(*self.scaling)
1262
+ crop_h = h * h_scale
1263
+ crop_w = w * w_scale
1264
+ return int(crop_h), int(crop_w)
1265
+
1266
+ def _iou_matrix(self,
1267
+ gt_bbox: HorizontalBoxes,
1268
+ crop_bbox: np.ndarray,
1269
+ eps: float = 1e-10) -> np.ndarray:
1270
+ """Calculate iou between gt and image crop box.
1271
+
1272
+ Args:
1273
+ gt_bbox (HorizontalBoxes): Ground truth bounding boxes.
1274
+ crop_bbox (np.ndarray): Image crop coordinates in
1275
+ [x1, y1, x2, y2] format.
1276
+ eps (float): Default to 1e-10.
1277
+ Return:
1278
+ (np.ndarray): IoU.
1279
+ """
1280
+ gt_bbox = gt_bbox.tensor.numpy()
1281
+ lefttop = np.maximum(gt_bbox[:, np.newaxis, :2], crop_bbox[:, :2])
1282
+ rightbottom = np.minimum(gt_bbox[:, np.newaxis, 2:], crop_bbox[:, 2:])
1283
+
1284
+ overlap = np.prod(
1285
+ rightbottom - lefttop,
1286
+ axis=2) * (lefttop < rightbottom).all(axis=2)
1287
+ area_gt_bbox = np.prod(gt_bbox[:, 2:] - crop_bbox[:, :2], axis=1)
1288
+ area_crop_bbox = np.prod(gt_bbox[:, 2:] - crop_bbox[:, :2], axis=1)
1289
+ area_o = (area_gt_bbox[:, np.newaxis] + area_crop_bbox - overlap)
1290
+ return overlap / (area_o + eps)
1291
+
1292
+ def _get_valid_inds(self, gt_bbox: HorizontalBoxes,
1293
+ img_crop_bbox: np.ndarray) -> np.ndarray:
1294
+ """Get which Bboxes to keep at the current cropping coordinates.
1295
+
1296
+ Args:
1297
+ gt_bbox (HorizontalBoxes): Ground truth bounding boxes.
1298
+ img_crop_bbox (np.ndarray): Image crop coordinates in
1299
+ [x1, y1, x2, y2] format.
1300
+
1301
+ Returns:
1302
+ (np.ndarray): Valid indexes.
1303
+ """
1304
+ cropped_box = gt_bbox.tensor.numpy().copy()
1305
+ gt_bbox = gt_bbox.tensor.numpy().copy()
1306
+
1307
+ cropped_box[:, :2] = np.maximum(gt_bbox[:, :2], img_crop_bbox[:2])
1308
+ cropped_box[:, 2:] = np.minimum(gt_bbox[:, 2:], img_crop_bbox[2:])
1309
+ cropped_box[:, :2] -= img_crop_bbox[:2]
1310
+ cropped_box[:, 2:] -= img_crop_bbox[:2]
1311
+
1312
+ centers = (gt_bbox[:, :2] + gt_bbox[:, 2:]) / 2
1313
+ valid = np.logical_and(img_crop_bbox[:2] <= centers,
1314
+ centers < img_crop_bbox[2:]).all(axis=1)
1315
+ valid = np.logical_and(
1316
+ valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1))
1317
+
1318
+ return np.where(valid)[0]
1319
+
1320
+ def __repr__(self) -> str:
1321
+ repr_str = self.__class__.__name__
1322
+ repr_str += f'(aspect_ratio={self.aspect_ratio}, '
1323
+ repr_str += f'thresholds={self.thresholds}, '
1324
+ repr_str += f'scaling={self.scaling}, '
1325
+ repr_str += f'num_attempts={self.num_attempts}, '
1326
+ repr_str += f'allow_no_crop={self.allow_no_crop}, '
1327
+ repr_str += f'cover_all_box={self.cover_all_box})'
1328
+ return repr_str
1329
+
1330
+
1331
+ @TRANSFORMS.register_module()
1332
+ class YOLOv5CopyPaste(BaseTransform):
1333
+ """Copy-Paste used in YOLOv5 and YOLOv8.
1334
+
1335
+ This transform randomly copy some objects in the image to the mirror
1336
+ position of the image.It is different from the `CopyPaste` in mmdet.
1337
+
1338
+ Required Keys:
1339
+
1340
+ - img (np.uint8)
1341
+ - gt_bboxes (BaseBoxes[torch.float32])
1342
+ - gt_bboxes_labels (np.int64) (optional)
1343
+ - gt_ignore_flags (bool) (optional)
1344
+ - gt_masks (PolygonMasks) (optional)
1345
+
1346
+ Modified Keys:
1347
+
1348
+ - img
1349
+ - gt_bboxes
1350
+ - gt_bboxes_labels (np.int64) (optional)
1351
+ - gt_ignore_flags (optional)
1352
+ - gt_masks (optional)
1353
+
1354
+ Args:
1355
+ ioa_thresh (float): Ioa thresholds for deciding valid bbox.
1356
+ prob (float): Probability of choosing objects.
1357
+ Defaults to 0.5.
1358
+ """
1359
+
1360
+ def __init__(self, ioa_thresh: float = 0.3, prob: float = 0.5):
1361
+ self.ioa_thresh = ioa_thresh
1362
+ self.prob = prob
1363
+
1364
+ @autocast_box_type()
1365
+ def transform(self, results: dict) -> Union[dict, None]:
1366
+ """The YOLOv5 and YOLOv8 Copy-Paste transform function.
1367
+
1368
+ Args:
1369
+ results (dict): The result dict.
1370
+
1371
+ Returns:
1372
+ dict: The result dict.
1373
+ """
1374
+ if len(results.get('gt_masks', [])) == 0:
1375
+ return results
1376
+ gt_masks = results['gt_masks']
1377
+ assert isinstance(gt_masks, PolygonMasks),\
1378
+ 'only support type of PolygonMasks,' \
1379
+ ' but get type: %s' % type(gt_masks)
1380
+ gt_bboxes = results['gt_bboxes']
1381
+ gt_bboxes_labels = results.get('gt_bboxes_labels', None)
1382
+ img = results['img']
1383
+ img_h, img_w = img.shape[:2]
1384
+
1385
+ # calculate ioa
1386
+ gt_bboxes_flip = deepcopy(gt_bboxes)
1387
+ gt_bboxes_flip.flip_(img.shape)
1388
+
1389
+ ioa = self.bbox_ioa(gt_bboxes_flip, gt_bboxes)
1390
+ indexes = torch.nonzero((ioa < self.ioa_thresh).all(1))[:, 0]
1391
+ n = len(indexes)
1392
+ valid_inds = random.choice(
1393
+ indexes, size=round(self.prob * n), replace=False)
1394
+ if len(valid_inds) == 0:
1395
+ return results
1396
+
1397
+ if gt_bboxes_labels is not None:
1398
+ # prepare labels
1399
+ gt_bboxes_labels = np.concatenate(
1400
+ (gt_bboxes_labels, gt_bboxes_labels[valid_inds]), axis=0)
1401
+
1402
+ # prepare bboxes
1403
+ copypaste_bboxes = gt_bboxes_flip[valid_inds]
1404
+ gt_bboxes = gt_bboxes.cat([gt_bboxes, copypaste_bboxes])
1405
+
1406
+ # prepare images
1407
+ copypaste_gt_masks = gt_masks[valid_inds]
1408
+ copypaste_gt_masks_flip = copypaste_gt_masks.flip()
1409
+ # convert poly format to bitmap format
1410
+ # example: poly: [[array(0.0, 0.0, 10.0, 0.0, 10.0, 10.0, 0.0, 10.0]]
1411
+ # -> bitmap: a mask with shape equal to (1, img_h, img_w)
1412
+ # # type1 low speed
1413
+ # copypaste_gt_masks_bitmap = copypaste_gt_masks.to_ndarray()
1414
+ # copypaste_mask = np.sum(copypaste_gt_masks_bitmap, axis=0) > 0
1415
+
1416
+ # type2
1417
+ copypaste_mask = np.zeros((img_h, img_w), dtype=np.uint8)
1418
+ for poly in copypaste_gt_masks.masks:
1419
+ poly = [i.reshape((-1, 1, 2)).astype(np.int32) for i in poly]
1420
+ cv2.drawContours(copypaste_mask, poly, -1, (1, ), cv2.FILLED)
1421
+
1422
+ copypaste_mask = copypaste_mask.astype(bool)
1423
+
1424
+ # copy objects, and paste to the mirror position of the image
1425
+ copypaste_mask_flip = mmcv.imflip(
1426
+ copypaste_mask, direction='horizontal')
1427
+ copypaste_img = mmcv.imflip(img, direction='horizontal')
1428
+ img[copypaste_mask_flip] = copypaste_img[copypaste_mask_flip]
1429
+
1430
+ # prepare masks
1431
+ gt_masks = copypaste_gt_masks.cat([gt_masks, copypaste_gt_masks_flip])
1432
+
1433
+ if 'gt_ignore_flags' in results:
1434
+ # prepare gt_ignore_flags
1435
+ gt_ignore_flags = results['gt_ignore_flags']
1436
+ gt_ignore_flags = np.concatenate(
1437
+ [gt_ignore_flags, gt_ignore_flags[valid_inds]], axis=0)
1438
+ results['gt_ignore_flags'] = gt_ignore_flags
1439
+
1440
+ results['img'] = img
1441
+ results['gt_bboxes'] = gt_bboxes
1442
+ if gt_bboxes_labels is not None:
1443
+ results['gt_bboxes_labels'] = gt_bboxes_labels
1444
+ results['gt_masks'] = gt_masks
1445
+
1446
+ return results
1447
+
1448
+ @staticmethod
1449
+ def bbox_ioa(gt_bboxes_flip: HorizontalBoxes,
1450
+ gt_bboxes: HorizontalBoxes,
1451
+ eps: float = 1e-7) -> np.ndarray:
1452
+ """Calculate ioa between gt_bboxes_flip and gt_bboxes.
1453
+
1454
+ Args:
1455
+ gt_bboxes_flip (HorizontalBoxes): Flipped ground truth
1456
+ bounding boxes.
1457
+ gt_bboxes (HorizontalBoxes): Ground truth bounding boxes.
1458
+ eps (float): Default to 1e-10.
1459
+ Return:
1460
+ (Tensor): Ioa.
1461
+ """
1462
+ gt_bboxes_flip = gt_bboxes_flip.tensor
1463
+ gt_bboxes = gt_bboxes.tensor
1464
+
1465
+ # Get the coordinates of bounding boxes
1466
+ b1_x1, b1_y1, b1_x2, b1_y2 = gt_bboxes_flip.T
1467
+ b2_x1, b2_y1, b2_x2, b2_y2 = gt_bboxes.T
1468
+
1469
+ # Intersection area
1470
+ inter_area = (torch.minimum(b1_x2[:, None],
1471
+ b2_x2) - torch.maximum(b1_x1[:, None],
1472
+ b2_x1)).clip(0) * \
1473
+ (torch.minimum(b1_y2[:, None],
1474
+ b2_y2) - torch.maximum(b1_y1[:, None],
1475
+ b2_y1)).clip(0)
1476
+
1477
+ # box2 area
1478
+ box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
1479
+
1480
+ # Intersection over box2 area
1481
+ return inter_area / box2_area
1482
+
1483
+ def __repr__(self) -> str:
1484
+ repr_str = self.__class__.__name__
1485
+ repr_str += f'(ioa_thresh={self.ioa_thresh},'
1486
+ repr_str += f'prob={self.prob})'
1487
+ return repr_str
1488
+
1489
+
1490
+ @TRANSFORMS.register_module()
1491
+ class RemoveDataElement(BaseTransform):
1492
+ """Remove unnecessary data element in results.
1493
+
1494
+ Args:
1495
+ keys (Union[str, Sequence[str]]): Keys need to be removed.
1496
+ """
1497
+
1498
+ def __init__(self, keys: Union[str, Sequence[str]]):
1499
+ self.keys = [keys] if isinstance(keys, str) else keys
1500
+
1501
+ def transform(self, results: dict) -> dict:
1502
+ for key in self.keys:
1503
+ results.pop(key, None)
1504
+ return results
1505
+
1506
+ def __repr__(self) -> str:
1507
+ repr_str = self.__class__.__name__
1508
+ repr_str += f'(keys={self.keys})'
1509
+ return repr_str
1510
+
1511
+
1512
+ @TRANSFORMS.register_module()
1513
+ class RegularizeRotatedBox(BaseTransform):
1514
+ """Regularize rotated boxes.
1515
+
1516
+ Due to the angle periodicity, one rotated box can be represented in
1517
+ many different (x, y, w, h, t). To make each rotated box unique,
1518
+ ``regularize_boxes`` will take the remainder of the angle divided by
1519
+ 180 degrees.
1520
+
1521
+ For convenience, three angle_version can be used here:
1522
+
1523
+ - 'oc': OpenCV Definition. Has the same box representation as
1524
+ ``cv2.minAreaRect`` the angle ranges in [-90, 0).
1525
+ - 'le90': Long Edge Definition (90). the angle ranges in [-90, 90).
1526
+ The width is always longer than the height.
1527
+ - 'le135': Long Edge Definition (135). the angle ranges in [-45, 135).
1528
+ The width is always longer than the height.
1529
+
1530
+ Required Keys:
1531
+
1532
+ - gt_bboxes (RotatedBoxes[torch.float32])
1533
+
1534
+ Modified Keys:
1535
+
1536
+ - gt_bboxes
1537
+
1538
+ Args:
1539
+ angle_version (str): Angle version. Can only be 'oc',
1540
+ 'le90', or 'le135'. Defaults to 'le90.
1541
+ """
1542
+
1543
+ def __init__(self, angle_version='le90') -> None:
1544
+ self.angle_version = angle_version
1545
+ try:
1546
+ from mmrotate.structures.bbox import RotatedBoxes
1547
+ self.box_type = RotatedBoxes
1548
+ except ImportError:
1549
+ raise ImportError(
1550
+ 'Please run "mim install -r requirements/mmrotate.txt" '
1551
+ 'to install mmrotate first for rotated detection.')
1552
+
1553
+ def transform(self, results: dict) -> dict:
1554
+ assert isinstance(results['gt_bboxes'], self.box_type)
1555
+ results['gt_bboxes'] = self.box_type(
1556
+ results['gt_bboxes'].regularize_boxes(self.angle_version))
1557
+ return results
mmyolo/datasets/utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List, Sequence
3
+
4
+ import numpy as np
5
+ import torch
6
+ from mmengine.dataset import COLLATE_FUNCTIONS
7
+
8
+ from ..registry import TASK_UTILS
9
+
10
+
11
+ @COLLATE_FUNCTIONS.register_module()
12
+ def yolov5_collate(data_batch: Sequence,
13
+ use_ms_training: bool = False) -> dict:
14
+ """Rewrite collate_fn to get faster training speed.
15
+
16
+ Args:
17
+ data_batch (Sequence): Batch of data.
18
+ use_ms_training (bool): Whether to use multi-scale training.
19
+ """
20
+ batch_imgs = []
21
+ batch_bboxes_labels = []
22
+ batch_masks = []
23
+ for i in range(len(data_batch)):
24
+ datasamples = data_batch[i]['data_samples']
25
+ inputs = data_batch[i]['inputs']
26
+ batch_imgs.append(inputs)
27
+
28
+ gt_bboxes = datasamples.gt_instances.bboxes.tensor
29
+ gt_labels = datasamples.gt_instances.labels
30
+ if 'masks' in datasamples.gt_instances:
31
+ masks = datasamples.gt_instances.masks.to_tensor(
32
+ dtype=torch.bool, device=gt_bboxes.device)
33
+ batch_masks.append(masks)
34
+ batch_idx = gt_labels.new_full((len(gt_labels), 1), i)
35
+ bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes),
36
+ dim=1)
37
+ batch_bboxes_labels.append(bboxes_labels)
38
+
39
+ collated_results = {
40
+ 'data_samples': {
41
+ 'bboxes_labels': torch.cat(batch_bboxes_labels, 0)
42
+ }
43
+ }
44
+ if len(batch_masks) > 0:
45
+ collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0)
46
+
47
+ if use_ms_training:
48
+ collated_results['inputs'] = batch_imgs
49
+ else:
50
+ collated_results['inputs'] = torch.stack(batch_imgs, 0)
51
+ return collated_results
52
+
53
+
54
+ @TASK_UTILS.register_module()
55
+ class BatchShapePolicy:
56
+ """BatchShapePolicy is only used in the testing phase, which can reduce the
57
+ number of pad pixels during batch inference.
58
+
59
+ Args:
60
+ batch_size (int): Single GPU batch size during batch inference.
61
+ Defaults to 32.
62
+ img_size (int): Expected output image size. Defaults to 640.
63
+ size_divisor (int): The minimum size that is divisible
64
+ by size_divisor. Defaults to 32.
65
+ extra_pad_ratio (float): Extra pad ratio. Defaults to 0.5.
66
+ """
67
+
68
+ def __init__(self,
69
+ batch_size: int = 32,
70
+ img_size: int = 640,
71
+ size_divisor: int = 32,
72
+ extra_pad_ratio: float = 0.5):
73
+ self.batch_size = batch_size
74
+ self.img_size = img_size
75
+ self.size_divisor = size_divisor
76
+ self.extra_pad_ratio = extra_pad_ratio
77
+
78
+ def __call__(self, data_list: List[dict]) -> List[dict]:
79
+ image_shapes = []
80
+ for data_info in data_list:
81
+ image_shapes.append((data_info['width'], data_info['height']))
82
+
83
+ image_shapes = np.array(image_shapes, dtype=np.float64)
84
+
85
+ n = len(image_shapes) # number of images
86
+ batch_index = np.floor(np.arange(n) / self.batch_size).astype(
87
+ np.int64) # batch index
88
+ number_of_batches = batch_index[-1] + 1 # number of batches
89
+
90
+ aspect_ratio = image_shapes[:, 1] / image_shapes[:, 0] # aspect ratio
91
+ irect = aspect_ratio.argsort()
92
+
93
+ data_list = [data_list[i] for i in irect]
94
+
95
+ aspect_ratio = aspect_ratio[irect]
96
+ # Set training image shapes
97
+ shapes = [[1, 1]] * number_of_batches
98
+ for i in range(number_of_batches):
99
+ aspect_ratio_index = aspect_ratio[batch_index == i]
100
+ min_index, max_index = aspect_ratio_index.min(
101
+ ), aspect_ratio_index.max()
102
+ if max_index < 1:
103
+ shapes[i] = [max_index, 1]
104
+ elif min_index > 1:
105
+ shapes[i] = [1, 1 / min_index]
106
+
107
+ batch_shapes = np.ceil(
108
+ np.array(shapes) * self.img_size / self.size_divisor +
109
+ self.extra_pad_ratio).astype(np.int64) * self.size_divisor
110
+
111
+ for i, data_info in enumerate(data_list):
112
+ data_info['batch_shape'] = batch_shapes[batch_index[i]]
113
+
114
+ return data_list
mmyolo/datasets/yolov5_coco.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Any, Optional
3
+
4
+ from mmdet.datasets import BaseDetDataset, CocoDataset
5
+
6
+ from ..registry import DATASETS, TASK_UTILS
7
+
8
+
9
+ class BatchShapePolicyDataset(BaseDetDataset):
10
+ """Dataset with the batch shape policy that makes paddings with least
11
+ pixels during batch inference process, which does not require the image
12
+ scales of all batches to be the same throughout validation."""
13
+
14
+ def __init__(self,
15
+ *args,
16
+ batch_shapes_cfg: Optional[dict] = None,
17
+ **kwargs):
18
+ self.batch_shapes_cfg = batch_shapes_cfg
19
+ super().__init__(*args, **kwargs)
20
+
21
+ def full_init(self):
22
+ """rewrite full_init() to be compatible with serialize_data in
23
+ BatchShapePolicy."""
24
+ if self._fully_initialized:
25
+ return
26
+ # load data information
27
+ self.data_list = self.load_data_list()
28
+
29
+ # batch_shapes_cfg
30
+ if self.batch_shapes_cfg:
31
+ batch_shapes_policy = TASK_UTILS.build(self.batch_shapes_cfg)
32
+ self.data_list = batch_shapes_policy(self.data_list)
33
+ del batch_shapes_policy
34
+
35
+ # filter illegal data, such as data that has no annotations.
36
+ self.data_list = self.filter_data()
37
+ # Get subset data according to indices.
38
+ if self._indices is not None:
39
+ self.data_list = self._get_unserialized_subset(self._indices)
40
+
41
+ # serialize data_list
42
+ if self.serialize_data:
43
+ self.data_bytes, self.data_address = self._serialize_data()
44
+
45
+ self._fully_initialized = True
46
+
47
+ def prepare_data(self, idx: int) -> Any:
48
+ """Pass the dataset to the pipeline during training to support mixed
49
+ data augmentation, such as Mosaic and MixUp."""
50
+ if self.test_mode is False:
51
+ data_info = self.get_data_info(idx)
52
+ data_info['dataset'] = self
53
+ return self.pipeline(data_info)
54
+ else:
55
+ return super().prepare_data(idx)
56
+
57
+
58
+ @DATASETS.register_module()
59
+ class YOLOv5CocoDataset(BatchShapePolicyDataset, CocoDataset):
60
+ """Dataset for YOLOv5 COCO Dataset.
61
+
62
+ We only add `BatchShapePolicy` function compared with CocoDataset. See
63
+ `mmyolo/datasets/utils.py#BatchShapePolicy` for details
64
+ """
65
+ pass
mmyolo/datasets/yolov5_crowdhuman.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmdet.datasets import CrowdHumanDataset
3
+
4
+ from ..registry import DATASETS
5
+ from .yolov5_coco import BatchShapePolicyDataset
6
+
7
+
8
+ @DATASETS.register_module()
9
+ class YOLOv5CrowdHumanDataset(BatchShapePolicyDataset, CrowdHumanDataset):
10
+ """Dataset for YOLOv5 CrowdHuman Dataset.
11
+
12
+ We only add `BatchShapePolicy` function compared with CrowdHumanDataset.
13
+ See `mmyolo/datasets/utils.py#BatchShapePolicy` for details
14
+ """
15
+ pass
mmyolo/datasets/yolov5_dota.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+
3
+ from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
4
+ from ..registry import DATASETS
5
+
6
+ try:
7
+ from mmrotate.datasets import DOTADataset
8
+ MMROTATE_AVAILABLE = True
9
+ except ImportError:
10
+ from mmengine.dataset import BaseDataset
11
+ DOTADataset = BaseDataset
12
+ MMROTATE_AVAILABLE = False
13
+
14
+
15
+ @DATASETS.register_module()
16
+ class YOLOv5DOTADataset(BatchShapePolicyDataset, DOTADataset):
17
+ """Dataset for YOLOv5 DOTA Dataset.
18
+
19
+ We only add `BatchShapePolicy` function compared with DOTADataset. See
20
+ `mmyolo/datasets/utils.py#BatchShapePolicy` for details
21
+ """
22
+
23
+ def __init__(self, *args, **kwargs):
24
+ if not MMROTATE_AVAILABLE:
25
+ raise ImportError(
26
+ 'Please run "mim install -r requirements/mmrotate.txt" '
27
+ 'to install mmrotate first for rotated detection.')
28
+
29
+ super().__init__(*args, **kwargs)
mmyolo/datasets/yolov5_voc.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmdet.datasets import VOCDataset
3
+
4
+ from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
5
+ from ..registry import DATASETS
6
+
7
+
8
+ @DATASETS.register_module()
9
+ class YOLOv5VOCDataset(BatchShapePolicyDataset, VOCDataset):
10
+ """Dataset for YOLOv5 VOC Dataset.
11
+
12
+ We only add `BatchShapePolicy` function compared with VOCDataset. See
13
+ `mmyolo/datasets/utils.py#BatchShapePolicy` for details
14
+ """
15
+ pass
mmyolo/deploy/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmdeploy.codebase.base import MMCodebase
3
+
4
+ from .models import * # noqa: F401,F403
5
+ from .object_detection import MMYOLO, YOLOObjectDetection
6
+
7
+ __all__ = ['MMCodebase', 'MMYOLO', 'YOLOObjectDetection']
mmyolo/deploy/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from . import dense_heads # noqa: F401,F403
mmyolo/deploy/models/dense_heads/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from . import yolov5_head # noqa: F401,F403
3
+
4
+ __all__ = ['yolov5_head']
mmyolo/deploy/models/dense_heads/yolov5_head.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ from functools import partial
4
+ from typing import List, Optional, Tuple
5
+
6
+ import torch
7
+ from mmdeploy.codebase.mmdet import get_post_processing_params
8
+ from mmdeploy.codebase.mmdet.models.layers import multiclass_nms
9
+ from mmdeploy.core import FUNCTION_REWRITER
10
+ from mmengine.config import ConfigDict
11
+ from mmengine.structures import InstanceData
12
+ from torch import Tensor
13
+
14
+ from mmyolo.deploy.models.layers import efficient_nms
15
+ from mmyolo.models.dense_heads import YOLOv5Head
16
+
17
+
18
+ def yolov5_bbox_decoder(priors: Tensor, bbox_preds: Tensor,
19
+ stride: int) -> Tensor:
20
+ """Decode YOLOv5 bounding boxes.
21
+
22
+ Args:
23
+ priors (Tensor): Prior boxes in center-offset form.
24
+ bbox_preds (Tensor): Predicted bounding boxes.
25
+ stride (int): Stride of the feature map.
26
+
27
+ Returns:
28
+ Tensor: Decoded bounding boxes.
29
+ """
30
+ bbox_preds = bbox_preds.sigmoid()
31
+
32
+ x_center = (priors[..., 0] + priors[..., 2]) * 0.5
33
+ y_center = (priors[..., 1] + priors[..., 3]) * 0.5
34
+ w = priors[..., 2] - priors[..., 0]
35
+ h = priors[..., 3] - priors[..., 1]
36
+
37
+ x_center_pred = (bbox_preds[..., 0] - 0.5) * 2 * stride + x_center
38
+ y_center_pred = (bbox_preds[..., 1] - 0.5) * 2 * stride + y_center
39
+ w_pred = (bbox_preds[..., 2] * 2)**2 * w
40
+ h_pred = (bbox_preds[..., 3] * 2)**2 * h
41
+
42
+ decoded_bboxes = torch.stack(
43
+ [x_center_pred, y_center_pred, w_pred, h_pred], dim=-1)
44
+
45
+ return decoded_bboxes
46
+
47
+
48
+ @FUNCTION_REWRITER.register_rewriter(
49
+ func_name='mmyolo.models.dense_heads.yolov5_head.'
50
+ 'YOLOv5Head.predict_by_feat')
51
+ def yolov5_head__predict_by_feat(self,
52
+ cls_scores: List[Tensor],
53
+ bbox_preds: List[Tensor],
54
+ objectnesses: Optional[List[Tensor]] = None,
55
+ batch_img_metas: Optional[List[dict]] = None,
56
+ cfg: Optional[ConfigDict] = None,
57
+ rescale: bool = False,
58
+ with_nms: bool = True) -> Tuple[InstanceData]:
59
+ """Transform a batch of output features extracted by the head into
60
+ bbox results.
61
+ Args:
62
+ cls_scores (list[Tensor]): Classification scores for all
63
+ scale levels, each is a 4D-tensor, has shape
64
+ (batch_size, num_priors * num_classes, H, W).
65
+ bbox_preds (list[Tensor]): Box energies / deltas for all
66
+ scale levels, each is a 4D-tensor, has shape
67
+ (batch_size, num_priors * 4, H, W).
68
+ objectnesses (list[Tensor], Optional): Score factor for
69
+ all scale level, each is a 4D-tensor, has shape
70
+ (batch_size, 1, H, W).
71
+ batch_img_metas (list[dict], Optional): Batch image meta info.
72
+ Defaults to None.
73
+ cfg (ConfigDict, optional): Test / postprocessing
74
+ configuration, if None, test_cfg would be used.
75
+ Defaults to None.
76
+ rescale (bool): If True, return boxes in original image space.
77
+ Defaults to False.
78
+ with_nms (bool): If True, do nms before return boxes.
79
+ Defaults to True.
80
+ Returns:
81
+ tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
82
+ where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
83
+ size and the score between 0 and 1. The shape of the second
84
+ tensor in the tuple is (N, num_box), and each element
85
+ represents the class label of the corresponding box.
86
+ """
87
+ ctx = FUNCTION_REWRITER.get_context()
88
+ detector_type = type(self)
89
+ deploy_cfg = ctx.cfg
90
+ use_efficientnms = deploy_cfg.get('use_efficientnms', False)
91
+ dtype = cls_scores[0].dtype
92
+ device = cls_scores[0].device
93
+ bbox_decoder = self.bbox_coder.decode
94
+ nms_func = multiclass_nms
95
+ if use_efficientnms:
96
+ if detector_type is YOLOv5Head:
97
+ nms_func = partial(efficient_nms, box_coding=0)
98
+ bbox_decoder = yolov5_bbox_decoder
99
+ else:
100
+ nms_func = efficient_nms
101
+
102
+ assert len(cls_scores) == len(bbox_preds)
103
+ cfg = self.test_cfg if cfg is None else cfg
104
+ cfg = copy.deepcopy(cfg)
105
+
106
+ num_imgs = cls_scores[0].shape[0]
107
+ featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
108
+
109
+ mlvl_priors = self.prior_generator.grid_priors(
110
+ featmap_sizes, dtype=dtype, device=device)
111
+
112
+ flatten_priors = torch.cat(mlvl_priors)
113
+
114
+ mlvl_strides = [
115
+ flatten_priors.new_full(
116
+ (featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
117
+ stride)
118
+ for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
119
+ ]
120
+ flatten_stride = torch.cat(mlvl_strides)
121
+
122
+ # flatten cls_scores, bbox_preds and objectness
123
+ flatten_cls_scores = [
124
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
125
+ for cls_score in cls_scores
126
+ ]
127
+ cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
128
+
129
+ flatten_bbox_preds = [
130
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
131
+ for bbox_pred in bbox_preds
132
+ ]
133
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
134
+
135
+ if objectnesses is not None:
136
+ flatten_objectness = [
137
+ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
138
+ for objectness in objectnesses
139
+ ]
140
+ flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
141
+ cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1))
142
+
143
+ scores = cls_scores
144
+
145
+ bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds,
146
+ flatten_stride)
147
+
148
+ if not with_nms:
149
+ return bboxes, scores
150
+
151
+ post_params = get_post_processing_params(deploy_cfg)
152
+ max_output_boxes_per_class = post_params.max_output_boxes_per_class
153
+ iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
154
+ score_threshold = cfg.get('score_thr', post_params.score_threshold)
155
+ pre_top_k = post_params.pre_top_k
156
+ keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
157
+
158
+ return nms_func(bboxes, scores, max_output_boxes_per_class, iou_threshold,
159
+ score_threshold, pre_top_k, keep_top_k)
160
+
161
+
162
+ @FUNCTION_REWRITER.register_rewriter(
163
+ func_name='mmyolo.models.dense_heads.yolov5_head.'
164
+ 'YOLOv5Head.predict',
165
+ backend='rknn')
166
+ def yolov5_head__predict__rknn(self, x: Tuple[Tensor], *args,
167
+ **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
168
+ """Perform forward propagation of the detection head and predict detection
169
+ results on the features of the upstream network.
170
+
171
+ Args:
172
+ x (tuple[Tensor]): Multi-level features from the
173
+ upstream network, each is a 4D-tensor.
174
+ """
175
+ outs = self(x)
176
+ return outs
177
+
178
+
179
+ @FUNCTION_REWRITER.register_rewriter(
180
+ func_name='mmyolo.models.dense_heads.yolov5_head.'
181
+ 'YOLOv5HeadModule.forward',
182
+ backend='rknn')
183
+ def yolov5_head_module__forward__rknn(
184
+ self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:
185
+ """Forward feature of a single scale level."""
186
+ out = []
187
+ for i, feat in enumerate(x):
188
+ out.append(self.convs_pred[i](feat))
189
+ return out
mmyolo/deploy/models/layers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .bbox_nms import efficient_nms
3
+
4
+ __all__ = ['efficient_nms']
mmyolo/deploy/models/layers/bbox_nms.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmdeploy.core import mark
4
+ from torch import Tensor
5
+
6
+
7
+ def _efficient_nms(
8
+ boxes: Tensor,
9
+ scores: Tensor,
10
+ max_output_boxes_per_class: int = 1000,
11
+ iou_threshold: float = 0.5,
12
+ score_threshold: float = 0.05,
13
+ pre_top_k: int = -1,
14
+ keep_top_k: int = 100,
15
+ box_coding: int = 0,
16
+ ):
17
+ """Wrapper for `efficient_nms` with TensorRT.
18
+
19
+ Args:
20
+ boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
21
+ scores (Tensor): The detection scores of shape
22
+ [N, num_boxes, num_classes].
23
+ max_output_boxes_per_class (int): Maximum number of output
24
+ boxes per class of nms. Defaults to 1000.
25
+ iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
26
+ score_threshold (float): score threshold of nms.
27
+ Defaults to 0.05.
28
+ pre_top_k (int): Number of top K boxes to keep before nms.
29
+ Defaults to -1.
30
+ keep_top_k (int): Number of top K boxes to keep after nms.
31
+ Defaults to -1.
32
+ box_coding (int): Bounding boxes format for nms.
33
+ Defaults to 0 means [x, y, w, h].
34
+ Set to 1 means [x1, y1 ,x2, y2].
35
+
36
+ Returns:
37
+ tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
38
+ and `labels` of shape [N, num_det].
39
+ """
40
+ boxes = boxes if boxes.dim() == 4 else boxes.unsqueeze(2)
41
+ _, det_boxes, det_scores, labels = TRTEfficientNMSop.apply(
42
+ boxes, scores, -1, box_coding, iou_threshold, keep_top_k, '1', 0,
43
+ score_threshold)
44
+ dets = torch.cat([det_boxes, det_scores.unsqueeze(2)], -1)
45
+
46
+ # retain shape info
47
+ batch_size = boxes.size(0)
48
+
49
+ dets_shape = dets.shape
50
+ label_shape = labels.shape
51
+ dets = dets.reshape([batch_size, *dets_shape[1:]])
52
+ labels = labels.reshape([batch_size, *label_shape[1:]])
53
+ return dets, labels
54
+
55
+
56
+ @mark('efficient_nms', inputs=['boxes', 'scores'], outputs=['dets', 'labels'])
57
+ def efficient_nms(*args, **kwargs):
58
+ """Wrapper function for `_efficient_nms`."""
59
+ return _efficient_nms(*args, **kwargs)
60
+
61
+
62
+ class TRTEfficientNMSop(torch.autograd.Function):
63
+ """Efficient NMS op for TensorRT."""
64
+
65
+ @staticmethod
66
+ def forward(
67
+ ctx,
68
+ boxes,
69
+ scores,
70
+ background_class=-1,
71
+ box_coding=0,
72
+ iou_threshold=0.45,
73
+ max_output_boxes=100,
74
+ plugin_version='1',
75
+ score_activation=0,
76
+ score_threshold=0.25,
77
+ ):
78
+ """Forward function of TRTEfficientNMSop."""
79
+ batch_size, num_boxes, num_classes = scores.shape
80
+ num_det = torch.randint(
81
+ 0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
82
+ det_boxes = torch.randn(batch_size, max_output_boxes, 4)
83
+ det_scores = torch.randn(batch_size, max_output_boxes)
84
+ det_classes = torch.randint(
85
+ 0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
86
+ return num_det, det_boxes, det_scores, det_classes
87
+
88
+ @staticmethod
89
+ def symbolic(g,
90
+ boxes,
91
+ scores,
92
+ background_class=-1,
93
+ box_coding=0,
94
+ iou_threshold=0.45,
95
+ max_output_boxes=100,
96
+ plugin_version='1',
97
+ score_activation=0,
98
+ score_threshold=0.25):
99
+ """Symbolic function of TRTEfficientNMSop."""
100
+ out = g.op(
101
+ 'TRT::EfficientNMS_TRT',
102
+ boxes,
103
+ scores,
104
+ background_class_i=background_class,
105
+ box_coding_i=box_coding,
106
+ iou_threshold_f=iou_threshold,
107
+ max_output_boxes_i=max_output_boxes,
108
+ plugin_version_s=plugin_version,
109
+ score_activation_i=score_activation,
110
+ score_threshold_f=score_threshold,
111
+ outputs=4)
112
+ nums, boxes, scores, classes = out
113
+ return nums, boxes, scores, classes
mmyolo/deploy/object_detection.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Callable, Dict, Optional
3
+
4
+ import torch
5
+ from mmdeploy.codebase.base import CODEBASE, MMCodebase
6
+ from mmdeploy.codebase.mmdet.deploy import ObjectDetection
7
+ from mmdeploy.utils import Codebase, Task
8
+ from mmengine import Config
9
+ from mmengine.registry import Registry
10
+
11
+ MMYOLO_TASK = Registry('mmyolo_tasks')
12
+
13
+
14
+ @CODEBASE.register_module(Codebase.MMYOLO.value)
15
+ class MMYOLO(MMCodebase):
16
+ """MMYOLO codebase class."""
17
+
18
+ task_registry = MMYOLO_TASK
19
+
20
+ @classmethod
21
+ def register_deploy_modules(cls):
22
+ """register all rewriters for mmdet."""
23
+ import mmdeploy.codebase.mmdet.models # noqa: F401
24
+ import mmdeploy.codebase.mmdet.ops # noqa: F401
25
+ import mmdeploy.codebase.mmdet.structures # noqa: F401
26
+
27
+ @classmethod
28
+ def register_all_modules(cls):
29
+ """register all modules."""
30
+ from mmdet.utils.setup_env import \
31
+ register_all_modules as register_all_modules_mmdet
32
+
33
+ from mmyolo.utils.setup_env import \
34
+ register_all_modules as register_all_modules_mmyolo
35
+
36
+ cls.register_deploy_modules()
37
+ register_all_modules_mmyolo(True)
38
+ register_all_modules_mmdet(False)
39
+
40
+
41
+ def _get_dataset_metainfo(model_cfg: Config):
42
+ """Get metainfo of dataset.
43
+
44
+ Args:
45
+ model_cfg Config: Input model Config object.
46
+
47
+ Returns:
48
+ list[str]: A list of string specifying names of different class.
49
+ """
50
+ from mmyolo import datasets # noqa
51
+ from mmyolo.registry import DATASETS
52
+
53
+ module_dict = DATASETS.module_dict
54
+ for dataloader_name in [
55
+ 'test_dataloader', 'val_dataloader', 'train_dataloader'
56
+ ]:
57
+ if dataloader_name not in model_cfg:
58
+ continue
59
+ dataloader_cfg = model_cfg[dataloader_name]
60
+ dataset_cfg = dataloader_cfg.dataset
61
+ dataset_cls = module_dict.get(dataset_cfg.type, None)
62
+ if dataset_cls is None:
63
+ continue
64
+ if hasattr(dataset_cls, '_load_metainfo') and isinstance(
65
+ dataset_cls._load_metainfo, Callable):
66
+ meta = dataset_cls._load_metainfo(
67
+ dataset_cfg.get('metainfo', None))
68
+ if meta is not None:
69
+ return meta
70
+ if hasattr(dataset_cls, 'METAINFO'):
71
+ return dataset_cls.METAINFO
72
+
73
+ return None
74
+
75
+
76
+ @MMYOLO_TASK.register_module(Task.OBJECT_DETECTION.value)
77
+ class YOLOObjectDetection(ObjectDetection):
78
+ """YOLO Object Detection task."""
79
+
80
+ def get_visualizer(self, name: str, save_dir: str):
81
+ """Get visualizer.
82
+
83
+ Args:
84
+ name (str): Name of visualizer.
85
+ save_dir (str): Directory to save visualization results.
86
+
87
+ Returns:
88
+ Visualizer: A visualizer instance.
89
+ """
90
+ from mmdet.visualization import DetLocalVisualizer # noqa: F401,F403
91
+ metainfo = _get_dataset_metainfo(self.model_cfg)
92
+ visualizer = super().get_visualizer(name, save_dir)
93
+ if metainfo is not None:
94
+ visualizer.dataset_meta = metainfo
95
+ return visualizer
96
+
97
+ def build_pytorch_model(self,
98
+ model_checkpoint: Optional[str] = None,
99
+ cfg_options: Optional[Dict] = None,
100
+ **kwargs) -> torch.nn.Module:
101
+ """Initialize torch model.
102
+
103
+ Args:
104
+ model_checkpoint (str): The checkpoint file of torch model,
105
+ defaults to `None`.
106
+ cfg_options (dict): Optional config key-pair parameters.
107
+ Returns:
108
+ nn.Module: An initialized torch model generated by other OpenMMLab
109
+ codebases.
110
+ """
111
+ from copy import deepcopy
112
+
113
+ from mmengine.model import revert_sync_batchnorm
114
+ from mmengine.registry import MODELS
115
+
116
+ from mmyolo.utils import switch_to_deploy
117
+
118
+ model = deepcopy(self.model_cfg.model)
119
+ preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {}))
120
+ preprocess_cfg.update(
121
+ deepcopy(self.model_cfg.get('data_preprocessor', {})))
122
+ model.setdefault('data_preprocessor', preprocess_cfg)
123
+ model = MODELS.build(model)
124
+ if model_checkpoint is not None:
125
+ from mmengine.runner.checkpoint import load_checkpoint
126
+ load_checkpoint(model, model_checkpoint, map_location=self.device)
127
+
128
+ model = revert_sync_batchnorm(model)
129
+ switch_to_deploy(model)
130
+ model = model.to(self.device)
131
+ model.eval()
132
+ return model
mmyolo/engine/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .hooks import * # noqa: F401,F403
3
+ from .optimizers import * # noqa: F401,F403
mmyolo/engine/hooks/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .ppyoloe_param_scheduler_hook import PPYOLOEParamSchedulerHook
3
+ from .switch_to_deploy_hook import SwitchToDeployHook
4
+ from .yolov5_param_scheduler_hook import YOLOv5ParamSchedulerHook
5
+ from .yolox_mode_switch_hook import YOLOXModeSwitchHook
6
+
7
+ __all__ = [
8
+ 'YOLOv5ParamSchedulerHook', 'YOLOXModeSwitchHook', 'SwitchToDeployHook',
9
+ 'PPYOLOEParamSchedulerHook'
10
+ ]
mmyolo/engine/hooks/ppyoloe_param_scheduler_hook.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ from typing import Optional
4
+
5
+ from mmengine.hooks import ParamSchedulerHook
6
+ from mmengine.runner import Runner
7
+
8
+ from mmyolo.registry import HOOKS
9
+
10
+
11
+ @HOOKS.register_module()
12
+ class PPYOLOEParamSchedulerHook(ParamSchedulerHook):
13
+ """A hook to update learning rate and momentum in optimizer of PPYOLOE. We
14
+ use this hook to implement adaptive computation for `warmup_total_iters`,
15
+ which is not possible with the built-in ParamScheduler in mmyolo.
16
+
17
+ Args:
18
+ warmup_min_iter (int): Minimum warmup iters. Defaults to 1000.
19
+ start_factor (float): The number we multiply learning rate in the
20
+ first epoch. The multiplication factor changes towards end_factor
21
+ in the following epochs. Defaults to 0.
22
+ warmup_epochs (int): Epochs for warmup. Defaults to 5.
23
+ min_lr_ratio (float): Minimum learning rate ratio.
24
+ total_epochs (int): In PPYOLOE, `total_epochs` is set to
25
+ training_epochs x 1.2. Defaults to 360.
26
+ """
27
+ priority = 9
28
+
29
+ def __init__(self,
30
+ warmup_min_iter: int = 1000,
31
+ start_factor: float = 0.,
32
+ warmup_epochs: int = 5,
33
+ min_lr_ratio: float = 0.0,
34
+ total_epochs: int = 360):
35
+
36
+ self.warmup_min_iter = warmup_min_iter
37
+ self.start_factor = start_factor
38
+ self.warmup_epochs = warmup_epochs
39
+ self.min_lr_ratio = min_lr_ratio
40
+ self.total_epochs = total_epochs
41
+
42
+ self._warmup_end = False
43
+ self._base_lr = None
44
+
45
+ def before_train(self, runner: Runner):
46
+ """Operations before train.
47
+
48
+ Args:
49
+ runner (Runner): The runner of the training process.
50
+ """
51
+ optimizer = runner.optim_wrapper.optimizer
52
+ for group in optimizer.param_groups:
53
+ # If the param is never be scheduled, record the current value
54
+ # as the initial value.
55
+ group.setdefault('initial_lr', group['lr'])
56
+
57
+ self._base_lr = [
58
+ group['initial_lr'] for group in optimizer.param_groups
59
+ ]
60
+ self._min_lr = [i * self.min_lr_ratio for i in self._base_lr]
61
+
62
+ def before_train_iter(self,
63
+ runner: Runner,
64
+ batch_idx: int,
65
+ data_batch: Optional[dict] = None):
66
+ """Operations before each training iteration.
67
+
68
+ Args:
69
+ runner (Runner): The runner of the training process.
70
+ batch_idx (int): The index of the current batch in the train loop.
71
+ data_batch (dict or tuple or list, optional): Data from dataloader.
72
+ """
73
+ cur_iters = runner.iter
74
+ optimizer = runner.optim_wrapper.optimizer
75
+ dataloader_len = len(runner.train_dataloader)
76
+
77
+ # The minimum warmup is self.warmup_min_iter
78
+ warmup_total_iters = max(
79
+ round(self.warmup_epochs * dataloader_len), self.warmup_min_iter)
80
+
81
+ if cur_iters <= warmup_total_iters:
82
+ # warm up
83
+ alpha = cur_iters / warmup_total_iters
84
+ factor = self.start_factor * (1 - alpha) + alpha
85
+
86
+ for group_idx, param in enumerate(optimizer.param_groups):
87
+ param['lr'] = self._base_lr[group_idx] * factor
88
+ else:
89
+ for group_idx, param in enumerate(optimizer.param_groups):
90
+ total_iters = self.total_epochs * dataloader_len
91
+ lr = self._min_lr[group_idx] + (
92
+ self._base_lr[group_idx] -
93
+ self._min_lr[group_idx]) * 0.5 * (
94
+ math.cos((cur_iters - warmup_total_iters) * math.pi /
95
+ (total_iters - warmup_total_iters)) + 1.0)
96
+ param['lr'] = lr
mmyolo/engine/hooks/switch_to_deploy_hook.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+
3
+ from mmengine.hooks import Hook
4
+ from mmengine.runner import Runner
5
+
6
+ from mmyolo.registry import HOOKS
7
+ from mmyolo.utils import switch_to_deploy
8
+
9
+
10
+ @HOOKS.register_module()
11
+ class SwitchToDeployHook(Hook):
12
+ """Switch to deploy mode before testing.
13
+
14
+ This hook converts the multi-channel structure of the training network
15
+ (high performance) to the one-way structure of the testing network (fast
16
+ speed and memory saving).
17
+ """
18
+
19
+ def before_test_epoch(self, runner: Runner):
20
+ """Switch to deploy mode before testing."""
21
+ switch_to_deploy(runner.model)
mmyolo/engine/hooks/yolov5_param_scheduler_hook.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ from mmengine.hooks import ParamSchedulerHook
7
+ from mmengine.runner import Runner
8
+
9
+ from mmyolo.registry import HOOKS
10
+
11
+
12
+ def linear_fn(lr_factor: float, max_epochs: int):
13
+ """Generate linear function."""
14
+ return lambda x: (1 - x / max_epochs) * (1.0 - lr_factor) + lr_factor
15
+
16
+
17
+ def cosine_fn(lr_factor: float, max_epochs: int):
18
+ """Generate cosine function."""
19
+ return lambda x: (
20
+ (1 - math.cos(x * math.pi / max_epochs)) / 2) * (lr_factor - 1) + 1
21
+
22
+
23
+ @HOOKS.register_module()
24
+ class YOLOv5ParamSchedulerHook(ParamSchedulerHook):
25
+ """A hook to update learning rate and momentum in optimizer of YOLOv5."""
26
+ priority = 9
27
+
28
+ scheduler_maps = {'linear': linear_fn, 'cosine': cosine_fn}
29
+
30
+ def __init__(self,
31
+ scheduler_type: str = 'linear',
32
+ lr_factor: float = 0.01,
33
+ max_epochs: int = 300,
34
+ warmup_epochs: int = 3,
35
+ warmup_bias_lr: float = 0.1,
36
+ warmup_momentum: float = 0.8,
37
+ warmup_mim_iter: int = 1000,
38
+ **kwargs):
39
+
40
+ assert scheduler_type in self.scheduler_maps
41
+
42
+ self.warmup_epochs = warmup_epochs
43
+ self.warmup_bias_lr = warmup_bias_lr
44
+ self.warmup_momentum = warmup_momentum
45
+ self.warmup_mim_iter = warmup_mim_iter
46
+
47
+ kwargs.update({'lr_factor': lr_factor, 'max_epochs': max_epochs})
48
+ self.scheduler_fn = self.scheduler_maps[scheduler_type](**kwargs)
49
+
50
+ self._warmup_end = False
51
+ self._base_lr = None
52
+ self._base_momentum = None
53
+
54
+ def before_train(self, runner: Runner):
55
+ """Operations before train.
56
+
57
+ Args:
58
+ runner (Runner): The runner of the training process.
59
+ """
60
+ optimizer = runner.optim_wrapper.optimizer
61
+ for group in optimizer.param_groups:
62
+ # If the param is never be scheduled, record the current value
63
+ # as the initial value.
64
+ group.setdefault('initial_lr', group['lr'])
65
+ group.setdefault('initial_momentum', group.get('momentum', -1))
66
+
67
+ self._base_lr = [
68
+ group['initial_lr'] for group in optimizer.param_groups
69
+ ]
70
+ self._base_momentum = [
71
+ group['initial_momentum'] for group in optimizer.param_groups
72
+ ]
73
+
74
+ def before_train_iter(self,
75
+ runner: Runner,
76
+ batch_idx: int,
77
+ data_batch: Optional[dict] = None):
78
+ """Operations before each training iteration.
79
+
80
+ Args:
81
+ runner (Runner): The runner of the training process.
82
+ batch_idx (int): The index of the current batch in the train loop.
83
+ data_batch (dict or tuple or list, optional): Data from dataloader.
84
+ """
85
+ cur_iters = runner.iter
86
+ cur_epoch = runner.epoch
87
+ optimizer = runner.optim_wrapper.optimizer
88
+
89
+ # The minimum warmup is self.warmup_mim_iter
90
+ warmup_total_iters = max(
91
+ round(self.warmup_epochs * len(runner.train_dataloader)),
92
+ self.warmup_mim_iter)
93
+
94
+ if cur_iters <= warmup_total_iters:
95
+ xp = [0, warmup_total_iters]
96
+ for group_idx, param in enumerate(optimizer.param_groups):
97
+ if group_idx == 2:
98
+ # bias learning rate will be handled specially
99
+ yp = [
100
+ self.warmup_bias_lr,
101
+ self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
102
+ ]
103
+ else:
104
+ yp = [
105
+ 0.0,
106
+ self._base_lr[group_idx] * self.scheduler_fn(cur_epoch)
107
+ ]
108
+ param['lr'] = np.interp(cur_iters, xp, yp)
109
+
110
+ if 'momentum' in param:
111
+ param['momentum'] = np.interp(
112
+ cur_iters, xp,
113
+ [self.warmup_momentum, self._base_momentum[group_idx]])
114
+ else:
115
+ self._warmup_end = True
116
+
117
+ def after_train_epoch(self, runner: Runner):
118
+ """Operations after each training epoch.
119
+
120
+ Args:
121
+ runner (Runner): The runner of the training process.
122
+ """
123
+ if not self._warmup_end:
124
+ return
125
+
126
+ cur_epoch = runner.epoch
127
+ optimizer = runner.optim_wrapper.optimizer
128
+ for group_idx, param in enumerate(optimizer.param_groups):
129
+ param['lr'] = self._base_lr[group_idx] * self.scheduler_fn(
130
+ cur_epoch)
mmyolo/engine/hooks/yolox_mode_switch_hook.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ from typing import Sequence
4
+
5
+ from mmengine.hooks import Hook
6
+ from mmengine.model import is_model_wrapper
7
+ from mmengine.runner import Runner
8
+
9
+ from mmyolo.registry import HOOKS
10
+
11
+
12
+ @HOOKS.register_module()
13
+ class YOLOXModeSwitchHook(Hook):
14
+ """Switch the mode of YOLOX during training.
15
+
16
+ This hook turns off the mosaic and mixup data augmentation and switches
17
+ to use L1 loss in bbox_head.
18
+
19
+ Args:
20
+ num_last_epochs (int): The number of latter epochs in the end of the
21
+ training to close the data augmentation and switch to L1 loss.
22
+ Defaults to 15.
23
+ """
24
+
25
+ def __init__(self,
26
+ num_last_epochs: int = 15,
27
+ new_train_pipeline: Sequence[dict] = None):
28
+ self.num_last_epochs = num_last_epochs
29
+ self.new_train_pipeline_cfg = new_train_pipeline
30
+
31
+ def before_train_epoch(self, runner: Runner):
32
+ """Close mosaic and mixup augmentation and switches to use L1 loss."""
33
+ epoch = runner.epoch
34
+ model = runner.model
35
+ if is_model_wrapper(model):
36
+ model = model.module
37
+
38
+ if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
39
+ runner.logger.info(f'New Pipeline: {self.new_train_pipeline_cfg}')
40
+
41
+ train_dataloader_cfg = copy.deepcopy(runner.cfg.train_dataloader)
42
+ train_dataloader_cfg.dataset.pipeline = self.new_train_pipeline_cfg
43
+ # Note: Why rebuild the dataset?
44
+ # When build_dataloader will make a deep copy of the dataset,
45
+ # it will lead to potential risks, such as the global instance
46
+ # object FileClient data is disordered.
47
+ # This problem needs to be solved in the future.
48
+ new_train_dataloader = Runner.build_dataloader(
49
+ train_dataloader_cfg)
50
+ runner.train_loop.dataloader = new_train_dataloader
51
+
52
+ runner.logger.info('recreate the dataloader!')
53
+ runner.logger.info('Add additional bbox reg loss now!')
54
+ model.bbox_head.use_bbox_aux = True
mmyolo/engine/optimizers/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .yolov5_optim_constructor import YOLOv5OptimizerConstructor
3
+ from .yolov7_optim_wrapper_constructor import YOLOv7OptimWrapperConstructor
4
+
5
+ __all__ = ['YOLOv5OptimizerConstructor', 'YOLOv7OptimWrapperConstructor']
mmyolo/engine/optimizers/yolov5_optim_constructor.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional
3
+
4
+ import torch.nn as nn
5
+ from mmengine.dist import get_world_size
6
+ from mmengine.logging import print_log
7
+ from mmengine.model import is_model_wrapper
8
+ from mmengine.optim import OptimWrapper
9
+
10
+ from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
11
+ OPTIMIZERS)
12
+
13
+
14
+ @OPTIM_WRAPPER_CONSTRUCTORS.register_module()
15
+ class YOLOv5OptimizerConstructor:
16
+ """YOLOv5 constructor for optimizers.
17
+
18
+ It has the following functions:
19
+
20
+ - divides the optimizer parameters into 3 groups:
21
+ Conv, Bias and BN
22
+
23
+ - support `weight_decay` parameter adaption based on
24
+ `batch_size_per_gpu`
25
+
26
+ Args:
27
+ optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
28
+ Positional fields are
29
+
30
+ - ``type``: class name of the OptimizerWrapper
31
+ - ``optimizer``: The configuration of optimizer.
32
+
33
+ Optional fields are
34
+
35
+ - any arguments of the corresponding optimizer wrapper type,
36
+ e.g., accumulative_counts, clip_grad, etc.
37
+
38
+ The positional fields of ``optimizer`` are
39
+
40
+ - `type`: class name of the optimizer.
41
+
42
+ Optional fields are
43
+
44
+ - any arguments of the corresponding optimizer type, e.g.,
45
+ lr, weight_decay, momentum, etc.
46
+
47
+ paramwise_cfg (dict, optional): Parameter-wise options. Must include
48
+ `base_total_batch_size` if not None. If the total input batch
49
+ is smaller than `base_total_batch_size`, the `weight_decay`
50
+ parameter will be kept unchanged, otherwise linear scaling.
51
+
52
+ Example:
53
+ >>> model = torch.nn.modules.Conv1d(1, 1, 1)
54
+ >>> optim_wrapper_cfg = dict(
55
+ >>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
56
+ >>> momentum=0.9, weight_decay=0.0001, batch_size_per_gpu=16))
57
+ >>> paramwise_cfg = dict(base_total_batch_size=64)
58
+ >>> optim_wrapper_builder = YOLOv5OptimizerConstructor(
59
+ >>> optim_wrapper_cfg, paramwise_cfg)
60
+ >>> optim_wrapper = optim_wrapper_builder(model)
61
+ """
62
+
63
+ def __init__(self,
64
+ optim_wrapper_cfg: dict,
65
+ paramwise_cfg: Optional[dict] = None):
66
+ if paramwise_cfg is None:
67
+ paramwise_cfg = {'base_total_batch_size': 64}
68
+ assert 'base_total_batch_size' in paramwise_cfg
69
+
70
+ if not isinstance(optim_wrapper_cfg, dict):
71
+ raise TypeError('optimizer_cfg should be a dict',
72
+ f'but got {type(optim_wrapper_cfg)}')
73
+ assert 'optimizer' in optim_wrapper_cfg, (
74
+ '`optim_wrapper_cfg` must contain "optimizer" config')
75
+
76
+ self.optim_wrapper_cfg = optim_wrapper_cfg
77
+ self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer')
78
+ self.base_total_batch_size = paramwise_cfg['base_total_batch_size']
79
+
80
+ def __call__(self, model: nn.Module) -> OptimWrapper:
81
+ if is_model_wrapper(model):
82
+ model = model.module
83
+ optimizer_cfg = self.optimizer_cfg.copy()
84
+ weight_decay = optimizer_cfg.pop('weight_decay', 0)
85
+
86
+ if 'batch_size_per_gpu' in optimizer_cfg:
87
+ batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu')
88
+ # No scaling if total_batch_size is less than
89
+ # base_total_batch_size, otherwise linear scaling.
90
+ total_batch_size = get_world_size() * batch_size_per_gpu
91
+ accumulate = max(
92
+ round(self.base_total_batch_size / total_batch_size), 1)
93
+ scale_factor = total_batch_size * \
94
+ accumulate / self.base_total_batch_size
95
+
96
+ if scale_factor != 1:
97
+ weight_decay *= scale_factor
98
+ print_log(f'Scaled weight_decay to {weight_decay}', 'current')
99
+
100
+ params_groups = [], [], []
101
+
102
+ for v in model.modules():
103
+ if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
104
+ params_groups[2].append(v.bias)
105
+ # Includes SyncBatchNorm
106
+ if isinstance(v, nn.modules.batchnorm._NormBase):
107
+ params_groups[1].append(v.weight)
108
+ elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
109
+ params_groups[0].append(v.weight)
110
+
111
+ # Note: Make sure bias is in the last parameter group
112
+ optimizer_cfg['params'] = []
113
+ # conv
114
+ optimizer_cfg['params'].append({
115
+ 'params': params_groups[0],
116
+ 'weight_decay': weight_decay
117
+ })
118
+ # bn
119
+ optimizer_cfg['params'].append({'params': params_groups[1]})
120
+ # bias
121
+ optimizer_cfg['params'].append({'params': params_groups[2]})
122
+
123
+ print_log(
124
+ 'Optimizer groups: %g .bias, %g conv.weight, %g other' %
125
+ (len(params_groups[2]), len(params_groups[0]), len(
126
+ params_groups[1])), 'current')
127
+ del params_groups
128
+
129
+ optimizer = OPTIMIZERS.build(optimizer_cfg)
130
+ optim_wrapper = OPTIM_WRAPPERS.build(
131
+ self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
132
+ return optim_wrapper
mmyolo/engine/optimizers/yolov7_optim_wrapper_constructor.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Optional
3
+
4
+ import torch.nn as nn
5
+ from mmengine.dist import get_world_size
6
+ from mmengine.logging import print_log
7
+ from mmengine.model import is_model_wrapper
8
+ from mmengine.optim import OptimWrapper
9
+
10
+ from mmyolo.models.dense_heads.yolov7_head import ImplicitA, ImplicitM
11
+ from mmyolo.registry import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS,
12
+ OPTIMIZERS)
13
+
14
+
15
+ # TODO: Consider merging into YOLOv5OptimizerConstructor
16
+ @OPTIM_WRAPPER_CONSTRUCTORS.register_module()
17
+ class YOLOv7OptimWrapperConstructor:
18
+ """YOLOv7 constructor for optimizer wrappers.
19
+
20
+ It has the following functions:
21
+
22
+ - divides the optimizer parameters into 3 groups:
23
+ Conv, Bias and BN/ImplicitA/ImplicitM
24
+
25
+ - support `weight_decay` parameter adaption based on
26
+ `batch_size_per_gpu`
27
+
28
+ Args:
29
+ optim_wrapper_cfg (dict): The config dict of the optimizer wrapper.
30
+ Positional fields are
31
+
32
+ - ``type``: class name of the OptimizerWrapper
33
+ - ``optimizer``: The configuration of optimizer.
34
+
35
+ Optional fields are
36
+
37
+ - any arguments of the corresponding optimizer wrapper type,
38
+ e.g., accumulative_counts, clip_grad, etc.
39
+
40
+ The positional fields of ``optimizer`` are
41
+
42
+ - `type`: class name of the optimizer.
43
+
44
+ Optional fields are
45
+
46
+ - any arguments of the corresponding optimizer type, e.g.,
47
+ lr, weight_decay, momentum, etc.
48
+
49
+ paramwise_cfg (dict, optional): Parameter-wise options. Must include
50
+ `base_total_batch_size` if not None. If the total input batch
51
+ is smaller than `base_total_batch_size`, the `weight_decay`
52
+ parameter will be kept unchanged, otherwise linear scaling.
53
+
54
+ Example:
55
+ >>> model = torch.nn.modules.Conv1d(1, 1, 1)
56
+ >>> optim_wrapper_cfg = dict(
57
+ >>> dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01,
58
+ >>> momentum=0.9, weight_decay=0.0001, batch_size_per_gpu=16))
59
+ >>> paramwise_cfg = dict(base_total_batch_size=64)
60
+ >>> optim_wrapper_builder = YOLOv7OptimWrapperConstructor(
61
+ >>> optim_wrapper_cfg, paramwise_cfg)
62
+ >>> optim_wrapper = optim_wrapper_builder(model)
63
+ """
64
+
65
+ def __init__(self,
66
+ optim_wrapper_cfg: dict,
67
+ paramwise_cfg: Optional[dict] = None):
68
+ if paramwise_cfg is None:
69
+ paramwise_cfg = {'base_total_batch_size': 64}
70
+ assert 'base_total_batch_size' in paramwise_cfg
71
+
72
+ if not isinstance(optim_wrapper_cfg, dict):
73
+ raise TypeError('optimizer_cfg should be a dict',
74
+ f'but got {type(optim_wrapper_cfg)}')
75
+ assert 'optimizer' in optim_wrapper_cfg, (
76
+ '`optim_wrapper_cfg` must contain "optimizer" config')
77
+
78
+ self.optim_wrapper_cfg = optim_wrapper_cfg
79
+ self.optimizer_cfg = self.optim_wrapper_cfg.pop('optimizer')
80
+ self.base_total_batch_size = paramwise_cfg['base_total_batch_size']
81
+
82
+ def __call__(self, model: nn.Module) -> OptimWrapper:
83
+ if is_model_wrapper(model):
84
+ model = model.module
85
+ optimizer_cfg = self.optimizer_cfg.copy()
86
+ weight_decay = optimizer_cfg.pop('weight_decay', 0)
87
+
88
+ if 'batch_size_per_gpu' in optimizer_cfg:
89
+ batch_size_per_gpu = optimizer_cfg.pop('batch_size_per_gpu')
90
+ # No scaling if total_batch_size is less than
91
+ # base_total_batch_size, otherwise linear scaling.
92
+ total_batch_size = get_world_size() * batch_size_per_gpu
93
+ accumulate = max(
94
+ round(self.base_total_batch_size / total_batch_size), 1)
95
+ scale_factor = total_batch_size * \
96
+ accumulate / self.base_total_batch_size
97
+
98
+ if scale_factor != 1:
99
+ weight_decay *= scale_factor
100
+ print_log(f'Scaled weight_decay to {weight_decay}', 'current')
101
+
102
+ params_groups = [], [], []
103
+ for v in model.modules():
104
+ # no decay
105
+ # Caution: Coupling with model
106
+ if isinstance(v, (ImplicitA, ImplicitM)):
107
+ params_groups[0].append(v.implicit)
108
+ elif isinstance(v, nn.modules.batchnorm._NormBase):
109
+ params_groups[0].append(v.weight)
110
+ # apply decay
111
+ elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):
112
+ params_groups[1].append(v.weight) # apply decay
113
+
114
+ # biases, no decay
115
+ if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):
116
+ params_groups[2].append(v.bias)
117
+
118
+ # Note: Make sure bias is in the last parameter group
119
+ optimizer_cfg['params'] = []
120
+ # conv
121
+ optimizer_cfg['params'].append({
122
+ 'params': params_groups[1],
123
+ 'weight_decay': weight_decay
124
+ })
125
+ # bn ...
126
+ optimizer_cfg['params'].append({'params': params_groups[0]})
127
+ # bias
128
+ optimizer_cfg['params'].append({'params': params_groups[2]})
129
+
130
+ print_log(
131
+ 'Optimizer groups: %g .bias, %g conv.weight, %g other' %
132
+ (len(params_groups[2]), len(params_groups[1]), len(
133
+ params_groups[0])), 'current')
134
+ del params_groups
135
+
136
+ optimizer = OPTIMIZERS.build(optimizer_cfg)
137
+ optim_wrapper = OPTIM_WRAPPERS.build(
138
+ self.optim_wrapper_cfg, default_args=dict(optimizer=optimizer))
139
+ return optim_wrapper
mmyolo/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .backbones import * # noqa: F401,F403
3
+ from .data_preprocessors import * # noqa: F401,F403
4
+ from .dense_heads import * # noqa: F401,F403
5
+ from .detectors import * # noqa: F401,F403
6
+ from .layers import * # noqa: F401,F403
7
+ from .losses import * # noqa: F401,F403
8
+ from .necks import * # noqa: F401,F403
9
+ from .plugins import * # noqa: F401,F403
10
+ from .task_modules import * # noqa: F401,F403
mmyolo/models/backbones/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .base_backbone import BaseBackbone
3
+ from .csp_darknet import YOLOv5CSPDarknet, YOLOv8CSPDarknet, YOLOXCSPDarknet
4
+ from .csp_resnet import PPYOLOECSPResNet
5
+ from .cspnext import CSPNeXt
6
+ from .efficient_rep import YOLOv6CSPBep, YOLOv6EfficientRep
7
+ from .yolov7_backbone import YOLOv7Backbone
8
+
9
+ __all__ = [
10
+ 'YOLOv5CSPDarknet', 'BaseBackbone', 'YOLOv6EfficientRep', 'YOLOv6CSPBep',
11
+ 'YOLOXCSPDarknet', 'CSPNeXt', 'YOLOv7Backbone', 'PPYOLOECSPResNet',
12
+ 'YOLOv8CSPDarknet'
13
+ ]
mmyolo/models/backbones/base_backbone.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from abc import ABCMeta, abstractmethod
3
+ from typing import List, Sequence, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from mmcv.cnn import build_plugin_layer
8
+ from mmdet.utils import ConfigType, OptMultiConfig
9
+ from mmengine.model import BaseModule
10
+ from torch.nn.modules.batchnorm import _BatchNorm
11
+
12
+ from mmyolo.registry import MODELS
13
+
14
+
15
+ @MODELS.register_module()
16
+ class BaseBackbone(BaseModule, metaclass=ABCMeta):
17
+ """BaseBackbone backbone used in YOLO series.
18
+
19
+ .. code:: text
20
+
21
+ Backbone model structure diagram
22
+ +-----------+
23
+ | input |
24
+ +-----------+
25
+ v
26
+ +-----------+
27
+ | stem |
28
+ | layer |
29
+ +-----------+
30
+ v
31
+ +-----------+
32
+ | stage |
33
+ | layer 1 |
34
+ +-----------+
35
+ v
36
+ +-----------+
37
+ | stage |
38
+ | layer 2 |
39
+ +-----------+
40
+ v
41
+ ......
42
+ v
43
+ +-----------+
44
+ | stage |
45
+ | layer n |
46
+ +-----------+
47
+ In P5 model, n=4
48
+ In P6 model, n=5
49
+
50
+ Args:
51
+ arch_setting (list): Architecture of BaseBackbone.
52
+ plugins (list[dict]): List of plugins for stages, each dict contains:
53
+
54
+ - cfg (dict, required): Cfg dict to build plugin.
55
+ - stages (tuple[bool], optional): Stages to apply plugin, length
56
+ should be same as 'num_stages'.
57
+ deepen_factor (float): Depth multiplier, multiply number of
58
+ blocks in CSP layer by this amount. Defaults to 1.0.
59
+ widen_factor (float): Width multiplier, multiply number of
60
+ channels in each layer by this amount. Defaults to 1.0.
61
+ input_channels: Number of input image channels. Defaults to 3.
62
+ out_indices (Sequence[int]): Output from which stages.
63
+ Defaults to (2, 3, 4).
64
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
65
+ mode). -1 means not freezing any parameters. Defaults to -1.
66
+ norm_cfg (dict): Dictionary to construct and config norm layer.
67
+ Defaults to None.
68
+ act_cfg (dict): Config dict for activation layer.
69
+ Defaults to None.
70
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
71
+ freeze running stats (mean and var). Note: Effect on Batch Norm
72
+ and its variants only. Defaults to False.
73
+ init_cfg (dict or list[dict], optional): Initialization config dict.
74
+ Defaults to None.
75
+ """
76
+
77
+ def __init__(self,
78
+ arch_setting: list,
79
+ deepen_factor: float = 1.0,
80
+ widen_factor: float = 1.0,
81
+ input_channels: int = 3,
82
+ out_indices: Sequence[int] = (2, 3, 4),
83
+ frozen_stages: int = -1,
84
+ plugins: Union[dict, List[dict]] = None,
85
+ norm_cfg: ConfigType = None,
86
+ act_cfg: ConfigType = None,
87
+ norm_eval: bool = False,
88
+ init_cfg: OptMultiConfig = None):
89
+ super().__init__(init_cfg)
90
+ self.num_stages = len(arch_setting)
91
+ self.arch_setting = arch_setting
92
+
93
+ assert set(out_indices).issubset(
94
+ i for i in range(len(arch_setting) + 1))
95
+
96
+ if frozen_stages not in range(-1, len(arch_setting) + 1):
97
+ raise ValueError('"frozen_stages" must be in range(-1, '
98
+ 'len(arch_setting) + 1). But received '
99
+ f'{frozen_stages}')
100
+
101
+ self.input_channels = input_channels
102
+ self.out_indices = out_indices
103
+ self.frozen_stages = frozen_stages
104
+ self.widen_factor = widen_factor
105
+ self.deepen_factor = deepen_factor
106
+ self.norm_eval = norm_eval
107
+ self.norm_cfg = norm_cfg
108
+ self.act_cfg = act_cfg
109
+ self.plugins = plugins
110
+
111
+ self.stem = self.build_stem_layer()
112
+ self.layers = ['stem']
113
+
114
+ for idx, setting in enumerate(arch_setting):
115
+ stage = []
116
+ stage += self.build_stage_layer(idx, setting)
117
+ if plugins is not None:
118
+ stage += self.make_stage_plugins(plugins, idx, setting)
119
+ self.add_module(f'stage{idx + 1}', nn.Sequential(*stage))
120
+ self.layers.append(f'stage{idx + 1}')
121
+
122
+ @abstractmethod
123
+ def build_stem_layer(self):
124
+ """Build a stem layer."""
125
+ pass
126
+
127
+ @abstractmethod
128
+ def build_stage_layer(self, stage_idx: int, setting: list):
129
+ """Build a stage layer.
130
+
131
+ Args:
132
+ stage_idx (int): The index of a stage layer.
133
+ setting (list): The architecture setting of a stage layer.
134
+ """
135
+ pass
136
+
137
+ def make_stage_plugins(self, plugins, stage_idx, setting):
138
+ """Make plugins for backbone ``stage_idx`` th stage.
139
+
140
+ Currently we support to insert ``context_block``,
141
+ ``empirical_attention_block``, ``nonlocal_block``, ``dropout_block``
142
+ into the backbone.
143
+
144
+
145
+ An example of plugins format could be:
146
+
147
+ Examples:
148
+ >>> plugins=[
149
+ ... dict(cfg=dict(type='xxx', arg1='xxx'),
150
+ ... stages=(False, True, True, True)),
151
+ ... dict(cfg=dict(type='yyy'),
152
+ ... stages=(True, True, True, True)),
153
+ ... ]
154
+ >>> model = YOLOv5CSPDarknet()
155
+ >>> stage_plugins = model.make_stage_plugins(plugins, 0, setting)
156
+ >>> assert len(stage_plugins) == 1
157
+
158
+ Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
159
+
160
+ .. code-block:: none
161
+
162
+ conv1 -> conv2 -> conv3 -> yyy
163
+
164
+ Suppose ``stage_idx=1``, the structure of blocks in the stage would be:
165
+
166
+ .. code-block:: none
167
+
168
+ conv1 -> conv2 -> conv3 -> xxx -> yyy
169
+
170
+
171
+ Args:
172
+ plugins (list[dict]): List of plugins cfg to build. The postfix is
173
+ required if multiple same type plugins are inserted.
174
+ stage_idx (int): Index of stage to build
175
+ If stages is missing, the plugin would be applied to all
176
+ stages.
177
+ setting (list): The architecture setting of a stage layer.
178
+
179
+ Returns:
180
+ list[nn.Module]: Plugins for current stage
181
+ """
182
+ # TODO: It is not general enough to support any channel and needs
183
+ # to be refactored
184
+ in_channels = int(setting[1] * self.widen_factor)
185
+ plugin_layers = []
186
+ for plugin in plugins:
187
+ plugin = plugin.copy()
188
+ stages = plugin.pop('stages', None)
189
+ assert stages is None or len(stages) == self.num_stages
190
+ if stages is None or stages[stage_idx]:
191
+ name, layer = build_plugin_layer(
192
+ plugin['cfg'], in_channels=in_channels)
193
+ plugin_layers.append(layer)
194
+ return plugin_layers
195
+
196
+ def _freeze_stages(self):
197
+ """Freeze the parameters of the specified stage so that they are no
198
+ longer updated."""
199
+ if self.frozen_stages >= 0:
200
+ for i in range(self.frozen_stages + 1):
201
+ m = getattr(self, self.layers[i])
202
+ m.eval()
203
+ for param in m.parameters():
204
+ param.requires_grad = False
205
+
206
+ def train(self, mode: bool = True):
207
+ """Convert the model into training mode while keep normalization layer
208
+ frozen."""
209
+ super().train(mode)
210
+ self._freeze_stages()
211
+ if mode and self.norm_eval:
212
+ for m in self.modules():
213
+ if isinstance(m, _BatchNorm):
214
+ m.eval()
215
+
216
+ def forward(self, x: torch.Tensor) -> tuple:
217
+ """Forward batch_inputs from the data_preprocessor."""
218
+ outs = []
219
+ for i, layer_name in enumerate(self.layers):
220
+ layer = getattr(self, layer_name)
221
+ x = layer(x)
222
+ if i in self.out_indices:
223
+ outs.append(x)
224
+
225
+ return tuple(outs)
mmyolo/models/backbones/csp_darknet.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
7
+ from mmdet.models.backbones.csp_darknet import CSPLayer, Focus
8
+ from mmdet.utils import ConfigType, OptMultiConfig
9
+
10
+ from mmyolo.registry import MODELS
11
+ from ..layers import CSPLayerWithTwoConv, SPPFBottleneck
12
+ from ..utils import make_divisible, make_round
13
+ from .base_backbone import BaseBackbone
14
+
15
+
16
+ @MODELS.register_module()
17
+ class YOLOv5CSPDarknet(BaseBackbone):
18
+ """CSP-Darknet backbone used in YOLOv5.
19
+ Args:
20
+ arch (str): Architecture of CSP-Darknet, from {P5, P6}.
21
+ Defaults to P5.
22
+ plugins (list[dict]): List of plugins for stages, each dict contains:
23
+ - cfg (dict, required): Cfg dict to build plugin.
24
+ - stages (tuple[bool], optional): Stages to apply plugin, length
25
+ should be same as 'num_stages'.
26
+ deepen_factor (float): Depth multiplier, multiply number of
27
+ blocks in CSP layer by this amount. Defaults to 1.0.
28
+ widen_factor (float): Width multiplier, multiply number of
29
+ channels in each layer by this amount. Defaults to 1.0.
30
+ input_channels (int): Number of input image channels. Defaults to: 3.
31
+ out_indices (Tuple[int]): Output from which stages.
32
+ Defaults to (2, 3, 4).
33
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
34
+ mode). -1 means not freezing any parameters. Defaults to -1.
35
+ norm_cfg (dict): Dictionary to construct and config norm layer.
36
+ Defaults to dict(type='BN', requires_grad=True).
37
+ act_cfg (dict): Config dict for activation layer.
38
+ Defaults to dict(type='SiLU', inplace=True).
39
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
40
+ freeze running stats (mean and var). Note: Effect on Batch Norm
41
+ and its variants only. Defaults to False.
42
+ init_cfg (Union[dict,list[dict]], optional): Initialization config
43
+ dict. Defaults to None.
44
+ Example:
45
+ >>> from mmyolo.models import YOLOv5CSPDarknet
46
+ >>> import torch
47
+ >>> model = YOLOv5CSPDarknet()
48
+ >>> model.eval()
49
+ >>> inputs = torch.rand(1, 3, 416, 416)
50
+ >>> level_outputs = model(inputs)
51
+ >>> for level_out in level_outputs:
52
+ ... print(tuple(level_out.shape))
53
+ ...
54
+ (1, 256, 52, 52)
55
+ (1, 512, 26, 26)
56
+ (1, 1024, 13, 13)
57
+ """
58
+ # From left to right:
59
+ # in_channels, out_channels, num_blocks, add_identity, use_spp
60
+ arch_settings = {
61
+ 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
62
+ [256, 512, 9, True, False], [512, 1024, 3, True, True]],
63
+ 'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
64
+ [256, 512, 9, True, False], [512, 768, 3, True, False],
65
+ [768, 1024, 3, True, True]]
66
+ }
67
+
68
+ def __init__(self,
69
+ arch: str = 'P5',
70
+ plugins: Union[dict, List[dict]] = None,
71
+ deepen_factor: float = 1.0,
72
+ widen_factor: float = 1.0,
73
+ input_channels: int = 3,
74
+ out_indices: Tuple[int] = (2, 3, 4),
75
+ frozen_stages: int = -1,
76
+ norm_cfg: ConfigType = dict(
77
+ type='BN', momentum=0.03, eps=0.001),
78
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
79
+ norm_eval: bool = False,
80
+ init_cfg: OptMultiConfig = None):
81
+ super().__init__(
82
+ self.arch_settings[arch],
83
+ deepen_factor,
84
+ widen_factor,
85
+ input_channels=input_channels,
86
+ out_indices=out_indices,
87
+ plugins=plugins,
88
+ frozen_stages=frozen_stages,
89
+ norm_cfg=norm_cfg,
90
+ act_cfg=act_cfg,
91
+ norm_eval=norm_eval,
92
+ init_cfg=init_cfg)
93
+
94
+ def build_stem_layer(self) -> nn.Module:
95
+ """Build a stem layer."""
96
+ return ConvModule(
97
+ self.input_channels,
98
+ make_divisible(self.arch_setting[0][0], self.widen_factor),
99
+ kernel_size=6,
100
+ stride=2,
101
+ padding=2,
102
+ norm_cfg=self.norm_cfg,
103
+ act_cfg=self.act_cfg)
104
+
105
+ def build_stage_layer(self, stage_idx: int, setting: list) -> list:
106
+ """Build a stage layer.
107
+
108
+ Args:
109
+ stage_idx (int): The index of a stage layer.
110
+ setting (list): The architecture setting of a stage layer.
111
+ """
112
+ in_channels, out_channels, num_blocks, add_identity, use_spp = setting
113
+
114
+ in_channels = make_divisible(in_channels, self.widen_factor)
115
+ out_channels = make_divisible(out_channels, self.widen_factor)
116
+ num_blocks = make_round(num_blocks, self.deepen_factor)
117
+ stage = []
118
+ conv_layer = ConvModule(
119
+ in_channels,
120
+ out_channels,
121
+ kernel_size=3,
122
+ stride=2,
123
+ padding=1,
124
+ norm_cfg=self.norm_cfg,
125
+ act_cfg=self.act_cfg)
126
+ stage.append(conv_layer)
127
+ csp_layer = CSPLayer(
128
+ out_channels,
129
+ out_channels,
130
+ num_blocks=num_blocks,
131
+ add_identity=add_identity,
132
+ norm_cfg=self.norm_cfg,
133
+ act_cfg=self.act_cfg)
134
+ stage.append(csp_layer)
135
+ if use_spp:
136
+ spp = SPPFBottleneck(
137
+ out_channels,
138
+ out_channels,
139
+ kernel_sizes=5,
140
+ norm_cfg=self.norm_cfg,
141
+ act_cfg=self.act_cfg)
142
+ stage.append(spp)
143
+ return stage
144
+
145
+ def init_weights(self):
146
+ """Initialize the parameters."""
147
+ if self.init_cfg is None:
148
+ for m in self.modules():
149
+ if isinstance(m, torch.nn.Conv2d):
150
+ # In order to be consistent with the source code,
151
+ # reset the Conv2d initialization parameters
152
+ m.reset_parameters()
153
+ else:
154
+ super().init_weights()
155
+
156
+
157
+ @MODELS.register_module()
158
+ class YOLOv8CSPDarknet(BaseBackbone):
159
+ """CSP-Darknet backbone used in YOLOv8.
160
+
161
+ Args:
162
+ arch (str): Architecture of CSP-Darknet, from {P5}.
163
+ Defaults to P5.
164
+ last_stage_out_channels (int): Final layer output channel.
165
+ Defaults to 1024.
166
+ plugins (list[dict]): List of plugins for stages, each dict contains:
167
+ - cfg (dict, required): Cfg dict to build plugin.
168
+ - stages (tuple[bool], optional): Stages to apply plugin, length
169
+ should be same as 'num_stages'.
170
+ deepen_factor (float): Depth multiplier, multiply number of
171
+ blocks in CSP layer by this amount. Defaults to 1.0.
172
+ widen_factor (float): Width multiplier, multiply number of
173
+ channels in each layer by this amount. Defaults to 1.0.
174
+ input_channels (int): Number of input image channels. Defaults to: 3.
175
+ out_indices (Tuple[int]): Output from which stages.
176
+ Defaults to (2, 3, 4).
177
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
178
+ mode). -1 means not freezing any parameters. Defaults to -1.
179
+ norm_cfg (dict): Dictionary to construct and config norm layer.
180
+ Defaults to dict(type='BN', requires_grad=True).
181
+ act_cfg (dict): Config dict for activation layer.
182
+ Defaults to dict(type='SiLU', inplace=True).
183
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
184
+ freeze running stats (mean and var). Note: Effect on Batch Norm
185
+ and its variants only. Defaults to False.
186
+ init_cfg (Union[dict,list[dict]], optional): Initialization config
187
+ dict. Defaults to None.
188
+
189
+ Example:
190
+ >>> from mmyolo.models import YOLOv8CSPDarknet
191
+ >>> import torch
192
+ >>> model = YOLOv8CSPDarknet()
193
+ >>> model.eval()
194
+ >>> inputs = torch.rand(1, 3, 416, 416)
195
+ >>> level_outputs = model(inputs)
196
+ >>> for level_out in level_outputs:
197
+ ... print(tuple(level_out.shape))
198
+ ...
199
+ (1, 256, 52, 52)
200
+ (1, 512, 26, 26)
201
+ (1, 1024, 13, 13)
202
+ """
203
+ # From left to right:
204
+ # in_channels, out_channels, num_blocks, add_identity, use_spp
205
+ # the final out_channels will be set according to the param.
206
+ arch_settings = {
207
+ 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
208
+ [256, 512, 6, True, False], [512, None, 3, True, True]],
209
+ }
210
+
211
+ def __init__(self,
212
+ arch: str = 'P5',
213
+ last_stage_out_channels: int = 1024,
214
+ plugins: Union[dict, List[dict]] = None,
215
+ deepen_factor: float = 1.0,
216
+ widen_factor: float = 1.0,
217
+ input_channels: int = 3,
218
+ out_indices: Tuple[int] = (2, 3, 4),
219
+ frozen_stages: int = -1,
220
+ norm_cfg: ConfigType = dict(
221
+ type='BN', momentum=0.03, eps=0.001),
222
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
223
+ norm_eval: bool = False,
224
+ init_cfg: OptMultiConfig = None):
225
+ self.arch_settings[arch][-1][1] = last_stage_out_channels
226
+ super().__init__(
227
+ self.arch_settings[arch],
228
+ deepen_factor,
229
+ widen_factor,
230
+ input_channels=input_channels,
231
+ out_indices=out_indices,
232
+ plugins=plugins,
233
+ frozen_stages=frozen_stages,
234
+ norm_cfg=norm_cfg,
235
+ act_cfg=act_cfg,
236
+ norm_eval=norm_eval,
237
+ init_cfg=init_cfg)
238
+
239
+ def build_stem_layer(self) -> nn.Module:
240
+ """Build a stem layer."""
241
+ return ConvModule(
242
+ self.input_channels,
243
+ make_divisible(self.arch_setting[0][0], self.widen_factor),
244
+ kernel_size=3,
245
+ stride=2,
246
+ padding=1,
247
+ norm_cfg=self.norm_cfg,
248
+ act_cfg=self.act_cfg)
249
+
250
+ def build_stage_layer(self, stage_idx: int, setting: list) -> list:
251
+ """Build a stage layer.
252
+
253
+ Args:
254
+ stage_idx (int): The index of a stage layer.
255
+ setting (list): The architecture setting of a stage layer.
256
+ """
257
+ in_channels, out_channels, num_blocks, add_identity, use_spp = setting
258
+
259
+ in_channels = make_divisible(in_channels, self.widen_factor)
260
+ out_channels = make_divisible(out_channels, self.widen_factor)
261
+ num_blocks = make_round(num_blocks, self.deepen_factor)
262
+ stage = []
263
+ conv_layer = ConvModule(
264
+ in_channels,
265
+ out_channels,
266
+ kernel_size=3,
267
+ stride=2,
268
+ padding=1,
269
+ norm_cfg=self.norm_cfg,
270
+ act_cfg=self.act_cfg)
271
+ stage.append(conv_layer)
272
+ csp_layer = CSPLayerWithTwoConv(
273
+ out_channels,
274
+ out_channels,
275
+ num_blocks=num_blocks,
276
+ add_identity=add_identity,
277
+ norm_cfg=self.norm_cfg,
278
+ act_cfg=self.act_cfg)
279
+ stage.append(csp_layer)
280
+ if use_spp:
281
+ spp = SPPFBottleneck(
282
+ out_channels,
283
+ out_channels,
284
+ kernel_sizes=5,
285
+ norm_cfg=self.norm_cfg,
286
+ act_cfg=self.act_cfg)
287
+ stage.append(spp)
288
+ return stage
289
+
290
+ def init_weights(self):
291
+ """Initialize the parameters."""
292
+ if self.init_cfg is None:
293
+ for m in self.modules():
294
+ if isinstance(m, torch.nn.Conv2d):
295
+ # In order to be consistent with the source code,
296
+ # reset the Conv2d initialization parameters
297
+ m.reset_parameters()
298
+ else:
299
+ super().init_weights()
300
+
301
+
302
+ @MODELS.register_module()
303
+ class YOLOXCSPDarknet(BaseBackbone):
304
+ """CSP-Darknet backbone used in YOLOX.
305
+
306
+ Args:
307
+ arch (str): Architecture of CSP-Darknet, from {P5, P6}.
308
+ Defaults to P5.
309
+ plugins (list[dict]): List of plugins for stages, each dict contains:
310
+
311
+ - cfg (dict, required): Cfg dict to build plugin.
312
+ - stages (tuple[bool], optional): Stages to apply plugin, length
313
+ should be same as 'num_stages'.
314
+ deepen_factor (float): Depth multiplier, multiply number of
315
+ blocks in CSP layer by this amount. Defaults to 1.0.
316
+ widen_factor (float): Width multiplier, multiply number of
317
+ channels in each layer by this amount. Defaults to 1.0.
318
+ input_channels (int): Number of input image channels. Defaults to 3.
319
+ out_indices (Tuple[int]): Output from which stages.
320
+ Defaults to (2, 3, 4).
321
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
322
+ mode). -1 means not freezing any parameters. Defaults to -1.
323
+ use_depthwise (bool): Whether to use depthwise separable convolution.
324
+ Defaults to False.
325
+ spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP
326
+ layers. Defaults to (5, 9, 13).
327
+ norm_cfg (dict): Dictionary to construct and config norm layer.
328
+ Defaults to dict(type='BN', momentum=0.03, eps=0.001).
329
+ act_cfg (dict): Config dict for activation layer.
330
+ Defaults to dict(type='SiLU', inplace=True).
331
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
332
+ freeze running stats (mean and var). Note: Effect on Batch Norm
333
+ and its variants only.
334
+ init_cfg (Union[dict,list[dict]], optional): Initialization config
335
+ dict. Defaults to None.
336
+ Example:
337
+ >>> from mmyolo.models import YOLOXCSPDarknet
338
+ >>> import torch
339
+ >>> model = YOLOXCSPDarknet()
340
+ >>> model.eval()
341
+ >>> inputs = torch.rand(1, 3, 416, 416)
342
+ >>> level_outputs = model(inputs)
343
+ >>> for level_out in level_outputs:
344
+ ... print(tuple(level_out.shape))
345
+ ...
346
+ (1, 256, 52, 52)
347
+ (1, 512, 26, 26)
348
+ (1, 1024, 13, 13)
349
+ """
350
+ # From left to right:
351
+ # in_channels, out_channels, num_blocks, add_identity, use_spp
352
+ arch_settings = {
353
+ 'P5': [[64, 128, 3, True, False], [128, 256, 9, True, False],
354
+ [256, 512, 9, True, False], [512, 1024, 3, False, True]],
355
+ }
356
+
357
+ def __init__(self,
358
+ arch: str = 'P5',
359
+ plugins: Union[dict, List[dict]] = None,
360
+ deepen_factor: float = 1.0,
361
+ widen_factor: float = 1.0,
362
+ input_channels: int = 3,
363
+ out_indices: Tuple[int] = (2, 3, 4),
364
+ frozen_stages: int = -1,
365
+ use_depthwise: bool = False,
366
+ spp_kernal_sizes: Tuple[int] = (5, 9, 13),
367
+ norm_cfg: ConfigType = dict(
368
+ type='BN', momentum=0.03, eps=0.001),
369
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
370
+ norm_eval: bool = False,
371
+ init_cfg: OptMultiConfig = None):
372
+ self.use_depthwise = use_depthwise
373
+ self.spp_kernal_sizes = spp_kernal_sizes
374
+ super().__init__(self.arch_settings[arch], deepen_factor, widen_factor,
375
+ input_channels, out_indices, frozen_stages, plugins,
376
+ norm_cfg, act_cfg, norm_eval, init_cfg)
377
+
378
+ def build_stem_layer(self) -> nn.Module:
379
+ """Build a stem layer."""
380
+ return Focus(
381
+ 3,
382
+ make_divisible(64, self.widen_factor),
383
+ kernel_size=3,
384
+ norm_cfg=self.norm_cfg,
385
+ act_cfg=self.act_cfg)
386
+
387
+ def build_stage_layer(self, stage_idx: int, setting: list) -> list:
388
+ """Build a stage layer.
389
+
390
+ Args:
391
+ stage_idx (int): The index of a stage layer.
392
+ setting (list): The architecture setting of a stage layer.
393
+ """
394
+ in_channels, out_channels, num_blocks, add_identity, use_spp = setting
395
+
396
+ in_channels = make_divisible(in_channels, self.widen_factor)
397
+ out_channels = make_divisible(out_channels, self.widen_factor)
398
+ num_blocks = make_round(num_blocks, self.deepen_factor)
399
+ stage = []
400
+ conv = DepthwiseSeparableConvModule \
401
+ if self.use_depthwise else ConvModule
402
+ conv_layer = conv(
403
+ in_channels,
404
+ out_channels,
405
+ kernel_size=3,
406
+ stride=2,
407
+ padding=1,
408
+ norm_cfg=self.norm_cfg,
409
+ act_cfg=self.act_cfg)
410
+ stage.append(conv_layer)
411
+ if use_spp:
412
+ spp = SPPFBottleneck(
413
+ out_channels,
414
+ out_channels,
415
+ kernel_sizes=self.spp_kernal_sizes,
416
+ norm_cfg=self.norm_cfg,
417
+ act_cfg=self.act_cfg)
418
+ stage.append(spp)
419
+ csp_layer = CSPLayer(
420
+ out_channels,
421
+ out_channels,
422
+ num_blocks=num_blocks,
423
+ add_identity=add_identity,
424
+ norm_cfg=self.norm_cfg,
425
+ act_cfg=self.act_cfg)
426
+ stage.append(csp_layer)
427
+ return stage
mmyolo/models/backbones/csp_resnet.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List, Tuple, Union
3
+
4
+ import torch.nn as nn
5
+ from mmcv.cnn import ConvModule
6
+ from mmdet.utils import ConfigType, OptMultiConfig
7
+
8
+ from mmyolo.models.backbones import BaseBackbone
9
+ from mmyolo.models.layers.yolo_bricks import CSPResLayer
10
+ from mmyolo.registry import MODELS
11
+
12
+
13
+ @MODELS.register_module()
14
+ class PPYOLOECSPResNet(BaseBackbone):
15
+ """CSP-ResNet backbone used in PPYOLOE.
16
+
17
+ Args:
18
+ arch (str): Architecture of CSPNeXt, from {P5, P6}.
19
+ Defaults to P5.
20
+ deepen_factor (float): Depth multiplier, multiply number of
21
+ blocks in CSP layer by this amount. Defaults to 1.0.
22
+ widen_factor (float): Width multiplier, multiply number of
23
+ channels in each layer by this amount. Defaults to 1.0.
24
+ out_indices (Sequence[int]): Output from which stages.
25
+ Defaults to (2, 3, 4).
26
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
27
+ mode). -1 means not freezing any parameters. Defaults to -1.
28
+ plugins (list[dict]): List of plugins for stages, each dict contains:
29
+ - cfg (dict, required): Cfg dict to build plugin.
30
+ - stages (tuple[bool], optional): Stages to apply plugin, length
31
+ should be same as 'num_stages'.
32
+ arch_ovewrite (list): Overwrite default arch settings.
33
+ Defaults to None.
34
+ block_cfg (dict): Config dict for block. Defaults to
35
+ dict(type='PPYOLOEBasicBlock', shortcut=True, use_alpha=True)
36
+ norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
37
+ config norm layer. Defaults to dict(type='BN', momentum=0.1,
38
+ eps=1e-5).
39
+ act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
40
+ Defaults to dict(type='SiLU', inplace=True).
41
+ attention_cfg (dict): Config dict for `EffectiveSELayer`.
42
+ Defaults to dict(type='EffectiveSELayer',
43
+ act_cfg=dict(type='HSigmoid')).
44
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
45
+ freeze running stats (mean and var). Note: Effect on Batch Norm
46
+ and its variants only.
47
+ init_cfg (:obj:`ConfigDict` or dict or list[dict] or
48
+ list[:obj:`ConfigDict`]): Initialization config dict.
49
+ use_large_stem (bool): Whether to use large stem layer.
50
+ Defaults to False.
51
+ """
52
+ # From left to right:
53
+ # in_channels, out_channels, num_blocks
54
+ arch_settings = {
55
+ 'P5': [[64, 128, 3], [128, 256, 6], [256, 512, 6], [512, 1024, 3]]
56
+ }
57
+
58
+ def __init__(self,
59
+ arch: str = 'P5',
60
+ deepen_factor: float = 1.0,
61
+ widen_factor: float = 1.0,
62
+ input_channels: int = 3,
63
+ out_indices: Tuple[int] = (2, 3, 4),
64
+ frozen_stages: int = -1,
65
+ plugins: Union[dict, List[dict]] = None,
66
+ arch_ovewrite: dict = None,
67
+ block_cfg: ConfigType = dict(
68
+ type='PPYOLOEBasicBlock', shortcut=True, use_alpha=True),
69
+ norm_cfg: ConfigType = dict(
70
+ type='BN', momentum=0.1, eps=1e-5),
71
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
72
+ attention_cfg: ConfigType = dict(
73
+ type='EffectiveSELayer', act_cfg=dict(type='HSigmoid')),
74
+ norm_eval: bool = False,
75
+ init_cfg: OptMultiConfig = None,
76
+ use_large_stem: bool = False):
77
+ arch_setting = self.arch_settings[arch]
78
+ if arch_ovewrite:
79
+ arch_setting = arch_ovewrite
80
+ arch_setting = [[
81
+ int(in_channels * widen_factor),
82
+ int(out_channels * widen_factor),
83
+ round(num_blocks * deepen_factor)
84
+ ] for in_channels, out_channels, num_blocks in arch_setting]
85
+ self.block_cfg = block_cfg
86
+ self.use_large_stem = use_large_stem
87
+ self.attention_cfg = attention_cfg
88
+
89
+ super().__init__(
90
+ arch_setting,
91
+ deepen_factor,
92
+ widen_factor,
93
+ input_channels=input_channels,
94
+ out_indices=out_indices,
95
+ plugins=plugins,
96
+ frozen_stages=frozen_stages,
97
+ norm_cfg=norm_cfg,
98
+ act_cfg=act_cfg,
99
+ norm_eval=norm_eval,
100
+ init_cfg=init_cfg)
101
+
102
+ def build_stem_layer(self) -> nn.Module:
103
+ """Build a stem layer."""
104
+ if self.use_large_stem:
105
+ stem = nn.Sequential(
106
+ ConvModule(
107
+ self.input_channels,
108
+ self.arch_setting[0][0] // 2,
109
+ 3,
110
+ stride=2,
111
+ padding=1,
112
+ act_cfg=self.act_cfg,
113
+ norm_cfg=self.norm_cfg),
114
+ ConvModule(
115
+ self.arch_setting[0][0] // 2,
116
+ self.arch_setting[0][0] // 2,
117
+ 3,
118
+ stride=1,
119
+ padding=1,
120
+ norm_cfg=self.norm_cfg,
121
+ act_cfg=self.act_cfg),
122
+ ConvModule(
123
+ self.arch_setting[0][0] // 2,
124
+ self.arch_setting[0][0],
125
+ 3,
126
+ stride=1,
127
+ padding=1,
128
+ norm_cfg=self.norm_cfg,
129
+ act_cfg=self.act_cfg))
130
+ else:
131
+ stem = nn.Sequential(
132
+ ConvModule(
133
+ self.input_channels,
134
+ self.arch_setting[0][0] // 2,
135
+ 3,
136
+ stride=2,
137
+ padding=1,
138
+ norm_cfg=self.norm_cfg,
139
+ act_cfg=self.act_cfg),
140
+ ConvModule(
141
+ self.arch_setting[0][0] // 2,
142
+ self.arch_setting[0][0],
143
+ 3,
144
+ stride=1,
145
+ padding=1,
146
+ norm_cfg=self.norm_cfg,
147
+ act_cfg=self.act_cfg))
148
+ return stem
149
+
150
+ def build_stage_layer(self, stage_idx: int, setting: list) -> list:
151
+ """Build a stage layer.
152
+
153
+ Args:
154
+ stage_idx (int): The index of a stage layer.
155
+ setting (list): The architecture setting of a stage layer.
156
+ """
157
+ in_channels, out_channels, num_blocks = setting
158
+
159
+ cspres_layer = CSPResLayer(
160
+ in_channels=in_channels,
161
+ out_channels=out_channels,
162
+ num_block=num_blocks,
163
+ block_cfg=self.block_cfg,
164
+ stride=2,
165
+ norm_cfg=self.norm_cfg,
166
+ act_cfg=self.act_cfg,
167
+ attention_cfg=self.attention_cfg,
168
+ use_spp=False)
169
+ return [cspres_layer]
mmyolo/models/backbones/cspnext.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ from typing import List, Sequence, Union
4
+
5
+ import torch.nn as nn
6
+ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
7
+ from mmdet.models.backbones.csp_darknet import CSPLayer
8
+ from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
9
+
10
+ from mmyolo.registry import MODELS
11
+ from ..layers import SPPFBottleneck
12
+ from .base_backbone import BaseBackbone
13
+
14
+
15
+ @MODELS.register_module()
16
+ class CSPNeXt(BaseBackbone):
17
+ """CSPNeXt backbone used in RTMDet.
18
+
19
+ Args:
20
+ arch (str): Architecture of CSPNeXt, from {P5, P6}.
21
+ Defaults to P5.
22
+ deepen_factor (float): Depth multiplier, multiply number of
23
+ blocks in CSP layer by this amount. Defaults to 1.0.
24
+ widen_factor (float): Width multiplier, multiply number of
25
+ channels in each layer by this amount. Defaults to 1.0.
26
+ out_indices (Sequence[int]): Output from which stages.
27
+ Defaults to (2, 3, 4).
28
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
29
+ mode). -1 means not freezing any parameters. Defaults to -1.
30
+ plugins (list[dict]): List of plugins for stages, each dict contains:
31
+ - cfg (dict, required): Cfg dict to build plugin.Defaults to
32
+ - stages (tuple[bool], optional): Stages to apply plugin, length
33
+ should be same as 'num_stages'.
34
+ use_depthwise (bool): Whether to use depthwise separable convolution.
35
+ Defaults to False.
36
+ expand_ratio (float): Ratio to adjust the number of channels of the
37
+ hidden layer. Defaults to 0.5.
38
+ arch_ovewrite (list): Overwrite default arch settings.
39
+ Defaults to None.
40
+ channel_attention (bool): Whether to add channel attention in each
41
+ stage. Defaults to True.
42
+ conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
43
+ convolution layer. Defaults to None.
44
+ norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
45
+ config norm layer. Defaults to dict(type='BN', requires_grad=True).
46
+ act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
47
+ Defaults to dict(type='SiLU', inplace=True).
48
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
49
+ freeze running stats (mean and var). Note: Effect on Batch Norm
50
+ and its variants only.
51
+ init_cfg (:obj:`ConfigDict` or dict or list[dict] or
52
+ list[:obj:`ConfigDict`]): Initialization config dict.
53
+ """
54
+ # From left to right:
55
+ # in_channels, out_channels, num_blocks, add_identity, use_spp
56
+ arch_settings = {
57
+ 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
58
+ [256, 512, 6, True, False], [512, 1024, 3, False, True]],
59
+ 'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
60
+ [256, 512, 6, True, False], [512, 768, 3, True, False],
61
+ [768, 1024, 3, False, True]]
62
+ }
63
+
64
+ def __init__(
65
+ self,
66
+ arch: str = 'P5',
67
+ deepen_factor: float = 1.0,
68
+ widen_factor: float = 1.0,
69
+ input_channels: int = 3,
70
+ out_indices: Sequence[int] = (2, 3, 4),
71
+ frozen_stages: int = -1,
72
+ plugins: Union[dict, List[dict]] = None,
73
+ use_depthwise: bool = False,
74
+ expand_ratio: float = 0.5,
75
+ arch_ovewrite: dict = None,
76
+ channel_attention: bool = True,
77
+ conv_cfg: OptConfigType = None,
78
+ norm_cfg: ConfigType = dict(type='BN'),
79
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
80
+ norm_eval: bool = False,
81
+ init_cfg: OptMultiConfig = dict(
82
+ type='Kaiming',
83
+ layer='Conv2d',
84
+ a=math.sqrt(5),
85
+ distribution='uniform',
86
+ mode='fan_in',
87
+ nonlinearity='leaky_relu')
88
+ ) -> None:
89
+ arch_setting = self.arch_settings[arch]
90
+ if arch_ovewrite:
91
+ arch_setting = arch_ovewrite
92
+ self.channel_attention = channel_attention
93
+ self.use_depthwise = use_depthwise
94
+ self.conv = DepthwiseSeparableConvModule \
95
+ if use_depthwise else ConvModule
96
+ self.expand_ratio = expand_ratio
97
+ self.conv_cfg = conv_cfg
98
+
99
+ super().__init__(
100
+ arch_setting,
101
+ deepen_factor,
102
+ widen_factor,
103
+ input_channels,
104
+ out_indices,
105
+ frozen_stages=frozen_stages,
106
+ plugins=plugins,
107
+ norm_cfg=norm_cfg,
108
+ act_cfg=act_cfg,
109
+ norm_eval=norm_eval,
110
+ init_cfg=init_cfg)
111
+
112
+ def build_stem_layer(self) -> nn.Module:
113
+ """Build a stem layer."""
114
+ stem = nn.Sequential(
115
+ ConvModule(
116
+ 3,
117
+ int(self.arch_setting[0][0] * self.widen_factor // 2),
118
+ 3,
119
+ padding=1,
120
+ stride=2,
121
+ norm_cfg=self.norm_cfg,
122
+ act_cfg=self.act_cfg),
123
+ ConvModule(
124
+ int(self.arch_setting[0][0] * self.widen_factor // 2),
125
+ int(self.arch_setting[0][0] * self.widen_factor // 2),
126
+ 3,
127
+ padding=1,
128
+ stride=1,
129
+ norm_cfg=self.norm_cfg,
130
+ act_cfg=self.act_cfg),
131
+ ConvModule(
132
+ int(self.arch_setting[0][0] * self.widen_factor // 2),
133
+ int(self.arch_setting[0][0] * self.widen_factor),
134
+ 3,
135
+ padding=1,
136
+ stride=1,
137
+ norm_cfg=self.norm_cfg,
138
+ act_cfg=self.act_cfg))
139
+ return stem
140
+
141
+ def build_stage_layer(self, stage_idx: int, setting: list) -> list:
142
+ """Build a stage layer.
143
+
144
+ Args:
145
+ stage_idx (int): The index of a stage layer.
146
+ setting (list): The architecture setting of a stage layer.
147
+ """
148
+ in_channels, out_channels, num_blocks, add_identity, use_spp = setting
149
+
150
+ in_channels = int(in_channels * self.widen_factor)
151
+ out_channels = int(out_channels * self.widen_factor)
152
+ num_blocks = max(round(num_blocks * self.deepen_factor), 1)
153
+
154
+ stage = []
155
+ conv_layer = self.conv(
156
+ in_channels,
157
+ out_channels,
158
+ 3,
159
+ stride=2,
160
+ padding=1,
161
+ conv_cfg=self.conv_cfg,
162
+ norm_cfg=self.norm_cfg,
163
+ act_cfg=self.act_cfg)
164
+ stage.append(conv_layer)
165
+ if use_spp:
166
+ spp = SPPFBottleneck(
167
+ out_channels,
168
+ out_channels,
169
+ kernel_sizes=5,
170
+ conv_cfg=self.conv_cfg,
171
+ norm_cfg=self.norm_cfg,
172
+ act_cfg=self.act_cfg)
173
+ stage.append(spp)
174
+ csp_layer = CSPLayer(
175
+ out_channels,
176
+ out_channels,
177
+ num_blocks=num_blocks,
178
+ add_identity=add_identity,
179
+ use_depthwise=self.use_depthwise,
180
+ use_cspnext_block=True,
181
+ expand_ratio=self.expand_ratio,
182
+ channel_attention=self.channel_attention,
183
+ conv_cfg=self.conv_cfg,
184
+ norm_cfg=self.norm_cfg,
185
+ act_cfg=self.act_cfg)
186
+ stage.append(csp_layer)
187
+ return stage
mmyolo/models/backbones/efficient_rep.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+
3
+ from typing import List, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from mmdet.utils import ConfigType, OptMultiConfig
8
+
9
+ from mmyolo.models.layers.yolo_bricks import SPPFBottleneck
10
+ from mmyolo.registry import MODELS
11
+ from ..layers import BepC3StageBlock, RepStageBlock
12
+ from ..utils import make_round
13
+ from .base_backbone import BaseBackbone
14
+
15
+
16
+ @MODELS.register_module()
17
+ class YOLOv6EfficientRep(BaseBackbone):
18
+ """EfficientRep backbone used in YOLOv6.
19
+ Args:
20
+ arch (str): Architecture of BaseDarknet, from {P5, P6}.
21
+ Defaults to P5.
22
+ plugins (list[dict]): List of plugins for stages, each dict contains:
23
+ - cfg (dict, required): Cfg dict to build plugin.
24
+ - stages (tuple[bool], optional): Stages to apply plugin, length
25
+ should be same as 'num_stages'.
26
+ deepen_factor (float): Depth multiplier, multiply number of
27
+ blocks in CSP layer by this amount. Defaults to 1.0.
28
+ widen_factor (float): Width multiplier, multiply number of
29
+ channels in each layer by this amount. Defaults to 1.0.
30
+ input_channels (int): Number of input image channels. Defaults to 3.
31
+ out_indices (Tuple[int]): Output from which stages.
32
+ Defaults to (2, 3, 4).
33
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
34
+ mode). -1 means not freezing any parameters. Defaults to -1.
35
+ norm_cfg (dict): Dictionary to construct and config norm layer.
36
+ Defaults to dict(type='BN', requires_grad=True).
37
+ act_cfg (dict): Config dict for activation layer.
38
+ Defaults to dict(type='LeakyReLU', negative_slope=0.1).
39
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
40
+ freeze running stats (mean and var). Note: Effect on Batch Norm
41
+ and its variants only. Defaults to False.
42
+ block_cfg (dict): Config dict for the block used to build each
43
+ layer. Defaults to dict(type='RepVGGBlock').
44
+ init_cfg (Union[dict, list[dict]], optional): Initialization config
45
+ dict. Defaults to None.
46
+ Example:
47
+ >>> from mmyolo.models import YOLOv6EfficientRep
48
+ >>> import torch
49
+ >>> model = YOLOv6EfficientRep()
50
+ >>> model.eval()
51
+ >>> inputs = torch.rand(1, 3, 416, 416)
52
+ >>> level_outputs = model(inputs)
53
+ >>> for level_out in level_outputs:
54
+ ... print(tuple(level_out.shape))
55
+ ...
56
+ (1, 256, 52, 52)
57
+ (1, 512, 26, 26)
58
+ (1, 1024, 13, 13)
59
+ """
60
+ # From left to right:
61
+ # in_channels, out_channels, num_blocks, use_spp
62
+ arch_settings = {
63
+ 'P5': [[64, 128, 6, False], [128, 256, 12, False],
64
+ [256, 512, 18, False], [512, 1024, 6, True]]
65
+ }
66
+
67
+ def __init__(self,
68
+ arch: str = 'P5',
69
+ plugins: Union[dict, List[dict]] = None,
70
+ deepen_factor: float = 1.0,
71
+ widen_factor: float = 1.0,
72
+ input_channels: int = 3,
73
+ out_indices: Tuple[int] = (2, 3, 4),
74
+ frozen_stages: int = -1,
75
+ norm_cfg: ConfigType = dict(
76
+ type='BN', momentum=0.03, eps=0.001),
77
+ act_cfg: ConfigType = dict(type='ReLU', inplace=True),
78
+ norm_eval: bool = False,
79
+ block_cfg: ConfigType = dict(type='RepVGGBlock'),
80
+ init_cfg: OptMultiConfig = None):
81
+ self.block_cfg = block_cfg
82
+ super().__init__(
83
+ self.arch_settings[arch],
84
+ deepen_factor,
85
+ widen_factor,
86
+ input_channels=input_channels,
87
+ out_indices=out_indices,
88
+ plugins=plugins,
89
+ frozen_stages=frozen_stages,
90
+ norm_cfg=norm_cfg,
91
+ act_cfg=act_cfg,
92
+ norm_eval=norm_eval,
93
+ init_cfg=init_cfg)
94
+
95
+ def build_stem_layer(self) -> nn.Module:
96
+ """Build a stem layer."""
97
+
98
+ block_cfg = self.block_cfg.copy()
99
+ block_cfg.update(
100
+ dict(
101
+ in_channels=self.input_channels,
102
+ out_channels=int(self.arch_setting[0][0] * self.widen_factor),
103
+ kernel_size=3,
104
+ stride=2,
105
+ ))
106
+ return MODELS.build(block_cfg)
107
+
108
+ def build_stage_layer(self, stage_idx: int, setting: list) -> list:
109
+ """Build a stage layer.
110
+
111
+ Args:
112
+ stage_idx (int): The index of a stage layer.
113
+ setting (list): The architecture setting of a stage layer.
114
+ """
115
+ in_channels, out_channels, num_blocks, use_spp = setting
116
+
117
+ in_channels = int(in_channels * self.widen_factor)
118
+ out_channels = int(out_channels * self.widen_factor)
119
+ num_blocks = make_round(num_blocks, self.deepen_factor)
120
+
121
+ rep_stage_block = RepStageBlock(
122
+ in_channels=out_channels,
123
+ out_channels=out_channels,
124
+ num_blocks=num_blocks,
125
+ block_cfg=self.block_cfg,
126
+ )
127
+
128
+ block_cfg = self.block_cfg.copy()
129
+ block_cfg.update(
130
+ dict(
131
+ in_channels=in_channels,
132
+ out_channels=out_channels,
133
+ kernel_size=3,
134
+ stride=2))
135
+ stage = []
136
+
137
+ ef_block = nn.Sequential(MODELS.build(block_cfg), rep_stage_block)
138
+
139
+ stage.append(ef_block)
140
+
141
+ if use_spp:
142
+ spp = SPPFBottleneck(
143
+ in_channels=out_channels,
144
+ out_channels=out_channels,
145
+ kernel_sizes=5,
146
+ norm_cfg=self.norm_cfg,
147
+ act_cfg=self.act_cfg)
148
+ stage.append(spp)
149
+ return stage
150
+
151
+ def init_weights(self):
152
+ if self.init_cfg is None:
153
+ """Initialize the parameters."""
154
+ for m in self.modules():
155
+ if isinstance(m, torch.nn.Conv2d):
156
+ # In order to be consistent with the source code,
157
+ # reset the Conv2d initialization parameters
158
+ m.reset_parameters()
159
+ else:
160
+ super().init_weights()
161
+
162
+
163
+ @MODELS.register_module()
164
+ class YOLOv6CSPBep(YOLOv6EfficientRep):
165
+ """CSPBep backbone used in YOLOv6.
166
+ Args:
167
+ arch (str): Architecture of BaseDarknet, from {P5, P6}.
168
+ Defaults to P5.
169
+ plugins (list[dict]): List of plugins for stages, each dict contains:
170
+ - cfg (dict, required): Cfg dict to build plugin.
171
+ - stages (tuple[bool], optional): Stages to apply plugin, length
172
+ should be same as 'num_stages'.
173
+ deepen_factor (float): Depth multiplier, multiply number of
174
+ blocks in CSP layer by this amount. Defaults to 1.0.
175
+ widen_factor (float): Width multiplier, multiply number of
176
+ channels in each layer by this amount. Defaults to 1.0.
177
+ input_channels (int): Number of input image channels. Defaults to 3.
178
+ out_indices (Tuple[int]): Output from which stages.
179
+ Defaults to (2, 3, 4).
180
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
181
+ mode). -1 means not freezing any parameters. Defaults to -1.
182
+ norm_cfg (dict): Dictionary to construct and config norm layer.
183
+ Defaults to dict(type='BN', requires_grad=True).
184
+ act_cfg (dict): Config dict for activation layer.
185
+ Defaults to dict(type='LeakyReLU', negative_slope=0.1).
186
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
187
+ freeze running stats (mean and var). Note: Effect on Batch Norm
188
+ and its variants only. Defaults to False.
189
+ block_cfg (dict): Config dict for the block used to build each
190
+ layer. Defaults to dict(type='RepVGGBlock').
191
+ block_act_cfg (dict): Config dict for activation layer used in each
192
+ stage. Defaults to dict(type='SiLU', inplace=True).
193
+ init_cfg (Union[dict, list[dict]], optional): Initialization config
194
+ dict. Defaults to None.
195
+ Example:
196
+ >>> from mmyolo.models import YOLOv6CSPBep
197
+ >>> import torch
198
+ >>> model = YOLOv6CSPBep()
199
+ >>> model.eval()
200
+ >>> inputs = torch.rand(1, 3, 416, 416)
201
+ >>> level_outputs = model(inputs)
202
+ >>> for level_out in level_outputs:
203
+ ... print(tuple(level_out.shape))
204
+ ...
205
+ (1, 256, 52, 52)
206
+ (1, 512, 26, 26)
207
+ (1, 1024, 13, 13)
208
+ """
209
+ # From left to right:
210
+ # in_channels, out_channels, num_blocks, use_spp
211
+ arch_settings = {
212
+ 'P5': [[64, 128, 6, False], [128, 256, 12, False],
213
+ [256, 512, 18, False], [512, 1024, 6, True]]
214
+ }
215
+
216
+ def __init__(self,
217
+ arch: str = 'P5',
218
+ plugins: Union[dict, List[dict]] = None,
219
+ deepen_factor: float = 1.0,
220
+ widen_factor: float = 1.0,
221
+ input_channels: int = 3,
222
+ hidden_ratio: float = 0.5,
223
+ out_indices: Tuple[int] = (2, 3, 4),
224
+ frozen_stages: int = -1,
225
+ norm_cfg: ConfigType = dict(
226
+ type='BN', momentum=0.03, eps=0.001),
227
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
228
+ norm_eval: bool = False,
229
+ block_cfg: ConfigType = dict(type='ConvWrapper'),
230
+ init_cfg: OptMultiConfig = None):
231
+ self.hidden_ratio = hidden_ratio
232
+ super().__init__(
233
+ arch=arch,
234
+ deepen_factor=deepen_factor,
235
+ widen_factor=widen_factor,
236
+ input_channels=input_channels,
237
+ out_indices=out_indices,
238
+ plugins=plugins,
239
+ frozen_stages=frozen_stages,
240
+ norm_cfg=norm_cfg,
241
+ act_cfg=act_cfg,
242
+ norm_eval=norm_eval,
243
+ block_cfg=block_cfg,
244
+ init_cfg=init_cfg)
245
+
246
+ def build_stage_layer(self, stage_idx: int, setting: list) -> list:
247
+ """Build a stage layer.
248
+
249
+ Args:
250
+ stage_idx (int): The index of a stage layer.
251
+ setting (list): The architecture setting of a stage layer.
252
+ """
253
+ in_channels, out_channels, num_blocks, use_spp = setting
254
+ in_channels = int(in_channels * self.widen_factor)
255
+ out_channels = int(out_channels * self.widen_factor)
256
+ num_blocks = make_round(num_blocks, self.deepen_factor)
257
+
258
+ rep_stage_block = BepC3StageBlock(
259
+ in_channels=out_channels,
260
+ out_channels=out_channels,
261
+ num_blocks=num_blocks,
262
+ hidden_ratio=self.hidden_ratio,
263
+ block_cfg=self.block_cfg,
264
+ norm_cfg=self.norm_cfg,
265
+ act_cfg=self.act_cfg)
266
+ block_cfg = self.block_cfg.copy()
267
+ block_cfg.update(
268
+ dict(
269
+ in_channels=in_channels,
270
+ out_channels=out_channels,
271
+ kernel_size=3,
272
+ stride=2))
273
+ stage = []
274
+
275
+ ef_block = nn.Sequential(MODELS.build(block_cfg), rep_stage_block)
276
+
277
+ stage.append(ef_block)
278
+
279
+ if use_spp:
280
+ spp = SPPFBottleneck(
281
+ in_channels=out_channels,
282
+ out_channels=out_channels,
283
+ kernel_sizes=5,
284
+ norm_cfg=self.norm_cfg,
285
+ act_cfg=self.act_cfg)
286
+ stage.append(spp)
287
+ return stage
mmyolo/models/backbones/yolov7_backbone.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch.nn as nn
5
+ from mmcv.cnn import ConvModule
6
+ from mmdet.models.backbones.csp_darknet import Focus
7
+ from mmdet.utils import ConfigType, OptMultiConfig
8
+
9
+ from mmyolo.registry import MODELS
10
+ from ..layers import MaxPoolAndStrideConvBlock
11
+ from .base_backbone import BaseBackbone
12
+
13
+
14
+ @MODELS.register_module()
15
+ class YOLOv7Backbone(BaseBackbone):
16
+ """Backbone used in YOLOv7.
17
+
18
+ Args:
19
+ arch (str): Architecture of YOLOv7Defaults to L.
20
+ deepen_factor (float): Depth multiplier, multiply number of
21
+ blocks in CSP layer by this amount. Defaults to 1.0.
22
+ widen_factor (float): Width multiplier, multiply number of
23
+ channels in each layer by this amount. Defaults to 1.0.
24
+ out_indices (Sequence[int]): Output from which stages.
25
+ Defaults to (2, 3, 4).
26
+ frozen_stages (int): Stages to be frozen (stop grad and set eval
27
+ mode). -1 means not freezing any parameters. Defaults to -1.
28
+ plugins (list[dict]): List of plugins for stages, each dict contains:
29
+
30
+ - cfg (dict, required): Cfg dict to build plugin.
31
+ - stages (tuple[bool], optional): Stages to apply plugin, length
32
+ should be same as 'num_stages'.
33
+ norm_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
34
+ config norm layer. Defaults to dict(type='BN', requires_grad=True).
35
+ act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
36
+ Defaults to dict(type='SiLU', inplace=True).
37
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
38
+ freeze running stats (mean and var). Note: Effect on Batch Norm
39
+ and its variants only.
40
+ init_cfg (:obj:`ConfigDict` or dict or list[dict] or
41
+ list[:obj:`ConfigDict`]): Initialization config dict.
42
+ """
43
+ _tiny_stage1_cfg = dict(type='TinyDownSampleBlock', middle_ratio=0.5)
44
+ _tiny_stage2_4_cfg = dict(type='TinyDownSampleBlock', middle_ratio=1.0)
45
+ _l_expand_channel_2x = dict(
46
+ type='ELANBlock',
47
+ middle_ratio=0.5,
48
+ block_ratio=0.5,
49
+ num_blocks=2,
50
+ num_convs_in_block=2)
51
+ _l_no_change_channel = dict(
52
+ type='ELANBlock',
53
+ middle_ratio=0.25,
54
+ block_ratio=0.25,
55
+ num_blocks=2,
56
+ num_convs_in_block=2)
57
+ _x_expand_channel_2x = dict(
58
+ type='ELANBlock',
59
+ middle_ratio=0.4,
60
+ block_ratio=0.4,
61
+ num_blocks=3,
62
+ num_convs_in_block=2)
63
+ _x_no_change_channel = dict(
64
+ type='ELANBlock',
65
+ middle_ratio=0.2,
66
+ block_ratio=0.2,
67
+ num_blocks=3,
68
+ num_convs_in_block=2)
69
+ _w_no_change_channel = dict(
70
+ type='ELANBlock',
71
+ middle_ratio=0.5,
72
+ block_ratio=0.5,
73
+ num_blocks=2,
74
+ num_convs_in_block=2)
75
+ _e_no_change_channel = dict(
76
+ type='ELANBlock',
77
+ middle_ratio=0.4,
78
+ block_ratio=0.4,
79
+ num_blocks=3,
80
+ num_convs_in_block=2)
81
+ _d_no_change_channel = dict(
82
+ type='ELANBlock',
83
+ middle_ratio=1 / 3,
84
+ block_ratio=1 / 3,
85
+ num_blocks=4,
86
+ num_convs_in_block=2)
87
+ _e2e_no_change_channel = dict(
88
+ type='EELANBlock',
89
+ num_elan_block=2,
90
+ middle_ratio=0.4,
91
+ block_ratio=0.4,
92
+ num_blocks=3,
93
+ num_convs_in_block=2)
94
+
95
+ # From left to right:
96
+ # in_channels, out_channels, Block_params
97
+ arch_settings = {
98
+ 'Tiny': [[64, 64, _tiny_stage1_cfg], [64, 128, _tiny_stage2_4_cfg],
99
+ [128, 256, _tiny_stage2_4_cfg],
100
+ [256, 512, _tiny_stage2_4_cfg]],
101
+ 'L': [[64, 256, _l_expand_channel_2x],
102
+ [256, 512, _l_expand_channel_2x],
103
+ [512, 1024, _l_expand_channel_2x],
104
+ [1024, 1024, _l_no_change_channel]],
105
+ 'X': [[80, 320, _x_expand_channel_2x],
106
+ [320, 640, _x_expand_channel_2x],
107
+ [640, 1280, _x_expand_channel_2x],
108
+ [1280, 1280, _x_no_change_channel]],
109
+ 'W':
110
+ [[64, 128, _w_no_change_channel], [128, 256, _w_no_change_channel],
111
+ [256, 512, _w_no_change_channel], [512, 768, _w_no_change_channel],
112
+ [768, 1024, _w_no_change_channel]],
113
+ 'E':
114
+ [[80, 160, _e_no_change_channel], [160, 320, _e_no_change_channel],
115
+ [320, 640, _e_no_change_channel], [640, 960, _e_no_change_channel],
116
+ [960, 1280, _e_no_change_channel]],
117
+ 'D': [[96, 192,
118
+ _d_no_change_channel], [192, 384, _d_no_change_channel],
119
+ [384, 768, _d_no_change_channel],
120
+ [768, 1152, _d_no_change_channel],
121
+ [1152, 1536, _d_no_change_channel]],
122
+ 'E2E': [[80, 160, _e2e_no_change_channel],
123
+ [160, 320, _e2e_no_change_channel],
124
+ [320, 640, _e2e_no_change_channel],
125
+ [640, 960, _e2e_no_change_channel],
126
+ [960, 1280, _e2e_no_change_channel]],
127
+ }
128
+
129
+ def __init__(self,
130
+ arch: str = 'L',
131
+ deepen_factor: float = 1.0,
132
+ widen_factor: float = 1.0,
133
+ input_channels: int = 3,
134
+ out_indices: Tuple[int] = (2, 3, 4),
135
+ frozen_stages: int = -1,
136
+ plugins: Union[dict, List[dict]] = None,
137
+ norm_cfg: ConfigType = dict(
138
+ type='BN', momentum=0.03, eps=0.001),
139
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
140
+ norm_eval: bool = False,
141
+ init_cfg: OptMultiConfig = None):
142
+ assert arch in self.arch_settings.keys()
143
+ self.arch = arch
144
+ super().__init__(
145
+ self.arch_settings[arch],
146
+ deepen_factor,
147
+ widen_factor,
148
+ input_channels=input_channels,
149
+ out_indices=out_indices,
150
+ plugins=plugins,
151
+ frozen_stages=frozen_stages,
152
+ norm_cfg=norm_cfg,
153
+ act_cfg=act_cfg,
154
+ norm_eval=norm_eval,
155
+ init_cfg=init_cfg)
156
+
157
+ def build_stem_layer(self) -> nn.Module:
158
+ """Build a stem layer."""
159
+ if self.arch in ['L', 'X']:
160
+ stem = nn.Sequential(
161
+ ConvModule(
162
+ 3,
163
+ int(self.arch_setting[0][0] * self.widen_factor // 2),
164
+ 3,
165
+ padding=1,
166
+ stride=1,
167
+ norm_cfg=self.norm_cfg,
168
+ act_cfg=self.act_cfg),
169
+ ConvModule(
170
+ int(self.arch_setting[0][0] * self.widen_factor // 2),
171
+ int(self.arch_setting[0][0] * self.widen_factor),
172
+ 3,
173
+ padding=1,
174
+ stride=2,
175
+ norm_cfg=self.norm_cfg,
176
+ act_cfg=self.act_cfg),
177
+ ConvModule(
178
+ int(self.arch_setting[0][0] * self.widen_factor),
179
+ int(self.arch_setting[0][0] * self.widen_factor),
180
+ 3,
181
+ padding=1,
182
+ stride=1,
183
+ norm_cfg=self.norm_cfg,
184
+ act_cfg=self.act_cfg))
185
+ elif self.arch == 'Tiny':
186
+ stem = nn.Sequential(
187
+ ConvModule(
188
+ 3,
189
+ int(self.arch_setting[0][0] * self.widen_factor // 2),
190
+ 3,
191
+ padding=1,
192
+ stride=2,
193
+ norm_cfg=self.norm_cfg,
194
+ act_cfg=self.act_cfg),
195
+ ConvModule(
196
+ int(self.arch_setting[0][0] * self.widen_factor // 2),
197
+ int(self.arch_setting[0][0] * self.widen_factor),
198
+ 3,
199
+ padding=1,
200
+ stride=2,
201
+ norm_cfg=self.norm_cfg,
202
+ act_cfg=self.act_cfg))
203
+ elif self.arch in ['W', 'E', 'D', 'E2E']:
204
+ stem = Focus(
205
+ 3,
206
+ int(self.arch_setting[0][0] * self.widen_factor),
207
+ kernel_size=3,
208
+ norm_cfg=self.norm_cfg,
209
+ act_cfg=self.act_cfg)
210
+ return stem
211
+
212
+ def build_stage_layer(self, stage_idx: int, setting: list) -> list:
213
+ """Build a stage layer.
214
+
215
+ Args:
216
+ stage_idx (int): The index of a stage layer.
217
+ setting (list): The architecture setting of a stage layer.
218
+ """
219
+ in_channels, out_channels, stage_block_cfg = setting
220
+ in_channels = int(in_channels * self.widen_factor)
221
+ out_channels = int(out_channels * self.widen_factor)
222
+
223
+ stage_block_cfg = stage_block_cfg.copy()
224
+ stage_block_cfg.setdefault('norm_cfg', self.norm_cfg)
225
+ stage_block_cfg.setdefault('act_cfg', self.act_cfg)
226
+
227
+ stage_block_cfg['in_channels'] = in_channels
228
+ stage_block_cfg['out_channels'] = out_channels
229
+
230
+ stage = []
231
+ if self.arch in ['W', 'E', 'D', 'E2E']:
232
+ stage_block_cfg['in_channels'] = out_channels
233
+ elif self.arch in ['L', 'X']:
234
+ if stage_idx == 0:
235
+ stage_block_cfg['in_channels'] = out_channels // 2
236
+
237
+ downsample_layer = self._build_downsample_layer(
238
+ stage_idx, in_channels, out_channels)
239
+ stage.append(MODELS.build(stage_block_cfg))
240
+ if downsample_layer is not None:
241
+ stage.insert(0, downsample_layer)
242
+ return stage
243
+
244
+ def _build_downsample_layer(self, stage_idx: int, in_channels: int,
245
+ out_channels: int) -> Optional[nn.Module]:
246
+ """Build a downsample layer pre stage."""
247
+ if self.arch in ['E', 'D', 'E2E']:
248
+ downsample_layer = MaxPoolAndStrideConvBlock(
249
+ in_channels,
250
+ out_channels,
251
+ use_in_channels_of_middle=True,
252
+ norm_cfg=self.norm_cfg,
253
+ act_cfg=self.act_cfg)
254
+ elif self.arch == 'W':
255
+ downsample_layer = ConvModule(
256
+ in_channels,
257
+ out_channels,
258
+ 3,
259
+ stride=2,
260
+ padding=1,
261
+ norm_cfg=self.norm_cfg,
262
+ act_cfg=self.act_cfg)
263
+ elif self.arch == 'Tiny':
264
+ if stage_idx != 0:
265
+ downsample_layer = nn.MaxPool2d(2, 2)
266
+ else:
267
+ downsample_layer = None
268
+ elif self.arch in ['L', 'X']:
269
+ if stage_idx == 0:
270
+ downsample_layer = ConvModule(
271
+ in_channels,
272
+ out_channels // 2,
273
+ 3,
274
+ stride=2,
275
+ padding=1,
276
+ norm_cfg=self.norm_cfg,
277
+ act_cfg=self.act_cfg)
278
+ else:
279
+ downsample_layer = MaxPoolAndStrideConvBlock(
280
+ in_channels,
281
+ in_channels,
282
+ use_in_channels_of_middle=False,
283
+ norm_cfg=self.norm_cfg,
284
+ act_cfg=self.act_cfg)
285
+ return downsample_layer
mmyolo/models/data_preprocessors/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .data_preprocessor import (PPYOLOEBatchRandomResize,
3
+ PPYOLOEDetDataPreprocessor,
4
+ YOLOv5DetDataPreprocessor,
5
+ YOLOXBatchSyncRandomResize)
6
+
7
+ __all__ = [
8
+ 'YOLOv5DetDataPreprocessor', 'PPYOLOEDetDataPreprocessor',
9
+ 'PPYOLOEBatchRandomResize', 'YOLOXBatchSyncRandomResize'
10
+ ]
mmyolo/models/data_preprocessors/data_preprocessor.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import random
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from mmdet.models import BatchSyncRandomResize
8
+ from mmdet.models.data_preprocessors import DetDataPreprocessor
9
+ from mmengine import MessageHub, is_list_of
10
+ from mmengine.structures import BaseDataElement
11
+ from torch import Tensor
12
+
13
+ from mmyolo.registry import MODELS
14
+
15
+ CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str,
16
+ None]
17
+
18
+
19
+ @MODELS.register_module()
20
+ class YOLOXBatchSyncRandomResize(BatchSyncRandomResize):
21
+ """YOLOX batch random resize.
22
+
23
+ Args:
24
+ random_size_range (tuple): The multi-scale random range during
25
+ multi-scale training.
26
+ interval (int): The iter interval of change
27
+ image size. Defaults to 10.
28
+ size_divisor (int): Image size divisible factor.
29
+ Defaults to 32.
30
+ """
31
+
32
+ def forward(self, inputs: Tensor, data_samples: dict) -> Tensor and dict:
33
+ """resize a batch of images and bboxes to shape ``self._input_size``"""
34
+ h, w = inputs.shape[-2:]
35
+ inputs = inputs.float()
36
+ assert isinstance(data_samples, dict)
37
+
38
+ if self._input_size is None:
39
+ self._input_size = (h, w)
40
+ scale_y = self._input_size[0] / h
41
+ scale_x = self._input_size[1] / w
42
+ if scale_x != 1 or scale_y != 1:
43
+ inputs = F.interpolate(
44
+ inputs,
45
+ size=self._input_size,
46
+ mode='bilinear',
47
+ align_corners=False)
48
+
49
+ data_samples['bboxes_labels'][:, 2::2] *= scale_x
50
+ data_samples['bboxes_labels'][:, 3::2] *= scale_y
51
+
52
+ message_hub = MessageHub.get_current_instance()
53
+ if (message_hub.get_info('iter') + 1) % self._interval == 0:
54
+ self._input_size = self._get_random_size(
55
+ aspect_ratio=float(w / h), device=inputs.device)
56
+
57
+ return inputs, data_samples
58
+
59
+
60
+ @MODELS.register_module()
61
+ class YOLOv5DetDataPreprocessor(DetDataPreprocessor):
62
+ """Rewrite collate_fn to get faster training speed.
63
+
64
+ Note: It must be used together with `mmyolo.datasets.utils.yolov5_collate`
65
+ """
66
+
67
+ def __init__(self, *args, non_blocking: Optional[bool] = True, **kwargs):
68
+ super().__init__(*args, non_blocking=non_blocking, **kwargs)
69
+
70
+ def forward(self, data: dict, training: bool = False) -> dict:
71
+ """Perform normalization, padding and bgr2rgb conversion based on
72
+ ``DetDataPreprocessorr``.
73
+
74
+ Args:
75
+ data (dict): Data sampled from dataloader.
76
+ training (bool): Whether to enable training time augmentation.
77
+
78
+ Returns:
79
+ dict: Data in the same format as the model input.
80
+ """
81
+ if not training:
82
+ return super().forward(data, training)
83
+
84
+ data = self.cast_data(data)
85
+ inputs, data_samples = data['inputs'], data['data_samples']
86
+ assert isinstance(data['data_samples'], dict)
87
+
88
+ # TODO: Supports multi-scale training
89
+ if self._channel_conversion and inputs.shape[1] == 3:
90
+ inputs = inputs[:, [2, 1, 0], ...]
91
+ if self._enable_normalize:
92
+ inputs = (inputs - self.mean) / self.std
93
+
94
+ if self.batch_augments is not None:
95
+ for batch_aug in self.batch_augments:
96
+ inputs, data_samples = batch_aug(inputs, data_samples)
97
+
98
+ img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
99
+ data_samples_output = {
100
+ 'bboxes_labels': data_samples['bboxes_labels'],
101
+ 'img_metas': img_metas
102
+ }
103
+ if 'masks' in data_samples:
104
+ data_samples_output['masks'] = data_samples['masks']
105
+
106
+ return {'inputs': inputs, 'data_samples': data_samples_output}
107
+
108
+
109
+ @MODELS.register_module()
110
+ class PPYOLOEDetDataPreprocessor(DetDataPreprocessor):
111
+ """Image pre-processor for detection tasks.
112
+
113
+ The main difference between PPYOLOEDetDataPreprocessor and
114
+ DetDataPreprocessor is the normalization order. The official
115
+ PPYOLOE resize image first, and then normalize image.
116
+ In DetDataPreprocessor, the order is reversed.
117
+
118
+ Note: It must be used together with
119
+ `mmyolo.datasets.utils.yolov5_collate`
120
+ """
121
+
122
+ def forward(self, data: dict, training: bool = False) -> dict:
123
+ """Perform normalization、padding and bgr2rgb conversion based on
124
+ ``BaseDataPreprocessor``. This class use batch_augments first, and then
125
+ normalize the image, which is different from the `DetDataPreprocessor`
126
+ .
127
+
128
+ Args:
129
+ data (dict): Data sampled from dataloader.
130
+ training (bool): Whether to enable training time augmentation.
131
+
132
+ Returns:
133
+ dict: Data in the same format as the model input.
134
+ """
135
+ if not training:
136
+ return super().forward(data, training)
137
+
138
+ assert isinstance(data['inputs'], list) and is_list_of(
139
+ data['inputs'], torch.Tensor), \
140
+ '"inputs" should be a list of Tensor, but got ' \
141
+ f'{type(data["inputs"])}. The possible reason for this ' \
142
+ 'is that you are not using it with ' \
143
+ '"mmyolo.datasets.utils.yolov5_collate". Please refer to ' \
144
+ '"cconfigs/ppyoloe/ppyoloe_plus_s_fast_8xb8-80e_coco.py".'
145
+
146
+ data = self.cast_data(data)
147
+ inputs, data_samples = data['inputs'], data['data_samples']
148
+ assert isinstance(data['data_samples'], dict)
149
+
150
+ # Process data.
151
+ batch_inputs = []
152
+ for _input in inputs:
153
+ # channel transform
154
+ if self._channel_conversion:
155
+ _input = _input[[2, 1, 0], ...]
156
+ # Convert to float after channel conversion to ensure
157
+ # efficiency
158
+ _input = _input.float()
159
+ batch_inputs.append(_input)
160
+
161
+ # Batch random resize image.
162
+ if self.batch_augments is not None:
163
+ for batch_aug in self.batch_augments:
164
+ inputs, data_samples = batch_aug(batch_inputs, data_samples)
165
+
166
+ if self._enable_normalize:
167
+ inputs = (inputs - self.mean) / self.std
168
+
169
+ img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs)
170
+ data_samples = {
171
+ 'bboxes_labels': data_samples['bboxes_labels'],
172
+ 'img_metas': img_metas
173
+ }
174
+
175
+ return {'inputs': inputs, 'data_samples': data_samples}
176
+
177
+
178
+ # TODO: No generality. Its input data format is different
179
+ # mmdet's batch aug, and it must be compatible in the future.
180
+ @MODELS.register_module()
181
+ class PPYOLOEBatchRandomResize(BatchSyncRandomResize):
182
+ """PPYOLOE batch random resize.
183
+
184
+ Args:
185
+ random_size_range (tuple): The multi-scale random range during
186
+ multi-scale training.
187
+ interval (int): The iter interval of change
188
+ image size. Defaults to 10.
189
+ size_divisor (int): Image size divisible factor.
190
+ Defaults to 32.
191
+ random_interp (bool): Whether to choose interp_mode randomly.
192
+ If set to True, the type of `interp_mode` must be list.
193
+ If set to False, the type of `interp_mode` must be str.
194
+ Defaults to True.
195
+ interp_mode (Union[List, str]): The modes available for resizing
196
+ are ('nearest', 'bilinear', 'bicubic', 'area').
197
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing
198
+ the image. Now we only support keep_ratio=False.
199
+ Defaults to False.
200
+ """
201
+
202
+ def __init__(self,
203
+ random_size_range: Tuple[int, int],
204
+ interval: int = 1,
205
+ size_divisor: int = 32,
206
+ random_interp=True,
207
+ interp_mode: Union[List[str], str] = [
208
+ 'nearest', 'bilinear', 'bicubic', 'area'
209
+ ],
210
+ keep_ratio: bool = False) -> None:
211
+ super().__init__(random_size_range, interval, size_divisor)
212
+ self.random_interp = random_interp
213
+ self.keep_ratio = keep_ratio
214
+ # TODO: need to support keep_ratio==True
215
+ assert not self.keep_ratio, 'We do not yet support keep_ratio=True'
216
+
217
+ if self.random_interp:
218
+ assert isinstance(interp_mode, list) and len(interp_mode) > 1,\
219
+ 'While random_interp==True, the type of `interp_mode`' \
220
+ ' must be list and len(interp_mode) must large than 1'
221
+ self.interp_mode_list = interp_mode
222
+ self.interp_mode = None
223
+ else:
224
+ assert isinstance(interp_mode, str),\
225
+ 'While random_interp==False, the type of ' \
226
+ '`interp_mode` must be str'
227
+ assert interp_mode in ['nearest', 'bilinear', 'bicubic', 'area']
228
+ self.interp_mode_list = None
229
+ self.interp_mode = interp_mode
230
+
231
+ def forward(self, inputs: list,
232
+ data_samples: dict) -> Tuple[Tensor, Tensor]:
233
+ """Resize a batch of images and bboxes to shape ``self._input_size``.
234
+
235
+ The inputs and data_samples should be list, and
236
+ ``PPYOLOEBatchRandomResize`` must be used with
237
+ ``PPYOLOEDetDataPreprocessor`` and ``yolov5_collate`` with
238
+ ``use_ms_training == True``.
239
+ """
240
+ assert isinstance(inputs, list),\
241
+ 'The type of inputs must be list. The possible reason for this ' \
242
+ 'is that you are not using it with `PPYOLOEDetDataPreprocessor` ' \
243
+ 'and `yolov5_collate` with use_ms_training == True.'
244
+
245
+ bboxes_labels = data_samples['bboxes_labels']
246
+
247
+ message_hub = MessageHub.get_current_instance()
248
+ if (message_hub.get_info('iter') + 1) % self._interval == 0:
249
+ # get current input size
250
+ self._input_size, interp_mode = self._get_random_size_and_interp()
251
+ if self.random_interp:
252
+ self.interp_mode = interp_mode
253
+
254
+ # TODO: need to support type(inputs)==Tensor
255
+ if isinstance(inputs, list):
256
+ outputs = []
257
+ for i in range(len(inputs)):
258
+ _batch_input = inputs[i]
259
+ h, w = _batch_input.shape[-2:]
260
+ scale_y = self._input_size[0] / h
261
+ scale_x = self._input_size[1] / w
262
+ if scale_x != 1. or scale_y != 1.:
263
+ if self.interp_mode in ('nearest', 'area'):
264
+ align_corners = None
265
+ else:
266
+ align_corners = False
267
+ _batch_input = F.interpolate(
268
+ _batch_input.unsqueeze(0),
269
+ size=self._input_size,
270
+ mode=self.interp_mode,
271
+ align_corners=align_corners)
272
+
273
+ # rescale boxes
274
+ indexes = bboxes_labels[:, 0] == i
275
+ bboxes_labels[indexes, 2] *= scale_x
276
+ bboxes_labels[indexes, 3] *= scale_y
277
+ bboxes_labels[indexes, 4] *= scale_x
278
+ bboxes_labels[indexes, 5] *= scale_y
279
+
280
+ data_samples['bboxes_labels'] = bboxes_labels
281
+ else:
282
+ _batch_input = _batch_input.unsqueeze(0)
283
+
284
+ outputs.append(_batch_input)
285
+
286
+ # convert to Tensor
287
+ return torch.cat(outputs, dim=0), data_samples
288
+ else:
289
+ raise NotImplementedError('Not implemented yet!')
290
+
291
+ def _get_random_size_and_interp(self) -> Tuple[int, int]:
292
+ """Randomly generate a shape in ``_random_size_range`` and a
293
+ interp_mode in interp_mode_list."""
294
+ size = random.randint(*self._random_size_range)
295
+ input_size = (self._size_divisor * size, self._size_divisor * size)
296
+
297
+ if self.random_interp:
298
+ interp_ind = random.randint(0, len(self.interp_mode_list) - 1)
299
+ interp_mode = self.interp_mode_list[interp_ind]
300
+ else:
301
+ interp_mode = None
302
+ return input_size, interp_mode
mmyolo/models/dense_heads/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .ppyoloe_head import PPYOLOEHead, PPYOLOEHeadModule
3
+ from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
4
+ from .rtmdet_ins_head import RTMDetInsSepBNHead, RTMDetInsSepBNHeadModule
5
+ from .rtmdet_rotated_head import (RTMDetRotatedHead,
6
+ RTMDetRotatedSepBNHeadModule)
7
+ from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
8
+ from .yolov6_head import YOLOv6Head, YOLOv6HeadModule
9
+ from .yolov7_head import YOLOv7Head, YOLOv7HeadModule, YOLOv7p6HeadModule
10
+ from .yolov8_head import YOLOv8Head, YOLOv8HeadModule
11
+ from .yolox_head import YOLOXHead, YOLOXHeadModule
12
+
13
+ __all__ = [
14
+ 'YOLOv5Head', 'YOLOv6Head', 'YOLOXHead', 'YOLOv5HeadModule',
15
+ 'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead',
16
+ 'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule',
17
+ 'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule',
18
+ 'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule', 'RTMDetInsSepBNHead',
19
+ 'RTMDetInsSepBNHeadModule'
20
+ ]
mmyolo/models/dense_heads/ppyoloe_head.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import Sequence, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from mmdet.models.utils import multi_apply
8
+ from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
9
+ OptMultiConfig, reduce_mean)
10
+ from mmengine import MessageHub
11
+ from mmengine.model import BaseModule, bias_init_with_prob
12
+ from mmengine.structures import InstanceData
13
+ from torch import Tensor
14
+
15
+ from mmyolo.registry import MODELS
16
+ from ..layers.yolo_bricks import PPYOLOESELayer
17
+ from ..utils import gt_instances_preprocess
18
+ from .yolov6_head import YOLOv6Head
19
+
20
+
21
+ @MODELS.register_module()
22
+ class PPYOLOEHeadModule(BaseModule):
23
+ """PPYOLOEHead head module used in `PPYOLOE.
24
+
25
+ <https://arxiv.org/abs/2203.16250>`_.
26
+
27
+ Args:
28
+ num_classes (int): Number of categories excluding the background
29
+ category.
30
+ in_channels (int): Number of channels in the input feature map.
31
+ widen_factor (float): Width multiplier, multiply number of
32
+ channels in each layer by this amount. Defaults to 1.0.
33
+ num_base_priors (int): The number of priors (points) at a point
34
+ on the feature grid.
35
+ featmap_strides (Sequence[int]): Downsample factor of each feature map.
36
+ Defaults to (8, 16, 32).
37
+ reg_max (int): Max value of integral set :math: ``{0, ..., reg_max}``
38
+ in QFL setting. Defaults to 16.
39
+ norm_cfg (dict): Config dict for normalization layer.
40
+ Defaults to dict(type='BN', momentum=0.03, eps=0.001).
41
+ act_cfg (dict): Config dict for activation layer.
42
+ Defaults to dict(type='SiLU', inplace=True).
43
+ init_cfg (dict or list[dict], optional): Initialization config dict.
44
+ Defaults to None.
45
+ """
46
+
47
+ def __init__(self,
48
+ num_classes: int,
49
+ in_channels: Union[int, Sequence],
50
+ widen_factor: float = 1.0,
51
+ num_base_priors: int = 1,
52
+ featmap_strides: Sequence[int] = (8, 16, 32),
53
+ reg_max: int = 16,
54
+ norm_cfg: ConfigType = dict(
55
+ type='BN', momentum=0.1, eps=1e-5),
56
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
57
+ init_cfg: OptMultiConfig = None):
58
+ super().__init__(init_cfg=init_cfg)
59
+
60
+ self.num_classes = num_classes
61
+ self.featmap_strides = featmap_strides
62
+ self.num_levels = len(self.featmap_strides)
63
+ self.num_base_priors = num_base_priors
64
+ self.norm_cfg = norm_cfg
65
+ self.act_cfg = act_cfg
66
+ self.reg_max = reg_max
67
+
68
+ if isinstance(in_channels, int):
69
+ self.in_channels = [int(in_channels * widen_factor)
70
+ ] * self.num_levels
71
+ else:
72
+ self.in_channels = [int(i * widen_factor) for i in in_channels]
73
+
74
+ self._init_layers()
75
+
76
+ def init_weights(self, prior_prob=0.01):
77
+ """Initialize the weight and bias of PPYOLOE head."""
78
+ super().init_weights()
79
+ for conv in self.cls_preds:
80
+ conv.bias.data.fill_(bias_init_with_prob(prior_prob))
81
+ conv.weight.data.fill_(0.)
82
+
83
+ for conv in self.reg_preds:
84
+ conv.bias.data.fill_(1.0)
85
+ conv.weight.data.fill_(0.)
86
+
87
+ def _init_layers(self):
88
+ """initialize conv layers in PPYOLOE head."""
89
+ self.cls_preds = nn.ModuleList()
90
+ self.reg_preds = nn.ModuleList()
91
+ self.cls_stems = nn.ModuleList()
92
+ self.reg_stems = nn.ModuleList()
93
+
94
+ for in_channel in self.in_channels:
95
+ self.cls_stems.append(
96
+ PPYOLOESELayer(
97
+ in_channel, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
98
+ self.reg_stems.append(
99
+ PPYOLOESELayer(
100
+ in_channel, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg))
101
+
102
+ for in_channel in self.in_channels:
103
+ self.cls_preds.append(
104
+ nn.Conv2d(in_channel, self.num_classes, 3, padding=1))
105
+ self.reg_preds.append(
106
+ nn.Conv2d(in_channel, 4 * (self.reg_max + 1), 3, padding=1))
107
+
108
+ # init proj
109
+ proj = torch.linspace(0, self.reg_max, self.reg_max + 1).view(
110
+ [1, self.reg_max + 1, 1, 1])
111
+ self.register_buffer('proj', proj, persistent=False)
112
+
113
+ def forward(self, x: Tuple[Tensor]) -> Tensor:
114
+ """Forward features from the upstream network.
115
+
116
+ Args:
117
+ x (Tuple[Tensor]): Features from the upstream network, each is
118
+ a 4D-tensor.
119
+ Returns:
120
+ Tuple[List]: A tuple of multi-level classification scores, bbox
121
+ predictions.
122
+ """
123
+ assert len(x) == self.num_levels
124
+
125
+ return multi_apply(self.forward_single, x, self.cls_stems,
126
+ self.cls_preds, self.reg_stems, self.reg_preds)
127
+
128
+ def forward_single(self, x: Tensor, cls_stem: nn.ModuleList,
129
+ cls_pred: nn.ModuleList, reg_stem: nn.ModuleList,
130
+ reg_pred: nn.ModuleList) -> Tensor:
131
+ """Forward feature of a single scale level."""
132
+ b, _, h, w = x.shape
133
+ hw = h * w
134
+ avg_feat = F.adaptive_avg_pool2d(x, (1, 1))
135
+ cls_logit = cls_pred(cls_stem(x, avg_feat) + x)
136
+ bbox_dist_preds = reg_pred(reg_stem(x, avg_feat))
137
+ # TODO: Test whether use matmul instead of conv can speed up training.
138
+ bbox_dist_preds = bbox_dist_preds.reshape(
139
+ [-1, 4, self.reg_max + 1, hw]).permute(0, 2, 3, 1)
140
+
141
+ bbox_preds = F.conv2d(F.softmax(bbox_dist_preds, dim=1), self.proj)
142
+
143
+ if self.training:
144
+ return cls_logit, bbox_preds, bbox_dist_preds
145
+ else:
146
+ return cls_logit, bbox_preds
147
+
148
+
149
+ @MODELS.register_module()
150
+ class PPYOLOEHead(YOLOv6Head):
151
+ """PPYOLOEHead head used in `PPYOLOE <https://arxiv.org/abs/2203.16250>`_.
152
+ The YOLOv6 head and the PPYOLOE head are only slightly different.
153
+ Distribution focal loss is extra used in PPYOLOE, but not in YOLOv6.
154
+
155
+ Args:
156
+ head_module(ConfigType): Base module used for YOLOv5Head
157
+ prior_generator(dict): Points generator feature maps in
158
+ 2D points-based detectors.
159
+ bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
160
+ loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
161
+ loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
162
+ loss_dfl (:obj:`ConfigDict` or dict): Config of distribution focal
163
+ loss.
164
+ train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
165
+ anchor head. Defaults to None.
166
+ test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
167
+ anchor head. Defaults to None.
168
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
169
+ list[dict], optional): Initialization config dict.
170
+ Defaults to None.
171
+ """
172
+
173
+ def __init__(self,
174
+ head_module: ConfigType,
175
+ prior_generator: ConfigType = dict(
176
+ type='mmdet.MlvlPointGenerator',
177
+ offset=0.5,
178
+ strides=[8, 16, 32]),
179
+ bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
180
+ loss_cls: ConfigType = dict(
181
+ type='mmdet.VarifocalLoss',
182
+ use_sigmoid=True,
183
+ alpha=0.75,
184
+ gamma=2.0,
185
+ iou_weighted=True,
186
+ reduction='sum',
187
+ loss_weight=1.0),
188
+ loss_bbox: ConfigType = dict(
189
+ type='IoULoss',
190
+ iou_mode='giou',
191
+ bbox_format='xyxy',
192
+ reduction='mean',
193
+ loss_weight=2.5,
194
+ return_iou=False),
195
+ loss_dfl: ConfigType = dict(
196
+ type='mmdet.DistributionFocalLoss',
197
+ reduction='mean',
198
+ loss_weight=0.5 / 4),
199
+ train_cfg: OptConfigType = None,
200
+ test_cfg: OptConfigType = None,
201
+ init_cfg: OptMultiConfig = None):
202
+ super().__init__(
203
+ head_module=head_module,
204
+ prior_generator=prior_generator,
205
+ bbox_coder=bbox_coder,
206
+ loss_cls=loss_cls,
207
+ loss_bbox=loss_bbox,
208
+ train_cfg=train_cfg,
209
+ test_cfg=test_cfg,
210
+ init_cfg=init_cfg)
211
+ self.loss_dfl = MODELS.build(loss_dfl)
212
+ # ppyoloe doesn't need loss_obj
213
+ self.loss_obj = None
214
+
215
+ def loss_by_feat(
216
+ self,
217
+ cls_scores: Sequence[Tensor],
218
+ bbox_preds: Sequence[Tensor],
219
+ bbox_dist_preds: Sequence[Tensor],
220
+ batch_gt_instances: Sequence[InstanceData],
221
+ batch_img_metas: Sequence[dict],
222
+ batch_gt_instances_ignore: OptInstanceList = None) -> dict:
223
+ """Calculate the loss based on the features extracted by the detection
224
+ head.
225
+
226
+ Args:
227
+ cls_scores (Sequence[Tensor]): Box scores for each scale level,
228
+ each is a 4D-tensor, the channel number is
229
+ num_priors * num_classes.
230
+ bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
231
+ level, each is a 4D-tensor, the channel number is
232
+ num_priors * 4.
233
+ bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
234
+ each scale level with shape (bs, reg_max + 1, H*W, 4).
235
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
236
+ gt_instance. It usually includes ``bboxes`` and ``labels``
237
+ attributes.
238
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
239
+ image size, scaling factor, etc.
240
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
241
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
242
+ data that is ignored during training and testing.
243
+ Defaults to None.
244
+ Returns:
245
+ dict[str, Tensor]: A dictionary of losses.
246
+ """
247
+
248
+ # get epoch information from message hub
249
+ message_hub = MessageHub.get_current_instance()
250
+ current_epoch = message_hub.get_info('epoch')
251
+
252
+ num_imgs = len(batch_img_metas)
253
+
254
+ current_featmap_sizes = [
255
+ cls_score.shape[2:] for cls_score in cls_scores
256
+ ]
257
+ # If the shape does not equal, generate new one
258
+ if current_featmap_sizes != self.featmap_sizes_train:
259
+ self.featmap_sizes_train = current_featmap_sizes
260
+
261
+ mlvl_priors_with_stride = self.prior_generator.grid_priors(
262
+ self.featmap_sizes_train,
263
+ dtype=cls_scores[0].dtype,
264
+ device=cls_scores[0].device,
265
+ with_stride=True)
266
+
267
+ self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
268
+ self.flatten_priors_train = torch.cat(
269
+ mlvl_priors_with_stride, dim=0)
270
+ self.stride_tensor = self.flatten_priors_train[..., [2]]
271
+
272
+ # gt info
273
+ gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
274
+ gt_labels = gt_info[:, :, :1]
275
+ gt_bboxes = gt_info[:, :, 1:] # xyxy
276
+ pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
277
+
278
+ # pred info
279
+ flatten_cls_preds = [
280
+ cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
281
+ self.num_classes)
282
+ for cls_pred in cls_scores
283
+ ]
284
+ flatten_pred_bboxes = [
285
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
286
+ for bbox_pred in bbox_preds
287
+ ]
288
+ # (bs, reg_max+1, n, 4) -> (bs, n, 4, reg_max+1)
289
+ flatten_pred_dists = [
290
+ bbox_pred_org.permute(0, 2, 3, 1).reshape(
291
+ num_imgs, -1, (self.head_module.reg_max + 1) * 4)
292
+ for bbox_pred_org in bbox_dist_preds
293
+ ]
294
+
295
+ flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
296
+ flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
297
+ flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
298
+ flatten_pred_bboxes = self.bbox_coder.decode(
299
+ self.flatten_priors_train[..., :2], flatten_pred_bboxes,
300
+ self.stride_tensor[..., 0])
301
+ pred_scores = torch.sigmoid(flatten_cls_preds)
302
+
303
+ if current_epoch < self.initial_epoch:
304
+ assigned_result = self.initial_assigner(
305
+ flatten_pred_bboxes.detach(), self.flatten_priors_train,
306
+ self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
307
+ else:
308
+ assigned_result = self.assigner(flatten_pred_bboxes.detach(),
309
+ pred_scores.detach(),
310
+ self.flatten_priors_train,
311
+ gt_labels, gt_bboxes,
312
+ pad_bbox_flag)
313
+
314
+ assigned_bboxes = assigned_result['assigned_bboxes']
315
+ assigned_scores = assigned_result['assigned_scores']
316
+ fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
317
+
318
+ # cls loss
319
+ with torch.cuda.amp.autocast(enabled=False):
320
+ loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores)
321
+
322
+ # rescale bbox
323
+ assigned_bboxes /= self.stride_tensor
324
+ flatten_pred_bboxes /= self.stride_tensor
325
+
326
+ assigned_scores_sum = assigned_scores.sum()
327
+ # reduce_mean between all gpus
328
+ assigned_scores_sum = torch.clamp(
329
+ reduce_mean(assigned_scores_sum), min=1)
330
+ loss_cls /= assigned_scores_sum
331
+
332
+ # select positive samples mask
333
+ num_pos = fg_mask_pre_prior.sum()
334
+ if num_pos > 0:
335
+ # when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
336
+ # will not report an error
337
+ # iou loss
338
+ prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
339
+ pred_bboxes_pos = torch.masked_select(
340
+ flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
341
+ assigned_bboxes_pos = torch.masked_select(
342
+ assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
343
+ bbox_weight = torch.masked_select(
344
+ assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
345
+ loss_bbox = self.loss_bbox(
346
+ pred_bboxes_pos,
347
+ assigned_bboxes_pos,
348
+ weight=bbox_weight,
349
+ avg_factor=assigned_scores_sum)
350
+
351
+ # dfl loss
352
+ dist_mask = fg_mask_pre_prior.unsqueeze(-1).repeat(
353
+ [1, 1, (self.head_module.reg_max + 1) * 4])
354
+
355
+ pred_dist_pos = torch.masked_select(
356
+ flatten_dist_preds,
357
+ dist_mask).reshape([-1, 4, self.head_module.reg_max + 1])
358
+ assigned_ltrb = self.bbox_coder.encode(
359
+ self.flatten_priors_train[..., :2] / self.stride_tensor,
360
+ assigned_bboxes,
361
+ max_dis=self.head_module.reg_max,
362
+ eps=0.01)
363
+ assigned_ltrb_pos = torch.masked_select(
364
+ assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
365
+ loss_dfl = self.loss_dfl(
366
+ pred_dist_pos.reshape(-1, self.head_module.reg_max + 1),
367
+ assigned_ltrb_pos.reshape(-1),
368
+ weight=bbox_weight.expand(-1, 4).reshape(-1),
369
+ avg_factor=assigned_scores_sum)
370
+ else:
371
+ loss_bbox = flatten_pred_bboxes.sum() * 0
372
+ loss_dfl = flatten_pred_bboxes.sum() * 0
373
+
374
+ return dict(loss_cls=loss_cls, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
mmyolo/models/dense_heads/rtmdet_head.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List, Sequence, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from mmcv.cnn import ConvModule, is_norm
7
+ from mmdet.models.task_modules.samplers import PseudoSampler
8
+ from mmdet.structures.bbox import distance2bbox
9
+ from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
10
+ OptInstanceList, OptMultiConfig, reduce_mean)
11
+ from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
12
+ normal_init)
13
+ from torch import Tensor
14
+
15
+ from mmyolo.registry import MODELS, TASK_UTILS
16
+ from ..utils import gt_instances_preprocess
17
+ from .yolov5_head import YOLOv5Head
18
+
19
+
20
+ @MODELS.register_module()
21
+ class RTMDetSepBNHeadModule(BaseModule):
22
+ """Detection Head of RTMDet.
23
+
24
+ Args:
25
+ num_classes (int): Number of categories excluding the background
26
+ category.
27
+ in_channels (int): Number of channels in the input feature map.
28
+ widen_factor (float): Width multiplier, multiply number of
29
+ channels in each layer by this amount. Defaults to 1.0.
30
+ num_base_priors (int): The number of priors (points) at a point
31
+ on the feature grid. Defaults to 1.
32
+ feat_channels (int): Number of hidden channels. Used in child classes.
33
+ Defaults to 256
34
+ stacked_convs (int): Number of stacking convs of the head.
35
+ Defaults to 2.
36
+ featmap_strides (Sequence[int]): Downsample factor of each feature map.
37
+ Defaults to (8, 16, 32).
38
+ share_conv (bool): Whether to share conv layers between stages.
39
+ Defaults to True.
40
+ pred_kernel_size (int): Kernel size of ``nn.Conv2d``. Defaults to 1.
41
+ conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
42
+ convolution layer. Defaults to None.
43
+ norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
44
+ layer. Defaults to ``dict(type='BN')``.
45
+ act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
46
+ Default: dict(type='SiLU', inplace=True).
47
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
48
+ list[dict], optional): Initialization config dict.
49
+ Defaults to None.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ num_classes: int,
55
+ in_channels: int,
56
+ widen_factor: float = 1.0,
57
+ num_base_priors: int = 1,
58
+ feat_channels: int = 256,
59
+ stacked_convs: int = 2,
60
+ featmap_strides: Sequence[int] = [8, 16, 32],
61
+ share_conv: bool = True,
62
+ pred_kernel_size: int = 1,
63
+ conv_cfg: OptConfigType = None,
64
+ norm_cfg: ConfigType = dict(type='BN'),
65
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
66
+ init_cfg: OptMultiConfig = None,
67
+ ):
68
+ super().__init__(init_cfg=init_cfg)
69
+ self.share_conv = share_conv
70
+ self.num_classes = num_classes
71
+ self.pred_kernel_size = pred_kernel_size
72
+ self.feat_channels = int(feat_channels * widen_factor)
73
+ self.stacked_convs = stacked_convs
74
+ self.num_base_priors = num_base_priors
75
+
76
+ self.conv_cfg = conv_cfg
77
+ self.norm_cfg = norm_cfg
78
+ self.act_cfg = act_cfg
79
+ self.featmap_strides = featmap_strides
80
+
81
+ self.in_channels = int(in_channels * widen_factor)
82
+
83
+ self._init_layers()
84
+
85
+ def _init_layers(self):
86
+ """Initialize layers of the head."""
87
+ self.cls_convs = nn.ModuleList()
88
+ self.reg_convs = nn.ModuleList()
89
+
90
+ self.rtm_cls = nn.ModuleList()
91
+ self.rtm_reg = nn.ModuleList()
92
+ for n in range(len(self.featmap_strides)):
93
+ cls_convs = nn.ModuleList()
94
+ reg_convs = nn.ModuleList()
95
+ for i in range(self.stacked_convs):
96
+ chn = self.in_channels if i == 0 else self.feat_channels
97
+ cls_convs.append(
98
+ ConvModule(
99
+ chn,
100
+ self.feat_channels,
101
+ 3,
102
+ stride=1,
103
+ padding=1,
104
+ conv_cfg=self.conv_cfg,
105
+ norm_cfg=self.norm_cfg,
106
+ act_cfg=self.act_cfg))
107
+ reg_convs.append(
108
+ ConvModule(
109
+ chn,
110
+ self.feat_channels,
111
+ 3,
112
+ stride=1,
113
+ padding=1,
114
+ conv_cfg=self.conv_cfg,
115
+ norm_cfg=self.norm_cfg,
116
+ act_cfg=self.act_cfg))
117
+ self.cls_convs.append(cls_convs)
118
+ self.reg_convs.append(reg_convs)
119
+
120
+ self.rtm_cls.append(
121
+ nn.Conv2d(
122
+ self.feat_channels,
123
+ self.num_base_priors * self.num_classes,
124
+ self.pred_kernel_size,
125
+ padding=self.pred_kernel_size // 2))
126
+ self.rtm_reg.append(
127
+ nn.Conv2d(
128
+ self.feat_channels,
129
+ self.num_base_priors * 4,
130
+ self.pred_kernel_size,
131
+ padding=self.pred_kernel_size // 2))
132
+
133
+ if self.share_conv:
134
+ for n in range(len(self.featmap_strides)):
135
+ for i in range(self.stacked_convs):
136
+ self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
137
+ self.reg_convs[n][i].conv = self.reg_convs[0][i].conv
138
+
139
+ def init_weights(self) -> None:
140
+ """Initialize weights of the head."""
141
+ # Use prior in model initialization to improve stability
142
+ super().init_weights()
143
+ for m in self.modules():
144
+ if isinstance(m, nn.Conv2d):
145
+ normal_init(m, mean=0, std=0.01)
146
+ if is_norm(m):
147
+ constant_init(m, 1)
148
+ bias_cls = bias_init_with_prob(0.01)
149
+ for rtm_cls, rtm_reg in zip(self.rtm_cls, self.rtm_reg):
150
+ normal_init(rtm_cls, std=0.01, bias=bias_cls)
151
+ normal_init(rtm_reg, std=0.01)
152
+
153
+ def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
154
+ """Forward features from the upstream network.
155
+
156
+ Args:
157
+ feats (tuple[Tensor]): Features from the upstream network, each is
158
+ a 4D-tensor.
159
+
160
+ Returns:
161
+ tuple: Usually a tuple of classification scores and bbox prediction
162
+ - cls_scores (list[Tensor]): Classification scores for all scale
163
+ levels, each is a 4D-tensor, the channels number is
164
+ num_base_priors * num_classes.
165
+ - bbox_preds (list[Tensor]): Box energies / deltas for all scale
166
+ levels, each is a 4D-tensor, the channels number is
167
+ num_base_priors * 4.
168
+ """
169
+
170
+ cls_scores = []
171
+ bbox_preds = []
172
+ for idx, x in enumerate(feats):
173
+ cls_feat = x
174
+ reg_feat = x
175
+
176
+ for cls_layer in self.cls_convs[idx]:
177
+ cls_feat = cls_layer(cls_feat)
178
+ cls_score = self.rtm_cls[idx](cls_feat)
179
+
180
+ for reg_layer in self.reg_convs[idx]:
181
+ reg_feat = reg_layer(reg_feat)
182
+
183
+ reg_dist = self.rtm_reg[idx](reg_feat)
184
+ cls_scores.append(cls_score)
185
+ bbox_preds.append(reg_dist)
186
+ return tuple(cls_scores), tuple(bbox_preds)
187
+
188
+
189
+ @MODELS.register_module()
190
+ class RTMDetHead(YOLOv5Head):
191
+ """RTMDet head.
192
+
193
+ Args:
194
+ head_module(ConfigType): Base module used for RTMDetHead
195
+ prior_generator: Points generator feature maps in
196
+ 2D points-based detectors.
197
+ bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
198
+ loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
199
+ loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
200
+ train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
201
+ anchor head. Defaults to None.
202
+ test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
203
+ anchor head. Defaults to None.
204
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
205
+ list[dict], optional): Initialization config dict.
206
+ Defaults to None.
207
+ """
208
+
209
+ def __init__(self,
210
+ head_module: ConfigType,
211
+ prior_generator: ConfigType = dict(
212
+ type='mmdet.MlvlPointGenerator',
213
+ offset=0,
214
+ strides=[8, 16, 32]),
215
+ bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
216
+ loss_cls: ConfigType = dict(
217
+ type='mmdet.QualityFocalLoss',
218
+ use_sigmoid=True,
219
+ beta=2.0,
220
+ loss_weight=1.0),
221
+ loss_bbox: ConfigType = dict(
222
+ type='mmdet.GIoULoss', loss_weight=2.0),
223
+ train_cfg: OptConfigType = None,
224
+ test_cfg: OptConfigType = None,
225
+ init_cfg: OptMultiConfig = None):
226
+
227
+ super().__init__(
228
+ head_module=head_module,
229
+ prior_generator=prior_generator,
230
+ bbox_coder=bbox_coder,
231
+ loss_cls=loss_cls,
232
+ loss_bbox=loss_bbox,
233
+ train_cfg=train_cfg,
234
+ test_cfg=test_cfg,
235
+ init_cfg=init_cfg)
236
+
237
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
238
+ if self.use_sigmoid_cls:
239
+ self.cls_out_channels = self.num_classes
240
+ else:
241
+ self.cls_out_channels = self.num_classes + 1
242
+ # rtmdet doesn't need loss_obj
243
+ self.loss_obj = None
244
+
245
+ def special_init(self):
246
+ """Since YOLO series algorithms will inherit from YOLOv5Head, but
247
+ different algorithms have special initialization process.
248
+
249
+ The special_init function is designed to deal with this situation.
250
+ """
251
+ if self.train_cfg:
252
+ self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
253
+ if self.train_cfg.get('sampler', None) is not None:
254
+ self.sampler = TASK_UTILS.build(
255
+ self.train_cfg.sampler, default_args=dict(context=self))
256
+ else:
257
+ self.sampler = PseudoSampler(context=self)
258
+
259
+ self.featmap_sizes_train = None
260
+ self.flatten_priors_train = None
261
+
262
+ def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
263
+ """Forward features from the upstream network.
264
+
265
+ Args:
266
+ x (Tuple[Tensor]): Features from the upstream network, each is
267
+ a 4D-tensor.
268
+ Returns:
269
+ Tuple[List]: A tuple of multi-level classification scores, bbox
270
+ predictions, and objectnesses.
271
+ """
272
+ return self.head_module(x)
273
+
274
+ def loss_by_feat(
275
+ self,
276
+ cls_scores: List[Tensor],
277
+ bbox_preds: List[Tensor],
278
+ batch_gt_instances: InstanceList,
279
+ batch_img_metas: List[dict],
280
+ batch_gt_instances_ignore: OptInstanceList = None) -> dict:
281
+ """Compute losses of the head.
282
+
283
+ Args:
284
+ cls_scores (list[Tensor]): Box scores for each scale level
285
+ Has shape (N, num_anchors * num_classes, H, W)
286
+ bbox_preds (list[Tensor]): Decoded box for each scale
287
+ level with shape (N, num_anchors * 4, H, W) in
288
+ [tl_x, tl_y, br_x, br_y] format.
289
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
290
+ gt_instance. It usually includes ``bboxes`` and ``labels``
291
+ attributes.
292
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
293
+ image size, scaling factor, etc.
294
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
295
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
296
+ data that is ignored during training and testing.
297
+ Defaults to None.
298
+
299
+ Returns:
300
+ dict[str, Tensor]: A dictionary of loss components.
301
+ """
302
+ num_imgs = len(batch_img_metas)
303
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
304
+ assert len(featmap_sizes) == self.prior_generator.num_levels
305
+
306
+ gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
307
+ gt_labels = gt_info[:, :, :1]
308
+ gt_bboxes = gt_info[:, :, 1:] # xyxy
309
+ pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
310
+
311
+ device = cls_scores[0].device
312
+
313
+ # If the shape does not equal, generate new one
314
+ if featmap_sizes != self.featmap_sizes_train:
315
+ self.featmap_sizes_train = featmap_sizes
316
+ mlvl_priors_with_stride = self.prior_generator.grid_priors(
317
+ featmap_sizes, device=device, with_stride=True)
318
+ self.flatten_priors_train = torch.cat(
319
+ mlvl_priors_with_stride, dim=0)
320
+
321
+ flatten_cls_scores = torch.cat([
322
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
323
+ self.cls_out_channels)
324
+ for cls_score in cls_scores
325
+ ], 1).contiguous()
326
+
327
+ flatten_bboxes = torch.cat([
328
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
329
+ for bbox_pred in bbox_preds
330
+ ], 1)
331
+ flatten_bboxes = flatten_bboxes * self.flatten_priors_train[..., -1,
332
+ None]
333
+ flatten_bboxes = distance2bbox(self.flatten_priors_train[..., :2],
334
+ flatten_bboxes)
335
+
336
+ assigned_result = self.assigner(flatten_bboxes.detach(),
337
+ flatten_cls_scores.detach(),
338
+ self.flatten_priors_train, gt_labels,
339
+ gt_bboxes, pad_bbox_flag)
340
+
341
+ labels = assigned_result['assigned_labels'].reshape(-1)
342
+ label_weights = assigned_result['assigned_labels_weights'].reshape(-1)
343
+ bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 4)
344
+ assign_metrics = assigned_result['assign_metrics'].reshape(-1)
345
+ cls_preds = flatten_cls_scores.reshape(-1, self.num_classes)
346
+ bbox_preds = flatten_bboxes.reshape(-1, 4)
347
+
348
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
349
+ bg_class_ind = self.num_classes
350
+ pos_inds = ((labels >= 0)
351
+ & (labels < bg_class_ind)).nonzero().squeeze(1)
352
+ avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item()
353
+
354
+ loss_cls = self.loss_cls(
355
+ cls_preds, (labels, assign_metrics),
356
+ label_weights,
357
+ avg_factor=avg_factor)
358
+
359
+ if len(pos_inds) > 0:
360
+ loss_bbox = self.loss_bbox(
361
+ bbox_preds[pos_inds],
362
+ bbox_targets[pos_inds],
363
+ weight=assign_metrics[pos_inds],
364
+ avg_factor=avg_factor)
365
+ else:
366
+ loss_bbox = bbox_preds.sum() * 0
367
+
368
+ return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
mmyolo/models/dense_heads/rtmdet_ins_head.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ from typing import List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from mmcv.cnn import ConvModule, is_norm
10
+ from mmcv.ops import batched_nms
11
+ from mmdet.models.utils import filter_scores_and_topk
12
+ from mmdet.structures.bbox import get_box_tensor, get_box_wh, scale_boxes
13
+ from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
14
+ OptInstanceList, OptMultiConfig)
15
+ from mmengine import ConfigDict
16
+ from mmengine.model import (BaseModule, bias_init_with_prob, constant_init,
17
+ normal_init)
18
+ from mmengine.structures import InstanceData
19
+ from torch import Tensor
20
+
21
+ from mmyolo.registry import MODELS
22
+ from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
23
+
24
+
25
+ class MaskFeatModule(BaseModule):
26
+ """Mask feature head used in RTMDet-Ins. Copy from mmdet.
27
+
28
+ Args:
29
+ in_channels (int): Number of channels in the input feature map.
30
+ feat_channels (int): Number of hidden channels of the mask feature
31
+ map branch.
32
+ stacked_convs (int): Number of convs in mask feature branch.
33
+ num_levels (int): The starting feature map level from RPN that
34
+ will be used to predict the mask feature map.
35
+ num_prototypes (int): Number of output channel of the mask feature
36
+ map branch. This is the channel count of the mask
37
+ feature map that to be dynamically convolved with the predicted
38
+ kernel.
39
+ act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
40
+ Default: dict(type='ReLU', inplace=True)
41
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ in_channels: int,
47
+ feat_channels: int = 256,
48
+ stacked_convs: int = 4,
49
+ num_levels: int = 3,
50
+ num_prototypes: int = 8,
51
+ act_cfg: ConfigType = dict(type='ReLU', inplace=True),
52
+ norm_cfg: ConfigType = dict(type='BN')
53
+ ) -> None:
54
+ super().__init__(init_cfg=None)
55
+ self.num_levels = num_levels
56
+ self.fusion_conv = nn.Conv2d(num_levels * in_channels, in_channels, 1)
57
+ convs = []
58
+ for i in range(stacked_convs):
59
+ in_c = in_channels if i == 0 else feat_channels
60
+ convs.append(
61
+ ConvModule(
62
+ in_c,
63
+ feat_channels,
64
+ 3,
65
+ padding=1,
66
+ act_cfg=act_cfg,
67
+ norm_cfg=norm_cfg))
68
+ self.stacked_convs = nn.Sequential(*convs)
69
+ self.projection = nn.Conv2d(
70
+ feat_channels, num_prototypes, kernel_size=1)
71
+
72
+ def forward(self, features: Tuple[Tensor, ...]) -> Tensor:
73
+ # multi-level feature fusion
74
+ fusion_feats = [features[0]]
75
+ size = features[0].shape[-2:]
76
+ for i in range(1, self.num_levels):
77
+ f = F.interpolate(features[i], size=size, mode='bilinear')
78
+ fusion_feats.append(f)
79
+ fusion_feats = torch.cat(fusion_feats, dim=1)
80
+ fusion_feats = self.fusion_conv(fusion_feats)
81
+ # pred mask feats
82
+ mask_features = self.stacked_convs(fusion_feats)
83
+ mask_features = self.projection(mask_features)
84
+ return mask_features
85
+
86
+
87
+ @MODELS.register_module()
88
+ class RTMDetInsSepBNHeadModule(RTMDetSepBNHeadModule):
89
+ """Detection and Instance Segmentation Head of RTMDet.
90
+
91
+ Args:
92
+ num_classes (int): Number of categories excluding the background
93
+ category.
94
+ num_prototypes (int): Number of mask prototype features extracted
95
+ from the mask head. Defaults to 8.
96
+ dyconv_channels (int): Channel of the dynamic conv layers.
97
+ Defaults to 8.
98
+ num_dyconvs (int): Number of the dynamic convolution layers.
99
+ Defaults to 3.
100
+ use_sigmoid_cls (bool): Use sigmoid for class prediction.
101
+ Defaults to True.
102
+ """
103
+
104
+ def __init__(self,
105
+ num_classes: int,
106
+ *args,
107
+ num_prototypes: int = 8,
108
+ dyconv_channels: int = 8,
109
+ num_dyconvs: int = 3,
110
+ use_sigmoid_cls: bool = True,
111
+ **kwargs):
112
+ self.num_prototypes = num_prototypes
113
+ self.num_dyconvs = num_dyconvs
114
+ self.dyconv_channels = dyconv_channels
115
+ self.use_sigmoid_cls = use_sigmoid_cls
116
+ if self.use_sigmoid_cls:
117
+ self.cls_out_channels = num_classes
118
+ else:
119
+ self.cls_out_channels = num_classes + 1
120
+ super().__init__(num_classes=num_classes, *args, **kwargs)
121
+
122
+ def _init_layers(self):
123
+ """Initialize layers of the head."""
124
+ self.cls_convs = nn.ModuleList()
125
+ self.reg_convs = nn.ModuleList()
126
+ self.kernel_convs = nn.ModuleList()
127
+
128
+ self.rtm_cls = nn.ModuleList()
129
+ self.rtm_reg = nn.ModuleList()
130
+ self.rtm_kernel = nn.ModuleList()
131
+ self.rtm_obj = nn.ModuleList()
132
+
133
+ # calculate num dynamic parameters
134
+ weight_nums, bias_nums = [], []
135
+ for i in range(self.num_dyconvs):
136
+ if i == 0:
137
+ weight_nums.append(
138
+ (self.num_prototypes + 2) * self.dyconv_channels)
139
+ bias_nums.append(self.dyconv_channels)
140
+ elif i == self.num_dyconvs - 1:
141
+ weight_nums.append(self.dyconv_channels)
142
+ bias_nums.append(1)
143
+ else:
144
+ weight_nums.append(self.dyconv_channels * self.dyconv_channels)
145
+ bias_nums.append(self.dyconv_channels)
146
+ self.weight_nums = weight_nums
147
+ self.bias_nums = bias_nums
148
+ self.num_gen_params = sum(weight_nums) + sum(bias_nums)
149
+ pred_pad_size = self.pred_kernel_size // 2
150
+
151
+ for n in range(len(self.featmap_strides)):
152
+ cls_convs = nn.ModuleList()
153
+ reg_convs = nn.ModuleList()
154
+ kernel_convs = nn.ModuleList()
155
+ for i in range(self.stacked_convs):
156
+ chn = self.in_channels if i == 0 else self.feat_channels
157
+ cls_convs.append(
158
+ ConvModule(
159
+ chn,
160
+ self.feat_channels,
161
+ 3,
162
+ stride=1,
163
+ padding=1,
164
+ conv_cfg=self.conv_cfg,
165
+ norm_cfg=self.norm_cfg,
166
+ act_cfg=self.act_cfg))
167
+ reg_convs.append(
168
+ ConvModule(
169
+ chn,
170
+ self.feat_channels,
171
+ 3,
172
+ stride=1,
173
+ padding=1,
174
+ conv_cfg=self.conv_cfg,
175
+ norm_cfg=self.norm_cfg,
176
+ act_cfg=self.act_cfg))
177
+ kernel_convs.append(
178
+ ConvModule(
179
+ chn,
180
+ self.feat_channels,
181
+ 3,
182
+ stride=1,
183
+ padding=1,
184
+ conv_cfg=self.conv_cfg,
185
+ norm_cfg=self.norm_cfg,
186
+ act_cfg=self.act_cfg))
187
+ self.cls_convs.append(cls_convs)
188
+ self.reg_convs.append(cls_convs)
189
+ self.kernel_convs.append(kernel_convs)
190
+
191
+ self.rtm_cls.append(
192
+ nn.Conv2d(
193
+ self.feat_channels,
194
+ self.num_base_priors * self.cls_out_channels,
195
+ self.pred_kernel_size,
196
+ padding=pred_pad_size))
197
+ self.rtm_reg.append(
198
+ nn.Conv2d(
199
+ self.feat_channels,
200
+ self.num_base_priors * 4,
201
+ self.pred_kernel_size,
202
+ padding=pred_pad_size))
203
+ self.rtm_kernel.append(
204
+ nn.Conv2d(
205
+ self.feat_channels,
206
+ self.num_gen_params,
207
+ self.pred_kernel_size,
208
+ padding=pred_pad_size))
209
+
210
+ if self.share_conv:
211
+ for n in range(len(self.featmap_strides)):
212
+ for i in range(self.stacked_convs):
213
+ self.cls_convs[n][i].conv = self.cls_convs[0][i].conv
214
+ self.reg_convs[n][i].conv = self.reg_convs[0][i].conv
215
+
216
+ self.mask_head = MaskFeatModule(
217
+ in_channels=self.in_channels,
218
+ feat_channels=self.feat_channels,
219
+ stacked_convs=4,
220
+ num_levels=len(self.featmap_strides),
221
+ num_prototypes=self.num_prototypes,
222
+ act_cfg=self.act_cfg,
223
+ norm_cfg=self.norm_cfg)
224
+
225
+ def init_weights(self) -> None:
226
+ """Initialize weights of the head."""
227
+ for m in self.modules():
228
+ if isinstance(m, nn.Conv2d):
229
+ normal_init(m, mean=0, std=0.01)
230
+ if is_norm(m):
231
+ constant_init(m, 1)
232
+ bias_cls = bias_init_with_prob(0.01)
233
+ for rtm_cls, rtm_reg, rtm_kernel in zip(self.rtm_cls, self.rtm_reg,
234
+ self.rtm_kernel):
235
+ normal_init(rtm_cls, std=0.01, bias=bias_cls)
236
+ normal_init(rtm_reg, std=0.01, bias=1)
237
+
238
+ def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
239
+ """Forward features from the upstream network.
240
+
241
+ Args:
242
+ feats (tuple[Tensor]): Features from the upstream network, each is
243
+ a 4D-tensor.
244
+
245
+ Returns:
246
+ tuple: Usually a tuple of classification scores and bbox prediction
247
+ - cls_scores (list[Tensor]): Classification scores for all scale
248
+ levels, each is a 4D-tensor, the channels number is
249
+ num_base_priors * num_classes.
250
+ - bbox_preds (list[Tensor]): Box energies / deltas for all scale
251
+ levels, each is a 4D-tensor, the channels number is
252
+ num_base_priors * 4.
253
+ - kernel_preds (list[Tensor]): Dynamic conv kernels for all scale
254
+ levels, each is a 4D-tensor, the channels number is
255
+ num_gen_params.
256
+ - mask_feat (Tensor): Mask prototype features.
257
+ Has shape (batch_size, num_prototypes, H, W).
258
+ """
259
+ mask_feat = self.mask_head(feats)
260
+
261
+ cls_scores = []
262
+ bbox_preds = []
263
+ kernel_preds = []
264
+ for idx, (x, stride) in enumerate(zip(feats, self.featmap_strides)):
265
+ cls_feat = x
266
+ reg_feat = x
267
+ kernel_feat = x
268
+
269
+ for cls_layer in self.cls_convs[idx]:
270
+ cls_feat = cls_layer(cls_feat)
271
+ cls_score = self.rtm_cls[idx](cls_feat)
272
+
273
+ for kernel_layer in self.kernel_convs[idx]:
274
+ kernel_feat = kernel_layer(kernel_feat)
275
+ kernel_pred = self.rtm_kernel[idx](kernel_feat)
276
+
277
+ for reg_layer in self.reg_convs[idx]:
278
+ reg_feat = reg_layer(reg_feat)
279
+ reg_dist = self.rtm_reg[idx](reg_feat)
280
+
281
+ cls_scores.append(cls_score)
282
+ bbox_preds.append(reg_dist)
283
+ kernel_preds.append(kernel_pred)
284
+ return tuple(cls_scores), tuple(bbox_preds), tuple(
285
+ kernel_preds), mask_feat
286
+
287
+
288
+ @MODELS.register_module()
289
+ class RTMDetInsSepBNHead(RTMDetHead):
290
+ """RTMDet Instance Segmentation head.
291
+
292
+ Args:
293
+ head_module(ConfigType): Base module used for RTMDetInsSepBNHead
294
+ prior_generator: Points generator feature maps in
295
+ 2D points-based detectors.
296
+ bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
297
+ loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
298
+ loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
299
+ loss_mask (:obj:`ConfigDict` or dict): Config of mask loss.
300
+ train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
301
+ anchor head. Defaults to None.
302
+ test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
303
+ anchor head. Defaults to None.
304
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
305
+ list[dict], optional): Initialization config dict.
306
+ Defaults to None.
307
+ """
308
+
309
+ def __init__(self,
310
+ head_module: ConfigType,
311
+ prior_generator: ConfigType = dict(
312
+ type='mmdet.MlvlPointGenerator',
313
+ offset=0,
314
+ strides=[8, 16, 32]),
315
+ bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
316
+ loss_cls: ConfigType = dict(
317
+ type='mmdet.QualityFocalLoss',
318
+ use_sigmoid=True,
319
+ beta=2.0,
320
+ loss_weight=1.0),
321
+ loss_bbox: ConfigType = dict(
322
+ type='mmdet.GIoULoss', loss_weight=2.0),
323
+ loss_mask=dict(
324
+ type='mmdet.DiceLoss',
325
+ loss_weight=2.0,
326
+ eps=5e-6,
327
+ reduction='mean'),
328
+ train_cfg: OptConfigType = None,
329
+ test_cfg: OptConfigType = None,
330
+ init_cfg: OptMultiConfig = None):
331
+
332
+ super().__init__(
333
+ head_module=head_module,
334
+ prior_generator=prior_generator,
335
+ bbox_coder=bbox_coder,
336
+ loss_cls=loss_cls,
337
+ loss_bbox=loss_bbox,
338
+ train_cfg=train_cfg,
339
+ test_cfg=test_cfg,
340
+ init_cfg=init_cfg)
341
+
342
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
343
+ if isinstance(self.head_module, RTMDetInsSepBNHeadModule):
344
+ assert self.use_sigmoid_cls == self.head_module.use_sigmoid_cls
345
+ self.loss_mask = MODELS.build(loss_mask)
346
+
347
+ def predict_by_feat(self,
348
+ cls_scores: List[Tensor],
349
+ bbox_preds: List[Tensor],
350
+ kernel_preds: List[Tensor],
351
+ mask_feats: Tensor,
352
+ score_factors: Optional[List[Tensor]] = None,
353
+ batch_img_metas: Optional[List[dict]] = None,
354
+ cfg: Optional[ConfigDict] = None,
355
+ rescale: bool = True,
356
+ with_nms: bool = True) -> List[InstanceData]:
357
+ """Transform a batch of output features extracted from the head into
358
+ bbox results.
359
+
360
+ Note: When score_factors is not None, the cls_scores are
361
+ usually multiplied by it then obtain the real score used in NMS.
362
+
363
+ Args:
364
+ cls_scores (list[Tensor]): Classification scores for all
365
+ scale levels, each is a 4D-tensor, has shape
366
+ (batch_size, num_priors * num_classes, H, W).
367
+ bbox_preds (list[Tensor]): Box energies / deltas for all
368
+ scale levels, each is a 4D-tensor, has shape
369
+ (batch_size, num_priors * 4, H, W).
370
+ kernel_preds (list[Tensor]): Kernel predictions of dynamic
371
+ convs for all scale levels, each is a 4D-tensor, has shape
372
+ (batch_size, num_params, H, W).
373
+ mask_feats (Tensor): Mask prototype features extracted from the
374
+ mask head, has shape (batch_size, num_prototypes, H, W).
375
+ score_factors (list[Tensor], optional): Score factor for
376
+ all scale level, each is a 4D-tensor, has shape
377
+ (batch_size, num_priors * 1, H, W). Defaults to None.
378
+ batch_img_metas (list[dict], Optional): Batch image meta info.
379
+ Defaults to None.
380
+ cfg (ConfigDict, optional): Test / postprocessing
381
+ configuration, if None, test_cfg would be used.
382
+ Defaults to None.
383
+ rescale (bool): If True, return boxes in original image space.
384
+ Defaults to False.
385
+ with_nms (bool): If True, do nms before return boxes.
386
+ Defaults to True.
387
+
388
+ Returns:
389
+ list[:obj:`InstanceData`]: Object detection and instance
390
+ segmentation results of each image after the post process.
391
+ Each item usually contains following keys.
392
+
393
+ - scores (Tensor): Classification scores, has a shape
394
+ (num_instance, )
395
+ - labels (Tensor): Labels of bboxes, has a shape
396
+ (num_instances, ).
397
+ - bboxes (Tensor): Has a shape (num_instances, 4),
398
+ the last dimension 4 arrange as (x1, y1, x2, y2).
399
+ - masks (Tensor): Has a shape (num_instances, h, w).
400
+ """
401
+ cfg = self.test_cfg if cfg is None else cfg
402
+ cfg = copy.deepcopy(cfg)
403
+
404
+ multi_label = cfg.multi_label
405
+ multi_label &= self.num_classes > 1
406
+ cfg.multi_label = multi_label
407
+
408
+ num_imgs = len(batch_img_metas)
409
+ featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
410
+
411
+ # If the shape does not change, use the previous mlvl_priors
412
+ if featmap_sizes != self.featmap_sizes:
413
+ self.mlvl_priors = self.prior_generator.grid_priors(
414
+ featmap_sizes,
415
+ dtype=cls_scores[0].dtype,
416
+ device=cls_scores[0].device,
417
+ with_stride=True)
418
+ self.featmap_sizes = featmap_sizes
419
+ flatten_priors = torch.cat(self.mlvl_priors)
420
+
421
+ mlvl_strides = [
422
+ flatten_priors.new_full(
423
+ (featmap_size.numel() * self.num_base_priors, ), stride) for
424
+ featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
425
+ ]
426
+ flatten_stride = torch.cat(mlvl_strides)
427
+
428
+ # flatten cls_scores, bbox_preds
429
+ flatten_cls_scores = [
430
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
431
+ self.num_classes)
432
+ for cls_score in cls_scores
433
+ ]
434
+ flatten_bbox_preds = [
435
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
436
+ for bbox_pred in bbox_preds
437
+ ]
438
+ flatten_kernel_preds = [
439
+ kernel_pred.permute(0, 2, 3,
440
+ 1).reshape(num_imgs, -1,
441
+ self.head_module.num_gen_params)
442
+ for kernel_pred in kernel_preds
443
+ ]
444
+
445
+ flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
446
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
447
+ flatten_decoded_bboxes = self.bbox_coder.decode(
448
+ flatten_priors[..., :2].unsqueeze(0), flatten_bbox_preds,
449
+ flatten_stride)
450
+
451
+ flatten_kernel_preds = torch.cat(flatten_kernel_preds, dim=1)
452
+
453
+ results_list = []
454
+ for (bboxes, scores, kernel_pred, mask_feat,
455
+ img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
456
+ flatten_kernel_preds, mask_feats,
457
+ batch_img_metas):
458
+ ori_shape = img_meta['ori_shape']
459
+ scale_factor = img_meta['scale_factor']
460
+ if 'pad_param' in img_meta:
461
+ pad_param = img_meta['pad_param']
462
+ else:
463
+ pad_param = None
464
+
465
+ score_thr = cfg.get('score_thr', -1)
466
+ if scores.shape[0] == 0:
467
+ empty_results = InstanceData()
468
+ empty_results.bboxes = bboxes
469
+ empty_results.scores = scores[:, 0]
470
+ empty_results.labels = scores[:, 0].int()
471
+ h, w = ori_shape[:2] if rescale else img_meta['img_shape'][:2]
472
+ empty_results.masks = torch.zeros(
473
+ size=(0, h, w), dtype=torch.bool, device=bboxes.device)
474
+ results_list.append(empty_results)
475
+ continue
476
+
477
+ nms_pre = cfg.get('nms_pre', 100000)
478
+ if cfg.multi_label is False:
479
+ scores, labels = scores.max(1, keepdim=True)
480
+ scores, _, keep_idxs, results = filter_scores_and_topk(
481
+ scores,
482
+ score_thr,
483
+ nms_pre,
484
+ results=dict(
485
+ labels=labels[:, 0],
486
+ kernel_pred=kernel_pred,
487
+ priors=flatten_priors))
488
+ labels = results['labels']
489
+ kernel_pred = results['kernel_pred']
490
+ priors = results['priors']
491
+ else:
492
+ out = filter_scores_and_topk(
493
+ scores,
494
+ score_thr,
495
+ nms_pre,
496
+ results=dict(
497
+ kernel_pred=kernel_pred, priors=flatten_priors))
498
+ scores, labels, keep_idxs, filtered_results = out
499
+ kernel_pred = filtered_results['kernel_pred']
500
+ priors = filtered_results['priors']
501
+
502
+ results = InstanceData(
503
+ scores=scores,
504
+ labels=labels,
505
+ bboxes=bboxes[keep_idxs],
506
+ kernels=kernel_pred,
507
+ priors=priors)
508
+
509
+ if rescale:
510
+ if pad_param is not None:
511
+ results.bboxes -= results.bboxes.new_tensor([
512
+ pad_param[2], pad_param[0], pad_param[2], pad_param[0]
513
+ ])
514
+ results.bboxes /= results.bboxes.new_tensor(
515
+ scale_factor).repeat((1, 2))
516
+
517
+ if cfg.get('yolox_style', False):
518
+ # do not need max_per_img
519
+ cfg.max_per_img = len(results)
520
+
521
+ results = self._bbox_mask_post_process(
522
+ results=results,
523
+ mask_feat=mask_feat,
524
+ cfg=cfg,
525
+ rescale_bbox=False,
526
+ rescale_mask=rescale,
527
+ with_nms=with_nms,
528
+ pad_param=pad_param,
529
+ img_meta=img_meta)
530
+ results.bboxes[:, 0::2].clamp_(0, ori_shape[1])
531
+ results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
532
+
533
+ results_list.append(results)
534
+ return results_list
535
+
536
+ def _bbox_mask_post_process(
537
+ self,
538
+ results: InstanceData,
539
+ mask_feat: Tensor,
540
+ cfg: ConfigDict,
541
+ rescale_bbox: bool = False,
542
+ rescale_mask: bool = True,
543
+ with_nms: bool = True,
544
+ pad_param: Optional[np.ndarray] = None,
545
+ img_meta: Optional[dict] = None) -> InstanceData:
546
+ """bbox and mask post-processing method.
547
+
548
+ The boxes would be rescaled to the original image scale and do
549
+ the nms operation. Usually `with_nms` is False is used for aug test.
550
+
551
+ Args:
552
+ results (:obj:`InstaceData`): Detection instance results,
553
+ each item has shape (num_bboxes, ).
554
+ mask_feat (Tensor): Mask prototype features extracted from the
555
+ mask head, has shape (batch_size, num_prototypes, H, W).
556
+ cfg (ConfigDict): Test / postprocessing configuration,
557
+ if None, test_cfg would be used.
558
+ rescale_bbox (bool): If True, return boxes in original image space.
559
+ Default to False.
560
+ rescale_mask (bool): If True, return masks in original image space.
561
+ Default to True.
562
+ with_nms (bool): If True, do nms before return boxes.
563
+ Default to True.
564
+ img_meta (dict, optional): Image meta info. Defaults to None.
565
+
566
+ Returns:
567
+ :obj:`InstanceData`: Detection results of each image
568
+ after the post process.
569
+ Each item usually contains following keys.
570
+
571
+ - scores (Tensor): Classification scores, has a shape
572
+ (num_instance, )
573
+ - labels (Tensor): Labels of bboxes, has a shape
574
+ (num_instances, ).
575
+ - bboxes (Tensor): Has a shape (num_instances, 4),
576
+ the last dimension 4 arrange as (x1, y1, x2, y2).
577
+ - masks (Tensor): Has a shape (num_instances, h, w).
578
+ """
579
+ if rescale_bbox:
580
+ assert img_meta.get('scale_factor') is not None
581
+ scale_factor = [1 / s for s in img_meta['scale_factor']]
582
+ results.bboxes = scale_boxes(results.bboxes, scale_factor)
583
+
584
+ if hasattr(results, 'score_factors'):
585
+ # TODO: Add sqrt operation in order to be consistent with
586
+ # the paper.
587
+ score_factors = results.pop('score_factors')
588
+ results.scores = results.scores * score_factors
589
+
590
+ # filter small size bboxes
591
+ if cfg.get('min_bbox_size', -1) >= 0:
592
+ w, h = get_box_wh(results.bboxes)
593
+ valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
594
+ if not valid_mask.all():
595
+ results = results[valid_mask]
596
+
597
+ # TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg
598
+ assert with_nms, 'with_nms must be True for RTMDet-Ins'
599
+ if results.bboxes.numel() > 0:
600
+ bboxes = get_box_tensor(results.bboxes)
601
+ det_bboxes, keep_idxs = batched_nms(bboxes, results.scores,
602
+ results.labels, cfg.nms)
603
+ results = results[keep_idxs]
604
+ # some nms would reweight the score, such as softnms
605
+ results.scores = det_bboxes[:, -1]
606
+ results = results[:cfg.max_per_img]
607
+
608
+ # process masks
609
+ mask_logits = self._mask_predict_by_feat(mask_feat,
610
+ results.kernels,
611
+ results.priors)
612
+
613
+ stride = self.prior_generator.strides[0][0]
614
+ mask_logits = F.interpolate(
615
+ mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear')
616
+ if rescale_mask:
617
+ # TODO: When use mmdet.Resize or mmdet.Pad, will meet bug
618
+ # Use img_meta to crop and resize
619
+ ori_h, ori_w = img_meta['ori_shape'][:2]
620
+ if isinstance(pad_param, np.ndarray):
621
+ pad_param = pad_param.astype(np.int32)
622
+ crop_y1, crop_y2 = pad_param[
623
+ 0], mask_logits.shape[-2] - pad_param[1]
624
+ crop_x1, crop_x2 = pad_param[
625
+ 2], mask_logits.shape[-1] - pad_param[3]
626
+ mask_logits = mask_logits[..., crop_y1:crop_y2,
627
+ crop_x1:crop_x2]
628
+ mask_logits = F.interpolate(
629
+ mask_logits,
630
+ size=[ori_h, ori_w],
631
+ mode='bilinear',
632
+ align_corners=False)
633
+
634
+ masks = mask_logits.sigmoid().squeeze(0)
635
+ masks = masks > cfg.mask_thr_binary
636
+ results.masks = masks
637
+ else:
638
+ h, w = img_meta['ori_shape'][:2] if rescale_mask else img_meta[
639
+ 'img_shape'][:2]
640
+ results.masks = torch.zeros(
641
+ size=(results.bboxes.shape[0], h, w),
642
+ dtype=torch.bool,
643
+ device=results.bboxes.device)
644
+ return results
645
+
646
+ def _mask_predict_by_feat(self, mask_feat: Tensor, kernels: Tensor,
647
+ priors: Tensor) -> Tensor:
648
+ """Generate mask logits from mask features with dynamic convs.
649
+
650
+ Args:
651
+ mask_feat (Tensor): Mask prototype features.
652
+ Has shape (num_prototypes, H, W).
653
+ kernels (Tensor): Kernel parameters for each instance.
654
+ Has shape (num_instance, num_params)
655
+ priors (Tensor): Center priors for each instance.
656
+ Has shape (num_instance, 4).
657
+ Returns:
658
+ Tensor: Instance segmentation masks for each instance.
659
+ Has shape (num_instance, H, W).
660
+ """
661
+ num_inst = kernels.shape[0]
662
+ h, w = mask_feat.size()[-2:]
663
+ if num_inst < 1:
664
+ return torch.empty(
665
+ size=(num_inst, h, w),
666
+ dtype=mask_feat.dtype,
667
+ device=mask_feat.device)
668
+ if len(mask_feat.shape) < 4:
669
+ mask_feat.unsqueeze(0)
670
+
671
+ coord = self.prior_generator.single_level_grid_priors(
672
+ (h, w), level_idx=0, device=mask_feat.device).reshape(1, -1, 2)
673
+ num_inst = priors.shape[0]
674
+ points = priors[:, :2].reshape(-1, 1, 2)
675
+ strides = priors[:, 2:].reshape(-1, 1, 2)
676
+ relative_coord = (points - coord).permute(0, 2, 1) / (
677
+ strides[..., 0].reshape(-1, 1, 1) * 8)
678
+ relative_coord = relative_coord.reshape(num_inst, 2, h, w)
679
+
680
+ mask_feat = torch.cat(
681
+ [relative_coord,
682
+ mask_feat.repeat(num_inst, 1, 1, 1)], dim=1)
683
+ weights, biases = self.parse_dynamic_params(kernels)
684
+
685
+ n_layers = len(weights)
686
+ x = mask_feat.reshape(1, -1, h, w)
687
+ for i, (weight, bias) in enumerate(zip(weights, biases)):
688
+ x = F.conv2d(
689
+ x, weight, bias=bias, stride=1, padding=0, groups=num_inst)
690
+ if i < n_layers - 1:
691
+ x = F.relu(x)
692
+ x = x.reshape(num_inst, h, w)
693
+ return x
694
+
695
+ def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple:
696
+ """split kernel head prediction to conv weight and bias."""
697
+ n_inst = flatten_kernels.size(0)
698
+ n_layers = len(self.head_module.weight_nums)
699
+ params_splits = list(
700
+ torch.split_with_sizes(
701
+ flatten_kernels,
702
+ self.head_module.weight_nums + self.head_module.bias_nums,
703
+ dim=1))
704
+ weight_splits = params_splits[:n_layers]
705
+ bias_splits = params_splits[n_layers:]
706
+ for i in range(n_layers):
707
+ if i < n_layers - 1:
708
+ weight_splits[i] = weight_splits[i].reshape(
709
+ n_inst * self.head_module.dyconv_channels, -1, 1, 1)
710
+ bias_splits[i] = bias_splits[i].reshape(
711
+ n_inst * self.head_module.dyconv_channels)
712
+ else:
713
+ weight_splits[i] = weight_splits[i].reshape(n_inst, -1, 1, 1)
714
+ bias_splits[i] = bias_splits[i].reshape(n_inst)
715
+
716
+ return weight_splits, bias_splits
717
+
718
+ def loss_by_feat(
719
+ self,
720
+ cls_scores: List[Tensor],
721
+ bbox_preds: List[Tensor],
722
+ batch_gt_instances: InstanceList,
723
+ batch_img_metas: List[dict],
724
+ batch_gt_instances_ignore: OptInstanceList = None) -> dict:
725
+ raise NotImplementedError
mmyolo/models/dense_heads/rtmdet_rotated_head.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import warnings
4
+ from typing import List, Optional, Sequence, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from mmdet.models.utils import filter_scores_and_topk
9
+ from mmdet.structures.bbox import HorizontalBoxes, distance2bbox
10
+ from mmdet.structures.bbox.transforms import bbox_cxcywh_to_xyxy, scale_boxes
11
+ from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
12
+ OptInstanceList, OptMultiConfig, reduce_mean)
13
+ from mmengine.config import ConfigDict
14
+ from mmengine.model import normal_init
15
+ from mmengine.structures import InstanceData
16
+ from torch import Tensor
17
+
18
+ from mmyolo.registry import MODELS, TASK_UTILS
19
+ from ..utils import gt_instances_preprocess
20
+ from .rtmdet_head import RTMDetHead, RTMDetSepBNHeadModule
21
+
22
+ try:
23
+ from mmrotate.structures.bbox import RotatedBoxes, distance2obb
24
+ MMROTATE_AVAILABLE = True
25
+ except ImportError:
26
+ RotatedBoxes = None
27
+ distance2obb = None
28
+ MMROTATE_AVAILABLE = False
29
+
30
+
31
+ @MODELS.register_module()
32
+ class RTMDetRotatedSepBNHeadModule(RTMDetSepBNHeadModule):
33
+ """Detection Head Module of RTMDet-R.
34
+
35
+ Compared with RTMDet Detection Head Module, RTMDet-R adds
36
+ a conv for angle prediction.
37
+ An `angle_out_dim` arg is added, which is generated by the
38
+ angle_coder module and controls the angle pred dim.
39
+
40
+ Args:
41
+ num_classes (int): Number of categories excluding the background
42
+ category.
43
+ in_channels (int): Number of channels in the input feature map.
44
+ widen_factor (float): Width multiplier, multiply number of
45
+ channels in each layer by this amount. Defaults to 1.0.
46
+ num_base_priors (int): The number of priors (points) at a point
47
+ on the feature grid. Defaults to 1.
48
+ feat_channels (int): Number of hidden channels. Used in child classes.
49
+ Defaults to 256
50
+ stacked_convs (int): Number of stacking convs of the head.
51
+ Defaults to 2.
52
+ featmap_strides (Sequence[int]): Downsample factor of each feature map.
53
+ Defaults to (8, 16, 32).
54
+ share_conv (bool): Whether to share conv layers between stages.
55
+ Defaults to True.
56
+ pred_kernel_size (int): Kernel size of ``nn.Conv2d``. Defaults to 1.
57
+ angle_out_dim (int): Encoded length of angle, will passed by head.
58
+ Defaults to 1.
59
+ conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
60
+ convolution layer. Defaults to None.
61
+ norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
62
+ layer. Defaults to ``dict(type='BN')``.
63
+ act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
64
+ Default: dict(type='SiLU', inplace=True).
65
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
66
+ list[dict], optional): Initialization config dict.
67
+ Defaults to None.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ num_classes: int,
73
+ in_channels: int,
74
+ widen_factor: float = 1.0,
75
+ num_base_priors: int = 1,
76
+ feat_channels: int = 256,
77
+ stacked_convs: int = 2,
78
+ featmap_strides: Sequence[int] = [8, 16, 32],
79
+ share_conv: bool = True,
80
+ pred_kernel_size: int = 1,
81
+ angle_out_dim: int = 1,
82
+ conv_cfg: OptConfigType = None,
83
+ norm_cfg: ConfigType = dict(type='BN'),
84
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
85
+ init_cfg: OptMultiConfig = None,
86
+ ):
87
+ self.angle_out_dim = angle_out_dim
88
+ super().__init__(
89
+ num_classes=num_classes,
90
+ in_channels=in_channels,
91
+ widen_factor=widen_factor,
92
+ num_base_priors=num_base_priors,
93
+ feat_channels=feat_channels,
94
+ stacked_convs=stacked_convs,
95
+ featmap_strides=featmap_strides,
96
+ share_conv=share_conv,
97
+ pred_kernel_size=pred_kernel_size,
98
+ conv_cfg=conv_cfg,
99
+ norm_cfg=norm_cfg,
100
+ act_cfg=act_cfg,
101
+ init_cfg=init_cfg)
102
+
103
+ def _init_layers(self):
104
+ """Initialize layers of the head."""
105
+ super()._init_layers()
106
+ self.rtm_ang = nn.ModuleList()
107
+ for _ in range(len(self.featmap_strides)):
108
+ self.rtm_ang.append(
109
+ nn.Conv2d(
110
+ self.feat_channels,
111
+ self.num_base_priors * self.angle_out_dim,
112
+ self.pred_kernel_size,
113
+ padding=self.pred_kernel_size // 2))
114
+
115
+ def init_weights(self) -> None:
116
+ """Initialize weights of the head."""
117
+ # Use prior in model initialization to improve stability
118
+ super().init_weights()
119
+ for rtm_ang in self.rtm_ang:
120
+ normal_init(rtm_ang, std=0.01)
121
+
122
+ def forward(self, feats: Tuple[Tensor, ...]) -> tuple:
123
+ """Forward features from the upstream network.
124
+
125
+ Args:
126
+ feats (tuple[Tensor]): Features from the upstream network, each is
127
+ a 4D-tensor.
128
+
129
+ Returns:
130
+ tuple: Usually a tuple of classification scores and bbox prediction
131
+ - cls_scores (list[Tensor]): Classification scores for all scale
132
+ levels, each is a 4D-tensor, the channels number is
133
+ num_base_priors * num_classes.
134
+ - bbox_preds (list[Tensor]): Box energies / deltas for all scale
135
+ levels, each is a 4D-tensor, the channels number is
136
+ num_base_priors * 4.
137
+ - angle_preds (list[Tensor]): Angle prediction for all scale
138
+ levels, each is a 4D-tensor, the channels number is
139
+ num_base_priors * angle_out_dim.
140
+ """
141
+
142
+ cls_scores = []
143
+ bbox_preds = []
144
+ angle_preds = []
145
+ for idx, x in enumerate(feats):
146
+ cls_feat = x
147
+ reg_feat = x
148
+
149
+ for cls_layer in self.cls_convs[idx]:
150
+ cls_feat = cls_layer(cls_feat)
151
+ cls_score = self.rtm_cls[idx](cls_feat)
152
+
153
+ for reg_layer in self.reg_convs[idx]:
154
+ reg_feat = reg_layer(reg_feat)
155
+
156
+ reg_dist = self.rtm_reg[idx](reg_feat)
157
+ angle_pred = self.rtm_ang[idx](reg_feat)
158
+
159
+ cls_scores.append(cls_score)
160
+ bbox_preds.append(reg_dist)
161
+ angle_preds.append(angle_pred)
162
+ return tuple(cls_scores), tuple(bbox_preds), tuple(angle_preds)
163
+
164
+
165
+ @MODELS.register_module()
166
+ class RTMDetRotatedHead(RTMDetHead):
167
+ """RTMDet-R head.
168
+
169
+ Compared with RTMDetHead, RTMDetRotatedHead add some args to support
170
+ rotated object detection.
171
+
172
+ - `angle_version` used to limit angle_range during training.
173
+ - `angle_coder` used to encode and decode angle, which is similar
174
+ to bbox_coder.
175
+ - `use_hbbox_loss` and `loss_angle` allow custom regression loss
176
+ calculation for rotated box.
177
+
178
+ There are three combination options for regression:
179
+
180
+ 1. `use_hbbox_loss=False` and loss_angle is None.
181
+
182
+ .. code:: text
183
+
184
+ bbox_pred────(tblr)───┐
185
+
186
+ angle_pred decode──►rbox_pred──(xywha)─►loss_bbox
187
+ │ ▲
188
+ └────►decode──(a)─┘
189
+
190
+ 2. `use_hbbox_loss=False` and loss_angle is specified.
191
+ A angle loss is added on angle_pred.
192
+
193
+ .. code:: text
194
+
195
+ bbox_pred────(tblr)───┐
196
+
197
+ angle_pred decode──►rbox_pred──(xywha)─►loss_bbox
198
+ │ ▲
199
+ ├────►decode──(a)─┘
200
+
201
+ └───────────────────────────────────────────►loss_angle
202
+
203
+ 3. `use_hbbox_loss=True` and loss_angle is specified.
204
+ In this case the loss_angle must be set.
205
+
206
+ .. code:: text
207
+
208
+ bbox_pred──(tblr)──►decode──►hbox_pred──(xyxy)──►loss_bbox
209
+
210
+ angle_pred──────────────────────────────────────►loss_angle
211
+
212
+ - There's a `decoded_with_angle` flag in test_cfg, which is similar
213
+ to training process.
214
+
215
+ When `decoded_with_angle=True`:
216
+
217
+ .. code:: text
218
+
219
+ bbox_pred────(tblr)───┐
220
+
221
+ angle_pred decode──(xywha)──►rbox_pred
222
+ │ ▲
223
+ └────►decode──(a)─┘
224
+
225
+ When `decoded_with_angle=False`:
226
+
227
+ .. code:: text
228
+
229
+ bbox_pred──(tblr)─►decode
230
+ │ (xyxy)
231
+
232
+ format───(xywh)──►concat──(xywha)──►rbox_pred
233
+
234
+ angle_pred────────►decode────(a)───────┘
235
+
236
+ Args:
237
+ head_module(ConfigType): Base module used for RTMDetRotatedHead.
238
+ prior_generator: Points generator feature maps in
239
+ 2D points-based detectors.
240
+ bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
241
+ loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
242
+ loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
243
+ angle_version (str): Angle representations. Defaults to 'le90'.
244
+ use_hbbox_loss (bool): If true, use horizontal bbox loss and
245
+ loss_angle should not be None. Default to False.
246
+ angle_coder (:obj:`ConfigDict` or dict): Config of angle coder.
247
+ loss_angle (:obj:`ConfigDict` or dict, optional): Config of angle loss.
248
+ train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
249
+ anchor head. Defaults to None.
250
+ test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
251
+ anchor head. Defaults to None.
252
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
253
+ list[dict], optional): Initialization config dict.
254
+ Defaults to None.
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ head_module: ConfigType,
260
+ prior_generator: ConfigType = dict(
261
+ type='mmdet.MlvlPointGenerator', strides=[8, 16, 32],
262
+ offset=0),
263
+ bbox_coder: ConfigType = dict(type='DistanceAnglePointCoder'),
264
+ loss_cls: ConfigType = dict(
265
+ type='mmdet.QualityFocalLoss',
266
+ use_sigmoid=True,
267
+ beta=2.0,
268
+ loss_weight=1.0),
269
+ loss_bbox: ConfigType = dict(
270
+ type='mmrotate.RotatedIoULoss', mode='linear',
271
+ loss_weight=2.0),
272
+ angle_version: str = 'le90',
273
+ use_hbbox_loss: bool = False,
274
+ angle_coder: ConfigType = dict(type='mmrotate.PseudoAngleCoder'),
275
+ loss_angle: OptConfigType = None,
276
+ train_cfg: OptConfigType = None,
277
+ test_cfg: OptConfigType = None,
278
+ init_cfg: OptMultiConfig = None):
279
+ if not MMROTATE_AVAILABLE:
280
+ raise ImportError(
281
+ 'Please run "mim install -r requirements/mmrotate.txt" '
282
+ 'to install mmrotate first for rotated detection.')
283
+
284
+ self.angle_version = angle_version
285
+ self.use_hbbox_loss = use_hbbox_loss
286
+ if self.use_hbbox_loss:
287
+ assert loss_angle is not None, \
288
+ ('When use hbbox loss, loss_angle needs to be specified')
289
+ self.angle_coder = TASK_UTILS.build(angle_coder)
290
+ self.angle_out_dim = self.angle_coder.encode_size
291
+ if head_module.get('angle_out_dim') is not None:
292
+ warnings.warn('angle_out_dim will be overridden by angle_coder '
293
+ 'and does not need to be set manually')
294
+
295
+ head_module['angle_out_dim'] = self.angle_out_dim
296
+ super().__init__(
297
+ head_module=head_module,
298
+ prior_generator=prior_generator,
299
+ bbox_coder=bbox_coder,
300
+ loss_cls=loss_cls,
301
+ loss_bbox=loss_bbox,
302
+ train_cfg=train_cfg,
303
+ test_cfg=test_cfg,
304
+ init_cfg=init_cfg)
305
+
306
+ if loss_angle is not None:
307
+ self.loss_angle = MODELS.build(loss_angle)
308
+ else:
309
+ self.loss_angle = None
310
+
311
+ def predict_by_feat(self,
312
+ cls_scores: List[Tensor],
313
+ bbox_preds: List[Tensor],
314
+ angle_preds: List[Tensor],
315
+ objectnesses: Optional[List[Tensor]] = None,
316
+ batch_img_metas: Optional[List[dict]] = None,
317
+ cfg: Optional[ConfigDict] = None,
318
+ rescale: bool = True,
319
+ with_nms: bool = True) -> List[InstanceData]:
320
+ """Transform a batch of output features extracted by the head into bbox
321
+ results.
322
+
323
+ Args:
324
+ cls_scores (list[Tensor]): Classification scores for all
325
+ scale levels, each is a 4D-tensor, has shape
326
+ (batch_size, num_priors * num_classes, H, W).
327
+ bbox_preds (list[Tensor]): Box energies / deltas for all
328
+ scale levels, each is a 4D-tensor, has shape
329
+ (batch_size, num_priors * 4, H, W).
330
+ angle_preds (list[Tensor]): Box angle for each scale level
331
+ with shape (N, num_points * angle_dim, H, W)
332
+ objectnesses (list[Tensor], Optional): Score factor for
333
+ all scale level, each is a 4D-tensor, has shape
334
+ (batch_size, 1, H, W).
335
+ batch_img_metas (list[dict], Optional): Batch image meta info.
336
+ Defaults to None.
337
+ cfg (ConfigDict, optional): Test / postprocessing
338
+ configuration, if None, test_cfg would be used.
339
+ Defaults to None.
340
+ rescale (bool): If True, return boxes in original image space.
341
+ Defaults to False.
342
+ with_nms (bool): If True, do nms before return boxes.
343
+ Defaults to True.
344
+
345
+ Returns:
346
+ list[:obj:`InstanceData`]: Object detection results of each image
347
+ after the post process. Each item usually contains following keys.
348
+ - scores (Tensor): Classification scores, has a shape
349
+ (num_instance, )
350
+ - labels (Tensor): Labels of bboxes, has a shape
351
+ (num_instances, ).
352
+ - bboxes (Tensor): Has a shape (num_instances, 5),
353
+ the last dimension 4 arrange as (x, y, w, h, angle).
354
+ """
355
+ assert len(cls_scores) == len(bbox_preds)
356
+ if objectnesses is None:
357
+ with_objectnesses = False
358
+ else:
359
+ with_objectnesses = True
360
+ assert len(cls_scores) == len(objectnesses)
361
+
362
+ cfg = self.test_cfg if cfg is None else cfg
363
+ cfg = copy.deepcopy(cfg)
364
+
365
+ multi_label = cfg.multi_label
366
+ multi_label &= self.num_classes > 1
367
+ cfg.multi_label = multi_label
368
+
369
+ # Whether to decode rbox with angle.
370
+ # different setting lead to different final results.
371
+ # Defaults to True.
372
+ decode_with_angle = cfg.get('decode_with_angle', True)
373
+
374
+ num_imgs = len(batch_img_metas)
375
+ featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
376
+
377
+ # If the shape does not change, use the previous mlvl_priors
378
+ if featmap_sizes != self.featmap_sizes:
379
+ self.mlvl_priors = self.prior_generator.grid_priors(
380
+ featmap_sizes,
381
+ dtype=cls_scores[0].dtype,
382
+ device=cls_scores[0].device)
383
+ self.featmap_sizes = featmap_sizes
384
+ flatten_priors = torch.cat(self.mlvl_priors)
385
+
386
+ mlvl_strides = [
387
+ flatten_priors.new_full(
388
+ (featmap_size.numel() * self.num_base_priors, ), stride) for
389
+ featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
390
+ ]
391
+ flatten_stride = torch.cat(mlvl_strides)
392
+
393
+ # flatten cls_scores, bbox_preds and objectness
394
+ flatten_cls_scores = [
395
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
396
+ self.num_classes)
397
+ for cls_score in cls_scores
398
+ ]
399
+ flatten_bbox_preds = [
400
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
401
+ for bbox_pred in bbox_preds
402
+ ]
403
+ flatten_angle_preds = [
404
+ angle_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
405
+ self.angle_out_dim)
406
+ for angle_pred in angle_preds
407
+ ]
408
+
409
+ flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
410
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
411
+ flatten_angle_preds = torch.cat(flatten_angle_preds, dim=1)
412
+ flatten_angle_preds = self.angle_coder.decode(
413
+ flatten_angle_preds, keepdim=True)
414
+
415
+ if decode_with_angle:
416
+ flatten_rbbox_preds = torch.cat(
417
+ [flatten_bbox_preds, flatten_angle_preds], dim=-1)
418
+ flatten_decoded_bboxes = self.bbox_coder.decode(
419
+ flatten_priors[None], flatten_rbbox_preds, flatten_stride)
420
+ else:
421
+ flatten_decoded_hbboxes = self.bbox_coder.decode(
422
+ flatten_priors[None], flatten_bbox_preds, flatten_stride)
423
+ flatten_decoded_hbboxes = HorizontalBoxes.xyxy_to_cxcywh(
424
+ flatten_decoded_hbboxes)
425
+ flatten_decoded_bboxes = torch.cat(
426
+ [flatten_decoded_hbboxes, flatten_angle_preds], dim=-1)
427
+
428
+ if with_objectnesses:
429
+ flatten_objectness = [
430
+ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
431
+ for objectness in objectnesses
432
+ ]
433
+ flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
434
+ else:
435
+ flatten_objectness = [None for _ in range(num_imgs)]
436
+
437
+ results_list = []
438
+ for (bboxes, scores, objectness,
439
+ img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
440
+ flatten_objectness, batch_img_metas):
441
+ scale_factor = img_meta['scale_factor']
442
+ if 'pad_param' in img_meta:
443
+ pad_param = img_meta['pad_param']
444
+ else:
445
+ pad_param = None
446
+
447
+ score_thr = cfg.get('score_thr', -1)
448
+ # yolox_style does not require the following operations
449
+ if objectness is not None and score_thr > 0 and not cfg.get(
450
+ 'yolox_style', False):
451
+ conf_inds = objectness > score_thr
452
+ bboxes = bboxes[conf_inds, :]
453
+ scores = scores[conf_inds, :]
454
+ objectness = objectness[conf_inds]
455
+
456
+ if objectness is not None:
457
+ # conf = obj_conf * cls_conf
458
+ scores *= objectness[:, None]
459
+
460
+ if scores.shape[0] == 0:
461
+ empty_results = InstanceData()
462
+ empty_results.bboxes = RotatedBoxes(bboxes)
463
+ empty_results.scores = scores[:, 0]
464
+ empty_results.labels = scores[:, 0].int()
465
+ results_list.append(empty_results)
466
+ continue
467
+
468
+ nms_pre = cfg.get('nms_pre', 100000)
469
+ if cfg.multi_label is False:
470
+ scores, labels = scores.max(1, keepdim=True)
471
+ scores, _, keep_idxs, results = filter_scores_and_topk(
472
+ scores,
473
+ score_thr,
474
+ nms_pre,
475
+ results=dict(labels=labels[:, 0]))
476
+ labels = results['labels']
477
+ else:
478
+ scores, labels, keep_idxs, _ = filter_scores_and_topk(
479
+ scores, score_thr, nms_pre)
480
+
481
+ results = InstanceData(
482
+ scores=scores,
483
+ labels=labels,
484
+ bboxes=RotatedBoxes(bboxes[keep_idxs]))
485
+
486
+ if rescale:
487
+ if pad_param is not None:
488
+ results.bboxes.translate_([-pad_param[2], -pad_param[0]])
489
+
490
+ scale_factor = [1 / s for s in img_meta['scale_factor']]
491
+ results.bboxes = scale_boxes(results.bboxes, scale_factor)
492
+
493
+ if cfg.get('yolox_style', False):
494
+ # do not need max_per_img
495
+ cfg.max_per_img = len(results)
496
+
497
+ results = self._bbox_post_process(
498
+ results=results,
499
+ cfg=cfg,
500
+ rescale=False,
501
+ with_nms=with_nms,
502
+ img_meta=img_meta)
503
+
504
+ results_list.append(results)
505
+ return results_list
506
+
507
+ def loss_by_feat(
508
+ self,
509
+ cls_scores: List[Tensor],
510
+ bbox_preds: List[Tensor],
511
+ angle_preds: List[Tensor],
512
+ batch_gt_instances: InstanceList,
513
+ batch_img_metas: List[dict],
514
+ batch_gt_instances_ignore: OptInstanceList = None) -> dict:
515
+ """Compute losses of the head.
516
+
517
+ Args:
518
+ cls_scores (list[Tensor]): Box scores for each scale level
519
+ Has shape (N, num_anchors * num_classes, H, W)
520
+ bbox_preds (list[Tensor]): Decoded box for each scale
521
+ level with shape (N, num_anchors * 4, H, W) in
522
+ [tl_x, tl_y, br_x, br_y] format.
523
+ angle_preds (list[Tensor]): Angle prediction for each scale
524
+ level with shape (N, num_anchors * angle_out_dim, H, W).
525
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
526
+ gt_instance. It usually includes ``bboxes`` and ``labels``
527
+ attributes.
528
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
529
+ image size, scaling factor, etc.
530
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
531
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
532
+ data that is ignored during training and testing.
533
+ Defaults to None.
534
+
535
+ Returns:
536
+ dict[str, Tensor]: A dictionary of loss components.
537
+ """
538
+ num_imgs = len(batch_img_metas)
539
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
540
+ assert len(featmap_sizes) == self.prior_generator.num_levels
541
+
542
+ gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
543
+ gt_labels = gt_info[:, :, :1]
544
+ gt_bboxes = gt_info[:, :, 1:] # xywha
545
+ pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
546
+
547
+ device = cls_scores[0].device
548
+
549
+ # If the shape does not equal, generate new one
550
+ if featmap_sizes != self.featmap_sizes_train:
551
+ self.featmap_sizes_train = featmap_sizes
552
+ mlvl_priors_with_stride = self.prior_generator.grid_priors(
553
+ featmap_sizes, device=device, with_stride=True)
554
+ self.flatten_priors_train = torch.cat(
555
+ mlvl_priors_with_stride, dim=0)
556
+
557
+ flatten_cls_scores = torch.cat([
558
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
559
+ self.cls_out_channels)
560
+ for cls_score in cls_scores
561
+ ], 1).contiguous()
562
+
563
+ flatten_tblrs = torch.cat([
564
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
565
+ for bbox_pred in bbox_preds
566
+ ], 1)
567
+ flatten_tblrs = flatten_tblrs * self.flatten_priors_train[..., -1,
568
+ None]
569
+ flatten_angles = torch.cat([
570
+ angle_pred.permute(0, 2, 3, 1).reshape(
571
+ num_imgs, -1, self.angle_out_dim) for angle_pred in angle_preds
572
+ ], 1)
573
+ flatten_decoded_angle = self.angle_coder.decode(
574
+ flatten_angles, keepdim=True)
575
+ flatten_tblra = torch.cat([flatten_tblrs, flatten_decoded_angle],
576
+ dim=-1)
577
+ flatten_rbboxes = distance2obb(
578
+ self.flatten_priors_train[..., :2],
579
+ flatten_tblra,
580
+ angle_version=self.angle_version)
581
+ if self.use_hbbox_loss:
582
+ flatten_hbboxes = distance2bbox(self.flatten_priors_train[..., :2],
583
+ flatten_tblrs)
584
+
585
+ assigned_result = self.assigner(flatten_rbboxes.detach(),
586
+ flatten_cls_scores.detach(),
587
+ self.flatten_priors_train, gt_labels,
588
+ gt_bboxes, pad_bbox_flag)
589
+
590
+ labels = assigned_result['assigned_labels'].reshape(-1)
591
+ label_weights = assigned_result['assigned_labels_weights'].reshape(-1)
592
+ bbox_targets = assigned_result['assigned_bboxes'].reshape(-1, 5)
593
+ assign_metrics = assigned_result['assign_metrics'].reshape(-1)
594
+ cls_preds = flatten_cls_scores.reshape(-1, self.num_classes)
595
+
596
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
597
+ bg_class_ind = self.num_classes
598
+ pos_inds = ((labels >= 0)
599
+ & (labels < bg_class_ind)).nonzero().squeeze(1)
600
+ avg_factor = reduce_mean(assign_metrics.sum()).clamp_(min=1).item()
601
+
602
+ loss_cls = self.loss_cls(
603
+ cls_preds, (labels, assign_metrics),
604
+ label_weights,
605
+ avg_factor=avg_factor)
606
+
607
+ pos_bbox_targets = bbox_targets[pos_inds]
608
+
609
+ if self.use_hbbox_loss:
610
+ bbox_preds = flatten_hbboxes.reshape(-1, 4)
611
+ pos_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets[:, :4])
612
+ else:
613
+ bbox_preds = flatten_rbboxes.reshape(-1, 5)
614
+ angle_preds = flatten_angles.reshape(-1, self.angle_out_dim)
615
+
616
+ if len(pos_inds) > 0:
617
+ loss_bbox = self.loss_bbox(
618
+ bbox_preds[pos_inds],
619
+ pos_bbox_targets,
620
+ weight=assign_metrics[pos_inds],
621
+ avg_factor=avg_factor)
622
+ loss_angle = angle_preds.sum() * 0
623
+ if self.loss_angle is not None:
624
+ pos_angle_targets = bbox_targets[pos_inds][:, 4:5]
625
+ pos_angle_targets = self.angle_coder.encode(pos_angle_targets)
626
+ loss_angle = self.loss_angle(
627
+ angle_preds[pos_inds],
628
+ pos_angle_targets,
629
+ weight=assign_metrics[pos_inds],
630
+ avg_factor=avg_factor)
631
+ else:
632
+ loss_bbox = bbox_preds.sum() * 0
633
+ loss_angle = angle_preds.sum() * 0
634
+
635
+ losses = dict()
636
+ losses['loss_cls'] = loss_cls
637
+ losses['loss_bbox'] = loss_bbox
638
+ if self.loss_angle is not None:
639
+ losses['loss_angle'] = loss_angle
640
+
641
+ return losses
mmyolo/models/dense_heads/yolov5_head.py ADDED
@@ -0,0 +1,890 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import math
4
+ from typing import List, Optional, Sequence, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from mmdet.models.dense_heads.base_dense_head import BaseDenseHead
9
+ from mmdet.models.utils import filter_scores_and_topk, multi_apply
10
+ from mmdet.structures.bbox import bbox_overlaps
11
+ from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
12
+ OptMultiConfig)
13
+ from mmengine.config import ConfigDict
14
+ from mmengine.dist import get_dist_info
15
+ from mmengine.logging import print_log
16
+ from mmengine.model import BaseModule
17
+ from mmengine.structures import InstanceData
18
+ from torch import Tensor
19
+
20
+ from mmyolo.registry import MODELS, TASK_UTILS
21
+ from ..utils import make_divisible
22
+
23
+
24
+ def get_prior_xy_info(index: int, num_base_priors: int,
25
+ featmap_sizes: int) -> Tuple[int, int, int]:
26
+ """Get prior index and xy index in feature map by flatten index."""
27
+ _, featmap_w = featmap_sizes
28
+ priors = index % num_base_priors
29
+ xy_index = index // num_base_priors
30
+ grid_y = xy_index // featmap_w
31
+ grid_x = xy_index % featmap_w
32
+ return priors, grid_x, grid_y
33
+
34
+
35
+ @MODELS.register_module()
36
+ class YOLOv5HeadModule(BaseModule):
37
+ """YOLOv5Head head module used in `YOLOv5`.
38
+
39
+ Args:
40
+ num_classes (int): Number of categories excluding the background
41
+ category.
42
+ in_channels (Union[int, Sequence]): Number of channels in the input
43
+ feature map.
44
+ widen_factor (float): Width multiplier, multiply number of
45
+ channels in each layer by this amount. Defaults to 1.0.
46
+ num_base_priors (int): The number of priors (points) at a point
47
+ on the feature grid.
48
+ featmap_strides (Sequence[int]): Downsample factor of each feature map.
49
+ Defaults to (8, 16, 32).
50
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
51
+ list[dict], optional): Initialization config dict.
52
+ Defaults to None.
53
+ """
54
+
55
+ def __init__(self,
56
+ num_classes: int,
57
+ in_channels: Union[int, Sequence],
58
+ widen_factor: float = 1.0,
59
+ num_base_priors: int = 3,
60
+ featmap_strides: Sequence[int] = (8, 16, 32),
61
+ init_cfg: OptMultiConfig = None):
62
+ super().__init__(init_cfg=init_cfg)
63
+ self.num_classes = num_classes
64
+ self.widen_factor = widen_factor
65
+
66
+ self.featmap_strides = featmap_strides
67
+ self.num_out_attrib = 5 + self.num_classes
68
+ self.num_levels = len(self.featmap_strides)
69
+ self.num_base_priors = num_base_priors
70
+
71
+ if isinstance(in_channels, int):
72
+ self.in_channels = [make_divisible(in_channels, widen_factor)
73
+ ] * self.num_levels
74
+ else:
75
+ self.in_channels = [
76
+ make_divisible(i, widen_factor) for i in in_channels
77
+ ]
78
+
79
+ self._init_layers()
80
+
81
+ def _init_layers(self):
82
+ """initialize conv layers in YOLOv5 head."""
83
+ self.convs_pred = nn.ModuleList()
84
+ for i in range(self.num_levels):
85
+ conv_pred = nn.Conv2d(self.in_channels[i],
86
+ self.num_base_priors * self.num_out_attrib,
87
+ 1)
88
+
89
+ self.convs_pred.append(conv_pred)
90
+
91
+ def init_weights(self):
92
+ """Initialize the bias of YOLOv5 head."""
93
+ super().init_weights()
94
+ for mi, s in zip(self.convs_pred, self.featmap_strides): # from
95
+ b = mi.bias.data.view(self.num_base_priors, -1)
96
+ # obj (8 objects per 640 image)
97
+ b.data[:, 4] += math.log(8 / (640 / s)**2)
98
+ b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.999999))
99
+
100
+ mi.bias.data = b.view(-1)
101
+
102
+ def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
103
+ """Forward features from the upstream network.
104
+
105
+ Args:
106
+ x (Tuple[Tensor]): Features from the upstream network, each is
107
+ a 4D-tensor.
108
+ Returns:
109
+ Tuple[List]: A tuple of multi-level classification scores, bbox
110
+ predictions, and objectnesses.
111
+ """
112
+ assert len(x) == self.num_levels
113
+ return multi_apply(self.forward_single, x, self.convs_pred)
114
+
115
+ def forward_single(self, x: Tensor,
116
+ convs: nn.Module) -> Tuple[Tensor, Tensor, Tensor]:
117
+ """Forward feature of a single scale level."""
118
+
119
+ pred_map = convs(x)
120
+ bs, _, ny, nx = pred_map.shape
121
+ pred_map = pred_map.view(bs, self.num_base_priors, self.num_out_attrib,
122
+ ny, nx)
123
+
124
+ cls_score = pred_map[:, :, 5:, ...].reshape(bs, -1, ny, nx)
125
+ bbox_pred = pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
126
+ objectness = pred_map[:, :, 4:5, ...].reshape(bs, -1, ny, nx)
127
+
128
+ return cls_score, bbox_pred, objectness
129
+
130
+
131
+ @MODELS.register_module()
132
+ class YOLOv5Head(BaseDenseHead):
133
+ """YOLOv5Head head used in `YOLOv5`.
134
+
135
+ Args:
136
+ head_module(ConfigType): Base module used for YOLOv5Head
137
+ prior_generator(dict): Points generator feature maps in
138
+ 2D points-based detectors.
139
+ bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
140
+ loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
141
+ loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
142
+ loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
143
+ prior_match_thr (float): Defaults to 4.0.
144
+ ignore_iof_thr (float): Defaults to -1.0.
145
+ obj_level_weights (List[float]): Defaults to [4.0, 1.0, 0.4].
146
+ train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
147
+ anchor head. Defaults to None.
148
+ test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
149
+ anchor head. Defaults to None.
150
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
151
+ list[dict], optional): Initialization config dict.
152
+ Defaults to None.
153
+ """
154
+
155
+ def __init__(self,
156
+ head_module: ConfigType,
157
+ prior_generator: ConfigType = dict(
158
+ type='mmdet.YOLOAnchorGenerator',
159
+ base_sizes=[[(10, 13), (16, 30), (33, 23)],
160
+ [(30, 61), (62, 45), (59, 119)],
161
+ [(116, 90), (156, 198), (373, 326)]],
162
+ strides=[8, 16, 32]),
163
+ bbox_coder: ConfigType = dict(type='YOLOv5BBoxCoder'),
164
+ loss_cls: ConfigType = dict(
165
+ type='mmdet.CrossEntropyLoss',
166
+ use_sigmoid=True,
167
+ reduction='mean',
168
+ loss_weight=0.5),
169
+ loss_bbox: ConfigType = dict(
170
+ type='IoULoss',
171
+ iou_mode='ciou',
172
+ bbox_format='xywh',
173
+ eps=1e-7,
174
+ reduction='mean',
175
+ loss_weight=0.05,
176
+ return_iou=True),
177
+ loss_obj: ConfigType = dict(
178
+ type='mmdet.CrossEntropyLoss',
179
+ use_sigmoid=True,
180
+ reduction='mean',
181
+ loss_weight=1.0),
182
+ prior_match_thr: float = 4.0,
183
+ near_neighbor_thr: float = 0.5,
184
+ ignore_iof_thr: float = -1.0,
185
+ obj_level_weights: List[float] = [4.0, 1.0, 0.4],
186
+ train_cfg: OptConfigType = None,
187
+ test_cfg: OptConfigType = None,
188
+ init_cfg: OptMultiConfig = None):
189
+ super().__init__(init_cfg=init_cfg)
190
+
191
+ self.head_module = MODELS.build(head_module)
192
+ self.num_classes = self.head_module.num_classes
193
+ self.featmap_strides = self.head_module.featmap_strides
194
+ self.num_levels = len(self.featmap_strides)
195
+
196
+ self.train_cfg = train_cfg
197
+ self.test_cfg = test_cfg
198
+
199
+ self.loss_cls: nn.Module = MODELS.build(loss_cls)
200
+ self.loss_bbox: nn.Module = MODELS.build(loss_bbox)
201
+ self.loss_obj: nn.Module = MODELS.build(loss_obj)
202
+
203
+ self.prior_generator = TASK_UTILS.build(prior_generator)
204
+ self.bbox_coder = TASK_UTILS.build(bbox_coder)
205
+ self.num_base_priors = self.prior_generator.num_base_priors[0]
206
+
207
+ self.featmap_sizes = [torch.empty(1)] * self.num_levels
208
+
209
+ self.prior_match_thr = prior_match_thr
210
+ self.near_neighbor_thr = near_neighbor_thr
211
+ self.obj_level_weights = obj_level_weights
212
+ self.ignore_iof_thr = ignore_iof_thr
213
+
214
+ self.special_init()
215
+
216
+ def special_init(self):
217
+ """Since YOLO series algorithms will inherit from YOLOv5Head, but
218
+ different algorithms have special initialization process.
219
+
220
+ The special_init function is designed to deal with this situation.
221
+ """
222
+ assert len(self.obj_level_weights) == len(
223
+ self.featmap_strides) == self.num_levels
224
+ if self.prior_match_thr != 4.0:
225
+ print_log(
226
+ "!!!Now, you've changed the prior_match_thr "
227
+ 'parameter to something other than 4.0. Please make sure '
228
+ 'that you have modified both the regression formula in '
229
+ 'bbox_coder and before loss_box computation, '
230
+ 'otherwise the accuracy may be degraded!!!')
231
+
232
+ if self.num_classes == 1:
233
+ print_log('!!!You are using `YOLOv5Head` with num_classes == 1.'
234
+ ' The loss_cls will be 0. This is a normal phenomenon.')
235
+
236
+ priors_base_sizes = torch.tensor(
237
+ self.prior_generator.base_sizes, dtype=torch.float)
238
+ featmap_strides = torch.tensor(
239
+ self.featmap_strides, dtype=torch.float)[:, None, None]
240
+ self.register_buffer(
241
+ 'priors_base_sizes',
242
+ priors_base_sizes / featmap_strides,
243
+ persistent=False)
244
+
245
+ grid_offset = torch.tensor([
246
+ [0, 0], # center
247
+ [1, 0], # left
248
+ [0, 1], # up
249
+ [-1, 0], # right
250
+ [0, -1], # bottom
251
+ ]).float()
252
+ self.register_buffer(
253
+ 'grid_offset', grid_offset[:, None], persistent=False)
254
+
255
+ prior_inds = torch.arange(self.num_base_priors).float().view(
256
+ self.num_base_priors, 1)
257
+ self.register_buffer('prior_inds', prior_inds, persistent=False)
258
+
259
+ def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
260
+ """Forward features from the upstream network.
261
+
262
+ Args:
263
+ x (Tuple[Tensor]): Features from the upstream network, each is
264
+ a 4D-tensor.
265
+ Returns:
266
+ Tuple[List]: A tuple of multi-level classification scores, bbox
267
+ predictions, and objectnesses.
268
+ """
269
+ return self.head_module(x)
270
+
271
+ def predict_by_feat(self,
272
+ cls_scores: List[Tensor],
273
+ bbox_preds: List[Tensor],
274
+ objectnesses: Optional[List[Tensor]] = None,
275
+ batch_img_metas: Optional[List[dict]] = None,
276
+ cfg: Optional[ConfigDict] = None,
277
+ rescale: bool = True,
278
+ with_nms: bool = True) -> List[InstanceData]:
279
+ """Transform a batch of output features extracted by the head into
280
+ bbox results.
281
+ Args:
282
+ cls_scores (list[Tensor]): Classification scores for all
283
+ scale levels, each is a 4D-tensor, has shape
284
+ (batch_size, num_priors * num_classes, H, W).
285
+ bbox_preds (list[Tensor]): Box energies / deltas for all
286
+ scale levels, each is a 4D-tensor, has shape
287
+ (batch_size, num_priors * 4, H, W).
288
+ objectnesses (list[Tensor], Optional): Score factor for
289
+ all scale level, each is a 4D-tensor, has shape
290
+ (batch_size, 1, H, W).
291
+ batch_img_metas (list[dict], Optional): Batch image meta info.
292
+ Defaults to None.
293
+ cfg (ConfigDict, optional): Test / postprocessing
294
+ configuration, if None, test_cfg would be used.
295
+ Defaults to None.
296
+ rescale (bool): If True, return boxes in original image space.
297
+ Defaults to False.
298
+ with_nms (bool): If True, do nms before return boxes.
299
+ Defaults to True.
300
+
301
+ Returns:
302
+ list[:obj:`InstanceData`]: Object detection results of each image
303
+ after the post process. Each item usually contains following keys.
304
+
305
+ - scores (Tensor): Classification scores, has a shape
306
+ (num_instance, )
307
+ - labels (Tensor): Labels of bboxes, has a shape
308
+ (num_instances, ).
309
+ - bboxes (Tensor): Has a shape (num_instances, 4),
310
+ the last dimension 4 arrange as (x1, y1, x2, y2).
311
+ """
312
+ assert len(cls_scores) == len(bbox_preds)
313
+ if objectnesses is None:
314
+ with_objectnesses = False
315
+ else:
316
+ with_objectnesses = True
317
+ assert len(cls_scores) == len(objectnesses)
318
+
319
+ cfg = self.test_cfg if cfg is None else cfg
320
+ cfg = copy.deepcopy(cfg)
321
+
322
+ multi_label = cfg.multi_label
323
+ multi_label &= self.num_classes > 1
324
+ cfg.multi_label = multi_label
325
+
326
+ num_imgs = len(batch_img_metas)
327
+ featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
328
+
329
+ # If the shape does not change, use the previous mlvl_priors
330
+ if featmap_sizes != self.featmap_sizes:
331
+ self.mlvl_priors = self.prior_generator.grid_priors(
332
+ featmap_sizes,
333
+ dtype=cls_scores[0].dtype,
334
+ device=cls_scores[0].device)
335
+ self.featmap_sizes = featmap_sizes
336
+ flatten_priors = torch.cat(self.mlvl_priors)
337
+
338
+ mlvl_strides = [
339
+ flatten_priors.new_full(
340
+ (featmap_size.numel() * self.num_base_priors, ), stride) for
341
+ featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
342
+ ]
343
+ flatten_stride = torch.cat(mlvl_strides)
344
+
345
+ # flatten cls_scores, bbox_preds and objectness
346
+ flatten_cls_scores = [
347
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1,
348
+ self.num_classes)
349
+ for cls_score in cls_scores
350
+ ]
351
+ flatten_bbox_preds = [
352
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
353
+ for bbox_pred in bbox_preds
354
+ ]
355
+
356
+ flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
357
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
358
+ flatten_decoded_bboxes = self.bbox_coder.decode(
359
+ flatten_priors[None], flatten_bbox_preds, flatten_stride)
360
+
361
+ if with_objectnesses:
362
+ flatten_objectness = [
363
+ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
364
+ for objectness in objectnesses
365
+ ]
366
+ flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
367
+ else:
368
+ flatten_objectness = [None for _ in range(num_imgs)]
369
+
370
+ results_list = []
371
+ for (bboxes, scores, objectness,
372
+ img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores,
373
+ flatten_objectness, batch_img_metas):
374
+ ori_shape = img_meta['ori_shape']
375
+ scale_factor = img_meta['scale_factor']
376
+ if 'pad_param' in img_meta:
377
+ pad_param = img_meta['pad_param']
378
+ else:
379
+ pad_param = None
380
+
381
+ score_thr = cfg.get('score_thr', -1)
382
+ # yolox_style does not require the following operations
383
+ if objectness is not None and score_thr > 0 and not cfg.get(
384
+ 'yolox_style', False):
385
+ conf_inds = objectness > score_thr
386
+ bboxes = bboxes[conf_inds, :]
387
+ scores = scores[conf_inds, :]
388
+ objectness = objectness[conf_inds]
389
+
390
+ if objectness is not None:
391
+ # conf = obj_conf * cls_conf
392
+ scores *= objectness[:, None]
393
+
394
+ if scores.shape[0] == 0:
395
+ empty_results = InstanceData()
396
+ empty_results.bboxes = bboxes
397
+ empty_results.scores = scores[:, 0]
398
+ empty_results.labels = scores[:, 0].int()
399
+ results_list.append(empty_results)
400
+ continue
401
+
402
+ nms_pre = cfg.get('nms_pre', 100000)
403
+ if cfg.multi_label is False:
404
+ scores, labels = scores.max(1, keepdim=True)
405
+ scores, _, keep_idxs, results = filter_scores_and_topk(
406
+ scores,
407
+ score_thr,
408
+ nms_pre,
409
+ results=dict(labels=labels[:, 0]))
410
+ labels = results['labels']
411
+ else:
412
+ scores, labels, keep_idxs, _ = filter_scores_and_topk(
413
+ scores, score_thr, nms_pre)
414
+
415
+ results = InstanceData(
416
+ scores=scores, labels=labels, bboxes=bboxes[keep_idxs])
417
+
418
+ if rescale:
419
+ if pad_param is not None:
420
+ results.bboxes -= results.bboxes.new_tensor([
421
+ pad_param[2], pad_param[0], pad_param[2], pad_param[0]
422
+ ])
423
+ results.bboxes /= results.bboxes.new_tensor(
424
+ scale_factor).repeat((1, 2))
425
+
426
+ if cfg.get('yolox_style', False):
427
+ # do not need max_per_img
428
+ cfg.max_per_img = len(results)
429
+
430
+ results = self._bbox_post_process(
431
+ results=results,
432
+ cfg=cfg,
433
+ rescale=False,
434
+ with_nms=with_nms,
435
+ img_meta=img_meta)
436
+ results.bboxes[:, 0::2].clamp_(0, ori_shape[1])
437
+ results.bboxes[:, 1::2].clamp_(0, ori_shape[0])
438
+
439
+ results_list.append(results)
440
+ return results_list
441
+
442
+ def loss(self, x: Tuple[Tensor], batch_data_samples: Union[list,
443
+ dict]) -> dict:
444
+ """Perform forward propagation and loss calculation of the detection
445
+ head on the features of the upstream network.
446
+
447
+ Args:
448
+ x (tuple[Tensor]): Features from the upstream network, each is
449
+ a 4D-tensor.
450
+ batch_data_samples (List[:obj:`DetDataSample`], dict): The Data
451
+ Samples. It usually includes information such as
452
+ `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
453
+
454
+ Returns:
455
+ dict: A dictionary of loss components.
456
+ """
457
+
458
+ if isinstance(batch_data_samples, list):
459
+ losses = super().loss(x, batch_data_samples)
460
+ else:
461
+ outs = self(x)
462
+ # Fast version
463
+ loss_inputs = outs + (batch_data_samples['bboxes_labels'],
464
+ batch_data_samples['img_metas'])
465
+ losses = self.loss_by_feat(*loss_inputs)
466
+
467
+ return losses
468
+
469
+ def loss_by_feat(
470
+ self,
471
+ cls_scores: Sequence[Tensor],
472
+ bbox_preds: Sequence[Tensor],
473
+ objectnesses: Sequence[Tensor],
474
+ batch_gt_instances: Sequence[InstanceData],
475
+ batch_img_metas: Sequence[dict],
476
+ batch_gt_instances_ignore: OptInstanceList = None) -> dict:
477
+ """Calculate the loss based on the features extracted by the detection
478
+ head.
479
+
480
+ Args:
481
+ cls_scores (Sequence[Tensor]): Box scores for each scale level,
482
+ each is a 4D-tensor, the channel number is
483
+ num_priors * num_classes.
484
+ bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
485
+ level, each is a 4D-tensor, the channel number is
486
+ num_priors * 4.
487
+ objectnesses (Sequence[Tensor]): Score factor for
488
+ all scale level, each is a 4D-tensor, has shape
489
+ (batch_size, 1, H, W).
490
+ batch_gt_instances (Sequence[InstanceData]): Batch of
491
+ gt_instance. It usually includes ``bboxes`` and ``labels``
492
+ attributes.
493
+ batch_img_metas (Sequence[dict]): Meta information of each image,
494
+ e.g., image size, scaling factor, etc.
495
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
496
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
497
+ data that is ignored during training and testing.
498
+ Defaults to None.
499
+ Returns:
500
+ dict[str, Tensor]: A dictionary of losses.
501
+ """
502
+ if self.ignore_iof_thr != -1:
503
+ # TODO: Support fast version
504
+ # convert ignore gt
505
+ batch_target_ignore_list = []
506
+ for i, gt_instances_ignore in enumerate(batch_gt_instances_ignore):
507
+ bboxes = gt_instances_ignore.bboxes
508
+ labels = gt_instances_ignore.labels
509
+ index = bboxes.new_full((len(bboxes), 1), i)
510
+ # (batch_idx, label, bboxes)
511
+ target = torch.cat((index, labels[:, None].float(), bboxes),
512
+ dim=1)
513
+ batch_target_ignore_list.append(target)
514
+
515
+ # (num_bboxes, 6)
516
+ batch_gt_targets_ignore = torch.cat(
517
+ batch_target_ignore_list, dim=0)
518
+ if batch_gt_targets_ignore.shape[0] != 0:
519
+ # Consider regions with ignore in annotations
520
+ return self._loss_by_feat_with_ignore(
521
+ cls_scores,
522
+ bbox_preds,
523
+ objectnesses,
524
+ batch_gt_instances=batch_gt_instances,
525
+ batch_img_metas=batch_img_metas,
526
+ batch_gt_instances_ignore=batch_gt_targets_ignore)
527
+
528
+ # 1. Convert gt to norm format
529
+ batch_targets_normed = self._convert_gt_to_norm_format(
530
+ batch_gt_instances, batch_img_metas)
531
+
532
+ device = cls_scores[0].device
533
+ loss_cls = torch.zeros(1, device=device)
534
+ loss_box = torch.zeros(1, device=device)
535
+ loss_obj = torch.zeros(1, device=device)
536
+ scaled_factor = torch.ones(7, device=device)
537
+
538
+ for i in range(self.num_levels):
539
+ batch_size, _, h, w = bbox_preds[i].shape
540
+ target_obj = torch.zeros_like(objectnesses[i])
541
+
542
+ # empty gt bboxes
543
+ if batch_targets_normed.shape[1] == 0:
544
+ loss_box += bbox_preds[i].sum() * 0
545
+ loss_cls += cls_scores[i].sum() * 0
546
+ loss_obj += self.loss_obj(
547
+ objectnesses[i], target_obj) * self.obj_level_weights[i]
548
+ continue
549
+
550
+ priors_base_sizes_i = self.priors_base_sizes[i]
551
+ # feature map scale whwh
552
+ scaled_factor[2:6] = torch.tensor(
553
+ bbox_preds[i].shape)[[3, 2, 3, 2]]
554
+ # Scale batch_targets from range 0-1 to range 0-features_maps size.
555
+ # (num_base_priors, num_bboxes, 7)
556
+ batch_targets_scaled = batch_targets_normed * scaled_factor
557
+
558
+ # 2. Shape match
559
+ wh_ratio = batch_targets_scaled[...,
560
+ 4:6] / priors_base_sizes_i[:, None]
561
+ match_inds = torch.max(
562
+ wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr
563
+ batch_targets_scaled = batch_targets_scaled[match_inds]
564
+
565
+ # no gt bbox matches anchor
566
+ if batch_targets_scaled.shape[0] == 0:
567
+ loss_box += bbox_preds[i].sum() * 0
568
+ loss_cls += cls_scores[i].sum() * 0
569
+ loss_obj += self.loss_obj(
570
+ objectnesses[i], target_obj) * self.obj_level_weights[i]
571
+ continue
572
+
573
+ # 3. Positive samples with additional neighbors
574
+
575
+ # check the left, up, right, bottom sides of the
576
+ # targets grid, and determine whether assigned
577
+ # them as positive samples as well.
578
+ batch_targets_cxcy = batch_targets_scaled[:, 2:4]
579
+ grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
580
+ left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
581
+ (batch_targets_cxcy > 1)).T
582
+ right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
583
+ (grid_xy > 1)).T
584
+ offset_inds = torch.stack(
585
+ (torch.ones_like(left), left, up, right, bottom))
586
+
587
+ batch_targets_scaled = batch_targets_scaled.repeat(
588
+ (5, 1, 1))[offset_inds]
589
+ retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1],
590
+ 1)[offset_inds]
591
+
592
+ # prepare pred results and positive sample indexes to
593
+ # calculate class loss and bbox lo
594
+ _chunk_targets = batch_targets_scaled.chunk(4, 1)
595
+ img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets
596
+ priors_inds, (img_inds, class_inds) = priors_inds.long().view(
597
+ -1), img_class_inds.long().T
598
+
599
+ grid_xy_long = (grid_xy -
600
+ retained_offsets * self.near_neighbor_thr).long()
601
+ grid_x_inds, grid_y_inds = grid_xy_long.T
602
+ bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1)
603
+
604
+ # 4. Calculate loss
605
+ # bbox loss
606
+ retained_bbox_pred = bbox_preds[i].reshape(
607
+ batch_size, self.num_base_priors, -1, h,
608
+ w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
609
+ priors_base_sizes_i = priors_base_sizes_i[priors_inds]
610
+ decoded_bbox_pred = self._decode_bbox_to_xywh(
611
+ retained_bbox_pred, priors_base_sizes_i)
612
+ loss_box_i, iou = self.loss_bbox(decoded_bbox_pred, bboxes_targets)
613
+ loss_box += loss_box_i
614
+
615
+ # obj loss
616
+ iou = iou.detach().clamp(0)
617
+ target_obj[img_inds, priors_inds, grid_y_inds,
618
+ grid_x_inds] = iou.type(target_obj.dtype)
619
+ loss_obj += self.loss_obj(objectnesses[i],
620
+ target_obj) * self.obj_level_weights[i]
621
+
622
+ # cls loss
623
+ if self.num_classes > 1:
624
+ pred_cls_scores = cls_scores[i].reshape(
625
+ batch_size, self.num_base_priors, -1, h,
626
+ w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
627
+
628
+ target_class = torch.full_like(pred_cls_scores, 0.)
629
+ target_class[range(batch_targets_scaled.shape[0]),
630
+ class_inds] = 1.
631
+ loss_cls += self.loss_cls(pred_cls_scores, target_class)
632
+ else:
633
+ loss_cls += cls_scores[i].sum() * 0
634
+
635
+ _, world_size = get_dist_info()
636
+ return dict(
637
+ loss_cls=loss_cls * batch_size * world_size,
638
+ loss_obj=loss_obj * batch_size * world_size,
639
+ loss_bbox=loss_box * batch_size * world_size)
640
+
641
+ def _convert_gt_to_norm_format(self,
642
+ batch_gt_instances: Sequence[InstanceData],
643
+ batch_img_metas: Sequence[dict]) -> Tensor:
644
+ if isinstance(batch_gt_instances, torch.Tensor):
645
+ # fast version
646
+ img_shape = batch_img_metas[0]['batch_input_shape']
647
+ gt_bboxes_xyxy = batch_gt_instances[:, 2:]
648
+ xy1, xy2 = gt_bboxes_xyxy.split((2, 2), dim=-1)
649
+ gt_bboxes_xywh = torch.cat([(xy2 + xy1) / 2, (xy2 - xy1)], dim=-1)
650
+ gt_bboxes_xywh[:, 1::2] /= img_shape[0]
651
+ gt_bboxes_xywh[:, 0::2] /= img_shape[1]
652
+ batch_gt_instances[:, 2:] = gt_bboxes_xywh
653
+
654
+ # (num_base_priors, num_bboxes, 6)
655
+ batch_targets_normed = batch_gt_instances.repeat(
656
+ self.num_base_priors, 1, 1)
657
+ else:
658
+ batch_target_list = []
659
+ # Convert xyxy bbox to yolo format.
660
+ for i, gt_instances in enumerate(batch_gt_instances):
661
+ img_shape = batch_img_metas[i]['batch_input_shape']
662
+ bboxes = gt_instances.bboxes
663
+ labels = gt_instances.labels
664
+
665
+ xy1, xy2 = bboxes.split((2, 2), dim=-1)
666
+ bboxes = torch.cat([(xy2 + xy1) / 2, (xy2 - xy1)], dim=-1)
667
+ # normalized to 0-1
668
+ bboxes[:, 1::2] /= img_shape[0]
669
+ bboxes[:, 0::2] /= img_shape[1]
670
+
671
+ index = bboxes.new_full((len(bboxes), 1), i)
672
+ # (batch_idx, label, normed_bbox)
673
+ target = torch.cat((index, labels[:, None].float(), bboxes),
674
+ dim=1)
675
+ batch_target_list.append(target)
676
+
677
+ # (num_base_priors, num_bboxes, 6)
678
+ batch_targets_normed = torch.cat(
679
+ batch_target_list, dim=0).repeat(self.num_base_priors, 1, 1)
680
+
681
+ # (num_base_priors, num_bboxes, 1)
682
+ batch_targets_prior_inds = self.prior_inds.repeat(
683
+ 1, batch_targets_normed.shape[1])[..., None]
684
+ # (num_base_priors, num_bboxes, 7)
685
+ # (img_ind, labels, bbox_cx, bbox_cy, bbox_w, bbox_h, prior_ind)
686
+ batch_targets_normed = torch.cat(
687
+ (batch_targets_normed, batch_targets_prior_inds), 2)
688
+ return batch_targets_normed
689
+
690
+ def _decode_bbox_to_xywh(self, bbox_pred, priors_base_sizes) -> Tensor:
691
+ bbox_pred = bbox_pred.sigmoid()
692
+ pred_xy = bbox_pred[:, :2] * 2 - 0.5
693
+ pred_wh = (bbox_pred[:, 2:] * 2)**2 * priors_base_sizes
694
+ decoded_bbox_pred = torch.cat((pred_xy, pred_wh), dim=-1)
695
+ return decoded_bbox_pred
696
+
697
+ def _loss_by_feat_with_ignore(
698
+ self, cls_scores: Sequence[Tensor], bbox_preds: Sequence[Tensor],
699
+ objectnesses: Sequence[Tensor],
700
+ batch_gt_instances: Sequence[InstanceData],
701
+ batch_img_metas: Sequence[dict],
702
+ batch_gt_instances_ignore: Sequence[Tensor]) -> dict:
703
+ """Calculate the loss based on the features extracted by the detection
704
+ head.
705
+
706
+ Args:
707
+ cls_scores (Sequence[Tensor]): Box scores for each scale level,
708
+ each is a 4D-tensor, the channel number is
709
+ num_priors * num_classes.
710
+ bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
711
+ level, each is a 4D-tensor, the channel number is
712
+ num_priors * 4.
713
+ objectnesses (Sequence[Tensor]): Score factor for
714
+ all scale level, each is a 4D-tensor, has shape
715
+ (batch_size, 1, H, W).
716
+ batch_gt_instances (Sequence[InstanceData]): Batch of
717
+ gt_instance. It usually includes ``bboxes`` and ``labels``
718
+ attributes.
719
+ batch_img_metas (Sequence[dict]): Meta information of each image,
720
+ e.g., image size, scaling factor, etc.
721
+ batch_gt_instances_ignore (Sequence[Tensor]): Ignore boxes with
722
+ batch_ids and labels, each is a 2D-tensor, the channel number
723
+ is 6, means that (batch_id, label, xmin, ymin, xmax, ymax).
724
+ Returns:
725
+ dict[str, Tensor]: A dictionary of losses.
726
+ """
727
+ # 1. Convert gt to norm format
728
+ batch_targets_normed = self._convert_gt_to_norm_format(
729
+ batch_gt_instances, batch_img_metas)
730
+
731
+ featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
732
+ if featmap_sizes != self.featmap_sizes:
733
+ self.mlvl_priors = self.prior_generator.grid_priors(
734
+ featmap_sizes,
735
+ dtype=cls_scores[0].dtype,
736
+ device=cls_scores[0].device)
737
+ self.featmap_sizes = featmap_sizes
738
+
739
+ device = cls_scores[0].device
740
+ loss_cls = torch.zeros(1, device=device)
741
+ loss_box = torch.zeros(1, device=device)
742
+ loss_obj = torch.zeros(1, device=device)
743
+ scaled_factor = torch.ones(7, device=device)
744
+
745
+ for i in range(self.num_levels):
746
+ batch_size, _, h, w = bbox_preds[i].shape
747
+ target_obj = torch.zeros_like(objectnesses[i])
748
+
749
+ not_ignore_flags = bbox_preds[i].new_ones(batch_size,
750
+ self.num_base_priors, h,
751
+ w)
752
+
753
+ ignore_overlaps = bbox_overlaps(self.mlvl_priors[i],
754
+ batch_gt_instances_ignore[..., 2:],
755
+ 'iof')
756
+ ignore_max_overlaps, ignore_max_ignore_index = ignore_overlaps.max(
757
+ dim=1)
758
+
759
+ batch_inds = batch_gt_instances_ignore[:,
760
+ 0][ignore_max_ignore_index]
761
+ ignore_inds = (ignore_max_overlaps > self.ignore_iof_thr).nonzero(
762
+ as_tuple=True)[0]
763
+ batch_inds = batch_inds[ignore_inds].long()
764
+ ignore_priors, ignore_grid_xs, ignore_grid_ys = get_prior_xy_info(
765
+ ignore_inds, self.num_base_priors, self.featmap_sizes[i])
766
+ not_ignore_flags[batch_inds, ignore_priors, ignore_grid_ys,
767
+ ignore_grid_xs] = 0
768
+
769
+ # empty gt bboxes
770
+ if batch_targets_normed.shape[1] == 0:
771
+ loss_box += bbox_preds[i].sum() * 0
772
+ loss_cls += cls_scores[i].sum() * 0
773
+ loss_obj += self.loss_obj(
774
+ objectnesses[i],
775
+ target_obj,
776
+ weight=not_ignore_flags,
777
+ avg_factor=max(not_ignore_flags.sum(),
778
+ 1)) * self.obj_level_weights[i]
779
+ continue
780
+
781
+ priors_base_sizes_i = self.priors_base_sizes[i]
782
+ # feature map scale whwh
783
+ scaled_factor[2:6] = torch.tensor(
784
+ bbox_preds[i].shape)[[3, 2, 3, 2]]
785
+ # Scale batch_targets from range 0-1 to range 0-features_maps size.
786
+ # (num_base_priors, num_bboxes, 7)
787
+ batch_targets_scaled = batch_targets_normed * scaled_factor
788
+
789
+ # 2. Shape match
790
+ wh_ratio = batch_targets_scaled[...,
791
+ 4:6] / priors_base_sizes_i[:, None]
792
+ match_inds = torch.max(
793
+ wh_ratio, 1 / wh_ratio).max(2)[0] < self.prior_match_thr
794
+ batch_targets_scaled = batch_targets_scaled[match_inds]
795
+
796
+ # no gt bbox matches anchor
797
+ if batch_targets_scaled.shape[0] == 0:
798
+ loss_box += bbox_preds[i].sum() * 0
799
+ loss_cls += cls_scores[i].sum() * 0
800
+ loss_obj += self.loss_obj(
801
+ objectnesses[i],
802
+ target_obj,
803
+ weight=not_ignore_flags,
804
+ avg_factor=max(not_ignore_flags.sum(),
805
+ 1)) * self.obj_level_weights[i]
806
+ continue
807
+
808
+ # 3. Positive samples with additional neighbors
809
+
810
+ # check the left, up, right, bottom sides of the
811
+ # targets grid, and determine whether assigned
812
+ # them as positive samples as well.
813
+ batch_targets_cxcy = batch_targets_scaled[:, 2:4]
814
+ grid_xy = scaled_factor[[2, 3]] - batch_targets_cxcy
815
+ left, up = ((batch_targets_cxcy % 1 < self.near_neighbor_thr) &
816
+ (batch_targets_cxcy > 1)).T
817
+ right, bottom = ((grid_xy % 1 < self.near_neighbor_thr) &
818
+ (grid_xy > 1)).T
819
+ offset_inds = torch.stack(
820
+ (torch.ones_like(left), left, up, right, bottom))
821
+
822
+ batch_targets_scaled = batch_targets_scaled.repeat(
823
+ (5, 1, 1))[offset_inds]
824
+ retained_offsets = self.grid_offset.repeat(1, offset_inds.shape[1],
825
+ 1)[offset_inds]
826
+
827
+ # prepare pred results and positive sample indexes to
828
+ # calculate class loss and bbox lo
829
+ _chunk_targets = batch_targets_scaled.chunk(4, 1)
830
+ img_class_inds, grid_xy, grid_wh, priors_inds = _chunk_targets
831
+ priors_inds, (img_inds, class_inds) = priors_inds.long().view(
832
+ -1), img_class_inds.long().T
833
+
834
+ grid_xy_long = (grid_xy -
835
+ retained_offsets * self.near_neighbor_thr).long()
836
+ grid_x_inds, grid_y_inds = grid_xy_long.T
837
+ bboxes_targets = torch.cat((grid_xy - grid_xy_long, grid_wh), 1)
838
+
839
+ # 4. Calculate loss
840
+ # bbox loss
841
+ retained_bbox_pred = bbox_preds[i].reshape(
842
+ batch_size, self.num_base_priors, -1, h,
843
+ w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
844
+ priors_base_sizes_i = priors_base_sizes_i[priors_inds]
845
+ decoded_bbox_pred = self._decode_bbox_to_xywh(
846
+ retained_bbox_pred, priors_base_sizes_i)
847
+
848
+ not_ignore_weights = not_ignore_flags[img_inds, priors_inds,
849
+ grid_y_inds, grid_x_inds]
850
+ loss_box_i, iou = self.loss_bbox(
851
+ decoded_bbox_pred,
852
+ bboxes_targets,
853
+ weight=not_ignore_weights,
854
+ avg_factor=max(not_ignore_weights.sum(), 1))
855
+ loss_box += loss_box_i
856
+
857
+ # obj loss
858
+ iou = iou.detach().clamp(0)
859
+ target_obj[img_inds, priors_inds, grid_y_inds,
860
+ grid_x_inds] = iou.type(target_obj.dtype)
861
+ loss_obj += self.loss_obj(
862
+ objectnesses[i],
863
+ target_obj,
864
+ weight=not_ignore_flags,
865
+ avg_factor=max(not_ignore_flags.sum(),
866
+ 1)) * self.obj_level_weights[i]
867
+
868
+ # cls loss
869
+ if self.num_classes > 1:
870
+ pred_cls_scores = cls_scores[i].reshape(
871
+ batch_size, self.num_base_priors, -1, h,
872
+ w)[img_inds, priors_inds, :, grid_y_inds, grid_x_inds]
873
+
874
+ target_class = torch.full_like(pred_cls_scores, 0.)
875
+ target_class[range(batch_targets_scaled.shape[0]),
876
+ class_inds] = 1.
877
+ loss_cls += self.loss_cls(
878
+ pred_cls_scores,
879
+ target_class,
880
+ weight=not_ignore_weights[:, None].repeat(
881
+ 1, self.num_classes),
882
+ avg_factor=max(not_ignore_weights.sum(), 1))
883
+ else:
884
+ loss_cls += cls_scores[i].sum() * 0
885
+
886
+ _, world_size = get_dist_info()
887
+ return dict(
888
+ loss_cls=loss_cls * batch_size * world_size,
889
+ loss_obj=loss_obj * batch_size * world_size,
890
+ loss_bbox=loss_box * batch_size * world_size)
mmyolo/models/dense_heads/yolov6_head.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List, Sequence, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from mmcv.cnn import ConvModule
7
+ from mmdet.models.utils import multi_apply
8
+ from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
9
+ OptMultiConfig)
10
+ from mmengine import MessageHub
11
+ from mmengine.dist import get_dist_info
12
+ from mmengine.model import BaseModule, bias_init_with_prob
13
+ from mmengine.structures import InstanceData
14
+ from torch import Tensor
15
+
16
+ from mmyolo.registry import MODELS, TASK_UTILS
17
+ from ..utils import gt_instances_preprocess
18
+ from .yolov5_head import YOLOv5Head
19
+
20
+
21
+ @MODELS.register_module()
22
+ class YOLOv6HeadModule(BaseModule):
23
+ """YOLOv6Head head module used in `YOLOv6.
24
+
25
+ <https://arxiv.org/pdf/2209.02976>`_.
26
+
27
+ Args:
28
+ num_classes (int): Number of categories excluding the background
29
+ category.
30
+ in_channels (Union[int, Sequence]): Number of channels in the input
31
+ feature map.
32
+ widen_factor (float): Width multiplier, multiply number of
33
+ channels in each layer by this amount. Defaults to 1.0.
34
+ num_base_priors: (int): The number of priors (points) at a point
35
+ on the feature grid.
36
+ featmap_strides (Sequence[int]): Downsample factor of each feature map.
37
+ Defaults to [8, 16, 32].
38
+ None, otherwise False. Defaults to "auto".
39
+ norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
40
+ layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
41
+ act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
42
+ Defaults to None.
43
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
44
+ list[dict], optional): Initialization config dict.
45
+ Defaults to None.
46
+ """
47
+
48
+ def __init__(self,
49
+ num_classes: int,
50
+ in_channels: Union[int, Sequence],
51
+ widen_factor: float = 1.0,
52
+ num_base_priors: int = 1,
53
+ featmap_strides: Sequence[int] = (8, 16, 32),
54
+ norm_cfg: ConfigType = dict(
55
+ type='BN', momentum=0.03, eps=0.001),
56
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
57
+ init_cfg: OptMultiConfig = None):
58
+ super().__init__(init_cfg=init_cfg)
59
+
60
+ self.num_classes = num_classes
61
+ self.featmap_strides = featmap_strides
62
+ self.num_levels = len(self.featmap_strides)
63
+ self.num_base_priors = num_base_priors
64
+ self.norm_cfg = norm_cfg
65
+ self.act_cfg = act_cfg
66
+
67
+ if isinstance(in_channels, int):
68
+ self.in_channels = [int(in_channels * widen_factor)
69
+ ] * self.num_levels
70
+ else:
71
+ self.in_channels = [int(i * widen_factor) for i in in_channels]
72
+
73
+ self._init_layers()
74
+
75
+ def _init_layers(self):
76
+ """initialize conv layers in YOLOv6 head."""
77
+ # Init decouple head
78
+ self.cls_convs = nn.ModuleList()
79
+ self.reg_convs = nn.ModuleList()
80
+ self.cls_preds = nn.ModuleList()
81
+ self.reg_preds = nn.ModuleList()
82
+ self.stems = nn.ModuleList()
83
+ for i in range(self.num_levels):
84
+ self.stems.append(
85
+ ConvModule(
86
+ in_channels=self.in_channels[i],
87
+ out_channels=self.in_channels[i],
88
+ kernel_size=1,
89
+ stride=1,
90
+ padding=1 // 2,
91
+ norm_cfg=self.norm_cfg,
92
+ act_cfg=self.act_cfg))
93
+ self.cls_convs.append(
94
+ ConvModule(
95
+ in_channels=self.in_channels[i],
96
+ out_channels=self.in_channels[i],
97
+ kernel_size=3,
98
+ stride=1,
99
+ padding=3 // 2,
100
+ norm_cfg=self.norm_cfg,
101
+ act_cfg=self.act_cfg))
102
+ self.reg_convs.append(
103
+ ConvModule(
104
+ in_channels=self.in_channels[i],
105
+ out_channels=self.in_channels[i],
106
+ kernel_size=3,
107
+ stride=1,
108
+ padding=3 // 2,
109
+ norm_cfg=self.norm_cfg,
110
+ act_cfg=self.act_cfg))
111
+ self.cls_preds.append(
112
+ nn.Conv2d(
113
+ in_channels=self.in_channels[i],
114
+ out_channels=self.num_base_priors * self.num_classes,
115
+ kernel_size=1))
116
+ self.reg_preds.append(
117
+ nn.Conv2d(
118
+ in_channels=self.in_channels[i],
119
+ out_channels=self.num_base_priors * 4,
120
+ kernel_size=1))
121
+
122
+ def init_weights(self):
123
+ super().init_weights()
124
+ bias_init = bias_init_with_prob(0.01)
125
+ for conv in self.cls_preds:
126
+ conv.bias.data.fill_(bias_init)
127
+ conv.weight.data.fill_(0.)
128
+
129
+ for conv in self.reg_preds:
130
+ conv.bias.data.fill_(1.0)
131
+ conv.weight.data.fill_(0.)
132
+
133
+ def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
134
+ """Forward features from the upstream network.
135
+
136
+ Args:
137
+ x (Tuple[Tensor]): Features from the upstream network, each is
138
+ a 4D-tensor.
139
+ Returns:
140
+ Tuple[List]: A tuple of multi-level classification scores, bbox
141
+ predictions.
142
+ """
143
+ assert len(x) == self.num_levels
144
+ return multi_apply(self.forward_single, x, self.stems, self.cls_convs,
145
+ self.cls_preds, self.reg_convs, self.reg_preds)
146
+
147
+ def forward_single(self, x: Tensor, stem: nn.Module, cls_conv: nn.Module,
148
+ cls_pred: nn.Module, reg_conv: nn.Module,
149
+ reg_pred: nn.Module) -> Tuple[Tensor, Tensor]:
150
+ """Forward feature of a single scale level."""
151
+ y = stem(x)
152
+ cls_x = y
153
+ reg_x = y
154
+ cls_feat = cls_conv(cls_x)
155
+ reg_feat = reg_conv(reg_x)
156
+
157
+ cls_score = cls_pred(cls_feat)
158
+ bbox_pred = reg_pred(reg_feat)
159
+
160
+ return cls_score, bbox_pred
161
+
162
+
163
+ @MODELS.register_module()
164
+ class YOLOv6Head(YOLOv5Head):
165
+ """YOLOv6Head head used in `YOLOv6 <https://arxiv.org/pdf/2209.02976>`_.
166
+
167
+ Args:
168
+ head_module(ConfigType): Base module used for YOLOv6Head
169
+ prior_generator(dict): Points generator feature maps
170
+ in 2D points-based detectors.
171
+ loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
172
+ loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
173
+ train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
174
+ anchor head. Defaults to None.
175
+ test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
176
+ anchor head. Defaults to None.
177
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
178
+ list[dict], optional): Initialization config dict.
179
+ Defaults to None.
180
+ """
181
+
182
+ def __init__(self,
183
+ head_module: ConfigType,
184
+ prior_generator: ConfigType = dict(
185
+ type='mmdet.MlvlPointGenerator',
186
+ offset=0.5,
187
+ strides=[8, 16, 32]),
188
+ bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
189
+ loss_cls: ConfigType = dict(
190
+ type='mmdet.VarifocalLoss',
191
+ use_sigmoid=True,
192
+ alpha=0.75,
193
+ gamma=2.0,
194
+ iou_weighted=True,
195
+ reduction='sum',
196
+ loss_weight=1.0),
197
+ loss_bbox: ConfigType = dict(
198
+ type='IoULoss',
199
+ iou_mode='giou',
200
+ bbox_format='xyxy',
201
+ reduction='mean',
202
+ loss_weight=2.5,
203
+ return_iou=False),
204
+ train_cfg: OptConfigType = None,
205
+ test_cfg: OptConfigType = None,
206
+ init_cfg: OptMultiConfig = None):
207
+ super().__init__(
208
+ head_module=head_module,
209
+ prior_generator=prior_generator,
210
+ bbox_coder=bbox_coder,
211
+ loss_cls=loss_cls,
212
+ loss_bbox=loss_bbox,
213
+ train_cfg=train_cfg,
214
+ test_cfg=test_cfg,
215
+ init_cfg=init_cfg)
216
+ # yolov6 doesn't need loss_obj
217
+ self.loss_obj = None
218
+
219
+ def special_init(self):
220
+ """Since YOLO series algorithms will inherit from YOLOv5Head, but
221
+ different algorithms have special initialization process.
222
+
223
+ The special_init function is designed to deal with this situation.
224
+ """
225
+ if self.train_cfg:
226
+ self.initial_epoch = self.train_cfg['initial_epoch']
227
+ self.initial_assigner = TASK_UTILS.build(
228
+ self.train_cfg.initial_assigner)
229
+ self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
230
+
231
+ # Add common attributes to reduce calculation
232
+ self.featmap_sizes_train = None
233
+ self.num_level_priors = None
234
+ self.flatten_priors_train = None
235
+ self.stride_tensor = None
236
+
237
+ def loss_by_feat(
238
+ self,
239
+ cls_scores: Sequence[Tensor],
240
+ bbox_preds: Sequence[Tensor],
241
+ batch_gt_instances: Sequence[InstanceData],
242
+ batch_img_metas: Sequence[dict],
243
+ batch_gt_instances_ignore: OptInstanceList = None) -> dict:
244
+ """Calculate the loss based on the features extracted by the detection
245
+ head.
246
+
247
+ Args:
248
+ cls_scores (Sequence[Tensor]): Box scores for each scale level,
249
+ each is a 4D-tensor, the channel number is
250
+ num_priors * num_classes.
251
+ bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
252
+ level, each is a 4D-tensor, the channel number is
253
+ num_priors * 4.
254
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
255
+ gt_instance. It usually includes ``bboxes`` and ``labels``
256
+ attributes.
257
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
258
+ image size, scaling factor, etc.
259
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
260
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
261
+ data that is ignored during training and testing.
262
+ Defaults to None.
263
+ Returns:
264
+ dict[str, Tensor]: A dictionary of losses.
265
+ """
266
+
267
+ # get epoch information from message hub
268
+ message_hub = MessageHub.get_current_instance()
269
+ current_epoch = message_hub.get_info('epoch')
270
+
271
+ num_imgs = len(batch_img_metas)
272
+ if batch_gt_instances_ignore is None:
273
+ batch_gt_instances_ignore = [None] * num_imgs
274
+
275
+ current_featmap_sizes = [
276
+ cls_score.shape[2:] for cls_score in cls_scores
277
+ ]
278
+ # If the shape does not equal, generate new one
279
+ if current_featmap_sizes != self.featmap_sizes_train:
280
+ self.featmap_sizes_train = current_featmap_sizes
281
+
282
+ mlvl_priors_with_stride = self.prior_generator.grid_priors(
283
+ self.featmap_sizes_train,
284
+ dtype=cls_scores[0].dtype,
285
+ device=cls_scores[0].device,
286
+ with_stride=True)
287
+
288
+ self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
289
+ self.flatten_priors_train = torch.cat(
290
+ mlvl_priors_with_stride, dim=0)
291
+ self.stride_tensor = self.flatten_priors_train[..., [2]]
292
+
293
+ # gt info
294
+ gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
295
+ gt_labels = gt_info[:, :, :1]
296
+ gt_bboxes = gt_info[:, :, 1:] # xyxy
297
+ pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
298
+
299
+ # pred info
300
+ flatten_cls_preds = [
301
+ cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
302
+ self.num_classes)
303
+ for cls_pred in cls_scores
304
+ ]
305
+
306
+ flatten_pred_bboxes = [
307
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
308
+ for bbox_pred in bbox_preds
309
+ ]
310
+
311
+ flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
312
+ flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
313
+ flatten_pred_bboxes = self.bbox_coder.decode(
314
+ self.flatten_priors_train[..., :2], flatten_pred_bboxes,
315
+ self.stride_tensor[:, 0])
316
+ pred_scores = torch.sigmoid(flatten_cls_preds)
317
+
318
+ if current_epoch < self.initial_epoch:
319
+ assigned_result = self.initial_assigner(
320
+ flatten_pred_bboxes.detach(), self.flatten_priors_train,
321
+ self.num_level_priors, gt_labels, gt_bboxes, pad_bbox_flag)
322
+ else:
323
+ assigned_result = self.assigner(flatten_pred_bboxes.detach(),
324
+ pred_scores.detach(),
325
+ self.flatten_priors_train,
326
+ gt_labels, gt_bboxes,
327
+ pad_bbox_flag)
328
+
329
+ assigned_bboxes = assigned_result['assigned_bboxes']
330
+ assigned_scores = assigned_result['assigned_scores']
331
+ fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
332
+
333
+ # cls loss
334
+ with torch.cuda.amp.autocast(enabled=False):
335
+ loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores)
336
+
337
+ # rescale bbox
338
+ assigned_bboxes /= self.stride_tensor
339
+ flatten_pred_bboxes /= self.stride_tensor
340
+
341
+ # TODO: Add all_reduce makes training more stable
342
+ assigned_scores_sum = assigned_scores.sum()
343
+ if assigned_scores_sum > 0:
344
+ loss_cls /= assigned_scores_sum
345
+
346
+ # select positive samples mask
347
+ num_pos = fg_mask_pre_prior.sum()
348
+ if num_pos > 0:
349
+ # when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
350
+ # will not report an error
351
+ # iou loss
352
+ prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
353
+ pred_bboxes_pos = torch.masked_select(
354
+ flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
355
+ assigned_bboxes_pos = torch.masked_select(
356
+ assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
357
+ bbox_weight = torch.masked_select(
358
+ assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
359
+ loss_bbox = self.loss_bbox(
360
+ pred_bboxes_pos,
361
+ assigned_bboxes_pos,
362
+ weight=bbox_weight,
363
+ avg_factor=assigned_scores_sum)
364
+ else:
365
+ loss_bbox = flatten_pred_bboxes.sum() * 0
366
+
367
+ _, world_size = get_dist_info()
368
+ return dict(
369
+ loss_cls=loss_cls * world_size, loss_bbox=loss_bbox * world_size)
mmyolo/models/dense_heads/yolov7_head.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ from typing import List, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from mmcv.cnn import ConvModule
8
+ from mmdet.models.utils import multi_apply
9
+ from mmdet.utils import ConfigType, OptInstanceList
10
+ from mmengine.dist import get_dist_info
11
+ from mmengine.structures import InstanceData
12
+ from torch import Tensor
13
+
14
+ from mmyolo.registry import MODELS
15
+ from ..layers import ImplicitA, ImplicitM
16
+ from ..task_modules.assigners.batch_yolov7_assigner import BatchYOLOv7Assigner
17
+ from .yolov5_head import YOLOv5Head, YOLOv5HeadModule
18
+
19
+
20
+ @MODELS.register_module()
21
+ class YOLOv7HeadModule(YOLOv5HeadModule):
22
+ """YOLOv7Head head module used in YOLOv7."""
23
+
24
+ def _init_layers(self):
25
+ """initialize conv layers in YOLOv7 head."""
26
+ self.convs_pred = nn.ModuleList()
27
+ for i in range(self.num_levels):
28
+ conv_pred = nn.Sequential(
29
+ ImplicitA(self.in_channels[i]),
30
+ nn.Conv2d(self.in_channels[i],
31
+ self.num_base_priors * self.num_out_attrib, 1),
32
+ ImplicitM(self.num_base_priors * self.num_out_attrib),
33
+ )
34
+ self.convs_pred.append(conv_pred)
35
+
36
+ def init_weights(self):
37
+ """Initialize the bias of YOLOv7 head."""
38
+ super(YOLOv5HeadModule, self).init_weights()
39
+ for mi, s in zip(self.convs_pred, self.featmap_strides): # from
40
+ mi = mi[1] # nn.Conv2d
41
+
42
+ b = mi.bias.data.view(3, -1)
43
+ # obj (8 objects per 640 image)
44
+ b.data[:, 4] += math.log(8 / (640 / s)**2)
45
+ b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
46
+
47
+ mi.bias.data = b.view(-1)
48
+
49
+
50
+ @MODELS.register_module()
51
+ class YOLOv7p6HeadModule(YOLOv5HeadModule):
52
+ """YOLOv7Head head module used in YOLOv7."""
53
+
54
+ def __init__(self,
55
+ *args,
56
+ main_out_channels: Sequence[int] = [256, 512, 768, 1024],
57
+ aux_out_channels: Sequence[int] = [320, 640, 960, 1280],
58
+ use_aux: bool = True,
59
+ norm_cfg: ConfigType = dict(
60
+ type='BN', momentum=0.03, eps=0.001),
61
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
62
+ **kwargs):
63
+ self.main_out_channels = main_out_channels
64
+ self.aux_out_channels = aux_out_channels
65
+ self.use_aux = use_aux
66
+ self.norm_cfg = norm_cfg
67
+ self.act_cfg = act_cfg
68
+ super().__init__(*args, **kwargs)
69
+
70
+ def _init_layers(self):
71
+ """initialize conv layers in YOLOv7 head."""
72
+ self.main_convs_pred = nn.ModuleList()
73
+ for i in range(self.num_levels):
74
+ conv_pred = nn.Sequential(
75
+ ConvModule(
76
+ self.in_channels[i],
77
+ self.main_out_channels[i],
78
+ 3,
79
+ padding=1,
80
+ norm_cfg=self.norm_cfg,
81
+ act_cfg=self.act_cfg),
82
+ ImplicitA(self.main_out_channels[i]),
83
+ nn.Conv2d(self.main_out_channels[i],
84
+ self.num_base_priors * self.num_out_attrib, 1),
85
+ ImplicitM(self.num_base_priors * self.num_out_attrib),
86
+ )
87
+ self.main_convs_pred.append(conv_pred)
88
+
89
+ if self.use_aux:
90
+ self.aux_convs_pred = nn.ModuleList()
91
+ for i in range(self.num_levels):
92
+ aux_pred = nn.Sequential(
93
+ ConvModule(
94
+ self.in_channels[i],
95
+ self.aux_out_channels[i],
96
+ 3,
97
+ padding=1,
98
+ norm_cfg=self.norm_cfg,
99
+ act_cfg=self.act_cfg),
100
+ nn.Conv2d(self.aux_out_channels[i],
101
+ self.num_base_priors * self.num_out_attrib, 1))
102
+ self.aux_convs_pred.append(aux_pred)
103
+ else:
104
+ self.aux_convs_pred = [None] * len(self.main_convs_pred)
105
+
106
+ def init_weights(self):
107
+ """Initialize the bias of YOLOv5 head."""
108
+ super(YOLOv5HeadModule, self).init_weights()
109
+ for mi, aux, s in zip(self.main_convs_pred, self.aux_convs_pred,
110
+ self.featmap_strides): # from
111
+ mi = mi[2] # nn.Conv2d
112
+ b = mi.bias.data.view(3, -1)
113
+ # obj (8 objects per 640 image)
114
+ b.data[:, 4] += math.log(8 / (640 / s)**2)
115
+ b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
116
+ mi.bias.data = b.view(-1)
117
+
118
+ if self.use_aux:
119
+ aux = aux[1] # nn.Conv2d
120
+ b = aux.bias.data.view(3, -1)
121
+ # obj (8 objects per 640 image)
122
+ b.data[:, 4] += math.log(8 / (640 / s)**2)
123
+ b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.99))
124
+ mi.bias.data = b.view(-1)
125
+
126
+ def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
127
+ """Forward features from the upstream network.
128
+
129
+ Args:
130
+ x (Tuple[Tensor]): Features from the upstream network, each is
131
+ a 4D-tensor.
132
+ Returns:
133
+ Tuple[List]: A tuple of multi-level classification scores, bbox
134
+ predictions, and objectnesses.
135
+ """
136
+ assert len(x) == self.num_levels
137
+ return multi_apply(self.forward_single, x, self.main_convs_pred,
138
+ self.aux_convs_pred)
139
+
140
+ def forward_single(self, x: Tensor, convs: nn.Module,
141
+ aux_convs: Optional[nn.Module]) \
142
+ -> Tuple[Union[Tensor, List], Union[Tensor, List],
143
+ Union[Tensor, List]]:
144
+ """Forward feature of a single scale level."""
145
+
146
+ pred_map = convs(x)
147
+ bs, _, ny, nx = pred_map.shape
148
+ pred_map = pred_map.view(bs, self.num_base_priors, self.num_out_attrib,
149
+ ny, nx)
150
+
151
+ cls_score = pred_map[:, :, 5:, ...].reshape(bs, -1, ny, nx)
152
+ bbox_pred = pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
153
+ objectness = pred_map[:, :, 4:5, ...].reshape(bs, -1, ny, nx)
154
+
155
+ if not self.training or not self.use_aux:
156
+ return cls_score, bbox_pred, objectness
157
+ else:
158
+ aux_pred_map = aux_convs(x)
159
+ aux_pred_map = aux_pred_map.view(bs, self.num_base_priors,
160
+ self.num_out_attrib, ny, nx)
161
+ aux_cls_score = aux_pred_map[:, :, 5:, ...].reshape(bs, -1, ny, nx)
162
+ aux_bbox_pred = aux_pred_map[:, :, :4, ...].reshape(bs, -1, ny, nx)
163
+ aux_objectness = aux_pred_map[:, :, 4:5,
164
+ ...].reshape(bs, -1, ny, nx)
165
+
166
+ return [cls_score,
167
+ aux_cls_score], [bbox_pred, aux_bbox_pred
168
+ ], [objectness, aux_objectness]
169
+
170
+
171
+ @MODELS.register_module()
172
+ class YOLOv7Head(YOLOv5Head):
173
+ """YOLOv7Head head used in `YOLOv7 <https://arxiv.org/abs/2207.02696>`_.
174
+
175
+ Args:
176
+ simota_candidate_topk (int): The candidate top-k which used to
177
+ get top-k ious to calculate dynamic-k in BatchYOLOv7Assigner.
178
+ Defaults to 10.
179
+ simota_iou_weight (float): The scale factor for regression
180
+ iou cost in BatchYOLOv7Assigner. Defaults to 3.0.
181
+ simota_cls_weight (float): The scale factor for classification
182
+ cost in BatchYOLOv7Assigner. Defaults to 1.0.
183
+ """
184
+
185
+ def __init__(self,
186
+ *args,
187
+ simota_candidate_topk: int = 20,
188
+ simota_iou_weight: float = 3.0,
189
+ simota_cls_weight: float = 1.0,
190
+ aux_loss_weights: float = 0.25,
191
+ **kwargs):
192
+ super().__init__(*args, **kwargs)
193
+ self.aux_loss_weights = aux_loss_weights
194
+ self.assigner = BatchYOLOv7Assigner(
195
+ num_classes=self.num_classes,
196
+ num_base_priors=self.num_base_priors,
197
+ featmap_strides=self.featmap_strides,
198
+ prior_match_thr=self.prior_match_thr,
199
+ candidate_topk=simota_candidate_topk,
200
+ iou_weight=simota_iou_weight,
201
+ cls_weight=simota_cls_weight)
202
+
203
+ def loss_by_feat(
204
+ self,
205
+ cls_scores: Sequence[Union[Tensor, List]],
206
+ bbox_preds: Sequence[Union[Tensor, List]],
207
+ objectnesses: Sequence[Union[Tensor, List]],
208
+ batch_gt_instances: Sequence[InstanceData],
209
+ batch_img_metas: Sequence[dict],
210
+ batch_gt_instances_ignore: OptInstanceList = None) -> dict:
211
+ """Calculate the loss based on the features extracted by the detection
212
+ head.
213
+
214
+ Args:
215
+ cls_scores (Sequence[Tensor]): Box scores for each scale level,
216
+ each is a 4D-tensor, the channel number is
217
+ num_priors * num_classes.
218
+ bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
219
+ level, each is a 4D-tensor, the channel number is
220
+ num_priors * 4.
221
+ objectnesses (Sequence[Tensor]): Score factor for
222
+ all scale level, each is a 4D-tensor, has shape
223
+ (batch_size, 1, H, W).
224
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
225
+ gt_instance. It usually includes ``bboxes`` and ``labels``
226
+ attributes.
227
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
228
+ image size, scaling factor, etc.
229
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
230
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
231
+ data that is ignored during training and testing.
232
+ Defaults to None.
233
+ Returns:
234
+ dict[str, Tensor]: A dictionary of losses.
235
+ """
236
+
237
+ if isinstance(cls_scores[0], Sequence):
238
+ with_aux = True
239
+ batch_size = cls_scores[0][0].shape[0]
240
+ device = cls_scores[0][0].device
241
+
242
+ bbox_preds_main, bbox_preds_aux = zip(*bbox_preds)
243
+ objectnesses_main, objectnesses_aux = zip(*objectnesses)
244
+ cls_scores_main, cls_scores_aux = zip(*cls_scores)
245
+
246
+ head_preds = self._merge_predict_results(bbox_preds_main,
247
+ objectnesses_main,
248
+ cls_scores_main)
249
+ head_preds_aux = self._merge_predict_results(
250
+ bbox_preds_aux, objectnesses_aux, cls_scores_aux)
251
+ else:
252
+ with_aux = False
253
+ batch_size = cls_scores[0].shape[0]
254
+ device = cls_scores[0].device
255
+
256
+ head_preds = self._merge_predict_results(bbox_preds, objectnesses,
257
+ cls_scores)
258
+
259
+ # Convert gt to norm xywh format
260
+ # (num_base_priors, num_batch_gt, 7)
261
+ # 7 is mean (batch_idx, cls_id, x_norm, y_norm,
262
+ # w_norm, h_norm, prior_idx)
263
+ batch_targets_normed = self._convert_gt_to_norm_format(
264
+ batch_gt_instances, batch_img_metas)
265
+
266
+ scaled_factors = [
267
+ torch.tensor(head_pred.shape, device=device)[[3, 2, 3, 2]]
268
+ for head_pred in head_preds
269
+ ]
270
+
271
+ loss_cls, loss_obj, loss_box = self._calc_loss(
272
+ head_preds=head_preds,
273
+ head_preds_aux=None,
274
+ batch_targets_normed=batch_targets_normed,
275
+ near_neighbor_thr=self.near_neighbor_thr,
276
+ scaled_factors=scaled_factors,
277
+ batch_img_metas=batch_img_metas,
278
+ device=device)
279
+
280
+ if with_aux:
281
+ loss_cls_aux, loss_obj_aux, loss_box_aux = self._calc_loss(
282
+ head_preds=head_preds,
283
+ head_preds_aux=head_preds_aux,
284
+ batch_targets_normed=batch_targets_normed,
285
+ near_neighbor_thr=self.near_neighbor_thr * 2,
286
+ scaled_factors=scaled_factors,
287
+ batch_img_metas=batch_img_metas,
288
+ device=device)
289
+ loss_cls += self.aux_loss_weights * loss_cls_aux
290
+ loss_obj += self.aux_loss_weights * loss_obj_aux
291
+ loss_box += self.aux_loss_weights * loss_box_aux
292
+
293
+ _, world_size = get_dist_info()
294
+ return dict(
295
+ loss_cls=loss_cls * batch_size * world_size,
296
+ loss_obj=loss_obj * batch_size * world_size,
297
+ loss_bbox=loss_box * batch_size * world_size)
298
+
299
+ def _calc_loss(self, head_preds, head_preds_aux, batch_targets_normed,
300
+ near_neighbor_thr, scaled_factors, batch_img_metas, device):
301
+ loss_cls = torch.zeros(1, device=device)
302
+ loss_box = torch.zeros(1, device=device)
303
+ loss_obj = torch.zeros(1, device=device)
304
+
305
+ assigner_results = self.assigner(
306
+ head_preds,
307
+ batch_targets_normed,
308
+ batch_img_metas[0]['batch_input_shape'],
309
+ self.priors_base_sizes,
310
+ self.grid_offset,
311
+ near_neighbor_thr=near_neighbor_thr)
312
+ # mlvl is mean multi_level
313
+ mlvl_positive_infos = assigner_results['mlvl_positive_infos']
314
+ mlvl_priors = assigner_results['mlvl_priors']
315
+ mlvl_targets_normed = assigner_results['mlvl_targets_normed']
316
+
317
+ if head_preds_aux is not None:
318
+ # This is mean calc aux branch loss
319
+ head_preds = head_preds_aux
320
+
321
+ for i, head_pred in enumerate(head_preds):
322
+ batch_inds, proir_idx, grid_x, grid_y = mlvl_positive_infos[i].T
323
+ num_pred_positive = batch_inds.shape[0]
324
+ target_obj = torch.zeros_like(head_pred[..., 0])
325
+ # empty positive sampler
326
+ if num_pred_positive == 0:
327
+ loss_box += head_pred[..., :4].sum() * 0
328
+ loss_cls += head_pred[..., 5:].sum() * 0
329
+ loss_obj += self.loss_obj(
330
+ head_pred[..., 4], target_obj) * self.obj_level_weights[i]
331
+ continue
332
+
333
+ priors = mlvl_priors[i]
334
+ targets_normed = mlvl_targets_normed[i]
335
+
336
+ head_pred_positive = head_pred[batch_inds, proir_idx, grid_y,
337
+ grid_x]
338
+
339
+ # calc bbox loss
340
+ grid_xy = torch.stack([grid_x, grid_y], dim=1)
341
+ decoded_pred_bbox = self._decode_bbox_to_xywh(
342
+ head_pred_positive[:, :4], priors, grid_xy)
343
+ target_bbox_scaled = targets_normed[:, 2:6] * scaled_factors[i]
344
+
345
+ loss_box_i, iou = self.loss_bbox(decoded_pred_bbox,
346
+ target_bbox_scaled)
347
+ loss_box += loss_box_i
348
+
349
+ # calc obj loss
350
+ target_obj[batch_inds, proir_idx, grid_y,
351
+ grid_x] = iou.detach().clamp(0).type(target_obj.dtype)
352
+ loss_obj += self.loss_obj(head_pred[..., 4],
353
+ target_obj) * self.obj_level_weights[i]
354
+
355
+ # calc cls loss
356
+ if self.num_classes > 1:
357
+ pred_cls_scores = targets_normed[:, 1].long()
358
+ target_class = torch.full_like(
359
+ head_pred_positive[:, 5:], 0., device=device)
360
+ target_class[range(num_pred_positive), pred_cls_scores] = 1.
361
+ loss_cls += self.loss_cls(head_pred_positive[:, 5:],
362
+ target_class)
363
+ else:
364
+ loss_cls += head_pred_positive[:, 5:].sum() * 0
365
+ return loss_cls, loss_obj, loss_box
366
+
367
+ def _merge_predict_results(self, bbox_preds: Sequence[Tensor],
368
+ objectnesses: Sequence[Tensor],
369
+ cls_scores: Sequence[Tensor]) -> List[Tensor]:
370
+ """Merge predict output from 3 heads.
371
+
372
+ Args:
373
+ cls_scores (Sequence[Tensor]): Box scores for each scale level,
374
+ each is a 4D-tensor, the channel number is
375
+ num_priors * num_classes.
376
+ bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
377
+ level, each is a 4D-tensor, the channel number is
378
+ num_priors * 4.
379
+ objectnesses (Sequence[Tensor]): Score factor for
380
+ all scale level, each is a 4D-tensor, has shape
381
+ (batch_size, 1, H, W).
382
+
383
+ Returns:
384
+ List[Tensor]: Merged output.
385
+ """
386
+ head_preds = []
387
+ for bbox_pred, objectness, cls_score in zip(bbox_preds, objectnesses,
388
+ cls_scores):
389
+ b, _, h, w = bbox_pred.shape
390
+ bbox_pred = bbox_pred.reshape(b, self.num_base_priors, -1, h, w)
391
+ objectness = objectness.reshape(b, self.num_base_priors, -1, h, w)
392
+ cls_score = cls_score.reshape(b, self.num_base_priors, -1, h, w)
393
+ head_pred = torch.cat([bbox_pred, objectness, cls_score],
394
+ dim=2).permute(0, 1, 3, 4, 2).contiguous()
395
+ head_preds.append(head_pred)
396
+ return head_preds
397
+
398
+ def _decode_bbox_to_xywh(self, bbox_pred, priors_base_sizes,
399
+ grid_xy) -> Tensor:
400
+ bbox_pred = bbox_pred.sigmoid()
401
+ pred_xy = bbox_pred[:, :2] * 2 - 0.5 + grid_xy
402
+ pred_wh = (bbox_pred[:, 2:] * 2)**2 * priors_base_sizes
403
+ decoded_bbox_pred = torch.cat((pred_xy, pred_wh), dim=-1)
404
+ return decoded_bbox_pred
mmyolo/models/dense_heads/yolov8_head.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ from typing import List, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from mmcv.cnn import ConvModule
8
+ from mmdet.models.utils import multi_apply
9
+ from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
10
+ OptMultiConfig)
11
+ from mmengine.dist import get_dist_info
12
+ from mmengine.model import BaseModule
13
+ from mmengine.structures import InstanceData
14
+ from torch import Tensor
15
+
16
+ from mmyolo.registry import MODELS, TASK_UTILS
17
+ from ..utils import gt_instances_preprocess, make_divisible
18
+ from .yolov5_head import YOLOv5Head
19
+
20
+
21
+ @MODELS.register_module()
22
+ class YOLOv8HeadModule(BaseModule):
23
+ """YOLOv8HeadModule head module used in `YOLOv8`.
24
+
25
+ Args:
26
+ num_classes (int): Number of categories excluding the background
27
+ category.
28
+ in_channels (Union[int, Sequence]): Number of channels in the input
29
+ feature map.
30
+ widen_factor (float): Width multiplier, multiply number of
31
+ channels in each layer by this amount. Defaults to 1.0.
32
+ num_base_priors (int): The number of priors (points) at a point
33
+ on the feature grid.
34
+ featmap_strides (Sequence[int]): Downsample factor of each feature map.
35
+ Defaults to [8, 16, 32].
36
+ reg_max (int): Max value of integral set :math: ``{0, ..., reg_max-1}``
37
+ in QFL setting. Defaults to 16.
38
+ norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
39
+ layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
40
+ act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
41
+ Defaults to None.
42
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
43
+ list[dict], optional): Initialization config dict.
44
+ Defaults to None.
45
+ """
46
+
47
+ def __init__(self,
48
+ num_classes: int,
49
+ in_channels: Union[int, Sequence],
50
+ widen_factor: float = 1.0,
51
+ num_base_priors: int = 1,
52
+ featmap_strides: Sequence[int] = (8, 16, 32),
53
+ reg_max: int = 16,
54
+ norm_cfg: ConfigType = dict(
55
+ type='BN', momentum=0.03, eps=0.001),
56
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
57
+ init_cfg: OptMultiConfig = None):
58
+ super().__init__(init_cfg=init_cfg)
59
+ self.num_classes = num_classes
60
+ self.featmap_strides = featmap_strides
61
+ self.num_levels = len(self.featmap_strides)
62
+ self.num_base_priors = num_base_priors
63
+ self.norm_cfg = norm_cfg
64
+ self.act_cfg = act_cfg
65
+ self.in_channels = in_channels
66
+ self.reg_max = reg_max
67
+
68
+ in_channels = []
69
+ for channel in self.in_channels:
70
+ channel = make_divisible(channel, widen_factor)
71
+ in_channels.append(channel)
72
+ self.in_channels = in_channels
73
+
74
+ self._init_layers()
75
+
76
+ def init_weights(self, prior_prob=0.01):
77
+ """Initialize the weight and bias of PPYOLOE head."""
78
+ super().init_weights()
79
+ for reg_pred, cls_pred, stride in zip(self.reg_preds, self.cls_preds,
80
+ self.featmap_strides):
81
+ reg_pred[-1].bias.data[:] = 1.0 # box
82
+ # cls (.01 objects, 80 classes, 640 img)
83
+ cls_pred[-1].bias.data[:self.num_classes] = math.log(
84
+ 5 / self.num_classes / (640 / stride)**2)
85
+
86
+ def _init_layers(self):
87
+ """initialize conv layers in YOLOv8 head."""
88
+ # Init decouple head
89
+ self.cls_preds = nn.ModuleList()
90
+ self.reg_preds = nn.ModuleList()
91
+
92
+ reg_out_channels = max(
93
+ (16, self.in_channels[0] // 4, self.reg_max * 4))
94
+ cls_out_channels = max(self.in_channels[0], self.num_classes)
95
+
96
+ for i in range(self.num_levels):
97
+ self.reg_preds.append(
98
+ nn.Sequential(
99
+ ConvModule(
100
+ in_channels=self.in_channels[i],
101
+ out_channels=reg_out_channels,
102
+ kernel_size=3,
103
+ stride=1,
104
+ padding=1,
105
+ norm_cfg=self.norm_cfg,
106
+ act_cfg=self.act_cfg),
107
+ ConvModule(
108
+ in_channels=reg_out_channels,
109
+ out_channels=reg_out_channels,
110
+ kernel_size=3,
111
+ stride=1,
112
+ padding=1,
113
+ norm_cfg=self.norm_cfg,
114
+ act_cfg=self.act_cfg),
115
+ nn.Conv2d(
116
+ in_channels=reg_out_channels,
117
+ out_channels=4 * self.reg_max,
118
+ kernel_size=1)))
119
+ self.cls_preds.append(
120
+ nn.Sequential(
121
+ ConvModule(
122
+ in_channels=self.in_channels[i],
123
+ out_channels=cls_out_channels,
124
+ kernel_size=3,
125
+ stride=1,
126
+ padding=1,
127
+ norm_cfg=self.norm_cfg,
128
+ act_cfg=self.act_cfg),
129
+ ConvModule(
130
+ in_channels=cls_out_channels,
131
+ out_channels=cls_out_channels,
132
+ kernel_size=3,
133
+ stride=1,
134
+ padding=1,
135
+ norm_cfg=self.norm_cfg,
136
+ act_cfg=self.act_cfg),
137
+ nn.Conv2d(
138
+ in_channels=cls_out_channels,
139
+ out_channels=self.num_classes,
140
+ kernel_size=1)))
141
+
142
+ proj = torch.arange(self.reg_max, dtype=torch.float)
143
+ self.register_buffer('proj', proj, persistent=False)
144
+
145
+ def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
146
+ """Forward features from the upstream network.
147
+
148
+ Args:
149
+ x (Tuple[Tensor]): Features from the upstream network, each is
150
+ a 4D-tensor.
151
+ Returns:
152
+ Tuple[List]: A tuple of multi-level classification scores, bbox
153
+ predictions
154
+ """
155
+ assert len(x) == self.num_levels
156
+ return multi_apply(self.forward_single, x, self.cls_preds,
157
+ self.reg_preds)
158
+
159
+ def forward_single(self, x: torch.Tensor, cls_pred: nn.ModuleList,
160
+ reg_pred: nn.ModuleList) -> Tuple:
161
+ """Forward feature of a single scale level."""
162
+ b, _, h, w = x.shape
163
+ cls_logit = cls_pred(x)
164
+ bbox_dist_preds = reg_pred(x)
165
+ if self.reg_max > 1:
166
+ bbox_dist_preds = bbox_dist_preds.reshape(
167
+ [-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2)
168
+
169
+ # TODO: The get_flops script cannot handle the situation of
170
+ # matmul, and needs to be fixed later
171
+ # bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj)
172
+ bbox_preds = bbox_dist_preds.softmax(3).matmul(
173
+ self.proj.view([-1, 1])).squeeze(-1)
174
+ bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w)
175
+ else:
176
+ bbox_preds = bbox_dist_preds
177
+ if self.training:
178
+ return cls_logit, bbox_preds, bbox_dist_preds
179
+ else:
180
+ return cls_logit, bbox_preds
181
+
182
+
183
+ @MODELS.register_module()
184
+ class YOLOv8Head(YOLOv5Head):
185
+ """YOLOv8Head head used in `YOLOv8`.
186
+
187
+ Args:
188
+ head_module(:obj:`ConfigDict` or dict): Base module used for YOLOv8Head
189
+ prior_generator(dict): Points generator feature maps
190
+ in 2D points-based detectors.
191
+ bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder.
192
+ loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
193
+ loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
194
+ loss_dfl (:obj:`ConfigDict` or dict): Config of Distribution Focal
195
+ Loss.
196
+ train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
197
+ anchor head. Defaults to None.
198
+ test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
199
+ anchor head. Defaults to None.
200
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
201
+ list[dict], optional): Initialization config dict.
202
+ Defaults to None.
203
+ """
204
+
205
+ def __init__(self,
206
+ head_module: ConfigType,
207
+ prior_generator: ConfigType = dict(
208
+ type='mmdet.MlvlPointGenerator',
209
+ offset=0.5,
210
+ strides=[8, 16, 32]),
211
+ bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'),
212
+ loss_cls: ConfigType = dict(
213
+ type='mmdet.CrossEntropyLoss',
214
+ use_sigmoid=True,
215
+ reduction='none',
216
+ loss_weight=0.5),
217
+ loss_bbox: ConfigType = dict(
218
+ type='IoULoss',
219
+ iou_mode='ciou',
220
+ bbox_format='xyxy',
221
+ reduction='sum',
222
+ loss_weight=7.5,
223
+ return_iou=False),
224
+ loss_dfl=dict(
225
+ type='mmdet.DistributionFocalLoss',
226
+ reduction='mean',
227
+ loss_weight=1.5 / 4),
228
+ train_cfg: OptConfigType = None,
229
+ test_cfg: OptConfigType = None,
230
+ init_cfg: OptMultiConfig = None
231
+ ):
232
+ super().__init__(
233
+ head_module=head_module,
234
+ prior_generator=prior_generator,
235
+ bbox_coder=bbox_coder,
236
+ loss_cls=loss_cls,
237
+ loss_bbox=loss_bbox,
238
+ train_cfg=train_cfg,
239
+ test_cfg=test_cfg,
240
+ init_cfg=init_cfg)
241
+ self.loss_dfl = MODELS.build(loss_dfl)
242
+ # YOLOv8 doesn't need loss_obj
243
+ self.loss_obj = None
244
+
245
+ def special_init(self):
246
+ """Since YOLO series algorithms will inherit from YOLOv5Head, but
247
+ different algorithms have special initialization process.
248
+
249
+ The special_init function is designed to deal with this situation.
250
+ """
251
+
252
+ if self.train_cfg:
253
+ self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
254
+
255
+ # Add common attributes to reduce calculation
256
+ self.featmap_sizes_train = None
257
+ self.num_level_priors = None
258
+ self.flatten_priors_train = None
259
+ self.stride_tensor = None
260
+
261
+ def loss_by_feat(
262
+ self,
263
+ cls_scores: Sequence[Tensor],
264
+ bbox_preds: Sequence[Tensor],
265
+ bbox_dist_preds: Sequence[Tensor],
266
+ batch_gt_instances: Sequence[InstanceData],
267
+ batch_img_metas: Sequence[dict],
268
+ batch_gt_instances_ignore: OptInstanceList = None) -> dict:
269
+ """Calculate the loss based on the features extracted by the detection
270
+ head.
271
+
272
+ Args:
273
+ cls_scores (Sequence[Tensor]): Box scores for each scale level,
274
+ each is a 4D-tensor, the channel number is
275
+ num_priors * num_classes.
276
+ bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
277
+ level, each is a 4D-tensor, the channel number is
278
+ num_priors * 4.
279
+ bbox_dist_preds (Sequence[Tensor]): Box distribution logits for
280
+ each scale level with shape (bs, reg_max + 1, H*W, 4).
281
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
282
+ gt_instance. It usually includes ``bboxes`` and ``labels``
283
+ attributes.
284
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
285
+ image size, scaling factor, etc.
286
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
287
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
288
+ data that is ignored during training and testing.
289
+ Defaults to None.
290
+ Returns:
291
+ dict[str, Tensor]: A dictionary of losses.
292
+ """
293
+ num_imgs = len(batch_img_metas)
294
+
295
+ current_featmap_sizes = [
296
+ cls_score.shape[2:] for cls_score in cls_scores
297
+ ]
298
+ # If the shape does not equal, generate new one
299
+ if current_featmap_sizes != self.featmap_sizes_train:
300
+ self.featmap_sizes_train = current_featmap_sizes
301
+
302
+ mlvl_priors_with_stride = self.prior_generator.grid_priors(
303
+ self.featmap_sizes_train,
304
+ dtype=cls_scores[0].dtype,
305
+ device=cls_scores[0].device,
306
+ with_stride=True)
307
+
308
+ self.num_level_priors = [len(n) for n in mlvl_priors_with_stride]
309
+ self.flatten_priors_train = torch.cat(
310
+ mlvl_priors_with_stride, dim=0)
311
+ self.stride_tensor = self.flatten_priors_train[..., [2]]
312
+
313
+ # gt info
314
+ gt_info = gt_instances_preprocess(batch_gt_instances, num_imgs)
315
+ gt_labels = gt_info[:, :, :1]
316
+ gt_bboxes = gt_info[:, :, 1:] # xyxy
317
+ pad_bbox_flag = (gt_bboxes.sum(-1, keepdim=True) > 0).float()
318
+
319
+ # pred info
320
+ flatten_cls_preds = [
321
+ cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
322
+ self.num_classes)
323
+ for cls_pred in cls_scores
324
+ ]
325
+ flatten_pred_bboxes = [
326
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
327
+ for bbox_pred in bbox_preds
328
+ ]
329
+ # (bs, n, 4 * reg_max)
330
+ flatten_pred_dists = [
331
+ bbox_pred_org.reshape(num_imgs, -1, self.head_module.reg_max * 4)
332
+ for bbox_pred_org in bbox_dist_preds
333
+ ]
334
+
335
+ flatten_dist_preds = torch.cat(flatten_pred_dists, dim=1)
336
+ flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
337
+ flatten_pred_bboxes = torch.cat(flatten_pred_bboxes, dim=1)
338
+ flatten_pred_bboxes = self.bbox_coder.decode(
339
+ self.flatten_priors_train[..., :2], flatten_pred_bboxes,
340
+ self.stride_tensor[..., 0])
341
+
342
+ assigned_result = self.assigner(
343
+ (flatten_pred_bboxes.detach()).type(gt_bboxes.dtype),
344
+ flatten_cls_preds.detach().sigmoid(), self.flatten_priors_train,
345
+ gt_labels, gt_bboxes, pad_bbox_flag)
346
+
347
+ assigned_bboxes = assigned_result['assigned_bboxes']
348
+ assigned_scores = assigned_result['assigned_scores']
349
+ fg_mask_pre_prior = assigned_result['fg_mask_pre_prior']
350
+
351
+ assigned_scores_sum = assigned_scores.sum().clamp(min=1)
352
+
353
+ loss_cls = self.loss_cls(flatten_cls_preds, assigned_scores).sum()
354
+ loss_cls /= assigned_scores_sum
355
+
356
+ # rescale bbox
357
+ assigned_bboxes /= self.stride_tensor
358
+ flatten_pred_bboxes /= self.stride_tensor
359
+
360
+ # select positive samples mask
361
+ num_pos = fg_mask_pre_prior.sum()
362
+ if num_pos > 0:
363
+ # when num_pos > 0, assigned_scores_sum will >0, so the loss_bbox
364
+ # will not report an error
365
+ # iou loss
366
+ prior_bbox_mask = fg_mask_pre_prior.unsqueeze(-1).repeat([1, 1, 4])
367
+ pred_bboxes_pos = torch.masked_select(
368
+ flatten_pred_bboxes, prior_bbox_mask).reshape([-1, 4])
369
+ assigned_bboxes_pos = torch.masked_select(
370
+ assigned_bboxes, prior_bbox_mask).reshape([-1, 4])
371
+ bbox_weight = torch.masked_select(
372
+ assigned_scores.sum(-1), fg_mask_pre_prior).unsqueeze(-1)
373
+ loss_bbox = self.loss_bbox(
374
+ pred_bboxes_pos, assigned_bboxes_pos,
375
+ weight=bbox_weight) / assigned_scores_sum
376
+
377
+ # dfl loss
378
+ pred_dist_pos = flatten_dist_preds[fg_mask_pre_prior]
379
+ assigned_ltrb = self.bbox_coder.encode(
380
+ self.flatten_priors_train[..., :2] / self.stride_tensor,
381
+ assigned_bboxes,
382
+ max_dis=self.head_module.reg_max - 1,
383
+ eps=0.01)
384
+ assigned_ltrb_pos = torch.masked_select(
385
+ assigned_ltrb, prior_bbox_mask).reshape([-1, 4])
386
+ loss_dfl = self.loss_dfl(
387
+ pred_dist_pos.reshape(-1, self.head_module.reg_max),
388
+ assigned_ltrb_pos.reshape(-1),
389
+ weight=bbox_weight.expand(-1, 4).reshape(-1),
390
+ avg_factor=assigned_scores_sum)
391
+ else:
392
+ loss_bbox = flatten_pred_bboxes.sum() * 0
393
+ loss_dfl = flatten_pred_bboxes.sum() * 0
394
+ _, world_size = get_dist_info()
395
+ return dict(
396
+ loss_cls=loss_cls * num_imgs * world_size,
397
+ loss_bbox=loss_bbox * num_imgs * world_size,
398
+ loss_dfl=loss_dfl * num_imgs * world_size)
mmyolo/models/dense_heads/yolox_head.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from typing import List, Optional, Sequence, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
8
+ from mmdet.models.task_modules.samplers import PseudoSampler
9
+ from mmdet.models.utils import multi_apply
10
+ from mmdet.structures.bbox import bbox_xyxy_to_cxcywh
11
+ from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList,
12
+ OptMultiConfig, reduce_mean)
13
+ from mmengine.model import BaseModule, bias_init_with_prob
14
+ from mmengine.structures import InstanceData
15
+ from torch import Tensor
16
+
17
+ from mmyolo.registry import MODELS, TASK_UTILS
18
+ from .yolov5_head import YOLOv5Head
19
+
20
+
21
+ @MODELS.register_module()
22
+ class YOLOXHeadModule(BaseModule):
23
+ """YOLOXHead head module used in `YOLOX.
24
+
25
+ `<https://arxiv.org/abs/2107.08430>`_
26
+
27
+ Args:
28
+ num_classes (int): Number of categories excluding the background
29
+ category.
30
+ in_channels (Union[int, Sequence]): Number of channels in the input
31
+ feature map.
32
+ widen_factor (float): Width multiplier, multiply number of
33
+ channels in each layer by this amount. Defaults to 1.0.
34
+ num_base_priors (int): The number of priors (points) at a point
35
+ on the feature grid
36
+ stacked_convs (int): Number of stacking convs of the head.
37
+ Defaults to 2.
38
+ featmap_strides (Sequence[int]): Downsample factor of each feature map.
39
+ Defaults to [8, 16, 32].
40
+ use_depthwise (bool): Whether to depthwise separable convolution in
41
+ blocks. Defaults to False.
42
+ dcn_on_last_conv (bool): If true, use dcn in the last layer of
43
+ towers. Defaults to False.
44
+ conv_bias (bool or str): If specified as `auto`, it will be decided by
45
+ the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
46
+ None, otherwise False. Defaults to "auto".
47
+ conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
48
+ convolution layer. Defaults to None.
49
+ norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
50
+ layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001).
51
+ act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer.
52
+ Defaults to None.
53
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
54
+ list[dict], optional): Initialization config dict.
55
+ Defaults to None.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ num_classes: int,
61
+ in_channels: Union[int, Sequence],
62
+ widen_factor: float = 1.0,
63
+ num_base_priors: int = 1,
64
+ feat_channels: int = 256,
65
+ stacked_convs: int = 2,
66
+ featmap_strides: Sequence[int] = [8, 16, 32],
67
+ use_depthwise: bool = False,
68
+ dcn_on_last_conv: bool = False,
69
+ conv_bias: Union[bool, str] = 'auto',
70
+ conv_cfg: OptConfigType = None,
71
+ norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
72
+ act_cfg: ConfigType = dict(type='SiLU', inplace=True),
73
+ init_cfg: OptMultiConfig = None,
74
+ ):
75
+ super().__init__(init_cfg=init_cfg)
76
+ self.num_classes = num_classes
77
+ self.feat_channels = int(feat_channels * widen_factor)
78
+ self.stacked_convs = stacked_convs
79
+ self.use_depthwise = use_depthwise
80
+ self.dcn_on_last_conv = dcn_on_last_conv
81
+ assert conv_bias == 'auto' or isinstance(conv_bias, bool)
82
+ self.conv_bias = conv_bias
83
+ self.num_base_priors = num_base_priors
84
+
85
+ self.conv_cfg = conv_cfg
86
+ self.norm_cfg = norm_cfg
87
+ self.act_cfg = act_cfg
88
+ self.featmap_strides = featmap_strides
89
+
90
+ if isinstance(in_channels, int):
91
+ in_channels = int(in_channels * widen_factor)
92
+ self.in_channels = in_channels
93
+
94
+ self._init_layers()
95
+
96
+ def _init_layers(self):
97
+ """Initialize heads for all level feature maps."""
98
+ self.multi_level_cls_convs = nn.ModuleList()
99
+ self.multi_level_reg_convs = nn.ModuleList()
100
+ self.multi_level_conv_cls = nn.ModuleList()
101
+ self.multi_level_conv_reg = nn.ModuleList()
102
+ self.multi_level_conv_obj = nn.ModuleList()
103
+ for _ in self.featmap_strides:
104
+ self.multi_level_cls_convs.append(self._build_stacked_convs())
105
+ self.multi_level_reg_convs.append(self._build_stacked_convs())
106
+ conv_cls, conv_reg, conv_obj = self._build_predictor()
107
+ self.multi_level_conv_cls.append(conv_cls)
108
+ self.multi_level_conv_reg.append(conv_reg)
109
+ self.multi_level_conv_obj.append(conv_obj)
110
+
111
+ def _build_stacked_convs(self) -> nn.Sequential:
112
+ """Initialize conv layers of a single level head."""
113
+ conv = DepthwiseSeparableConvModule \
114
+ if self.use_depthwise else ConvModule
115
+ stacked_convs = []
116
+ for i in range(self.stacked_convs):
117
+ chn = self.in_channels if i == 0 else self.feat_channels
118
+ if self.dcn_on_last_conv and i == self.stacked_convs - 1:
119
+ conv_cfg = dict(type='DCNv2')
120
+ else:
121
+ conv_cfg = self.conv_cfg
122
+ stacked_convs.append(
123
+ conv(
124
+ chn,
125
+ self.feat_channels,
126
+ 3,
127
+ stride=1,
128
+ padding=1,
129
+ conv_cfg=conv_cfg,
130
+ norm_cfg=self.norm_cfg,
131
+ act_cfg=self.act_cfg,
132
+ bias=self.conv_bias))
133
+ return nn.Sequential(*stacked_convs)
134
+
135
+ def _build_predictor(self) -> Tuple[nn.Module, nn.Module, nn.Module]:
136
+ """Initialize predictor layers of a single level head."""
137
+ conv_cls = nn.Conv2d(self.feat_channels, self.num_classes, 1)
138
+ conv_reg = nn.Conv2d(self.feat_channels, 4, 1)
139
+ conv_obj = nn.Conv2d(self.feat_channels, 1, 1)
140
+ return conv_cls, conv_reg, conv_obj
141
+
142
+ def init_weights(self):
143
+ """Initialize weights of the head."""
144
+ # Use prior in model initialization to improve stability
145
+ super().init_weights()
146
+ bias_init = bias_init_with_prob(0.01)
147
+ for conv_cls, conv_obj in zip(self.multi_level_conv_cls,
148
+ self.multi_level_conv_obj):
149
+ conv_cls.bias.data.fill_(bias_init)
150
+ conv_obj.bias.data.fill_(bias_init)
151
+
152
+ def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
153
+ """Forward features from the upstream network.
154
+
155
+ Args:
156
+ x (Tuple[Tensor]): Features from the upstream network, each is
157
+ a 4D-tensor.
158
+ Returns:
159
+ Tuple[List]: A tuple of multi-level classification scores, bbox
160
+ predictions, and objectnesses.
161
+ """
162
+
163
+ return multi_apply(self.forward_single, x, self.multi_level_cls_convs,
164
+ self.multi_level_reg_convs,
165
+ self.multi_level_conv_cls,
166
+ self.multi_level_conv_reg,
167
+ self.multi_level_conv_obj)
168
+
169
+ def forward_single(self, x: Tensor, cls_convs: nn.Module,
170
+ reg_convs: nn.Module, conv_cls: nn.Module,
171
+ conv_reg: nn.Module,
172
+ conv_obj: nn.Module) -> Tuple[Tensor, Tensor, Tensor]:
173
+ """Forward feature of a single scale level."""
174
+
175
+ cls_feat = cls_convs(x)
176
+ reg_feat = reg_convs(x)
177
+
178
+ cls_score = conv_cls(cls_feat)
179
+ bbox_pred = conv_reg(reg_feat)
180
+ objectness = conv_obj(reg_feat)
181
+
182
+ return cls_score, bbox_pred, objectness
183
+
184
+
185
+ @MODELS.register_module()
186
+ class YOLOXHead(YOLOv5Head):
187
+ """YOLOXHead head used in `YOLOX <https://arxiv.org/abs/2107.08430>`_.
188
+
189
+ Args:
190
+ head_module(ConfigType): Base module used for YOLOXHead
191
+ prior_generator: Points generator feature maps in
192
+ 2D points-based detectors.
193
+ loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
194
+ loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
195
+ loss_obj (:obj:`ConfigDict` or dict): Config of objectness loss.
196
+ loss_bbox_aux (:obj:`ConfigDict` or dict): Config of bbox aux loss.
197
+ train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
198
+ anchor head. Defaults to None.
199
+ test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
200
+ anchor head. Defaults to None.
201
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
202
+ list[dict], optional): Initialization config dict.
203
+ Defaults to None.
204
+ """
205
+
206
+ def __init__(self,
207
+ head_module: ConfigType,
208
+ prior_generator: ConfigType = dict(
209
+ type='mmdet.MlvlPointGenerator',
210
+ offset=0,
211
+ strides=[8, 16, 32]),
212
+ bbox_coder: ConfigType = dict(type='YOLOXBBoxCoder'),
213
+ loss_cls: ConfigType = dict(
214
+ type='mmdet.CrossEntropyLoss',
215
+ use_sigmoid=True,
216
+ reduction='sum',
217
+ loss_weight=1.0),
218
+ loss_bbox: ConfigType = dict(
219
+ type='mmdet.IoULoss',
220
+ mode='square',
221
+ eps=1e-16,
222
+ reduction='sum',
223
+ loss_weight=5.0),
224
+ loss_obj: ConfigType = dict(
225
+ type='mmdet.CrossEntropyLoss',
226
+ use_sigmoid=True,
227
+ reduction='sum',
228
+ loss_weight=1.0),
229
+ loss_bbox_aux: ConfigType = dict(
230
+ type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
231
+ train_cfg: OptConfigType = None,
232
+ test_cfg: OptConfigType = None,
233
+ init_cfg: OptMultiConfig = None):
234
+ self.use_bbox_aux = False
235
+ self.loss_bbox_aux = loss_bbox_aux
236
+
237
+ super().__init__(
238
+ head_module=head_module,
239
+ prior_generator=prior_generator,
240
+ bbox_coder=bbox_coder,
241
+ loss_cls=loss_cls,
242
+ loss_bbox=loss_bbox,
243
+ loss_obj=loss_obj,
244
+ train_cfg=train_cfg,
245
+ test_cfg=test_cfg,
246
+ init_cfg=init_cfg)
247
+
248
+ def special_init(self):
249
+ """Since YOLO series algorithms will inherit from YOLOv5Head, but
250
+ different algorithms have special initialization process.
251
+
252
+ The special_init function is designed to deal with this situation.
253
+ """
254
+ self.loss_bbox_aux: nn.Module = MODELS.build(self.loss_bbox_aux)
255
+ if self.train_cfg:
256
+ self.assigner = TASK_UTILS.build(self.train_cfg.assigner)
257
+ # YOLOX does not support sampling
258
+ self.sampler = PseudoSampler()
259
+
260
+ def forward(self, x: Tuple[Tensor]) -> Tuple[List]:
261
+ return self.head_module(x)
262
+
263
+ def loss_by_feat(
264
+ self,
265
+ cls_scores: Sequence[Tensor],
266
+ bbox_preds: Sequence[Tensor],
267
+ objectnesses: Sequence[Tensor],
268
+ batch_gt_instances: Tensor,
269
+ batch_img_metas: Sequence[dict],
270
+ batch_gt_instances_ignore: OptInstanceList = None) -> dict:
271
+ """Calculate the loss based on the features extracted by the detection
272
+ head.
273
+
274
+ Args:
275
+ cls_scores (Sequence[Tensor]): Box scores for each scale level,
276
+ each is a 4D-tensor, the channel number is
277
+ num_priors * num_classes.
278
+ bbox_preds (Sequence[Tensor]): Box energies / deltas for each scale
279
+ level, each is a 4D-tensor, the channel number is
280
+ num_priors * 4.
281
+ objectnesses (Sequence[Tensor]): Score factor for
282
+ all scale level, each is a 4D-tensor, has shape
283
+ (batch_size, 1, H, W).
284
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
285
+ gt_instance. It usually includes ``bboxes`` and ``labels``
286
+ attributes.
287
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
288
+ image size, scaling factor, etc.
289
+ batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
290
+ Batch of gt_instances_ignore. It includes ``bboxes`` attribute
291
+ data that is ignored during training and testing.
292
+ Defaults to None.
293
+ Returns:
294
+ dict[str, Tensor]: A dictionary of losses.
295
+ """
296
+ num_imgs = len(batch_img_metas)
297
+ if batch_gt_instances_ignore is None:
298
+ batch_gt_instances_ignore = [None] * num_imgs
299
+
300
+ batch_gt_instances = self.gt_instances_preprocess(
301
+ batch_gt_instances, len(batch_img_metas))
302
+
303
+ featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
304
+ mlvl_priors = self.prior_generator.grid_priors(
305
+ featmap_sizes,
306
+ dtype=cls_scores[0].dtype,
307
+ device=cls_scores[0].device,
308
+ with_stride=True)
309
+
310
+ flatten_cls_preds = [
311
+ cls_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1,
312
+ self.num_classes)
313
+ for cls_pred in cls_scores
314
+ ]
315
+ flatten_bbox_preds = [
316
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
317
+ for bbox_pred in bbox_preds
318
+ ]
319
+ flatten_objectness = [
320
+ objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
321
+ for objectness in objectnesses
322
+ ]
323
+
324
+ flatten_cls_preds = torch.cat(flatten_cls_preds, dim=1)
325
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
326
+ flatten_objectness = torch.cat(flatten_objectness, dim=1)
327
+ flatten_priors = torch.cat(mlvl_priors)
328
+ flatten_bboxes = self.bbox_coder.decode(flatten_priors[..., :2],
329
+ flatten_bbox_preds,
330
+ flatten_priors[..., 2])
331
+
332
+ (pos_masks, cls_targets, obj_targets, bbox_targets, bbox_aux_target,
333
+ num_fg_imgs) = multi_apply(
334
+ self._get_targets_single,
335
+ flatten_priors.unsqueeze(0).repeat(num_imgs, 1, 1),
336
+ flatten_cls_preds.detach(), flatten_bboxes.detach(),
337
+ flatten_objectness.detach(), batch_gt_instances, batch_img_metas,
338
+ batch_gt_instances_ignore)
339
+
340
+ # The experimental results show that 'reduce_mean' can improve
341
+ # performance on the COCO dataset.
342
+ num_pos = torch.tensor(
343
+ sum(num_fg_imgs),
344
+ dtype=torch.float,
345
+ device=flatten_cls_preds.device)
346
+ num_total_samples = max(reduce_mean(num_pos), 1.0)
347
+
348
+ pos_masks = torch.cat(pos_masks, 0)
349
+ cls_targets = torch.cat(cls_targets, 0)
350
+ obj_targets = torch.cat(obj_targets, 0)
351
+ bbox_targets = torch.cat(bbox_targets, 0)
352
+ if self.use_bbox_aux:
353
+ bbox_aux_target = torch.cat(bbox_aux_target, 0)
354
+
355
+ loss_obj = self.loss_obj(flatten_objectness.view(-1, 1),
356
+ obj_targets) / num_total_samples
357
+ if num_pos > 0:
358
+ loss_cls = self.loss_cls(
359
+ flatten_cls_preds.view(-1, self.num_classes)[pos_masks],
360
+ cls_targets) / num_total_samples
361
+ loss_bbox = self.loss_bbox(
362
+ flatten_bboxes.view(-1, 4)[pos_masks],
363
+ bbox_targets) / num_total_samples
364
+ else:
365
+ # Avoid cls and reg branch not participating in the gradient
366
+ # propagation when there is no ground-truth in the images.
367
+ # For more details, please refer to
368
+ # https://github.com/open-mmlab/mmdetection/issues/7298
369
+ loss_cls = flatten_cls_preds.sum() * 0
370
+ loss_bbox = flatten_bboxes.sum() * 0
371
+
372
+ loss_dict = dict(
373
+ loss_cls=loss_cls, loss_bbox=loss_bbox, loss_obj=loss_obj)
374
+
375
+ if self.use_bbox_aux:
376
+ if num_pos > 0:
377
+ loss_bbox_aux = self.loss_bbox_aux(
378
+ flatten_bbox_preds.view(-1, 4)[pos_masks],
379
+ bbox_aux_target) / num_total_samples
380
+ else:
381
+ # Avoid cls and reg branch not participating in the gradient
382
+ # propagation when there is no ground-truth in the images.
383
+ # For more details, please refer to
384
+ # https://github.com/open-mmlab/mmdetection/issues/7298
385
+ loss_bbox_aux = flatten_bbox_preds.sum() * 0
386
+ loss_dict.update(loss_bbox_aux=loss_bbox_aux)
387
+
388
+ return loss_dict
389
+
390
+ @torch.no_grad()
391
+ def _get_targets_single(
392
+ self,
393
+ priors: Tensor,
394
+ cls_preds: Tensor,
395
+ decoded_bboxes: Tensor,
396
+ objectness: Tensor,
397
+ gt_instances: InstanceData,
398
+ img_meta: dict,
399
+ gt_instances_ignore: Optional[InstanceData] = None) -> tuple:
400
+ """Compute classification, regression, and objectness targets for
401
+ priors in a single image.
402
+
403
+ Args:
404
+ priors (Tensor): All priors of one image, a 2D-Tensor with shape
405
+ [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
406
+ cls_preds (Tensor): Classification predictions of one image,
407
+ a 2D-Tensor with shape [num_priors, num_classes]
408
+ decoded_bboxes (Tensor): Decoded bboxes predictions of one image,
409
+ a 2D-Tensor with shape [num_priors, 4] in [tl_x, tl_y,
410
+ br_x, br_y] format.
411
+ objectness (Tensor): Objectness predictions of one image,
412
+ a 1D-Tensor with shape [num_priors]
413
+ gt_instances (:obj:`InstanceData`): Ground truth of instance
414
+ annotations. It should includes ``bboxes`` and ``labels``
415
+ attributes.
416
+ img_meta (dict): Meta information for current image.
417
+ gt_instances_ignore (:obj:`InstanceData`, optional): Instances
418
+ to be ignored during training. It includes ``bboxes`` attribute
419
+ data that is ignored during training and testing.
420
+ Defaults to None.
421
+ Returns:
422
+ tuple:
423
+ foreground_mask (list[Tensor]): Binary mask of foreground
424
+ targets.
425
+ cls_target (list[Tensor]): Classification targets of an image.
426
+ obj_target (list[Tensor]): Objectness targets of an image.
427
+ bbox_target (list[Tensor]): BBox targets of an image.
428
+ bbox_aux_target (int): BBox aux targets of an image.
429
+ num_pos_per_img (int): Number of positive samples in an image.
430
+ """
431
+
432
+ num_priors = priors.size(0)
433
+ num_gts = len(gt_instances)
434
+ # No target
435
+ if num_gts == 0:
436
+ cls_target = cls_preds.new_zeros((0, self.num_classes))
437
+ bbox_target = cls_preds.new_zeros((0, 4))
438
+ bbox_aux_target = cls_preds.new_zeros((0, 4))
439
+ obj_target = cls_preds.new_zeros((num_priors, 1))
440
+ foreground_mask = cls_preds.new_zeros(num_priors).bool()
441
+ return (foreground_mask, cls_target, obj_target, bbox_target,
442
+ bbox_aux_target, 0)
443
+
444
+ # YOLOX uses center priors with 0.5 offset to assign targets,
445
+ # but use center priors without offset to regress bboxes.
446
+ offset_priors = torch.cat(
447
+ [priors[:, :2] + priors[:, 2:] * 0.5, priors[:, 2:]], dim=-1)
448
+
449
+ scores = cls_preds.sigmoid() * objectness.unsqueeze(1).sigmoid()
450
+ pred_instances = InstanceData(
451
+ bboxes=decoded_bboxes, scores=scores.sqrt_(), priors=offset_priors)
452
+ assign_result = self.assigner.assign(
453
+ pred_instances=pred_instances,
454
+ gt_instances=gt_instances,
455
+ gt_instances_ignore=gt_instances_ignore)
456
+
457
+ sampling_result = self.sampler.sample(assign_result, pred_instances,
458
+ gt_instances)
459
+ pos_inds = sampling_result.pos_inds
460
+ num_pos_per_img = pos_inds.size(0)
461
+
462
+ pos_ious = assign_result.max_overlaps[pos_inds]
463
+ # IOU aware classification score
464
+ cls_target = F.one_hot(sampling_result.pos_gt_labels,
465
+ self.num_classes) * pos_ious.unsqueeze(-1)
466
+ obj_target = torch.zeros_like(objectness).unsqueeze(-1)
467
+ obj_target[pos_inds] = 1
468
+ bbox_target = sampling_result.pos_gt_bboxes
469
+ bbox_aux_target = cls_preds.new_zeros((num_pos_per_img, 4))
470
+ if self.use_bbox_aux:
471
+ bbox_aux_target = self._get_bbox_aux_target(
472
+ bbox_aux_target, bbox_target, priors[pos_inds])
473
+ foreground_mask = torch.zeros_like(objectness).to(torch.bool)
474
+ foreground_mask[pos_inds] = 1
475
+ return (foreground_mask, cls_target, obj_target, bbox_target,
476
+ bbox_aux_target, num_pos_per_img)
477
+
478
+ def _get_bbox_aux_target(self,
479
+ bbox_aux_target: Tensor,
480
+ gt_bboxes: Tensor,
481
+ priors: Tensor,
482
+ eps: float = 1e-8) -> Tensor:
483
+ """Convert gt bboxes to center offset and log width height."""
484
+ gt_cxcywh = bbox_xyxy_to_cxcywh(gt_bboxes)
485
+ bbox_aux_target[:, :2] = (gt_cxcywh[:, :2] -
486
+ priors[:, :2]) / priors[:, 2:]
487
+ bbox_aux_target[:,
488
+ 2:] = torch.log(gt_cxcywh[:, 2:] / priors[:, 2:] + eps)
489
+ return bbox_aux_target
490
+
491
+ @staticmethod
492
+ def gt_instances_preprocess(batch_gt_instances: Tensor,
493
+ batch_size: int) -> List[InstanceData]:
494
+ """Split batch_gt_instances with batch size.
495
+
496
+ Args:
497
+ batch_gt_instances (Tensor): Ground truth
498
+ a 2D-Tensor for whole batch, shape [all_gt_bboxes, 6]
499
+ batch_size (int): Batch size.
500
+
501
+ Returns:
502
+ List: batch gt instances data, shape [batch_size, InstanceData]
503
+ """
504
+ # faster version
505
+ batch_instance_list = []
506
+ for i in range(batch_size):
507
+ batch_gt_instance_ = InstanceData()
508
+ single_batch_instance = \
509
+ batch_gt_instances[batch_gt_instances[:, 0] == i, :]
510
+ batch_gt_instance_.bboxes = single_batch_instance[:, 2:]
511
+ batch_gt_instance_.labels = single_batch_instance[:, 1]
512
+ batch_instance_list.append(batch_gt_instance_)
513
+
514
+ return batch_instance_list
mmyolo/models/detectors/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .yolo_detector import YOLODetector
3
+
4
+ __all__ = ['YOLODetector']
mmyolo/models/detectors/yolo_detector.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmdet.models.detectors.single_stage import SingleStageDetector
4
+ from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
5
+ from mmengine.dist import get_world_size
6
+ from mmengine.logging import print_log
7
+
8
+ from mmyolo.registry import MODELS
9
+
10
+
11
+ @MODELS.register_module()
12
+ class YOLODetector(SingleStageDetector):
13
+ r"""Implementation of YOLO Series
14
+
15
+ Args:
16
+ backbone (:obj:`ConfigDict` or dict): The backbone config.
17
+ neck (:obj:`ConfigDict` or dict): The neck config.
18
+ bbox_head (:obj:`ConfigDict` or dict): The bbox head config.
19
+ train_cfg (:obj:`ConfigDict` or dict, optional): The training config
20
+ of YOLO. Defaults to None.
21
+ test_cfg (:obj:`ConfigDict` or dict, optional): The testing config
22
+ of YOLO. Defaults to None.
23
+ data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of
24
+ :class:`DetDataPreprocessor` to process the input data.
25
+ Defaults to None.
26
+ init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or
27
+ list[dict], optional): Initialization config dict.
28
+ Defaults to None.
29
+ use_syncbn (bool): whether to use SyncBatchNorm. Defaults to True.
30
+ """
31
+
32
+ def __init__(self,
33
+ backbone: ConfigType,
34
+ neck: ConfigType,
35
+ bbox_head: ConfigType,
36
+ train_cfg: OptConfigType = None,
37
+ test_cfg: OptConfigType = None,
38
+ data_preprocessor: OptConfigType = None,
39
+ init_cfg: OptMultiConfig = None,
40
+ use_syncbn: bool = True):
41
+ super().__init__(
42
+ backbone=backbone,
43
+ neck=neck,
44
+ bbox_head=bbox_head,
45
+ train_cfg=train_cfg,
46
+ test_cfg=test_cfg,
47
+ data_preprocessor=data_preprocessor,
48
+ init_cfg=init_cfg)
49
+
50
+ # TODO: Waiting for mmengine support
51
+ if use_syncbn and get_world_size() > 1:
52
+ torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
53
+ print_log('Using SyncBatchNorm()', 'current')
mmyolo/models/layers/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .ema import ExpMomentumEMA
3
+ from .yolo_bricks import (BepC3StageBlock, CSPLayerWithTwoConv,
4
+ DarknetBottleneck, EELANBlock, EffectiveSELayer,
5
+ ELANBlock, ImplicitA, ImplicitM,
6
+ MaxPoolAndStrideConvBlock, PPYOLOEBasicBlock,
7
+ RepStageBlock, RepVGGBlock, SPPFBottleneck,
8
+ SPPFCSPBlock, TinyDownSampleBlock)
9
+
10
+ __all__ = [
11
+ 'SPPFBottleneck', 'RepVGGBlock', 'RepStageBlock', 'ExpMomentumEMA',
12
+ 'ELANBlock', 'MaxPoolAndStrideConvBlock', 'SPPFCSPBlock',
13
+ 'PPYOLOEBasicBlock', 'EffectiveSELayer', 'TinyDownSampleBlock',
14
+ 'EELANBlock', 'ImplicitA', 'ImplicitM', 'BepC3StageBlock',
15
+ 'CSPLayerWithTwoConv', 'DarknetBottleneck'
16
+ ]
mmyolo/models/layers/ema.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from mmdet.models.layers import ExpMomentumEMA as MMDET_ExpMomentumEMA
8
+ from torch import Tensor
9
+
10
+ from mmyolo.registry import MODELS
11
+
12
+
13
+ @MODELS.register_module()
14
+ class ExpMomentumEMA(MMDET_ExpMomentumEMA):
15
+ """Exponential moving average (EMA) with exponential momentum strategy,
16
+ which is used in YOLO.
17
+
18
+ Args:
19
+ model (nn.Module): The model to be averaged.
20
+ momentum (float): The momentum used for updating ema parameter.
21
+ Ema's parameters are updated with the formula:
22
+ `averaged_param = (1-momentum) * averaged_param + momentum *
23
+ source_param`. Defaults to 0.0002.
24
+ gamma (int): Use a larger momentum early in training and gradually
25
+ annealing to a smaller value to update the ema model smoothly. The
26
+ momentum is calculated as
27
+ `(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`.
28
+ Defaults to 2000.
29
+ interval (int): Interval between two updates. Defaults to 1.
30
+ device (torch.device, optional): If provided, the averaged model will
31
+ be stored on the :attr:`device`. Defaults to None.
32
+ update_buffers (bool): if True, it will compute running averages for
33
+ both the parameters and the buffers of the model. Defaults to
34
+ False.
35
+ """
36
+
37
+ def __init__(self,
38
+ model: nn.Module,
39
+ momentum: float = 0.0002,
40
+ gamma: int = 2000,
41
+ interval=1,
42
+ device: Optional[torch.device] = None,
43
+ update_buffers: bool = False):
44
+ super().__init__(
45
+ model=model,
46
+ momentum=momentum,
47
+ interval=interval,
48
+ device=device,
49
+ update_buffers=update_buffers)
50
+ assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
51
+ self.gamma = gamma
52
+
53
+ # Note: There is no need to re-fetch every update,
54
+ # as most models do not change their structure
55
+ # during the training process.
56
+ self.src_parameters = (
57
+ model.state_dict()
58
+ if self.update_buffers else dict(model.named_parameters()))
59
+ if not self.update_buffers:
60
+ self.src_buffers = model.buffers()
61
+
62
+ def avg_func(self, averaged_param: Tensor, source_param: Tensor,
63
+ steps: int):
64
+ """Compute the moving average of the parameters using the exponential
65
+ momentum strategy.
66
+
67
+ Args:
68
+ averaged_param (Tensor): The averaged parameters.
69
+ source_param (Tensor): The source parameters.
70
+ steps (int): The number of times the parameters have been
71
+ updated.
72
+ """
73
+ momentum = (1 - self.momentum) * math.exp(
74
+ -float(1 + steps) / self.gamma) + self.momentum
75
+ averaged_param.lerp_(source_param, momentum)
76
+
77
+ def update_parameters(self, model: nn.Module):
78
+ """Update the parameters after each training step.
79
+
80
+ Args:
81
+ model (nn.Module): The model of the parameter needs to be updated.
82
+ """
83
+ if self.steps == 0:
84
+ for k, p_avg in self.avg_parameters.items():
85
+ p_avg.data.copy_(self.src_parameters[k].data)
86
+ elif self.steps % self.interval == 0:
87
+ for k, p_avg in self.avg_parameters.items():
88
+ if p_avg.dtype.is_floating_point:
89
+ self.avg_func(p_avg.data, self.src_parameters[k].data,
90
+ self.steps)
91
+ if not self.update_buffers:
92
+ # If not update the buffers,
93
+ # keep the buffers in sync with the source model.
94
+ for b_avg, b_src in zip(self.module.buffers(), self.src_buffers):
95
+ b_avg.data.copy_(b_src.data)
96
+ self.steps += 1