ffb3164d894ec3b8685d8944bc2ebd3ff9ffc7bee135a21bd1326d58cb5f3d19
Browse files- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/loading.py +153 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/test_time_aug.py +133 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/transforms.py +889 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/stare.py +27 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/voc.py +29 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/__init__.py +12 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/__init__.py +16 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/cgnet.py +367 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/fast_scnn.py +375 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/hrnet.py +555 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/mobilenet_v2.py +180 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/mobilenet_v3.py +255 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/resnest.py +314 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/resnet.py +688 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/resnext.py +145 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/unet.py +429 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/vit.py +459 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/builder.py +46 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/__init__.py +28 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ann_head.py +245 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/apc_head.py +158 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/aspp_head.py +107 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/cascade_decode_head.py +57 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/cc_head.py +45 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/da_head.py +178 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/decode_head.py +234 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/dm_head.py +140 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/dnl_head.py +131 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ema_head.py +168 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/enc_head.py +187 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/fcn_head.py +81 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/fpn_head.py +68 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/gc_head.py +47 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/lraspp_head.py +90 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/nl_head.py +49 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ocr_head.py +127 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/point_head.py +354 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/psa_head.py +199 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/psp_head.py +101 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/sep_aspp_head.py +101 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/sep_fcn_head.py +51 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/uper_head.py +126 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/__init__.py +12 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/accuracy.py +78 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/cross_entropy_loss.py +198 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/dice_loss.py +119 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/lovasz_loss.py +303 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/utils.py +121 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/necks/__init__.py +4 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/necks/fpn.py +212 -0
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/loading.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
import annotator.mmpkg.mmcv as mmcv
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from ..builder import PIPELINES
|
7 |
+
|
8 |
+
|
9 |
+
@PIPELINES.register_module()
|
10 |
+
class LoadImageFromFile(object):
|
11 |
+
"""Load an image from file.
|
12 |
+
|
13 |
+
Required keys are "img_prefix" and "img_info" (a dict that must contain the
|
14 |
+
key "filename"). Added or updated keys are "filename", "img", "img_shape",
|
15 |
+
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
|
16 |
+
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
|
17 |
+
|
18 |
+
Args:
|
19 |
+
to_float32 (bool): Whether to convert the loaded image to a float32
|
20 |
+
numpy array. If set to False, the loaded image is an uint8 array.
|
21 |
+
Defaults to False.
|
22 |
+
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
|
23 |
+
Defaults to 'color'.
|
24 |
+
file_client_args (dict): Arguments to instantiate a FileClient.
|
25 |
+
See :class:`mmcv.fileio.FileClient` for details.
|
26 |
+
Defaults to ``dict(backend='disk')``.
|
27 |
+
imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
|
28 |
+
'cv2'
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
to_float32=False,
|
33 |
+
color_type='color',
|
34 |
+
file_client_args=dict(backend='disk'),
|
35 |
+
imdecode_backend='cv2'):
|
36 |
+
self.to_float32 = to_float32
|
37 |
+
self.color_type = color_type
|
38 |
+
self.file_client_args = file_client_args.copy()
|
39 |
+
self.file_client = None
|
40 |
+
self.imdecode_backend = imdecode_backend
|
41 |
+
|
42 |
+
def __call__(self, results):
|
43 |
+
"""Call functions to load image and get image meta information.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
results (dict): Result dict from :obj:`mmseg.CustomDataset`.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
dict: The dict contains loaded image and meta information.
|
50 |
+
"""
|
51 |
+
|
52 |
+
if self.file_client is None:
|
53 |
+
self.file_client = mmcv.FileClient(**self.file_client_args)
|
54 |
+
|
55 |
+
if results.get('img_prefix') is not None:
|
56 |
+
filename = osp.join(results['img_prefix'],
|
57 |
+
results['img_info']['filename'])
|
58 |
+
else:
|
59 |
+
filename = results['img_info']['filename']
|
60 |
+
img_bytes = self.file_client.get(filename)
|
61 |
+
img = mmcv.imfrombytes(
|
62 |
+
img_bytes, flag=self.color_type, backend=self.imdecode_backend)
|
63 |
+
if self.to_float32:
|
64 |
+
img = img.astype(np.float32)
|
65 |
+
|
66 |
+
results['filename'] = filename
|
67 |
+
results['ori_filename'] = results['img_info']['filename']
|
68 |
+
results['img'] = img
|
69 |
+
results['img_shape'] = img.shape
|
70 |
+
results['ori_shape'] = img.shape
|
71 |
+
# Set initial values for default meta_keys
|
72 |
+
results['pad_shape'] = img.shape
|
73 |
+
results['scale_factor'] = 1.0
|
74 |
+
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
|
75 |
+
results['img_norm_cfg'] = dict(
|
76 |
+
mean=np.zeros(num_channels, dtype=np.float32),
|
77 |
+
std=np.ones(num_channels, dtype=np.float32),
|
78 |
+
to_rgb=False)
|
79 |
+
return results
|
80 |
+
|
81 |
+
def __repr__(self):
|
82 |
+
repr_str = self.__class__.__name__
|
83 |
+
repr_str += f'(to_float32={self.to_float32},'
|
84 |
+
repr_str += f"color_type='{self.color_type}',"
|
85 |
+
repr_str += f"imdecode_backend='{self.imdecode_backend}')"
|
86 |
+
return repr_str
|
87 |
+
|
88 |
+
|
89 |
+
@PIPELINES.register_module()
|
90 |
+
class LoadAnnotations(object):
|
91 |
+
"""Load annotations for semantic segmentation.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
reduce_zero_label (bool): Whether reduce all label value by 1.
|
95 |
+
Usually used for datasets where 0 is background label.
|
96 |
+
Default: False.
|
97 |
+
file_client_args (dict): Arguments to instantiate a FileClient.
|
98 |
+
See :class:`mmcv.fileio.FileClient` for details.
|
99 |
+
Defaults to ``dict(backend='disk')``.
|
100 |
+
imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
|
101 |
+
'pillow'
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(self,
|
105 |
+
reduce_zero_label=False,
|
106 |
+
file_client_args=dict(backend='disk'),
|
107 |
+
imdecode_backend='pillow'):
|
108 |
+
self.reduce_zero_label = reduce_zero_label
|
109 |
+
self.file_client_args = file_client_args.copy()
|
110 |
+
self.file_client = None
|
111 |
+
self.imdecode_backend = imdecode_backend
|
112 |
+
|
113 |
+
def __call__(self, results):
|
114 |
+
"""Call function to load multiple types annotations.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
results (dict): Result dict from :obj:`mmseg.CustomDataset`.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
dict: The dict contains loaded semantic segmentation annotations.
|
121 |
+
"""
|
122 |
+
|
123 |
+
if self.file_client is None:
|
124 |
+
self.file_client = mmcv.FileClient(**self.file_client_args)
|
125 |
+
|
126 |
+
if results.get('seg_prefix', None) is not None:
|
127 |
+
filename = osp.join(results['seg_prefix'],
|
128 |
+
results['ann_info']['seg_map'])
|
129 |
+
else:
|
130 |
+
filename = results['ann_info']['seg_map']
|
131 |
+
img_bytes = self.file_client.get(filename)
|
132 |
+
gt_semantic_seg = mmcv.imfrombytes(
|
133 |
+
img_bytes, flag='unchanged',
|
134 |
+
backend=self.imdecode_backend).squeeze().astype(np.uint8)
|
135 |
+
# modify if custom classes
|
136 |
+
if results.get('label_map', None) is not None:
|
137 |
+
for old_id, new_id in results['label_map'].items():
|
138 |
+
gt_semantic_seg[gt_semantic_seg == old_id] = new_id
|
139 |
+
# reduce zero_label
|
140 |
+
if self.reduce_zero_label:
|
141 |
+
# avoid using underflow conversion
|
142 |
+
gt_semantic_seg[gt_semantic_seg == 0] = 255
|
143 |
+
gt_semantic_seg = gt_semantic_seg - 1
|
144 |
+
gt_semantic_seg[gt_semantic_seg == 254] = 255
|
145 |
+
results['gt_semantic_seg'] = gt_semantic_seg
|
146 |
+
results['seg_fields'].append('gt_semantic_seg')
|
147 |
+
return results
|
148 |
+
|
149 |
+
def __repr__(self):
|
150 |
+
repr_str = self.__class__.__name__
|
151 |
+
repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
|
152 |
+
repr_str += f"imdecode_backend='{self.imdecode_backend}')"
|
153 |
+
return repr_str
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/test_time_aug.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import annotator.mmpkg.mmcv as mmcv
|
4 |
+
|
5 |
+
from ..builder import PIPELINES
|
6 |
+
from .compose import Compose
|
7 |
+
|
8 |
+
|
9 |
+
@PIPELINES.register_module()
|
10 |
+
class MultiScaleFlipAug(object):
|
11 |
+
"""Test-time augmentation with multiple scales and flipping.
|
12 |
+
|
13 |
+
An example configuration is as followed:
|
14 |
+
|
15 |
+
.. code-block::
|
16 |
+
|
17 |
+
img_scale=(2048, 1024),
|
18 |
+
img_ratios=[0.5, 1.0],
|
19 |
+
flip=True,
|
20 |
+
transforms=[
|
21 |
+
dict(type='Resize', keep_ratio=True),
|
22 |
+
dict(type='RandomFlip'),
|
23 |
+
dict(type='Normalize', **img_norm_cfg),
|
24 |
+
dict(type='Pad', size_divisor=32),
|
25 |
+
dict(type='ImageToTensor', keys=['img']),
|
26 |
+
dict(type='Collect', keys=['img']),
|
27 |
+
]
|
28 |
+
|
29 |
+
After MultiScaleFLipAug with above configuration, the results are wrapped
|
30 |
+
into lists of the same length as followed:
|
31 |
+
|
32 |
+
.. code-block::
|
33 |
+
|
34 |
+
dict(
|
35 |
+
img=[...],
|
36 |
+
img_shape=[...],
|
37 |
+
scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]
|
38 |
+
flip=[False, True, False, True]
|
39 |
+
...
|
40 |
+
)
|
41 |
+
|
42 |
+
Args:
|
43 |
+
transforms (list[dict]): Transforms to apply in each augmentation.
|
44 |
+
img_scale (None | tuple | list[tuple]): Images scales for resizing.
|
45 |
+
img_ratios (float | list[float]): Image ratios for resizing
|
46 |
+
flip (bool): Whether apply flip augmentation. Default: False.
|
47 |
+
flip_direction (str | list[str]): Flip augmentation directions,
|
48 |
+
options are "horizontal" and "vertical". If flip_direction is list,
|
49 |
+
multiple flip augmentations will be applied.
|
50 |
+
It has no effect when flip == False. Default: "horizontal".
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(self,
|
54 |
+
transforms,
|
55 |
+
img_scale,
|
56 |
+
img_ratios=None,
|
57 |
+
flip=False,
|
58 |
+
flip_direction='horizontal'):
|
59 |
+
self.transforms = Compose(transforms)
|
60 |
+
if img_ratios is not None:
|
61 |
+
img_ratios = img_ratios if isinstance(img_ratios,
|
62 |
+
list) else [img_ratios]
|
63 |
+
assert mmcv.is_list_of(img_ratios, float)
|
64 |
+
if img_scale is None:
|
65 |
+
# mode 1: given img_scale=None and a range of image ratio
|
66 |
+
self.img_scale = None
|
67 |
+
assert mmcv.is_list_of(img_ratios, float)
|
68 |
+
elif isinstance(img_scale, tuple) and mmcv.is_list_of(
|
69 |
+
img_ratios, float):
|
70 |
+
assert len(img_scale) == 2
|
71 |
+
# mode 2: given a scale and a range of image ratio
|
72 |
+
self.img_scale = [(int(img_scale[0] * ratio),
|
73 |
+
int(img_scale[1] * ratio))
|
74 |
+
for ratio in img_ratios]
|
75 |
+
else:
|
76 |
+
# mode 3: given multiple scales
|
77 |
+
self.img_scale = img_scale if isinstance(img_scale,
|
78 |
+
list) else [img_scale]
|
79 |
+
assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
|
80 |
+
self.flip = flip
|
81 |
+
self.img_ratios = img_ratios
|
82 |
+
self.flip_direction = flip_direction if isinstance(
|
83 |
+
flip_direction, list) else [flip_direction]
|
84 |
+
assert mmcv.is_list_of(self.flip_direction, str)
|
85 |
+
if not self.flip and self.flip_direction != ['horizontal']:
|
86 |
+
warnings.warn(
|
87 |
+
'flip_direction has no effect when flip is set to False')
|
88 |
+
if (self.flip
|
89 |
+
and not any([t['type'] == 'RandomFlip' for t in transforms])):
|
90 |
+
warnings.warn(
|
91 |
+
'flip has no effect when RandomFlip is not in transforms')
|
92 |
+
|
93 |
+
def __call__(self, results):
|
94 |
+
"""Call function to apply test time augment transforms on results.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
results (dict): Result dict contains the data to transform.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
dict[str: list]: The augmented data, where each value is wrapped
|
101 |
+
into a list.
|
102 |
+
"""
|
103 |
+
|
104 |
+
aug_data = []
|
105 |
+
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
|
106 |
+
h, w = results['img'].shape[:2]
|
107 |
+
img_scale = [(int(w * ratio), int(h * ratio))
|
108 |
+
for ratio in self.img_ratios]
|
109 |
+
else:
|
110 |
+
img_scale = self.img_scale
|
111 |
+
flip_aug = [False, True] if self.flip else [False]
|
112 |
+
for scale in img_scale:
|
113 |
+
for flip in flip_aug:
|
114 |
+
for direction in self.flip_direction:
|
115 |
+
_results = results.copy()
|
116 |
+
_results['scale'] = scale
|
117 |
+
_results['flip'] = flip
|
118 |
+
_results['flip_direction'] = direction
|
119 |
+
data = self.transforms(_results)
|
120 |
+
aug_data.append(data)
|
121 |
+
# list of dict to dict of list
|
122 |
+
aug_data_dict = {key: [] for key in aug_data[0]}
|
123 |
+
for data in aug_data:
|
124 |
+
for key, val in data.items():
|
125 |
+
aug_data_dict[key].append(val)
|
126 |
+
return aug_data_dict
|
127 |
+
|
128 |
+
def __repr__(self):
|
129 |
+
repr_str = self.__class__.__name__
|
130 |
+
repr_str += f'(transforms={self.transforms}, '
|
131 |
+
repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
|
132 |
+
repr_str += f'flip_direction={self.flip_direction}'
|
133 |
+
return repr_str
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/transforms.py
ADDED
@@ -0,0 +1,889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import annotator.mmpkg.mmcv as mmcv
|
2 |
+
import numpy as np
|
3 |
+
from annotator.mmpkg.mmcv.utils import deprecated_api_warning, is_tuple_of
|
4 |
+
from numpy import random
|
5 |
+
|
6 |
+
from ..builder import PIPELINES
|
7 |
+
|
8 |
+
|
9 |
+
@PIPELINES.register_module()
|
10 |
+
class Resize(object):
|
11 |
+
"""Resize images & seg.
|
12 |
+
|
13 |
+
This transform resizes the input image to some scale. If the input dict
|
14 |
+
contains the key "scale", then the scale in the input dict is used,
|
15 |
+
otherwise the specified scale in the init method is used.
|
16 |
+
|
17 |
+
``img_scale`` can be None, a tuple (single-scale) or a list of tuple
|
18 |
+
(multi-scale). There are 4 multiscale modes:
|
19 |
+
|
20 |
+
- ``ratio_range is not None``:
|
21 |
+
1. When img_scale is None, img_scale is the shape of image in results
|
22 |
+
(img_scale = results['img'].shape[:2]) and the image is resized based
|
23 |
+
on the original size. (mode 1)
|
24 |
+
2. When img_scale is a tuple (single-scale), randomly sample a ratio from
|
25 |
+
the ratio range and multiply it with the image scale. (mode 2)
|
26 |
+
|
27 |
+
- ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
|
28 |
+
scale from the a range. (mode 3)
|
29 |
+
|
30 |
+
- ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
|
31 |
+
scale from multiple scales. (mode 4)
|
32 |
+
|
33 |
+
Args:
|
34 |
+
img_scale (tuple or list[tuple]): Images scales for resizing.
|
35 |
+
multiscale_mode (str): Either "range" or "value".
|
36 |
+
ratio_range (tuple[float]): (min_ratio, max_ratio)
|
37 |
+
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
|
38 |
+
image.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self,
|
42 |
+
img_scale=None,
|
43 |
+
multiscale_mode='range',
|
44 |
+
ratio_range=None,
|
45 |
+
keep_ratio=True):
|
46 |
+
if img_scale is None:
|
47 |
+
self.img_scale = None
|
48 |
+
else:
|
49 |
+
if isinstance(img_scale, list):
|
50 |
+
self.img_scale = img_scale
|
51 |
+
else:
|
52 |
+
self.img_scale = [img_scale]
|
53 |
+
assert mmcv.is_list_of(self.img_scale, tuple)
|
54 |
+
|
55 |
+
if ratio_range is not None:
|
56 |
+
# mode 1: given img_scale=None and a range of image ratio
|
57 |
+
# mode 2: given a scale and a range of image ratio
|
58 |
+
assert self.img_scale is None or len(self.img_scale) == 1
|
59 |
+
else:
|
60 |
+
# mode 3 and 4: given multiple scales or a range of scales
|
61 |
+
assert multiscale_mode in ['value', 'range']
|
62 |
+
|
63 |
+
self.multiscale_mode = multiscale_mode
|
64 |
+
self.ratio_range = ratio_range
|
65 |
+
self.keep_ratio = keep_ratio
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def random_select(img_scales):
|
69 |
+
"""Randomly select an img_scale from given candidates.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
img_scales (list[tuple]): Images scales for selection.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
|
76 |
+
where ``img_scale`` is the selected image scale and
|
77 |
+
``scale_idx`` is the selected index in the given candidates.
|
78 |
+
"""
|
79 |
+
|
80 |
+
assert mmcv.is_list_of(img_scales, tuple)
|
81 |
+
scale_idx = np.random.randint(len(img_scales))
|
82 |
+
img_scale = img_scales[scale_idx]
|
83 |
+
return img_scale, scale_idx
|
84 |
+
|
85 |
+
@staticmethod
|
86 |
+
def random_sample(img_scales):
|
87 |
+
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
img_scales (list[tuple]): Images scale range for sampling.
|
91 |
+
There must be two tuples in img_scales, which specify the lower
|
92 |
+
and upper bound of image scales.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
(tuple, None): Returns a tuple ``(img_scale, None)``, where
|
96 |
+
``img_scale`` is sampled scale and None is just a placeholder
|
97 |
+
to be consistent with :func:`random_select`.
|
98 |
+
"""
|
99 |
+
|
100 |
+
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
|
101 |
+
img_scale_long = [max(s) for s in img_scales]
|
102 |
+
img_scale_short = [min(s) for s in img_scales]
|
103 |
+
long_edge = np.random.randint(
|
104 |
+
min(img_scale_long),
|
105 |
+
max(img_scale_long) + 1)
|
106 |
+
short_edge = np.random.randint(
|
107 |
+
min(img_scale_short),
|
108 |
+
max(img_scale_short) + 1)
|
109 |
+
img_scale = (long_edge, short_edge)
|
110 |
+
return img_scale, None
|
111 |
+
|
112 |
+
@staticmethod
|
113 |
+
def random_sample_ratio(img_scale, ratio_range):
|
114 |
+
"""Randomly sample an img_scale when ``ratio_range`` is specified.
|
115 |
+
|
116 |
+
A ratio will be randomly sampled from the range specified by
|
117 |
+
``ratio_range``. Then it would be multiplied with ``img_scale`` to
|
118 |
+
generate sampled scale.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
img_scale (tuple): Images scale base to multiply with ratio.
|
122 |
+
ratio_range (tuple[float]): The minimum and maximum ratio to scale
|
123 |
+
the ``img_scale``.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
(tuple, None): Returns a tuple ``(scale, None)``, where
|
127 |
+
``scale`` is sampled ratio multiplied with ``img_scale`` and
|
128 |
+
None is just a placeholder to be consistent with
|
129 |
+
:func:`random_select`.
|
130 |
+
"""
|
131 |
+
|
132 |
+
assert isinstance(img_scale, tuple) and len(img_scale) == 2
|
133 |
+
min_ratio, max_ratio = ratio_range
|
134 |
+
assert min_ratio <= max_ratio
|
135 |
+
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
|
136 |
+
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
|
137 |
+
return scale, None
|
138 |
+
|
139 |
+
def _random_scale(self, results):
|
140 |
+
"""Randomly sample an img_scale according to ``ratio_range`` and
|
141 |
+
``multiscale_mode``.
|
142 |
+
|
143 |
+
If ``ratio_range`` is specified, a ratio will be sampled and be
|
144 |
+
multiplied with ``img_scale``.
|
145 |
+
If multiple scales are specified by ``img_scale``, a scale will be
|
146 |
+
sampled according to ``multiscale_mode``.
|
147 |
+
Otherwise, single scale will be used.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
results (dict): Result dict from :obj:`dataset`.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
dict: Two new keys 'scale` and 'scale_idx` are added into
|
154 |
+
``results``, which would be used by subsequent pipelines.
|
155 |
+
"""
|
156 |
+
|
157 |
+
if self.ratio_range is not None:
|
158 |
+
if self.img_scale is None:
|
159 |
+
h, w = results['img'].shape[:2]
|
160 |
+
scale, scale_idx = self.random_sample_ratio((w, h),
|
161 |
+
self.ratio_range)
|
162 |
+
else:
|
163 |
+
scale, scale_idx = self.random_sample_ratio(
|
164 |
+
self.img_scale[0], self.ratio_range)
|
165 |
+
elif len(self.img_scale) == 1:
|
166 |
+
scale, scale_idx = self.img_scale[0], 0
|
167 |
+
elif self.multiscale_mode == 'range':
|
168 |
+
scale, scale_idx = self.random_sample(self.img_scale)
|
169 |
+
elif self.multiscale_mode == 'value':
|
170 |
+
scale, scale_idx = self.random_select(self.img_scale)
|
171 |
+
else:
|
172 |
+
raise NotImplementedError
|
173 |
+
|
174 |
+
results['scale'] = scale
|
175 |
+
results['scale_idx'] = scale_idx
|
176 |
+
|
177 |
+
def _resize_img(self, results):
|
178 |
+
"""Resize images with ``results['scale']``."""
|
179 |
+
if self.keep_ratio:
|
180 |
+
img, scale_factor = mmcv.imrescale(
|
181 |
+
results['img'], results['scale'], return_scale=True)
|
182 |
+
# the w_scale and h_scale has minor difference
|
183 |
+
# a real fix should be done in the mmcv.imrescale in the future
|
184 |
+
new_h, new_w = img.shape[:2]
|
185 |
+
h, w = results['img'].shape[:2]
|
186 |
+
w_scale = new_w / w
|
187 |
+
h_scale = new_h / h
|
188 |
+
else:
|
189 |
+
img, w_scale, h_scale = mmcv.imresize(
|
190 |
+
results['img'], results['scale'], return_scale=True)
|
191 |
+
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
|
192 |
+
dtype=np.float32)
|
193 |
+
results['img'] = img
|
194 |
+
results['img_shape'] = img.shape
|
195 |
+
results['pad_shape'] = img.shape # in case that there is no padding
|
196 |
+
results['scale_factor'] = scale_factor
|
197 |
+
results['keep_ratio'] = self.keep_ratio
|
198 |
+
|
199 |
+
def _resize_seg(self, results):
|
200 |
+
"""Resize semantic segmentation map with ``results['scale']``."""
|
201 |
+
for key in results.get('seg_fields', []):
|
202 |
+
if self.keep_ratio:
|
203 |
+
gt_seg = mmcv.imrescale(
|
204 |
+
results[key], results['scale'], interpolation='nearest')
|
205 |
+
else:
|
206 |
+
gt_seg = mmcv.imresize(
|
207 |
+
results[key], results['scale'], interpolation='nearest')
|
208 |
+
results[key] = gt_seg
|
209 |
+
|
210 |
+
def __call__(self, results):
|
211 |
+
"""Call function to resize images, bounding boxes, masks, semantic
|
212 |
+
segmentation map.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
results (dict): Result dict from loading pipeline.
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
|
219 |
+
'keep_ratio' keys are added into result dict.
|
220 |
+
"""
|
221 |
+
|
222 |
+
if 'scale' not in results:
|
223 |
+
self._random_scale(results)
|
224 |
+
self._resize_img(results)
|
225 |
+
self._resize_seg(results)
|
226 |
+
return results
|
227 |
+
|
228 |
+
def __repr__(self):
|
229 |
+
repr_str = self.__class__.__name__
|
230 |
+
repr_str += (f'(img_scale={self.img_scale}, '
|
231 |
+
f'multiscale_mode={self.multiscale_mode}, '
|
232 |
+
f'ratio_range={self.ratio_range}, '
|
233 |
+
f'keep_ratio={self.keep_ratio})')
|
234 |
+
return repr_str
|
235 |
+
|
236 |
+
|
237 |
+
@PIPELINES.register_module()
|
238 |
+
class RandomFlip(object):
|
239 |
+
"""Flip the image & seg.
|
240 |
+
|
241 |
+
If the input dict contains the key "flip", then the flag will be used,
|
242 |
+
otherwise it will be randomly decided by a ratio specified in the init
|
243 |
+
method.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
prob (float, optional): The flipping probability. Default: None.
|
247 |
+
direction(str, optional): The flipping direction. Options are
|
248 |
+
'horizontal' and 'vertical'. Default: 'horizontal'.
|
249 |
+
"""
|
250 |
+
|
251 |
+
@deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
|
252 |
+
def __init__(self, prob=None, direction='horizontal'):
|
253 |
+
self.prob = prob
|
254 |
+
self.direction = direction
|
255 |
+
if prob is not None:
|
256 |
+
assert prob >= 0 and prob <= 1
|
257 |
+
assert direction in ['horizontal', 'vertical']
|
258 |
+
|
259 |
+
def __call__(self, results):
|
260 |
+
"""Call function to flip bounding boxes, masks, semantic segmentation
|
261 |
+
maps.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
results (dict): Result dict from loading pipeline.
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
dict: Flipped results, 'flip', 'flip_direction' keys are added into
|
268 |
+
result dict.
|
269 |
+
"""
|
270 |
+
|
271 |
+
if 'flip' not in results:
|
272 |
+
flip = True if np.random.rand() < self.prob else False
|
273 |
+
results['flip'] = flip
|
274 |
+
if 'flip_direction' not in results:
|
275 |
+
results['flip_direction'] = self.direction
|
276 |
+
if results['flip']:
|
277 |
+
# flip image
|
278 |
+
results['img'] = mmcv.imflip(
|
279 |
+
results['img'], direction=results['flip_direction'])
|
280 |
+
|
281 |
+
# flip segs
|
282 |
+
for key in results.get('seg_fields', []):
|
283 |
+
# use copy() to make numpy stride positive
|
284 |
+
results[key] = mmcv.imflip(
|
285 |
+
results[key], direction=results['flip_direction']).copy()
|
286 |
+
return results
|
287 |
+
|
288 |
+
def __repr__(self):
|
289 |
+
return self.__class__.__name__ + f'(prob={self.prob})'
|
290 |
+
|
291 |
+
|
292 |
+
@PIPELINES.register_module()
|
293 |
+
class Pad(object):
|
294 |
+
"""Pad the image & mask.
|
295 |
+
|
296 |
+
There are two padding modes: (1) pad to a fixed size and (2) pad to the
|
297 |
+
minimum size that is divisible by some number.
|
298 |
+
Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
|
299 |
+
|
300 |
+
Args:
|
301 |
+
size (tuple, optional): Fixed padding size.
|
302 |
+
size_divisor (int, optional): The divisor of padded size.
|
303 |
+
pad_val (float, optional): Padding value. Default: 0.
|
304 |
+
seg_pad_val (float, optional): Padding value of segmentation map.
|
305 |
+
Default: 255.
|
306 |
+
"""
|
307 |
+
|
308 |
+
def __init__(self,
|
309 |
+
size=None,
|
310 |
+
size_divisor=None,
|
311 |
+
pad_val=0,
|
312 |
+
seg_pad_val=255):
|
313 |
+
self.size = size
|
314 |
+
self.size_divisor = size_divisor
|
315 |
+
self.pad_val = pad_val
|
316 |
+
self.seg_pad_val = seg_pad_val
|
317 |
+
# only one of size and size_divisor should be valid
|
318 |
+
assert size is not None or size_divisor is not None
|
319 |
+
assert size is None or size_divisor is None
|
320 |
+
|
321 |
+
def _pad_img(self, results):
|
322 |
+
"""Pad images according to ``self.size``."""
|
323 |
+
if self.size is not None:
|
324 |
+
padded_img = mmcv.impad(
|
325 |
+
results['img'], shape=self.size, pad_val=self.pad_val)
|
326 |
+
elif self.size_divisor is not None:
|
327 |
+
padded_img = mmcv.impad_to_multiple(
|
328 |
+
results['img'], self.size_divisor, pad_val=self.pad_val)
|
329 |
+
results['img'] = padded_img
|
330 |
+
results['pad_shape'] = padded_img.shape
|
331 |
+
results['pad_fixed_size'] = self.size
|
332 |
+
results['pad_size_divisor'] = self.size_divisor
|
333 |
+
|
334 |
+
def _pad_seg(self, results):
|
335 |
+
"""Pad masks according to ``results['pad_shape']``."""
|
336 |
+
for key in results.get('seg_fields', []):
|
337 |
+
results[key] = mmcv.impad(
|
338 |
+
results[key],
|
339 |
+
shape=results['pad_shape'][:2],
|
340 |
+
pad_val=self.seg_pad_val)
|
341 |
+
|
342 |
+
def __call__(self, results):
|
343 |
+
"""Call function to pad images, masks, semantic segmentation maps.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
results (dict): Result dict from loading pipeline.
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
dict: Updated result dict.
|
350 |
+
"""
|
351 |
+
|
352 |
+
self._pad_img(results)
|
353 |
+
self._pad_seg(results)
|
354 |
+
return results
|
355 |
+
|
356 |
+
def __repr__(self):
|
357 |
+
repr_str = self.__class__.__name__
|
358 |
+
repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \
|
359 |
+
f'pad_val={self.pad_val})'
|
360 |
+
return repr_str
|
361 |
+
|
362 |
+
|
363 |
+
@PIPELINES.register_module()
|
364 |
+
class Normalize(object):
|
365 |
+
"""Normalize the image.
|
366 |
+
|
367 |
+
Added key is "img_norm_cfg".
|
368 |
+
|
369 |
+
Args:
|
370 |
+
mean (sequence): Mean values of 3 channels.
|
371 |
+
std (sequence): Std values of 3 channels.
|
372 |
+
to_rgb (bool): Whether to convert the image from BGR to RGB,
|
373 |
+
default is true.
|
374 |
+
"""
|
375 |
+
|
376 |
+
def __init__(self, mean, std, to_rgb=True):
|
377 |
+
self.mean = np.array(mean, dtype=np.float32)
|
378 |
+
self.std = np.array(std, dtype=np.float32)
|
379 |
+
self.to_rgb = to_rgb
|
380 |
+
|
381 |
+
def __call__(self, results):
|
382 |
+
"""Call function to normalize images.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
results (dict): Result dict from loading pipeline.
|
386 |
+
|
387 |
+
Returns:
|
388 |
+
dict: Normalized results, 'img_norm_cfg' key is added into
|
389 |
+
result dict.
|
390 |
+
"""
|
391 |
+
|
392 |
+
results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
|
393 |
+
self.to_rgb)
|
394 |
+
results['img_norm_cfg'] = dict(
|
395 |
+
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
|
396 |
+
return results
|
397 |
+
|
398 |
+
def __repr__(self):
|
399 |
+
repr_str = self.__class__.__name__
|
400 |
+
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
|
401 |
+
f'{self.to_rgb})'
|
402 |
+
return repr_str
|
403 |
+
|
404 |
+
|
405 |
+
@PIPELINES.register_module()
|
406 |
+
class Rerange(object):
|
407 |
+
"""Rerange the image pixel value.
|
408 |
+
|
409 |
+
Args:
|
410 |
+
min_value (float or int): Minimum value of the reranged image.
|
411 |
+
Default: 0.
|
412 |
+
max_value (float or int): Maximum value of the reranged image.
|
413 |
+
Default: 255.
|
414 |
+
"""
|
415 |
+
|
416 |
+
def __init__(self, min_value=0, max_value=255):
|
417 |
+
assert isinstance(min_value, float) or isinstance(min_value, int)
|
418 |
+
assert isinstance(max_value, float) or isinstance(max_value, int)
|
419 |
+
assert min_value < max_value
|
420 |
+
self.min_value = min_value
|
421 |
+
self.max_value = max_value
|
422 |
+
|
423 |
+
def __call__(self, results):
|
424 |
+
"""Call function to rerange images.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
results (dict): Result dict from loading pipeline.
|
428 |
+
Returns:
|
429 |
+
dict: Reranged results.
|
430 |
+
"""
|
431 |
+
|
432 |
+
img = results['img']
|
433 |
+
img_min_value = np.min(img)
|
434 |
+
img_max_value = np.max(img)
|
435 |
+
|
436 |
+
assert img_min_value < img_max_value
|
437 |
+
# rerange to [0, 1]
|
438 |
+
img = (img - img_min_value) / (img_max_value - img_min_value)
|
439 |
+
# rerange to [min_value, max_value]
|
440 |
+
img = img * (self.max_value - self.min_value) + self.min_value
|
441 |
+
results['img'] = img
|
442 |
+
|
443 |
+
return results
|
444 |
+
|
445 |
+
def __repr__(self):
|
446 |
+
repr_str = self.__class__.__name__
|
447 |
+
repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
|
448 |
+
return repr_str
|
449 |
+
|
450 |
+
|
451 |
+
@PIPELINES.register_module()
|
452 |
+
class CLAHE(object):
|
453 |
+
"""Use CLAHE method to process the image.
|
454 |
+
|
455 |
+
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
|
456 |
+
Graphics Gems, 1994:474-485.` for more information.
|
457 |
+
|
458 |
+
Args:
|
459 |
+
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
|
460 |
+
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
|
461 |
+
Input image will be divided into equally sized rectangular tiles.
|
462 |
+
It defines the number of tiles in row and column. Default: (8, 8).
|
463 |
+
"""
|
464 |
+
|
465 |
+
def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
|
466 |
+
assert isinstance(clip_limit, (float, int))
|
467 |
+
self.clip_limit = clip_limit
|
468 |
+
assert is_tuple_of(tile_grid_size, int)
|
469 |
+
assert len(tile_grid_size) == 2
|
470 |
+
self.tile_grid_size = tile_grid_size
|
471 |
+
|
472 |
+
def __call__(self, results):
|
473 |
+
"""Call function to Use CLAHE method process images.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
results (dict): Result dict from loading pipeline.
|
477 |
+
|
478 |
+
Returns:
|
479 |
+
dict: Processed results.
|
480 |
+
"""
|
481 |
+
|
482 |
+
for i in range(results['img'].shape[2]):
|
483 |
+
results['img'][:, :, i] = mmcv.clahe(
|
484 |
+
np.array(results['img'][:, :, i], dtype=np.uint8),
|
485 |
+
self.clip_limit, self.tile_grid_size)
|
486 |
+
|
487 |
+
return results
|
488 |
+
|
489 |
+
def __repr__(self):
|
490 |
+
repr_str = self.__class__.__name__
|
491 |
+
repr_str += f'(clip_limit={self.clip_limit}, '\
|
492 |
+
f'tile_grid_size={self.tile_grid_size})'
|
493 |
+
return repr_str
|
494 |
+
|
495 |
+
|
496 |
+
@PIPELINES.register_module()
|
497 |
+
class RandomCrop(object):
|
498 |
+
"""Random crop the image & seg.
|
499 |
+
|
500 |
+
Args:
|
501 |
+
crop_size (tuple): Expected size after cropping, (h, w).
|
502 |
+
cat_max_ratio (float): The maximum ratio that single category could
|
503 |
+
occupy.
|
504 |
+
"""
|
505 |
+
|
506 |
+
def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
|
507 |
+
assert crop_size[0] > 0 and crop_size[1] > 0
|
508 |
+
self.crop_size = crop_size
|
509 |
+
self.cat_max_ratio = cat_max_ratio
|
510 |
+
self.ignore_index = ignore_index
|
511 |
+
|
512 |
+
def get_crop_bbox(self, img):
|
513 |
+
"""Randomly get a crop bounding box."""
|
514 |
+
margin_h = max(img.shape[0] - self.crop_size[0], 0)
|
515 |
+
margin_w = max(img.shape[1] - self.crop_size[1], 0)
|
516 |
+
offset_h = np.random.randint(0, margin_h + 1)
|
517 |
+
offset_w = np.random.randint(0, margin_w + 1)
|
518 |
+
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
|
519 |
+
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
|
520 |
+
|
521 |
+
return crop_y1, crop_y2, crop_x1, crop_x2
|
522 |
+
|
523 |
+
def crop(self, img, crop_bbox):
|
524 |
+
"""Crop from ``img``"""
|
525 |
+
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
|
526 |
+
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
|
527 |
+
return img
|
528 |
+
|
529 |
+
def __call__(self, results):
|
530 |
+
"""Call function to randomly crop images, semantic segmentation maps.
|
531 |
+
|
532 |
+
Args:
|
533 |
+
results (dict): Result dict from loading pipeline.
|
534 |
+
|
535 |
+
Returns:
|
536 |
+
dict: Randomly cropped results, 'img_shape' key in result dict is
|
537 |
+
updated according to crop size.
|
538 |
+
"""
|
539 |
+
|
540 |
+
img = results['img']
|
541 |
+
crop_bbox = self.get_crop_bbox(img)
|
542 |
+
if self.cat_max_ratio < 1.:
|
543 |
+
# Repeat 10 times
|
544 |
+
for _ in range(10):
|
545 |
+
seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
|
546 |
+
labels, cnt = np.unique(seg_temp, return_counts=True)
|
547 |
+
cnt = cnt[labels != self.ignore_index]
|
548 |
+
if len(cnt) > 1 and np.max(cnt) / np.sum(
|
549 |
+
cnt) < self.cat_max_ratio:
|
550 |
+
break
|
551 |
+
crop_bbox = self.get_crop_bbox(img)
|
552 |
+
|
553 |
+
# crop the image
|
554 |
+
img = self.crop(img, crop_bbox)
|
555 |
+
img_shape = img.shape
|
556 |
+
results['img'] = img
|
557 |
+
results['img_shape'] = img_shape
|
558 |
+
|
559 |
+
# crop semantic seg
|
560 |
+
for key in results.get('seg_fields', []):
|
561 |
+
results[key] = self.crop(results[key], crop_bbox)
|
562 |
+
|
563 |
+
return results
|
564 |
+
|
565 |
+
def __repr__(self):
|
566 |
+
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
|
567 |
+
|
568 |
+
|
569 |
+
@PIPELINES.register_module()
|
570 |
+
class RandomRotate(object):
|
571 |
+
"""Rotate the image & seg.
|
572 |
+
|
573 |
+
Args:
|
574 |
+
prob (float): The rotation probability.
|
575 |
+
degree (float, tuple[float]): Range of degrees to select from. If
|
576 |
+
degree is a number instead of tuple like (min, max),
|
577 |
+
the range of degree will be (``-degree``, ``+degree``)
|
578 |
+
pad_val (float, optional): Padding value of image. Default: 0.
|
579 |
+
seg_pad_val (float, optional): Padding value of segmentation map.
|
580 |
+
Default: 255.
|
581 |
+
center (tuple[float], optional): Center point (w, h) of the rotation in
|
582 |
+
the source image. If not specified, the center of the image will be
|
583 |
+
used. Default: None.
|
584 |
+
auto_bound (bool): Whether to adjust the image size to cover the whole
|
585 |
+
rotated image. Default: False
|
586 |
+
"""
|
587 |
+
|
588 |
+
def __init__(self,
|
589 |
+
prob,
|
590 |
+
degree,
|
591 |
+
pad_val=0,
|
592 |
+
seg_pad_val=255,
|
593 |
+
center=None,
|
594 |
+
auto_bound=False):
|
595 |
+
self.prob = prob
|
596 |
+
assert prob >= 0 and prob <= 1
|
597 |
+
if isinstance(degree, (float, int)):
|
598 |
+
assert degree > 0, f'degree {degree} should be positive'
|
599 |
+
self.degree = (-degree, degree)
|
600 |
+
else:
|
601 |
+
self.degree = degree
|
602 |
+
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
|
603 |
+
f'tuple of (min, max)'
|
604 |
+
self.pal_val = pad_val
|
605 |
+
self.seg_pad_val = seg_pad_val
|
606 |
+
self.center = center
|
607 |
+
self.auto_bound = auto_bound
|
608 |
+
|
609 |
+
def __call__(self, results):
|
610 |
+
"""Call function to rotate image, semantic segmentation maps.
|
611 |
+
|
612 |
+
Args:
|
613 |
+
results (dict): Result dict from loading pipeline.
|
614 |
+
|
615 |
+
Returns:
|
616 |
+
dict: Rotated results.
|
617 |
+
"""
|
618 |
+
|
619 |
+
rotate = True if np.random.rand() < self.prob else False
|
620 |
+
degree = np.random.uniform(min(*self.degree), max(*self.degree))
|
621 |
+
if rotate:
|
622 |
+
# rotate image
|
623 |
+
results['img'] = mmcv.imrotate(
|
624 |
+
results['img'],
|
625 |
+
angle=degree,
|
626 |
+
border_value=self.pal_val,
|
627 |
+
center=self.center,
|
628 |
+
auto_bound=self.auto_bound)
|
629 |
+
|
630 |
+
# rotate segs
|
631 |
+
for key in results.get('seg_fields', []):
|
632 |
+
results[key] = mmcv.imrotate(
|
633 |
+
results[key],
|
634 |
+
angle=degree,
|
635 |
+
border_value=self.seg_pad_val,
|
636 |
+
center=self.center,
|
637 |
+
auto_bound=self.auto_bound,
|
638 |
+
interpolation='nearest')
|
639 |
+
return results
|
640 |
+
|
641 |
+
def __repr__(self):
|
642 |
+
repr_str = self.__class__.__name__
|
643 |
+
repr_str += f'(prob={self.prob}, ' \
|
644 |
+
f'degree={self.degree}, ' \
|
645 |
+
f'pad_val={self.pal_val}, ' \
|
646 |
+
f'seg_pad_val={self.seg_pad_val}, ' \
|
647 |
+
f'center={self.center}, ' \
|
648 |
+
f'auto_bound={self.auto_bound})'
|
649 |
+
return repr_str
|
650 |
+
|
651 |
+
|
652 |
+
@PIPELINES.register_module()
|
653 |
+
class RGB2Gray(object):
|
654 |
+
"""Convert RGB image to grayscale image.
|
655 |
+
|
656 |
+
This transform calculate the weighted mean of input image channels with
|
657 |
+
``weights`` and then expand the channels to ``out_channels``. When
|
658 |
+
``out_channels`` is None, the number of output channels is the same as
|
659 |
+
input channels.
|
660 |
+
|
661 |
+
Args:
|
662 |
+
out_channels (int): Expected number of output channels after
|
663 |
+
transforming. Default: None.
|
664 |
+
weights (tuple[float]): The weights to calculate the weighted mean.
|
665 |
+
Default: (0.299, 0.587, 0.114).
|
666 |
+
"""
|
667 |
+
|
668 |
+
def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
|
669 |
+
assert out_channels is None or out_channels > 0
|
670 |
+
self.out_channels = out_channels
|
671 |
+
assert isinstance(weights, tuple)
|
672 |
+
for item in weights:
|
673 |
+
assert isinstance(item, (float, int))
|
674 |
+
self.weights = weights
|
675 |
+
|
676 |
+
def __call__(self, results):
|
677 |
+
"""Call function to convert RGB image to grayscale image.
|
678 |
+
|
679 |
+
Args:
|
680 |
+
results (dict): Result dict from loading pipeline.
|
681 |
+
|
682 |
+
Returns:
|
683 |
+
dict: Result dict with grayscale image.
|
684 |
+
"""
|
685 |
+
img = results['img']
|
686 |
+
assert len(img.shape) == 3
|
687 |
+
assert img.shape[2] == len(self.weights)
|
688 |
+
weights = np.array(self.weights).reshape((1, 1, -1))
|
689 |
+
img = (img * weights).sum(2, keepdims=True)
|
690 |
+
if self.out_channels is None:
|
691 |
+
img = img.repeat(weights.shape[2], axis=2)
|
692 |
+
else:
|
693 |
+
img = img.repeat(self.out_channels, axis=2)
|
694 |
+
|
695 |
+
results['img'] = img
|
696 |
+
results['img_shape'] = img.shape
|
697 |
+
|
698 |
+
return results
|
699 |
+
|
700 |
+
def __repr__(self):
|
701 |
+
repr_str = self.__class__.__name__
|
702 |
+
repr_str += f'(out_channels={self.out_channels}, ' \
|
703 |
+
f'weights={self.weights})'
|
704 |
+
return repr_str
|
705 |
+
|
706 |
+
|
707 |
+
@PIPELINES.register_module()
|
708 |
+
class AdjustGamma(object):
|
709 |
+
"""Using gamma correction to process the image.
|
710 |
+
|
711 |
+
Args:
|
712 |
+
gamma (float or int): Gamma value used in gamma correction.
|
713 |
+
Default: 1.0.
|
714 |
+
"""
|
715 |
+
|
716 |
+
def __init__(self, gamma=1.0):
|
717 |
+
assert isinstance(gamma, float) or isinstance(gamma, int)
|
718 |
+
assert gamma > 0
|
719 |
+
self.gamma = gamma
|
720 |
+
inv_gamma = 1.0 / gamma
|
721 |
+
self.table = np.array([(i / 255.0)**inv_gamma * 255
|
722 |
+
for i in np.arange(256)]).astype('uint8')
|
723 |
+
|
724 |
+
def __call__(self, results):
|
725 |
+
"""Call function to process the image with gamma correction.
|
726 |
+
|
727 |
+
Args:
|
728 |
+
results (dict): Result dict from loading pipeline.
|
729 |
+
|
730 |
+
Returns:
|
731 |
+
dict: Processed results.
|
732 |
+
"""
|
733 |
+
|
734 |
+
results['img'] = mmcv.lut_transform(
|
735 |
+
np.array(results['img'], dtype=np.uint8), self.table)
|
736 |
+
|
737 |
+
return results
|
738 |
+
|
739 |
+
def __repr__(self):
|
740 |
+
return self.__class__.__name__ + f'(gamma={self.gamma})'
|
741 |
+
|
742 |
+
|
743 |
+
@PIPELINES.register_module()
|
744 |
+
class SegRescale(object):
|
745 |
+
"""Rescale semantic segmentation maps.
|
746 |
+
|
747 |
+
Args:
|
748 |
+
scale_factor (float): The scale factor of the final output.
|
749 |
+
"""
|
750 |
+
|
751 |
+
def __init__(self, scale_factor=1):
|
752 |
+
self.scale_factor = scale_factor
|
753 |
+
|
754 |
+
def __call__(self, results):
|
755 |
+
"""Call function to scale the semantic segmentation map.
|
756 |
+
|
757 |
+
Args:
|
758 |
+
results (dict): Result dict from loading pipeline.
|
759 |
+
|
760 |
+
Returns:
|
761 |
+
dict: Result dict with semantic segmentation map scaled.
|
762 |
+
"""
|
763 |
+
for key in results.get('seg_fields', []):
|
764 |
+
if self.scale_factor != 1:
|
765 |
+
results[key] = mmcv.imrescale(
|
766 |
+
results[key], self.scale_factor, interpolation='nearest')
|
767 |
+
return results
|
768 |
+
|
769 |
+
def __repr__(self):
|
770 |
+
return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
|
771 |
+
|
772 |
+
|
773 |
+
@PIPELINES.register_module()
|
774 |
+
class PhotoMetricDistortion(object):
|
775 |
+
"""Apply photometric distortion to image sequentially, every transformation
|
776 |
+
is applied with a probability of 0.5. The position of random contrast is in
|
777 |
+
second or second to last.
|
778 |
+
|
779 |
+
1. random brightness
|
780 |
+
2. random contrast (mode 0)
|
781 |
+
3. convert color from BGR to HSV
|
782 |
+
4. random saturation
|
783 |
+
5. random hue
|
784 |
+
6. convert color from HSV to BGR
|
785 |
+
7. random contrast (mode 1)
|
786 |
+
|
787 |
+
Args:
|
788 |
+
brightness_delta (int): delta of brightness.
|
789 |
+
contrast_range (tuple): range of contrast.
|
790 |
+
saturation_range (tuple): range of saturation.
|
791 |
+
hue_delta (int): delta of hue.
|
792 |
+
"""
|
793 |
+
|
794 |
+
def __init__(self,
|
795 |
+
brightness_delta=32,
|
796 |
+
contrast_range=(0.5, 1.5),
|
797 |
+
saturation_range=(0.5, 1.5),
|
798 |
+
hue_delta=18):
|
799 |
+
self.brightness_delta = brightness_delta
|
800 |
+
self.contrast_lower, self.contrast_upper = contrast_range
|
801 |
+
self.saturation_lower, self.saturation_upper = saturation_range
|
802 |
+
self.hue_delta = hue_delta
|
803 |
+
|
804 |
+
def convert(self, img, alpha=1, beta=0):
|
805 |
+
"""Multiple with alpha and add beat with clip."""
|
806 |
+
img = img.astype(np.float32) * alpha + beta
|
807 |
+
img = np.clip(img, 0, 255)
|
808 |
+
return img.astype(np.uint8)
|
809 |
+
|
810 |
+
def brightness(self, img):
|
811 |
+
"""Brightness distortion."""
|
812 |
+
if random.randint(2):
|
813 |
+
return self.convert(
|
814 |
+
img,
|
815 |
+
beta=random.uniform(-self.brightness_delta,
|
816 |
+
self.brightness_delta))
|
817 |
+
return img
|
818 |
+
|
819 |
+
def contrast(self, img):
|
820 |
+
"""Contrast distortion."""
|
821 |
+
if random.randint(2):
|
822 |
+
return self.convert(
|
823 |
+
img,
|
824 |
+
alpha=random.uniform(self.contrast_lower, self.contrast_upper))
|
825 |
+
return img
|
826 |
+
|
827 |
+
def saturation(self, img):
|
828 |
+
"""Saturation distortion."""
|
829 |
+
if random.randint(2):
|
830 |
+
img = mmcv.bgr2hsv(img)
|
831 |
+
img[:, :, 1] = self.convert(
|
832 |
+
img[:, :, 1],
|
833 |
+
alpha=random.uniform(self.saturation_lower,
|
834 |
+
self.saturation_upper))
|
835 |
+
img = mmcv.hsv2bgr(img)
|
836 |
+
return img
|
837 |
+
|
838 |
+
def hue(self, img):
|
839 |
+
"""Hue distortion."""
|
840 |
+
if random.randint(2):
|
841 |
+
img = mmcv.bgr2hsv(img)
|
842 |
+
img[:, :,
|
843 |
+
0] = (img[:, :, 0].astype(int) +
|
844 |
+
random.randint(-self.hue_delta, self.hue_delta)) % 180
|
845 |
+
img = mmcv.hsv2bgr(img)
|
846 |
+
return img
|
847 |
+
|
848 |
+
def __call__(self, results):
|
849 |
+
"""Call function to perform photometric distortion on images.
|
850 |
+
|
851 |
+
Args:
|
852 |
+
results (dict): Result dict from loading pipeline.
|
853 |
+
|
854 |
+
Returns:
|
855 |
+
dict: Result dict with images distorted.
|
856 |
+
"""
|
857 |
+
|
858 |
+
img = results['img']
|
859 |
+
# random brightness
|
860 |
+
img = self.brightness(img)
|
861 |
+
|
862 |
+
# mode == 0 --> do random contrast first
|
863 |
+
# mode == 1 --> do random contrast last
|
864 |
+
mode = random.randint(2)
|
865 |
+
if mode == 1:
|
866 |
+
img = self.contrast(img)
|
867 |
+
|
868 |
+
# random saturation
|
869 |
+
img = self.saturation(img)
|
870 |
+
|
871 |
+
# random hue
|
872 |
+
img = self.hue(img)
|
873 |
+
|
874 |
+
# random contrast
|
875 |
+
if mode == 0:
|
876 |
+
img = self.contrast(img)
|
877 |
+
|
878 |
+
results['img'] = img
|
879 |
+
return results
|
880 |
+
|
881 |
+
def __repr__(self):
|
882 |
+
repr_str = self.__class__.__name__
|
883 |
+
repr_str += (f'(brightness_delta={self.brightness_delta}, '
|
884 |
+
f'contrast_range=({self.contrast_lower}, '
|
885 |
+
f'{self.contrast_upper}), '
|
886 |
+
f'saturation_range=({self.saturation_lower}, '
|
887 |
+
f'{self.saturation_upper}), '
|
888 |
+
f'hue_delta={self.hue_delta})')
|
889 |
+
return repr_str
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/stare.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
from .builder import DATASETS
|
4 |
+
from .custom import CustomDataset
|
5 |
+
|
6 |
+
|
7 |
+
@DATASETS.register_module()
|
8 |
+
class STAREDataset(CustomDataset):
|
9 |
+
"""STARE dataset.
|
10 |
+
|
11 |
+
In segmentation map annotation for STARE, 0 stands for background, which is
|
12 |
+
included in 2 categories. ``reduce_zero_label`` is fixed to False. The
|
13 |
+
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
14 |
+
'.ah.png'.
|
15 |
+
"""
|
16 |
+
|
17 |
+
CLASSES = ('background', 'vessel')
|
18 |
+
|
19 |
+
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
20 |
+
|
21 |
+
def __init__(self, **kwargs):
|
22 |
+
super(STAREDataset, self).__init__(
|
23 |
+
img_suffix='.png',
|
24 |
+
seg_map_suffix='.ah.png',
|
25 |
+
reduce_zero_label=False,
|
26 |
+
**kwargs)
|
27 |
+
assert osp.exists(self.img_dir)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/voc.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
from .builder import DATASETS
|
4 |
+
from .custom import CustomDataset
|
5 |
+
|
6 |
+
|
7 |
+
@DATASETS.register_module()
|
8 |
+
class PascalVOCDataset(CustomDataset):
|
9 |
+
"""Pascal VOC dataset.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
split (str): Split txt file for Pascal VOC.
|
13 |
+
"""
|
14 |
+
|
15 |
+
CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
|
16 |
+
'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
17 |
+
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
|
18 |
+
'train', 'tvmonitor')
|
19 |
+
|
20 |
+
PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
21 |
+
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
22 |
+
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
23 |
+
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
24 |
+
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
25 |
+
|
26 |
+
def __init__(self, split, **kwargs):
|
27 |
+
super(PascalVOCDataset, self).__init__(
|
28 |
+
img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
|
29 |
+
assert osp.exists(self.img_dir) and self.split is not None
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .backbones import * # noqa: F401,F403
|
2 |
+
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
|
3 |
+
build_head, build_loss, build_segmentor)
|
4 |
+
from .decode_heads import * # noqa: F401,F403
|
5 |
+
from .losses import * # noqa: F401,F403
|
6 |
+
from .necks import * # noqa: F401,F403
|
7 |
+
from .segmentors import * # noqa: F401,F403
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
|
11 |
+
'build_head', 'build_loss', 'build_segmentor'
|
12 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cgnet import CGNet
|
2 |
+
# from .fast_scnn import FastSCNN
|
3 |
+
from .hrnet import HRNet
|
4 |
+
from .mobilenet_v2 import MobileNetV2
|
5 |
+
from .mobilenet_v3 import MobileNetV3
|
6 |
+
from .resnest import ResNeSt
|
7 |
+
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
8 |
+
from .resnext import ResNeXt
|
9 |
+
from .unet import UNet
|
10 |
+
from .vit import VisionTransformer
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet',
|
14 |
+
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
15 |
+
'VisionTransformer'
|
16 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/cgnet.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.utils.checkpoint as cp
|
4 |
+
from annotator.mmpkg.mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
|
5 |
+
constant_init, kaiming_init)
|
6 |
+
from annotator.mmpkg.mmcv.runner import load_checkpoint
|
7 |
+
from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm
|
8 |
+
|
9 |
+
from annotator.mmpkg.mmseg.utils import get_root_logger
|
10 |
+
from ..builder import BACKBONES
|
11 |
+
|
12 |
+
|
13 |
+
class GlobalContextExtractor(nn.Module):
|
14 |
+
"""Global Context Extractor for CGNet.
|
15 |
+
|
16 |
+
This class is employed to refine the joint feature of both local feature
|
17 |
+
and surrounding context.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
channel (int): Number of input feature channels.
|
21 |
+
reduction (int): Reductions for global context extractor. Default: 16.
|
22 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
23 |
+
memory while slowing down the training speed. Default: False.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, channel, reduction=16, with_cp=False):
|
27 |
+
super(GlobalContextExtractor, self).__init__()
|
28 |
+
self.channel = channel
|
29 |
+
self.reduction = reduction
|
30 |
+
assert reduction >= 1 and channel >= reduction
|
31 |
+
self.with_cp = with_cp
|
32 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
33 |
+
self.fc = nn.Sequential(
|
34 |
+
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
|
35 |
+
nn.Linear(channel // reduction, channel), nn.Sigmoid())
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
|
39 |
+
def _inner_forward(x):
|
40 |
+
num_batch, num_channel = x.size()[:2]
|
41 |
+
y = self.avg_pool(x).view(num_batch, num_channel)
|
42 |
+
y = self.fc(y).view(num_batch, num_channel, 1, 1)
|
43 |
+
return x * y
|
44 |
+
|
45 |
+
if self.with_cp and x.requires_grad:
|
46 |
+
out = cp.checkpoint(_inner_forward, x)
|
47 |
+
else:
|
48 |
+
out = _inner_forward(x)
|
49 |
+
|
50 |
+
return out
|
51 |
+
|
52 |
+
|
53 |
+
class ContextGuidedBlock(nn.Module):
|
54 |
+
"""Context Guided Block for CGNet.
|
55 |
+
|
56 |
+
This class consists of four components: local feature extractor,
|
57 |
+
surrounding feature extractor, joint feature extractor and global
|
58 |
+
context extractor.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
in_channels (int): Number of input feature channels.
|
62 |
+
out_channels (int): Number of output feature channels.
|
63 |
+
dilation (int): Dilation rate for surrounding context extractor.
|
64 |
+
Default: 2.
|
65 |
+
reduction (int): Reduction for global context extractor. Default: 16.
|
66 |
+
skip_connect (bool): Add input to output or not. Default: True.
|
67 |
+
downsample (bool): Downsample the input to 1/2 or not. Default: False.
|
68 |
+
conv_cfg (dict): Config dict for convolution layer.
|
69 |
+
Default: None, which means using conv2d.
|
70 |
+
norm_cfg (dict): Config dict for normalization layer.
|
71 |
+
Default: dict(type='BN', requires_grad=True).
|
72 |
+
act_cfg (dict): Config dict for activation layer.
|
73 |
+
Default: dict(type='PReLU').
|
74 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
75 |
+
memory while slowing down the training speed. Default: False.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self,
|
79 |
+
in_channels,
|
80 |
+
out_channels,
|
81 |
+
dilation=2,
|
82 |
+
reduction=16,
|
83 |
+
skip_connect=True,
|
84 |
+
downsample=False,
|
85 |
+
conv_cfg=None,
|
86 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
87 |
+
act_cfg=dict(type='PReLU'),
|
88 |
+
with_cp=False):
|
89 |
+
super(ContextGuidedBlock, self).__init__()
|
90 |
+
self.with_cp = with_cp
|
91 |
+
self.downsample = downsample
|
92 |
+
|
93 |
+
channels = out_channels if downsample else out_channels // 2
|
94 |
+
if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
|
95 |
+
act_cfg['num_parameters'] = channels
|
96 |
+
kernel_size = 3 if downsample else 1
|
97 |
+
stride = 2 if downsample else 1
|
98 |
+
padding = (kernel_size - 1) // 2
|
99 |
+
|
100 |
+
self.conv1x1 = ConvModule(
|
101 |
+
in_channels,
|
102 |
+
channels,
|
103 |
+
kernel_size,
|
104 |
+
stride,
|
105 |
+
padding,
|
106 |
+
conv_cfg=conv_cfg,
|
107 |
+
norm_cfg=norm_cfg,
|
108 |
+
act_cfg=act_cfg)
|
109 |
+
|
110 |
+
self.f_loc = build_conv_layer(
|
111 |
+
conv_cfg,
|
112 |
+
channels,
|
113 |
+
channels,
|
114 |
+
kernel_size=3,
|
115 |
+
padding=1,
|
116 |
+
groups=channels,
|
117 |
+
bias=False)
|
118 |
+
self.f_sur = build_conv_layer(
|
119 |
+
conv_cfg,
|
120 |
+
channels,
|
121 |
+
channels,
|
122 |
+
kernel_size=3,
|
123 |
+
padding=dilation,
|
124 |
+
groups=channels,
|
125 |
+
dilation=dilation,
|
126 |
+
bias=False)
|
127 |
+
|
128 |
+
self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
|
129 |
+
self.activate = nn.PReLU(2 * channels)
|
130 |
+
|
131 |
+
if downsample:
|
132 |
+
self.bottleneck = build_conv_layer(
|
133 |
+
conv_cfg,
|
134 |
+
2 * channels,
|
135 |
+
out_channels,
|
136 |
+
kernel_size=1,
|
137 |
+
bias=False)
|
138 |
+
|
139 |
+
self.skip_connect = skip_connect and not downsample
|
140 |
+
self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
|
144 |
+
def _inner_forward(x):
|
145 |
+
out = self.conv1x1(x)
|
146 |
+
loc = self.f_loc(out)
|
147 |
+
sur = self.f_sur(out)
|
148 |
+
|
149 |
+
joi_feat = torch.cat([loc, sur], 1) # the joint feature
|
150 |
+
joi_feat = self.bn(joi_feat)
|
151 |
+
joi_feat = self.activate(joi_feat)
|
152 |
+
if self.downsample:
|
153 |
+
joi_feat = self.bottleneck(joi_feat) # channel = out_channels
|
154 |
+
# f_glo is employed to refine the joint feature
|
155 |
+
out = self.f_glo(joi_feat)
|
156 |
+
|
157 |
+
if self.skip_connect:
|
158 |
+
return x + out
|
159 |
+
else:
|
160 |
+
return out
|
161 |
+
|
162 |
+
if self.with_cp and x.requires_grad:
|
163 |
+
out = cp.checkpoint(_inner_forward, x)
|
164 |
+
else:
|
165 |
+
out = _inner_forward(x)
|
166 |
+
|
167 |
+
return out
|
168 |
+
|
169 |
+
|
170 |
+
class InputInjection(nn.Module):
|
171 |
+
"""Downsampling module for CGNet."""
|
172 |
+
|
173 |
+
def __init__(self, num_downsampling):
|
174 |
+
super(InputInjection, self).__init__()
|
175 |
+
self.pool = nn.ModuleList()
|
176 |
+
for i in range(num_downsampling):
|
177 |
+
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
for pool in self.pool:
|
181 |
+
x = pool(x)
|
182 |
+
return x
|
183 |
+
|
184 |
+
|
185 |
+
@BACKBONES.register_module()
|
186 |
+
class CGNet(nn.Module):
|
187 |
+
"""CGNet backbone.
|
188 |
+
|
189 |
+
A Light-weight Context Guided Network for Semantic Segmentation
|
190 |
+
arXiv: https://arxiv.org/abs/1811.08201
|
191 |
+
|
192 |
+
Args:
|
193 |
+
in_channels (int): Number of input image channels. Normally 3.
|
194 |
+
num_channels (tuple[int]): Numbers of feature channels at each stages.
|
195 |
+
Default: (32, 64, 128).
|
196 |
+
num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
|
197 |
+
Default: (3, 21).
|
198 |
+
dilations (tuple[int]): Dilation rate for surrounding context
|
199 |
+
extractors at stage 1 and stage 2. Default: (2, 4).
|
200 |
+
reductions (tuple[int]): Reductions for global context extractors at
|
201 |
+
stage 1 and stage 2. Default: (8, 16).
|
202 |
+
conv_cfg (dict): Config dict for convolution layer.
|
203 |
+
Default: None, which means using conv2d.
|
204 |
+
norm_cfg (dict): Config dict for normalization layer.
|
205 |
+
Default: dict(type='BN', requires_grad=True).
|
206 |
+
act_cfg (dict): Config dict for activation layer.
|
207 |
+
Default: dict(type='PReLU').
|
208 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
209 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
210 |
+
and its variants only. Default: False.
|
211 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
212 |
+
memory while slowing down the training speed. Default: False.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self,
|
216 |
+
in_channels=3,
|
217 |
+
num_channels=(32, 64, 128),
|
218 |
+
num_blocks=(3, 21),
|
219 |
+
dilations=(2, 4),
|
220 |
+
reductions=(8, 16),
|
221 |
+
conv_cfg=None,
|
222 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
223 |
+
act_cfg=dict(type='PReLU'),
|
224 |
+
norm_eval=False,
|
225 |
+
with_cp=False):
|
226 |
+
|
227 |
+
super(CGNet, self).__init__()
|
228 |
+
self.in_channels = in_channels
|
229 |
+
self.num_channels = num_channels
|
230 |
+
assert isinstance(self.num_channels, tuple) and len(
|
231 |
+
self.num_channels) == 3
|
232 |
+
self.num_blocks = num_blocks
|
233 |
+
assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
|
234 |
+
self.dilations = dilations
|
235 |
+
assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
|
236 |
+
self.reductions = reductions
|
237 |
+
assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
|
238 |
+
self.conv_cfg = conv_cfg
|
239 |
+
self.norm_cfg = norm_cfg
|
240 |
+
self.act_cfg = act_cfg
|
241 |
+
if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
|
242 |
+
self.act_cfg['num_parameters'] = num_channels[0]
|
243 |
+
self.norm_eval = norm_eval
|
244 |
+
self.with_cp = with_cp
|
245 |
+
|
246 |
+
cur_channels = in_channels
|
247 |
+
self.stem = nn.ModuleList()
|
248 |
+
for i in range(3):
|
249 |
+
self.stem.append(
|
250 |
+
ConvModule(
|
251 |
+
cur_channels,
|
252 |
+
num_channels[0],
|
253 |
+
3,
|
254 |
+
2 if i == 0 else 1,
|
255 |
+
padding=1,
|
256 |
+
conv_cfg=conv_cfg,
|
257 |
+
norm_cfg=norm_cfg,
|
258 |
+
act_cfg=act_cfg))
|
259 |
+
cur_channels = num_channels[0]
|
260 |
+
|
261 |
+
self.inject_2x = InputInjection(1) # down-sample for Input, factor=2
|
262 |
+
self.inject_4x = InputInjection(2) # down-sample for Input, factor=4
|
263 |
+
|
264 |
+
cur_channels += in_channels
|
265 |
+
self.norm_prelu_0 = nn.Sequential(
|
266 |
+
build_norm_layer(norm_cfg, cur_channels)[1],
|
267 |
+
nn.PReLU(cur_channels))
|
268 |
+
|
269 |
+
# stage 1
|
270 |
+
self.level1 = nn.ModuleList()
|
271 |
+
for i in range(num_blocks[0]):
|
272 |
+
self.level1.append(
|
273 |
+
ContextGuidedBlock(
|
274 |
+
cur_channels if i == 0 else num_channels[1],
|
275 |
+
num_channels[1],
|
276 |
+
dilations[0],
|
277 |
+
reductions[0],
|
278 |
+
downsample=(i == 0),
|
279 |
+
conv_cfg=conv_cfg,
|
280 |
+
norm_cfg=norm_cfg,
|
281 |
+
act_cfg=act_cfg,
|
282 |
+
with_cp=with_cp)) # CG block
|
283 |
+
|
284 |
+
cur_channels = 2 * num_channels[1] + in_channels
|
285 |
+
self.norm_prelu_1 = nn.Sequential(
|
286 |
+
build_norm_layer(norm_cfg, cur_channels)[1],
|
287 |
+
nn.PReLU(cur_channels))
|
288 |
+
|
289 |
+
# stage 2
|
290 |
+
self.level2 = nn.ModuleList()
|
291 |
+
for i in range(num_blocks[1]):
|
292 |
+
self.level2.append(
|
293 |
+
ContextGuidedBlock(
|
294 |
+
cur_channels if i == 0 else num_channels[2],
|
295 |
+
num_channels[2],
|
296 |
+
dilations[1],
|
297 |
+
reductions[1],
|
298 |
+
downsample=(i == 0),
|
299 |
+
conv_cfg=conv_cfg,
|
300 |
+
norm_cfg=norm_cfg,
|
301 |
+
act_cfg=act_cfg,
|
302 |
+
with_cp=with_cp)) # CG block
|
303 |
+
|
304 |
+
cur_channels = 2 * num_channels[2]
|
305 |
+
self.norm_prelu_2 = nn.Sequential(
|
306 |
+
build_norm_layer(norm_cfg, cur_channels)[1],
|
307 |
+
nn.PReLU(cur_channels))
|
308 |
+
|
309 |
+
def forward(self, x):
|
310 |
+
output = []
|
311 |
+
|
312 |
+
# stage 0
|
313 |
+
inp_2x = self.inject_2x(x)
|
314 |
+
inp_4x = self.inject_4x(x)
|
315 |
+
for layer in self.stem:
|
316 |
+
x = layer(x)
|
317 |
+
x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
|
318 |
+
output.append(x)
|
319 |
+
|
320 |
+
# stage 1
|
321 |
+
for i, layer in enumerate(self.level1):
|
322 |
+
x = layer(x)
|
323 |
+
if i == 0:
|
324 |
+
down1 = x
|
325 |
+
x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
|
326 |
+
output.append(x)
|
327 |
+
|
328 |
+
# stage 2
|
329 |
+
for i, layer in enumerate(self.level2):
|
330 |
+
x = layer(x)
|
331 |
+
if i == 0:
|
332 |
+
down2 = x
|
333 |
+
x = self.norm_prelu_2(torch.cat([down2, x], 1))
|
334 |
+
output.append(x)
|
335 |
+
|
336 |
+
return output
|
337 |
+
|
338 |
+
def init_weights(self, pretrained=None):
|
339 |
+
"""Initialize the weights in backbone.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
pretrained (str, optional): Path to pre-trained weights.
|
343 |
+
Defaults to None.
|
344 |
+
"""
|
345 |
+
if isinstance(pretrained, str):
|
346 |
+
logger = get_root_logger()
|
347 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
348 |
+
elif pretrained is None:
|
349 |
+
for m in self.modules():
|
350 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
351 |
+
kaiming_init(m)
|
352 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
353 |
+
constant_init(m, 1)
|
354 |
+
elif isinstance(m, nn.PReLU):
|
355 |
+
constant_init(m, 0)
|
356 |
+
else:
|
357 |
+
raise TypeError('pretrained must be a str or None')
|
358 |
+
|
359 |
+
def train(self, mode=True):
|
360 |
+
"""Convert the model into training mode will keeping the normalization
|
361 |
+
layer freezed."""
|
362 |
+
super(CGNet, self).train(mode)
|
363 |
+
if mode and self.norm_eval:
|
364 |
+
for m in self.modules():
|
365 |
+
# trick: eval have effect on BatchNorm only
|
366 |
+
if isinstance(m, _BatchNorm):
|
367 |
+
m.eval()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/fast_scnn.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from annotator.mmpkg.mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
|
4 |
+
kaiming_init)
|
5 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
6 |
+
|
7 |
+
from annotator.mmpkg.mmseg.models.decode_heads.psp_head import PPM
|
8 |
+
from annotator.mmpkg.mmseg.ops import resize
|
9 |
+
from ..builder import BACKBONES
|
10 |
+
from ..utils.inverted_residual import InvertedResidual
|
11 |
+
|
12 |
+
|
13 |
+
class LearningToDownsample(nn.Module):
|
14 |
+
"""Learning to downsample module.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
in_channels (int): Number of input channels.
|
18 |
+
dw_channels (tuple[int]): Number of output channels of the first and
|
19 |
+
the second depthwise conv (dwconv) layers.
|
20 |
+
out_channels (int): Number of output channels of the whole
|
21 |
+
'learning to downsample' module.
|
22 |
+
conv_cfg (dict | None): Config of conv layers. Default: None
|
23 |
+
norm_cfg (dict | None): Config of norm layers. Default:
|
24 |
+
dict(type='BN')
|
25 |
+
act_cfg (dict): Config of activation layers. Default:
|
26 |
+
dict(type='ReLU')
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
in_channels,
|
31 |
+
dw_channels,
|
32 |
+
out_channels,
|
33 |
+
conv_cfg=None,
|
34 |
+
norm_cfg=dict(type='BN'),
|
35 |
+
act_cfg=dict(type='ReLU')):
|
36 |
+
super(LearningToDownsample, self).__init__()
|
37 |
+
self.conv_cfg = conv_cfg
|
38 |
+
self.norm_cfg = norm_cfg
|
39 |
+
self.act_cfg = act_cfg
|
40 |
+
dw_channels1 = dw_channels[0]
|
41 |
+
dw_channels2 = dw_channels[1]
|
42 |
+
|
43 |
+
self.conv = ConvModule(
|
44 |
+
in_channels,
|
45 |
+
dw_channels1,
|
46 |
+
3,
|
47 |
+
stride=2,
|
48 |
+
conv_cfg=self.conv_cfg,
|
49 |
+
norm_cfg=self.norm_cfg,
|
50 |
+
act_cfg=self.act_cfg)
|
51 |
+
self.dsconv1 = DepthwiseSeparableConvModule(
|
52 |
+
dw_channels1,
|
53 |
+
dw_channels2,
|
54 |
+
kernel_size=3,
|
55 |
+
stride=2,
|
56 |
+
padding=1,
|
57 |
+
norm_cfg=self.norm_cfg)
|
58 |
+
self.dsconv2 = DepthwiseSeparableConvModule(
|
59 |
+
dw_channels2,
|
60 |
+
out_channels,
|
61 |
+
kernel_size=3,
|
62 |
+
stride=2,
|
63 |
+
padding=1,
|
64 |
+
norm_cfg=self.norm_cfg)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
x = self.conv(x)
|
68 |
+
x = self.dsconv1(x)
|
69 |
+
x = self.dsconv2(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
class GlobalFeatureExtractor(nn.Module):
|
74 |
+
"""Global feature extractor module.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
in_channels (int): Number of input channels of the GFE module.
|
78 |
+
Default: 64
|
79 |
+
block_channels (tuple[int]): Tuple of ints. Each int specifies the
|
80 |
+
number of output channels of each Inverted Residual module.
|
81 |
+
Default: (64, 96, 128)
|
82 |
+
out_channels(int): Number of output channels of the GFE module.
|
83 |
+
Default: 128
|
84 |
+
expand_ratio (int): Adjusts number of channels of the hidden layer
|
85 |
+
in InvertedResidual by this amount.
|
86 |
+
Default: 6
|
87 |
+
num_blocks (tuple[int]): Tuple of ints. Each int specifies the
|
88 |
+
number of times each Inverted Residual module is repeated.
|
89 |
+
The repeated Inverted Residual modules are called a 'group'.
|
90 |
+
Default: (3, 3, 3)
|
91 |
+
strides (tuple[int]): Tuple of ints. Each int specifies
|
92 |
+
the downsampling factor of each 'group'.
|
93 |
+
Default: (2, 2, 1)
|
94 |
+
pool_scales (tuple[int]): Tuple of ints. Each int specifies
|
95 |
+
the parameter required in 'global average pooling' within PPM.
|
96 |
+
Default: (1, 2, 3, 6)
|
97 |
+
conv_cfg (dict | None): Config of conv layers. Default: None
|
98 |
+
norm_cfg (dict | None): Config of norm layers. Default:
|
99 |
+
dict(type='BN')
|
100 |
+
act_cfg (dict): Config of activation layers. Default:
|
101 |
+
dict(type='ReLU')
|
102 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
103 |
+
Default: False
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self,
|
107 |
+
in_channels=64,
|
108 |
+
block_channels=(64, 96, 128),
|
109 |
+
out_channels=128,
|
110 |
+
expand_ratio=6,
|
111 |
+
num_blocks=(3, 3, 3),
|
112 |
+
strides=(2, 2, 1),
|
113 |
+
pool_scales=(1, 2, 3, 6),
|
114 |
+
conv_cfg=None,
|
115 |
+
norm_cfg=dict(type='BN'),
|
116 |
+
act_cfg=dict(type='ReLU'),
|
117 |
+
align_corners=False):
|
118 |
+
super(GlobalFeatureExtractor, self).__init__()
|
119 |
+
self.conv_cfg = conv_cfg
|
120 |
+
self.norm_cfg = norm_cfg
|
121 |
+
self.act_cfg = act_cfg
|
122 |
+
assert len(block_channels) == len(num_blocks) == 3
|
123 |
+
self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
|
124 |
+
num_blocks[0], strides[0],
|
125 |
+
expand_ratio)
|
126 |
+
self.bottleneck2 = self._make_layer(block_channels[0],
|
127 |
+
block_channels[1], num_blocks[1],
|
128 |
+
strides[1], expand_ratio)
|
129 |
+
self.bottleneck3 = self._make_layer(block_channels[1],
|
130 |
+
block_channels[2], num_blocks[2],
|
131 |
+
strides[2], expand_ratio)
|
132 |
+
self.ppm = PPM(
|
133 |
+
pool_scales,
|
134 |
+
block_channels[2],
|
135 |
+
block_channels[2] // 4,
|
136 |
+
conv_cfg=self.conv_cfg,
|
137 |
+
norm_cfg=self.norm_cfg,
|
138 |
+
act_cfg=self.act_cfg,
|
139 |
+
align_corners=align_corners)
|
140 |
+
self.out = ConvModule(
|
141 |
+
block_channels[2] * 2,
|
142 |
+
out_channels,
|
143 |
+
1,
|
144 |
+
conv_cfg=self.conv_cfg,
|
145 |
+
norm_cfg=self.norm_cfg,
|
146 |
+
act_cfg=self.act_cfg)
|
147 |
+
|
148 |
+
def _make_layer(self,
|
149 |
+
in_channels,
|
150 |
+
out_channels,
|
151 |
+
blocks,
|
152 |
+
stride=1,
|
153 |
+
expand_ratio=6):
|
154 |
+
layers = [
|
155 |
+
InvertedResidual(
|
156 |
+
in_channels,
|
157 |
+
out_channels,
|
158 |
+
stride,
|
159 |
+
expand_ratio,
|
160 |
+
norm_cfg=self.norm_cfg)
|
161 |
+
]
|
162 |
+
for i in range(1, blocks):
|
163 |
+
layers.append(
|
164 |
+
InvertedResidual(
|
165 |
+
out_channels,
|
166 |
+
out_channels,
|
167 |
+
1,
|
168 |
+
expand_ratio,
|
169 |
+
norm_cfg=self.norm_cfg))
|
170 |
+
return nn.Sequential(*layers)
|
171 |
+
|
172 |
+
def forward(self, x):
|
173 |
+
x = self.bottleneck1(x)
|
174 |
+
x = self.bottleneck2(x)
|
175 |
+
x = self.bottleneck3(x)
|
176 |
+
x = torch.cat([x, *self.ppm(x)], dim=1)
|
177 |
+
x = self.out(x)
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class FeatureFusionModule(nn.Module):
|
182 |
+
"""Feature fusion module.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
higher_in_channels (int): Number of input channels of the
|
186 |
+
higher-resolution branch.
|
187 |
+
lower_in_channels (int): Number of input channels of the
|
188 |
+
lower-resolution branch.
|
189 |
+
out_channels (int): Number of output channels.
|
190 |
+
conv_cfg (dict | None): Config of conv layers. Default: None
|
191 |
+
norm_cfg (dict | None): Config of norm layers. Default:
|
192 |
+
dict(type='BN')
|
193 |
+
act_cfg (dict): Config of activation layers. Default:
|
194 |
+
dict(type='ReLU')
|
195 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
196 |
+
Default: False
|
197 |
+
"""
|
198 |
+
|
199 |
+
def __init__(self,
|
200 |
+
higher_in_channels,
|
201 |
+
lower_in_channels,
|
202 |
+
out_channels,
|
203 |
+
conv_cfg=None,
|
204 |
+
norm_cfg=dict(type='BN'),
|
205 |
+
act_cfg=dict(type='ReLU'),
|
206 |
+
align_corners=False):
|
207 |
+
super(FeatureFusionModule, self).__init__()
|
208 |
+
self.conv_cfg = conv_cfg
|
209 |
+
self.norm_cfg = norm_cfg
|
210 |
+
self.act_cfg = act_cfg
|
211 |
+
self.align_corners = align_corners
|
212 |
+
self.dwconv = ConvModule(
|
213 |
+
lower_in_channels,
|
214 |
+
out_channels,
|
215 |
+
1,
|
216 |
+
conv_cfg=self.conv_cfg,
|
217 |
+
norm_cfg=self.norm_cfg,
|
218 |
+
act_cfg=self.act_cfg)
|
219 |
+
self.conv_lower_res = ConvModule(
|
220 |
+
out_channels,
|
221 |
+
out_channels,
|
222 |
+
1,
|
223 |
+
conv_cfg=self.conv_cfg,
|
224 |
+
norm_cfg=self.norm_cfg,
|
225 |
+
act_cfg=None)
|
226 |
+
self.conv_higher_res = ConvModule(
|
227 |
+
higher_in_channels,
|
228 |
+
out_channels,
|
229 |
+
1,
|
230 |
+
conv_cfg=self.conv_cfg,
|
231 |
+
norm_cfg=self.norm_cfg,
|
232 |
+
act_cfg=None)
|
233 |
+
self.relu = nn.ReLU(True)
|
234 |
+
|
235 |
+
def forward(self, higher_res_feature, lower_res_feature):
|
236 |
+
lower_res_feature = resize(
|
237 |
+
lower_res_feature,
|
238 |
+
size=higher_res_feature.size()[2:],
|
239 |
+
mode='bilinear',
|
240 |
+
align_corners=self.align_corners)
|
241 |
+
lower_res_feature = self.dwconv(lower_res_feature)
|
242 |
+
lower_res_feature = self.conv_lower_res(lower_res_feature)
|
243 |
+
|
244 |
+
higher_res_feature = self.conv_higher_res(higher_res_feature)
|
245 |
+
out = higher_res_feature + lower_res_feature
|
246 |
+
return self.relu(out)
|
247 |
+
|
248 |
+
|
249 |
+
@BACKBONES.register_module()
|
250 |
+
class FastSCNN(nn.Module):
|
251 |
+
"""Fast-SCNN Backbone.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
in_channels (int): Number of input image channels. Default: 3.
|
255 |
+
downsample_dw_channels (tuple[int]): Number of output channels after
|
256 |
+
the first conv layer & the second conv layer in
|
257 |
+
Learning-To-Downsample (LTD) module.
|
258 |
+
Default: (32, 48).
|
259 |
+
global_in_channels (int): Number of input channels of
|
260 |
+
Global Feature Extractor(GFE).
|
261 |
+
Equal to number of output channels of LTD.
|
262 |
+
Default: 64.
|
263 |
+
global_block_channels (tuple[int]): Tuple of integers that describe
|
264 |
+
the output channels for each of the MobileNet-v2 bottleneck
|
265 |
+
residual blocks in GFE.
|
266 |
+
Default: (64, 96, 128).
|
267 |
+
global_block_strides (tuple[int]): Tuple of integers
|
268 |
+
that describe the strides (downsampling factors) for each of the
|
269 |
+
MobileNet-v2 bottleneck residual blocks in GFE.
|
270 |
+
Default: (2, 2, 1).
|
271 |
+
global_out_channels (int): Number of output channels of GFE.
|
272 |
+
Default: 128.
|
273 |
+
higher_in_channels (int): Number of input channels of the higher
|
274 |
+
resolution branch in FFM.
|
275 |
+
Equal to global_in_channels.
|
276 |
+
Default: 64.
|
277 |
+
lower_in_channels (int): Number of input channels of the lower
|
278 |
+
resolution branch in FFM.
|
279 |
+
Equal to global_out_channels.
|
280 |
+
Default: 128.
|
281 |
+
fusion_out_channels (int): Number of output channels of FFM.
|
282 |
+
Default: 128.
|
283 |
+
out_indices (tuple): Tuple of indices of list
|
284 |
+
[higher_res_features, lower_res_features, fusion_output].
|
285 |
+
Often set to (0,1,2) to enable aux. heads.
|
286 |
+
Default: (0, 1, 2).
|
287 |
+
conv_cfg (dict | None): Config of conv layers. Default: None
|
288 |
+
norm_cfg (dict | None): Config of norm layers. Default:
|
289 |
+
dict(type='BN')
|
290 |
+
act_cfg (dict): Config of activation layers. Default:
|
291 |
+
dict(type='ReLU')
|
292 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
293 |
+
Default: False
|
294 |
+
"""
|
295 |
+
|
296 |
+
def __init__(self,
|
297 |
+
in_channels=3,
|
298 |
+
downsample_dw_channels=(32, 48),
|
299 |
+
global_in_channels=64,
|
300 |
+
global_block_channels=(64, 96, 128),
|
301 |
+
global_block_strides=(2, 2, 1),
|
302 |
+
global_out_channels=128,
|
303 |
+
higher_in_channels=64,
|
304 |
+
lower_in_channels=128,
|
305 |
+
fusion_out_channels=128,
|
306 |
+
out_indices=(0, 1, 2),
|
307 |
+
conv_cfg=None,
|
308 |
+
norm_cfg=dict(type='BN'),
|
309 |
+
act_cfg=dict(type='ReLU'),
|
310 |
+
align_corners=False):
|
311 |
+
|
312 |
+
super(FastSCNN, self).__init__()
|
313 |
+
if global_in_channels != higher_in_channels:
|
314 |
+
raise AssertionError('Global Input Channels must be the same \
|
315 |
+
with Higher Input Channels!')
|
316 |
+
elif global_out_channels != lower_in_channels:
|
317 |
+
raise AssertionError('Global Output Channels must be the same \
|
318 |
+
with Lower Input Channels!')
|
319 |
+
|
320 |
+
self.in_channels = in_channels
|
321 |
+
self.downsample_dw_channels1 = downsample_dw_channels[0]
|
322 |
+
self.downsample_dw_channels2 = downsample_dw_channels[1]
|
323 |
+
self.global_in_channels = global_in_channels
|
324 |
+
self.global_block_channels = global_block_channels
|
325 |
+
self.global_block_strides = global_block_strides
|
326 |
+
self.global_out_channels = global_out_channels
|
327 |
+
self.higher_in_channels = higher_in_channels
|
328 |
+
self.lower_in_channels = lower_in_channels
|
329 |
+
self.fusion_out_channels = fusion_out_channels
|
330 |
+
self.out_indices = out_indices
|
331 |
+
self.conv_cfg = conv_cfg
|
332 |
+
self.norm_cfg = norm_cfg
|
333 |
+
self.act_cfg = act_cfg
|
334 |
+
self.align_corners = align_corners
|
335 |
+
self.learning_to_downsample = LearningToDownsample(
|
336 |
+
in_channels,
|
337 |
+
downsample_dw_channels,
|
338 |
+
global_in_channels,
|
339 |
+
conv_cfg=self.conv_cfg,
|
340 |
+
norm_cfg=self.norm_cfg,
|
341 |
+
act_cfg=self.act_cfg)
|
342 |
+
self.global_feature_extractor = GlobalFeatureExtractor(
|
343 |
+
global_in_channels,
|
344 |
+
global_block_channels,
|
345 |
+
global_out_channels,
|
346 |
+
strides=self.global_block_strides,
|
347 |
+
conv_cfg=self.conv_cfg,
|
348 |
+
norm_cfg=self.norm_cfg,
|
349 |
+
act_cfg=self.act_cfg,
|
350 |
+
align_corners=self.align_corners)
|
351 |
+
self.feature_fusion = FeatureFusionModule(
|
352 |
+
higher_in_channels,
|
353 |
+
lower_in_channels,
|
354 |
+
fusion_out_channels,
|
355 |
+
conv_cfg=self.conv_cfg,
|
356 |
+
norm_cfg=self.norm_cfg,
|
357 |
+
act_cfg=self.act_cfg,
|
358 |
+
align_corners=self.align_corners)
|
359 |
+
|
360 |
+
def init_weights(self, pretrained=None):
|
361 |
+
for m in self.modules():
|
362 |
+
if isinstance(m, nn.Conv2d):
|
363 |
+
kaiming_init(m)
|
364 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
365 |
+
constant_init(m, 1)
|
366 |
+
|
367 |
+
def forward(self, x):
|
368 |
+
higher_res_features = self.learning_to_downsample(x)
|
369 |
+
lower_res_features = self.global_feature_extractor(higher_res_features)
|
370 |
+
fusion_output = self.feature_fusion(higher_res_features,
|
371 |
+
lower_res_features)
|
372 |
+
|
373 |
+
outs = [higher_res_features, lower_res_features, fusion_output]
|
374 |
+
outs = [outs[i] for i in self.out_indices]
|
375 |
+
return tuple(outs)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/hrnet.py
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from annotator.mmpkg.mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
3 |
+
kaiming_init)
|
4 |
+
from annotator.mmpkg.mmcv.runner import load_checkpoint
|
5 |
+
from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm
|
6 |
+
|
7 |
+
from annotator.mmpkg.mmseg.ops import Upsample, resize
|
8 |
+
from annotator.mmpkg.mmseg.utils import get_root_logger
|
9 |
+
from ..builder import BACKBONES
|
10 |
+
from .resnet import BasicBlock, Bottleneck
|
11 |
+
|
12 |
+
|
13 |
+
class HRModule(nn.Module):
|
14 |
+
"""High-Resolution Module for HRNet.
|
15 |
+
|
16 |
+
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
|
17 |
+
is in this module.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
num_branches,
|
22 |
+
blocks,
|
23 |
+
num_blocks,
|
24 |
+
in_channels,
|
25 |
+
num_channels,
|
26 |
+
multiscale_output=True,
|
27 |
+
with_cp=False,
|
28 |
+
conv_cfg=None,
|
29 |
+
norm_cfg=dict(type='BN', requires_grad=True)):
|
30 |
+
super(HRModule, self).__init__()
|
31 |
+
self._check_branches(num_branches, num_blocks, in_channels,
|
32 |
+
num_channels)
|
33 |
+
|
34 |
+
self.in_channels = in_channels
|
35 |
+
self.num_branches = num_branches
|
36 |
+
|
37 |
+
self.multiscale_output = multiscale_output
|
38 |
+
self.norm_cfg = norm_cfg
|
39 |
+
self.conv_cfg = conv_cfg
|
40 |
+
self.with_cp = with_cp
|
41 |
+
self.branches = self._make_branches(num_branches, blocks, num_blocks,
|
42 |
+
num_channels)
|
43 |
+
self.fuse_layers = self._make_fuse_layers()
|
44 |
+
self.relu = nn.ReLU(inplace=False)
|
45 |
+
|
46 |
+
def _check_branches(self, num_branches, num_blocks, in_channels,
|
47 |
+
num_channels):
|
48 |
+
"""Check branches configuration."""
|
49 |
+
if num_branches != len(num_blocks):
|
50 |
+
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
|
51 |
+
f'{len(num_blocks)})'
|
52 |
+
raise ValueError(error_msg)
|
53 |
+
|
54 |
+
if num_branches != len(num_channels):
|
55 |
+
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
|
56 |
+
f'{len(num_channels)})'
|
57 |
+
raise ValueError(error_msg)
|
58 |
+
|
59 |
+
if num_branches != len(in_channels):
|
60 |
+
error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
|
61 |
+
f'{len(in_channels)})'
|
62 |
+
raise ValueError(error_msg)
|
63 |
+
|
64 |
+
def _make_one_branch(self,
|
65 |
+
branch_index,
|
66 |
+
block,
|
67 |
+
num_blocks,
|
68 |
+
num_channels,
|
69 |
+
stride=1):
|
70 |
+
"""Build one branch."""
|
71 |
+
downsample = None
|
72 |
+
if stride != 1 or \
|
73 |
+
self.in_channels[branch_index] != \
|
74 |
+
num_channels[branch_index] * block.expansion:
|
75 |
+
downsample = nn.Sequential(
|
76 |
+
build_conv_layer(
|
77 |
+
self.conv_cfg,
|
78 |
+
self.in_channels[branch_index],
|
79 |
+
num_channels[branch_index] * block.expansion,
|
80 |
+
kernel_size=1,
|
81 |
+
stride=stride,
|
82 |
+
bias=False),
|
83 |
+
build_norm_layer(self.norm_cfg, num_channels[branch_index] *
|
84 |
+
block.expansion)[1])
|
85 |
+
|
86 |
+
layers = []
|
87 |
+
layers.append(
|
88 |
+
block(
|
89 |
+
self.in_channels[branch_index],
|
90 |
+
num_channels[branch_index],
|
91 |
+
stride,
|
92 |
+
downsample=downsample,
|
93 |
+
with_cp=self.with_cp,
|
94 |
+
norm_cfg=self.norm_cfg,
|
95 |
+
conv_cfg=self.conv_cfg))
|
96 |
+
self.in_channels[branch_index] = \
|
97 |
+
num_channels[branch_index] * block.expansion
|
98 |
+
for i in range(1, num_blocks[branch_index]):
|
99 |
+
layers.append(
|
100 |
+
block(
|
101 |
+
self.in_channels[branch_index],
|
102 |
+
num_channels[branch_index],
|
103 |
+
with_cp=self.with_cp,
|
104 |
+
norm_cfg=self.norm_cfg,
|
105 |
+
conv_cfg=self.conv_cfg))
|
106 |
+
|
107 |
+
return nn.Sequential(*layers)
|
108 |
+
|
109 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
110 |
+
"""Build multiple branch."""
|
111 |
+
branches = []
|
112 |
+
|
113 |
+
for i in range(num_branches):
|
114 |
+
branches.append(
|
115 |
+
self._make_one_branch(i, block, num_blocks, num_channels))
|
116 |
+
|
117 |
+
return nn.ModuleList(branches)
|
118 |
+
|
119 |
+
def _make_fuse_layers(self):
|
120 |
+
"""Build fuse layer."""
|
121 |
+
if self.num_branches == 1:
|
122 |
+
return None
|
123 |
+
|
124 |
+
num_branches = self.num_branches
|
125 |
+
in_channels = self.in_channels
|
126 |
+
fuse_layers = []
|
127 |
+
num_out_branches = num_branches if self.multiscale_output else 1
|
128 |
+
for i in range(num_out_branches):
|
129 |
+
fuse_layer = []
|
130 |
+
for j in range(num_branches):
|
131 |
+
if j > i:
|
132 |
+
fuse_layer.append(
|
133 |
+
nn.Sequential(
|
134 |
+
build_conv_layer(
|
135 |
+
self.conv_cfg,
|
136 |
+
in_channels[j],
|
137 |
+
in_channels[i],
|
138 |
+
kernel_size=1,
|
139 |
+
stride=1,
|
140 |
+
padding=0,
|
141 |
+
bias=False),
|
142 |
+
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
143 |
+
# we set align_corners=False for HRNet
|
144 |
+
Upsample(
|
145 |
+
scale_factor=2**(j - i),
|
146 |
+
mode='bilinear',
|
147 |
+
align_corners=False)))
|
148 |
+
elif j == i:
|
149 |
+
fuse_layer.append(None)
|
150 |
+
else:
|
151 |
+
conv_downsamples = []
|
152 |
+
for k in range(i - j):
|
153 |
+
if k == i - j - 1:
|
154 |
+
conv_downsamples.append(
|
155 |
+
nn.Sequential(
|
156 |
+
build_conv_layer(
|
157 |
+
self.conv_cfg,
|
158 |
+
in_channels[j],
|
159 |
+
in_channels[i],
|
160 |
+
kernel_size=3,
|
161 |
+
stride=2,
|
162 |
+
padding=1,
|
163 |
+
bias=False),
|
164 |
+
build_norm_layer(self.norm_cfg,
|
165 |
+
in_channels[i])[1]))
|
166 |
+
else:
|
167 |
+
conv_downsamples.append(
|
168 |
+
nn.Sequential(
|
169 |
+
build_conv_layer(
|
170 |
+
self.conv_cfg,
|
171 |
+
in_channels[j],
|
172 |
+
in_channels[j],
|
173 |
+
kernel_size=3,
|
174 |
+
stride=2,
|
175 |
+
padding=1,
|
176 |
+
bias=False),
|
177 |
+
build_norm_layer(self.norm_cfg,
|
178 |
+
in_channels[j])[1],
|
179 |
+
nn.ReLU(inplace=False)))
|
180 |
+
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
181 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
182 |
+
|
183 |
+
return nn.ModuleList(fuse_layers)
|
184 |
+
|
185 |
+
def forward(self, x):
|
186 |
+
"""Forward function."""
|
187 |
+
if self.num_branches == 1:
|
188 |
+
return [self.branches[0](x[0])]
|
189 |
+
|
190 |
+
for i in range(self.num_branches):
|
191 |
+
x[i] = self.branches[i](x[i])
|
192 |
+
|
193 |
+
x_fuse = []
|
194 |
+
for i in range(len(self.fuse_layers)):
|
195 |
+
y = 0
|
196 |
+
for j in range(self.num_branches):
|
197 |
+
if i == j:
|
198 |
+
y += x[j]
|
199 |
+
elif j > i:
|
200 |
+
y = y + resize(
|
201 |
+
self.fuse_layers[i][j](x[j]),
|
202 |
+
size=x[i].shape[2:],
|
203 |
+
mode='bilinear',
|
204 |
+
align_corners=False)
|
205 |
+
else:
|
206 |
+
y += self.fuse_layers[i][j](x[j])
|
207 |
+
x_fuse.append(self.relu(y))
|
208 |
+
return x_fuse
|
209 |
+
|
210 |
+
|
211 |
+
@BACKBONES.register_module()
|
212 |
+
class HRNet(nn.Module):
|
213 |
+
"""HRNet backbone.
|
214 |
+
|
215 |
+
High-Resolution Representations for Labeling Pixels and Regions
|
216 |
+
arXiv: https://arxiv.org/abs/1904.04514
|
217 |
+
|
218 |
+
Args:
|
219 |
+
extra (dict): detailed configuration for each stage of HRNet.
|
220 |
+
in_channels (int): Number of input image channels. Normally 3.
|
221 |
+
conv_cfg (dict): dictionary to construct and config conv layer.
|
222 |
+
norm_cfg (dict): dictionary to construct and config norm layer.
|
223 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
224 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
225 |
+
and its variants only.
|
226 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
227 |
+
memory while slowing down the training speed.
|
228 |
+
zero_init_residual (bool): whether to use zero init for last norm layer
|
229 |
+
in resblocks to let them behave as identity.
|
230 |
+
|
231 |
+
Example:
|
232 |
+
>>> from annotator.mmpkg.mmseg.models import HRNet
|
233 |
+
>>> import torch
|
234 |
+
>>> extra = dict(
|
235 |
+
>>> stage1=dict(
|
236 |
+
>>> num_modules=1,
|
237 |
+
>>> num_branches=1,
|
238 |
+
>>> block='BOTTLENECK',
|
239 |
+
>>> num_blocks=(4, ),
|
240 |
+
>>> num_channels=(64, )),
|
241 |
+
>>> stage2=dict(
|
242 |
+
>>> num_modules=1,
|
243 |
+
>>> num_branches=2,
|
244 |
+
>>> block='BASIC',
|
245 |
+
>>> num_blocks=(4, 4),
|
246 |
+
>>> num_channels=(32, 64)),
|
247 |
+
>>> stage3=dict(
|
248 |
+
>>> num_modules=4,
|
249 |
+
>>> num_branches=3,
|
250 |
+
>>> block='BASIC',
|
251 |
+
>>> num_blocks=(4, 4, 4),
|
252 |
+
>>> num_channels=(32, 64, 128)),
|
253 |
+
>>> stage4=dict(
|
254 |
+
>>> num_modules=3,
|
255 |
+
>>> num_branches=4,
|
256 |
+
>>> block='BASIC',
|
257 |
+
>>> num_blocks=(4, 4, 4, 4),
|
258 |
+
>>> num_channels=(32, 64, 128, 256)))
|
259 |
+
>>> self = HRNet(extra, in_channels=1)
|
260 |
+
>>> self.eval()
|
261 |
+
>>> inputs = torch.rand(1, 1, 32, 32)
|
262 |
+
>>> level_outputs = self.forward(inputs)
|
263 |
+
>>> for level_out in level_outputs:
|
264 |
+
... print(tuple(level_out.shape))
|
265 |
+
(1, 32, 8, 8)
|
266 |
+
(1, 64, 4, 4)
|
267 |
+
(1, 128, 2, 2)
|
268 |
+
(1, 256, 1, 1)
|
269 |
+
"""
|
270 |
+
|
271 |
+
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
|
272 |
+
|
273 |
+
def __init__(self,
|
274 |
+
extra,
|
275 |
+
in_channels=3,
|
276 |
+
conv_cfg=None,
|
277 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
278 |
+
norm_eval=False,
|
279 |
+
with_cp=False,
|
280 |
+
zero_init_residual=False):
|
281 |
+
super(HRNet, self).__init__()
|
282 |
+
self.extra = extra
|
283 |
+
self.conv_cfg = conv_cfg
|
284 |
+
self.norm_cfg = norm_cfg
|
285 |
+
self.norm_eval = norm_eval
|
286 |
+
self.with_cp = with_cp
|
287 |
+
self.zero_init_residual = zero_init_residual
|
288 |
+
|
289 |
+
# stem net
|
290 |
+
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
291 |
+
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
|
292 |
+
|
293 |
+
self.conv1 = build_conv_layer(
|
294 |
+
self.conv_cfg,
|
295 |
+
in_channels,
|
296 |
+
64,
|
297 |
+
kernel_size=3,
|
298 |
+
stride=2,
|
299 |
+
padding=1,
|
300 |
+
bias=False)
|
301 |
+
|
302 |
+
self.add_module(self.norm1_name, norm1)
|
303 |
+
self.conv2 = build_conv_layer(
|
304 |
+
self.conv_cfg,
|
305 |
+
64,
|
306 |
+
64,
|
307 |
+
kernel_size=3,
|
308 |
+
stride=2,
|
309 |
+
padding=1,
|
310 |
+
bias=False)
|
311 |
+
|
312 |
+
self.add_module(self.norm2_name, norm2)
|
313 |
+
self.relu = nn.ReLU(inplace=True)
|
314 |
+
|
315 |
+
# stage 1
|
316 |
+
self.stage1_cfg = self.extra['stage1']
|
317 |
+
num_channels = self.stage1_cfg['num_channels'][0]
|
318 |
+
block_type = self.stage1_cfg['block']
|
319 |
+
num_blocks = self.stage1_cfg['num_blocks'][0]
|
320 |
+
|
321 |
+
block = self.blocks_dict[block_type]
|
322 |
+
stage1_out_channels = num_channels * block.expansion
|
323 |
+
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
324 |
+
|
325 |
+
# stage 2
|
326 |
+
self.stage2_cfg = self.extra['stage2']
|
327 |
+
num_channels = self.stage2_cfg['num_channels']
|
328 |
+
block_type = self.stage2_cfg['block']
|
329 |
+
|
330 |
+
block = self.blocks_dict[block_type]
|
331 |
+
num_channels = [channel * block.expansion for channel in num_channels]
|
332 |
+
self.transition1 = self._make_transition_layer([stage1_out_channels],
|
333 |
+
num_channels)
|
334 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
335 |
+
self.stage2_cfg, num_channels)
|
336 |
+
|
337 |
+
# stage 3
|
338 |
+
self.stage3_cfg = self.extra['stage3']
|
339 |
+
num_channels = self.stage3_cfg['num_channels']
|
340 |
+
block_type = self.stage3_cfg['block']
|
341 |
+
|
342 |
+
block = self.blocks_dict[block_type]
|
343 |
+
num_channels = [channel * block.expansion for channel in num_channels]
|
344 |
+
self.transition2 = self._make_transition_layer(pre_stage_channels,
|
345 |
+
num_channels)
|
346 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
347 |
+
self.stage3_cfg, num_channels)
|
348 |
+
|
349 |
+
# stage 4
|
350 |
+
self.stage4_cfg = self.extra['stage4']
|
351 |
+
num_channels = self.stage4_cfg['num_channels']
|
352 |
+
block_type = self.stage4_cfg['block']
|
353 |
+
|
354 |
+
block = self.blocks_dict[block_type]
|
355 |
+
num_channels = [channel * block.expansion for channel in num_channels]
|
356 |
+
self.transition3 = self._make_transition_layer(pre_stage_channels,
|
357 |
+
num_channels)
|
358 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
359 |
+
self.stage4_cfg, num_channels)
|
360 |
+
|
361 |
+
@property
|
362 |
+
def norm1(self):
|
363 |
+
"""nn.Module: the normalization layer named "norm1" """
|
364 |
+
return getattr(self, self.norm1_name)
|
365 |
+
|
366 |
+
@property
|
367 |
+
def norm2(self):
|
368 |
+
"""nn.Module: the normalization layer named "norm2" """
|
369 |
+
return getattr(self, self.norm2_name)
|
370 |
+
|
371 |
+
def _make_transition_layer(self, num_channels_pre_layer,
|
372 |
+
num_channels_cur_layer):
|
373 |
+
"""Make transition layer."""
|
374 |
+
num_branches_cur = len(num_channels_cur_layer)
|
375 |
+
num_branches_pre = len(num_channels_pre_layer)
|
376 |
+
|
377 |
+
transition_layers = []
|
378 |
+
for i in range(num_branches_cur):
|
379 |
+
if i < num_branches_pre:
|
380 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
381 |
+
transition_layers.append(
|
382 |
+
nn.Sequential(
|
383 |
+
build_conv_layer(
|
384 |
+
self.conv_cfg,
|
385 |
+
num_channels_pre_layer[i],
|
386 |
+
num_channels_cur_layer[i],
|
387 |
+
kernel_size=3,
|
388 |
+
stride=1,
|
389 |
+
padding=1,
|
390 |
+
bias=False),
|
391 |
+
build_norm_layer(self.norm_cfg,
|
392 |
+
num_channels_cur_layer[i])[1],
|
393 |
+
nn.ReLU(inplace=True)))
|
394 |
+
else:
|
395 |
+
transition_layers.append(None)
|
396 |
+
else:
|
397 |
+
conv_downsamples = []
|
398 |
+
for j in range(i + 1 - num_branches_pre):
|
399 |
+
in_channels = num_channels_pre_layer[-1]
|
400 |
+
out_channels = num_channels_cur_layer[i] \
|
401 |
+
if j == i - num_branches_pre else in_channels
|
402 |
+
conv_downsamples.append(
|
403 |
+
nn.Sequential(
|
404 |
+
build_conv_layer(
|
405 |
+
self.conv_cfg,
|
406 |
+
in_channels,
|
407 |
+
out_channels,
|
408 |
+
kernel_size=3,
|
409 |
+
stride=2,
|
410 |
+
padding=1,
|
411 |
+
bias=False),
|
412 |
+
build_norm_layer(self.norm_cfg, out_channels)[1],
|
413 |
+
nn.ReLU(inplace=True)))
|
414 |
+
transition_layers.append(nn.Sequential(*conv_downsamples))
|
415 |
+
|
416 |
+
return nn.ModuleList(transition_layers)
|
417 |
+
|
418 |
+
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
419 |
+
"""Make each layer."""
|
420 |
+
downsample = None
|
421 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
422 |
+
downsample = nn.Sequential(
|
423 |
+
build_conv_layer(
|
424 |
+
self.conv_cfg,
|
425 |
+
inplanes,
|
426 |
+
planes * block.expansion,
|
427 |
+
kernel_size=1,
|
428 |
+
stride=stride,
|
429 |
+
bias=False),
|
430 |
+
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
431 |
+
|
432 |
+
layers = []
|
433 |
+
layers.append(
|
434 |
+
block(
|
435 |
+
inplanes,
|
436 |
+
planes,
|
437 |
+
stride,
|
438 |
+
downsample=downsample,
|
439 |
+
with_cp=self.with_cp,
|
440 |
+
norm_cfg=self.norm_cfg,
|
441 |
+
conv_cfg=self.conv_cfg))
|
442 |
+
inplanes = planes * block.expansion
|
443 |
+
for i in range(1, blocks):
|
444 |
+
layers.append(
|
445 |
+
block(
|
446 |
+
inplanes,
|
447 |
+
planes,
|
448 |
+
with_cp=self.with_cp,
|
449 |
+
norm_cfg=self.norm_cfg,
|
450 |
+
conv_cfg=self.conv_cfg))
|
451 |
+
|
452 |
+
return nn.Sequential(*layers)
|
453 |
+
|
454 |
+
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
|
455 |
+
"""Make each stage."""
|
456 |
+
num_modules = layer_config['num_modules']
|
457 |
+
num_branches = layer_config['num_branches']
|
458 |
+
num_blocks = layer_config['num_blocks']
|
459 |
+
num_channels = layer_config['num_channels']
|
460 |
+
block = self.blocks_dict[layer_config['block']]
|
461 |
+
|
462 |
+
hr_modules = []
|
463 |
+
for i in range(num_modules):
|
464 |
+
# multi_scale_output is only used for the last module
|
465 |
+
if not multiscale_output and i == num_modules - 1:
|
466 |
+
reset_multiscale_output = False
|
467 |
+
else:
|
468 |
+
reset_multiscale_output = True
|
469 |
+
|
470 |
+
hr_modules.append(
|
471 |
+
HRModule(
|
472 |
+
num_branches,
|
473 |
+
block,
|
474 |
+
num_blocks,
|
475 |
+
in_channels,
|
476 |
+
num_channels,
|
477 |
+
reset_multiscale_output,
|
478 |
+
with_cp=self.with_cp,
|
479 |
+
norm_cfg=self.norm_cfg,
|
480 |
+
conv_cfg=self.conv_cfg))
|
481 |
+
|
482 |
+
return nn.Sequential(*hr_modules), in_channels
|
483 |
+
|
484 |
+
def init_weights(self, pretrained=None):
|
485 |
+
"""Initialize the weights in backbone.
|
486 |
+
|
487 |
+
Args:
|
488 |
+
pretrained (str, optional): Path to pre-trained weights.
|
489 |
+
Defaults to None.
|
490 |
+
"""
|
491 |
+
if isinstance(pretrained, str):
|
492 |
+
logger = get_root_logger()
|
493 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
494 |
+
elif pretrained is None:
|
495 |
+
for m in self.modules():
|
496 |
+
if isinstance(m, nn.Conv2d):
|
497 |
+
kaiming_init(m)
|
498 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
499 |
+
constant_init(m, 1)
|
500 |
+
|
501 |
+
if self.zero_init_residual:
|
502 |
+
for m in self.modules():
|
503 |
+
if isinstance(m, Bottleneck):
|
504 |
+
constant_init(m.norm3, 0)
|
505 |
+
elif isinstance(m, BasicBlock):
|
506 |
+
constant_init(m.norm2, 0)
|
507 |
+
else:
|
508 |
+
raise TypeError('pretrained must be a str or None')
|
509 |
+
|
510 |
+
def forward(self, x):
|
511 |
+
"""Forward function."""
|
512 |
+
|
513 |
+
x = self.conv1(x)
|
514 |
+
x = self.norm1(x)
|
515 |
+
x = self.relu(x)
|
516 |
+
x = self.conv2(x)
|
517 |
+
x = self.norm2(x)
|
518 |
+
x = self.relu(x)
|
519 |
+
x = self.layer1(x)
|
520 |
+
|
521 |
+
x_list = []
|
522 |
+
for i in range(self.stage2_cfg['num_branches']):
|
523 |
+
if self.transition1[i] is not None:
|
524 |
+
x_list.append(self.transition1[i](x))
|
525 |
+
else:
|
526 |
+
x_list.append(x)
|
527 |
+
y_list = self.stage2(x_list)
|
528 |
+
|
529 |
+
x_list = []
|
530 |
+
for i in range(self.stage3_cfg['num_branches']):
|
531 |
+
if self.transition2[i] is not None:
|
532 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
533 |
+
else:
|
534 |
+
x_list.append(y_list[i])
|
535 |
+
y_list = self.stage3(x_list)
|
536 |
+
|
537 |
+
x_list = []
|
538 |
+
for i in range(self.stage4_cfg['num_branches']):
|
539 |
+
if self.transition3[i] is not None:
|
540 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
541 |
+
else:
|
542 |
+
x_list.append(y_list[i])
|
543 |
+
y_list = self.stage4(x_list)
|
544 |
+
|
545 |
+
return y_list
|
546 |
+
|
547 |
+
def train(self, mode=True):
|
548 |
+
"""Convert the model into training mode will keeping the normalization
|
549 |
+
layer freezed."""
|
550 |
+
super(HRNet, self).train(mode)
|
551 |
+
if mode and self.norm_eval:
|
552 |
+
for m in self.modules():
|
553 |
+
# trick: eval have effect on BatchNorm only
|
554 |
+
if isinstance(m, _BatchNorm):
|
555 |
+
m.eval()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/mobilenet_v2.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule, constant_init, kaiming_init
|
5 |
+
from annotator.mmpkg.mmcv.runner import load_checkpoint
|
6 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
7 |
+
|
8 |
+
from ..builder import BACKBONES
|
9 |
+
from ..utils import InvertedResidual, make_divisible
|
10 |
+
|
11 |
+
|
12 |
+
@BACKBONES.register_module()
|
13 |
+
class MobileNetV2(nn.Module):
|
14 |
+
"""MobileNetV2 backbone.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
widen_factor (float): Width multiplier, multiply number of
|
18 |
+
channels in each layer by this amount. Default: 1.0.
|
19 |
+
strides (Sequence[int], optional): Strides of the first block of each
|
20 |
+
layer. If not specified, default config in ``arch_setting`` will
|
21 |
+
be used.
|
22 |
+
dilations (Sequence[int]): Dilation of each layer.
|
23 |
+
out_indices (None or Sequence[int]): Output from which stages.
|
24 |
+
Default: (7, ).
|
25 |
+
frozen_stages (int): Stages to be frozen (all param fixed).
|
26 |
+
Default: -1, which means not freezing any parameters.
|
27 |
+
conv_cfg (dict): Config dict for convolution layer.
|
28 |
+
Default: None, which means using conv2d.
|
29 |
+
norm_cfg (dict): Config dict for normalization layer.
|
30 |
+
Default: dict(type='BN').
|
31 |
+
act_cfg (dict): Config dict for activation layer.
|
32 |
+
Default: dict(type='ReLU6').
|
33 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
34 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
35 |
+
and its variants only. Default: False.
|
36 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
37 |
+
memory while slowing down the training speed. Default: False.
|
38 |
+
"""
|
39 |
+
|
40 |
+
# Parameters to build layers. 3 parameters are needed to construct a
|
41 |
+
# layer, from left to right: expand_ratio, channel, num_blocks.
|
42 |
+
arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
|
43 |
+
[6, 96, 3], [6, 160, 3], [6, 320, 1]]
|
44 |
+
|
45 |
+
def __init__(self,
|
46 |
+
widen_factor=1.,
|
47 |
+
strides=(1, 2, 2, 2, 1, 2, 1),
|
48 |
+
dilations=(1, 1, 1, 1, 1, 1, 1),
|
49 |
+
out_indices=(1, 2, 4, 6),
|
50 |
+
frozen_stages=-1,
|
51 |
+
conv_cfg=None,
|
52 |
+
norm_cfg=dict(type='BN'),
|
53 |
+
act_cfg=dict(type='ReLU6'),
|
54 |
+
norm_eval=False,
|
55 |
+
with_cp=False):
|
56 |
+
super(MobileNetV2, self).__init__()
|
57 |
+
self.widen_factor = widen_factor
|
58 |
+
self.strides = strides
|
59 |
+
self.dilations = dilations
|
60 |
+
assert len(strides) == len(dilations) == len(self.arch_settings)
|
61 |
+
self.out_indices = out_indices
|
62 |
+
for index in out_indices:
|
63 |
+
if index not in range(0, 7):
|
64 |
+
raise ValueError('the item in out_indices must in '
|
65 |
+
f'range(0, 8). But received {index}')
|
66 |
+
|
67 |
+
if frozen_stages not in range(-1, 7):
|
68 |
+
raise ValueError('frozen_stages must be in range(-1, 7). '
|
69 |
+
f'But received {frozen_stages}')
|
70 |
+
self.out_indices = out_indices
|
71 |
+
self.frozen_stages = frozen_stages
|
72 |
+
self.conv_cfg = conv_cfg
|
73 |
+
self.norm_cfg = norm_cfg
|
74 |
+
self.act_cfg = act_cfg
|
75 |
+
self.norm_eval = norm_eval
|
76 |
+
self.with_cp = with_cp
|
77 |
+
|
78 |
+
self.in_channels = make_divisible(32 * widen_factor, 8)
|
79 |
+
|
80 |
+
self.conv1 = ConvModule(
|
81 |
+
in_channels=3,
|
82 |
+
out_channels=self.in_channels,
|
83 |
+
kernel_size=3,
|
84 |
+
stride=2,
|
85 |
+
padding=1,
|
86 |
+
conv_cfg=self.conv_cfg,
|
87 |
+
norm_cfg=self.norm_cfg,
|
88 |
+
act_cfg=self.act_cfg)
|
89 |
+
|
90 |
+
self.layers = []
|
91 |
+
|
92 |
+
for i, layer_cfg in enumerate(self.arch_settings):
|
93 |
+
expand_ratio, channel, num_blocks = layer_cfg
|
94 |
+
stride = self.strides[i]
|
95 |
+
dilation = self.dilations[i]
|
96 |
+
out_channels = make_divisible(channel * widen_factor, 8)
|
97 |
+
inverted_res_layer = self.make_layer(
|
98 |
+
out_channels=out_channels,
|
99 |
+
num_blocks=num_blocks,
|
100 |
+
stride=stride,
|
101 |
+
dilation=dilation,
|
102 |
+
expand_ratio=expand_ratio)
|
103 |
+
layer_name = f'layer{i + 1}'
|
104 |
+
self.add_module(layer_name, inverted_res_layer)
|
105 |
+
self.layers.append(layer_name)
|
106 |
+
|
107 |
+
def make_layer(self, out_channels, num_blocks, stride, dilation,
|
108 |
+
expand_ratio):
|
109 |
+
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
out_channels (int): out_channels of block.
|
113 |
+
num_blocks (int): Number of blocks.
|
114 |
+
stride (int): Stride of the first block.
|
115 |
+
dilation (int): Dilation of the first block.
|
116 |
+
expand_ratio (int): Expand the number of channels of the
|
117 |
+
hidden layer in InvertedResidual by this ratio.
|
118 |
+
"""
|
119 |
+
layers = []
|
120 |
+
for i in range(num_blocks):
|
121 |
+
layers.append(
|
122 |
+
InvertedResidual(
|
123 |
+
self.in_channels,
|
124 |
+
out_channels,
|
125 |
+
stride if i == 0 else 1,
|
126 |
+
expand_ratio=expand_ratio,
|
127 |
+
dilation=dilation if i == 0 else 1,
|
128 |
+
conv_cfg=self.conv_cfg,
|
129 |
+
norm_cfg=self.norm_cfg,
|
130 |
+
act_cfg=self.act_cfg,
|
131 |
+
with_cp=self.with_cp))
|
132 |
+
self.in_channels = out_channels
|
133 |
+
|
134 |
+
return nn.Sequential(*layers)
|
135 |
+
|
136 |
+
def init_weights(self, pretrained=None):
|
137 |
+
if isinstance(pretrained, str):
|
138 |
+
logger = logging.getLogger()
|
139 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
140 |
+
elif pretrained is None:
|
141 |
+
for m in self.modules():
|
142 |
+
if isinstance(m, nn.Conv2d):
|
143 |
+
kaiming_init(m)
|
144 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
145 |
+
constant_init(m, 1)
|
146 |
+
else:
|
147 |
+
raise TypeError('pretrained must be a str or None')
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = self.conv1(x)
|
151 |
+
|
152 |
+
outs = []
|
153 |
+
for i, layer_name in enumerate(self.layers):
|
154 |
+
layer = getattr(self, layer_name)
|
155 |
+
x = layer(x)
|
156 |
+
if i in self.out_indices:
|
157 |
+
outs.append(x)
|
158 |
+
|
159 |
+
if len(outs) == 1:
|
160 |
+
return outs[0]
|
161 |
+
else:
|
162 |
+
return tuple(outs)
|
163 |
+
|
164 |
+
def _freeze_stages(self):
|
165 |
+
if self.frozen_stages >= 0:
|
166 |
+
for param in self.conv1.parameters():
|
167 |
+
param.requires_grad = False
|
168 |
+
for i in range(1, self.frozen_stages + 1):
|
169 |
+
layer = getattr(self, f'layer{i}')
|
170 |
+
layer.eval()
|
171 |
+
for param in layer.parameters():
|
172 |
+
param.requires_grad = False
|
173 |
+
|
174 |
+
def train(self, mode=True):
|
175 |
+
super(MobileNetV2, self).train(mode)
|
176 |
+
self._freeze_stages()
|
177 |
+
if mode and self.norm_eval:
|
178 |
+
for m in self.modules():
|
179 |
+
if isinstance(m, _BatchNorm):
|
180 |
+
m.eval()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/mobilenet_v3.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import annotator.mmpkg.mmcv as mmcv
|
4 |
+
import torch.nn as nn
|
5 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule, constant_init, kaiming_init
|
6 |
+
from annotator.mmpkg.mmcv.cnn.bricks import Conv2dAdaptivePadding
|
7 |
+
from annotator.mmpkg.mmcv.runner import load_checkpoint
|
8 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
9 |
+
|
10 |
+
from ..builder import BACKBONES
|
11 |
+
from ..utils import InvertedResidualV3 as InvertedResidual
|
12 |
+
|
13 |
+
|
14 |
+
@BACKBONES.register_module()
|
15 |
+
class MobileNetV3(nn.Module):
|
16 |
+
"""MobileNetV3 backbone.
|
17 |
+
|
18 |
+
This backbone is the improved implementation of `Searching for MobileNetV3
|
19 |
+
<https://ieeexplore.ieee.org/document/9008835>`_.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
|
23 |
+
Default: 'small'.
|
24 |
+
conv_cfg (dict): Config dict for convolution layer.
|
25 |
+
Default: None, which means using conv2d.
|
26 |
+
norm_cfg (dict): Config dict for normalization layer.
|
27 |
+
Default: dict(type='BN').
|
28 |
+
out_indices (tuple[int]): Output from which layer.
|
29 |
+
Default: (0, 1, 12).
|
30 |
+
frozen_stages (int): Stages to be frozen (all param fixed).
|
31 |
+
Default: -1, which means not freezing any parameters.
|
32 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
33 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
34 |
+
and its variants only. Default: False.
|
35 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
36 |
+
some memory while slowing down the training speed.
|
37 |
+
Default: False.
|
38 |
+
"""
|
39 |
+
# Parameters to build each block:
|
40 |
+
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
41 |
+
arch_settings = {
|
42 |
+
'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
|
43 |
+
[3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
|
44 |
+
[3, 88, 24, False, 'ReLU', 1],
|
45 |
+
[5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
|
46 |
+
[5, 240, 40, True, 'HSwish', 1],
|
47 |
+
[5, 240, 40, True, 'HSwish', 1],
|
48 |
+
[5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
|
49 |
+
[5, 144, 48, True, 'HSwish', 1],
|
50 |
+
[5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
|
51 |
+
[5, 576, 96, True, 'HSwish', 1],
|
52 |
+
[5, 576, 96, True, 'HSwish', 1]],
|
53 |
+
'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
|
54 |
+
[3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
|
55 |
+
[3, 72, 24, False, 'ReLU', 1],
|
56 |
+
[5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
|
57 |
+
[5, 120, 40, True, 'ReLU', 1],
|
58 |
+
[5, 120, 40, True, 'ReLU', 1],
|
59 |
+
[3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
|
60 |
+
[3, 200, 80, False, 'HSwish', 1],
|
61 |
+
[3, 184, 80, False, 'HSwish', 1],
|
62 |
+
[3, 184, 80, False, 'HSwish', 1],
|
63 |
+
[3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
|
64 |
+
[3, 672, 112, True, 'HSwish', 1],
|
65 |
+
[5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
|
66 |
+
[5, 960, 160, True, 'HSwish', 1],
|
67 |
+
[5, 960, 160, True, 'HSwish', 1]]
|
68 |
+
} # yapf: disable
|
69 |
+
|
70 |
+
def __init__(self,
|
71 |
+
arch='small',
|
72 |
+
conv_cfg=None,
|
73 |
+
norm_cfg=dict(type='BN'),
|
74 |
+
out_indices=(0, 1, 12),
|
75 |
+
frozen_stages=-1,
|
76 |
+
reduction_factor=1,
|
77 |
+
norm_eval=False,
|
78 |
+
with_cp=False):
|
79 |
+
super(MobileNetV3, self).__init__()
|
80 |
+
assert arch in self.arch_settings
|
81 |
+
assert isinstance(reduction_factor, int) and reduction_factor > 0
|
82 |
+
assert mmcv.is_tuple_of(out_indices, int)
|
83 |
+
for index in out_indices:
|
84 |
+
if index not in range(0, len(self.arch_settings[arch]) + 2):
|
85 |
+
raise ValueError(
|
86 |
+
'the item in out_indices must in '
|
87 |
+
f'range(0, {len(self.arch_settings[arch])+2}). '
|
88 |
+
f'But received {index}')
|
89 |
+
|
90 |
+
if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
|
91 |
+
raise ValueError('frozen_stages must be in range(-1, '
|
92 |
+
f'{len(self.arch_settings[arch])+2}). '
|
93 |
+
f'But received {frozen_stages}')
|
94 |
+
self.arch = arch
|
95 |
+
self.conv_cfg = conv_cfg
|
96 |
+
self.norm_cfg = norm_cfg
|
97 |
+
self.out_indices = out_indices
|
98 |
+
self.frozen_stages = frozen_stages
|
99 |
+
self.reduction_factor = reduction_factor
|
100 |
+
self.norm_eval = norm_eval
|
101 |
+
self.with_cp = with_cp
|
102 |
+
self.layers = self._make_layer()
|
103 |
+
|
104 |
+
def _make_layer(self):
|
105 |
+
layers = []
|
106 |
+
|
107 |
+
# build the first layer (layer0)
|
108 |
+
in_channels = 16
|
109 |
+
layer = ConvModule(
|
110 |
+
in_channels=3,
|
111 |
+
out_channels=in_channels,
|
112 |
+
kernel_size=3,
|
113 |
+
stride=2,
|
114 |
+
padding=1,
|
115 |
+
conv_cfg=dict(type='Conv2dAdaptivePadding'),
|
116 |
+
norm_cfg=self.norm_cfg,
|
117 |
+
act_cfg=dict(type='HSwish'))
|
118 |
+
self.add_module('layer0', layer)
|
119 |
+
layers.append('layer0')
|
120 |
+
|
121 |
+
layer_setting = self.arch_settings[self.arch]
|
122 |
+
for i, params in enumerate(layer_setting):
|
123 |
+
(kernel_size, mid_channels, out_channels, with_se, act,
|
124 |
+
stride) = params
|
125 |
+
|
126 |
+
if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
|
127 |
+
i >= 8:
|
128 |
+
mid_channels = mid_channels // self.reduction_factor
|
129 |
+
out_channels = out_channels // self.reduction_factor
|
130 |
+
|
131 |
+
if with_se:
|
132 |
+
se_cfg = dict(
|
133 |
+
channels=mid_channels,
|
134 |
+
ratio=4,
|
135 |
+
act_cfg=(dict(type='ReLU'),
|
136 |
+
dict(type='HSigmoid', bias=3.0, divisor=6.0)))
|
137 |
+
else:
|
138 |
+
se_cfg = None
|
139 |
+
|
140 |
+
layer = InvertedResidual(
|
141 |
+
in_channels=in_channels,
|
142 |
+
out_channels=out_channels,
|
143 |
+
mid_channels=mid_channels,
|
144 |
+
kernel_size=kernel_size,
|
145 |
+
stride=stride,
|
146 |
+
se_cfg=se_cfg,
|
147 |
+
with_expand_conv=(in_channels != mid_channels),
|
148 |
+
conv_cfg=self.conv_cfg,
|
149 |
+
norm_cfg=self.norm_cfg,
|
150 |
+
act_cfg=dict(type=act),
|
151 |
+
with_cp=self.with_cp)
|
152 |
+
in_channels = out_channels
|
153 |
+
layer_name = 'layer{}'.format(i + 1)
|
154 |
+
self.add_module(layer_name, layer)
|
155 |
+
layers.append(layer_name)
|
156 |
+
|
157 |
+
# build the last layer
|
158 |
+
# block5 layer12 os=32 for small model
|
159 |
+
# block6 layer16 os=32 for large model
|
160 |
+
layer = ConvModule(
|
161 |
+
in_channels=in_channels,
|
162 |
+
out_channels=576 if self.arch == 'small' else 960,
|
163 |
+
kernel_size=1,
|
164 |
+
stride=1,
|
165 |
+
dilation=4,
|
166 |
+
padding=0,
|
167 |
+
conv_cfg=self.conv_cfg,
|
168 |
+
norm_cfg=self.norm_cfg,
|
169 |
+
act_cfg=dict(type='HSwish'))
|
170 |
+
layer_name = 'layer{}'.format(len(layer_setting) + 1)
|
171 |
+
self.add_module(layer_name, layer)
|
172 |
+
layers.append(layer_name)
|
173 |
+
|
174 |
+
# next, convert backbone MobileNetV3 to a semantic segmentation version
|
175 |
+
if self.arch == 'small':
|
176 |
+
self.layer4.depthwise_conv.conv.stride = (1, 1)
|
177 |
+
self.layer9.depthwise_conv.conv.stride = (1, 1)
|
178 |
+
for i in range(4, len(layers)):
|
179 |
+
layer = getattr(self, layers[i])
|
180 |
+
if isinstance(layer, InvertedResidual):
|
181 |
+
modified_module = layer.depthwise_conv.conv
|
182 |
+
else:
|
183 |
+
modified_module = layer.conv
|
184 |
+
|
185 |
+
if i < 9:
|
186 |
+
modified_module.dilation = (2, 2)
|
187 |
+
pad = 2
|
188 |
+
else:
|
189 |
+
modified_module.dilation = (4, 4)
|
190 |
+
pad = 4
|
191 |
+
|
192 |
+
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
193 |
+
# Adjust padding
|
194 |
+
pad *= (modified_module.kernel_size[0] - 1) // 2
|
195 |
+
modified_module.padding = (pad, pad)
|
196 |
+
else:
|
197 |
+
self.layer7.depthwise_conv.conv.stride = (1, 1)
|
198 |
+
self.layer13.depthwise_conv.conv.stride = (1, 1)
|
199 |
+
for i in range(7, len(layers)):
|
200 |
+
layer = getattr(self, layers[i])
|
201 |
+
if isinstance(layer, InvertedResidual):
|
202 |
+
modified_module = layer.depthwise_conv.conv
|
203 |
+
else:
|
204 |
+
modified_module = layer.conv
|
205 |
+
|
206 |
+
if i < 13:
|
207 |
+
modified_module.dilation = (2, 2)
|
208 |
+
pad = 2
|
209 |
+
else:
|
210 |
+
modified_module.dilation = (4, 4)
|
211 |
+
pad = 4
|
212 |
+
|
213 |
+
if not isinstance(modified_module, Conv2dAdaptivePadding):
|
214 |
+
# Adjust padding
|
215 |
+
pad *= (modified_module.kernel_size[0] - 1) // 2
|
216 |
+
modified_module.padding = (pad, pad)
|
217 |
+
|
218 |
+
return layers
|
219 |
+
|
220 |
+
def init_weights(self, pretrained=None):
|
221 |
+
if isinstance(pretrained, str):
|
222 |
+
logger = logging.getLogger()
|
223 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
224 |
+
elif pretrained is None:
|
225 |
+
for m in self.modules():
|
226 |
+
if isinstance(m, nn.Conv2d):
|
227 |
+
kaiming_init(m)
|
228 |
+
elif isinstance(m, nn.BatchNorm2d):
|
229 |
+
constant_init(m, 1)
|
230 |
+
else:
|
231 |
+
raise TypeError('pretrained must be a str or None')
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
outs = []
|
235 |
+
for i, layer_name in enumerate(self.layers):
|
236 |
+
layer = getattr(self, layer_name)
|
237 |
+
x = layer(x)
|
238 |
+
if i in self.out_indices:
|
239 |
+
outs.append(x)
|
240 |
+
return outs
|
241 |
+
|
242 |
+
def _freeze_stages(self):
|
243 |
+
for i in range(self.frozen_stages + 1):
|
244 |
+
layer = getattr(self, f'layer{i}')
|
245 |
+
layer.eval()
|
246 |
+
for param in layer.parameters():
|
247 |
+
param.requires_grad = False
|
248 |
+
|
249 |
+
def train(self, mode=True):
|
250 |
+
super(MobileNetV3, self).train(mode)
|
251 |
+
self._freeze_stages()
|
252 |
+
if mode and self.norm_eval:
|
253 |
+
for m in self.modules():
|
254 |
+
if isinstance(m, _BatchNorm):
|
255 |
+
m.eval()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/resnest.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.utils.checkpoint as cp
|
7 |
+
from annotator.mmpkg.mmcv.cnn import build_conv_layer, build_norm_layer
|
8 |
+
|
9 |
+
from ..builder import BACKBONES
|
10 |
+
from ..utils import ResLayer
|
11 |
+
from .resnet import Bottleneck as _Bottleneck
|
12 |
+
from .resnet import ResNetV1d
|
13 |
+
|
14 |
+
|
15 |
+
class RSoftmax(nn.Module):
|
16 |
+
"""Radix Softmax module in ``SplitAttentionConv2d``.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
radix (int): Radix of input.
|
20 |
+
groups (int): Groups of input.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, radix, groups):
|
24 |
+
super().__init__()
|
25 |
+
self.radix = radix
|
26 |
+
self.groups = groups
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
batch = x.size(0)
|
30 |
+
if self.radix > 1:
|
31 |
+
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
|
32 |
+
x = F.softmax(x, dim=1)
|
33 |
+
x = x.reshape(batch, -1)
|
34 |
+
else:
|
35 |
+
x = torch.sigmoid(x)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class SplitAttentionConv2d(nn.Module):
|
40 |
+
"""Split-Attention Conv2d in ResNeSt.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
in_channels (int): Same as nn.Conv2d.
|
44 |
+
out_channels (int): Same as nn.Conv2d.
|
45 |
+
kernel_size (int | tuple[int]): Same as nn.Conv2d.
|
46 |
+
stride (int | tuple[int]): Same as nn.Conv2d.
|
47 |
+
padding (int | tuple[int]): Same as nn.Conv2d.
|
48 |
+
dilation (int | tuple[int]): Same as nn.Conv2d.
|
49 |
+
groups (int): Same as nn.Conv2d.
|
50 |
+
radix (int): Radix of SpltAtConv2d. Default: 2
|
51 |
+
reduction_factor (int): Reduction factor of inter_channels. Default: 4.
|
52 |
+
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
53 |
+
which means using conv2d.
|
54 |
+
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
55 |
+
dcn (dict): Config dict for DCN. Default: None.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self,
|
59 |
+
in_channels,
|
60 |
+
channels,
|
61 |
+
kernel_size,
|
62 |
+
stride=1,
|
63 |
+
padding=0,
|
64 |
+
dilation=1,
|
65 |
+
groups=1,
|
66 |
+
radix=2,
|
67 |
+
reduction_factor=4,
|
68 |
+
conv_cfg=None,
|
69 |
+
norm_cfg=dict(type='BN'),
|
70 |
+
dcn=None):
|
71 |
+
super(SplitAttentionConv2d, self).__init__()
|
72 |
+
inter_channels = max(in_channels * radix // reduction_factor, 32)
|
73 |
+
self.radix = radix
|
74 |
+
self.groups = groups
|
75 |
+
self.channels = channels
|
76 |
+
self.with_dcn = dcn is not None
|
77 |
+
self.dcn = dcn
|
78 |
+
fallback_on_stride = False
|
79 |
+
if self.with_dcn:
|
80 |
+
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
81 |
+
if self.with_dcn and not fallback_on_stride:
|
82 |
+
assert conv_cfg is None, 'conv_cfg must be None for DCN'
|
83 |
+
conv_cfg = dcn
|
84 |
+
self.conv = build_conv_layer(
|
85 |
+
conv_cfg,
|
86 |
+
in_channels,
|
87 |
+
channels * radix,
|
88 |
+
kernel_size,
|
89 |
+
stride=stride,
|
90 |
+
padding=padding,
|
91 |
+
dilation=dilation,
|
92 |
+
groups=groups * radix,
|
93 |
+
bias=False)
|
94 |
+
self.norm0_name, norm0 = build_norm_layer(
|
95 |
+
norm_cfg, channels * radix, postfix=0)
|
96 |
+
self.add_module(self.norm0_name, norm0)
|
97 |
+
self.relu = nn.ReLU(inplace=True)
|
98 |
+
self.fc1 = build_conv_layer(
|
99 |
+
None, channels, inter_channels, 1, groups=self.groups)
|
100 |
+
self.norm1_name, norm1 = build_norm_layer(
|
101 |
+
norm_cfg, inter_channels, postfix=1)
|
102 |
+
self.add_module(self.norm1_name, norm1)
|
103 |
+
self.fc2 = build_conv_layer(
|
104 |
+
None, inter_channels, channels * radix, 1, groups=self.groups)
|
105 |
+
self.rsoftmax = RSoftmax(radix, groups)
|
106 |
+
|
107 |
+
@property
|
108 |
+
def norm0(self):
|
109 |
+
"""nn.Module: the normalization layer named "norm0" """
|
110 |
+
return getattr(self, self.norm0_name)
|
111 |
+
|
112 |
+
@property
|
113 |
+
def norm1(self):
|
114 |
+
"""nn.Module: the normalization layer named "norm1" """
|
115 |
+
return getattr(self, self.norm1_name)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
x = self.conv(x)
|
119 |
+
x = self.norm0(x)
|
120 |
+
x = self.relu(x)
|
121 |
+
|
122 |
+
batch, rchannel = x.shape[:2]
|
123 |
+
batch = x.size(0)
|
124 |
+
if self.radix > 1:
|
125 |
+
splits = x.view(batch, self.radix, -1, *x.shape[2:])
|
126 |
+
gap = splits.sum(dim=1)
|
127 |
+
else:
|
128 |
+
gap = x
|
129 |
+
gap = F.adaptive_avg_pool2d(gap, 1)
|
130 |
+
gap = self.fc1(gap)
|
131 |
+
|
132 |
+
gap = self.norm1(gap)
|
133 |
+
gap = self.relu(gap)
|
134 |
+
|
135 |
+
atten = self.fc2(gap)
|
136 |
+
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
|
137 |
+
|
138 |
+
if self.radix > 1:
|
139 |
+
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
|
140 |
+
out = torch.sum(attens * splits, dim=1)
|
141 |
+
else:
|
142 |
+
out = atten * x
|
143 |
+
return out.contiguous()
|
144 |
+
|
145 |
+
|
146 |
+
class Bottleneck(_Bottleneck):
|
147 |
+
"""Bottleneck block for ResNeSt.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
inplane (int): Input planes of this block.
|
151 |
+
planes (int): Middle planes of this block.
|
152 |
+
groups (int): Groups of conv2.
|
153 |
+
width_per_group (int): Width per group of conv2. 64x4d indicates
|
154 |
+
``groups=64, width_per_group=4`` and 32x8d indicates
|
155 |
+
``groups=32, width_per_group=8``.
|
156 |
+
radix (int): Radix of SpltAtConv2d. Default: 2
|
157 |
+
reduction_factor (int): Reduction factor of inter_channels in
|
158 |
+
SplitAttentionConv2d. Default: 4.
|
159 |
+
avg_down_stride (bool): Whether to use average pool for stride in
|
160 |
+
Bottleneck. Default: True.
|
161 |
+
kwargs (dict): Key word arguments for base class.
|
162 |
+
"""
|
163 |
+
expansion = 4
|
164 |
+
|
165 |
+
def __init__(self,
|
166 |
+
inplanes,
|
167 |
+
planes,
|
168 |
+
groups=1,
|
169 |
+
base_width=4,
|
170 |
+
base_channels=64,
|
171 |
+
radix=2,
|
172 |
+
reduction_factor=4,
|
173 |
+
avg_down_stride=True,
|
174 |
+
**kwargs):
|
175 |
+
"""Bottleneck block for ResNeSt."""
|
176 |
+
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
|
177 |
+
|
178 |
+
if groups == 1:
|
179 |
+
width = self.planes
|
180 |
+
else:
|
181 |
+
width = math.floor(self.planes *
|
182 |
+
(base_width / base_channels)) * groups
|
183 |
+
|
184 |
+
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
|
185 |
+
|
186 |
+
self.norm1_name, norm1 = build_norm_layer(
|
187 |
+
self.norm_cfg, width, postfix=1)
|
188 |
+
self.norm3_name, norm3 = build_norm_layer(
|
189 |
+
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
190 |
+
|
191 |
+
self.conv1 = build_conv_layer(
|
192 |
+
self.conv_cfg,
|
193 |
+
self.inplanes,
|
194 |
+
width,
|
195 |
+
kernel_size=1,
|
196 |
+
stride=self.conv1_stride,
|
197 |
+
bias=False)
|
198 |
+
self.add_module(self.norm1_name, norm1)
|
199 |
+
self.with_modulated_dcn = False
|
200 |
+
self.conv2 = SplitAttentionConv2d(
|
201 |
+
width,
|
202 |
+
width,
|
203 |
+
kernel_size=3,
|
204 |
+
stride=1 if self.avg_down_stride else self.conv2_stride,
|
205 |
+
padding=self.dilation,
|
206 |
+
dilation=self.dilation,
|
207 |
+
groups=groups,
|
208 |
+
radix=radix,
|
209 |
+
reduction_factor=reduction_factor,
|
210 |
+
conv_cfg=self.conv_cfg,
|
211 |
+
norm_cfg=self.norm_cfg,
|
212 |
+
dcn=self.dcn)
|
213 |
+
delattr(self, self.norm2_name)
|
214 |
+
|
215 |
+
if self.avg_down_stride:
|
216 |
+
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
|
217 |
+
|
218 |
+
self.conv3 = build_conv_layer(
|
219 |
+
self.conv_cfg,
|
220 |
+
width,
|
221 |
+
self.planes * self.expansion,
|
222 |
+
kernel_size=1,
|
223 |
+
bias=False)
|
224 |
+
self.add_module(self.norm3_name, norm3)
|
225 |
+
|
226 |
+
def forward(self, x):
|
227 |
+
|
228 |
+
def _inner_forward(x):
|
229 |
+
identity = x
|
230 |
+
|
231 |
+
out = self.conv1(x)
|
232 |
+
out = self.norm1(out)
|
233 |
+
out = self.relu(out)
|
234 |
+
|
235 |
+
if self.with_plugins:
|
236 |
+
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
237 |
+
|
238 |
+
out = self.conv2(out)
|
239 |
+
|
240 |
+
if self.avg_down_stride:
|
241 |
+
out = self.avd_layer(out)
|
242 |
+
|
243 |
+
if self.with_plugins:
|
244 |
+
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
245 |
+
|
246 |
+
out = self.conv3(out)
|
247 |
+
out = self.norm3(out)
|
248 |
+
|
249 |
+
if self.with_plugins:
|
250 |
+
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
251 |
+
|
252 |
+
if self.downsample is not None:
|
253 |
+
identity = self.downsample(x)
|
254 |
+
|
255 |
+
out += identity
|
256 |
+
|
257 |
+
return out
|
258 |
+
|
259 |
+
if self.with_cp and x.requires_grad:
|
260 |
+
out = cp.checkpoint(_inner_forward, x)
|
261 |
+
else:
|
262 |
+
out = _inner_forward(x)
|
263 |
+
|
264 |
+
out = self.relu(out)
|
265 |
+
|
266 |
+
return out
|
267 |
+
|
268 |
+
|
269 |
+
@BACKBONES.register_module()
|
270 |
+
class ResNeSt(ResNetV1d):
|
271 |
+
"""ResNeSt backbone.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
groups (int): Number of groups of Bottleneck. Default: 1
|
275 |
+
base_width (int): Base width of Bottleneck. Default: 4
|
276 |
+
radix (int): Radix of SpltAtConv2d. Default: 2
|
277 |
+
reduction_factor (int): Reduction factor of inter_channels in
|
278 |
+
SplitAttentionConv2d. Default: 4.
|
279 |
+
avg_down_stride (bool): Whether to use average pool for stride in
|
280 |
+
Bottleneck. Default: True.
|
281 |
+
kwargs (dict): Keyword arguments for ResNet.
|
282 |
+
"""
|
283 |
+
|
284 |
+
arch_settings = {
|
285 |
+
50: (Bottleneck, (3, 4, 6, 3)),
|
286 |
+
101: (Bottleneck, (3, 4, 23, 3)),
|
287 |
+
152: (Bottleneck, (3, 8, 36, 3)),
|
288 |
+
200: (Bottleneck, (3, 24, 36, 3))
|
289 |
+
}
|
290 |
+
|
291 |
+
def __init__(self,
|
292 |
+
groups=1,
|
293 |
+
base_width=4,
|
294 |
+
radix=2,
|
295 |
+
reduction_factor=4,
|
296 |
+
avg_down_stride=True,
|
297 |
+
**kwargs):
|
298 |
+
self.groups = groups
|
299 |
+
self.base_width = base_width
|
300 |
+
self.radix = radix
|
301 |
+
self.reduction_factor = reduction_factor
|
302 |
+
self.avg_down_stride = avg_down_stride
|
303 |
+
super(ResNeSt, self).__init__(**kwargs)
|
304 |
+
|
305 |
+
def make_res_layer(self, **kwargs):
|
306 |
+
"""Pack all blocks in a stage into a ``ResLayer``."""
|
307 |
+
return ResLayer(
|
308 |
+
groups=self.groups,
|
309 |
+
base_width=self.base_width,
|
310 |
+
base_channels=self.base_channels,
|
311 |
+
radix=self.radix,
|
312 |
+
reduction_factor=self.reduction_factor,
|
313 |
+
avg_down_stride=self.avg_down_stride,
|
314 |
+
**kwargs)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/resnet.py
ADDED
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.utils.checkpoint as cp
|
3 |
+
from annotator.mmpkg.mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
|
4 |
+
constant_init, kaiming_init)
|
5 |
+
from annotator.mmpkg.mmcv.runner import load_checkpoint
|
6 |
+
from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm
|
7 |
+
|
8 |
+
from annotator.mmpkg.mmseg.utils import get_root_logger
|
9 |
+
from ..builder import BACKBONES
|
10 |
+
from ..utils import ResLayer
|
11 |
+
|
12 |
+
|
13 |
+
class BasicBlock(nn.Module):
|
14 |
+
"""Basic block for ResNet."""
|
15 |
+
|
16 |
+
expansion = 1
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
inplanes,
|
20 |
+
planes,
|
21 |
+
stride=1,
|
22 |
+
dilation=1,
|
23 |
+
downsample=None,
|
24 |
+
style='pytorch',
|
25 |
+
with_cp=False,
|
26 |
+
conv_cfg=None,
|
27 |
+
norm_cfg=dict(type='BN'),
|
28 |
+
dcn=None,
|
29 |
+
plugins=None):
|
30 |
+
super(BasicBlock, self).__init__()
|
31 |
+
assert dcn is None, 'Not implemented yet.'
|
32 |
+
assert plugins is None, 'Not implemented yet.'
|
33 |
+
|
34 |
+
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
35 |
+
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
36 |
+
|
37 |
+
self.conv1 = build_conv_layer(
|
38 |
+
conv_cfg,
|
39 |
+
inplanes,
|
40 |
+
planes,
|
41 |
+
3,
|
42 |
+
stride=stride,
|
43 |
+
padding=dilation,
|
44 |
+
dilation=dilation,
|
45 |
+
bias=False)
|
46 |
+
self.add_module(self.norm1_name, norm1)
|
47 |
+
self.conv2 = build_conv_layer(
|
48 |
+
conv_cfg, planes, planes, 3, padding=1, bias=False)
|
49 |
+
self.add_module(self.norm2_name, norm2)
|
50 |
+
|
51 |
+
self.relu = nn.ReLU(inplace=True)
|
52 |
+
self.downsample = downsample
|
53 |
+
self.stride = stride
|
54 |
+
self.dilation = dilation
|
55 |
+
self.with_cp = with_cp
|
56 |
+
|
57 |
+
@property
|
58 |
+
def norm1(self):
|
59 |
+
"""nn.Module: normalization layer after the first convolution layer"""
|
60 |
+
return getattr(self, self.norm1_name)
|
61 |
+
|
62 |
+
@property
|
63 |
+
def norm2(self):
|
64 |
+
"""nn.Module: normalization layer after the second convolution layer"""
|
65 |
+
return getattr(self, self.norm2_name)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
"""Forward function."""
|
69 |
+
|
70 |
+
def _inner_forward(x):
|
71 |
+
identity = x
|
72 |
+
|
73 |
+
out = self.conv1(x)
|
74 |
+
out = self.norm1(out)
|
75 |
+
out = self.relu(out)
|
76 |
+
|
77 |
+
out = self.conv2(out)
|
78 |
+
out = self.norm2(out)
|
79 |
+
|
80 |
+
if self.downsample is not None:
|
81 |
+
identity = self.downsample(x)
|
82 |
+
|
83 |
+
out += identity
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
if self.with_cp and x.requires_grad:
|
88 |
+
out = cp.checkpoint(_inner_forward, x)
|
89 |
+
else:
|
90 |
+
out = _inner_forward(x)
|
91 |
+
|
92 |
+
out = self.relu(out)
|
93 |
+
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
class Bottleneck(nn.Module):
|
98 |
+
"""Bottleneck block for ResNet.
|
99 |
+
|
100 |
+
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
101 |
+
"caffe", the stride-two layer is the first 1x1 conv layer.
|
102 |
+
"""
|
103 |
+
|
104 |
+
expansion = 4
|
105 |
+
|
106 |
+
def __init__(self,
|
107 |
+
inplanes,
|
108 |
+
planes,
|
109 |
+
stride=1,
|
110 |
+
dilation=1,
|
111 |
+
downsample=None,
|
112 |
+
style='pytorch',
|
113 |
+
with_cp=False,
|
114 |
+
conv_cfg=None,
|
115 |
+
norm_cfg=dict(type='BN'),
|
116 |
+
dcn=None,
|
117 |
+
plugins=None):
|
118 |
+
super(Bottleneck, self).__init__()
|
119 |
+
assert style in ['pytorch', 'caffe']
|
120 |
+
assert dcn is None or isinstance(dcn, dict)
|
121 |
+
assert plugins is None or isinstance(plugins, list)
|
122 |
+
if plugins is not None:
|
123 |
+
allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
|
124 |
+
assert all(p['position'] in allowed_position for p in plugins)
|
125 |
+
|
126 |
+
self.inplanes = inplanes
|
127 |
+
self.planes = planes
|
128 |
+
self.stride = stride
|
129 |
+
self.dilation = dilation
|
130 |
+
self.style = style
|
131 |
+
self.with_cp = with_cp
|
132 |
+
self.conv_cfg = conv_cfg
|
133 |
+
self.norm_cfg = norm_cfg
|
134 |
+
self.dcn = dcn
|
135 |
+
self.with_dcn = dcn is not None
|
136 |
+
self.plugins = plugins
|
137 |
+
self.with_plugins = plugins is not None
|
138 |
+
|
139 |
+
if self.with_plugins:
|
140 |
+
# collect plugins for conv1/conv2/conv3
|
141 |
+
self.after_conv1_plugins = [
|
142 |
+
plugin['cfg'] for plugin in plugins
|
143 |
+
if plugin['position'] == 'after_conv1'
|
144 |
+
]
|
145 |
+
self.after_conv2_plugins = [
|
146 |
+
plugin['cfg'] for plugin in plugins
|
147 |
+
if plugin['position'] == 'after_conv2'
|
148 |
+
]
|
149 |
+
self.after_conv3_plugins = [
|
150 |
+
plugin['cfg'] for plugin in plugins
|
151 |
+
if plugin['position'] == 'after_conv3'
|
152 |
+
]
|
153 |
+
|
154 |
+
if self.style == 'pytorch':
|
155 |
+
self.conv1_stride = 1
|
156 |
+
self.conv2_stride = stride
|
157 |
+
else:
|
158 |
+
self.conv1_stride = stride
|
159 |
+
self.conv2_stride = 1
|
160 |
+
|
161 |
+
self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
|
162 |
+
self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
|
163 |
+
self.norm3_name, norm3 = build_norm_layer(
|
164 |
+
norm_cfg, planes * self.expansion, postfix=3)
|
165 |
+
|
166 |
+
self.conv1 = build_conv_layer(
|
167 |
+
conv_cfg,
|
168 |
+
inplanes,
|
169 |
+
planes,
|
170 |
+
kernel_size=1,
|
171 |
+
stride=self.conv1_stride,
|
172 |
+
bias=False)
|
173 |
+
self.add_module(self.norm1_name, norm1)
|
174 |
+
fallback_on_stride = False
|
175 |
+
if self.with_dcn:
|
176 |
+
fallback_on_stride = dcn.pop('fallback_on_stride', False)
|
177 |
+
if not self.with_dcn or fallback_on_stride:
|
178 |
+
self.conv2 = build_conv_layer(
|
179 |
+
conv_cfg,
|
180 |
+
planes,
|
181 |
+
planes,
|
182 |
+
kernel_size=3,
|
183 |
+
stride=self.conv2_stride,
|
184 |
+
padding=dilation,
|
185 |
+
dilation=dilation,
|
186 |
+
bias=False)
|
187 |
+
else:
|
188 |
+
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
|
189 |
+
self.conv2 = build_conv_layer(
|
190 |
+
dcn,
|
191 |
+
planes,
|
192 |
+
planes,
|
193 |
+
kernel_size=3,
|
194 |
+
stride=self.conv2_stride,
|
195 |
+
padding=dilation,
|
196 |
+
dilation=dilation,
|
197 |
+
bias=False)
|
198 |
+
|
199 |
+
self.add_module(self.norm2_name, norm2)
|
200 |
+
self.conv3 = build_conv_layer(
|
201 |
+
conv_cfg,
|
202 |
+
planes,
|
203 |
+
planes * self.expansion,
|
204 |
+
kernel_size=1,
|
205 |
+
bias=False)
|
206 |
+
self.add_module(self.norm3_name, norm3)
|
207 |
+
|
208 |
+
self.relu = nn.ReLU(inplace=True)
|
209 |
+
self.downsample = downsample
|
210 |
+
|
211 |
+
if self.with_plugins:
|
212 |
+
self.after_conv1_plugin_names = self.make_block_plugins(
|
213 |
+
planes, self.after_conv1_plugins)
|
214 |
+
self.after_conv2_plugin_names = self.make_block_plugins(
|
215 |
+
planes, self.after_conv2_plugins)
|
216 |
+
self.after_conv3_plugin_names = self.make_block_plugins(
|
217 |
+
planes * self.expansion, self.after_conv3_plugins)
|
218 |
+
|
219 |
+
def make_block_plugins(self, in_channels, plugins):
|
220 |
+
"""make plugins for block.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
in_channels (int): Input channels of plugin.
|
224 |
+
plugins (list[dict]): List of plugins cfg to build.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
list[str]: List of the names of plugin.
|
228 |
+
"""
|
229 |
+
assert isinstance(plugins, list)
|
230 |
+
plugin_names = []
|
231 |
+
for plugin in plugins:
|
232 |
+
plugin = plugin.copy()
|
233 |
+
name, layer = build_plugin_layer(
|
234 |
+
plugin,
|
235 |
+
in_channels=in_channels,
|
236 |
+
postfix=plugin.pop('postfix', ''))
|
237 |
+
assert not hasattr(self, name), f'duplicate plugin {name}'
|
238 |
+
self.add_module(name, layer)
|
239 |
+
plugin_names.append(name)
|
240 |
+
return plugin_names
|
241 |
+
|
242 |
+
def forward_plugin(self, x, plugin_names):
|
243 |
+
"""Forward function for plugins."""
|
244 |
+
out = x
|
245 |
+
for name in plugin_names:
|
246 |
+
out = getattr(self, name)(x)
|
247 |
+
return out
|
248 |
+
|
249 |
+
@property
|
250 |
+
def norm1(self):
|
251 |
+
"""nn.Module: normalization layer after the first convolution layer"""
|
252 |
+
return getattr(self, self.norm1_name)
|
253 |
+
|
254 |
+
@property
|
255 |
+
def norm2(self):
|
256 |
+
"""nn.Module: normalization layer after the second convolution layer"""
|
257 |
+
return getattr(self, self.norm2_name)
|
258 |
+
|
259 |
+
@property
|
260 |
+
def norm3(self):
|
261 |
+
"""nn.Module: normalization layer after the third convolution layer"""
|
262 |
+
return getattr(self, self.norm3_name)
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
"""Forward function."""
|
266 |
+
|
267 |
+
def _inner_forward(x):
|
268 |
+
identity = x
|
269 |
+
|
270 |
+
out = self.conv1(x)
|
271 |
+
out = self.norm1(out)
|
272 |
+
out = self.relu(out)
|
273 |
+
|
274 |
+
if self.with_plugins:
|
275 |
+
out = self.forward_plugin(out, self.after_conv1_plugin_names)
|
276 |
+
|
277 |
+
out = self.conv2(out)
|
278 |
+
out = self.norm2(out)
|
279 |
+
out = self.relu(out)
|
280 |
+
|
281 |
+
if self.with_plugins:
|
282 |
+
out = self.forward_plugin(out, self.after_conv2_plugin_names)
|
283 |
+
|
284 |
+
out = self.conv3(out)
|
285 |
+
out = self.norm3(out)
|
286 |
+
|
287 |
+
if self.with_plugins:
|
288 |
+
out = self.forward_plugin(out, self.after_conv3_plugin_names)
|
289 |
+
|
290 |
+
if self.downsample is not None:
|
291 |
+
identity = self.downsample(x)
|
292 |
+
|
293 |
+
out += identity
|
294 |
+
|
295 |
+
return out
|
296 |
+
|
297 |
+
if self.with_cp and x.requires_grad:
|
298 |
+
out = cp.checkpoint(_inner_forward, x)
|
299 |
+
else:
|
300 |
+
out = _inner_forward(x)
|
301 |
+
|
302 |
+
out = self.relu(out)
|
303 |
+
|
304 |
+
return out
|
305 |
+
|
306 |
+
|
307 |
+
@BACKBONES.register_module()
|
308 |
+
class ResNet(nn.Module):
|
309 |
+
"""ResNet backbone.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
313 |
+
in_channels (int): Number of input image channels. Default" 3.
|
314 |
+
stem_channels (int): Number of stem channels. Default: 64.
|
315 |
+
base_channels (int): Number of base channels of res layer. Default: 64.
|
316 |
+
num_stages (int): Resnet stages, normally 4.
|
317 |
+
strides (Sequence[int]): Strides of the first block of each stage.
|
318 |
+
dilations (Sequence[int]): Dilation of each stage.
|
319 |
+
out_indices (Sequence[int]): Output from which stages.
|
320 |
+
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
321 |
+
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
322 |
+
the first 1x1 conv layer.
|
323 |
+
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
|
324 |
+
avg_down (bool): Use AvgPool instead of stride conv when
|
325 |
+
downsampling in the bottleneck.
|
326 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
327 |
+
-1 means not freezing any parameters.
|
328 |
+
norm_cfg (dict): Dictionary to construct and config norm layer.
|
329 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
330 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
331 |
+
and its variants only.
|
332 |
+
plugins (list[dict]): List of plugins for stages, each dict contains:
|
333 |
+
|
334 |
+
- cfg (dict, required): Cfg dict to build plugin.
|
335 |
+
|
336 |
+
- position (str, required): Position inside block to insert plugin,
|
337 |
+
options: 'after_conv1', 'after_conv2', 'after_conv3'.
|
338 |
+
|
339 |
+
- stages (tuple[bool], optional): Stages to apply plugin, length
|
340 |
+
should be same as 'num_stages'
|
341 |
+
multi_grid (Sequence[int]|None): Multi grid dilation rates of last
|
342 |
+
stage. Default: None
|
343 |
+
contract_dilation (bool): Whether contract first dilation of each layer
|
344 |
+
Default: False
|
345 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
346 |
+
memory while slowing down the training speed.
|
347 |
+
zero_init_residual (bool): Whether to use zero init for last norm layer
|
348 |
+
in resblocks to let them behave as identity.
|
349 |
+
|
350 |
+
Example:
|
351 |
+
>>> from annotator.mmpkg.mmseg.models import ResNet
|
352 |
+
>>> import torch
|
353 |
+
>>> self = ResNet(depth=18)
|
354 |
+
>>> self.eval()
|
355 |
+
>>> inputs = torch.rand(1, 3, 32, 32)
|
356 |
+
>>> level_outputs = self.forward(inputs)
|
357 |
+
>>> for level_out in level_outputs:
|
358 |
+
... print(tuple(level_out.shape))
|
359 |
+
(1, 64, 8, 8)
|
360 |
+
(1, 128, 4, 4)
|
361 |
+
(1, 256, 2, 2)
|
362 |
+
(1, 512, 1, 1)
|
363 |
+
"""
|
364 |
+
|
365 |
+
arch_settings = {
|
366 |
+
18: (BasicBlock, (2, 2, 2, 2)),
|
367 |
+
34: (BasicBlock, (3, 4, 6, 3)),
|
368 |
+
50: (Bottleneck, (3, 4, 6, 3)),
|
369 |
+
101: (Bottleneck, (3, 4, 23, 3)),
|
370 |
+
152: (Bottleneck, (3, 8, 36, 3))
|
371 |
+
}
|
372 |
+
|
373 |
+
def __init__(self,
|
374 |
+
depth,
|
375 |
+
in_channels=3,
|
376 |
+
stem_channels=64,
|
377 |
+
base_channels=64,
|
378 |
+
num_stages=4,
|
379 |
+
strides=(1, 2, 2, 2),
|
380 |
+
dilations=(1, 1, 1, 1),
|
381 |
+
out_indices=(0, 1, 2, 3),
|
382 |
+
style='pytorch',
|
383 |
+
deep_stem=False,
|
384 |
+
avg_down=False,
|
385 |
+
frozen_stages=-1,
|
386 |
+
conv_cfg=None,
|
387 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
388 |
+
norm_eval=False,
|
389 |
+
dcn=None,
|
390 |
+
stage_with_dcn=(False, False, False, False),
|
391 |
+
plugins=None,
|
392 |
+
multi_grid=None,
|
393 |
+
contract_dilation=False,
|
394 |
+
with_cp=False,
|
395 |
+
zero_init_residual=True):
|
396 |
+
super(ResNet, self).__init__()
|
397 |
+
if depth not in self.arch_settings:
|
398 |
+
raise KeyError(f'invalid depth {depth} for resnet')
|
399 |
+
self.depth = depth
|
400 |
+
self.stem_channels = stem_channels
|
401 |
+
self.base_channels = base_channels
|
402 |
+
self.num_stages = num_stages
|
403 |
+
assert num_stages >= 1 and num_stages <= 4
|
404 |
+
self.strides = strides
|
405 |
+
self.dilations = dilations
|
406 |
+
assert len(strides) == len(dilations) == num_stages
|
407 |
+
self.out_indices = out_indices
|
408 |
+
assert max(out_indices) < num_stages
|
409 |
+
self.style = style
|
410 |
+
self.deep_stem = deep_stem
|
411 |
+
self.avg_down = avg_down
|
412 |
+
self.frozen_stages = frozen_stages
|
413 |
+
self.conv_cfg = conv_cfg
|
414 |
+
self.norm_cfg = norm_cfg
|
415 |
+
self.with_cp = with_cp
|
416 |
+
self.norm_eval = norm_eval
|
417 |
+
self.dcn = dcn
|
418 |
+
self.stage_with_dcn = stage_with_dcn
|
419 |
+
if dcn is not None:
|
420 |
+
assert len(stage_with_dcn) == num_stages
|
421 |
+
self.plugins = plugins
|
422 |
+
self.multi_grid = multi_grid
|
423 |
+
self.contract_dilation = contract_dilation
|
424 |
+
self.zero_init_residual = zero_init_residual
|
425 |
+
self.block, stage_blocks = self.arch_settings[depth]
|
426 |
+
self.stage_blocks = stage_blocks[:num_stages]
|
427 |
+
self.inplanes = stem_channels
|
428 |
+
|
429 |
+
self._make_stem_layer(in_channels, stem_channels)
|
430 |
+
|
431 |
+
self.res_layers = []
|
432 |
+
for i, num_blocks in enumerate(self.stage_blocks):
|
433 |
+
stride = strides[i]
|
434 |
+
dilation = dilations[i]
|
435 |
+
dcn = self.dcn if self.stage_with_dcn[i] else None
|
436 |
+
if plugins is not None:
|
437 |
+
stage_plugins = self.make_stage_plugins(plugins, i)
|
438 |
+
else:
|
439 |
+
stage_plugins = None
|
440 |
+
# multi grid is applied to last layer only
|
441 |
+
stage_multi_grid = multi_grid if i == len(
|
442 |
+
self.stage_blocks) - 1 else None
|
443 |
+
planes = base_channels * 2**i
|
444 |
+
res_layer = self.make_res_layer(
|
445 |
+
block=self.block,
|
446 |
+
inplanes=self.inplanes,
|
447 |
+
planes=planes,
|
448 |
+
num_blocks=num_blocks,
|
449 |
+
stride=stride,
|
450 |
+
dilation=dilation,
|
451 |
+
style=self.style,
|
452 |
+
avg_down=self.avg_down,
|
453 |
+
with_cp=with_cp,
|
454 |
+
conv_cfg=conv_cfg,
|
455 |
+
norm_cfg=norm_cfg,
|
456 |
+
dcn=dcn,
|
457 |
+
plugins=stage_plugins,
|
458 |
+
multi_grid=stage_multi_grid,
|
459 |
+
contract_dilation=contract_dilation)
|
460 |
+
self.inplanes = planes * self.block.expansion
|
461 |
+
layer_name = f'layer{i+1}'
|
462 |
+
self.add_module(layer_name, res_layer)
|
463 |
+
self.res_layers.append(layer_name)
|
464 |
+
|
465 |
+
self._freeze_stages()
|
466 |
+
|
467 |
+
self.feat_dim = self.block.expansion * base_channels * 2**(
|
468 |
+
len(self.stage_blocks) - 1)
|
469 |
+
|
470 |
+
def make_stage_plugins(self, plugins, stage_idx):
|
471 |
+
"""make plugins for ResNet 'stage_idx'th stage .
|
472 |
+
|
473 |
+
Currently we support to insert 'context_block',
|
474 |
+
'empirical_attention_block', 'nonlocal_block' into the backbone like
|
475 |
+
ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
|
476 |
+
Bottleneck.
|
477 |
+
|
478 |
+
An example of plugins format could be :
|
479 |
+
>>> plugins=[
|
480 |
+
... dict(cfg=dict(type='xxx', arg1='xxx'),
|
481 |
+
... stages=(False, True, True, True),
|
482 |
+
... position='after_conv2'),
|
483 |
+
... dict(cfg=dict(type='yyy'),
|
484 |
+
... stages=(True, True, True, True),
|
485 |
+
... position='after_conv3'),
|
486 |
+
... dict(cfg=dict(type='zzz', postfix='1'),
|
487 |
+
... stages=(True, True, True, True),
|
488 |
+
... position='after_conv3'),
|
489 |
+
... dict(cfg=dict(type='zzz', postfix='2'),
|
490 |
+
... stages=(True, True, True, True),
|
491 |
+
... position='after_conv3')
|
492 |
+
... ]
|
493 |
+
>>> self = ResNet(depth=18)
|
494 |
+
>>> stage_plugins = self.make_stage_plugins(plugins, 0)
|
495 |
+
>>> assert len(stage_plugins) == 3
|
496 |
+
|
497 |
+
Suppose 'stage_idx=0', the structure of blocks in the stage would be:
|
498 |
+
conv1-> conv2->conv3->yyy->zzz1->zzz2
|
499 |
+
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
|
500 |
+
conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
|
501 |
+
|
502 |
+
If stages is missing, the plugin would be applied to all stages.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
plugins (list[dict]): List of plugins cfg to build. The postfix is
|
506 |
+
required if multiple same type plugins are inserted.
|
507 |
+
stage_idx (int): Index of stage to build
|
508 |
+
|
509 |
+
Returns:
|
510 |
+
list[dict]: Plugins for current stage
|
511 |
+
"""
|
512 |
+
stage_plugins = []
|
513 |
+
for plugin in plugins:
|
514 |
+
plugin = plugin.copy()
|
515 |
+
stages = plugin.pop('stages', None)
|
516 |
+
assert stages is None or len(stages) == self.num_stages
|
517 |
+
# whether to insert plugin into current stage
|
518 |
+
if stages is None or stages[stage_idx]:
|
519 |
+
stage_plugins.append(plugin)
|
520 |
+
|
521 |
+
return stage_plugins
|
522 |
+
|
523 |
+
def make_res_layer(self, **kwargs):
|
524 |
+
"""Pack all blocks in a stage into a ``ResLayer``."""
|
525 |
+
return ResLayer(**kwargs)
|
526 |
+
|
527 |
+
@property
|
528 |
+
def norm1(self):
|
529 |
+
"""nn.Module: the normalization layer named "norm1" """
|
530 |
+
return getattr(self, self.norm1_name)
|
531 |
+
|
532 |
+
def _make_stem_layer(self, in_channels, stem_channels):
|
533 |
+
"""Make stem layer for ResNet."""
|
534 |
+
if self.deep_stem:
|
535 |
+
self.stem = nn.Sequential(
|
536 |
+
build_conv_layer(
|
537 |
+
self.conv_cfg,
|
538 |
+
in_channels,
|
539 |
+
stem_channels // 2,
|
540 |
+
kernel_size=3,
|
541 |
+
stride=2,
|
542 |
+
padding=1,
|
543 |
+
bias=False),
|
544 |
+
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
545 |
+
nn.ReLU(inplace=True),
|
546 |
+
build_conv_layer(
|
547 |
+
self.conv_cfg,
|
548 |
+
stem_channels // 2,
|
549 |
+
stem_channels // 2,
|
550 |
+
kernel_size=3,
|
551 |
+
stride=1,
|
552 |
+
padding=1,
|
553 |
+
bias=False),
|
554 |
+
build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
|
555 |
+
nn.ReLU(inplace=True),
|
556 |
+
build_conv_layer(
|
557 |
+
self.conv_cfg,
|
558 |
+
stem_channels // 2,
|
559 |
+
stem_channels,
|
560 |
+
kernel_size=3,
|
561 |
+
stride=1,
|
562 |
+
padding=1,
|
563 |
+
bias=False),
|
564 |
+
build_norm_layer(self.norm_cfg, stem_channels)[1],
|
565 |
+
nn.ReLU(inplace=True))
|
566 |
+
else:
|
567 |
+
self.conv1 = build_conv_layer(
|
568 |
+
self.conv_cfg,
|
569 |
+
in_channels,
|
570 |
+
stem_channels,
|
571 |
+
kernel_size=7,
|
572 |
+
stride=2,
|
573 |
+
padding=3,
|
574 |
+
bias=False)
|
575 |
+
self.norm1_name, norm1 = build_norm_layer(
|
576 |
+
self.norm_cfg, stem_channels, postfix=1)
|
577 |
+
self.add_module(self.norm1_name, norm1)
|
578 |
+
self.relu = nn.ReLU(inplace=True)
|
579 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
580 |
+
|
581 |
+
def _freeze_stages(self):
|
582 |
+
"""Freeze stages param and norm stats."""
|
583 |
+
if self.frozen_stages >= 0:
|
584 |
+
if self.deep_stem:
|
585 |
+
self.stem.eval()
|
586 |
+
for param in self.stem.parameters():
|
587 |
+
param.requires_grad = False
|
588 |
+
else:
|
589 |
+
self.norm1.eval()
|
590 |
+
for m in [self.conv1, self.norm1]:
|
591 |
+
for param in m.parameters():
|
592 |
+
param.requires_grad = False
|
593 |
+
|
594 |
+
for i in range(1, self.frozen_stages + 1):
|
595 |
+
m = getattr(self, f'layer{i}')
|
596 |
+
m.eval()
|
597 |
+
for param in m.parameters():
|
598 |
+
param.requires_grad = False
|
599 |
+
|
600 |
+
def init_weights(self, pretrained=None):
|
601 |
+
"""Initialize the weights in backbone.
|
602 |
+
|
603 |
+
Args:
|
604 |
+
pretrained (str, optional): Path to pre-trained weights.
|
605 |
+
Defaults to None.
|
606 |
+
"""
|
607 |
+
if isinstance(pretrained, str):
|
608 |
+
logger = get_root_logger()
|
609 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
610 |
+
elif pretrained is None:
|
611 |
+
for m in self.modules():
|
612 |
+
if isinstance(m, nn.Conv2d):
|
613 |
+
kaiming_init(m)
|
614 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
615 |
+
constant_init(m, 1)
|
616 |
+
|
617 |
+
if self.dcn is not None:
|
618 |
+
for m in self.modules():
|
619 |
+
if isinstance(m, Bottleneck) and hasattr(
|
620 |
+
m, 'conv2_offset'):
|
621 |
+
constant_init(m.conv2_offset, 0)
|
622 |
+
|
623 |
+
if self.zero_init_residual:
|
624 |
+
for m in self.modules():
|
625 |
+
if isinstance(m, Bottleneck):
|
626 |
+
constant_init(m.norm3, 0)
|
627 |
+
elif isinstance(m, BasicBlock):
|
628 |
+
constant_init(m.norm2, 0)
|
629 |
+
else:
|
630 |
+
raise TypeError('pretrained must be a str or None')
|
631 |
+
|
632 |
+
def forward(self, x):
|
633 |
+
"""Forward function."""
|
634 |
+
if self.deep_stem:
|
635 |
+
x = self.stem(x)
|
636 |
+
else:
|
637 |
+
x = self.conv1(x)
|
638 |
+
x = self.norm1(x)
|
639 |
+
x = self.relu(x)
|
640 |
+
x = self.maxpool(x)
|
641 |
+
outs = []
|
642 |
+
for i, layer_name in enumerate(self.res_layers):
|
643 |
+
res_layer = getattr(self, layer_name)
|
644 |
+
x = res_layer(x)
|
645 |
+
if i in self.out_indices:
|
646 |
+
outs.append(x)
|
647 |
+
return tuple(outs)
|
648 |
+
|
649 |
+
def train(self, mode=True):
|
650 |
+
"""Convert the model into training mode while keep normalization layer
|
651 |
+
freezed."""
|
652 |
+
super(ResNet, self).train(mode)
|
653 |
+
self._freeze_stages()
|
654 |
+
if mode and self.norm_eval:
|
655 |
+
for m in self.modules():
|
656 |
+
# trick: eval have effect on BatchNorm only
|
657 |
+
if isinstance(m, _BatchNorm):
|
658 |
+
m.eval()
|
659 |
+
|
660 |
+
|
661 |
+
@BACKBONES.register_module()
|
662 |
+
class ResNetV1c(ResNet):
|
663 |
+
"""ResNetV1c variant described in [1]_.
|
664 |
+
|
665 |
+
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
|
666 |
+
in the input stem with three 3x3 convs.
|
667 |
+
|
668 |
+
References:
|
669 |
+
.. [1] https://arxiv.org/pdf/1812.01187.pdf
|
670 |
+
"""
|
671 |
+
|
672 |
+
def __init__(self, **kwargs):
|
673 |
+
super(ResNetV1c, self).__init__(
|
674 |
+
deep_stem=True, avg_down=False, **kwargs)
|
675 |
+
|
676 |
+
|
677 |
+
@BACKBONES.register_module()
|
678 |
+
class ResNetV1d(ResNet):
|
679 |
+
"""ResNetV1d variant described in [1]_.
|
680 |
+
|
681 |
+
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
|
682 |
+
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
|
683 |
+
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
|
684 |
+
"""
|
685 |
+
|
686 |
+
def __init__(self, **kwargs):
|
687 |
+
super(ResNetV1d, self).__init__(
|
688 |
+
deep_stem=True, avg_down=True, **kwargs)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/resnext.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
from annotator.mmpkg.mmcv.cnn import build_conv_layer, build_norm_layer
|
4 |
+
|
5 |
+
from ..builder import BACKBONES
|
6 |
+
from ..utils import ResLayer
|
7 |
+
from .resnet import Bottleneck as _Bottleneck
|
8 |
+
from .resnet import ResNet
|
9 |
+
|
10 |
+
|
11 |
+
class Bottleneck(_Bottleneck):
|
12 |
+
"""Bottleneck block for ResNeXt.
|
13 |
+
|
14 |
+
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
|
15 |
+
"caffe", the stride-two layer is the first 1x1 conv layer.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
inplanes,
|
20 |
+
planes,
|
21 |
+
groups=1,
|
22 |
+
base_width=4,
|
23 |
+
base_channels=64,
|
24 |
+
**kwargs):
|
25 |
+
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
|
26 |
+
|
27 |
+
if groups == 1:
|
28 |
+
width = self.planes
|
29 |
+
else:
|
30 |
+
width = math.floor(self.planes *
|
31 |
+
(base_width / base_channels)) * groups
|
32 |
+
|
33 |
+
self.norm1_name, norm1 = build_norm_layer(
|
34 |
+
self.norm_cfg, width, postfix=1)
|
35 |
+
self.norm2_name, norm2 = build_norm_layer(
|
36 |
+
self.norm_cfg, width, postfix=2)
|
37 |
+
self.norm3_name, norm3 = build_norm_layer(
|
38 |
+
self.norm_cfg, self.planes * self.expansion, postfix=3)
|
39 |
+
|
40 |
+
self.conv1 = build_conv_layer(
|
41 |
+
self.conv_cfg,
|
42 |
+
self.inplanes,
|
43 |
+
width,
|
44 |
+
kernel_size=1,
|
45 |
+
stride=self.conv1_stride,
|
46 |
+
bias=False)
|
47 |
+
self.add_module(self.norm1_name, norm1)
|
48 |
+
fallback_on_stride = False
|
49 |
+
self.with_modulated_dcn = False
|
50 |
+
if self.with_dcn:
|
51 |
+
fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
|
52 |
+
if not self.with_dcn or fallback_on_stride:
|
53 |
+
self.conv2 = build_conv_layer(
|
54 |
+
self.conv_cfg,
|
55 |
+
width,
|
56 |
+
width,
|
57 |
+
kernel_size=3,
|
58 |
+
stride=self.conv2_stride,
|
59 |
+
padding=self.dilation,
|
60 |
+
dilation=self.dilation,
|
61 |
+
groups=groups,
|
62 |
+
bias=False)
|
63 |
+
else:
|
64 |
+
assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
|
65 |
+
self.conv2 = build_conv_layer(
|
66 |
+
self.dcn,
|
67 |
+
width,
|
68 |
+
width,
|
69 |
+
kernel_size=3,
|
70 |
+
stride=self.conv2_stride,
|
71 |
+
padding=self.dilation,
|
72 |
+
dilation=self.dilation,
|
73 |
+
groups=groups,
|
74 |
+
bias=False)
|
75 |
+
|
76 |
+
self.add_module(self.norm2_name, norm2)
|
77 |
+
self.conv3 = build_conv_layer(
|
78 |
+
self.conv_cfg,
|
79 |
+
width,
|
80 |
+
self.planes * self.expansion,
|
81 |
+
kernel_size=1,
|
82 |
+
bias=False)
|
83 |
+
self.add_module(self.norm3_name, norm3)
|
84 |
+
|
85 |
+
|
86 |
+
@BACKBONES.register_module()
|
87 |
+
class ResNeXt(ResNet):
|
88 |
+
"""ResNeXt backbone.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
92 |
+
in_channels (int): Number of input image channels. Normally 3.
|
93 |
+
num_stages (int): Resnet stages, normally 4.
|
94 |
+
groups (int): Group of resnext.
|
95 |
+
base_width (int): Base width of resnext.
|
96 |
+
strides (Sequence[int]): Strides of the first block of each stage.
|
97 |
+
dilations (Sequence[int]): Dilation of each stage.
|
98 |
+
out_indices (Sequence[int]): Output from which stages.
|
99 |
+
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
100 |
+
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
101 |
+
the first 1x1 conv layer.
|
102 |
+
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
103 |
+
not freezing any parameters.
|
104 |
+
norm_cfg (dict): dictionary to construct and config norm layer.
|
105 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
106 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
107 |
+
and its variants only.
|
108 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
109 |
+
memory while slowing down the training speed.
|
110 |
+
zero_init_residual (bool): whether to use zero init for last norm layer
|
111 |
+
in resblocks to let them behave as identity.
|
112 |
+
|
113 |
+
Example:
|
114 |
+
>>> from annotator.mmpkg.mmseg.models import ResNeXt
|
115 |
+
>>> import torch
|
116 |
+
>>> self = ResNeXt(depth=50)
|
117 |
+
>>> self.eval()
|
118 |
+
>>> inputs = torch.rand(1, 3, 32, 32)
|
119 |
+
>>> level_outputs = self.forward(inputs)
|
120 |
+
>>> for level_out in level_outputs:
|
121 |
+
... print(tuple(level_out.shape))
|
122 |
+
(1, 256, 8, 8)
|
123 |
+
(1, 512, 4, 4)
|
124 |
+
(1, 1024, 2, 2)
|
125 |
+
(1, 2048, 1, 1)
|
126 |
+
"""
|
127 |
+
|
128 |
+
arch_settings = {
|
129 |
+
50: (Bottleneck, (3, 4, 6, 3)),
|
130 |
+
101: (Bottleneck, (3, 4, 23, 3)),
|
131 |
+
152: (Bottleneck, (3, 8, 36, 3))
|
132 |
+
}
|
133 |
+
|
134 |
+
def __init__(self, groups=1, base_width=4, **kwargs):
|
135 |
+
self.groups = groups
|
136 |
+
self.base_width = base_width
|
137 |
+
super(ResNeXt, self).__init__(**kwargs)
|
138 |
+
|
139 |
+
def make_res_layer(self, **kwargs):
|
140 |
+
"""Pack all blocks in a stage into a ``ResLayer``"""
|
141 |
+
return ResLayer(
|
142 |
+
groups=self.groups,
|
143 |
+
base_width=self.base_width,
|
144 |
+
base_channels=self.base_channels,
|
145 |
+
**kwargs)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/unet.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.utils.checkpoint as cp
|
3 |
+
from annotator.mmpkg.mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
4 |
+
build_norm_layer, constant_init, kaiming_init)
|
5 |
+
from annotator.mmpkg.mmcv.runner import load_checkpoint
|
6 |
+
from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm
|
7 |
+
|
8 |
+
from annotator.mmpkg.mmseg.utils import get_root_logger
|
9 |
+
from ..builder import BACKBONES
|
10 |
+
from ..utils import UpConvBlock
|
11 |
+
|
12 |
+
|
13 |
+
class BasicConvBlock(nn.Module):
|
14 |
+
"""Basic convolutional block for UNet.
|
15 |
+
|
16 |
+
This module consists of several plain convolutional layers.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
in_channels (int): Number of input channels.
|
20 |
+
out_channels (int): Number of output channels.
|
21 |
+
num_convs (int): Number of convolutional layers. Default: 2.
|
22 |
+
stride (int): Whether use stride convolution to downsample
|
23 |
+
the input feature map. If stride=2, it only uses stride convolution
|
24 |
+
in the first convolutional layer to downsample the input feature
|
25 |
+
map. Options are 1 or 2. Default: 1.
|
26 |
+
dilation (int): Whether use dilated convolution to expand the
|
27 |
+
receptive field. Set dilation rate of each convolutional layer and
|
28 |
+
the dilation rate of the first convolutional layer is always 1.
|
29 |
+
Default: 1.
|
30 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
31 |
+
memory while slowing down the training speed. Default: False.
|
32 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
33 |
+
Default: None.
|
34 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
35 |
+
Default: dict(type='BN').
|
36 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
37 |
+
Default: dict(type='ReLU').
|
38 |
+
dcn (bool): Use deformable convolution in convolutional layer or not.
|
39 |
+
Default: None.
|
40 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self,
|
44 |
+
in_channels,
|
45 |
+
out_channels,
|
46 |
+
num_convs=2,
|
47 |
+
stride=1,
|
48 |
+
dilation=1,
|
49 |
+
with_cp=False,
|
50 |
+
conv_cfg=None,
|
51 |
+
norm_cfg=dict(type='BN'),
|
52 |
+
act_cfg=dict(type='ReLU'),
|
53 |
+
dcn=None,
|
54 |
+
plugins=None):
|
55 |
+
super(BasicConvBlock, self).__init__()
|
56 |
+
assert dcn is None, 'Not implemented yet.'
|
57 |
+
assert plugins is None, 'Not implemented yet.'
|
58 |
+
|
59 |
+
self.with_cp = with_cp
|
60 |
+
convs = []
|
61 |
+
for i in range(num_convs):
|
62 |
+
convs.append(
|
63 |
+
ConvModule(
|
64 |
+
in_channels=in_channels if i == 0 else out_channels,
|
65 |
+
out_channels=out_channels,
|
66 |
+
kernel_size=3,
|
67 |
+
stride=stride if i == 0 else 1,
|
68 |
+
dilation=1 if i == 0 else dilation,
|
69 |
+
padding=1 if i == 0 else dilation,
|
70 |
+
conv_cfg=conv_cfg,
|
71 |
+
norm_cfg=norm_cfg,
|
72 |
+
act_cfg=act_cfg))
|
73 |
+
|
74 |
+
self.convs = nn.Sequential(*convs)
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
"""Forward function."""
|
78 |
+
|
79 |
+
if self.with_cp and x.requires_grad:
|
80 |
+
out = cp.checkpoint(self.convs, x)
|
81 |
+
else:
|
82 |
+
out = self.convs(x)
|
83 |
+
return out
|
84 |
+
|
85 |
+
|
86 |
+
@UPSAMPLE_LAYERS.register_module()
|
87 |
+
class DeconvModule(nn.Module):
|
88 |
+
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
89 |
+
|
90 |
+
This module uses deconvolution to upsample feature map in the decoder
|
91 |
+
of UNet.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
in_channels (int): Number of input channels.
|
95 |
+
out_channels (int): Number of output channels.
|
96 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
97 |
+
memory while slowing down the training speed. Default: False.
|
98 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
99 |
+
Default: dict(type='BN').
|
100 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
101 |
+
Default: dict(type='ReLU').
|
102 |
+
kernel_size (int): Kernel size of the convolutional layer. Default: 4.
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(self,
|
106 |
+
in_channels,
|
107 |
+
out_channels,
|
108 |
+
with_cp=False,
|
109 |
+
norm_cfg=dict(type='BN'),
|
110 |
+
act_cfg=dict(type='ReLU'),
|
111 |
+
*,
|
112 |
+
kernel_size=4,
|
113 |
+
scale_factor=2):
|
114 |
+
super(DeconvModule, self).__init__()
|
115 |
+
|
116 |
+
assert (kernel_size - scale_factor >= 0) and\
|
117 |
+
(kernel_size - scale_factor) % 2 == 0,\
|
118 |
+
f'kernel_size should be greater than or equal to scale_factor '\
|
119 |
+
f'and (kernel_size - scale_factor) should be even numbers, '\
|
120 |
+
f'while the kernel size is {kernel_size} and scale_factor is '\
|
121 |
+
f'{scale_factor}.'
|
122 |
+
|
123 |
+
stride = scale_factor
|
124 |
+
padding = (kernel_size - scale_factor) // 2
|
125 |
+
self.with_cp = with_cp
|
126 |
+
deconv = nn.ConvTranspose2d(
|
127 |
+
in_channels,
|
128 |
+
out_channels,
|
129 |
+
kernel_size=kernel_size,
|
130 |
+
stride=stride,
|
131 |
+
padding=padding)
|
132 |
+
|
133 |
+
norm_name, norm = build_norm_layer(norm_cfg, out_channels)
|
134 |
+
activate = build_activation_layer(act_cfg)
|
135 |
+
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
"""Forward function."""
|
139 |
+
|
140 |
+
if self.with_cp and x.requires_grad:
|
141 |
+
out = cp.checkpoint(self.deconv_upsamping, x)
|
142 |
+
else:
|
143 |
+
out = self.deconv_upsamping(x)
|
144 |
+
return out
|
145 |
+
|
146 |
+
|
147 |
+
@UPSAMPLE_LAYERS.register_module()
|
148 |
+
class InterpConv(nn.Module):
|
149 |
+
"""Interpolation upsample module in decoder for UNet.
|
150 |
+
|
151 |
+
This module uses interpolation to upsample feature map in the decoder
|
152 |
+
of UNet. It consists of one interpolation upsample layer and one
|
153 |
+
convolutional layer. It can be one interpolation upsample layer followed
|
154 |
+
by one convolutional layer (conv_first=False) or one convolutional layer
|
155 |
+
followed by one interpolation upsample layer (conv_first=True).
|
156 |
+
|
157 |
+
Args:
|
158 |
+
in_channels (int): Number of input channels.
|
159 |
+
out_channels (int): Number of output channels.
|
160 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
161 |
+
memory while slowing down the training speed. Default: False.
|
162 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
163 |
+
Default: dict(type='BN').
|
164 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
165 |
+
Default: dict(type='ReLU').
|
166 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
167 |
+
Default: None.
|
168 |
+
conv_first (bool): Whether convolutional layer or interpolation
|
169 |
+
upsample layer first. Default: False. It means interpolation
|
170 |
+
upsample layer followed by one convolutional layer.
|
171 |
+
kernel_size (int): Kernel size of the convolutional layer. Default: 1.
|
172 |
+
stride (int): Stride of the convolutional layer. Default: 1.
|
173 |
+
padding (int): Padding of the convolutional layer. Default: 1.
|
174 |
+
upsample_cfg (dict): Interpolation config of the upsample layer.
|
175 |
+
Default: dict(
|
176 |
+
scale_factor=2, mode='bilinear', align_corners=False).
|
177 |
+
"""
|
178 |
+
|
179 |
+
def __init__(self,
|
180 |
+
in_channels,
|
181 |
+
out_channels,
|
182 |
+
with_cp=False,
|
183 |
+
norm_cfg=dict(type='BN'),
|
184 |
+
act_cfg=dict(type='ReLU'),
|
185 |
+
*,
|
186 |
+
conv_cfg=None,
|
187 |
+
conv_first=False,
|
188 |
+
kernel_size=1,
|
189 |
+
stride=1,
|
190 |
+
padding=0,
|
191 |
+
upsample_cfg=dict(
|
192 |
+
scale_factor=2, mode='bilinear', align_corners=False)):
|
193 |
+
super(InterpConv, self).__init__()
|
194 |
+
|
195 |
+
self.with_cp = with_cp
|
196 |
+
conv = ConvModule(
|
197 |
+
in_channels,
|
198 |
+
out_channels,
|
199 |
+
kernel_size=kernel_size,
|
200 |
+
stride=stride,
|
201 |
+
padding=padding,
|
202 |
+
conv_cfg=conv_cfg,
|
203 |
+
norm_cfg=norm_cfg,
|
204 |
+
act_cfg=act_cfg)
|
205 |
+
upsample = nn.Upsample(**upsample_cfg)
|
206 |
+
if conv_first:
|
207 |
+
self.interp_upsample = nn.Sequential(conv, upsample)
|
208 |
+
else:
|
209 |
+
self.interp_upsample = nn.Sequential(upsample, conv)
|
210 |
+
|
211 |
+
def forward(self, x):
|
212 |
+
"""Forward function."""
|
213 |
+
|
214 |
+
if self.with_cp and x.requires_grad:
|
215 |
+
out = cp.checkpoint(self.interp_upsample, x)
|
216 |
+
else:
|
217 |
+
out = self.interp_upsample(x)
|
218 |
+
return out
|
219 |
+
|
220 |
+
|
221 |
+
@BACKBONES.register_module()
|
222 |
+
class UNet(nn.Module):
|
223 |
+
"""UNet backbone.
|
224 |
+
U-Net: Convolutional Networks for Biomedical Image Segmentation.
|
225 |
+
https://arxiv.org/pdf/1505.04597.pdf
|
226 |
+
|
227 |
+
Args:
|
228 |
+
in_channels (int): Number of input image channels. Default" 3.
|
229 |
+
base_channels (int): Number of base channels of each stage.
|
230 |
+
The output channels of the first stage. Default: 64.
|
231 |
+
num_stages (int): Number of stages in encoder, normally 5. Default: 5.
|
232 |
+
strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
|
233 |
+
len(strides) is equal to num_stages. Normally the stride of the
|
234 |
+
first stage in encoder is 1. If strides[i]=2, it uses stride
|
235 |
+
convolution to downsample in the correspondence encoder stage.
|
236 |
+
Default: (1, 1, 1, 1, 1).
|
237 |
+
enc_num_convs (Sequence[int]): Number of convolutional layers in the
|
238 |
+
convolution block of the correspondence encoder stage.
|
239 |
+
Default: (2, 2, 2, 2, 2).
|
240 |
+
dec_num_convs (Sequence[int]): Number of convolutional layers in the
|
241 |
+
convolution block of the correspondence decoder stage.
|
242 |
+
Default: (2, 2, 2, 2).
|
243 |
+
downsamples (Sequence[int]): Whether use MaxPool to downsample the
|
244 |
+
feature map after the first stage of encoder
|
245 |
+
(stages: [1, num_stages)). If the correspondence encoder stage use
|
246 |
+
stride convolution (strides[i]=2), it will never use MaxPool to
|
247 |
+
downsample, even downsamples[i-1]=True.
|
248 |
+
Default: (True, True, True, True).
|
249 |
+
enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
|
250 |
+
Default: (1, 1, 1, 1, 1).
|
251 |
+
dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
|
252 |
+
Default: (1, 1, 1, 1).
|
253 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
254 |
+
memory while slowing down the training speed. Default: False.
|
255 |
+
conv_cfg (dict | None): Config dict for convolution layer.
|
256 |
+
Default: None.
|
257 |
+
norm_cfg (dict | None): Config dict for normalization layer.
|
258 |
+
Default: dict(type='BN').
|
259 |
+
act_cfg (dict | None): Config dict for activation layer in ConvModule.
|
260 |
+
Default: dict(type='ReLU').
|
261 |
+
upsample_cfg (dict): The upsample config of the upsample module in
|
262 |
+
decoder. Default: dict(type='InterpConv').
|
263 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
264 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
265 |
+
and its variants only. Default: False.
|
266 |
+
dcn (bool): Use deformable convolution in convolutional layer or not.
|
267 |
+
Default: None.
|
268 |
+
plugins (dict): plugins for convolutional layers. Default: None.
|
269 |
+
|
270 |
+
Notice:
|
271 |
+
The input image size should be divisible by the whole downsample rate
|
272 |
+
of the encoder. More detail of the whole downsample rate can be found
|
273 |
+
in UNet._check_input_divisible.
|
274 |
+
|
275 |
+
"""
|
276 |
+
|
277 |
+
def __init__(self,
|
278 |
+
in_channels=3,
|
279 |
+
base_channels=64,
|
280 |
+
num_stages=5,
|
281 |
+
strides=(1, 1, 1, 1, 1),
|
282 |
+
enc_num_convs=(2, 2, 2, 2, 2),
|
283 |
+
dec_num_convs=(2, 2, 2, 2),
|
284 |
+
downsamples=(True, True, True, True),
|
285 |
+
enc_dilations=(1, 1, 1, 1, 1),
|
286 |
+
dec_dilations=(1, 1, 1, 1),
|
287 |
+
with_cp=False,
|
288 |
+
conv_cfg=None,
|
289 |
+
norm_cfg=dict(type='BN'),
|
290 |
+
act_cfg=dict(type='ReLU'),
|
291 |
+
upsample_cfg=dict(type='InterpConv'),
|
292 |
+
norm_eval=False,
|
293 |
+
dcn=None,
|
294 |
+
plugins=None):
|
295 |
+
super(UNet, self).__init__()
|
296 |
+
assert dcn is None, 'Not implemented yet.'
|
297 |
+
assert plugins is None, 'Not implemented yet.'
|
298 |
+
assert len(strides) == num_stages, \
|
299 |
+
'The length of strides should be equal to num_stages, '\
|
300 |
+
f'while the strides is {strides}, the length of '\
|
301 |
+
f'strides is {len(strides)}, and the num_stages is '\
|
302 |
+
f'{num_stages}.'
|
303 |
+
assert len(enc_num_convs) == num_stages, \
|
304 |
+
'The length of enc_num_convs should be equal to num_stages, '\
|
305 |
+
f'while the enc_num_convs is {enc_num_convs}, the length of '\
|
306 |
+
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
|
307 |
+
f'{num_stages}.'
|
308 |
+
assert len(dec_num_convs) == (num_stages-1), \
|
309 |
+
'The length of dec_num_convs should be equal to (num_stages-1), '\
|
310 |
+
f'while the dec_num_convs is {dec_num_convs}, the length of '\
|
311 |
+
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
|
312 |
+
f'{num_stages}.'
|
313 |
+
assert len(downsamples) == (num_stages-1), \
|
314 |
+
'The length of downsamples should be equal to (num_stages-1), '\
|
315 |
+
f'while the downsamples is {downsamples}, the length of '\
|
316 |
+
f'downsamples is {len(downsamples)}, and the num_stages is '\
|
317 |
+
f'{num_stages}.'
|
318 |
+
assert len(enc_dilations) == num_stages, \
|
319 |
+
'The length of enc_dilations should be equal to num_stages, '\
|
320 |
+
f'while the enc_dilations is {enc_dilations}, the length of '\
|
321 |
+
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
|
322 |
+
f'{num_stages}.'
|
323 |
+
assert len(dec_dilations) == (num_stages-1), \
|
324 |
+
'The length of dec_dilations should be equal to (num_stages-1), '\
|
325 |
+
f'while the dec_dilations is {dec_dilations}, the length of '\
|
326 |
+
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
|
327 |
+
f'{num_stages}.'
|
328 |
+
self.num_stages = num_stages
|
329 |
+
self.strides = strides
|
330 |
+
self.downsamples = downsamples
|
331 |
+
self.norm_eval = norm_eval
|
332 |
+
self.base_channels = base_channels
|
333 |
+
|
334 |
+
self.encoder = nn.ModuleList()
|
335 |
+
self.decoder = nn.ModuleList()
|
336 |
+
|
337 |
+
for i in range(num_stages):
|
338 |
+
enc_conv_block = []
|
339 |
+
if i != 0:
|
340 |
+
if strides[i] == 1 and downsamples[i - 1]:
|
341 |
+
enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
|
342 |
+
upsample = (strides[i] != 1 or downsamples[i - 1])
|
343 |
+
self.decoder.append(
|
344 |
+
UpConvBlock(
|
345 |
+
conv_block=BasicConvBlock,
|
346 |
+
in_channels=base_channels * 2**i,
|
347 |
+
skip_channels=base_channels * 2**(i - 1),
|
348 |
+
out_channels=base_channels * 2**(i - 1),
|
349 |
+
num_convs=dec_num_convs[i - 1],
|
350 |
+
stride=1,
|
351 |
+
dilation=dec_dilations[i - 1],
|
352 |
+
with_cp=with_cp,
|
353 |
+
conv_cfg=conv_cfg,
|
354 |
+
norm_cfg=norm_cfg,
|
355 |
+
act_cfg=act_cfg,
|
356 |
+
upsample_cfg=upsample_cfg if upsample else None,
|
357 |
+
dcn=None,
|
358 |
+
plugins=None))
|
359 |
+
|
360 |
+
enc_conv_block.append(
|
361 |
+
BasicConvBlock(
|
362 |
+
in_channels=in_channels,
|
363 |
+
out_channels=base_channels * 2**i,
|
364 |
+
num_convs=enc_num_convs[i],
|
365 |
+
stride=strides[i],
|
366 |
+
dilation=enc_dilations[i],
|
367 |
+
with_cp=with_cp,
|
368 |
+
conv_cfg=conv_cfg,
|
369 |
+
norm_cfg=norm_cfg,
|
370 |
+
act_cfg=act_cfg,
|
371 |
+
dcn=None,
|
372 |
+
plugins=None))
|
373 |
+
self.encoder.append((nn.Sequential(*enc_conv_block)))
|
374 |
+
in_channels = base_channels * 2**i
|
375 |
+
|
376 |
+
def forward(self, x):
|
377 |
+
self._check_input_divisible(x)
|
378 |
+
enc_outs = []
|
379 |
+
for enc in self.encoder:
|
380 |
+
x = enc(x)
|
381 |
+
enc_outs.append(x)
|
382 |
+
dec_outs = [x]
|
383 |
+
for i in reversed(range(len(self.decoder))):
|
384 |
+
x = self.decoder[i](enc_outs[i], x)
|
385 |
+
dec_outs.append(x)
|
386 |
+
|
387 |
+
return dec_outs
|
388 |
+
|
389 |
+
def train(self, mode=True):
|
390 |
+
"""Convert the model into training mode while keep normalization layer
|
391 |
+
freezed."""
|
392 |
+
super(UNet, self).train(mode)
|
393 |
+
if mode and self.norm_eval:
|
394 |
+
for m in self.modules():
|
395 |
+
# trick: eval have effect on BatchNorm only
|
396 |
+
if isinstance(m, _BatchNorm):
|
397 |
+
m.eval()
|
398 |
+
|
399 |
+
def _check_input_divisible(self, x):
|
400 |
+
h, w = x.shape[-2:]
|
401 |
+
whole_downsample_rate = 1
|
402 |
+
for i in range(1, self.num_stages):
|
403 |
+
if self.strides[i] == 2 or self.downsamples[i - 1]:
|
404 |
+
whole_downsample_rate *= 2
|
405 |
+
assert (h % whole_downsample_rate == 0) \
|
406 |
+
and (w % whole_downsample_rate == 0),\
|
407 |
+
f'The input image size {(h, w)} should be divisible by the whole '\
|
408 |
+
f'downsample rate {whole_downsample_rate}, when num_stages is '\
|
409 |
+
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
|
410 |
+
f'is {self.downsamples}.'
|
411 |
+
|
412 |
+
def init_weights(self, pretrained=None):
|
413 |
+
"""Initialize the weights in backbone.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
pretrained (str, optional): Path to pre-trained weights.
|
417 |
+
Defaults to None.
|
418 |
+
"""
|
419 |
+
if isinstance(pretrained, str):
|
420 |
+
logger = get_root_logger()
|
421 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
422 |
+
elif pretrained is None:
|
423 |
+
for m in self.modules():
|
424 |
+
if isinstance(m, nn.Conv2d):
|
425 |
+
kaiming_init(m)
|
426 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
427 |
+
constant_init(m, 1)
|
428 |
+
else:
|
429 |
+
raise TypeError('pretrained must be a str or None')
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/backbones/vit.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/rwightman/pytorch-image-
|
2 |
+
models/blob/master/timm/models/vision_transformer.py."""
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.utils.checkpoint as cp
|
10 |
+
from annotator.mmpkg.mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
|
11 |
+
constant_init, kaiming_init, normal_init)
|
12 |
+
from annotator.mmpkg.mmcv.runner import _load_checkpoint
|
13 |
+
from annotator.mmpkg.mmcv.utils.parrots_wrapper import _BatchNorm
|
14 |
+
|
15 |
+
from annotator.mmpkg.mmseg.utils import get_root_logger
|
16 |
+
from ..builder import BACKBONES
|
17 |
+
from ..utils import DropPath, trunc_normal_
|
18 |
+
|
19 |
+
|
20 |
+
class Mlp(nn.Module):
|
21 |
+
"""MLP layer for Encoder block.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
in_features(int): Input dimension for the first fully
|
25 |
+
connected layer.
|
26 |
+
hidden_features(int): Output dimension for the first fully
|
27 |
+
connected layer.
|
28 |
+
out_features(int): Output dementsion for the second fully
|
29 |
+
connected layer.
|
30 |
+
act_cfg(dict): Config dict for activation layer.
|
31 |
+
Default: dict(type='GELU').
|
32 |
+
drop(float): Drop rate for the dropout layer. Dropout rate has
|
33 |
+
to be between 0 and 1. Default: 0.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
in_features,
|
38 |
+
hidden_features=None,
|
39 |
+
out_features=None,
|
40 |
+
act_cfg=dict(type='GELU'),
|
41 |
+
drop=0.):
|
42 |
+
super(Mlp, self).__init__()
|
43 |
+
out_features = out_features or in_features
|
44 |
+
hidden_features = hidden_features or in_features
|
45 |
+
self.fc1 = Linear(in_features, hidden_features)
|
46 |
+
self.act = build_activation_layer(act_cfg)
|
47 |
+
self.fc2 = Linear(hidden_features, out_features)
|
48 |
+
self.drop = nn.Dropout(drop)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = self.fc1(x)
|
52 |
+
x = self.act(x)
|
53 |
+
x = self.drop(x)
|
54 |
+
x = self.fc2(x)
|
55 |
+
x = self.drop(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class Attention(nn.Module):
|
60 |
+
"""Attention layer for Encoder block.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
dim (int): Dimension for the input vector.
|
64 |
+
num_heads (int): Number of parallel attention heads.
|
65 |
+
qkv_bias (bool): Enable bias for qkv if True. Default: False.
|
66 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
67 |
+
attn_drop (float): Drop rate for attention output weights.
|
68 |
+
Default: 0.
|
69 |
+
proj_drop (float): Drop rate for output weights. Default: 0.
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self,
|
73 |
+
dim,
|
74 |
+
num_heads=8,
|
75 |
+
qkv_bias=False,
|
76 |
+
qk_scale=None,
|
77 |
+
attn_drop=0.,
|
78 |
+
proj_drop=0.):
|
79 |
+
super(Attention, self).__init__()
|
80 |
+
self.num_heads = num_heads
|
81 |
+
head_dim = dim // num_heads
|
82 |
+
self.scale = qk_scale or head_dim**-0.5
|
83 |
+
|
84 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
85 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
86 |
+
self.proj = Linear(dim, dim)
|
87 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
b, n, c = x.shape
|
91 |
+
qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
|
92 |
+
c // self.num_heads).permute(2, 0, 3, 1, 4)
|
93 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
94 |
+
|
95 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
96 |
+
attn = attn.softmax(dim=-1)
|
97 |
+
attn = self.attn_drop(attn)
|
98 |
+
|
99 |
+
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
|
100 |
+
x = self.proj(x)
|
101 |
+
x = self.proj_drop(x)
|
102 |
+
return x
|
103 |
+
|
104 |
+
|
105 |
+
class Block(nn.Module):
|
106 |
+
"""Implements encoder block with residual connection.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
dim (int): The feature dimension.
|
110 |
+
num_heads (int): Number of parallel attention heads.
|
111 |
+
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
112 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
113 |
+
drop (float): Drop rate for mlp output weights. Default: 0.
|
114 |
+
attn_drop (float): Drop rate for attention output weights.
|
115 |
+
Default: 0.
|
116 |
+
proj_drop (float): Drop rate for attn layer output weights.
|
117 |
+
Default: 0.
|
118 |
+
drop_path (float): Drop rate for paths of model.
|
119 |
+
Default: 0.
|
120 |
+
act_cfg (dict): Config dict for activation layer.
|
121 |
+
Default: dict(type='GELU').
|
122 |
+
norm_cfg (dict): Config dict for normalization layer.
|
123 |
+
Default: dict(type='LN', requires_grad=True).
|
124 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
125 |
+
memory while slowing down the training speed. Default: False.
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self,
|
129 |
+
dim,
|
130 |
+
num_heads,
|
131 |
+
mlp_ratio=4,
|
132 |
+
qkv_bias=False,
|
133 |
+
qk_scale=None,
|
134 |
+
drop=0.,
|
135 |
+
attn_drop=0.,
|
136 |
+
proj_drop=0.,
|
137 |
+
drop_path=0.,
|
138 |
+
act_cfg=dict(type='GELU'),
|
139 |
+
norm_cfg=dict(type='LN', eps=1e-6),
|
140 |
+
with_cp=False):
|
141 |
+
super(Block, self).__init__()
|
142 |
+
self.with_cp = with_cp
|
143 |
+
_, self.norm1 = build_norm_layer(norm_cfg, dim)
|
144 |
+
self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
|
145 |
+
proj_drop)
|
146 |
+
self.drop_path = DropPath(
|
147 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
148 |
+
_, self.norm2 = build_norm_layer(norm_cfg, dim)
|
149 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
150 |
+
self.mlp = Mlp(
|
151 |
+
in_features=dim,
|
152 |
+
hidden_features=mlp_hidden_dim,
|
153 |
+
act_cfg=act_cfg,
|
154 |
+
drop=drop)
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
|
158 |
+
def _inner_forward(x):
|
159 |
+
out = x + self.drop_path(self.attn(self.norm1(x)))
|
160 |
+
out = out + self.drop_path(self.mlp(self.norm2(out)))
|
161 |
+
return out
|
162 |
+
|
163 |
+
if self.with_cp and x.requires_grad:
|
164 |
+
out = cp.checkpoint(_inner_forward, x)
|
165 |
+
else:
|
166 |
+
out = _inner_forward(x)
|
167 |
+
|
168 |
+
return out
|
169 |
+
|
170 |
+
|
171 |
+
class PatchEmbed(nn.Module):
|
172 |
+
"""Image to Patch Embedding.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
img_size (int | tuple): Input image size.
|
176 |
+
default: 224.
|
177 |
+
patch_size (int): Width and height for a patch.
|
178 |
+
default: 16.
|
179 |
+
in_channels (int): Input channels for images. Default: 3.
|
180 |
+
embed_dim (int): The embedding dimension. Default: 768.
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(self,
|
184 |
+
img_size=224,
|
185 |
+
patch_size=16,
|
186 |
+
in_channels=3,
|
187 |
+
embed_dim=768):
|
188 |
+
super(PatchEmbed, self).__init__()
|
189 |
+
if isinstance(img_size, int):
|
190 |
+
self.img_size = (img_size, img_size)
|
191 |
+
elif isinstance(img_size, tuple):
|
192 |
+
self.img_size = img_size
|
193 |
+
else:
|
194 |
+
raise TypeError('img_size must be type of int or tuple')
|
195 |
+
h, w = self.img_size
|
196 |
+
self.patch_size = (patch_size, patch_size)
|
197 |
+
self.num_patches = (h // patch_size) * (w // patch_size)
|
198 |
+
self.proj = Conv2d(
|
199 |
+
in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
return self.proj(x).flatten(2).transpose(1, 2)
|
203 |
+
|
204 |
+
|
205 |
+
@BACKBONES.register_module()
|
206 |
+
class VisionTransformer(nn.Module):
|
207 |
+
"""Vision transformer backbone.
|
208 |
+
|
209 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
|
210 |
+
Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
|
211 |
+
|
212 |
+
Args:
|
213 |
+
img_size (tuple): input image size. Default: (224, 224).
|
214 |
+
patch_size (int, tuple): patch size. Default: 16.
|
215 |
+
in_channels (int): number of input channels. Default: 3.
|
216 |
+
embed_dim (int): embedding dimension. Default: 768.
|
217 |
+
depth (int): depth of transformer. Default: 12.
|
218 |
+
num_heads (int): number of attention heads. Default: 12.
|
219 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
|
220 |
+
Default: 4.
|
221 |
+
out_indices (list | tuple | int): Output from which stages.
|
222 |
+
Default: -1.
|
223 |
+
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
224 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
|
225 |
+
drop_rate (float): dropout rate. Default: 0.
|
226 |
+
attn_drop_rate (float): attention dropout rate. Default: 0.
|
227 |
+
drop_path_rate (float): Rate of DropPath. Default: 0.
|
228 |
+
norm_cfg (dict): Config dict for normalization layer.
|
229 |
+
Default: dict(type='LN', eps=1e-6, requires_grad=True).
|
230 |
+
act_cfg (dict): Config dict for activation layer.
|
231 |
+
Default: dict(type='GELU').
|
232 |
+
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
233 |
+
freeze running stats (mean and var). Note: Effect on Batch Norm
|
234 |
+
and its variants only. Default: False.
|
235 |
+
final_norm (bool): Whether to add a additional layer to normalize
|
236 |
+
final feature map. Default: False.
|
237 |
+
interpolate_mode (str): Select the interpolate mode for position
|
238 |
+
embeding vector resize. Default: bicubic.
|
239 |
+
with_cls_token (bool): If concatenating class token into image tokens
|
240 |
+
as transformer input. Default: True.
|
241 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint
|
242 |
+
will save some memory while slowing down the training speed.
|
243 |
+
Default: False.
|
244 |
+
"""
|
245 |
+
|
246 |
+
def __init__(self,
|
247 |
+
img_size=(224, 224),
|
248 |
+
patch_size=16,
|
249 |
+
in_channels=3,
|
250 |
+
embed_dim=768,
|
251 |
+
depth=12,
|
252 |
+
num_heads=12,
|
253 |
+
mlp_ratio=4,
|
254 |
+
out_indices=11,
|
255 |
+
qkv_bias=True,
|
256 |
+
qk_scale=None,
|
257 |
+
drop_rate=0.,
|
258 |
+
attn_drop_rate=0.,
|
259 |
+
drop_path_rate=0.,
|
260 |
+
norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True),
|
261 |
+
act_cfg=dict(type='GELU'),
|
262 |
+
norm_eval=False,
|
263 |
+
final_norm=False,
|
264 |
+
with_cls_token=True,
|
265 |
+
interpolate_mode='bicubic',
|
266 |
+
with_cp=False):
|
267 |
+
super(VisionTransformer, self).__init__()
|
268 |
+
self.img_size = img_size
|
269 |
+
self.patch_size = patch_size
|
270 |
+
self.features = self.embed_dim = embed_dim
|
271 |
+
self.patch_embed = PatchEmbed(
|
272 |
+
img_size=img_size,
|
273 |
+
patch_size=patch_size,
|
274 |
+
in_channels=in_channels,
|
275 |
+
embed_dim=embed_dim)
|
276 |
+
|
277 |
+
self.with_cls_token = with_cls_token
|
278 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
279 |
+
self.pos_embed = nn.Parameter(
|
280 |
+
torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
|
281 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
282 |
+
|
283 |
+
if isinstance(out_indices, int):
|
284 |
+
self.out_indices = [out_indices]
|
285 |
+
elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
|
286 |
+
self.out_indices = out_indices
|
287 |
+
else:
|
288 |
+
raise TypeError('out_indices must be type of int, list or tuple')
|
289 |
+
|
290 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
291 |
+
] # stochastic depth decay rule
|
292 |
+
self.blocks = nn.ModuleList([
|
293 |
+
Block(
|
294 |
+
dim=embed_dim,
|
295 |
+
num_heads=num_heads,
|
296 |
+
mlp_ratio=mlp_ratio,
|
297 |
+
qkv_bias=qkv_bias,
|
298 |
+
qk_scale=qk_scale,
|
299 |
+
drop=dpr[i],
|
300 |
+
attn_drop=attn_drop_rate,
|
301 |
+
act_cfg=act_cfg,
|
302 |
+
norm_cfg=norm_cfg,
|
303 |
+
with_cp=with_cp) for i in range(depth)
|
304 |
+
])
|
305 |
+
|
306 |
+
self.interpolate_mode = interpolate_mode
|
307 |
+
self.final_norm = final_norm
|
308 |
+
if final_norm:
|
309 |
+
_, self.norm = build_norm_layer(norm_cfg, embed_dim)
|
310 |
+
|
311 |
+
self.norm_eval = norm_eval
|
312 |
+
self.with_cp = with_cp
|
313 |
+
|
314 |
+
def init_weights(self, pretrained=None):
|
315 |
+
if isinstance(pretrained, str):
|
316 |
+
logger = get_root_logger()
|
317 |
+
checkpoint = _load_checkpoint(pretrained, logger=logger)
|
318 |
+
if 'state_dict' in checkpoint:
|
319 |
+
state_dict = checkpoint['state_dict']
|
320 |
+
else:
|
321 |
+
state_dict = checkpoint
|
322 |
+
|
323 |
+
if 'pos_embed' in state_dict.keys():
|
324 |
+
if self.pos_embed.shape != state_dict['pos_embed'].shape:
|
325 |
+
logger.info(msg=f'Resize the pos_embed shape from \
|
326 |
+
{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
|
327 |
+
h, w = self.img_size
|
328 |
+
pos_size = int(
|
329 |
+
math.sqrt(state_dict['pos_embed'].shape[1] - 1))
|
330 |
+
state_dict['pos_embed'] = self.resize_pos_embed(
|
331 |
+
state_dict['pos_embed'], (h, w), (pos_size, pos_size),
|
332 |
+
self.patch_size, self.interpolate_mode)
|
333 |
+
|
334 |
+
self.load_state_dict(state_dict, False)
|
335 |
+
|
336 |
+
elif pretrained is None:
|
337 |
+
# We only implement the 'jax_impl' initialization implemented at
|
338 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
|
339 |
+
trunc_normal_(self.pos_embed, std=.02)
|
340 |
+
trunc_normal_(self.cls_token, std=.02)
|
341 |
+
for n, m in self.named_modules():
|
342 |
+
if isinstance(m, Linear):
|
343 |
+
trunc_normal_(m.weight, std=.02)
|
344 |
+
if m.bias is not None:
|
345 |
+
if 'mlp' in n:
|
346 |
+
normal_init(m.bias, std=1e-6)
|
347 |
+
else:
|
348 |
+
constant_init(m.bias, 0)
|
349 |
+
elif isinstance(m, Conv2d):
|
350 |
+
kaiming_init(m.weight, mode='fan_in')
|
351 |
+
if m.bias is not None:
|
352 |
+
constant_init(m.bias, 0)
|
353 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
354 |
+
constant_init(m.bias, 0)
|
355 |
+
constant_init(m.weight, 1.0)
|
356 |
+
else:
|
357 |
+
raise TypeError('pretrained must be a str or None')
|
358 |
+
|
359 |
+
def _pos_embeding(self, img, patched_img, pos_embed):
|
360 |
+
"""Positiong embeding method.
|
361 |
+
|
362 |
+
Resize the pos_embed, if the input image size doesn't match
|
363 |
+
the training size.
|
364 |
+
Args:
|
365 |
+
img (torch.Tensor): The inference image tensor, the shape
|
366 |
+
must be [B, C, H, W].
|
367 |
+
patched_img (torch.Tensor): The patched image, it should be
|
368 |
+
shape of [B, L1, C].
|
369 |
+
pos_embed (torch.Tensor): The pos_embed weighs, it should be
|
370 |
+
shape of [B, L2, c].
|
371 |
+
Return:
|
372 |
+
torch.Tensor: The pos encoded image feature.
|
373 |
+
"""
|
374 |
+
assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
|
375 |
+
'the shapes of patched_img and pos_embed must be [B, L, C]'
|
376 |
+
x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
|
377 |
+
if x_len != pos_len:
|
378 |
+
if pos_len == (self.img_size[0] // self.patch_size) * (
|
379 |
+
self.img_size[1] // self.patch_size) + 1:
|
380 |
+
pos_h = self.img_size[0] // self.patch_size
|
381 |
+
pos_w = self.img_size[1] // self.patch_size
|
382 |
+
else:
|
383 |
+
raise ValueError(
|
384 |
+
'Unexpected shape of pos_embed, got {}.'.format(
|
385 |
+
pos_embed.shape))
|
386 |
+
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
|
387 |
+
(pos_h, pos_w), self.patch_size,
|
388 |
+
self.interpolate_mode)
|
389 |
+
return self.pos_drop(patched_img + pos_embed)
|
390 |
+
|
391 |
+
@staticmethod
|
392 |
+
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
|
393 |
+
"""Resize pos_embed weights.
|
394 |
+
|
395 |
+
Resize pos_embed using bicubic interpolate method.
|
396 |
+
Args:
|
397 |
+
pos_embed (torch.Tensor): pos_embed weights.
|
398 |
+
input_shpae (tuple): Tuple for (input_h, intput_w).
|
399 |
+
pos_shape (tuple): Tuple for (pos_h, pos_w).
|
400 |
+
patch_size (int): Patch size.
|
401 |
+
Return:
|
402 |
+
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
403 |
+
"""
|
404 |
+
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
405 |
+
input_h, input_w = input_shpae
|
406 |
+
pos_h, pos_w = pos_shape
|
407 |
+
cls_token_weight = pos_embed[:, 0]
|
408 |
+
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
409 |
+
pos_embed_weight = pos_embed_weight.reshape(
|
410 |
+
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
411 |
+
pos_embed_weight = F.interpolate(
|
412 |
+
pos_embed_weight,
|
413 |
+
size=[input_h // patch_size, input_w // patch_size],
|
414 |
+
align_corners=False,
|
415 |
+
mode=mode)
|
416 |
+
cls_token_weight = cls_token_weight.unsqueeze(1)
|
417 |
+
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
418 |
+
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
419 |
+
return pos_embed
|
420 |
+
|
421 |
+
def forward(self, inputs):
|
422 |
+
B = inputs.shape[0]
|
423 |
+
|
424 |
+
x = self.patch_embed(inputs)
|
425 |
+
|
426 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
427 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
428 |
+
x = self._pos_embeding(inputs, x, self.pos_embed)
|
429 |
+
|
430 |
+
if not self.with_cls_token:
|
431 |
+
# Remove class token for transformer input
|
432 |
+
x = x[:, 1:]
|
433 |
+
|
434 |
+
outs = []
|
435 |
+
for i, blk in enumerate(self.blocks):
|
436 |
+
x = blk(x)
|
437 |
+
if i == len(self.blocks) - 1:
|
438 |
+
if self.final_norm:
|
439 |
+
x = self.norm(x)
|
440 |
+
if i in self.out_indices:
|
441 |
+
if self.with_cls_token:
|
442 |
+
# Remove class token and reshape token for decoder head
|
443 |
+
out = x[:, 1:]
|
444 |
+
else:
|
445 |
+
out = x
|
446 |
+
B, _, C = out.shape
|
447 |
+
out = out.reshape(B, inputs.shape[2] // self.patch_size,
|
448 |
+
inputs.shape[3] // self.patch_size,
|
449 |
+
C).permute(0, 3, 1, 2)
|
450 |
+
outs.append(out)
|
451 |
+
|
452 |
+
return tuple(outs)
|
453 |
+
|
454 |
+
def train(self, mode=True):
|
455 |
+
super(VisionTransformer, self).train(mode)
|
456 |
+
if mode and self.norm_eval:
|
457 |
+
for m in self.modules():
|
458 |
+
if isinstance(m, nn.LayerNorm):
|
459 |
+
m.eval()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/builder.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
from annotator.mmpkg.mmcv.cnn import MODELS as MMCV_MODELS
|
4 |
+
from annotator.mmpkg.mmcv.utils import Registry
|
5 |
+
|
6 |
+
MODELS = Registry('models', parent=MMCV_MODELS)
|
7 |
+
|
8 |
+
BACKBONES = MODELS
|
9 |
+
NECKS = MODELS
|
10 |
+
HEADS = MODELS
|
11 |
+
LOSSES = MODELS
|
12 |
+
SEGMENTORS = MODELS
|
13 |
+
|
14 |
+
|
15 |
+
def build_backbone(cfg):
|
16 |
+
"""Build backbone."""
|
17 |
+
return BACKBONES.build(cfg)
|
18 |
+
|
19 |
+
|
20 |
+
def build_neck(cfg):
|
21 |
+
"""Build neck."""
|
22 |
+
return NECKS.build(cfg)
|
23 |
+
|
24 |
+
|
25 |
+
def build_head(cfg):
|
26 |
+
"""Build head."""
|
27 |
+
return HEADS.build(cfg)
|
28 |
+
|
29 |
+
|
30 |
+
def build_loss(cfg):
|
31 |
+
"""Build loss."""
|
32 |
+
return LOSSES.build(cfg)
|
33 |
+
|
34 |
+
|
35 |
+
def build_segmentor(cfg, train_cfg=None, test_cfg=None):
|
36 |
+
"""Build segmentor."""
|
37 |
+
if train_cfg is not None or test_cfg is not None:
|
38 |
+
warnings.warn(
|
39 |
+
'train_cfg and test_cfg is deprecated, '
|
40 |
+
'please specify them in model', UserWarning)
|
41 |
+
assert cfg.get('train_cfg') is None or train_cfg is None, \
|
42 |
+
'train_cfg specified in both outer field and model field '
|
43 |
+
assert cfg.get('test_cfg') is None or test_cfg is None, \
|
44 |
+
'test_cfg specified in both outer field and model field '
|
45 |
+
return SEGMENTORS.build(
|
46 |
+
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ann_head import ANNHead
|
2 |
+
from .apc_head import APCHead
|
3 |
+
from .aspp_head import ASPPHead
|
4 |
+
from .cc_head import CCHead
|
5 |
+
from .da_head import DAHead
|
6 |
+
from .dm_head import DMHead
|
7 |
+
from .dnl_head import DNLHead
|
8 |
+
from .ema_head import EMAHead
|
9 |
+
from .enc_head import EncHead
|
10 |
+
from .fcn_head import FCNHead
|
11 |
+
from .fpn_head import FPNHead
|
12 |
+
from .gc_head import GCHead
|
13 |
+
from .lraspp_head import LRASPPHead
|
14 |
+
from .nl_head import NLHead
|
15 |
+
from .ocr_head import OCRHead
|
16 |
+
# from .point_head import PointHead
|
17 |
+
from .psa_head import PSAHead
|
18 |
+
from .psp_head import PSPHead
|
19 |
+
from .sep_aspp_head import DepthwiseSeparableASPPHead
|
20 |
+
from .sep_fcn_head import DepthwiseSeparableFCNHead
|
21 |
+
from .uper_head import UPerHead
|
22 |
+
|
23 |
+
__all__ = [
|
24 |
+
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
|
25 |
+
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
|
26 |
+
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
|
27 |
+
'APCHead', 'DMHead', 'LRASPPHead'
|
28 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ann_head.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
4 |
+
|
5 |
+
from ..builder import HEADS
|
6 |
+
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
7 |
+
from .decode_head import BaseDecodeHead
|
8 |
+
|
9 |
+
|
10 |
+
class PPMConcat(nn.ModuleList):
|
11 |
+
"""Pyramid Pooling Module that only concat the features of each layer.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
15 |
+
Module.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, pool_scales=(1, 3, 6, 8)):
|
19 |
+
super(PPMConcat, self).__init__(
|
20 |
+
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
|
21 |
+
|
22 |
+
def forward(self, feats):
|
23 |
+
"""Forward function."""
|
24 |
+
ppm_outs = []
|
25 |
+
for ppm in self:
|
26 |
+
ppm_out = ppm(feats)
|
27 |
+
ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
|
28 |
+
concat_outs = torch.cat(ppm_outs, dim=2)
|
29 |
+
return concat_outs
|
30 |
+
|
31 |
+
|
32 |
+
class SelfAttentionBlock(_SelfAttentionBlock):
|
33 |
+
"""Make a ANN used SelfAttentionBlock.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
low_in_channels (int): Input channels of lower level feature,
|
37 |
+
which is the key feature for self-attention.
|
38 |
+
high_in_channels (int): Input channels of higher level feature,
|
39 |
+
which is the query feature for self-attention.
|
40 |
+
channels (int): Output channels of key/query transform.
|
41 |
+
out_channels (int): Output channels.
|
42 |
+
share_key_query (bool): Whether share projection weight between key
|
43 |
+
and query projection.
|
44 |
+
query_scale (int): The scale of query feature map.
|
45 |
+
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
46 |
+
Module of key feature.
|
47 |
+
conv_cfg (dict|None): Config of conv layers.
|
48 |
+
norm_cfg (dict|None): Config of norm layers.
|
49 |
+
act_cfg (dict|None): Config of activation layers.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, low_in_channels, high_in_channels, channels,
|
53 |
+
out_channels, share_key_query, query_scale, key_pool_scales,
|
54 |
+
conv_cfg, norm_cfg, act_cfg):
|
55 |
+
key_psp = PPMConcat(key_pool_scales)
|
56 |
+
if query_scale > 1:
|
57 |
+
query_downsample = nn.MaxPool2d(kernel_size=query_scale)
|
58 |
+
else:
|
59 |
+
query_downsample = None
|
60 |
+
super(SelfAttentionBlock, self).__init__(
|
61 |
+
key_in_channels=low_in_channels,
|
62 |
+
query_in_channels=high_in_channels,
|
63 |
+
channels=channels,
|
64 |
+
out_channels=out_channels,
|
65 |
+
share_key_query=share_key_query,
|
66 |
+
query_downsample=query_downsample,
|
67 |
+
key_downsample=key_psp,
|
68 |
+
key_query_num_convs=1,
|
69 |
+
key_query_norm=True,
|
70 |
+
value_out_num_convs=1,
|
71 |
+
value_out_norm=False,
|
72 |
+
matmul_norm=True,
|
73 |
+
with_out=True,
|
74 |
+
conv_cfg=conv_cfg,
|
75 |
+
norm_cfg=norm_cfg,
|
76 |
+
act_cfg=act_cfg)
|
77 |
+
|
78 |
+
|
79 |
+
class AFNB(nn.Module):
|
80 |
+
"""Asymmetric Fusion Non-local Block(AFNB)
|
81 |
+
|
82 |
+
Args:
|
83 |
+
low_in_channels (int): Input channels of lower level feature,
|
84 |
+
which is the key feature for self-attention.
|
85 |
+
high_in_channels (int): Input channels of higher level feature,
|
86 |
+
which is the query feature for self-attention.
|
87 |
+
channels (int): Output channels of key/query transform.
|
88 |
+
out_channels (int): Output channels.
|
89 |
+
and query projection.
|
90 |
+
query_scales (tuple[int]): The scales of query feature map.
|
91 |
+
Default: (1,)
|
92 |
+
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
93 |
+
Module of key feature.
|
94 |
+
conv_cfg (dict|None): Config of conv layers.
|
95 |
+
norm_cfg (dict|None): Config of norm layers.
|
96 |
+
act_cfg (dict|None): Config of activation layers.
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, low_in_channels, high_in_channels, channels,
|
100 |
+
out_channels, query_scales, key_pool_scales, conv_cfg,
|
101 |
+
norm_cfg, act_cfg):
|
102 |
+
super(AFNB, self).__init__()
|
103 |
+
self.stages = nn.ModuleList()
|
104 |
+
for query_scale in query_scales:
|
105 |
+
self.stages.append(
|
106 |
+
SelfAttentionBlock(
|
107 |
+
low_in_channels=low_in_channels,
|
108 |
+
high_in_channels=high_in_channels,
|
109 |
+
channels=channels,
|
110 |
+
out_channels=out_channels,
|
111 |
+
share_key_query=False,
|
112 |
+
query_scale=query_scale,
|
113 |
+
key_pool_scales=key_pool_scales,
|
114 |
+
conv_cfg=conv_cfg,
|
115 |
+
norm_cfg=norm_cfg,
|
116 |
+
act_cfg=act_cfg))
|
117 |
+
self.bottleneck = ConvModule(
|
118 |
+
out_channels + high_in_channels,
|
119 |
+
out_channels,
|
120 |
+
1,
|
121 |
+
conv_cfg=conv_cfg,
|
122 |
+
norm_cfg=norm_cfg,
|
123 |
+
act_cfg=None)
|
124 |
+
|
125 |
+
def forward(self, low_feats, high_feats):
|
126 |
+
"""Forward function."""
|
127 |
+
priors = [stage(high_feats, low_feats) for stage in self.stages]
|
128 |
+
context = torch.stack(priors, dim=0).sum(dim=0)
|
129 |
+
output = self.bottleneck(torch.cat([context, high_feats], 1))
|
130 |
+
return output
|
131 |
+
|
132 |
+
|
133 |
+
class APNB(nn.Module):
|
134 |
+
"""Asymmetric Pyramid Non-local Block (APNB)
|
135 |
+
|
136 |
+
Args:
|
137 |
+
in_channels (int): Input channels of key/query feature,
|
138 |
+
which is the key feature for self-attention.
|
139 |
+
channels (int): Output channels of key/query transform.
|
140 |
+
out_channels (int): Output channels.
|
141 |
+
query_scales (tuple[int]): The scales of query feature map.
|
142 |
+
Default: (1,)
|
143 |
+
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
144 |
+
Module of key feature.
|
145 |
+
conv_cfg (dict|None): Config of conv layers.
|
146 |
+
norm_cfg (dict|None): Config of norm layers.
|
147 |
+
act_cfg (dict|None): Config of activation layers.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self, in_channels, channels, out_channels, query_scales,
|
151 |
+
key_pool_scales, conv_cfg, norm_cfg, act_cfg):
|
152 |
+
super(APNB, self).__init__()
|
153 |
+
self.stages = nn.ModuleList()
|
154 |
+
for query_scale in query_scales:
|
155 |
+
self.stages.append(
|
156 |
+
SelfAttentionBlock(
|
157 |
+
low_in_channels=in_channels,
|
158 |
+
high_in_channels=in_channels,
|
159 |
+
channels=channels,
|
160 |
+
out_channels=out_channels,
|
161 |
+
share_key_query=True,
|
162 |
+
query_scale=query_scale,
|
163 |
+
key_pool_scales=key_pool_scales,
|
164 |
+
conv_cfg=conv_cfg,
|
165 |
+
norm_cfg=norm_cfg,
|
166 |
+
act_cfg=act_cfg))
|
167 |
+
self.bottleneck = ConvModule(
|
168 |
+
2 * in_channels,
|
169 |
+
out_channels,
|
170 |
+
1,
|
171 |
+
conv_cfg=conv_cfg,
|
172 |
+
norm_cfg=norm_cfg,
|
173 |
+
act_cfg=act_cfg)
|
174 |
+
|
175 |
+
def forward(self, feats):
|
176 |
+
"""Forward function."""
|
177 |
+
priors = [stage(feats, feats) for stage in self.stages]
|
178 |
+
context = torch.stack(priors, dim=0).sum(dim=0)
|
179 |
+
output = self.bottleneck(torch.cat([context, feats], 1))
|
180 |
+
return output
|
181 |
+
|
182 |
+
|
183 |
+
@HEADS.register_module()
|
184 |
+
class ANNHead(BaseDecodeHead):
|
185 |
+
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.
|
186 |
+
|
187 |
+
This head is the implementation of `ANNNet
|
188 |
+
<https://arxiv.org/abs/1908.07678>`_.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
project_channels (int): Projection channels for Nonlocal.
|
192 |
+
query_scales (tuple[int]): The scales of query feature map.
|
193 |
+
Default: (1,)
|
194 |
+
key_pool_scales (tuple[int]): The pooling scales of key feature map.
|
195 |
+
Default: (1, 3, 6, 8).
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self,
|
199 |
+
project_channels,
|
200 |
+
query_scales=(1, ),
|
201 |
+
key_pool_scales=(1, 3, 6, 8),
|
202 |
+
**kwargs):
|
203 |
+
super(ANNHead, self).__init__(
|
204 |
+
input_transform='multiple_select', **kwargs)
|
205 |
+
assert len(self.in_channels) == 2
|
206 |
+
low_in_channels, high_in_channels = self.in_channels
|
207 |
+
self.project_channels = project_channels
|
208 |
+
self.fusion = AFNB(
|
209 |
+
low_in_channels=low_in_channels,
|
210 |
+
high_in_channels=high_in_channels,
|
211 |
+
out_channels=high_in_channels,
|
212 |
+
channels=project_channels,
|
213 |
+
query_scales=query_scales,
|
214 |
+
key_pool_scales=key_pool_scales,
|
215 |
+
conv_cfg=self.conv_cfg,
|
216 |
+
norm_cfg=self.norm_cfg,
|
217 |
+
act_cfg=self.act_cfg)
|
218 |
+
self.bottleneck = ConvModule(
|
219 |
+
high_in_channels,
|
220 |
+
self.channels,
|
221 |
+
3,
|
222 |
+
padding=1,
|
223 |
+
conv_cfg=self.conv_cfg,
|
224 |
+
norm_cfg=self.norm_cfg,
|
225 |
+
act_cfg=self.act_cfg)
|
226 |
+
self.context = APNB(
|
227 |
+
in_channels=self.channels,
|
228 |
+
out_channels=self.channels,
|
229 |
+
channels=project_channels,
|
230 |
+
query_scales=query_scales,
|
231 |
+
key_pool_scales=key_pool_scales,
|
232 |
+
conv_cfg=self.conv_cfg,
|
233 |
+
norm_cfg=self.norm_cfg,
|
234 |
+
act_cfg=self.act_cfg)
|
235 |
+
|
236 |
+
def forward(self, inputs):
|
237 |
+
"""Forward function."""
|
238 |
+
low_feats, high_feats = self._transform_inputs(inputs)
|
239 |
+
output = self.fusion(low_feats, high_feats)
|
240 |
+
output = self.dropout(output)
|
241 |
+
output = self.bottleneck(output)
|
242 |
+
output = self.context(output)
|
243 |
+
output = self.cls_seg(output)
|
244 |
+
|
245 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/apc_head.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmseg.ops import resize
|
7 |
+
from ..builder import HEADS
|
8 |
+
from .decode_head import BaseDecodeHead
|
9 |
+
|
10 |
+
|
11 |
+
class ACM(nn.Module):
|
12 |
+
"""Adaptive Context Module used in APCNet.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
pool_scale (int): Pooling scale used in Adaptive Context
|
16 |
+
Module to extract region features.
|
17 |
+
fusion (bool): Add one conv to fuse residual feature.
|
18 |
+
in_channels (int): Input channels.
|
19 |
+
channels (int): Channels after modules, before conv_seg.
|
20 |
+
conv_cfg (dict | None): Config of conv layers.
|
21 |
+
norm_cfg (dict | None): Config of norm layers.
|
22 |
+
act_cfg (dict): Config of activation layers.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
|
26 |
+
norm_cfg, act_cfg):
|
27 |
+
super(ACM, self).__init__()
|
28 |
+
self.pool_scale = pool_scale
|
29 |
+
self.fusion = fusion
|
30 |
+
self.in_channels = in_channels
|
31 |
+
self.channels = channels
|
32 |
+
self.conv_cfg = conv_cfg
|
33 |
+
self.norm_cfg = norm_cfg
|
34 |
+
self.act_cfg = act_cfg
|
35 |
+
self.pooled_redu_conv = ConvModule(
|
36 |
+
self.in_channels,
|
37 |
+
self.channels,
|
38 |
+
1,
|
39 |
+
conv_cfg=self.conv_cfg,
|
40 |
+
norm_cfg=self.norm_cfg,
|
41 |
+
act_cfg=self.act_cfg)
|
42 |
+
|
43 |
+
self.input_redu_conv = ConvModule(
|
44 |
+
self.in_channels,
|
45 |
+
self.channels,
|
46 |
+
1,
|
47 |
+
conv_cfg=self.conv_cfg,
|
48 |
+
norm_cfg=self.norm_cfg,
|
49 |
+
act_cfg=self.act_cfg)
|
50 |
+
|
51 |
+
self.global_info = ConvModule(
|
52 |
+
self.channels,
|
53 |
+
self.channels,
|
54 |
+
1,
|
55 |
+
conv_cfg=self.conv_cfg,
|
56 |
+
norm_cfg=self.norm_cfg,
|
57 |
+
act_cfg=self.act_cfg)
|
58 |
+
|
59 |
+
self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
|
60 |
+
|
61 |
+
self.residual_conv = ConvModule(
|
62 |
+
self.channels,
|
63 |
+
self.channels,
|
64 |
+
1,
|
65 |
+
conv_cfg=self.conv_cfg,
|
66 |
+
norm_cfg=self.norm_cfg,
|
67 |
+
act_cfg=self.act_cfg)
|
68 |
+
|
69 |
+
if self.fusion:
|
70 |
+
self.fusion_conv = ConvModule(
|
71 |
+
self.channels,
|
72 |
+
self.channels,
|
73 |
+
1,
|
74 |
+
conv_cfg=self.conv_cfg,
|
75 |
+
norm_cfg=self.norm_cfg,
|
76 |
+
act_cfg=self.act_cfg)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
"""Forward function."""
|
80 |
+
pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
|
81 |
+
# [batch_size, channels, h, w]
|
82 |
+
x = self.input_redu_conv(x)
|
83 |
+
# [batch_size, channels, pool_scale, pool_scale]
|
84 |
+
pooled_x = self.pooled_redu_conv(pooled_x)
|
85 |
+
batch_size = x.size(0)
|
86 |
+
# [batch_size, pool_scale * pool_scale, channels]
|
87 |
+
pooled_x = pooled_x.view(batch_size, self.channels,
|
88 |
+
-1).permute(0, 2, 1).contiguous()
|
89 |
+
# [batch_size, h * w, pool_scale * pool_scale]
|
90 |
+
affinity_matrix = self.gla(x + resize(
|
91 |
+
self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
|
92 |
+
).permute(0, 2, 3, 1).reshape(
|
93 |
+
batch_size, -1, self.pool_scale**2)
|
94 |
+
affinity_matrix = F.sigmoid(affinity_matrix)
|
95 |
+
# [batch_size, h * w, channels]
|
96 |
+
z_out = torch.matmul(affinity_matrix, pooled_x)
|
97 |
+
# [batch_size, channels, h * w]
|
98 |
+
z_out = z_out.permute(0, 2, 1).contiguous()
|
99 |
+
# [batch_size, channels, h, w]
|
100 |
+
z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
|
101 |
+
z_out = self.residual_conv(z_out)
|
102 |
+
z_out = F.relu(z_out + x)
|
103 |
+
if self.fusion:
|
104 |
+
z_out = self.fusion_conv(z_out)
|
105 |
+
|
106 |
+
return z_out
|
107 |
+
|
108 |
+
|
109 |
+
@HEADS.register_module()
|
110 |
+
class APCHead(BaseDecodeHead):
|
111 |
+
"""Adaptive Pyramid Context Network for Semantic Segmentation.
|
112 |
+
|
113 |
+
This head is the implementation of
|
114 |
+
`APCNet <https://openaccess.thecvf.com/content_CVPR_2019/papers/\
|
115 |
+
He_Adaptive_Pyramid_Context_Network_for_Semantic_Segmentation_\
|
116 |
+
CVPR_2019_paper.pdf>`_.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
pool_scales (tuple[int]): Pooling scales used in Adaptive Context
|
120 |
+
Module. Default: (1, 2, 3, 6).
|
121 |
+
fusion (bool): Add one conv to fuse residual feature.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
|
125 |
+
super(APCHead, self).__init__(**kwargs)
|
126 |
+
assert isinstance(pool_scales, (list, tuple))
|
127 |
+
self.pool_scales = pool_scales
|
128 |
+
self.fusion = fusion
|
129 |
+
acm_modules = []
|
130 |
+
for pool_scale in self.pool_scales:
|
131 |
+
acm_modules.append(
|
132 |
+
ACM(pool_scale,
|
133 |
+
self.fusion,
|
134 |
+
self.in_channels,
|
135 |
+
self.channels,
|
136 |
+
conv_cfg=self.conv_cfg,
|
137 |
+
norm_cfg=self.norm_cfg,
|
138 |
+
act_cfg=self.act_cfg))
|
139 |
+
self.acm_modules = nn.ModuleList(acm_modules)
|
140 |
+
self.bottleneck = ConvModule(
|
141 |
+
self.in_channels + len(pool_scales) * self.channels,
|
142 |
+
self.channels,
|
143 |
+
3,
|
144 |
+
padding=1,
|
145 |
+
conv_cfg=self.conv_cfg,
|
146 |
+
norm_cfg=self.norm_cfg,
|
147 |
+
act_cfg=self.act_cfg)
|
148 |
+
|
149 |
+
def forward(self, inputs):
|
150 |
+
"""Forward function."""
|
151 |
+
x = self._transform_inputs(inputs)
|
152 |
+
acm_outs = [x]
|
153 |
+
for acm_module in self.acm_modules:
|
154 |
+
acm_outs.append(acm_module(x))
|
155 |
+
acm_outs = torch.cat(acm_outs, dim=1)
|
156 |
+
output = self.bottleneck(acm_outs)
|
157 |
+
output = self.cls_seg(output)
|
158 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/aspp_head.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
4 |
+
|
5 |
+
from annotator.mmpkg.mmseg.ops import resize
|
6 |
+
from ..builder import HEADS
|
7 |
+
from .decode_head import BaseDecodeHead
|
8 |
+
|
9 |
+
|
10 |
+
class ASPPModule(nn.ModuleList):
|
11 |
+
"""Atrous Spatial Pyramid Pooling (ASPP) Module.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
dilations (tuple[int]): Dilation rate of each layer.
|
15 |
+
in_channels (int): Input channels.
|
16 |
+
channels (int): Channels after modules, before conv_seg.
|
17 |
+
conv_cfg (dict|None): Config of conv layers.
|
18 |
+
norm_cfg (dict|None): Config of norm layers.
|
19 |
+
act_cfg (dict): Config of activation layers.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
|
23 |
+
act_cfg):
|
24 |
+
super(ASPPModule, self).__init__()
|
25 |
+
self.dilations = dilations
|
26 |
+
self.in_channels = in_channels
|
27 |
+
self.channels = channels
|
28 |
+
self.conv_cfg = conv_cfg
|
29 |
+
self.norm_cfg = norm_cfg
|
30 |
+
self.act_cfg = act_cfg
|
31 |
+
for dilation in dilations:
|
32 |
+
self.append(
|
33 |
+
ConvModule(
|
34 |
+
self.in_channels,
|
35 |
+
self.channels,
|
36 |
+
1 if dilation == 1 else 3,
|
37 |
+
dilation=dilation,
|
38 |
+
padding=0 if dilation == 1 else dilation,
|
39 |
+
conv_cfg=self.conv_cfg,
|
40 |
+
norm_cfg=self.norm_cfg,
|
41 |
+
act_cfg=self.act_cfg))
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
"""Forward function."""
|
45 |
+
aspp_outs = []
|
46 |
+
for aspp_module in self:
|
47 |
+
aspp_outs.append(aspp_module(x))
|
48 |
+
|
49 |
+
return aspp_outs
|
50 |
+
|
51 |
+
|
52 |
+
@HEADS.register_module()
|
53 |
+
class ASPPHead(BaseDecodeHead):
|
54 |
+
"""Rethinking Atrous Convolution for Semantic Image Segmentation.
|
55 |
+
|
56 |
+
This head is the implementation of `DeepLabV3
|
57 |
+
<https://arxiv.org/abs/1706.05587>`_.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
dilations (tuple[int]): Dilation rates for ASPP module.
|
61 |
+
Default: (1, 6, 12, 18).
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
|
65 |
+
super(ASPPHead, self).__init__(**kwargs)
|
66 |
+
assert isinstance(dilations, (list, tuple))
|
67 |
+
self.dilations = dilations
|
68 |
+
self.image_pool = nn.Sequential(
|
69 |
+
nn.AdaptiveAvgPool2d(1),
|
70 |
+
ConvModule(
|
71 |
+
self.in_channels,
|
72 |
+
self.channels,
|
73 |
+
1,
|
74 |
+
conv_cfg=self.conv_cfg,
|
75 |
+
norm_cfg=self.norm_cfg,
|
76 |
+
act_cfg=self.act_cfg))
|
77 |
+
self.aspp_modules = ASPPModule(
|
78 |
+
dilations,
|
79 |
+
self.in_channels,
|
80 |
+
self.channels,
|
81 |
+
conv_cfg=self.conv_cfg,
|
82 |
+
norm_cfg=self.norm_cfg,
|
83 |
+
act_cfg=self.act_cfg)
|
84 |
+
self.bottleneck = ConvModule(
|
85 |
+
(len(dilations) + 1) * self.channels,
|
86 |
+
self.channels,
|
87 |
+
3,
|
88 |
+
padding=1,
|
89 |
+
conv_cfg=self.conv_cfg,
|
90 |
+
norm_cfg=self.norm_cfg,
|
91 |
+
act_cfg=self.act_cfg)
|
92 |
+
|
93 |
+
def forward(self, inputs):
|
94 |
+
"""Forward function."""
|
95 |
+
x = self._transform_inputs(inputs)
|
96 |
+
aspp_outs = [
|
97 |
+
resize(
|
98 |
+
self.image_pool(x),
|
99 |
+
size=x.size()[2:],
|
100 |
+
mode='bilinear',
|
101 |
+
align_corners=self.align_corners)
|
102 |
+
]
|
103 |
+
aspp_outs.extend(self.aspp_modules(x))
|
104 |
+
aspp_outs = torch.cat(aspp_outs, dim=1)
|
105 |
+
output = self.bottleneck(aspp_outs)
|
106 |
+
output = self.cls_seg(output)
|
107 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/cascade_decode_head.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
|
3 |
+
from .decode_head import BaseDecodeHead
|
4 |
+
|
5 |
+
|
6 |
+
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
|
7 |
+
"""Base class for cascade decode head used in
|
8 |
+
:class:`CascadeEncoderDecoder."""
|
9 |
+
|
10 |
+
def __init__(self, *args, **kwargs):
|
11 |
+
super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
|
12 |
+
|
13 |
+
@abstractmethod
|
14 |
+
def forward(self, inputs, prev_output):
|
15 |
+
"""Placeholder of forward function."""
|
16 |
+
pass
|
17 |
+
|
18 |
+
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
|
19 |
+
train_cfg):
|
20 |
+
"""Forward function for training.
|
21 |
+
Args:
|
22 |
+
inputs (list[Tensor]): List of multi-level img features.
|
23 |
+
prev_output (Tensor): The output of previous decode head.
|
24 |
+
img_metas (list[dict]): List of image info dict where each dict
|
25 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
26 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
27 |
+
For details on the values of these keys see
|
28 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
29 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
30 |
+
used if the architecture supports semantic segmentation task.
|
31 |
+
train_cfg (dict): The training config.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
dict[str, Tensor]: a dictionary of loss components
|
35 |
+
"""
|
36 |
+
seg_logits = self.forward(inputs, prev_output)
|
37 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
38 |
+
|
39 |
+
return losses
|
40 |
+
|
41 |
+
def forward_test(self, inputs, prev_output, img_metas, test_cfg):
|
42 |
+
"""Forward function for testing.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
inputs (list[Tensor]): List of multi-level img features.
|
46 |
+
prev_output (Tensor): The output of previous decode head.
|
47 |
+
img_metas (list[dict]): List of image info dict where each dict
|
48 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
49 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
50 |
+
For details on the values of these keys see
|
51 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
52 |
+
test_cfg (dict): The testing config.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
Tensor: Output segmentation map.
|
56 |
+
"""
|
57 |
+
return self.forward(inputs, prev_output)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/cc_head.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..builder import HEADS
|
4 |
+
from .fcn_head import FCNHead
|
5 |
+
|
6 |
+
try:
|
7 |
+
try:
|
8 |
+
from mmcv.ops import CrissCrossAttention
|
9 |
+
except ImportError:
|
10 |
+
from annotator.mmpkg.mmcv.ops import CrissCrossAttention
|
11 |
+
except ModuleNotFoundError:
|
12 |
+
CrissCrossAttention = None
|
13 |
+
|
14 |
+
|
15 |
+
@HEADS.register_module()
|
16 |
+
class CCHead(FCNHead):
|
17 |
+
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
|
18 |
+
|
19 |
+
This head is the implementation of `CCNet
|
20 |
+
<https://arxiv.org/abs/1811.11721>`_.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
recurrence (int): Number of recurrence of Criss Cross Attention
|
24 |
+
module. Default: 2.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, recurrence=2, **kwargs):
|
28 |
+
if CrissCrossAttention is None:
|
29 |
+
raise RuntimeError('Please install mmcv-full for '
|
30 |
+
'CrissCrossAttention ops')
|
31 |
+
super(CCHead, self).__init__(num_convs=2, **kwargs)
|
32 |
+
self.recurrence = recurrence
|
33 |
+
self.cca = CrissCrossAttention(self.channels)
|
34 |
+
|
35 |
+
def forward(self, inputs):
|
36 |
+
"""Forward function."""
|
37 |
+
x = self._transform_inputs(inputs)
|
38 |
+
output = self.convs[0](x)
|
39 |
+
for _ in range(self.recurrence):
|
40 |
+
output = self.cca(output)
|
41 |
+
output = self.convs[1](output)
|
42 |
+
if self.concat_input:
|
43 |
+
output = self.conv_cat(torch.cat([x, output], dim=1))
|
44 |
+
output = self.cls_seg(output)
|
45 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/da_head.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule, Scale
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmseg.core import add_prefix
|
7 |
+
from ..builder import HEADS
|
8 |
+
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
9 |
+
from .decode_head import BaseDecodeHead
|
10 |
+
|
11 |
+
|
12 |
+
class PAM(_SelfAttentionBlock):
|
13 |
+
"""Position Attention Module (PAM)
|
14 |
+
|
15 |
+
Args:
|
16 |
+
in_channels (int): Input channels of key/query feature.
|
17 |
+
channels (int): Output channels of key/query transform.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, in_channels, channels):
|
21 |
+
super(PAM, self).__init__(
|
22 |
+
key_in_channels=in_channels,
|
23 |
+
query_in_channels=in_channels,
|
24 |
+
channels=channels,
|
25 |
+
out_channels=in_channels,
|
26 |
+
share_key_query=False,
|
27 |
+
query_downsample=None,
|
28 |
+
key_downsample=None,
|
29 |
+
key_query_num_convs=1,
|
30 |
+
key_query_norm=False,
|
31 |
+
value_out_num_convs=1,
|
32 |
+
value_out_norm=False,
|
33 |
+
matmul_norm=False,
|
34 |
+
with_out=False,
|
35 |
+
conv_cfg=None,
|
36 |
+
norm_cfg=None,
|
37 |
+
act_cfg=None)
|
38 |
+
|
39 |
+
self.gamma = Scale(0)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
"""Forward function."""
|
43 |
+
out = super(PAM, self).forward(x, x)
|
44 |
+
|
45 |
+
out = self.gamma(out) + x
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
class CAM(nn.Module):
|
50 |
+
"""Channel Attention Module (CAM)"""
|
51 |
+
|
52 |
+
def __init__(self):
|
53 |
+
super(CAM, self).__init__()
|
54 |
+
self.gamma = Scale(0)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
"""Forward function."""
|
58 |
+
batch_size, channels, height, width = x.size()
|
59 |
+
proj_query = x.view(batch_size, channels, -1)
|
60 |
+
proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
|
61 |
+
energy = torch.bmm(proj_query, proj_key)
|
62 |
+
energy_new = torch.max(
|
63 |
+
energy, -1, keepdim=True)[0].expand_as(energy) - energy
|
64 |
+
attention = F.softmax(energy_new, dim=-1)
|
65 |
+
proj_value = x.view(batch_size, channels, -1)
|
66 |
+
|
67 |
+
out = torch.bmm(attention, proj_value)
|
68 |
+
out = out.view(batch_size, channels, height, width)
|
69 |
+
|
70 |
+
out = self.gamma(out) + x
|
71 |
+
return out
|
72 |
+
|
73 |
+
|
74 |
+
@HEADS.register_module()
|
75 |
+
class DAHead(BaseDecodeHead):
|
76 |
+
"""Dual Attention Network for Scene Segmentation.
|
77 |
+
|
78 |
+
This head is the implementation of `DANet
|
79 |
+
<https://arxiv.org/abs/1809.02983>`_.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
pam_channels (int): The channels of Position Attention Module(PAM).
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self, pam_channels, **kwargs):
|
86 |
+
super(DAHead, self).__init__(**kwargs)
|
87 |
+
self.pam_channels = pam_channels
|
88 |
+
self.pam_in_conv = ConvModule(
|
89 |
+
self.in_channels,
|
90 |
+
self.channels,
|
91 |
+
3,
|
92 |
+
padding=1,
|
93 |
+
conv_cfg=self.conv_cfg,
|
94 |
+
norm_cfg=self.norm_cfg,
|
95 |
+
act_cfg=self.act_cfg)
|
96 |
+
self.pam = PAM(self.channels, pam_channels)
|
97 |
+
self.pam_out_conv = ConvModule(
|
98 |
+
self.channels,
|
99 |
+
self.channels,
|
100 |
+
3,
|
101 |
+
padding=1,
|
102 |
+
conv_cfg=self.conv_cfg,
|
103 |
+
norm_cfg=self.norm_cfg,
|
104 |
+
act_cfg=self.act_cfg)
|
105 |
+
self.pam_conv_seg = nn.Conv2d(
|
106 |
+
self.channels, self.num_classes, kernel_size=1)
|
107 |
+
|
108 |
+
self.cam_in_conv = ConvModule(
|
109 |
+
self.in_channels,
|
110 |
+
self.channels,
|
111 |
+
3,
|
112 |
+
padding=1,
|
113 |
+
conv_cfg=self.conv_cfg,
|
114 |
+
norm_cfg=self.norm_cfg,
|
115 |
+
act_cfg=self.act_cfg)
|
116 |
+
self.cam = CAM()
|
117 |
+
self.cam_out_conv = ConvModule(
|
118 |
+
self.channels,
|
119 |
+
self.channels,
|
120 |
+
3,
|
121 |
+
padding=1,
|
122 |
+
conv_cfg=self.conv_cfg,
|
123 |
+
norm_cfg=self.norm_cfg,
|
124 |
+
act_cfg=self.act_cfg)
|
125 |
+
self.cam_conv_seg = nn.Conv2d(
|
126 |
+
self.channels, self.num_classes, kernel_size=1)
|
127 |
+
|
128 |
+
def pam_cls_seg(self, feat):
|
129 |
+
"""PAM feature classification."""
|
130 |
+
if self.dropout is not None:
|
131 |
+
feat = self.dropout(feat)
|
132 |
+
output = self.pam_conv_seg(feat)
|
133 |
+
return output
|
134 |
+
|
135 |
+
def cam_cls_seg(self, feat):
|
136 |
+
"""CAM feature classification."""
|
137 |
+
if self.dropout is not None:
|
138 |
+
feat = self.dropout(feat)
|
139 |
+
output = self.cam_conv_seg(feat)
|
140 |
+
return output
|
141 |
+
|
142 |
+
def forward(self, inputs):
|
143 |
+
"""Forward function."""
|
144 |
+
x = self._transform_inputs(inputs)
|
145 |
+
pam_feat = self.pam_in_conv(x)
|
146 |
+
pam_feat = self.pam(pam_feat)
|
147 |
+
pam_feat = self.pam_out_conv(pam_feat)
|
148 |
+
pam_out = self.pam_cls_seg(pam_feat)
|
149 |
+
|
150 |
+
cam_feat = self.cam_in_conv(x)
|
151 |
+
cam_feat = self.cam(cam_feat)
|
152 |
+
cam_feat = self.cam_out_conv(cam_feat)
|
153 |
+
cam_out = self.cam_cls_seg(cam_feat)
|
154 |
+
|
155 |
+
feat_sum = pam_feat + cam_feat
|
156 |
+
pam_cam_out = self.cls_seg(feat_sum)
|
157 |
+
|
158 |
+
return pam_cam_out, pam_out, cam_out
|
159 |
+
|
160 |
+
def forward_test(self, inputs, img_metas, test_cfg):
|
161 |
+
"""Forward function for testing, only ``pam_cam`` is used."""
|
162 |
+
return self.forward(inputs)[0]
|
163 |
+
|
164 |
+
def losses(self, seg_logit, seg_label):
|
165 |
+
"""Compute ``pam_cam``, ``pam``, ``cam`` loss."""
|
166 |
+
pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
|
167 |
+
loss = dict()
|
168 |
+
loss.update(
|
169 |
+
add_prefix(
|
170 |
+
super(DAHead, self).losses(pam_cam_seg_logit, seg_label),
|
171 |
+
'pam_cam'))
|
172 |
+
loss.update(
|
173 |
+
add_prefix(
|
174 |
+
super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam'))
|
175 |
+
loss.update(
|
176 |
+
add_prefix(
|
177 |
+
super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam'))
|
178 |
+
return loss
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/decode_head.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from annotator.mmpkg.mmcv.cnn import normal_init
|
6 |
+
from annotator.mmpkg.mmcv.runner import auto_fp16, force_fp32
|
7 |
+
|
8 |
+
from annotator.mmpkg.mmseg.core import build_pixel_sampler
|
9 |
+
from annotator.mmpkg.mmseg.ops import resize
|
10 |
+
from ..builder import build_loss
|
11 |
+
from ..losses import accuracy
|
12 |
+
|
13 |
+
|
14 |
+
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
15 |
+
"""Base class for BaseDecodeHead.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
in_channels (int|Sequence[int]): Input channels.
|
19 |
+
channels (int): Channels after modules, before conv_seg.
|
20 |
+
num_classes (int): Number of classes.
|
21 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
22 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
23 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
24 |
+
act_cfg (dict): Config of activation layers.
|
25 |
+
Default: dict(type='ReLU')
|
26 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
27 |
+
input_transform (str|None): Transformation type of input features.
|
28 |
+
Options: 'resize_concat', 'multiple_select', None.
|
29 |
+
'resize_concat': Multiple feature maps will be resize to the
|
30 |
+
same size as first one and than concat together.
|
31 |
+
Usually used in FCN head of HRNet.
|
32 |
+
'multiple_select': Multiple feature maps will be bundle into
|
33 |
+
a list and passed into decode head.
|
34 |
+
None: Only one select feature map is allowed.
|
35 |
+
Default: None.
|
36 |
+
loss_decode (dict): Config of decode loss.
|
37 |
+
Default: dict(type='CrossEntropyLoss').
|
38 |
+
ignore_index (int | None): The label index to be ignored. When using
|
39 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
40 |
+
sampler (dict|None): The config of segmentation map sampler.
|
41 |
+
Default: None.
|
42 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
43 |
+
Default: False.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
in_channels,
|
48 |
+
channels,
|
49 |
+
*,
|
50 |
+
num_classes,
|
51 |
+
dropout_ratio=0.1,
|
52 |
+
conv_cfg=None,
|
53 |
+
norm_cfg=None,
|
54 |
+
act_cfg=dict(type='ReLU'),
|
55 |
+
in_index=-1,
|
56 |
+
input_transform=None,
|
57 |
+
loss_decode=dict(
|
58 |
+
type='CrossEntropyLoss',
|
59 |
+
use_sigmoid=False,
|
60 |
+
loss_weight=1.0),
|
61 |
+
ignore_index=255,
|
62 |
+
sampler=None,
|
63 |
+
align_corners=False):
|
64 |
+
super(BaseDecodeHead, self).__init__()
|
65 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
66 |
+
self.channels = channels
|
67 |
+
self.num_classes = num_classes
|
68 |
+
self.dropout_ratio = dropout_ratio
|
69 |
+
self.conv_cfg = conv_cfg
|
70 |
+
self.norm_cfg = norm_cfg
|
71 |
+
self.act_cfg = act_cfg
|
72 |
+
self.in_index = in_index
|
73 |
+
self.loss_decode = build_loss(loss_decode)
|
74 |
+
self.ignore_index = ignore_index
|
75 |
+
self.align_corners = align_corners
|
76 |
+
if sampler is not None:
|
77 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
78 |
+
else:
|
79 |
+
self.sampler = None
|
80 |
+
|
81 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
82 |
+
if dropout_ratio > 0:
|
83 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
84 |
+
else:
|
85 |
+
self.dropout = None
|
86 |
+
self.fp16_enabled = False
|
87 |
+
|
88 |
+
def extra_repr(self):
|
89 |
+
"""Extra repr."""
|
90 |
+
s = f'input_transform={self.input_transform}, ' \
|
91 |
+
f'ignore_index={self.ignore_index}, ' \
|
92 |
+
f'align_corners={self.align_corners}'
|
93 |
+
return s
|
94 |
+
|
95 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
96 |
+
"""Check and initialize input transforms.
|
97 |
+
|
98 |
+
The in_channels, in_index and input_transform must match.
|
99 |
+
Specifically, when input_transform is None, only single feature map
|
100 |
+
will be selected. So in_channels and in_index must be of type int.
|
101 |
+
When input_transform
|
102 |
+
|
103 |
+
Args:
|
104 |
+
in_channels (int|Sequence[int]): Input channels.
|
105 |
+
in_index (int|Sequence[int]): Input feature index.
|
106 |
+
input_transform (str|None): Transformation type of input features.
|
107 |
+
Options: 'resize_concat', 'multiple_select', None.
|
108 |
+
'resize_concat': Multiple feature maps will be resize to the
|
109 |
+
same size as first one and than concat together.
|
110 |
+
Usually used in FCN head of HRNet.
|
111 |
+
'multiple_select': Multiple feature maps will be bundle into
|
112 |
+
a list and passed into decode head.
|
113 |
+
None: Only one select feature map is allowed.
|
114 |
+
"""
|
115 |
+
|
116 |
+
if input_transform is not None:
|
117 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
118 |
+
self.input_transform = input_transform
|
119 |
+
self.in_index = in_index
|
120 |
+
if input_transform is not None:
|
121 |
+
assert isinstance(in_channels, (list, tuple))
|
122 |
+
assert isinstance(in_index, (list, tuple))
|
123 |
+
assert len(in_channels) == len(in_index)
|
124 |
+
if input_transform == 'resize_concat':
|
125 |
+
self.in_channels = sum(in_channels)
|
126 |
+
else:
|
127 |
+
self.in_channels = in_channels
|
128 |
+
else:
|
129 |
+
assert isinstance(in_channels, int)
|
130 |
+
assert isinstance(in_index, int)
|
131 |
+
self.in_channels = in_channels
|
132 |
+
|
133 |
+
def init_weights(self):
|
134 |
+
"""Initialize weights of classification layer."""
|
135 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
136 |
+
|
137 |
+
def _transform_inputs(self, inputs):
|
138 |
+
"""Transform inputs for decoder.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
inputs (list[Tensor]): List of multi-level img features.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
Tensor: The transformed inputs
|
145 |
+
"""
|
146 |
+
|
147 |
+
if self.input_transform == 'resize_concat':
|
148 |
+
inputs = [inputs[i] for i in self.in_index]
|
149 |
+
upsampled_inputs = [
|
150 |
+
resize(
|
151 |
+
input=x,
|
152 |
+
size=inputs[0].shape[2:],
|
153 |
+
mode='bilinear',
|
154 |
+
align_corners=self.align_corners) for x in inputs
|
155 |
+
]
|
156 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
157 |
+
elif self.input_transform == 'multiple_select':
|
158 |
+
inputs = [inputs[i] for i in self.in_index]
|
159 |
+
else:
|
160 |
+
inputs = inputs[self.in_index]
|
161 |
+
|
162 |
+
return inputs
|
163 |
+
|
164 |
+
@auto_fp16()
|
165 |
+
@abstractmethod
|
166 |
+
def forward(self, inputs):
|
167 |
+
"""Placeholder of forward function."""
|
168 |
+
pass
|
169 |
+
|
170 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
171 |
+
"""Forward function for training.
|
172 |
+
Args:
|
173 |
+
inputs (list[Tensor]): List of multi-level img features.
|
174 |
+
img_metas (list[dict]): List of image info dict where each dict
|
175 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
176 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
177 |
+
For details on the values of these keys see
|
178 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
179 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
180 |
+
used if the architecture supports semantic segmentation task.
|
181 |
+
train_cfg (dict): The training config.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
dict[str, Tensor]: a dictionary of loss components
|
185 |
+
"""
|
186 |
+
seg_logits = self.forward(inputs)
|
187 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
188 |
+
return losses
|
189 |
+
|
190 |
+
def forward_test(self, inputs, img_metas, test_cfg):
|
191 |
+
"""Forward function for testing.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
inputs (list[Tensor]): List of multi-level img features.
|
195 |
+
img_metas (list[dict]): List of image info dict where each dict
|
196 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
197 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
198 |
+
For details on the values of these keys see
|
199 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
200 |
+
test_cfg (dict): The testing config.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
Tensor: Output segmentation map.
|
204 |
+
"""
|
205 |
+
return self.forward(inputs)
|
206 |
+
|
207 |
+
def cls_seg(self, feat):
|
208 |
+
"""Classify each pixel."""
|
209 |
+
if self.dropout is not None:
|
210 |
+
feat = self.dropout(feat)
|
211 |
+
output = self.conv_seg(feat)
|
212 |
+
return output
|
213 |
+
|
214 |
+
@force_fp32(apply_to=('seg_logit', ))
|
215 |
+
def losses(self, seg_logit, seg_label):
|
216 |
+
"""Compute segmentation loss."""
|
217 |
+
loss = dict()
|
218 |
+
seg_logit = resize(
|
219 |
+
input=seg_logit,
|
220 |
+
size=seg_label.shape[2:],
|
221 |
+
mode='bilinear',
|
222 |
+
align_corners=self.align_corners)
|
223 |
+
if self.sampler is not None:
|
224 |
+
seg_weight = self.sampler.sample(seg_logit, seg_label)
|
225 |
+
else:
|
226 |
+
seg_weight = None
|
227 |
+
seg_label = seg_label.squeeze(1)
|
228 |
+
loss['loss_seg'] = self.loss_decode(
|
229 |
+
seg_logit,
|
230 |
+
seg_label,
|
231 |
+
weight=seg_weight,
|
232 |
+
ignore_index=self.ignore_index)
|
233 |
+
loss['acc_seg'] = accuracy(seg_logit, seg_label)
|
234 |
+
return loss
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/dm_head.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
5 |
+
|
6 |
+
from ..builder import HEADS
|
7 |
+
from .decode_head import BaseDecodeHead
|
8 |
+
|
9 |
+
|
10 |
+
class DCM(nn.Module):
|
11 |
+
"""Dynamic Convolutional Module used in DMNet.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
filter_size (int): The filter size of generated convolution kernel
|
15 |
+
used in Dynamic Convolutional Module.
|
16 |
+
fusion (bool): Add one conv to fuse DCM output feature.
|
17 |
+
in_channels (int): Input channels.
|
18 |
+
channels (int): Channels after modules, before conv_seg.
|
19 |
+
conv_cfg (dict | None): Config of conv layers.
|
20 |
+
norm_cfg (dict | None): Config of norm layers.
|
21 |
+
act_cfg (dict): Config of activation layers.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
|
25 |
+
norm_cfg, act_cfg):
|
26 |
+
super(DCM, self).__init__()
|
27 |
+
self.filter_size = filter_size
|
28 |
+
self.fusion = fusion
|
29 |
+
self.in_channels = in_channels
|
30 |
+
self.channels = channels
|
31 |
+
self.conv_cfg = conv_cfg
|
32 |
+
self.norm_cfg = norm_cfg
|
33 |
+
self.act_cfg = act_cfg
|
34 |
+
self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
|
35 |
+
0)
|
36 |
+
|
37 |
+
self.input_redu_conv = ConvModule(
|
38 |
+
self.in_channels,
|
39 |
+
self.channels,
|
40 |
+
1,
|
41 |
+
conv_cfg=self.conv_cfg,
|
42 |
+
norm_cfg=self.norm_cfg,
|
43 |
+
act_cfg=self.act_cfg)
|
44 |
+
|
45 |
+
if self.norm_cfg is not None:
|
46 |
+
self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
|
47 |
+
else:
|
48 |
+
self.norm = None
|
49 |
+
self.activate = build_activation_layer(self.act_cfg)
|
50 |
+
|
51 |
+
if self.fusion:
|
52 |
+
self.fusion_conv = ConvModule(
|
53 |
+
self.channels,
|
54 |
+
self.channels,
|
55 |
+
1,
|
56 |
+
conv_cfg=self.conv_cfg,
|
57 |
+
norm_cfg=self.norm_cfg,
|
58 |
+
act_cfg=self.act_cfg)
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
"""Forward function."""
|
62 |
+
generated_filter = self.filter_gen_conv(
|
63 |
+
F.adaptive_avg_pool2d(x, self.filter_size))
|
64 |
+
x = self.input_redu_conv(x)
|
65 |
+
b, c, h, w = x.shape
|
66 |
+
# [1, b * c, h, w], c = self.channels
|
67 |
+
x = x.view(1, b * c, h, w)
|
68 |
+
# [b * c, 1, filter_size, filter_size]
|
69 |
+
generated_filter = generated_filter.view(b * c, 1, self.filter_size,
|
70 |
+
self.filter_size)
|
71 |
+
pad = (self.filter_size - 1) // 2
|
72 |
+
if (self.filter_size - 1) % 2 == 0:
|
73 |
+
p2d = (pad, pad, pad, pad)
|
74 |
+
else:
|
75 |
+
p2d = (pad + 1, pad, pad + 1, pad)
|
76 |
+
x = F.pad(input=x, pad=p2d, mode='constant', value=0)
|
77 |
+
# [1, b * c, h, w]
|
78 |
+
output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
|
79 |
+
# [b, c, h, w]
|
80 |
+
output = output.view(b, c, h, w)
|
81 |
+
if self.norm is not None:
|
82 |
+
output = self.norm(output)
|
83 |
+
output = self.activate(output)
|
84 |
+
|
85 |
+
if self.fusion:
|
86 |
+
output = self.fusion_conv(output)
|
87 |
+
|
88 |
+
return output
|
89 |
+
|
90 |
+
|
91 |
+
@HEADS.register_module()
|
92 |
+
class DMHead(BaseDecodeHead):
|
93 |
+
"""Dynamic Multi-scale Filters for Semantic Segmentation.
|
94 |
+
|
95 |
+
This head is the implementation of
|
96 |
+
`DMNet <https://openaccess.thecvf.com/content_ICCV_2019/papers/\
|
97 |
+
He_Dynamic_Multi-Scale_Filters_for_Semantic_Segmentation_\
|
98 |
+
ICCV_2019_paper.pdf>`_.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
filter_sizes (tuple[int]): The size of generated convolutional filters
|
102 |
+
used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
|
103 |
+
fusion (bool): Add one conv to fuse DCM output feature.
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
|
107 |
+
super(DMHead, self).__init__(**kwargs)
|
108 |
+
assert isinstance(filter_sizes, (list, tuple))
|
109 |
+
self.filter_sizes = filter_sizes
|
110 |
+
self.fusion = fusion
|
111 |
+
dcm_modules = []
|
112 |
+
for filter_size in self.filter_sizes:
|
113 |
+
dcm_modules.append(
|
114 |
+
DCM(filter_size,
|
115 |
+
self.fusion,
|
116 |
+
self.in_channels,
|
117 |
+
self.channels,
|
118 |
+
conv_cfg=self.conv_cfg,
|
119 |
+
norm_cfg=self.norm_cfg,
|
120 |
+
act_cfg=self.act_cfg))
|
121 |
+
self.dcm_modules = nn.ModuleList(dcm_modules)
|
122 |
+
self.bottleneck = ConvModule(
|
123 |
+
self.in_channels + len(filter_sizes) * self.channels,
|
124 |
+
self.channels,
|
125 |
+
3,
|
126 |
+
padding=1,
|
127 |
+
conv_cfg=self.conv_cfg,
|
128 |
+
norm_cfg=self.norm_cfg,
|
129 |
+
act_cfg=self.act_cfg)
|
130 |
+
|
131 |
+
def forward(self, inputs):
|
132 |
+
"""Forward function."""
|
133 |
+
x = self._transform_inputs(inputs)
|
134 |
+
dcm_outs = [x]
|
135 |
+
for dcm_module in self.dcm_modules:
|
136 |
+
dcm_outs.append(dcm_module(x))
|
137 |
+
dcm_outs = torch.cat(dcm_outs, dim=1)
|
138 |
+
output = self.bottleneck(dcm_outs)
|
139 |
+
output = self.cls_seg(output)
|
140 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/dnl_head.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from annotator.mmpkg.mmcv.cnn import NonLocal2d
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from ..builder import HEADS
|
6 |
+
from .fcn_head import FCNHead
|
7 |
+
|
8 |
+
|
9 |
+
class DisentangledNonLocal2d(NonLocal2d):
|
10 |
+
"""Disentangled Non-Local Blocks.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
temperature (float): Temperature to adjust attention. Default: 0.05
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, *arg, temperature, **kwargs):
|
17 |
+
super().__init__(*arg, **kwargs)
|
18 |
+
self.temperature = temperature
|
19 |
+
self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
|
20 |
+
|
21 |
+
def embedded_gaussian(self, theta_x, phi_x):
|
22 |
+
"""Embedded gaussian with temperature."""
|
23 |
+
|
24 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
25 |
+
pairwise_weight = torch.matmul(theta_x, phi_x)
|
26 |
+
if self.use_scale:
|
27 |
+
# theta_x.shape[-1] is `self.inter_channels`
|
28 |
+
pairwise_weight /= theta_x.shape[-1]**0.5
|
29 |
+
pairwise_weight /= self.temperature
|
30 |
+
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
31 |
+
return pairwise_weight
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
# x: [N, C, H, W]
|
35 |
+
n = x.size(0)
|
36 |
+
|
37 |
+
# g_x: [N, HxW, C]
|
38 |
+
g_x = self.g(x).view(n, self.inter_channels, -1)
|
39 |
+
g_x = g_x.permute(0, 2, 1)
|
40 |
+
|
41 |
+
# theta_x: [N, HxW, C], phi_x: [N, C, HxW]
|
42 |
+
if self.mode == 'gaussian':
|
43 |
+
theta_x = x.view(n, self.in_channels, -1)
|
44 |
+
theta_x = theta_x.permute(0, 2, 1)
|
45 |
+
if self.sub_sample:
|
46 |
+
phi_x = self.phi(x).view(n, self.in_channels, -1)
|
47 |
+
else:
|
48 |
+
phi_x = x.view(n, self.in_channels, -1)
|
49 |
+
elif self.mode == 'concatenation':
|
50 |
+
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
|
51 |
+
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
|
52 |
+
else:
|
53 |
+
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
54 |
+
theta_x = theta_x.permute(0, 2, 1)
|
55 |
+
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
56 |
+
|
57 |
+
# subtract mean
|
58 |
+
theta_x -= theta_x.mean(dim=-2, keepdim=True)
|
59 |
+
phi_x -= phi_x.mean(dim=-1, keepdim=True)
|
60 |
+
|
61 |
+
pairwise_func = getattr(self, self.mode)
|
62 |
+
# pairwise_weight: [N, HxW, HxW]
|
63 |
+
pairwise_weight = pairwise_func(theta_x, phi_x)
|
64 |
+
|
65 |
+
# y: [N, HxW, C]
|
66 |
+
y = torch.matmul(pairwise_weight, g_x)
|
67 |
+
# y: [N, C, H, W]
|
68 |
+
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
|
69 |
+
*x.size()[2:])
|
70 |
+
|
71 |
+
# unary_mask: [N, 1, HxW]
|
72 |
+
unary_mask = self.conv_mask(x)
|
73 |
+
unary_mask = unary_mask.view(n, 1, -1)
|
74 |
+
unary_mask = unary_mask.softmax(dim=-1)
|
75 |
+
# unary_x: [N, 1, C]
|
76 |
+
unary_x = torch.matmul(unary_mask, g_x)
|
77 |
+
# unary_x: [N, C, 1, 1]
|
78 |
+
unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
|
79 |
+
n, self.inter_channels, 1, 1)
|
80 |
+
|
81 |
+
output = x + self.conv_out(y + unary_x)
|
82 |
+
|
83 |
+
return output
|
84 |
+
|
85 |
+
|
86 |
+
@HEADS.register_module()
|
87 |
+
class DNLHead(FCNHead):
|
88 |
+
"""Disentangled Non-Local Neural Networks.
|
89 |
+
|
90 |
+
This head is the implementation of `DNLNet
|
91 |
+
<https://arxiv.org/abs/2006.06668>`_.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
reduction (int): Reduction factor of projection transform. Default: 2.
|
95 |
+
use_scale (bool): Whether to scale pairwise_weight by
|
96 |
+
sqrt(1/inter_channels). Default: False.
|
97 |
+
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
98 |
+
'dot_product'. Default: 'embedded_gaussian.'.
|
99 |
+
temperature (float): Temperature to adjust attention. Default: 0.05
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self,
|
103 |
+
reduction=2,
|
104 |
+
use_scale=True,
|
105 |
+
mode='embedded_gaussian',
|
106 |
+
temperature=0.05,
|
107 |
+
**kwargs):
|
108 |
+
super(DNLHead, self).__init__(num_convs=2, **kwargs)
|
109 |
+
self.reduction = reduction
|
110 |
+
self.use_scale = use_scale
|
111 |
+
self.mode = mode
|
112 |
+
self.temperature = temperature
|
113 |
+
self.dnl_block = DisentangledNonLocal2d(
|
114 |
+
in_channels=self.channels,
|
115 |
+
reduction=self.reduction,
|
116 |
+
use_scale=self.use_scale,
|
117 |
+
conv_cfg=self.conv_cfg,
|
118 |
+
norm_cfg=self.norm_cfg,
|
119 |
+
mode=self.mode,
|
120 |
+
temperature=self.temperature)
|
121 |
+
|
122 |
+
def forward(self, inputs):
|
123 |
+
"""Forward function."""
|
124 |
+
x = self._transform_inputs(inputs)
|
125 |
+
output = self.convs[0](x)
|
126 |
+
output = self.dnl_block(output)
|
127 |
+
output = self.convs[1](output)
|
128 |
+
if self.concat_input:
|
129 |
+
output = self.conv_cat(torch.cat([x, output], dim=1))
|
130 |
+
output = self.cls_seg(output)
|
131 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ema_head.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
8 |
+
|
9 |
+
from ..builder import HEADS
|
10 |
+
from .decode_head import BaseDecodeHead
|
11 |
+
|
12 |
+
|
13 |
+
def reduce_mean(tensor):
|
14 |
+
"""Reduce mean when distributed training."""
|
15 |
+
if not (dist.is_available() and dist.is_initialized()):
|
16 |
+
return tensor
|
17 |
+
tensor = tensor.clone()
|
18 |
+
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
|
19 |
+
return tensor
|
20 |
+
|
21 |
+
|
22 |
+
class EMAModule(nn.Module):
|
23 |
+
"""Expectation Maximization Attention Module used in EMANet.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
channels (int): Channels of the whole module.
|
27 |
+
num_bases (int): Number of bases.
|
28 |
+
num_stages (int): Number of the EM iterations.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, channels, num_bases, num_stages, momentum):
|
32 |
+
super(EMAModule, self).__init__()
|
33 |
+
assert num_stages >= 1, 'num_stages must be at least 1!'
|
34 |
+
self.num_bases = num_bases
|
35 |
+
self.num_stages = num_stages
|
36 |
+
self.momentum = momentum
|
37 |
+
|
38 |
+
bases = torch.zeros(1, channels, self.num_bases)
|
39 |
+
bases.normal_(0, math.sqrt(2. / self.num_bases))
|
40 |
+
# [1, channels, num_bases]
|
41 |
+
bases = F.normalize(bases, dim=1, p=2)
|
42 |
+
self.register_buffer('bases', bases)
|
43 |
+
|
44 |
+
def forward(self, feats):
|
45 |
+
"""Forward function."""
|
46 |
+
batch_size, channels, height, width = feats.size()
|
47 |
+
# [batch_size, channels, height*width]
|
48 |
+
feats = feats.view(batch_size, channels, height * width)
|
49 |
+
# [batch_size, channels, num_bases]
|
50 |
+
bases = self.bases.repeat(batch_size, 1, 1)
|
51 |
+
|
52 |
+
with torch.no_grad():
|
53 |
+
for i in range(self.num_stages):
|
54 |
+
# [batch_size, height*width, num_bases]
|
55 |
+
attention = torch.einsum('bcn,bck->bnk', feats, bases)
|
56 |
+
attention = F.softmax(attention, dim=2)
|
57 |
+
# l1 norm
|
58 |
+
attention_normed = F.normalize(attention, dim=1, p=1)
|
59 |
+
# [batch_size, channels, num_bases]
|
60 |
+
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
|
61 |
+
# l2 norm
|
62 |
+
bases = F.normalize(bases, dim=1, p=2)
|
63 |
+
|
64 |
+
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
|
65 |
+
feats_recon = feats_recon.view(batch_size, channels, height, width)
|
66 |
+
|
67 |
+
if self.training:
|
68 |
+
bases = bases.mean(dim=0, keepdim=True)
|
69 |
+
bases = reduce_mean(bases)
|
70 |
+
# l2 norm
|
71 |
+
bases = F.normalize(bases, dim=1, p=2)
|
72 |
+
self.bases = (1 -
|
73 |
+
self.momentum) * self.bases + self.momentum * bases
|
74 |
+
|
75 |
+
return feats_recon
|
76 |
+
|
77 |
+
|
78 |
+
@HEADS.register_module()
|
79 |
+
class EMAHead(BaseDecodeHead):
|
80 |
+
"""Expectation Maximization Attention Networks for Semantic Segmentation.
|
81 |
+
|
82 |
+
This head is the implementation of `EMANet
|
83 |
+
<https://arxiv.org/abs/1907.13426>`_.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
ema_channels (int): EMA module channels
|
87 |
+
num_bases (int): Number of bases.
|
88 |
+
num_stages (int): Number of the EM iterations.
|
89 |
+
concat_input (bool): Whether concat the input and output of convs
|
90 |
+
before classification layer. Default: True
|
91 |
+
momentum (float): Momentum to update the base. Default: 0.1.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self,
|
95 |
+
ema_channels,
|
96 |
+
num_bases,
|
97 |
+
num_stages,
|
98 |
+
concat_input=True,
|
99 |
+
momentum=0.1,
|
100 |
+
**kwargs):
|
101 |
+
super(EMAHead, self).__init__(**kwargs)
|
102 |
+
self.ema_channels = ema_channels
|
103 |
+
self.num_bases = num_bases
|
104 |
+
self.num_stages = num_stages
|
105 |
+
self.concat_input = concat_input
|
106 |
+
self.momentum = momentum
|
107 |
+
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
|
108 |
+
self.num_stages, self.momentum)
|
109 |
+
|
110 |
+
self.ema_in_conv = ConvModule(
|
111 |
+
self.in_channels,
|
112 |
+
self.ema_channels,
|
113 |
+
3,
|
114 |
+
padding=1,
|
115 |
+
conv_cfg=self.conv_cfg,
|
116 |
+
norm_cfg=self.norm_cfg,
|
117 |
+
act_cfg=self.act_cfg)
|
118 |
+
# project (0, inf) -> (-inf, inf)
|
119 |
+
self.ema_mid_conv = ConvModule(
|
120 |
+
self.ema_channels,
|
121 |
+
self.ema_channels,
|
122 |
+
1,
|
123 |
+
conv_cfg=self.conv_cfg,
|
124 |
+
norm_cfg=None,
|
125 |
+
act_cfg=None)
|
126 |
+
for param in self.ema_mid_conv.parameters():
|
127 |
+
param.requires_grad = False
|
128 |
+
|
129 |
+
self.ema_out_conv = ConvModule(
|
130 |
+
self.ema_channels,
|
131 |
+
self.ema_channels,
|
132 |
+
1,
|
133 |
+
conv_cfg=self.conv_cfg,
|
134 |
+
norm_cfg=self.norm_cfg,
|
135 |
+
act_cfg=None)
|
136 |
+
self.bottleneck = ConvModule(
|
137 |
+
self.ema_channels,
|
138 |
+
self.channels,
|
139 |
+
3,
|
140 |
+
padding=1,
|
141 |
+
conv_cfg=self.conv_cfg,
|
142 |
+
norm_cfg=self.norm_cfg,
|
143 |
+
act_cfg=self.act_cfg)
|
144 |
+
if self.concat_input:
|
145 |
+
self.conv_cat = ConvModule(
|
146 |
+
self.in_channels + self.channels,
|
147 |
+
self.channels,
|
148 |
+
kernel_size=3,
|
149 |
+
padding=1,
|
150 |
+
conv_cfg=self.conv_cfg,
|
151 |
+
norm_cfg=self.norm_cfg,
|
152 |
+
act_cfg=self.act_cfg)
|
153 |
+
|
154 |
+
def forward(self, inputs):
|
155 |
+
"""Forward function."""
|
156 |
+
x = self._transform_inputs(inputs)
|
157 |
+
feats = self.ema_in_conv(x)
|
158 |
+
identity = feats
|
159 |
+
feats = self.ema_mid_conv(feats)
|
160 |
+
recon = self.ema_module(feats)
|
161 |
+
recon = F.relu(recon, inplace=True)
|
162 |
+
recon = self.ema_out_conv(recon)
|
163 |
+
output = F.relu(identity + recon, inplace=True)
|
164 |
+
output = self.bottleneck(output)
|
165 |
+
if self.concat_input:
|
166 |
+
output = self.conv_cat(torch.cat([x, output], dim=1))
|
167 |
+
output = self.cls_seg(output)
|
168 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/enc_head.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule, build_norm_layer
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmseg.ops import Encoding, resize
|
7 |
+
from ..builder import HEADS, build_loss
|
8 |
+
from .decode_head import BaseDecodeHead
|
9 |
+
|
10 |
+
|
11 |
+
class EncModule(nn.Module):
|
12 |
+
"""Encoding Module used in EncNet.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
in_channels (int): Input channels.
|
16 |
+
num_codes (int): Number of code words.
|
17 |
+
conv_cfg (dict|None): Config of conv layers.
|
18 |
+
norm_cfg (dict|None): Config of norm layers.
|
19 |
+
act_cfg (dict): Config of activation layers.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
|
23 |
+
super(EncModule, self).__init__()
|
24 |
+
self.encoding_project = ConvModule(
|
25 |
+
in_channels,
|
26 |
+
in_channels,
|
27 |
+
1,
|
28 |
+
conv_cfg=conv_cfg,
|
29 |
+
norm_cfg=norm_cfg,
|
30 |
+
act_cfg=act_cfg)
|
31 |
+
# TODO: resolve this hack
|
32 |
+
# change to 1d
|
33 |
+
if norm_cfg is not None:
|
34 |
+
encoding_norm_cfg = norm_cfg.copy()
|
35 |
+
if encoding_norm_cfg['type'] in ['BN', 'IN']:
|
36 |
+
encoding_norm_cfg['type'] += '1d'
|
37 |
+
else:
|
38 |
+
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
|
39 |
+
'2d', '1d')
|
40 |
+
else:
|
41 |
+
# fallback to BN1d
|
42 |
+
encoding_norm_cfg = dict(type='BN1d')
|
43 |
+
self.encoding = nn.Sequential(
|
44 |
+
Encoding(channels=in_channels, num_codes=num_codes),
|
45 |
+
build_norm_layer(encoding_norm_cfg, num_codes)[1],
|
46 |
+
nn.ReLU(inplace=True))
|
47 |
+
self.fc = nn.Sequential(
|
48 |
+
nn.Linear(in_channels, in_channels), nn.Sigmoid())
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
"""Forward function."""
|
52 |
+
encoding_projection = self.encoding_project(x)
|
53 |
+
encoding_feat = self.encoding(encoding_projection).mean(dim=1)
|
54 |
+
batch_size, channels, _, _ = x.size()
|
55 |
+
gamma = self.fc(encoding_feat)
|
56 |
+
y = gamma.view(batch_size, channels, 1, 1)
|
57 |
+
output = F.relu_(x + x * y)
|
58 |
+
return encoding_feat, output
|
59 |
+
|
60 |
+
|
61 |
+
@HEADS.register_module()
|
62 |
+
class EncHead(BaseDecodeHead):
|
63 |
+
"""Context Encoding for Semantic Segmentation.
|
64 |
+
|
65 |
+
This head is the implementation of `EncNet
|
66 |
+
<https://arxiv.org/abs/1803.08904>`_.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
num_codes (int): Number of code words. Default: 32.
|
70 |
+
use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
|
71 |
+
regularize the training. Default: True.
|
72 |
+
add_lateral (bool): Whether use lateral connection to fuse features.
|
73 |
+
Default: False.
|
74 |
+
loss_se_decode (dict): Config of decode loss.
|
75 |
+
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self,
|
79 |
+
num_codes=32,
|
80 |
+
use_se_loss=True,
|
81 |
+
add_lateral=False,
|
82 |
+
loss_se_decode=dict(
|
83 |
+
type='CrossEntropyLoss',
|
84 |
+
use_sigmoid=True,
|
85 |
+
loss_weight=0.2),
|
86 |
+
**kwargs):
|
87 |
+
super(EncHead, self).__init__(
|
88 |
+
input_transform='multiple_select', **kwargs)
|
89 |
+
self.use_se_loss = use_se_loss
|
90 |
+
self.add_lateral = add_lateral
|
91 |
+
self.num_codes = num_codes
|
92 |
+
self.bottleneck = ConvModule(
|
93 |
+
self.in_channels[-1],
|
94 |
+
self.channels,
|
95 |
+
3,
|
96 |
+
padding=1,
|
97 |
+
conv_cfg=self.conv_cfg,
|
98 |
+
norm_cfg=self.norm_cfg,
|
99 |
+
act_cfg=self.act_cfg)
|
100 |
+
if add_lateral:
|
101 |
+
self.lateral_convs = nn.ModuleList()
|
102 |
+
for in_channels in self.in_channels[:-1]: # skip the last one
|
103 |
+
self.lateral_convs.append(
|
104 |
+
ConvModule(
|
105 |
+
in_channels,
|
106 |
+
self.channels,
|
107 |
+
1,
|
108 |
+
conv_cfg=self.conv_cfg,
|
109 |
+
norm_cfg=self.norm_cfg,
|
110 |
+
act_cfg=self.act_cfg))
|
111 |
+
self.fusion = ConvModule(
|
112 |
+
len(self.in_channels) * self.channels,
|
113 |
+
self.channels,
|
114 |
+
3,
|
115 |
+
padding=1,
|
116 |
+
conv_cfg=self.conv_cfg,
|
117 |
+
norm_cfg=self.norm_cfg,
|
118 |
+
act_cfg=self.act_cfg)
|
119 |
+
self.enc_module = EncModule(
|
120 |
+
self.channels,
|
121 |
+
num_codes=num_codes,
|
122 |
+
conv_cfg=self.conv_cfg,
|
123 |
+
norm_cfg=self.norm_cfg,
|
124 |
+
act_cfg=self.act_cfg)
|
125 |
+
if self.use_se_loss:
|
126 |
+
self.loss_se_decode = build_loss(loss_se_decode)
|
127 |
+
self.se_layer = nn.Linear(self.channels, self.num_classes)
|
128 |
+
|
129 |
+
def forward(self, inputs):
|
130 |
+
"""Forward function."""
|
131 |
+
inputs = self._transform_inputs(inputs)
|
132 |
+
feat = self.bottleneck(inputs[-1])
|
133 |
+
if self.add_lateral:
|
134 |
+
laterals = [
|
135 |
+
resize(
|
136 |
+
lateral_conv(inputs[i]),
|
137 |
+
size=feat.shape[2:],
|
138 |
+
mode='bilinear',
|
139 |
+
align_corners=self.align_corners)
|
140 |
+
for i, lateral_conv in enumerate(self.lateral_convs)
|
141 |
+
]
|
142 |
+
feat = self.fusion(torch.cat([feat, *laterals], 1))
|
143 |
+
encode_feat, output = self.enc_module(feat)
|
144 |
+
output = self.cls_seg(output)
|
145 |
+
if self.use_se_loss:
|
146 |
+
se_output = self.se_layer(encode_feat)
|
147 |
+
return output, se_output
|
148 |
+
else:
|
149 |
+
return output
|
150 |
+
|
151 |
+
def forward_test(self, inputs, img_metas, test_cfg):
|
152 |
+
"""Forward function for testing, ignore se_loss."""
|
153 |
+
if self.use_se_loss:
|
154 |
+
return self.forward(inputs)[0]
|
155 |
+
else:
|
156 |
+
return self.forward(inputs)
|
157 |
+
|
158 |
+
@staticmethod
|
159 |
+
def _convert_to_onehot_labels(seg_label, num_classes):
|
160 |
+
"""Convert segmentation label to onehot.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
seg_label (Tensor): Segmentation label of shape (N, H, W).
|
164 |
+
num_classes (int): Number of classes.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
Tensor: Onehot labels of shape (N, num_classes).
|
168 |
+
"""
|
169 |
+
|
170 |
+
batch_size = seg_label.size(0)
|
171 |
+
onehot_labels = seg_label.new_zeros((batch_size, num_classes))
|
172 |
+
for i in range(batch_size):
|
173 |
+
hist = seg_label[i].float().histc(
|
174 |
+
bins=num_classes, min=0, max=num_classes - 1)
|
175 |
+
onehot_labels[i] = hist > 0
|
176 |
+
return onehot_labels
|
177 |
+
|
178 |
+
def losses(self, seg_logit, seg_label):
|
179 |
+
"""Compute segmentation and semantic encoding loss."""
|
180 |
+
seg_logit, se_seg_logit = seg_logit
|
181 |
+
loss = dict()
|
182 |
+
loss.update(super(EncHead, self).losses(seg_logit, seg_label))
|
183 |
+
se_loss = self.loss_se_decode(
|
184 |
+
se_seg_logit,
|
185 |
+
self._convert_to_onehot_labels(seg_label, self.num_classes))
|
186 |
+
loss['loss_se'] = se_loss
|
187 |
+
return loss
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/fcn_head.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
4 |
+
|
5 |
+
from ..builder import HEADS
|
6 |
+
from .decode_head import BaseDecodeHead
|
7 |
+
|
8 |
+
|
9 |
+
@HEADS.register_module()
|
10 |
+
class FCNHead(BaseDecodeHead):
|
11 |
+
"""Fully Convolution Networks for Semantic Segmentation.
|
12 |
+
|
13 |
+
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_convs (int): Number of convs in the head. Default: 2.
|
17 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
18 |
+
concat_input (bool): Whether concat the input and output of convs
|
19 |
+
before classification layer.
|
20 |
+
dilation (int): The dilation rate for convs in the head. Default: 1.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self,
|
24 |
+
num_convs=2,
|
25 |
+
kernel_size=3,
|
26 |
+
concat_input=True,
|
27 |
+
dilation=1,
|
28 |
+
**kwargs):
|
29 |
+
assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
|
30 |
+
self.num_convs = num_convs
|
31 |
+
self.concat_input = concat_input
|
32 |
+
self.kernel_size = kernel_size
|
33 |
+
super(FCNHead, self).__init__(**kwargs)
|
34 |
+
if num_convs == 0:
|
35 |
+
assert self.in_channels == self.channels
|
36 |
+
|
37 |
+
conv_padding = (kernel_size // 2) * dilation
|
38 |
+
convs = []
|
39 |
+
convs.append(
|
40 |
+
ConvModule(
|
41 |
+
self.in_channels,
|
42 |
+
self.channels,
|
43 |
+
kernel_size=kernel_size,
|
44 |
+
padding=conv_padding,
|
45 |
+
dilation=dilation,
|
46 |
+
conv_cfg=self.conv_cfg,
|
47 |
+
norm_cfg=self.norm_cfg,
|
48 |
+
act_cfg=self.act_cfg))
|
49 |
+
for i in range(num_convs - 1):
|
50 |
+
convs.append(
|
51 |
+
ConvModule(
|
52 |
+
self.channels,
|
53 |
+
self.channels,
|
54 |
+
kernel_size=kernel_size,
|
55 |
+
padding=conv_padding,
|
56 |
+
dilation=dilation,
|
57 |
+
conv_cfg=self.conv_cfg,
|
58 |
+
norm_cfg=self.norm_cfg,
|
59 |
+
act_cfg=self.act_cfg))
|
60 |
+
if num_convs == 0:
|
61 |
+
self.convs = nn.Identity()
|
62 |
+
else:
|
63 |
+
self.convs = nn.Sequential(*convs)
|
64 |
+
if self.concat_input:
|
65 |
+
self.conv_cat = ConvModule(
|
66 |
+
self.in_channels + self.channels,
|
67 |
+
self.channels,
|
68 |
+
kernel_size=kernel_size,
|
69 |
+
padding=kernel_size // 2,
|
70 |
+
conv_cfg=self.conv_cfg,
|
71 |
+
norm_cfg=self.norm_cfg,
|
72 |
+
act_cfg=self.act_cfg)
|
73 |
+
|
74 |
+
def forward(self, inputs):
|
75 |
+
"""Forward function."""
|
76 |
+
x = self._transform_inputs(inputs)
|
77 |
+
output = self.convs(x)
|
78 |
+
if self.concat_input:
|
79 |
+
output = self.conv_cat(torch.cat([x, output], dim=1))
|
80 |
+
output = self.cls_seg(output)
|
81 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/fpn_head.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch.nn as nn
|
3 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
4 |
+
|
5 |
+
from annotator.mmpkg.mmseg.ops import resize
|
6 |
+
from ..builder import HEADS
|
7 |
+
from .decode_head import BaseDecodeHead
|
8 |
+
|
9 |
+
|
10 |
+
@HEADS.register_module()
|
11 |
+
class FPNHead(BaseDecodeHead):
|
12 |
+
"""Panoptic Feature Pyramid Networks.
|
13 |
+
|
14 |
+
This head is the implementation of `Semantic FPN
|
15 |
+
<https://arxiv.org/abs/1901.02446>`_.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
feature_strides (tuple[int]): The strides for input feature maps.
|
19 |
+
stack_lateral. All strides suppose to be power of 2. The first
|
20 |
+
one is of largest resolution.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, feature_strides, **kwargs):
|
24 |
+
super(FPNHead, self).__init__(
|
25 |
+
input_transform='multiple_select', **kwargs)
|
26 |
+
assert len(feature_strides) == len(self.in_channels)
|
27 |
+
assert min(feature_strides) == feature_strides[0]
|
28 |
+
self.feature_strides = feature_strides
|
29 |
+
|
30 |
+
self.scale_heads = nn.ModuleList()
|
31 |
+
for i in range(len(feature_strides)):
|
32 |
+
head_length = max(
|
33 |
+
1,
|
34 |
+
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
|
35 |
+
scale_head = []
|
36 |
+
for k in range(head_length):
|
37 |
+
scale_head.append(
|
38 |
+
ConvModule(
|
39 |
+
self.in_channels[i] if k == 0 else self.channels,
|
40 |
+
self.channels,
|
41 |
+
3,
|
42 |
+
padding=1,
|
43 |
+
conv_cfg=self.conv_cfg,
|
44 |
+
norm_cfg=self.norm_cfg,
|
45 |
+
act_cfg=self.act_cfg))
|
46 |
+
if feature_strides[i] != feature_strides[0]:
|
47 |
+
scale_head.append(
|
48 |
+
nn.Upsample(
|
49 |
+
scale_factor=2,
|
50 |
+
mode='bilinear',
|
51 |
+
align_corners=self.align_corners))
|
52 |
+
self.scale_heads.append(nn.Sequential(*scale_head))
|
53 |
+
|
54 |
+
def forward(self, inputs):
|
55 |
+
|
56 |
+
x = self._transform_inputs(inputs)
|
57 |
+
|
58 |
+
output = self.scale_heads[0](x[0])
|
59 |
+
for i in range(1, len(self.feature_strides)):
|
60 |
+
# non inplace
|
61 |
+
output = output + resize(
|
62 |
+
self.scale_heads[i](x[i]),
|
63 |
+
size=output.shape[2:],
|
64 |
+
mode='bilinear',
|
65 |
+
align_corners=self.align_corners)
|
66 |
+
|
67 |
+
output = self.cls_seg(output)
|
68 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/gc_head.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from annotator.mmpkg.mmcv.cnn import ContextBlock
|
3 |
+
|
4 |
+
from ..builder import HEADS
|
5 |
+
from .fcn_head import FCNHead
|
6 |
+
|
7 |
+
|
8 |
+
@HEADS.register_module()
|
9 |
+
class GCHead(FCNHead):
|
10 |
+
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
|
11 |
+
|
12 |
+
This head is the implementation of `GCNet
|
13 |
+
<https://arxiv.org/abs/1904.11492>`_.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
ratio (float): Multiplier of channels ratio. Default: 1/4.
|
17 |
+
pooling_type (str): The pooling type of context aggregation.
|
18 |
+
Options are 'att', 'avg'. Default: 'avg'.
|
19 |
+
fusion_types (tuple[str]): The fusion type for feature fusion.
|
20 |
+
Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self,
|
24 |
+
ratio=1 / 4.,
|
25 |
+
pooling_type='att',
|
26 |
+
fusion_types=('channel_add', ),
|
27 |
+
**kwargs):
|
28 |
+
super(GCHead, self).__init__(num_convs=2, **kwargs)
|
29 |
+
self.ratio = ratio
|
30 |
+
self.pooling_type = pooling_type
|
31 |
+
self.fusion_types = fusion_types
|
32 |
+
self.gc_block = ContextBlock(
|
33 |
+
in_channels=self.channels,
|
34 |
+
ratio=self.ratio,
|
35 |
+
pooling_type=self.pooling_type,
|
36 |
+
fusion_types=self.fusion_types)
|
37 |
+
|
38 |
+
def forward(self, inputs):
|
39 |
+
"""Forward function."""
|
40 |
+
x = self._transform_inputs(inputs)
|
41 |
+
output = self.convs[0](x)
|
42 |
+
output = self.gc_block(output)
|
43 |
+
output = self.convs[1](output)
|
44 |
+
if self.concat_input:
|
45 |
+
output = self.conv_cat(torch.cat([x, output], dim=1))
|
46 |
+
output = self.cls_seg(output)
|
47 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/lraspp_head.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from annotator.mmpkg.mmcv import is_tuple_of
|
4 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmseg.ops import resize
|
7 |
+
from ..builder import HEADS
|
8 |
+
from .decode_head import BaseDecodeHead
|
9 |
+
|
10 |
+
|
11 |
+
@HEADS.register_module()
|
12 |
+
class LRASPPHead(BaseDecodeHead):
|
13 |
+
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
|
14 |
+
|
15 |
+
This head is the improved implementation of `Searching for MobileNetV3
|
16 |
+
<https://ieeexplore.ieee.org/document/9008835>`_.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
branch_channels (tuple[int]): The number of output channels in every
|
20 |
+
each branch. Default: (32, 64).
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, branch_channels=(32, 64), **kwargs):
|
24 |
+
super(LRASPPHead, self).__init__(**kwargs)
|
25 |
+
if self.input_transform != 'multiple_select':
|
26 |
+
raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
|
27 |
+
f'must be \'multiple_select\'. But received '
|
28 |
+
f'\'{self.input_transform}\'')
|
29 |
+
assert is_tuple_of(branch_channels, int)
|
30 |
+
assert len(branch_channels) == len(self.in_channels) - 1
|
31 |
+
self.branch_channels = branch_channels
|
32 |
+
|
33 |
+
self.convs = nn.Sequential()
|
34 |
+
self.conv_ups = nn.Sequential()
|
35 |
+
for i in range(len(branch_channels)):
|
36 |
+
self.convs.add_module(
|
37 |
+
f'conv{i}',
|
38 |
+
nn.Conv2d(
|
39 |
+
self.in_channels[i], branch_channels[i], 1, bias=False))
|
40 |
+
self.conv_ups.add_module(
|
41 |
+
f'conv_up{i}',
|
42 |
+
ConvModule(
|
43 |
+
self.channels + branch_channels[i],
|
44 |
+
self.channels,
|
45 |
+
1,
|
46 |
+
norm_cfg=self.norm_cfg,
|
47 |
+
act_cfg=self.act_cfg,
|
48 |
+
bias=False))
|
49 |
+
|
50 |
+
self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
|
51 |
+
|
52 |
+
self.aspp_conv = ConvModule(
|
53 |
+
self.in_channels[-1],
|
54 |
+
self.channels,
|
55 |
+
1,
|
56 |
+
norm_cfg=self.norm_cfg,
|
57 |
+
act_cfg=self.act_cfg,
|
58 |
+
bias=False)
|
59 |
+
self.image_pool = nn.Sequential(
|
60 |
+
nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
|
61 |
+
ConvModule(
|
62 |
+
self.in_channels[2],
|
63 |
+
self.channels,
|
64 |
+
1,
|
65 |
+
act_cfg=dict(type='Sigmoid'),
|
66 |
+
bias=False))
|
67 |
+
|
68 |
+
def forward(self, inputs):
|
69 |
+
"""Forward function."""
|
70 |
+
inputs = self._transform_inputs(inputs)
|
71 |
+
|
72 |
+
x = inputs[-1]
|
73 |
+
|
74 |
+
x = self.aspp_conv(x) * resize(
|
75 |
+
self.image_pool(x),
|
76 |
+
size=x.size()[2:],
|
77 |
+
mode='bilinear',
|
78 |
+
align_corners=self.align_corners)
|
79 |
+
x = self.conv_up_input(x)
|
80 |
+
|
81 |
+
for i in range(len(self.branch_channels) - 1, -1, -1):
|
82 |
+
x = resize(
|
83 |
+
x,
|
84 |
+
size=inputs[i].size()[2:],
|
85 |
+
mode='bilinear',
|
86 |
+
align_corners=self.align_corners)
|
87 |
+
x = torch.cat([x, self.convs[i](inputs[i])], 1)
|
88 |
+
x = self.conv_ups[i](x)
|
89 |
+
|
90 |
+
return self.cls_seg(x)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/nl_head.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from annotator.mmpkg.mmcv.cnn import NonLocal2d
|
3 |
+
|
4 |
+
from ..builder import HEADS
|
5 |
+
from .fcn_head import FCNHead
|
6 |
+
|
7 |
+
|
8 |
+
@HEADS.register_module()
|
9 |
+
class NLHead(FCNHead):
|
10 |
+
"""Non-local Neural Networks.
|
11 |
+
|
12 |
+
This head is the implementation of `NLNet
|
13 |
+
<https://arxiv.org/abs/1711.07971>`_.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
reduction (int): Reduction factor of projection transform. Default: 2.
|
17 |
+
use_scale (bool): Whether to scale pairwise_weight by
|
18 |
+
sqrt(1/inter_channels). Default: True.
|
19 |
+
mode (str): The nonlocal mode. Options are 'embedded_gaussian',
|
20 |
+
'dot_product'. Default: 'embedded_gaussian.'.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self,
|
24 |
+
reduction=2,
|
25 |
+
use_scale=True,
|
26 |
+
mode='embedded_gaussian',
|
27 |
+
**kwargs):
|
28 |
+
super(NLHead, self).__init__(num_convs=2, **kwargs)
|
29 |
+
self.reduction = reduction
|
30 |
+
self.use_scale = use_scale
|
31 |
+
self.mode = mode
|
32 |
+
self.nl_block = NonLocal2d(
|
33 |
+
in_channels=self.channels,
|
34 |
+
reduction=self.reduction,
|
35 |
+
use_scale=self.use_scale,
|
36 |
+
conv_cfg=self.conv_cfg,
|
37 |
+
norm_cfg=self.norm_cfg,
|
38 |
+
mode=self.mode)
|
39 |
+
|
40 |
+
def forward(self, inputs):
|
41 |
+
"""Forward function."""
|
42 |
+
x = self._transform_inputs(inputs)
|
43 |
+
output = self.convs[0](x)
|
44 |
+
output = self.nl_block(output)
|
45 |
+
output = self.convs[1](output)
|
46 |
+
if self.concat_input:
|
47 |
+
output = self.conv_cat(torch.cat([x, output], dim=1))
|
48 |
+
output = self.cls_seg(output)
|
49 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/ocr_head.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmseg.ops import resize
|
7 |
+
from ..builder import HEADS
|
8 |
+
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
9 |
+
from .cascade_decode_head import BaseCascadeDecodeHead
|
10 |
+
|
11 |
+
|
12 |
+
class SpatialGatherModule(nn.Module):
|
13 |
+
"""Aggregate the context features according to the initial predicted
|
14 |
+
probability distribution.
|
15 |
+
|
16 |
+
Employ the soft-weighted method to aggregate the context.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, scale):
|
20 |
+
super(SpatialGatherModule, self).__init__()
|
21 |
+
self.scale = scale
|
22 |
+
|
23 |
+
def forward(self, feats, probs):
|
24 |
+
"""Forward function."""
|
25 |
+
batch_size, num_classes, height, width = probs.size()
|
26 |
+
channels = feats.size(1)
|
27 |
+
probs = probs.view(batch_size, num_classes, -1)
|
28 |
+
feats = feats.view(batch_size, channels, -1)
|
29 |
+
# [batch_size, height*width, num_classes]
|
30 |
+
feats = feats.permute(0, 2, 1)
|
31 |
+
# [batch_size, channels, height*width]
|
32 |
+
probs = F.softmax(self.scale * probs, dim=2)
|
33 |
+
# [batch_size, channels, num_classes]
|
34 |
+
ocr_context = torch.matmul(probs, feats)
|
35 |
+
ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
|
36 |
+
return ocr_context
|
37 |
+
|
38 |
+
|
39 |
+
class ObjectAttentionBlock(_SelfAttentionBlock):
|
40 |
+
"""Make a OCR used SelfAttentionBlock."""
|
41 |
+
|
42 |
+
def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
|
43 |
+
act_cfg):
|
44 |
+
if scale > 1:
|
45 |
+
query_downsample = nn.MaxPool2d(kernel_size=scale)
|
46 |
+
else:
|
47 |
+
query_downsample = None
|
48 |
+
super(ObjectAttentionBlock, self).__init__(
|
49 |
+
key_in_channels=in_channels,
|
50 |
+
query_in_channels=in_channels,
|
51 |
+
channels=channels,
|
52 |
+
out_channels=in_channels,
|
53 |
+
share_key_query=False,
|
54 |
+
query_downsample=query_downsample,
|
55 |
+
key_downsample=None,
|
56 |
+
key_query_num_convs=2,
|
57 |
+
key_query_norm=True,
|
58 |
+
value_out_num_convs=1,
|
59 |
+
value_out_norm=True,
|
60 |
+
matmul_norm=True,
|
61 |
+
with_out=True,
|
62 |
+
conv_cfg=conv_cfg,
|
63 |
+
norm_cfg=norm_cfg,
|
64 |
+
act_cfg=act_cfg)
|
65 |
+
self.bottleneck = ConvModule(
|
66 |
+
in_channels * 2,
|
67 |
+
in_channels,
|
68 |
+
1,
|
69 |
+
conv_cfg=self.conv_cfg,
|
70 |
+
norm_cfg=self.norm_cfg,
|
71 |
+
act_cfg=self.act_cfg)
|
72 |
+
|
73 |
+
def forward(self, query_feats, key_feats):
|
74 |
+
"""Forward function."""
|
75 |
+
context = super(ObjectAttentionBlock,
|
76 |
+
self).forward(query_feats, key_feats)
|
77 |
+
output = self.bottleneck(torch.cat([context, query_feats], dim=1))
|
78 |
+
if self.query_downsample is not None:
|
79 |
+
output = resize(query_feats)
|
80 |
+
|
81 |
+
return output
|
82 |
+
|
83 |
+
|
84 |
+
@HEADS.register_module()
|
85 |
+
class OCRHead(BaseCascadeDecodeHead):
|
86 |
+
"""Object-Contextual Representations for Semantic Segmentation.
|
87 |
+
|
88 |
+
This head is the implementation of `OCRNet
|
89 |
+
<https://arxiv.org/abs/1909.11065>`_.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
ocr_channels (int): The intermediate channels of OCR block.
|
93 |
+
scale (int): The scale of probability map in SpatialGatherModule in
|
94 |
+
Default: 1.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, ocr_channels, scale=1, **kwargs):
|
98 |
+
super(OCRHead, self).__init__(**kwargs)
|
99 |
+
self.ocr_channels = ocr_channels
|
100 |
+
self.scale = scale
|
101 |
+
self.object_context_block = ObjectAttentionBlock(
|
102 |
+
self.channels,
|
103 |
+
self.ocr_channels,
|
104 |
+
self.scale,
|
105 |
+
conv_cfg=self.conv_cfg,
|
106 |
+
norm_cfg=self.norm_cfg,
|
107 |
+
act_cfg=self.act_cfg)
|
108 |
+
self.spatial_gather_module = SpatialGatherModule(self.scale)
|
109 |
+
|
110 |
+
self.bottleneck = ConvModule(
|
111 |
+
self.in_channels,
|
112 |
+
self.channels,
|
113 |
+
3,
|
114 |
+
padding=1,
|
115 |
+
conv_cfg=self.conv_cfg,
|
116 |
+
norm_cfg=self.norm_cfg,
|
117 |
+
act_cfg=self.act_cfg)
|
118 |
+
|
119 |
+
def forward(self, inputs, prev_output):
|
120 |
+
"""Forward function."""
|
121 |
+
x = self._transform_inputs(inputs)
|
122 |
+
feats = self.bottleneck(x)
|
123 |
+
context = self.spatial_gather_module(feats, prev_output)
|
124 |
+
object_context = self.object_context_block(feats, context)
|
125 |
+
output = self.cls_seg(object_context)
|
126 |
+
|
127 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/point_head.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
try:
|
7 |
+
from mmcv.cnn import ConvModule, normal_init
|
8 |
+
from mmcv.ops import point_sample
|
9 |
+
except ImportError:
|
10 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule, normal_init
|
11 |
+
from annotator.mmpkg.mmcv.ops import point_sample
|
12 |
+
|
13 |
+
from annotator.mmpkg.mmseg.models.builder import HEADS
|
14 |
+
from annotator.mmpkg.mmseg.ops import resize
|
15 |
+
from ..losses import accuracy
|
16 |
+
from .cascade_decode_head import BaseCascadeDecodeHead
|
17 |
+
|
18 |
+
|
19 |
+
def calculate_uncertainty(seg_logits):
|
20 |
+
"""Estimate uncertainty based on seg logits.
|
21 |
+
|
22 |
+
For each location of the prediction ``seg_logits`` we estimate
|
23 |
+
uncertainty as the difference between top first and top second
|
24 |
+
predicted logits.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
seg_logits (Tensor): Semantic segmentation logits,
|
28 |
+
shape (batch_size, num_classes, height, width).
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
scores (Tensor): T uncertainty scores with the most uncertain
|
32 |
+
locations having the highest uncertainty score, shape (
|
33 |
+
batch_size, 1, height, width)
|
34 |
+
"""
|
35 |
+
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
|
36 |
+
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
|
37 |
+
|
38 |
+
|
39 |
+
@HEADS.register_module()
|
40 |
+
class PointHead(BaseCascadeDecodeHead):
|
41 |
+
"""A mask point head use in PointRend.
|
42 |
+
|
43 |
+
``PointHead`` use shared multi-layer perceptron (equivalent to
|
44 |
+
nn.Conv1d) to predict the logit of input points. The fine-grained feature
|
45 |
+
and coarse feature will be concatenate together for predication.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
num_fcs (int): Number of fc layers in the head. Default: 3.
|
49 |
+
in_channels (int): Number of input channels. Default: 256.
|
50 |
+
fc_channels (int): Number of fc channels. Default: 256.
|
51 |
+
num_classes (int): Number of classes for logits. Default: 80.
|
52 |
+
class_agnostic (bool): Whether use class agnostic classification.
|
53 |
+
If so, the output channels of logits will be 1. Default: False.
|
54 |
+
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
|
55 |
+
the output of each fc layer. Default: True.
|
56 |
+
conv_cfg (dict|None): Dictionary to construct and config conv layer.
|
57 |
+
Default: dict(type='Conv1d'))
|
58 |
+
norm_cfg (dict|None): Dictionary to construct and config norm layer.
|
59 |
+
Default: None.
|
60 |
+
loss_point (dict): Dictionary to construct and config loss layer of
|
61 |
+
point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
|
62 |
+
loss_weight=1.0).
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self,
|
66 |
+
num_fcs=3,
|
67 |
+
coarse_pred_each_layer=True,
|
68 |
+
conv_cfg=dict(type='Conv1d'),
|
69 |
+
norm_cfg=None,
|
70 |
+
act_cfg=dict(type='ReLU', inplace=False),
|
71 |
+
**kwargs):
|
72 |
+
super(PointHead, self).__init__(
|
73 |
+
input_transform='multiple_select',
|
74 |
+
conv_cfg=conv_cfg,
|
75 |
+
norm_cfg=norm_cfg,
|
76 |
+
act_cfg=act_cfg,
|
77 |
+
**kwargs)
|
78 |
+
|
79 |
+
self.num_fcs = num_fcs
|
80 |
+
self.coarse_pred_each_layer = coarse_pred_each_layer
|
81 |
+
|
82 |
+
fc_in_channels = sum(self.in_channels) + self.num_classes
|
83 |
+
fc_channels = self.channels
|
84 |
+
self.fcs = nn.ModuleList()
|
85 |
+
for k in range(num_fcs):
|
86 |
+
fc = ConvModule(
|
87 |
+
fc_in_channels,
|
88 |
+
fc_channels,
|
89 |
+
kernel_size=1,
|
90 |
+
stride=1,
|
91 |
+
padding=0,
|
92 |
+
conv_cfg=conv_cfg,
|
93 |
+
norm_cfg=norm_cfg,
|
94 |
+
act_cfg=act_cfg)
|
95 |
+
self.fcs.append(fc)
|
96 |
+
fc_in_channels = fc_channels
|
97 |
+
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
|
98 |
+
else 0
|
99 |
+
self.fc_seg = nn.Conv1d(
|
100 |
+
fc_in_channels,
|
101 |
+
self.num_classes,
|
102 |
+
kernel_size=1,
|
103 |
+
stride=1,
|
104 |
+
padding=0)
|
105 |
+
if self.dropout_ratio > 0:
|
106 |
+
self.dropout = nn.Dropout(self.dropout_ratio)
|
107 |
+
delattr(self, 'conv_seg')
|
108 |
+
|
109 |
+
def init_weights(self):
|
110 |
+
"""Initialize weights of classification layer."""
|
111 |
+
normal_init(self.fc_seg, std=0.001)
|
112 |
+
|
113 |
+
def cls_seg(self, feat):
|
114 |
+
"""Classify each pixel with fc."""
|
115 |
+
if self.dropout is not None:
|
116 |
+
feat = self.dropout(feat)
|
117 |
+
output = self.fc_seg(feat)
|
118 |
+
return output
|
119 |
+
|
120 |
+
def forward(self, fine_grained_point_feats, coarse_point_feats):
|
121 |
+
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
|
122 |
+
for fc in self.fcs:
|
123 |
+
x = fc(x)
|
124 |
+
if self.coarse_pred_each_layer:
|
125 |
+
x = torch.cat((x, coarse_point_feats), dim=1)
|
126 |
+
return self.cls_seg(x)
|
127 |
+
|
128 |
+
def _get_fine_grained_point_feats(self, x, points):
|
129 |
+
"""Sample from fine grained features.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
x (list[Tensor]): Feature pyramid from by neck or backbone.
|
133 |
+
points (Tensor): Point coordinates, shape (batch_size,
|
134 |
+
num_points, 2).
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
fine_grained_feats (Tensor): Sampled fine grained feature,
|
138 |
+
shape (batch_size, sum(channels of x), num_points).
|
139 |
+
"""
|
140 |
+
|
141 |
+
fine_grained_feats_list = [
|
142 |
+
point_sample(_, points, align_corners=self.align_corners)
|
143 |
+
for _ in x
|
144 |
+
]
|
145 |
+
if len(fine_grained_feats_list) > 1:
|
146 |
+
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
|
147 |
+
else:
|
148 |
+
fine_grained_feats = fine_grained_feats_list[0]
|
149 |
+
|
150 |
+
return fine_grained_feats
|
151 |
+
|
152 |
+
def _get_coarse_point_feats(self, prev_output, points):
|
153 |
+
"""Sample from fine grained features.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
prev_output (list[Tensor]): Prediction of previous decode head.
|
157 |
+
points (Tensor): Point coordinates, shape (batch_size,
|
158 |
+
num_points, 2).
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
|
162 |
+
num_classes, num_points).
|
163 |
+
"""
|
164 |
+
|
165 |
+
coarse_feats = point_sample(
|
166 |
+
prev_output, points, align_corners=self.align_corners)
|
167 |
+
|
168 |
+
return coarse_feats
|
169 |
+
|
170 |
+
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
|
171 |
+
train_cfg):
|
172 |
+
"""Forward function for training.
|
173 |
+
Args:
|
174 |
+
inputs (list[Tensor]): List of multi-level img features.
|
175 |
+
prev_output (Tensor): The output of previous decode head.
|
176 |
+
img_metas (list[dict]): List of image info dict where each dict
|
177 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
178 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
179 |
+
For details on the values of these keys see
|
180 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
181 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
182 |
+
used if the architecture supports semantic segmentation task.
|
183 |
+
train_cfg (dict): The training config.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
dict[str, Tensor]: a dictionary of loss components
|
187 |
+
"""
|
188 |
+
x = self._transform_inputs(inputs)
|
189 |
+
with torch.no_grad():
|
190 |
+
points = self.get_points_train(
|
191 |
+
prev_output, calculate_uncertainty, cfg=train_cfg)
|
192 |
+
fine_grained_point_feats = self._get_fine_grained_point_feats(
|
193 |
+
x, points)
|
194 |
+
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
|
195 |
+
point_logits = self.forward(fine_grained_point_feats,
|
196 |
+
coarse_point_feats)
|
197 |
+
point_label = point_sample(
|
198 |
+
gt_semantic_seg.float(),
|
199 |
+
points,
|
200 |
+
mode='nearest',
|
201 |
+
align_corners=self.align_corners)
|
202 |
+
point_label = point_label.squeeze(1).long()
|
203 |
+
|
204 |
+
losses = self.losses(point_logits, point_label)
|
205 |
+
|
206 |
+
return losses
|
207 |
+
|
208 |
+
def forward_test(self, inputs, prev_output, img_metas, test_cfg):
|
209 |
+
"""Forward function for testing.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
inputs (list[Tensor]): List of multi-level img features.
|
213 |
+
prev_output (Tensor): The output of previous decode head.
|
214 |
+
img_metas (list[dict]): List of image info dict where each dict
|
215 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
216 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
217 |
+
For details on the values of these keys see
|
218 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
219 |
+
test_cfg (dict): The testing config.
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
Tensor: Output segmentation map.
|
223 |
+
"""
|
224 |
+
|
225 |
+
x = self._transform_inputs(inputs)
|
226 |
+
refined_seg_logits = prev_output.clone()
|
227 |
+
for _ in range(test_cfg.subdivision_steps):
|
228 |
+
refined_seg_logits = resize(
|
229 |
+
refined_seg_logits,
|
230 |
+
scale_factor=test_cfg.scale_factor,
|
231 |
+
mode='bilinear',
|
232 |
+
align_corners=self.align_corners)
|
233 |
+
batch_size, channels, height, width = refined_seg_logits.shape
|
234 |
+
point_indices, points = self.get_points_test(
|
235 |
+
refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
|
236 |
+
fine_grained_point_feats = self._get_fine_grained_point_feats(
|
237 |
+
x, points)
|
238 |
+
coarse_point_feats = self._get_coarse_point_feats(
|
239 |
+
prev_output, points)
|
240 |
+
point_logits = self.forward(fine_grained_point_feats,
|
241 |
+
coarse_point_feats)
|
242 |
+
|
243 |
+
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
|
244 |
+
refined_seg_logits = refined_seg_logits.reshape(
|
245 |
+
batch_size, channels, height * width)
|
246 |
+
refined_seg_logits = refined_seg_logits.scatter_(
|
247 |
+
2, point_indices, point_logits)
|
248 |
+
refined_seg_logits = refined_seg_logits.view(
|
249 |
+
batch_size, channels, height, width)
|
250 |
+
|
251 |
+
return refined_seg_logits
|
252 |
+
|
253 |
+
def losses(self, point_logits, point_label):
|
254 |
+
"""Compute segmentation loss."""
|
255 |
+
loss = dict()
|
256 |
+
loss['loss_point'] = self.loss_decode(
|
257 |
+
point_logits, point_label, ignore_index=self.ignore_index)
|
258 |
+
loss['acc_point'] = accuracy(point_logits, point_label)
|
259 |
+
return loss
|
260 |
+
|
261 |
+
def get_points_train(self, seg_logits, uncertainty_func, cfg):
|
262 |
+
"""Sample points for training.
|
263 |
+
|
264 |
+
Sample points in [0, 1] x [0, 1] coordinate space based on their
|
265 |
+
uncertainty. The uncertainties are calculated for each point using
|
266 |
+
'uncertainty_func' function that takes point's logit prediction as
|
267 |
+
input.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
seg_logits (Tensor): Semantic segmentation logits, shape (
|
271 |
+
batch_size, num_classes, height, width).
|
272 |
+
uncertainty_func (func): uncertainty calculation function.
|
273 |
+
cfg (dict): Training config of point head.
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
point_coords (Tensor): A tensor of shape (batch_size, num_points,
|
277 |
+
2) that contains the coordinates of ``num_points`` sampled
|
278 |
+
points.
|
279 |
+
"""
|
280 |
+
num_points = cfg.num_points
|
281 |
+
oversample_ratio = cfg.oversample_ratio
|
282 |
+
importance_sample_ratio = cfg.importance_sample_ratio
|
283 |
+
assert oversample_ratio >= 1
|
284 |
+
assert 0 <= importance_sample_ratio <= 1
|
285 |
+
batch_size = seg_logits.shape[0]
|
286 |
+
num_sampled = int(num_points * oversample_ratio)
|
287 |
+
point_coords = torch.rand(
|
288 |
+
batch_size, num_sampled, 2, device=seg_logits.device)
|
289 |
+
point_logits = point_sample(seg_logits, point_coords)
|
290 |
+
# It is crucial to calculate uncertainty based on the sampled
|
291 |
+
# prediction value for the points. Calculating uncertainties of the
|
292 |
+
# coarse predictions first and sampling them for points leads to
|
293 |
+
# incorrect results. To illustrate this: assume uncertainty func(
|
294 |
+
# logits)=-abs(logits), a sampled point between two coarse
|
295 |
+
# predictions with -1 and 1 logits has 0 logits, and therefore 0
|
296 |
+
# uncertainty value. However, if we calculate uncertainties for the
|
297 |
+
# coarse predictions first, both will have -1 uncertainty,
|
298 |
+
# and sampled point will get -1 uncertainty.
|
299 |
+
point_uncertainties = uncertainty_func(point_logits)
|
300 |
+
num_uncertain_points = int(importance_sample_ratio * num_points)
|
301 |
+
num_random_points = num_points - num_uncertain_points
|
302 |
+
idx = torch.topk(
|
303 |
+
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
304 |
+
shift = num_sampled * torch.arange(
|
305 |
+
batch_size, dtype=torch.long, device=seg_logits.device)
|
306 |
+
idx += shift[:, None]
|
307 |
+
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
|
308 |
+
batch_size, num_uncertain_points, 2)
|
309 |
+
if num_random_points > 0:
|
310 |
+
rand_point_coords = torch.rand(
|
311 |
+
batch_size, num_random_points, 2, device=seg_logits.device)
|
312 |
+
point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
|
313 |
+
return point_coords
|
314 |
+
|
315 |
+
def get_points_test(self, seg_logits, uncertainty_func, cfg):
|
316 |
+
"""Sample points for testing.
|
317 |
+
|
318 |
+
Find ``num_points`` most uncertain points from ``uncertainty_map``.
|
319 |
+
|
320 |
+
Args:
|
321 |
+
seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
|
322 |
+
height, width) for class-specific or class-agnostic prediction.
|
323 |
+
uncertainty_func (func): uncertainty calculation function.
|
324 |
+
cfg (dict): Testing config of point head.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
point_indices (Tensor): A tensor of shape (batch_size, num_points)
|
328 |
+
that contains indices from [0, height x width) of the most
|
329 |
+
uncertain points.
|
330 |
+
point_coords (Tensor): A tensor of shape (batch_size, num_points,
|
331 |
+
2) that contains [0, 1] x [0, 1] normalized coordinates of the
|
332 |
+
most uncertain points from the ``height x width`` grid .
|
333 |
+
"""
|
334 |
+
|
335 |
+
num_points = cfg.subdivision_num_points
|
336 |
+
uncertainty_map = uncertainty_func(seg_logits)
|
337 |
+
batch_size, _, height, width = uncertainty_map.shape
|
338 |
+
h_step = 1.0 / height
|
339 |
+
w_step = 1.0 / width
|
340 |
+
|
341 |
+
uncertainty_map = uncertainty_map.view(batch_size, height * width)
|
342 |
+
num_points = min(height * width, num_points)
|
343 |
+
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
|
344 |
+
point_coords = torch.zeros(
|
345 |
+
batch_size,
|
346 |
+
num_points,
|
347 |
+
2,
|
348 |
+
dtype=torch.float,
|
349 |
+
device=seg_logits.device)
|
350 |
+
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
|
351 |
+
width).float() * w_step
|
352 |
+
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
|
353 |
+
width).float() * h_step
|
354 |
+
return point_indices, point_coords
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/psa_head.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmseg.ops import resize
|
7 |
+
from ..builder import HEADS
|
8 |
+
from .decode_head import BaseDecodeHead
|
9 |
+
|
10 |
+
try:
|
11 |
+
try:
|
12 |
+
from mmcv.ops import PSAMask
|
13 |
+
except ImportError:
|
14 |
+
from annotator.mmpkg.mmcv.ops import PSAMask
|
15 |
+
except ModuleNotFoundError:
|
16 |
+
PSAMask = None
|
17 |
+
|
18 |
+
|
19 |
+
@HEADS.register_module()
|
20 |
+
class PSAHead(BaseDecodeHead):
|
21 |
+
"""Point-wise Spatial Attention Network for Scene Parsing.
|
22 |
+
|
23 |
+
This head is the implementation of `PSANet
|
24 |
+
<https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
mask_size (tuple[int]): The PSA mask size. It usually equals input
|
28 |
+
size.
|
29 |
+
psa_type (str): The type of psa module. Options are 'collect',
|
30 |
+
'distribute', 'bi-direction'. Default: 'bi-direction'
|
31 |
+
compact (bool): Whether use compact map for 'collect' mode.
|
32 |
+
Default: True.
|
33 |
+
shrink_factor (int): The downsample factors of psa mask. Default: 2.
|
34 |
+
normalization_factor (float): The normalize factor of attention.
|
35 |
+
psa_softmax (bool): Whether use softmax for attention.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
mask_size,
|
40 |
+
psa_type='bi-direction',
|
41 |
+
compact=False,
|
42 |
+
shrink_factor=2,
|
43 |
+
normalization_factor=1.0,
|
44 |
+
psa_softmax=True,
|
45 |
+
**kwargs):
|
46 |
+
if PSAMask is None:
|
47 |
+
raise RuntimeError('Please install mmcv-full for PSAMask ops')
|
48 |
+
super(PSAHead, self).__init__(**kwargs)
|
49 |
+
assert psa_type in ['collect', 'distribute', 'bi-direction']
|
50 |
+
self.psa_type = psa_type
|
51 |
+
self.compact = compact
|
52 |
+
self.shrink_factor = shrink_factor
|
53 |
+
self.mask_size = mask_size
|
54 |
+
mask_h, mask_w = mask_size
|
55 |
+
self.psa_softmax = psa_softmax
|
56 |
+
if normalization_factor is None:
|
57 |
+
normalization_factor = mask_h * mask_w
|
58 |
+
self.normalization_factor = normalization_factor
|
59 |
+
|
60 |
+
self.reduce = ConvModule(
|
61 |
+
self.in_channels,
|
62 |
+
self.channels,
|
63 |
+
kernel_size=1,
|
64 |
+
conv_cfg=self.conv_cfg,
|
65 |
+
norm_cfg=self.norm_cfg,
|
66 |
+
act_cfg=self.act_cfg)
|
67 |
+
self.attention = nn.Sequential(
|
68 |
+
ConvModule(
|
69 |
+
self.channels,
|
70 |
+
self.channels,
|
71 |
+
kernel_size=1,
|
72 |
+
conv_cfg=self.conv_cfg,
|
73 |
+
norm_cfg=self.norm_cfg,
|
74 |
+
act_cfg=self.act_cfg),
|
75 |
+
nn.Conv2d(
|
76 |
+
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
|
77 |
+
if psa_type == 'bi-direction':
|
78 |
+
self.reduce_p = ConvModule(
|
79 |
+
self.in_channels,
|
80 |
+
self.channels,
|
81 |
+
kernel_size=1,
|
82 |
+
conv_cfg=self.conv_cfg,
|
83 |
+
norm_cfg=self.norm_cfg,
|
84 |
+
act_cfg=self.act_cfg)
|
85 |
+
self.attention_p = nn.Sequential(
|
86 |
+
ConvModule(
|
87 |
+
self.channels,
|
88 |
+
self.channels,
|
89 |
+
kernel_size=1,
|
90 |
+
conv_cfg=self.conv_cfg,
|
91 |
+
norm_cfg=self.norm_cfg,
|
92 |
+
act_cfg=self.act_cfg),
|
93 |
+
nn.Conv2d(
|
94 |
+
self.channels, mask_h * mask_w, kernel_size=1, bias=False))
|
95 |
+
self.psamask_collect = PSAMask('collect', mask_size)
|
96 |
+
self.psamask_distribute = PSAMask('distribute', mask_size)
|
97 |
+
else:
|
98 |
+
self.psamask = PSAMask(psa_type, mask_size)
|
99 |
+
self.proj = ConvModule(
|
100 |
+
self.channels * (2 if psa_type == 'bi-direction' else 1),
|
101 |
+
self.in_channels,
|
102 |
+
kernel_size=1,
|
103 |
+
padding=1,
|
104 |
+
conv_cfg=self.conv_cfg,
|
105 |
+
norm_cfg=self.norm_cfg,
|
106 |
+
act_cfg=self.act_cfg)
|
107 |
+
self.bottleneck = ConvModule(
|
108 |
+
self.in_channels * 2,
|
109 |
+
self.channels,
|
110 |
+
kernel_size=3,
|
111 |
+
padding=1,
|
112 |
+
conv_cfg=self.conv_cfg,
|
113 |
+
norm_cfg=self.norm_cfg,
|
114 |
+
act_cfg=self.act_cfg)
|
115 |
+
|
116 |
+
def forward(self, inputs):
|
117 |
+
"""Forward function."""
|
118 |
+
x = self._transform_inputs(inputs)
|
119 |
+
identity = x
|
120 |
+
align_corners = self.align_corners
|
121 |
+
if self.psa_type in ['collect', 'distribute']:
|
122 |
+
out = self.reduce(x)
|
123 |
+
n, c, h, w = out.size()
|
124 |
+
if self.shrink_factor != 1:
|
125 |
+
if h % self.shrink_factor and w % self.shrink_factor:
|
126 |
+
h = (h - 1) // self.shrink_factor + 1
|
127 |
+
w = (w - 1) // self.shrink_factor + 1
|
128 |
+
align_corners = True
|
129 |
+
else:
|
130 |
+
h = h // self.shrink_factor
|
131 |
+
w = w // self.shrink_factor
|
132 |
+
align_corners = False
|
133 |
+
out = resize(
|
134 |
+
out,
|
135 |
+
size=(h, w),
|
136 |
+
mode='bilinear',
|
137 |
+
align_corners=align_corners)
|
138 |
+
y = self.attention(out)
|
139 |
+
if self.compact:
|
140 |
+
if self.psa_type == 'collect':
|
141 |
+
y = y.view(n, h * w,
|
142 |
+
h * w).transpose(1, 2).view(n, h * w, h, w)
|
143 |
+
else:
|
144 |
+
y = self.psamask(y)
|
145 |
+
if self.psa_softmax:
|
146 |
+
y = F.softmax(y, dim=1)
|
147 |
+
out = torch.bmm(
|
148 |
+
out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
|
149 |
+
n, c, h, w) * (1.0 / self.normalization_factor)
|
150 |
+
else:
|
151 |
+
x_col = self.reduce(x)
|
152 |
+
x_dis = self.reduce_p(x)
|
153 |
+
n, c, h, w = x_col.size()
|
154 |
+
if self.shrink_factor != 1:
|
155 |
+
if h % self.shrink_factor and w % self.shrink_factor:
|
156 |
+
h = (h - 1) // self.shrink_factor + 1
|
157 |
+
w = (w - 1) // self.shrink_factor + 1
|
158 |
+
align_corners = True
|
159 |
+
else:
|
160 |
+
h = h // self.shrink_factor
|
161 |
+
w = w // self.shrink_factor
|
162 |
+
align_corners = False
|
163 |
+
x_col = resize(
|
164 |
+
x_col,
|
165 |
+
size=(h, w),
|
166 |
+
mode='bilinear',
|
167 |
+
align_corners=align_corners)
|
168 |
+
x_dis = resize(
|
169 |
+
x_dis,
|
170 |
+
size=(h, w),
|
171 |
+
mode='bilinear',
|
172 |
+
align_corners=align_corners)
|
173 |
+
y_col = self.attention(x_col)
|
174 |
+
y_dis = self.attention_p(x_dis)
|
175 |
+
if self.compact:
|
176 |
+
y_dis = y_dis.view(n, h * w,
|
177 |
+
h * w).transpose(1, 2).view(n, h * w, h, w)
|
178 |
+
else:
|
179 |
+
y_col = self.psamask_collect(y_col)
|
180 |
+
y_dis = self.psamask_distribute(y_dis)
|
181 |
+
if self.psa_softmax:
|
182 |
+
y_col = F.softmax(y_col, dim=1)
|
183 |
+
y_dis = F.softmax(y_dis, dim=1)
|
184 |
+
x_col = torch.bmm(
|
185 |
+
x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
|
186 |
+
n, c, h, w) * (1.0 / self.normalization_factor)
|
187 |
+
x_dis = torch.bmm(
|
188 |
+
x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
|
189 |
+
n, c, h, w) * (1.0 / self.normalization_factor)
|
190 |
+
out = torch.cat([x_col, x_dis], 1)
|
191 |
+
out = self.proj(out)
|
192 |
+
out = resize(
|
193 |
+
out,
|
194 |
+
size=identity.shape[2:],
|
195 |
+
mode='bilinear',
|
196 |
+
align_corners=align_corners)
|
197 |
+
out = self.bottleneck(torch.cat((identity, out), dim=1))
|
198 |
+
out = self.cls_seg(out)
|
199 |
+
return out
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/psp_head.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
4 |
+
|
5 |
+
from annotator.mmpkg.mmseg.ops import resize
|
6 |
+
from ..builder import HEADS
|
7 |
+
from .decode_head import BaseDecodeHead
|
8 |
+
|
9 |
+
|
10 |
+
class PPM(nn.ModuleList):
|
11 |
+
"""Pooling Pyramid Module used in PSPNet.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
15 |
+
Module.
|
16 |
+
in_channels (int): Input channels.
|
17 |
+
channels (int): Channels after modules, before conv_seg.
|
18 |
+
conv_cfg (dict|None): Config of conv layers.
|
19 |
+
norm_cfg (dict|None): Config of norm layers.
|
20 |
+
act_cfg (dict): Config of activation layers.
|
21 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
25 |
+
act_cfg, align_corners):
|
26 |
+
super(PPM, self).__init__()
|
27 |
+
self.pool_scales = pool_scales
|
28 |
+
self.align_corners = align_corners
|
29 |
+
self.in_channels = in_channels
|
30 |
+
self.channels = channels
|
31 |
+
self.conv_cfg = conv_cfg
|
32 |
+
self.norm_cfg = norm_cfg
|
33 |
+
self.act_cfg = act_cfg
|
34 |
+
for pool_scale in pool_scales:
|
35 |
+
self.append(
|
36 |
+
nn.Sequential(
|
37 |
+
nn.AdaptiveAvgPool2d(pool_scale),
|
38 |
+
ConvModule(
|
39 |
+
self.in_channels,
|
40 |
+
self.channels,
|
41 |
+
1,
|
42 |
+
conv_cfg=self.conv_cfg,
|
43 |
+
norm_cfg=self.norm_cfg,
|
44 |
+
act_cfg=self.act_cfg)))
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
"""Forward function."""
|
48 |
+
ppm_outs = []
|
49 |
+
for ppm in self:
|
50 |
+
ppm_out = ppm(x)
|
51 |
+
upsampled_ppm_out = resize(
|
52 |
+
ppm_out,
|
53 |
+
size=x.size()[2:],
|
54 |
+
mode='bilinear',
|
55 |
+
align_corners=self.align_corners)
|
56 |
+
ppm_outs.append(upsampled_ppm_out)
|
57 |
+
return ppm_outs
|
58 |
+
|
59 |
+
|
60 |
+
@HEADS.register_module()
|
61 |
+
class PSPHead(BaseDecodeHead):
|
62 |
+
"""Pyramid Scene Parsing Network.
|
63 |
+
|
64 |
+
This head is the implementation of
|
65 |
+
`PSPNet <https://arxiv.org/abs/1612.01105>`_.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
69 |
+
Module. Default: (1, 2, 3, 6).
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
73 |
+
super(PSPHead, self).__init__(**kwargs)
|
74 |
+
assert isinstance(pool_scales, (list, tuple))
|
75 |
+
self.pool_scales = pool_scales
|
76 |
+
self.psp_modules = PPM(
|
77 |
+
self.pool_scales,
|
78 |
+
self.in_channels,
|
79 |
+
self.channels,
|
80 |
+
conv_cfg=self.conv_cfg,
|
81 |
+
norm_cfg=self.norm_cfg,
|
82 |
+
act_cfg=self.act_cfg,
|
83 |
+
align_corners=self.align_corners)
|
84 |
+
self.bottleneck = ConvModule(
|
85 |
+
self.in_channels + len(pool_scales) * self.channels,
|
86 |
+
self.channels,
|
87 |
+
3,
|
88 |
+
padding=1,
|
89 |
+
conv_cfg=self.conv_cfg,
|
90 |
+
norm_cfg=self.norm_cfg,
|
91 |
+
act_cfg=self.act_cfg)
|
92 |
+
|
93 |
+
def forward(self, inputs):
|
94 |
+
"""Forward function."""
|
95 |
+
x = self._transform_inputs(inputs)
|
96 |
+
psp_outs = [x]
|
97 |
+
psp_outs.extend(self.psp_modules(x))
|
98 |
+
psp_outs = torch.cat(psp_outs, dim=1)
|
99 |
+
output = self.bottleneck(psp_outs)
|
100 |
+
output = self.cls_seg(output)
|
101 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/sep_aspp_head.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
4 |
+
|
5 |
+
from annotator.mmpkg.mmseg.ops import resize
|
6 |
+
from ..builder import HEADS
|
7 |
+
from .aspp_head import ASPPHead, ASPPModule
|
8 |
+
|
9 |
+
|
10 |
+
class DepthwiseSeparableASPPModule(ASPPModule):
|
11 |
+
"""Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
|
12 |
+
conv."""
|
13 |
+
|
14 |
+
def __init__(self, **kwargs):
|
15 |
+
super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
|
16 |
+
for i, dilation in enumerate(self.dilations):
|
17 |
+
if dilation > 1:
|
18 |
+
self[i] = DepthwiseSeparableConvModule(
|
19 |
+
self.in_channels,
|
20 |
+
self.channels,
|
21 |
+
3,
|
22 |
+
dilation=dilation,
|
23 |
+
padding=dilation,
|
24 |
+
norm_cfg=self.norm_cfg,
|
25 |
+
act_cfg=self.act_cfg)
|
26 |
+
|
27 |
+
|
28 |
+
@HEADS.register_module()
|
29 |
+
class DepthwiseSeparableASPPHead(ASPPHead):
|
30 |
+
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
|
31 |
+
Segmentation.
|
32 |
+
|
33 |
+
This head is the implementation of `DeepLabV3+
|
34 |
+
<https://arxiv.org/abs/1802.02611>`_.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
c1_in_channels (int): The input channels of c1 decoder. If is 0,
|
38 |
+
the no decoder will be used.
|
39 |
+
c1_channels (int): The intermediate channels of c1 decoder.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, c1_in_channels, c1_channels, **kwargs):
|
43 |
+
super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
|
44 |
+
assert c1_in_channels >= 0
|
45 |
+
self.aspp_modules = DepthwiseSeparableASPPModule(
|
46 |
+
dilations=self.dilations,
|
47 |
+
in_channels=self.in_channels,
|
48 |
+
channels=self.channels,
|
49 |
+
conv_cfg=self.conv_cfg,
|
50 |
+
norm_cfg=self.norm_cfg,
|
51 |
+
act_cfg=self.act_cfg)
|
52 |
+
if c1_in_channels > 0:
|
53 |
+
self.c1_bottleneck = ConvModule(
|
54 |
+
c1_in_channels,
|
55 |
+
c1_channels,
|
56 |
+
1,
|
57 |
+
conv_cfg=self.conv_cfg,
|
58 |
+
norm_cfg=self.norm_cfg,
|
59 |
+
act_cfg=self.act_cfg)
|
60 |
+
else:
|
61 |
+
self.c1_bottleneck = None
|
62 |
+
self.sep_bottleneck = nn.Sequential(
|
63 |
+
DepthwiseSeparableConvModule(
|
64 |
+
self.channels + c1_channels,
|
65 |
+
self.channels,
|
66 |
+
3,
|
67 |
+
padding=1,
|
68 |
+
norm_cfg=self.norm_cfg,
|
69 |
+
act_cfg=self.act_cfg),
|
70 |
+
DepthwiseSeparableConvModule(
|
71 |
+
self.channels,
|
72 |
+
self.channels,
|
73 |
+
3,
|
74 |
+
padding=1,
|
75 |
+
norm_cfg=self.norm_cfg,
|
76 |
+
act_cfg=self.act_cfg))
|
77 |
+
|
78 |
+
def forward(self, inputs):
|
79 |
+
"""Forward function."""
|
80 |
+
x = self._transform_inputs(inputs)
|
81 |
+
aspp_outs = [
|
82 |
+
resize(
|
83 |
+
self.image_pool(x),
|
84 |
+
size=x.size()[2:],
|
85 |
+
mode='bilinear',
|
86 |
+
align_corners=self.align_corners)
|
87 |
+
]
|
88 |
+
aspp_outs.extend(self.aspp_modules(x))
|
89 |
+
aspp_outs = torch.cat(aspp_outs, dim=1)
|
90 |
+
output = self.bottleneck(aspp_outs)
|
91 |
+
if self.c1_bottleneck is not None:
|
92 |
+
c1_output = self.c1_bottleneck(inputs[0])
|
93 |
+
output = resize(
|
94 |
+
input=output,
|
95 |
+
size=c1_output.shape[2:],
|
96 |
+
mode='bilinear',
|
97 |
+
align_corners=self.align_corners)
|
98 |
+
output = torch.cat([output, c1_output], dim=1)
|
99 |
+
output = self.sep_bottleneck(output)
|
100 |
+
output = self.cls_seg(output)
|
101 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/sep_fcn_head.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from annotator.mmpkg.mmcv.cnn import DepthwiseSeparableConvModule
|
2 |
+
|
3 |
+
from ..builder import HEADS
|
4 |
+
from .fcn_head import FCNHead
|
5 |
+
|
6 |
+
|
7 |
+
@HEADS.register_module()
|
8 |
+
class DepthwiseSeparableFCNHead(FCNHead):
|
9 |
+
"""Depthwise-Separable Fully Convolutional Network for Semantic
|
10 |
+
Segmentation.
|
11 |
+
|
12 |
+
This head is implemented according to Fast-SCNN paper.
|
13 |
+
Args:
|
14 |
+
in_channels(int): Number of output channels of FFM.
|
15 |
+
channels(int): Number of middle-stage channels in the decode head.
|
16 |
+
concat_input(bool): Whether to concatenate original decode input into
|
17 |
+
the result of several consecutive convolution layers.
|
18 |
+
Default: True.
|
19 |
+
num_classes(int): Used to determine the dimension of
|
20 |
+
final prediction tensor.
|
21 |
+
in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
|
22 |
+
norm_cfg (dict | None): Config of norm layers.
|
23 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
24 |
+
Default: False.
|
25 |
+
loss_decode(dict): Config of loss type and some
|
26 |
+
relevant additional options.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, **kwargs):
|
30 |
+
super(DepthwiseSeparableFCNHead, self).__init__(**kwargs)
|
31 |
+
self.convs[0] = DepthwiseSeparableConvModule(
|
32 |
+
self.in_channels,
|
33 |
+
self.channels,
|
34 |
+
kernel_size=self.kernel_size,
|
35 |
+
padding=self.kernel_size // 2,
|
36 |
+
norm_cfg=self.norm_cfg)
|
37 |
+
for i in range(1, self.num_convs):
|
38 |
+
self.convs[i] = DepthwiseSeparableConvModule(
|
39 |
+
self.channels,
|
40 |
+
self.channels,
|
41 |
+
kernel_size=self.kernel_size,
|
42 |
+
padding=self.kernel_size // 2,
|
43 |
+
norm_cfg=self.norm_cfg)
|
44 |
+
|
45 |
+
if self.concat_input:
|
46 |
+
self.conv_cat = DepthwiseSeparableConvModule(
|
47 |
+
self.in_channels + self.channels,
|
48 |
+
self.channels,
|
49 |
+
kernel_size=self.kernel_size,
|
50 |
+
padding=self.kernel_size // 2,
|
51 |
+
norm_cfg=self.norm_cfg)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/decode_heads/uper_head.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule
|
4 |
+
|
5 |
+
from annotator.mmpkg.mmseg.ops import resize
|
6 |
+
from ..builder import HEADS
|
7 |
+
from .decode_head import BaseDecodeHead
|
8 |
+
from .psp_head import PPM
|
9 |
+
|
10 |
+
|
11 |
+
@HEADS.register_module()
|
12 |
+
class UPerHead(BaseDecodeHead):
|
13 |
+
"""Unified Perceptual Parsing for Scene Understanding.
|
14 |
+
|
15 |
+
This head is the implementation of `UPerNet
|
16 |
+
<https://arxiv.org/abs/1807.10221>`_.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
20 |
+
Module applied on the last feature. Default: (1, 2, 3, 6).
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
24 |
+
super(UPerHead, self).__init__(
|
25 |
+
input_transform='multiple_select', **kwargs)
|
26 |
+
# PSP Module
|
27 |
+
self.psp_modules = PPM(
|
28 |
+
pool_scales,
|
29 |
+
self.in_channels[-1],
|
30 |
+
self.channels,
|
31 |
+
conv_cfg=self.conv_cfg,
|
32 |
+
norm_cfg=self.norm_cfg,
|
33 |
+
act_cfg=self.act_cfg,
|
34 |
+
align_corners=self.align_corners)
|
35 |
+
self.bottleneck = ConvModule(
|
36 |
+
self.in_channels[-1] + len(pool_scales) * self.channels,
|
37 |
+
self.channels,
|
38 |
+
3,
|
39 |
+
padding=1,
|
40 |
+
conv_cfg=self.conv_cfg,
|
41 |
+
norm_cfg=self.norm_cfg,
|
42 |
+
act_cfg=self.act_cfg)
|
43 |
+
# FPN Module
|
44 |
+
self.lateral_convs = nn.ModuleList()
|
45 |
+
self.fpn_convs = nn.ModuleList()
|
46 |
+
for in_channels in self.in_channels[:-1]: # skip the top layer
|
47 |
+
l_conv = ConvModule(
|
48 |
+
in_channels,
|
49 |
+
self.channels,
|
50 |
+
1,
|
51 |
+
conv_cfg=self.conv_cfg,
|
52 |
+
norm_cfg=self.norm_cfg,
|
53 |
+
act_cfg=self.act_cfg,
|
54 |
+
inplace=False)
|
55 |
+
fpn_conv = ConvModule(
|
56 |
+
self.channels,
|
57 |
+
self.channels,
|
58 |
+
3,
|
59 |
+
padding=1,
|
60 |
+
conv_cfg=self.conv_cfg,
|
61 |
+
norm_cfg=self.norm_cfg,
|
62 |
+
act_cfg=self.act_cfg,
|
63 |
+
inplace=False)
|
64 |
+
self.lateral_convs.append(l_conv)
|
65 |
+
self.fpn_convs.append(fpn_conv)
|
66 |
+
|
67 |
+
self.fpn_bottleneck = ConvModule(
|
68 |
+
len(self.in_channels) * self.channels,
|
69 |
+
self.channels,
|
70 |
+
3,
|
71 |
+
padding=1,
|
72 |
+
conv_cfg=self.conv_cfg,
|
73 |
+
norm_cfg=self.norm_cfg,
|
74 |
+
act_cfg=self.act_cfg)
|
75 |
+
|
76 |
+
def psp_forward(self, inputs):
|
77 |
+
"""Forward function of PSP module."""
|
78 |
+
x = inputs[-1]
|
79 |
+
psp_outs = [x]
|
80 |
+
psp_outs.extend(self.psp_modules(x))
|
81 |
+
psp_outs = torch.cat(psp_outs, dim=1)
|
82 |
+
output = self.bottleneck(psp_outs)
|
83 |
+
|
84 |
+
return output
|
85 |
+
|
86 |
+
def forward(self, inputs):
|
87 |
+
"""Forward function."""
|
88 |
+
|
89 |
+
inputs = self._transform_inputs(inputs)
|
90 |
+
|
91 |
+
# build laterals
|
92 |
+
laterals = [
|
93 |
+
lateral_conv(inputs[i])
|
94 |
+
for i, lateral_conv in enumerate(self.lateral_convs)
|
95 |
+
]
|
96 |
+
|
97 |
+
laterals.append(self.psp_forward(inputs))
|
98 |
+
|
99 |
+
# build top-down path
|
100 |
+
used_backbone_levels = len(laterals)
|
101 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
102 |
+
prev_shape = laterals[i - 1].shape[2:]
|
103 |
+
laterals[i - 1] += resize(
|
104 |
+
laterals[i],
|
105 |
+
size=prev_shape,
|
106 |
+
mode='bilinear',
|
107 |
+
align_corners=self.align_corners)
|
108 |
+
|
109 |
+
# build outputs
|
110 |
+
fpn_outs = [
|
111 |
+
self.fpn_convs[i](laterals[i])
|
112 |
+
for i in range(used_backbone_levels - 1)
|
113 |
+
]
|
114 |
+
# append psp feature
|
115 |
+
fpn_outs.append(laterals[-1])
|
116 |
+
|
117 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
118 |
+
fpn_outs[i] = resize(
|
119 |
+
fpn_outs[i],
|
120 |
+
size=fpn_outs[0].shape[2:],
|
121 |
+
mode='bilinear',
|
122 |
+
align_corners=self.align_corners)
|
123 |
+
fpn_outs = torch.cat(fpn_outs, dim=1)
|
124 |
+
output = self.fpn_bottleneck(fpn_outs)
|
125 |
+
output = self.cls_seg(output)
|
126 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .accuracy import Accuracy, accuracy
|
2 |
+
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
3 |
+
cross_entropy, mask_cross_entropy)
|
4 |
+
from .dice_loss import DiceLoss
|
5 |
+
from .lovasz_loss import LovaszLoss
|
6 |
+
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
|
10 |
+
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
11 |
+
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
|
12 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/accuracy.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def accuracy(pred, target, topk=1, thresh=None):
|
5 |
+
"""Calculate accuracy according to the prediction and target.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
|
9 |
+
target (torch.Tensor): The target of each prediction, shape (N, , ...)
|
10 |
+
topk (int | tuple[int], optional): If the predictions in ``topk``
|
11 |
+
matches the target, the predictions will be regarded as
|
12 |
+
correct ones. Defaults to 1.
|
13 |
+
thresh (float, optional): If not None, predictions with scores under
|
14 |
+
this threshold are considered incorrect. Default to None.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
float | tuple[float]: If the input ``topk`` is a single integer,
|
18 |
+
the function will return a single float as accuracy. If
|
19 |
+
``topk`` is a tuple containing multiple integers, the
|
20 |
+
function will return a tuple containing accuracies of
|
21 |
+
each ``topk`` number.
|
22 |
+
"""
|
23 |
+
assert isinstance(topk, (int, tuple))
|
24 |
+
if isinstance(topk, int):
|
25 |
+
topk = (topk, )
|
26 |
+
return_single = True
|
27 |
+
else:
|
28 |
+
return_single = False
|
29 |
+
|
30 |
+
maxk = max(topk)
|
31 |
+
if pred.size(0) == 0:
|
32 |
+
accu = [pred.new_tensor(0.) for i in range(len(topk))]
|
33 |
+
return accu[0] if return_single else accu
|
34 |
+
assert pred.ndim == target.ndim + 1
|
35 |
+
assert pred.size(0) == target.size(0)
|
36 |
+
assert maxk <= pred.size(1), \
|
37 |
+
f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
|
38 |
+
pred_value, pred_label = pred.topk(maxk, dim=1)
|
39 |
+
# transpose to shape (maxk, N, ...)
|
40 |
+
pred_label = pred_label.transpose(0, 1)
|
41 |
+
correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
|
42 |
+
if thresh is not None:
|
43 |
+
# Only prediction values larger than thresh are counted as correct
|
44 |
+
correct = correct & (pred_value > thresh).t()
|
45 |
+
res = []
|
46 |
+
for k in topk:
|
47 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
48 |
+
res.append(correct_k.mul_(100.0 / target.numel()))
|
49 |
+
return res[0] if return_single else res
|
50 |
+
|
51 |
+
|
52 |
+
class Accuracy(nn.Module):
|
53 |
+
"""Accuracy calculation module."""
|
54 |
+
|
55 |
+
def __init__(self, topk=(1, ), thresh=None):
|
56 |
+
"""Module to calculate the accuracy.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
topk (tuple, optional): The criterion used to calculate the
|
60 |
+
accuracy. Defaults to (1,).
|
61 |
+
thresh (float, optional): If not None, predictions with scores
|
62 |
+
under this threshold are considered incorrect. Default to None.
|
63 |
+
"""
|
64 |
+
super().__init__()
|
65 |
+
self.topk = topk
|
66 |
+
self.thresh = thresh
|
67 |
+
|
68 |
+
def forward(self, pred, target):
|
69 |
+
"""Forward function to calculate accuracy.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
pred (torch.Tensor): Prediction of models.
|
73 |
+
target (torch.Tensor): Target for each prediction.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
tuple[float]: The accuracies under different topk criterions.
|
77 |
+
"""
|
78 |
+
return accuracy(pred, target, self.topk, self.thresh)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/cross_entropy_loss.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from ..builder import LOSSES
|
6 |
+
from .utils import get_class_weight, weight_reduce_loss
|
7 |
+
|
8 |
+
|
9 |
+
def cross_entropy(pred,
|
10 |
+
label,
|
11 |
+
weight=None,
|
12 |
+
class_weight=None,
|
13 |
+
reduction='mean',
|
14 |
+
avg_factor=None,
|
15 |
+
ignore_index=-100):
|
16 |
+
"""The wrapper function for :func:`F.cross_entropy`"""
|
17 |
+
# class_weight is a manual rescaling weight given to each class.
|
18 |
+
# If given, has to be a Tensor of size C element-wise losses
|
19 |
+
loss = F.cross_entropy(
|
20 |
+
pred,
|
21 |
+
label,
|
22 |
+
weight=class_weight,
|
23 |
+
reduction='none',
|
24 |
+
ignore_index=ignore_index)
|
25 |
+
|
26 |
+
# apply weights and do the reduction
|
27 |
+
if weight is not None:
|
28 |
+
weight = weight.float()
|
29 |
+
loss = weight_reduce_loss(
|
30 |
+
loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
31 |
+
|
32 |
+
return loss
|
33 |
+
|
34 |
+
|
35 |
+
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
36 |
+
"""Expand onehot labels to match the size of prediction."""
|
37 |
+
bin_labels = labels.new_zeros(target_shape)
|
38 |
+
valid_mask = (labels >= 0) & (labels != ignore_index)
|
39 |
+
inds = torch.nonzero(valid_mask, as_tuple=True)
|
40 |
+
|
41 |
+
if inds[0].numel() > 0:
|
42 |
+
if labels.dim() == 3:
|
43 |
+
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
44 |
+
else:
|
45 |
+
bin_labels[inds[0], labels[valid_mask]] = 1
|
46 |
+
|
47 |
+
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
48 |
+
if label_weights is None:
|
49 |
+
bin_label_weights = valid_mask
|
50 |
+
else:
|
51 |
+
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
52 |
+
bin_label_weights *= valid_mask
|
53 |
+
|
54 |
+
return bin_labels, bin_label_weights
|
55 |
+
|
56 |
+
|
57 |
+
def binary_cross_entropy(pred,
|
58 |
+
label,
|
59 |
+
weight=None,
|
60 |
+
reduction='mean',
|
61 |
+
avg_factor=None,
|
62 |
+
class_weight=None,
|
63 |
+
ignore_index=255):
|
64 |
+
"""Calculate the binary CrossEntropy loss.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
pred (torch.Tensor): The prediction with shape (N, 1).
|
68 |
+
label (torch.Tensor): The learning label of the prediction.
|
69 |
+
weight (torch.Tensor, optional): Sample-wise loss weight.
|
70 |
+
reduction (str, optional): The method used to reduce the loss.
|
71 |
+
Options are "none", "mean" and "sum".
|
72 |
+
avg_factor (int, optional): Average factor that is used to average
|
73 |
+
the loss. Defaults to None.
|
74 |
+
class_weight (list[float], optional): The weight for each class.
|
75 |
+
ignore_index (int | None): The label index to be ignored. Default: 255
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
torch.Tensor: The calculated loss
|
79 |
+
"""
|
80 |
+
if pred.dim() != label.dim():
|
81 |
+
assert (pred.dim() == 2 and label.dim() == 1) or (
|
82 |
+
pred.dim() == 4 and label.dim() == 3), \
|
83 |
+
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
|
84 |
+
'H, W], label shape [N, H, W] are supported'
|
85 |
+
label, weight = _expand_onehot_labels(label, weight, pred.shape,
|
86 |
+
ignore_index)
|
87 |
+
|
88 |
+
# weighted element-wise losses
|
89 |
+
if weight is not None:
|
90 |
+
weight = weight.float()
|
91 |
+
loss = F.binary_cross_entropy_with_logits(
|
92 |
+
pred, label.float(), pos_weight=class_weight, reduction='none')
|
93 |
+
# do the reduction for the weighted loss
|
94 |
+
loss = weight_reduce_loss(
|
95 |
+
loss, weight, reduction=reduction, avg_factor=avg_factor)
|
96 |
+
|
97 |
+
return loss
|
98 |
+
|
99 |
+
|
100 |
+
def mask_cross_entropy(pred,
|
101 |
+
target,
|
102 |
+
label,
|
103 |
+
reduction='mean',
|
104 |
+
avg_factor=None,
|
105 |
+
class_weight=None,
|
106 |
+
ignore_index=None):
|
107 |
+
"""Calculate the CrossEntropy loss for masks.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
111 |
+
of classes.
|
112 |
+
target (torch.Tensor): The learning label of the prediction.
|
113 |
+
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
114 |
+
corresponding object. This will be used to select the mask in the
|
115 |
+
of the class which the object belongs to when the mask prediction
|
116 |
+
if not class-agnostic.
|
117 |
+
reduction (str, optional): The method used to reduce the loss.
|
118 |
+
Options are "none", "mean" and "sum".
|
119 |
+
avg_factor (int, optional): Average factor that is used to average
|
120 |
+
the loss. Defaults to None.
|
121 |
+
class_weight (list[float], optional): The weight for each class.
|
122 |
+
ignore_index (None): Placeholder, to be consistent with other loss.
|
123 |
+
Default: None.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
torch.Tensor: The calculated loss
|
127 |
+
"""
|
128 |
+
assert ignore_index is None, 'BCE loss does not support ignore_index'
|
129 |
+
# TODO: handle these two reserved arguments
|
130 |
+
assert reduction == 'mean' and avg_factor is None
|
131 |
+
num_rois = pred.size()[0]
|
132 |
+
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
133 |
+
pred_slice = pred[inds, label].squeeze(1)
|
134 |
+
return F.binary_cross_entropy_with_logits(
|
135 |
+
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
136 |
+
|
137 |
+
|
138 |
+
@LOSSES.register_module()
|
139 |
+
class CrossEntropyLoss(nn.Module):
|
140 |
+
"""CrossEntropyLoss.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
144 |
+
of softmax. Defaults to False.
|
145 |
+
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
146 |
+
Defaults to False.
|
147 |
+
reduction (str, optional): . Defaults to 'mean'.
|
148 |
+
Options are "none", "mean" and "sum".
|
149 |
+
class_weight (list[float] | str, optional): Weight of each class. If in
|
150 |
+
str format, read them from a file. Defaults to None.
|
151 |
+
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
152 |
+
"""
|
153 |
+
|
154 |
+
def __init__(self,
|
155 |
+
use_sigmoid=False,
|
156 |
+
use_mask=False,
|
157 |
+
reduction='mean',
|
158 |
+
class_weight=None,
|
159 |
+
loss_weight=1.0):
|
160 |
+
super(CrossEntropyLoss, self).__init__()
|
161 |
+
assert (use_sigmoid is False) or (use_mask is False)
|
162 |
+
self.use_sigmoid = use_sigmoid
|
163 |
+
self.use_mask = use_mask
|
164 |
+
self.reduction = reduction
|
165 |
+
self.loss_weight = loss_weight
|
166 |
+
self.class_weight = get_class_weight(class_weight)
|
167 |
+
|
168 |
+
if self.use_sigmoid:
|
169 |
+
self.cls_criterion = binary_cross_entropy
|
170 |
+
elif self.use_mask:
|
171 |
+
self.cls_criterion = mask_cross_entropy
|
172 |
+
else:
|
173 |
+
self.cls_criterion = cross_entropy
|
174 |
+
|
175 |
+
def forward(self,
|
176 |
+
cls_score,
|
177 |
+
label,
|
178 |
+
weight=None,
|
179 |
+
avg_factor=None,
|
180 |
+
reduction_override=None,
|
181 |
+
**kwargs):
|
182 |
+
"""Forward function."""
|
183 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
184 |
+
reduction = (
|
185 |
+
reduction_override if reduction_override else self.reduction)
|
186 |
+
if self.class_weight is not None:
|
187 |
+
class_weight = cls_score.new_tensor(self.class_weight)
|
188 |
+
else:
|
189 |
+
class_weight = None
|
190 |
+
loss_cls = self.loss_weight * self.cls_criterion(
|
191 |
+
cls_score,
|
192 |
+
label,
|
193 |
+
weight,
|
194 |
+
class_weight=class_weight,
|
195 |
+
reduction=reduction,
|
196 |
+
avg_factor=avg_factor,
|
197 |
+
**kwargs)
|
198 |
+
return loss_cls
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/dice_loss.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
|
2 |
+
segmentron/solver/loss.py (Apache-2.0 License)"""
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from ..builder import LOSSES
|
8 |
+
from .utils import get_class_weight, weighted_loss
|
9 |
+
|
10 |
+
|
11 |
+
@weighted_loss
|
12 |
+
def dice_loss(pred,
|
13 |
+
target,
|
14 |
+
valid_mask,
|
15 |
+
smooth=1,
|
16 |
+
exponent=2,
|
17 |
+
class_weight=None,
|
18 |
+
ignore_index=255):
|
19 |
+
assert pred.shape[0] == target.shape[0]
|
20 |
+
total_loss = 0
|
21 |
+
num_classes = pred.shape[1]
|
22 |
+
for i in range(num_classes):
|
23 |
+
if i != ignore_index:
|
24 |
+
dice_loss = binary_dice_loss(
|
25 |
+
pred[:, i],
|
26 |
+
target[..., i],
|
27 |
+
valid_mask=valid_mask,
|
28 |
+
smooth=smooth,
|
29 |
+
exponent=exponent)
|
30 |
+
if class_weight is not None:
|
31 |
+
dice_loss *= class_weight[i]
|
32 |
+
total_loss += dice_loss
|
33 |
+
return total_loss / num_classes
|
34 |
+
|
35 |
+
|
36 |
+
@weighted_loss
|
37 |
+
def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
|
38 |
+
assert pred.shape[0] == target.shape[0]
|
39 |
+
pred = pred.reshape(pred.shape[0], -1)
|
40 |
+
target = target.reshape(target.shape[0], -1)
|
41 |
+
valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
|
42 |
+
|
43 |
+
num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
|
44 |
+
den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
|
45 |
+
|
46 |
+
return 1 - num / den
|
47 |
+
|
48 |
+
|
49 |
+
@LOSSES.register_module()
|
50 |
+
class DiceLoss(nn.Module):
|
51 |
+
"""DiceLoss.
|
52 |
+
|
53 |
+
This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
|
54 |
+
Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
loss_type (str, optional): Binary or multi-class loss.
|
58 |
+
Default: 'multi_class'. Options are "binary" and "multi_class".
|
59 |
+
smooth (float): A float number to smooth loss, and avoid NaN error.
|
60 |
+
Default: 1
|
61 |
+
exponent (float): An float number to calculate denominator
|
62 |
+
value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
|
63 |
+
reduction (str, optional): The method used to reduce the loss. Options
|
64 |
+
are "none", "mean" and "sum". This parameter only works when
|
65 |
+
per_image is True. Default: 'mean'.
|
66 |
+
class_weight (list[float] | str, optional): Weight of each class. If in
|
67 |
+
str format, read them from a file. Defaults to None.
|
68 |
+
loss_weight (float, optional): Weight of the loss. Default to 1.0.
|
69 |
+
ignore_index (int | None): The label index to be ignored. Default: 255.
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self,
|
73 |
+
smooth=1,
|
74 |
+
exponent=2,
|
75 |
+
reduction='mean',
|
76 |
+
class_weight=None,
|
77 |
+
loss_weight=1.0,
|
78 |
+
ignore_index=255,
|
79 |
+
**kwards):
|
80 |
+
super(DiceLoss, self).__init__()
|
81 |
+
self.smooth = smooth
|
82 |
+
self.exponent = exponent
|
83 |
+
self.reduction = reduction
|
84 |
+
self.class_weight = get_class_weight(class_weight)
|
85 |
+
self.loss_weight = loss_weight
|
86 |
+
self.ignore_index = ignore_index
|
87 |
+
|
88 |
+
def forward(self,
|
89 |
+
pred,
|
90 |
+
target,
|
91 |
+
avg_factor=None,
|
92 |
+
reduction_override=None,
|
93 |
+
**kwards):
|
94 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
95 |
+
reduction = (
|
96 |
+
reduction_override if reduction_override else self.reduction)
|
97 |
+
if self.class_weight is not None:
|
98 |
+
class_weight = pred.new_tensor(self.class_weight)
|
99 |
+
else:
|
100 |
+
class_weight = None
|
101 |
+
|
102 |
+
pred = F.softmax(pred, dim=1)
|
103 |
+
num_classes = pred.shape[1]
|
104 |
+
one_hot_target = F.one_hot(
|
105 |
+
torch.clamp(target.long(), 0, num_classes - 1),
|
106 |
+
num_classes=num_classes)
|
107 |
+
valid_mask = (target != self.ignore_index).long()
|
108 |
+
|
109 |
+
loss = self.loss_weight * dice_loss(
|
110 |
+
pred,
|
111 |
+
one_hot_target,
|
112 |
+
valid_mask=valid_mask,
|
113 |
+
reduction=reduction,
|
114 |
+
avg_factor=avg_factor,
|
115 |
+
smooth=self.smooth,
|
116 |
+
exponent=self.exponent,
|
117 |
+
class_weight=class_weight,
|
118 |
+
ignore_index=self.ignore_index)
|
119 |
+
return loss
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/lovasz_loss.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
|
2 |
+
ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
|
3 |
+
Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
|
4 |
+
|
5 |
+
import annotator.mmpkg.mmcv as mmcv
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from ..builder import LOSSES
|
11 |
+
from .utils import get_class_weight, weight_reduce_loss
|
12 |
+
|
13 |
+
|
14 |
+
def lovasz_grad(gt_sorted):
|
15 |
+
"""Computes gradient of the Lovasz extension w.r.t sorted errors.
|
16 |
+
|
17 |
+
See Alg. 1 in paper.
|
18 |
+
"""
|
19 |
+
p = len(gt_sorted)
|
20 |
+
gts = gt_sorted.sum()
|
21 |
+
intersection = gts - gt_sorted.float().cumsum(0)
|
22 |
+
union = gts + (1 - gt_sorted).float().cumsum(0)
|
23 |
+
jaccard = 1. - intersection / union
|
24 |
+
if p > 1: # cover 1-pixel case
|
25 |
+
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
26 |
+
return jaccard
|
27 |
+
|
28 |
+
|
29 |
+
def flatten_binary_logits(logits, labels, ignore_index=None):
|
30 |
+
"""Flattens predictions in the batch (binary case) Remove labels equal to
|
31 |
+
'ignore_index'."""
|
32 |
+
logits = logits.view(-1)
|
33 |
+
labels = labels.view(-1)
|
34 |
+
if ignore_index is None:
|
35 |
+
return logits, labels
|
36 |
+
valid = (labels != ignore_index)
|
37 |
+
vlogits = logits[valid]
|
38 |
+
vlabels = labels[valid]
|
39 |
+
return vlogits, vlabels
|
40 |
+
|
41 |
+
|
42 |
+
def flatten_probs(probs, labels, ignore_index=None):
|
43 |
+
"""Flattens predictions in the batch."""
|
44 |
+
if probs.dim() == 3:
|
45 |
+
# assumes output of a sigmoid layer
|
46 |
+
B, H, W = probs.size()
|
47 |
+
probs = probs.view(B, 1, H, W)
|
48 |
+
B, C, H, W = probs.size()
|
49 |
+
probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
|
50 |
+
labels = labels.view(-1)
|
51 |
+
if ignore_index is None:
|
52 |
+
return probs, labels
|
53 |
+
valid = (labels != ignore_index)
|
54 |
+
vprobs = probs[valid.nonzero().squeeze()]
|
55 |
+
vlabels = labels[valid]
|
56 |
+
return vprobs, vlabels
|
57 |
+
|
58 |
+
|
59 |
+
def lovasz_hinge_flat(logits, labels):
|
60 |
+
"""Binary Lovasz hinge loss.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
logits (torch.Tensor): [P], logits at each prediction
|
64 |
+
(between -infty and +infty).
|
65 |
+
labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
torch.Tensor: The calculated loss.
|
69 |
+
"""
|
70 |
+
if len(labels) == 0:
|
71 |
+
# only void pixels, the gradients should be 0
|
72 |
+
return logits.sum() * 0.
|
73 |
+
signs = 2. * labels.float() - 1.
|
74 |
+
errors = (1. - logits * signs)
|
75 |
+
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
|
76 |
+
perm = perm.data
|
77 |
+
gt_sorted = labels[perm]
|
78 |
+
grad = lovasz_grad(gt_sorted)
|
79 |
+
loss = torch.dot(F.relu(errors_sorted), grad)
|
80 |
+
return loss
|
81 |
+
|
82 |
+
|
83 |
+
def lovasz_hinge(logits,
|
84 |
+
labels,
|
85 |
+
classes='present',
|
86 |
+
per_image=False,
|
87 |
+
class_weight=None,
|
88 |
+
reduction='mean',
|
89 |
+
avg_factor=None,
|
90 |
+
ignore_index=255):
|
91 |
+
"""Binary Lovasz hinge loss.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
logits (torch.Tensor): [B, H, W], logits at each pixel
|
95 |
+
(between -infty and +infty).
|
96 |
+
labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
|
97 |
+
classes (str | list[int], optional): Placeholder, to be consistent with
|
98 |
+
other loss. Default: None.
|
99 |
+
per_image (bool, optional): If per_image is True, compute the loss per
|
100 |
+
image instead of per batch. Default: False.
|
101 |
+
class_weight (list[float], optional): Placeholder, to be consistent
|
102 |
+
with other loss. Default: None.
|
103 |
+
reduction (str, optional): The method used to reduce the loss. Options
|
104 |
+
are "none", "mean" and "sum". This parameter only works when
|
105 |
+
per_image is True. Default: 'mean'.
|
106 |
+
avg_factor (int, optional): Average factor that is used to average
|
107 |
+
the loss. This parameter only works when per_image is True.
|
108 |
+
Default: None.
|
109 |
+
ignore_index (int | None): The label index to be ignored. Default: 255.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
torch.Tensor: The calculated loss.
|
113 |
+
"""
|
114 |
+
if per_image:
|
115 |
+
loss = [
|
116 |
+
lovasz_hinge_flat(*flatten_binary_logits(
|
117 |
+
logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
|
118 |
+
for logit, label in zip(logits, labels)
|
119 |
+
]
|
120 |
+
loss = weight_reduce_loss(
|
121 |
+
torch.stack(loss), None, reduction, avg_factor)
|
122 |
+
else:
|
123 |
+
loss = lovasz_hinge_flat(
|
124 |
+
*flatten_binary_logits(logits, labels, ignore_index))
|
125 |
+
return loss
|
126 |
+
|
127 |
+
|
128 |
+
def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
|
129 |
+
"""Multi-class Lovasz-Softmax loss.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
probs (torch.Tensor): [P, C], class probabilities at each prediction
|
133 |
+
(between 0 and 1).
|
134 |
+
labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
|
135 |
+
classes (str | list[int], optional): Classes chosen to calculate loss.
|
136 |
+
'all' for all classes, 'present' for classes present in labels, or
|
137 |
+
a list of classes to average. Default: 'present'.
|
138 |
+
class_weight (list[float], optional): The weight for each class.
|
139 |
+
Default: None.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
torch.Tensor: The calculated loss.
|
143 |
+
"""
|
144 |
+
if probs.numel() == 0:
|
145 |
+
# only void pixels, the gradients should be 0
|
146 |
+
return probs * 0.
|
147 |
+
C = probs.size(1)
|
148 |
+
losses = []
|
149 |
+
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
|
150 |
+
for c in class_to_sum:
|
151 |
+
fg = (labels == c).float() # foreground for class c
|
152 |
+
if (classes == 'present' and fg.sum() == 0):
|
153 |
+
continue
|
154 |
+
if C == 1:
|
155 |
+
if len(classes) > 1:
|
156 |
+
raise ValueError('Sigmoid output possible only with 1 class')
|
157 |
+
class_pred = probs[:, 0]
|
158 |
+
else:
|
159 |
+
class_pred = probs[:, c]
|
160 |
+
errors = (fg - class_pred).abs()
|
161 |
+
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
162 |
+
perm = perm.data
|
163 |
+
fg_sorted = fg[perm]
|
164 |
+
loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
|
165 |
+
if class_weight is not None:
|
166 |
+
loss *= class_weight[c]
|
167 |
+
losses.append(loss)
|
168 |
+
return torch.stack(losses).mean()
|
169 |
+
|
170 |
+
|
171 |
+
def lovasz_softmax(probs,
|
172 |
+
labels,
|
173 |
+
classes='present',
|
174 |
+
per_image=False,
|
175 |
+
class_weight=None,
|
176 |
+
reduction='mean',
|
177 |
+
avg_factor=None,
|
178 |
+
ignore_index=255):
|
179 |
+
"""Multi-class Lovasz-Softmax loss.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
probs (torch.Tensor): [B, C, H, W], class probabilities at each
|
183 |
+
prediction (between 0 and 1).
|
184 |
+
labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
|
185 |
+
C - 1).
|
186 |
+
classes (str | list[int], optional): Classes chosen to calculate loss.
|
187 |
+
'all' for all classes, 'present' for classes present in labels, or
|
188 |
+
a list of classes to average. Default: 'present'.
|
189 |
+
per_image (bool, optional): If per_image is True, compute the loss per
|
190 |
+
image instead of per batch. Default: False.
|
191 |
+
class_weight (list[float], optional): The weight for each class.
|
192 |
+
Default: None.
|
193 |
+
reduction (str, optional): The method used to reduce the loss. Options
|
194 |
+
are "none", "mean" and "sum". This parameter only works when
|
195 |
+
per_image is True. Default: 'mean'.
|
196 |
+
avg_factor (int, optional): Average factor that is used to average
|
197 |
+
the loss. This parameter only works when per_image is True.
|
198 |
+
Default: None.
|
199 |
+
ignore_index (int | None): The label index to be ignored. Default: 255.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
torch.Tensor: The calculated loss.
|
203 |
+
"""
|
204 |
+
|
205 |
+
if per_image:
|
206 |
+
loss = [
|
207 |
+
lovasz_softmax_flat(
|
208 |
+
*flatten_probs(
|
209 |
+
prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
|
210 |
+
classes=classes,
|
211 |
+
class_weight=class_weight)
|
212 |
+
for prob, label in zip(probs, labels)
|
213 |
+
]
|
214 |
+
loss = weight_reduce_loss(
|
215 |
+
torch.stack(loss), None, reduction, avg_factor)
|
216 |
+
else:
|
217 |
+
loss = lovasz_softmax_flat(
|
218 |
+
*flatten_probs(probs, labels, ignore_index),
|
219 |
+
classes=classes,
|
220 |
+
class_weight=class_weight)
|
221 |
+
return loss
|
222 |
+
|
223 |
+
|
224 |
+
@LOSSES.register_module()
|
225 |
+
class LovaszLoss(nn.Module):
|
226 |
+
"""LovaszLoss.
|
227 |
+
|
228 |
+
This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
|
229 |
+
for the optimization of the intersection-over-union measure in neural
|
230 |
+
networks <https://arxiv.org/abs/1705.08790>`_.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
loss_type (str, optional): Binary or multi-class loss.
|
234 |
+
Default: 'multi_class'. Options are "binary" and "multi_class".
|
235 |
+
classes (str | list[int], optional): Classes chosen to calculate loss.
|
236 |
+
'all' for all classes, 'present' for classes present in labels, or
|
237 |
+
a list of classes to average. Default: 'present'.
|
238 |
+
per_image (bool, optional): If per_image is True, compute the loss per
|
239 |
+
image instead of per batch. Default: False.
|
240 |
+
reduction (str, optional): The method used to reduce the loss. Options
|
241 |
+
are "none", "mean" and "sum". This parameter only works when
|
242 |
+
per_image is True. Default: 'mean'.
|
243 |
+
class_weight (list[float] | str, optional): Weight of each class. If in
|
244 |
+
str format, read them from a file. Defaults to None.
|
245 |
+
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
246 |
+
"""
|
247 |
+
|
248 |
+
def __init__(self,
|
249 |
+
loss_type='multi_class',
|
250 |
+
classes='present',
|
251 |
+
per_image=False,
|
252 |
+
reduction='mean',
|
253 |
+
class_weight=None,
|
254 |
+
loss_weight=1.0):
|
255 |
+
super(LovaszLoss, self).__init__()
|
256 |
+
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
|
257 |
+
'binary' or 'multi_class'."
|
258 |
+
|
259 |
+
if loss_type == 'binary':
|
260 |
+
self.cls_criterion = lovasz_hinge
|
261 |
+
else:
|
262 |
+
self.cls_criterion = lovasz_softmax
|
263 |
+
assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
|
264 |
+
if not per_image:
|
265 |
+
assert reduction == 'none', "reduction should be 'none' when \
|
266 |
+
per_image is False."
|
267 |
+
|
268 |
+
self.classes = classes
|
269 |
+
self.per_image = per_image
|
270 |
+
self.reduction = reduction
|
271 |
+
self.loss_weight = loss_weight
|
272 |
+
self.class_weight = get_class_weight(class_weight)
|
273 |
+
|
274 |
+
def forward(self,
|
275 |
+
cls_score,
|
276 |
+
label,
|
277 |
+
weight=None,
|
278 |
+
avg_factor=None,
|
279 |
+
reduction_override=None,
|
280 |
+
**kwargs):
|
281 |
+
"""Forward function."""
|
282 |
+
assert reduction_override in (None, 'none', 'mean', 'sum')
|
283 |
+
reduction = (
|
284 |
+
reduction_override if reduction_override else self.reduction)
|
285 |
+
if self.class_weight is not None:
|
286 |
+
class_weight = cls_score.new_tensor(self.class_weight)
|
287 |
+
else:
|
288 |
+
class_weight = None
|
289 |
+
|
290 |
+
# if multi-class loss, transform logits to probs
|
291 |
+
if self.cls_criterion == lovasz_softmax:
|
292 |
+
cls_score = F.softmax(cls_score, dim=1)
|
293 |
+
|
294 |
+
loss_cls = self.loss_weight * self.cls_criterion(
|
295 |
+
cls_score,
|
296 |
+
label,
|
297 |
+
self.classes,
|
298 |
+
self.per_image,
|
299 |
+
class_weight=class_weight,
|
300 |
+
reduction=reduction,
|
301 |
+
avg_factor=avg_factor,
|
302 |
+
**kwargs)
|
303 |
+
return loss_cls
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/losses/utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import annotator.mmpkg.mmcv as mmcv
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def get_class_weight(class_weight):
|
9 |
+
"""Get class weight for loss function.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
class_weight (list[float] | str | None): If class_weight is a str,
|
13 |
+
take it as a file name and read from it.
|
14 |
+
"""
|
15 |
+
if isinstance(class_weight, str):
|
16 |
+
# take it as a file path
|
17 |
+
if class_weight.endswith('.npy'):
|
18 |
+
class_weight = np.load(class_weight)
|
19 |
+
else:
|
20 |
+
# pkl, json or yaml
|
21 |
+
class_weight = mmcv.load(class_weight)
|
22 |
+
|
23 |
+
return class_weight
|
24 |
+
|
25 |
+
|
26 |
+
def reduce_loss(loss, reduction):
|
27 |
+
"""Reduce loss as specified.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
loss (Tensor): Elementwise loss tensor.
|
31 |
+
reduction (str): Options are "none", "mean" and "sum".
|
32 |
+
|
33 |
+
Return:
|
34 |
+
Tensor: Reduced loss tensor.
|
35 |
+
"""
|
36 |
+
reduction_enum = F._Reduction.get_enum(reduction)
|
37 |
+
# none: 0, elementwise_mean:1, sum: 2
|
38 |
+
if reduction_enum == 0:
|
39 |
+
return loss
|
40 |
+
elif reduction_enum == 1:
|
41 |
+
return loss.mean()
|
42 |
+
elif reduction_enum == 2:
|
43 |
+
return loss.sum()
|
44 |
+
|
45 |
+
|
46 |
+
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
|
47 |
+
"""Apply element-wise weight and reduce loss.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
loss (Tensor): Element-wise loss.
|
51 |
+
weight (Tensor): Element-wise weights.
|
52 |
+
reduction (str): Same as built-in losses of PyTorch.
|
53 |
+
avg_factor (float): Avarage factor when computing the mean of losses.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Tensor: Processed loss values.
|
57 |
+
"""
|
58 |
+
# if weight is specified, apply element-wise weight
|
59 |
+
if weight is not None:
|
60 |
+
assert weight.dim() == loss.dim()
|
61 |
+
if weight.dim() > 1:
|
62 |
+
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
|
63 |
+
loss = loss * weight
|
64 |
+
|
65 |
+
# if avg_factor is not specified, just reduce the loss
|
66 |
+
if avg_factor is None:
|
67 |
+
loss = reduce_loss(loss, reduction)
|
68 |
+
else:
|
69 |
+
# if reduction is mean, then average the loss by avg_factor
|
70 |
+
if reduction == 'mean':
|
71 |
+
loss = loss.sum() / avg_factor
|
72 |
+
# if reduction is 'none', then do nothing, otherwise raise an error
|
73 |
+
elif reduction != 'none':
|
74 |
+
raise ValueError('avg_factor can not be used with reduction="sum"')
|
75 |
+
return loss
|
76 |
+
|
77 |
+
|
78 |
+
def weighted_loss(loss_func):
|
79 |
+
"""Create a weighted version of a given loss function.
|
80 |
+
|
81 |
+
To use this decorator, the loss function must have the signature like
|
82 |
+
`loss_func(pred, target, **kwargs)`. The function only needs to compute
|
83 |
+
element-wise loss without any reduction. This decorator will add weight
|
84 |
+
and reduction arguments to the function. The decorated function will have
|
85 |
+
the signature like `loss_func(pred, target, weight=None, reduction='mean',
|
86 |
+
avg_factor=None, **kwargs)`.
|
87 |
+
|
88 |
+
:Example:
|
89 |
+
|
90 |
+
>>> import torch
|
91 |
+
>>> @weighted_loss
|
92 |
+
>>> def l1_loss(pred, target):
|
93 |
+
>>> return (pred - target).abs()
|
94 |
+
|
95 |
+
>>> pred = torch.Tensor([0, 2, 3])
|
96 |
+
>>> target = torch.Tensor([1, 1, 1])
|
97 |
+
>>> weight = torch.Tensor([1, 0, 1])
|
98 |
+
|
99 |
+
>>> l1_loss(pred, target)
|
100 |
+
tensor(1.3333)
|
101 |
+
>>> l1_loss(pred, target, weight)
|
102 |
+
tensor(1.)
|
103 |
+
>>> l1_loss(pred, target, reduction='none')
|
104 |
+
tensor([1., 1., 2.])
|
105 |
+
>>> l1_loss(pred, target, weight, avg_factor=2)
|
106 |
+
tensor(1.5000)
|
107 |
+
"""
|
108 |
+
|
109 |
+
@functools.wraps(loss_func)
|
110 |
+
def wrapper(pred,
|
111 |
+
target,
|
112 |
+
weight=None,
|
113 |
+
reduction='mean',
|
114 |
+
avg_factor=None,
|
115 |
+
**kwargs):
|
116 |
+
# get element-wise loss
|
117 |
+
loss = loss_func(pred, target, **kwargs)
|
118 |
+
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
119 |
+
return loss
|
120 |
+
|
121 |
+
return wrapper
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/necks/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .fpn import FPN
|
2 |
+
from .multilevel_neck import MultiLevelNeck
|
3 |
+
|
4 |
+
__all__ = ['FPN', 'MultiLevelNeck']
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/models/necks/fpn.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from annotator.mmpkg.mmcv.cnn import ConvModule, xavier_init
|
4 |
+
|
5 |
+
from ..builder import NECKS
|
6 |
+
|
7 |
+
|
8 |
+
@NECKS.register_module()
|
9 |
+
class FPN(nn.Module):
|
10 |
+
"""Feature Pyramid Network.
|
11 |
+
|
12 |
+
This is an implementation of - Feature Pyramid Networks for Object
|
13 |
+
Detection (https://arxiv.org/abs/1612.03144)
|
14 |
+
|
15 |
+
Args:
|
16 |
+
in_channels (List[int]): Number of input channels per scale.
|
17 |
+
out_channels (int): Number of output channels (used at each scale)
|
18 |
+
num_outs (int): Number of output scales.
|
19 |
+
start_level (int): Index of the start input backbone level used to
|
20 |
+
build the feature pyramid. Default: 0.
|
21 |
+
end_level (int): Index of the end input backbone level (exclusive) to
|
22 |
+
build the feature pyramid. Default: -1, which means the last level.
|
23 |
+
add_extra_convs (bool | str): If bool, it decides whether to add conv
|
24 |
+
layers on top of the original feature maps. Default to False.
|
25 |
+
If True, its actual mode is specified by `extra_convs_on_inputs`.
|
26 |
+
If str, it specifies the source feature map of the extra convs.
|
27 |
+
Only the following options are allowed
|
28 |
+
|
29 |
+
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
|
30 |
+
- 'on_lateral': Last feature map after lateral convs.
|
31 |
+
- 'on_output': The last output feature map after fpn convs.
|
32 |
+
extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
|
33 |
+
on the original feature from the backbone. If True,
|
34 |
+
it is equivalent to `add_extra_convs='on_input'`. If False, it is
|
35 |
+
equivalent to set `add_extra_convs='on_output'`. Default to True.
|
36 |
+
relu_before_extra_convs (bool): Whether to apply relu before the extra
|
37 |
+
conv. Default: False.
|
38 |
+
no_norm_on_lateral (bool): Whether to apply norm on lateral.
|
39 |
+
Default: False.
|
40 |
+
conv_cfg (dict): Config dict for convolution layer. Default: None.
|
41 |
+
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
42 |
+
act_cfg (str): Config dict for activation layer in ConvModule.
|
43 |
+
Default: None.
|
44 |
+
upsample_cfg (dict): Config dict for interpolate layer.
|
45 |
+
Default: `dict(mode='nearest')`
|
46 |
+
|
47 |
+
Example:
|
48 |
+
>>> import torch
|
49 |
+
>>> in_channels = [2, 3, 5, 7]
|
50 |
+
>>> scales = [340, 170, 84, 43]
|
51 |
+
>>> inputs = [torch.rand(1, c, s, s)
|
52 |
+
... for c, s in zip(in_channels, scales)]
|
53 |
+
>>> self = FPN(in_channels, 11, len(in_channels)).eval()
|
54 |
+
>>> outputs = self.forward(inputs)
|
55 |
+
>>> for i in range(len(outputs)):
|
56 |
+
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
57 |
+
outputs[0].shape = torch.Size([1, 11, 340, 340])
|
58 |
+
outputs[1].shape = torch.Size([1, 11, 170, 170])
|
59 |
+
outputs[2].shape = torch.Size([1, 11, 84, 84])
|
60 |
+
outputs[3].shape = torch.Size([1, 11, 43, 43])
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self,
|
64 |
+
in_channels,
|
65 |
+
out_channels,
|
66 |
+
num_outs,
|
67 |
+
start_level=0,
|
68 |
+
end_level=-1,
|
69 |
+
add_extra_convs=False,
|
70 |
+
extra_convs_on_inputs=False,
|
71 |
+
relu_before_extra_convs=False,
|
72 |
+
no_norm_on_lateral=False,
|
73 |
+
conv_cfg=None,
|
74 |
+
norm_cfg=None,
|
75 |
+
act_cfg=None,
|
76 |
+
upsample_cfg=dict(mode='nearest')):
|
77 |
+
super(FPN, self).__init__()
|
78 |
+
assert isinstance(in_channels, list)
|
79 |
+
self.in_channels = in_channels
|
80 |
+
self.out_channels = out_channels
|
81 |
+
self.num_ins = len(in_channels)
|
82 |
+
self.num_outs = num_outs
|
83 |
+
self.relu_before_extra_convs = relu_before_extra_convs
|
84 |
+
self.no_norm_on_lateral = no_norm_on_lateral
|
85 |
+
self.fp16_enabled = False
|
86 |
+
self.upsample_cfg = upsample_cfg.copy()
|
87 |
+
|
88 |
+
if end_level == -1:
|
89 |
+
self.backbone_end_level = self.num_ins
|
90 |
+
assert num_outs >= self.num_ins - start_level
|
91 |
+
else:
|
92 |
+
# if end_level < inputs, no extra level is allowed
|
93 |
+
self.backbone_end_level = end_level
|
94 |
+
assert end_level <= len(in_channels)
|
95 |
+
assert num_outs == end_level - start_level
|
96 |
+
self.start_level = start_level
|
97 |
+
self.end_level = end_level
|
98 |
+
self.add_extra_convs = add_extra_convs
|
99 |
+
assert isinstance(add_extra_convs, (str, bool))
|
100 |
+
if isinstance(add_extra_convs, str):
|
101 |
+
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
|
102 |
+
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
|
103 |
+
elif add_extra_convs: # True
|
104 |
+
if extra_convs_on_inputs:
|
105 |
+
# For compatibility with previous release
|
106 |
+
# TODO: deprecate `extra_convs_on_inputs`
|
107 |
+
self.add_extra_convs = 'on_input'
|
108 |
+
else:
|
109 |
+
self.add_extra_convs = 'on_output'
|
110 |
+
|
111 |
+
self.lateral_convs = nn.ModuleList()
|
112 |
+
self.fpn_convs = nn.ModuleList()
|
113 |
+
|
114 |
+
for i in range(self.start_level, self.backbone_end_level):
|
115 |
+
l_conv = ConvModule(
|
116 |
+
in_channels[i],
|
117 |
+
out_channels,
|
118 |
+
1,
|
119 |
+
conv_cfg=conv_cfg,
|
120 |
+
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
|
121 |
+
act_cfg=act_cfg,
|
122 |
+
inplace=False)
|
123 |
+
fpn_conv = ConvModule(
|
124 |
+
out_channels,
|
125 |
+
out_channels,
|
126 |
+
3,
|
127 |
+
padding=1,
|
128 |
+
conv_cfg=conv_cfg,
|
129 |
+
norm_cfg=norm_cfg,
|
130 |
+
act_cfg=act_cfg,
|
131 |
+
inplace=False)
|
132 |
+
|
133 |
+
self.lateral_convs.append(l_conv)
|
134 |
+
self.fpn_convs.append(fpn_conv)
|
135 |
+
|
136 |
+
# add extra conv layers (e.g., RetinaNet)
|
137 |
+
extra_levels = num_outs - self.backbone_end_level + self.start_level
|
138 |
+
if self.add_extra_convs and extra_levels >= 1:
|
139 |
+
for i in range(extra_levels):
|
140 |
+
if i == 0 and self.add_extra_convs == 'on_input':
|
141 |
+
in_channels = self.in_channels[self.backbone_end_level - 1]
|
142 |
+
else:
|
143 |
+
in_channels = out_channels
|
144 |
+
extra_fpn_conv = ConvModule(
|
145 |
+
in_channels,
|
146 |
+
out_channels,
|
147 |
+
3,
|
148 |
+
stride=2,
|
149 |
+
padding=1,
|
150 |
+
conv_cfg=conv_cfg,
|
151 |
+
norm_cfg=norm_cfg,
|
152 |
+
act_cfg=act_cfg,
|
153 |
+
inplace=False)
|
154 |
+
self.fpn_convs.append(extra_fpn_conv)
|
155 |
+
|
156 |
+
# default init_weights for conv(msra) and norm in ConvModule
|
157 |
+
def init_weights(self):
|
158 |
+
for m in self.modules():
|
159 |
+
if isinstance(m, nn.Conv2d):
|
160 |
+
xavier_init(m, distribution='uniform')
|
161 |
+
|
162 |
+
def forward(self, inputs):
|
163 |
+
assert len(inputs) == len(self.in_channels)
|
164 |
+
|
165 |
+
# build laterals
|
166 |
+
laterals = [
|
167 |
+
lateral_conv(inputs[i + self.start_level])
|
168 |
+
for i, lateral_conv in enumerate(self.lateral_convs)
|
169 |
+
]
|
170 |
+
|
171 |
+
# build top-down path
|
172 |
+
used_backbone_levels = len(laterals)
|
173 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
174 |
+
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
|
175 |
+
# it cannot co-exist with `size` in `F.interpolate`.
|
176 |
+
if 'scale_factor' in self.upsample_cfg:
|
177 |
+
laterals[i - 1] += F.interpolate(laterals[i],
|
178 |
+
**self.upsample_cfg)
|
179 |
+
else:
|
180 |
+
prev_shape = laterals[i - 1].shape[2:]
|
181 |
+
laterals[i - 1] += F.interpolate(
|
182 |
+
laterals[i], size=prev_shape, **self.upsample_cfg)
|
183 |
+
|
184 |
+
# build outputs
|
185 |
+
# part 1: from original levels
|
186 |
+
outs = [
|
187 |
+
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
|
188 |
+
]
|
189 |
+
# part 2: add extra levels
|
190 |
+
if self.num_outs > len(outs):
|
191 |
+
# use max pool to get more levels on top of outputs
|
192 |
+
# (e.g., Faster R-CNN, Mask R-CNN)
|
193 |
+
if not self.add_extra_convs:
|
194 |
+
for i in range(self.num_outs - used_backbone_levels):
|
195 |
+
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
|
196 |
+
# add conv layers on top of original feature maps (RetinaNet)
|
197 |
+
else:
|
198 |
+
if self.add_extra_convs == 'on_input':
|
199 |
+
extra_source = inputs[self.backbone_end_level - 1]
|
200 |
+
elif self.add_extra_convs == 'on_lateral':
|
201 |
+
extra_source = laterals[-1]
|
202 |
+
elif self.add_extra_convs == 'on_output':
|
203 |
+
extra_source = outs[-1]
|
204 |
+
else:
|
205 |
+
raise NotImplementedError
|
206 |
+
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
|
207 |
+
for i in range(used_backbone_levels + 1, self.num_outs):
|
208 |
+
if self.relu_before_extra_convs:
|
209 |
+
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
|
210 |
+
else:
|
211 |
+
outs.append(self.fpn_convs[i](outs[-1]))
|
212 |
+
return tuple(outs)
|