b04c629d473b73ca15d008cade0dbdf01deeb1a8f48216e43b46a735e2975a9a
Browse files- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/arraymisc/quantization.py +55 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/__init__.py +41 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/alexnet.py +61 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/__init__.py +35 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/activation.py +92 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/context_block.py +125 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv.py +44 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv2d_adaptive_padding.py +62 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_module.py +206 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_ws.py +148 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/depthwise_separable_conv_module.py +96 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/drop.py +65 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/generalized_attention.py +412 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hsigmoid.py +34 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hswish.py +29 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/non_local.py +306 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/norm.py +144 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/padding.py +36 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/plugin.py +88 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/registry.py +16 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/scale.py +21 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/swish.py +25 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/transformer.py +595 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/upsample.py +84 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/wrappers.py +180 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/builder.py +30 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/resnet.py +316 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/__init__.py +19 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/flops_counter.py +599 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/fuse_conv_bn.py +59 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/sync_bn.py +59 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/weight_init.py +684 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/vgg.py +175 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/engine/__init__.py +8 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/engine/test.py +202 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/__init__.py +11 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/file_client.py +1148 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/__init__.py +7 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/base.py +30 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/json_handler.py +36 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/pickle_handler.py +28 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/yaml_handler.py +24 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/io.py +151 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/parse.py +97 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/__init__.py +28 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/colorspace.py +306 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/geometric.py +728 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/io.py +258 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/misc.py +44 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/photometric.py +428 -0
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/arraymisc/quantization.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
|
6 |
+
"""Quantize an array of (-inf, inf) to [0, levels-1].
|
7 |
+
|
8 |
+
Args:
|
9 |
+
arr (ndarray): Input array.
|
10 |
+
min_val (scalar): Minimum value to be clipped.
|
11 |
+
max_val (scalar): Maximum value to be clipped.
|
12 |
+
levels (int): Quantization levels.
|
13 |
+
dtype (np.type): The type of the quantized array.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple: Quantized array.
|
17 |
+
"""
|
18 |
+
if not (isinstance(levels, int) and levels > 1):
|
19 |
+
raise ValueError(
|
20 |
+
f'levels must be a positive integer, but got {levels}')
|
21 |
+
if min_val >= max_val:
|
22 |
+
raise ValueError(
|
23 |
+
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
24 |
+
|
25 |
+
arr = np.clip(arr, min_val, max_val) - min_val
|
26 |
+
quantized_arr = np.minimum(
|
27 |
+
np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
|
28 |
+
|
29 |
+
return quantized_arr
|
30 |
+
|
31 |
+
|
32 |
+
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
|
33 |
+
"""Dequantize an array.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
arr (ndarray): Input array.
|
37 |
+
min_val (scalar): Minimum value to be clipped.
|
38 |
+
max_val (scalar): Maximum value to be clipped.
|
39 |
+
levels (int): Quantization levels.
|
40 |
+
dtype (np.type): The type of the dequantized array.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
tuple: Dequantized array.
|
44 |
+
"""
|
45 |
+
if not (isinstance(levels, int) and levels > 1):
|
46 |
+
raise ValueError(
|
47 |
+
f'levels must be a positive integer, but got {levels}')
|
48 |
+
if min_val >= max_val:
|
49 |
+
raise ValueError(
|
50 |
+
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
51 |
+
|
52 |
+
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
|
53 |
+
min_val) / levels + min_val
|
54 |
+
|
55 |
+
return dequantized_arr
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/__init__.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .alexnet import AlexNet
|
3 |
+
# yapf: disable
|
4 |
+
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
5 |
+
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
|
6 |
+
ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
|
7 |
+
ConvTranspose2d, ConvTranspose3d, ConvWS2d,
|
8 |
+
DepthwiseSeparableConvModule, GeneralizedAttention,
|
9 |
+
HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
|
10 |
+
NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
|
11 |
+
build_activation_layer, build_conv_layer,
|
12 |
+
build_norm_layer, build_padding_layer, build_plugin_layer,
|
13 |
+
build_upsample_layer, conv_ws_2d, is_norm)
|
14 |
+
from .builder import MODELS, build_model_from_cfg
|
15 |
+
# yapf: enable
|
16 |
+
from .resnet import ResNet, make_res_layer
|
17 |
+
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
|
18 |
+
NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
|
19 |
+
XavierInit, bias_init_with_prob, caffe2_xavier_init,
|
20 |
+
constant_init, fuse_conv_bn, get_model_complexity_info,
|
21 |
+
initialize, kaiming_init, normal_init, trunc_normal_init,
|
22 |
+
uniform_init, xavier_init)
|
23 |
+
from .vgg import VGG, make_vgg_layer
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
|
27 |
+
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
|
28 |
+
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
|
29 |
+
'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
|
30 |
+
'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
|
31 |
+
'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
|
32 |
+
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
|
33 |
+
'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
|
34 |
+
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
|
35 |
+
'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
|
36 |
+
'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
|
37 |
+
'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
|
38 |
+
'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
|
39 |
+
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
40 |
+
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
|
41 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/alexnet.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class AlexNet(nn.Module):
|
8 |
+
"""AlexNet backbone.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
num_classes (int): number of classes for classification.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, num_classes=-1):
|
15 |
+
super(AlexNet, self).__init__()
|
16 |
+
self.num_classes = num_classes
|
17 |
+
self.features = nn.Sequential(
|
18 |
+
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
19 |
+
nn.ReLU(inplace=True),
|
20 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
21 |
+
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
24 |
+
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
25 |
+
nn.ReLU(inplace=True),
|
26 |
+
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
27 |
+
nn.ReLU(inplace=True),
|
28 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
29 |
+
nn.ReLU(inplace=True),
|
30 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
31 |
+
)
|
32 |
+
if self.num_classes > 0:
|
33 |
+
self.classifier = nn.Sequential(
|
34 |
+
nn.Dropout(),
|
35 |
+
nn.Linear(256 * 6 * 6, 4096),
|
36 |
+
nn.ReLU(inplace=True),
|
37 |
+
nn.Dropout(),
|
38 |
+
nn.Linear(4096, 4096),
|
39 |
+
nn.ReLU(inplace=True),
|
40 |
+
nn.Linear(4096, num_classes),
|
41 |
+
)
|
42 |
+
|
43 |
+
def init_weights(self, pretrained=None):
|
44 |
+
if isinstance(pretrained, str):
|
45 |
+
logger = logging.getLogger()
|
46 |
+
from ..runner import load_checkpoint
|
47 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
48 |
+
elif pretrained is None:
|
49 |
+
# use default initializer
|
50 |
+
pass
|
51 |
+
else:
|
52 |
+
raise TypeError('pretrained must be a str or None')
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
|
56 |
+
x = self.features(x)
|
57 |
+
if self.num_classes > 0:
|
58 |
+
x = x.view(x.size(0), 256 * 6 * 6)
|
59 |
+
x = self.classifier(x)
|
60 |
+
|
61 |
+
return x
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .activation import build_activation_layer
|
3 |
+
from .context_block import ContextBlock
|
4 |
+
from .conv import build_conv_layer
|
5 |
+
from .conv2d_adaptive_padding import Conv2dAdaptivePadding
|
6 |
+
from .conv_module import ConvModule
|
7 |
+
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
|
8 |
+
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
|
9 |
+
from .drop import Dropout, DropPath
|
10 |
+
from .generalized_attention import GeneralizedAttention
|
11 |
+
from .hsigmoid import HSigmoid
|
12 |
+
from .hswish import HSwish
|
13 |
+
from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
|
14 |
+
from .norm import build_norm_layer, is_norm
|
15 |
+
from .padding import build_padding_layer
|
16 |
+
from .plugin import build_plugin_layer
|
17 |
+
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
18 |
+
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
|
19 |
+
from .scale import Scale
|
20 |
+
from .swish import Swish
|
21 |
+
from .upsample import build_upsample_layer
|
22 |
+
from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
|
23 |
+
Linear, MaxPool2d, MaxPool3d)
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
'ConvModule', 'build_activation_layer', 'build_conv_layer',
|
27 |
+
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
|
28 |
+
'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
|
29 |
+
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
|
30 |
+
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
|
31 |
+
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
|
32 |
+
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
|
33 |
+
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
|
34 |
+
'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
|
35 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/activation.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
|
7 |
+
from .registry import ACTIVATION_LAYERS
|
8 |
+
|
9 |
+
for module in [
|
10 |
+
nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
|
11 |
+
nn.Sigmoid, nn.Tanh
|
12 |
+
]:
|
13 |
+
ACTIVATION_LAYERS.register_module(module=module)
|
14 |
+
|
15 |
+
|
16 |
+
@ACTIVATION_LAYERS.register_module(name='Clip')
|
17 |
+
@ACTIVATION_LAYERS.register_module()
|
18 |
+
class Clamp(nn.Module):
|
19 |
+
"""Clamp activation layer.
|
20 |
+
|
21 |
+
This activation function is to clamp the feature map value within
|
22 |
+
:math:`[min, max]`. More details can be found in ``torch.clamp()``.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
min (Number | optional): Lower-bound of the range to be clamped to.
|
26 |
+
Default to -1.
|
27 |
+
max (Number | optional): Upper-bound of the range to be clamped to.
|
28 |
+
Default to 1.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, min=-1., max=1.):
|
32 |
+
super(Clamp, self).__init__()
|
33 |
+
self.min = min
|
34 |
+
self.max = max
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
"""Forward function.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
x (torch.Tensor): The input tensor.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
torch.Tensor: Clamped tensor.
|
44 |
+
"""
|
45 |
+
return torch.clamp(x, min=self.min, max=self.max)
|
46 |
+
|
47 |
+
|
48 |
+
class GELU(nn.Module):
|
49 |
+
r"""Applies the Gaussian Error Linear Units function:
|
50 |
+
|
51 |
+
.. math::
|
52 |
+
\text{GELU}(x) = x * \Phi(x)
|
53 |
+
where :math:`\Phi(x)` is the Cumulative Distribution Function for
|
54 |
+
Gaussian Distribution.
|
55 |
+
|
56 |
+
Shape:
|
57 |
+
- Input: :math:`(N, *)` where `*` means, any number of additional
|
58 |
+
dimensions
|
59 |
+
- Output: :math:`(N, *)`, same shape as the input
|
60 |
+
|
61 |
+
.. image:: scripts/activation_images/GELU.png
|
62 |
+
|
63 |
+
Examples::
|
64 |
+
|
65 |
+
>>> m = nn.GELU()
|
66 |
+
>>> input = torch.randn(2)
|
67 |
+
>>> output = m(input)
|
68 |
+
"""
|
69 |
+
|
70 |
+
def forward(self, input):
|
71 |
+
return F.gelu(input)
|
72 |
+
|
73 |
+
|
74 |
+
if (TORCH_VERSION == 'parrots'
|
75 |
+
or digit_version(TORCH_VERSION) < digit_version('1.4')):
|
76 |
+
ACTIVATION_LAYERS.register_module(module=GELU)
|
77 |
+
else:
|
78 |
+
ACTIVATION_LAYERS.register_module(module=nn.GELU)
|
79 |
+
|
80 |
+
|
81 |
+
def build_activation_layer(cfg):
|
82 |
+
"""Build activation layer.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
cfg (dict): The activation layer config, which should contain:
|
86 |
+
- type (str): Layer type.
|
87 |
+
- layer args: Args needed to instantiate an activation layer.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
nn.Module: Created activation layer.
|
91 |
+
"""
|
92 |
+
return build_from_cfg(cfg, ACTIVATION_LAYERS)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/context_block.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from ..utils import constant_init, kaiming_init
|
6 |
+
from .registry import PLUGIN_LAYERS
|
7 |
+
|
8 |
+
|
9 |
+
def last_zero_init(m):
|
10 |
+
if isinstance(m, nn.Sequential):
|
11 |
+
constant_init(m[-1], val=0)
|
12 |
+
else:
|
13 |
+
constant_init(m, val=0)
|
14 |
+
|
15 |
+
|
16 |
+
@PLUGIN_LAYERS.register_module()
|
17 |
+
class ContextBlock(nn.Module):
|
18 |
+
"""ContextBlock module in GCNet.
|
19 |
+
|
20 |
+
See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
|
21 |
+
(https://arxiv.org/abs/1904.11492) for details.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
in_channels (int): Channels of the input feature map.
|
25 |
+
ratio (float): Ratio of channels of transform bottleneck
|
26 |
+
pooling_type (str): Pooling method for context modeling.
|
27 |
+
Options are 'att' and 'avg', stand for attention pooling and
|
28 |
+
average pooling respectively. Default: 'att'.
|
29 |
+
fusion_types (Sequence[str]): Fusion method for feature fusion,
|
30 |
+
Options are 'channels_add', 'channel_mul', stand for channelwise
|
31 |
+
addition and multiplication respectively. Default: ('channel_add',)
|
32 |
+
"""
|
33 |
+
|
34 |
+
_abbr_ = 'context_block'
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
in_channels,
|
38 |
+
ratio,
|
39 |
+
pooling_type='att',
|
40 |
+
fusion_types=('channel_add', )):
|
41 |
+
super(ContextBlock, self).__init__()
|
42 |
+
assert pooling_type in ['avg', 'att']
|
43 |
+
assert isinstance(fusion_types, (list, tuple))
|
44 |
+
valid_fusion_types = ['channel_add', 'channel_mul']
|
45 |
+
assert all([f in valid_fusion_types for f in fusion_types])
|
46 |
+
assert len(fusion_types) > 0, 'at least one fusion should be used'
|
47 |
+
self.in_channels = in_channels
|
48 |
+
self.ratio = ratio
|
49 |
+
self.planes = int(in_channels * ratio)
|
50 |
+
self.pooling_type = pooling_type
|
51 |
+
self.fusion_types = fusion_types
|
52 |
+
if pooling_type == 'att':
|
53 |
+
self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
|
54 |
+
self.softmax = nn.Softmax(dim=2)
|
55 |
+
else:
|
56 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
57 |
+
if 'channel_add' in fusion_types:
|
58 |
+
self.channel_add_conv = nn.Sequential(
|
59 |
+
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
|
60 |
+
nn.LayerNorm([self.planes, 1, 1]),
|
61 |
+
nn.ReLU(inplace=True), # yapf: disable
|
62 |
+
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
|
63 |
+
else:
|
64 |
+
self.channel_add_conv = None
|
65 |
+
if 'channel_mul' in fusion_types:
|
66 |
+
self.channel_mul_conv = nn.Sequential(
|
67 |
+
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
|
68 |
+
nn.LayerNorm([self.planes, 1, 1]),
|
69 |
+
nn.ReLU(inplace=True), # yapf: disable
|
70 |
+
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
|
71 |
+
else:
|
72 |
+
self.channel_mul_conv = None
|
73 |
+
self.reset_parameters()
|
74 |
+
|
75 |
+
def reset_parameters(self):
|
76 |
+
if self.pooling_type == 'att':
|
77 |
+
kaiming_init(self.conv_mask, mode='fan_in')
|
78 |
+
self.conv_mask.inited = True
|
79 |
+
|
80 |
+
if self.channel_add_conv is not None:
|
81 |
+
last_zero_init(self.channel_add_conv)
|
82 |
+
if self.channel_mul_conv is not None:
|
83 |
+
last_zero_init(self.channel_mul_conv)
|
84 |
+
|
85 |
+
def spatial_pool(self, x):
|
86 |
+
batch, channel, height, width = x.size()
|
87 |
+
if self.pooling_type == 'att':
|
88 |
+
input_x = x
|
89 |
+
# [N, C, H * W]
|
90 |
+
input_x = input_x.view(batch, channel, height * width)
|
91 |
+
# [N, 1, C, H * W]
|
92 |
+
input_x = input_x.unsqueeze(1)
|
93 |
+
# [N, 1, H, W]
|
94 |
+
context_mask = self.conv_mask(x)
|
95 |
+
# [N, 1, H * W]
|
96 |
+
context_mask = context_mask.view(batch, 1, height * width)
|
97 |
+
# [N, 1, H * W]
|
98 |
+
context_mask = self.softmax(context_mask)
|
99 |
+
# [N, 1, H * W, 1]
|
100 |
+
context_mask = context_mask.unsqueeze(-1)
|
101 |
+
# [N, 1, C, 1]
|
102 |
+
context = torch.matmul(input_x, context_mask)
|
103 |
+
# [N, C, 1, 1]
|
104 |
+
context = context.view(batch, channel, 1, 1)
|
105 |
+
else:
|
106 |
+
# [N, C, 1, 1]
|
107 |
+
context = self.avg_pool(x)
|
108 |
+
|
109 |
+
return context
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
# [N, C, 1, 1]
|
113 |
+
context = self.spatial_pool(x)
|
114 |
+
|
115 |
+
out = x
|
116 |
+
if self.channel_mul_conv is not None:
|
117 |
+
# [N, C, 1, 1]
|
118 |
+
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
|
119 |
+
out = out * channel_mul_term
|
120 |
+
if self.channel_add_conv is not None:
|
121 |
+
# [N, C, 1, 1]
|
122 |
+
channel_add_term = self.channel_add_conv(context)
|
123 |
+
out = out + channel_add_term
|
124 |
+
|
125 |
+
return out
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from .registry import CONV_LAYERS
|
5 |
+
|
6 |
+
CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
|
7 |
+
CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
|
8 |
+
CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
|
9 |
+
CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
|
10 |
+
|
11 |
+
|
12 |
+
def build_conv_layer(cfg, *args, **kwargs):
|
13 |
+
"""Build convolution layer.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
cfg (None or dict): The conv layer config, which should contain:
|
17 |
+
- type (str): Layer type.
|
18 |
+
- layer args: Args needed to instantiate an conv layer.
|
19 |
+
args (argument list): Arguments passed to the `__init__`
|
20 |
+
method of the corresponding conv layer.
|
21 |
+
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
|
22 |
+
method of the corresponding conv layer.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
nn.Module: Created conv layer.
|
26 |
+
"""
|
27 |
+
if cfg is None:
|
28 |
+
cfg_ = dict(type='Conv2d')
|
29 |
+
else:
|
30 |
+
if not isinstance(cfg, dict):
|
31 |
+
raise TypeError('cfg must be a dict')
|
32 |
+
if 'type' not in cfg:
|
33 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
34 |
+
cfg_ = cfg.copy()
|
35 |
+
|
36 |
+
layer_type = cfg_.pop('type')
|
37 |
+
if layer_type not in CONV_LAYERS:
|
38 |
+
raise KeyError(f'Unrecognized norm type {layer_type}')
|
39 |
+
else:
|
40 |
+
conv_layer = CONV_LAYERS.get(layer_type)
|
41 |
+
|
42 |
+
layer = conv_layer(*args, **kwargs, **cfg_)
|
43 |
+
|
44 |
+
return layer
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv2d_adaptive_padding.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from .registry import CONV_LAYERS
|
8 |
+
|
9 |
+
|
10 |
+
@CONV_LAYERS.register_module()
|
11 |
+
class Conv2dAdaptivePadding(nn.Conv2d):
|
12 |
+
"""Implementation of 2D convolution in tensorflow with `padding` as "same",
|
13 |
+
which applies padding to input (if needed) so that input image gets fully
|
14 |
+
covered by filter and stride you specified. For stride 1, this will ensure
|
15 |
+
that output image size is same as input. For stride of 2, output dimensions
|
16 |
+
will be half, for example.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
in_channels (int): Number of channels in the input image
|
20 |
+
out_channels (int): Number of channels produced by the convolution
|
21 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
22 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
23 |
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
24 |
+
the input. Default: 0
|
25 |
+
dilation (int or tuple, optional): Spacing between kernel elements.
|
26 |
+
Default: 1
|
27 |
+
groups (int, optional): Number of blocked connections from input
|
28 |
+
channels to output channels. Default: 1
|
29 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the
|
30 |
+
output. Default: ``True``
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
in_channels,
|
35 |
+
out_channels,
|
36 |
+
kernel_size,
|
37 |
+
stride=1,
|
38 |
+
padding=0,
|
39 |
+
dilation=1,
|
40 |
+
groups=1,
|
41 |
+
bias=True):
|
42 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, 0,
|
43 |
+
dilation, groups, bias)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
img_h, img_w = x.size()[-2:]
|
47 |
+
kernel_h, kernel_w = self.weight.size()[-2:]
|
48 |
+
stride_h, stride_w = self.stride
|
49 |
+
output_h = math.ceil(img_h / stride_h)
|
50 |
+
output_w = math.ceil(img_w / stride_w)
|
51 |
+
pad_h = (
|
52 |
+
max((output_h - 1) * self.stride[0] +
|
53 |
+
(kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
|
54 |
+
pad_w = (
|
55 |
+
max((output_w - 1) * self.stride[1] +
|
56 |
+
(kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
|
57 |
+
if pad_h > 0 or pad_w > 0:
|
58 |
+
x = F.pad(x, [
|
59 |
+
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
|
60 |
+
])
|
61 |
+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
|
62 |
+
self.dilation, self.groups)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_module.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmcv.utils import _BatchNorm, _InstanceNorm
|
7 |
+
from ..utils import constant_init, kaiming_init
|
8 |
+
from .activation import build_activation_layer
|
9 |
+
from .conv import build_conv_layer
|
10 |
+
from .norm import build_norm_layer
|
11 |
+
from .padding import build_padding_layer
|
12 |
+
from .registry import PLUGIN_LAYERS
|
13 |
+
|
14 |
+
|
15 |
+
@PLUGIN_LAYERS.register_module()
|
16 |
+
class ConvModule(nn.Module):
|
17 |
+
"""A conv block that bundles conv/norm/activation layers.
|
18 |
+
|
19 |
+
This block simplifies the usage of convolution layers, which are commonly
|
20 |
+
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
21 |
+
It is based upon three build methods: `build_conv_layer()`,
|
22 |
+
`build_norm_layer()` and `build_activation_layer()`.
|
23 |
+
|
24 |
+
Besides, we add some additional features in this module.
|
25 |
+
1. Automatically set `bias` of the conv layer.
|
26 |
+
2. Spectral norm is supported.
|
27 |
+
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
28 |
+
supports zero and circular padding, and we add "reflect" padding mode.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
in_channels (int): Number of channels in the input feature map.
|
32 |
+
Same as that in ``nn._ConvNd``.
|
33 |
+
out_channels (int): Number of channels produced by the convolution.
|
34 |
+
Same as that in ``nn._ConvNd``.
|
35 |
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
36 |
+
Same as that in ``nn._ConvNd``.
|
37 |
+
stride (int | tuple[int]): Stride of the convolution.
|
38 |
+
Same as that in ``nn._ConvNd``.
|
39 |
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
40 |
+
the input. Same as that in ``nn._ConvNd``.
|
41 |
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
42 |
+
Same as that in ``nn._ConvNd``.
|
43 |
+
groups (int): Number of blocked connections from input channels to
|
44 |
+
output channels. Same as that in ``nn._ConvNd``.
|
45 |
+
bias (bool | str): If specified as `auto`, it will be decided by the
|
46 |
+
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
|
47 |
+
False. Default: "auto".
|
48 |
+
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
49 |
+
which means using conv2d.
|
50 |
+
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
51 |
+
act_cfg (dict): Config dict for activation layer.
|
52 |
+
Default: dict(type='ReLU').
|
53 |
+
inplace (bool): Whether to use inplace mode for activation.
|
54 |
+
Default: True.
|
55 |
+
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
56 |
+
Default: False.
|
57 |
+
padding_mode (str): If the `padding_mode` has not been supported by
|
58 |
+
current `Conv2d` in PyTorch, we will use our own padding layer
|
59 |
+
instead. Currently, we support ['zeros', 'circular'] with official
|
60 |
+
implementation and ['reflect'] with our own implementation.
|
61 |
+
Default: 'zeros'.
|
62 |
+
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
63 |
+
sequence of "conv", "norm" and "act". Common examples are
|
64 |
+
("conv", "norm", "act") and ("act", "conv", "norm").
|
65 |
+
Default: ('conv', 'norm', 'act').
|
66 |
+
"""
|
67 |
+
|
68 |
+
_abbr_ = 'conv_block'
|
69 |
+
|
70 |
+
def __init__(self,
|
71 |
+
in_channels,
|
72 |
+
out_channels,
|
73 |
+
kernel_size,
|
74 |
+
stride=1,
|
75 |
+
padding=0,
|
76 |
+
dilation=1,
|
77 |
+
groups=1,
|
78 |
+
bias='auto',
|
79 |
+
conv_cfg=None,
|
80 |
+
norm_cfg=None,
|
81 |
+
act_cfg=dict(type='ReLU'),
|
82 |
+
inplace=True,
|
83 |
+
with_spectral_norm=False,
|
84 |
+
padding_mode='zeros',
|
85 |
+
order=('conv', 'norm', 'act')):
|
86 |
+
super(ConvModule, self).__init__()
|
87 |
+
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
88 |
+
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
89 |
+
assert act_cfg is None or isinstance(act_cfg, dict)
|
90 |
+
official_padding_mode = ['zeros', 'circular']
|
91 |
+
self.conv_cfg = conv_cfg
|
92 |
+
self.norm_cfg = norm_cfg
|
93 |
+
self.act_cfg = act_cfg
|
94 |
+
self.inplace = inplace
|
95 |
+
self.with_spectral_norm = with_spectral_norm
|
96 |
+
self.with_explicit_padding = padding_mode not in official_padding_mode
|
97 |
+
self.order = order
|
98 |
+
assert isinstance(self.order, tuple) and len(self.order) == 3
|
99 |
+
assert set(order) == set(['conv', 'norm', 'act'])
|
100 |
+
|
101 |
+
self.with_norm = norm_cfg is not None
|
102 |
+
self.with_activation = act_cfg is not None
|
103 |
+
# if the conv layer is before a norm layer, bias is unnecessary.
|
104 |
+
if bias == 'auto':
|
105 |
+
bias = not self.with_norm
|
106 |
+
self.with_bias = bias
|
107 |
+
|
108 |
+
if self.with_explicit_padding:
|
109 |
+
pad_cfg = dict(type=padding_mode)
|
110 |
+
self.padding_layer = build_padding_layer(pad_cfg, padding)
|
111 |
+
|
112 |
+
# reset padding to 0 for conv module
|
113 |
+
conv_padding = 0 if self.with_explicit_padding else padding
|
114 |
+
# build convolution layer
|
115 |
+
self.conv = build_conv_layer(
|
116 |
+
conv_cfg,
|
117 |
+
in_channels,
|
118 |
+
out_channels,
|
119 |
+
kernel_size,
|
120 |
+
stride=stride,
|
121 |
+
padding=conv_padding,
|
122 |
+
dilation=dilation,
|
123 |
+
groups=groups,
|
124 |
+
bias=bias)
|
125 |
+
# export the attributes of self.conv to a higher level for convenience
|
126 |
+
self.in_channels = self.conv.in_channels
|
127 |
+
self.out_channels = self.conv.out_channels
|
128 |
+
self.kernel_size = self.conv.kernel_size
|
129 |
+
self.stride = self.conv.stride
|
130 |
+
self.padding = padding
|
131 |
+
self.dilation = self.conv.dilation
|
132 |
+
self.transposed = self.conv.transposed
|
133 |
+
self.output_padding = self.conv.output_padding
|
134 |
+
self.groups = self.conv.groups
|
135 |
+
|
136 |
+
if self.with_spectral_norm:
|
137 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
138 |
+
|
139 |
+
# build normalization layers
|
140 |
+
if self.with_norm:
|
141 |
+
# norm layer is after conv layer
|
142 |
+
if order.index('norm') > order.index('conv'):
|
143 |
+
norm_channels = out_channels
|
144 |
+
else:
|
145 |
+
norm_channels = in_channels
|
146 |
+
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
|
147 |
+
self.add_module(self.norm_name, norm)
|
148 |
+
if self.with_bias:
|
149 |
+
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
150 |
+
warnings.warn(
|
151 |
+
'Unnecessary conv bias before batch/instance norm')
|
152 |
+
else:
|
153 |
+
self.norm_name = None
|
154 |
+
|
155 |
+
# build activation layer
|
156 |
+
if self.with_activation:
|
157 |
+
act_cfg_ = act_cfg.copy()
|
158 |
+
# nn.Tanh has no 'inplace' argument
|
159 |
+
if act_cfg_['type'] not in [
|
160 |
+
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
|
161 |
+
]:
|
162 |
+
act_cfg_.setdefault('inplace', inplace)
|
163 |
+
self.activate = build_activation_layer(act_cfg_)
|
164 |
+
|
165 |
+
# Use msra init by default
|
166 |
+
self.init_weights()
|
167 |
+
|
168 |
+
@property
|
169 |
+
def norm(self):
|
170 |
+
if self.norm_name:
|
171 |
+
return getattr(self, self.norm_name)
|
172 |
+
else:
|
173 |
+
return None
|
174 |
+
|
175 |
+
def init_weights(self):
|
176 |
+
# 1. It is mainly for customized conv layers with their own
|
177 |
+
# initialization manners by calling their own ``init_weights()``,
|
178 |
+
# and we do not want ConvModule to override the initialization.
|
179 |
+
# 2. For customized conv layers without their own initialization
|
180 |
+
# manners (that is, they don't have their own ``init_weights()``)
|
181 |
+
# and PyTorch's conv layers, they will be initialized by
|
182 |
+
# this method with default ``kaiming_init``.
|
183 |
+
# Note: For PyTorch's conv layers, they will be overwritten by our
|
184 |
+
# initialization implementation using default ``kaiming_init``.
|
185 |
+
if not hasattr(self.conv, 'init_weights'):
|
186 |
+
if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
|
187 |
+
nonlinearity = 'leaky_relu'
|
188 |
+
a = self.act_cfg.get('negative_slope', 0.01)
|
189 |
+
else:
|
190 |
+
nonlinearity = 'relu'
|
191 |
+
a = 0
|
192 |
+
kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
|
193 |
+
if self.with_norm:
|
194 |
+
constant_init(self.norm, 1, bias=0)
|
195 |
+
|
196 |
+
def forward(self, x, activate=True, norm=True):
|
197 |
+
for layer in self.order:
|
198 |
+
if layer == 'conv':
|
199 |
+
if self.with_explicit_padding:
|
200 |
+
x = self.padding_layer(x)
|
201 |
+
x = self.conv(x)
|
202 |
+
elif layer == 'norm' and norm and self.with_norm:
|
203 |
+
x = self.norm(x)
|
204 |
+
elif layer == 'act' and activate and self.with_activation:
|
205 |
+
x = self.activate(x)
|
206 |
+
return x
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_ws.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .registry import CONV_LAYERS
|
7 |
+
|
8 |
+
|
9 |
+
def conv_ws_2d(input,
|
10 |
+
weight,
|
11 |
+
bias=None,
|
12 |
+
stride=1,
|
13 |
+
padding=0,
|
14 |
+
dilation=1,
|
15 |
+
groups=1,
|
16 |
+
eps=1e-5):
|
17 |
+
c_in = weight.size(0)
|
18 |
+
weight_flat = weight.view(c_in, -1)
|
19 |
+
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
20 |
+
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
21 |
+
weight = (weight - mean) / (std + eps)
|
22 |
+
return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
23 |
+
|
24 |
+
|
25 |
+
@CONV_LAYERS.register_module('ConvWS')
|
26 |
+
class ConvWS2d(nn.Conv2d):
|
27 |
+
|
28 |
+
def __init__(self,
|
29 |
+
in_channels,
|
30 |
+
out_channels,
|
31 |
+
kernel_size,
|
32 |
+
stride=1,
|
33 |
+
padding=0,
|
34 |
+
dilation=1,
|
35 |
+
groups=1,
|
36 |
+
bias=True,
|
37 |
+
eps=1e-5):
|
38 |
+
super(ConvWS2d, self).__init__(
|
39 |
+
in_channels,
|
40 |
+
out_channels,
|
41 |
+
kernel_size,
|
42 |
+
stride=stride,
|
43 |
+
padding=padding,
|
44 |
+
dilation=dilation,
|
45 |
+
groups=groups,
|
46 |
+
bias=bias)
|
47 |
+
self.eps = eps
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
|
51 |
+
self.dilation, self.groups, self.eps)
|
52 |
+
|
53 |
+
|
54 |
+
@CONV_LAYERS.register_module(name='ConvAWS')
|
55 |
+
class ConvAWS2d(nn.Conv2d):
|
56 |
+
"""AWS (Adaptive Weight Standardization)
|
57 |
+
|
58 |
+
This is a variant of Weight Standardization
|
59 |
+
(https://arxiv.org/pdf/1903.10520.pdf)
|
60 |
+
It is used in DetectoRS to avoid NaN
|
61 |
+
(https://arxiv.org/pdf/2006.02334.pdf)
|
62 |
+
|
63 |
+
Args:
|
64 |
+
in_channels (int): Number of channels in the input image
|
65 |
+
out_channels (int): Number of channels produced by the convolution
|
66 |
+
kernel_size (int or tuple): Size of the conv kernel
|
67 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
68 |
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
69 |
+
the input. Default: 0
|
70 |
+
dilation (int or tuple, optional): Spacing between kernel elements.
|
71 |
+
Default: 1
|
72 |
+
groups (int, optional): Number of blocked connections from input
|
73 |
+
channels to output channels. Default: 1
|
74 |
+
bias (bool, optional): If set True, adds a learnable bias to the
|
75 |
+
output. Default: True
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self,
|
79 |
+
in_channels,
|
80 |
+
out_channels,
|
81 |
+
kernel_size,
|
82 |
+
stride=1,
|
83 |
+
padding=0,
|
84 |
+
dilation=1,
|
85 |
+
groups=1,
|
86 |
+
bias=True):
|
87 |
+
super().__init__(
|
88 |
+
in_channels,
|
89 |
+
out_channels,
|
90 |
+
kernel_size,
|
91 |
+
stride=stride,
|
92 |
+
padding=padding,
|
93 |
+
dilation=dilation,
|
94 |
+
groups=groups,
|
95 |
+
bias=bias)
|
96 |
+
self.register_buffer('weight_gamma',
|
97 |
+
torch.ones(self.out_channels, 1, 1, 1))
|
98 |
+
self.register_buffer('weight_beta',
|
99 |
+
torch.zeros(self.out_channels, 1, 1, 1))
|
100 |
+
|
101 |
+
def _get_weight(self, weight):
|
102 |
+
weight_flat = weight.view(weight.size(0), -1)
|
103 |
+
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
104 |
+
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
105 |
+
weight = (weight - mean) / std
|
106 |
+
weight = self.weight_gamma * weight + self.weight_beta
|
107 |
+
return weight
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
weight = self._get_weight(self.weight)
|
111 |
+
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
|
112 |
+
self.dilation, self.groups)
|
113 |
+
|
114 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
115 |
+
missing_keys, unexpected_keys, error_msgs):
|
116 |
+
"""Override default load function.
|
117 |
+
|
118 |
+
AWS overrides the function _load_from_state_dict to recover
|
119 |
+
weight_gamma and weight_beta if they are missing. If weight_gamma and
|
120 |
+
weight_beta are found in the checkpoint, this function will return
|
121 |
+
after super()._load_from_state_dict. Otherwise, it will compute the
|
122 |
+
mean and std of the pretrained weights and store them in weight_beta
|
123 |
+
and weight_gamma.
|
124 |
+
"""
|
125 |
+
|
126 |
+
self.weight_gamma.data.fill_(-1)
|
127 |
+
local_missing_keys = []
|
128 |
+
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
129 |
+
strict, local_missing_keys,
|
130 |
+
unexpected_keys, error_msgs)
|
131 |
+
if self.weight_gamma.data.mean() > 0:
|
132 |
+
for k in local_missing_keys:
|
133 |
+
missing_keys.append(k)
|
134 |
+
return
|
135 |
+
weight = self.weight.data
|
136 |
+
weight_flat = weight.view(weight.size(0), -1)
|
137 |
+
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
138 |
+
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
139 |
+
self.weight_beta.data.copy_(mean)
|
140 |
+
self.weight_gamma.data.copy_(std)
|
141 |
+
missing_gamma_beta = [
|
142 |
+
k for k in local_missing_keys
|
143 |
+
if k.endswith('weight_gamma') or k.endswith('weight_beta')
|
144 |
+
]
|
145 |
+
for k in missing_gamma_beta:
|
146 |
+
local_missing_keys.remove(k)
|
147 |
+
for k in local_missing_keys:
|
148 |
+
missing_keys.append(k)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/depthwise_separable_conv_module.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .conv_module import ConvModule
|
5 |
+
|
6 |
+
|
7 |
+
class DepthwiseSeparableConvModule(nn.Module):
|
8 |
+
"""Depthwise separable convolution module.
|
9 |
+
|
10 |
+
See https://arxiv.org/pdf/1704.04861.pdf for details.
|
11 |
+
|
12 |
+
This module can replace a ConvModule with the conv block replaced by two
|
13 |
+
conv block: depthwise conv block and pointwise conv block. The depthwise
|
14 |
+
conv block contains depthwise-conv/norm/activation layers. The pointwise
|
15 |
+
conv block contains pointwise-conv/norm/activation layers. It should be
|
16 |
+
noted that there will be norm/activation layer in the depthwise conv block
|
17 |
+
if `norm_cfg` and `act_cfg` are specified.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
in_channels (int): Number of channels in the input feature map.
|
21 |
+
Same as that in ``nn._ConvNd``.
|
22 |
+
out_channels (int): Number of channels produced by the convolution.
|
23 |
+
Same as that in ``nn._ConvNd``.
|
24 |
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
25 |
+
Same as that in ``nn._ConvNd``.
|
26 |
+
stride (int | tuple[int]): Stride of the convolution.
|
27 |
+
Same as that in ``nn._ConvNd``. Default: 1.
|
28 |
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
29 |
+
the input. Same as that in ``nn._ConvNd``. Default: 0.
|
30 |
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
31 |
+
Same as that in ``nn._ConvNd``. Default: 1.
|
32 |
+
norm_cfg (dict): Default norm config for both depthwise ConvModule and
|
33 |
+
pointwise ConvModule. Default: None.
|
34 |
+
act_cfg (dict): Default activation config for both depthwise ConvModule
|
35 |
+
and pointwise ConvModule. Default: dict(type='ReLU').
|
36 |
+
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
|
37 |
+
'default', it will be the same as `norm_cfg`. Default: 'default'.
|
38 |
+
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
|
39 |
+
'default', it will be the same as `act_cfg`. Default: 'default'.
|
40 |
+
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
|
41 |
+
'default', it will be the same as `norm_cfg`. Default: 'default'.
|
42 |
+
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
|
43 |
+
'default', it will be the same as `act_cfg`. Default: 'default'.
|
44 |
+
kwargs (optional): Other shared arguments for depthwise and pointwise
|
45 |
+
ConvModule. See ConvModule for ref.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self,
|
49 |
+
in_channels,
|
50 |
+
out_channels,
|
51 |
+
kernel_size,
|
52 |
+
stride=1,
|
53 |
+
padding=0,
|
54 |
+
dilation=1,
|
55 |
+
norm_cfg=None,
|
56 |
+
act_cfg=dict(type='ReLU'),
|
57 |
+
dw_norm_cfg='default',
|
58 |
+
dw_act_cfg='default',
|
59 |
+
pw_norm_cfg='default',
|
60 |
+
pw_act_cfg='default',
|
61 |
+
**kwargs):
|
62 |
+
super(DepthwiseSeparableConvModule, self).__init__()
|
63 |
+
assert 'groups' not in kwargs, 'groups should not be specified'
|
64 |
+
|
65 |
+
# if norm/activation config of depthwise/pointwise ConvModule is not
|
66 |
+
# specified, use default config.
|
67 |
+
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
|
68 |
+
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
|
69 |
+
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
|
70 |
+
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
|
71 |
+
|
72 |
+
# depthwise convolution
|
73 |
+
self.depthwise_conv = ConvModule(
|
74 |
+
in_channels,
|
75 |
+
in_channels,
|
76 |
+
kernel_size,
|
77 |
+
stride=stride,
|
78 |
+
padding=padding,
|
79 |
+
dilation=dilation,
|
80 |
+
groups=in_channels,
|
81 |
+
norm_cfg=dw_norm_cfg,
|
82 |
+
act_cfg=dw_act_cfg,
|
83 |
+
**kwargs)
|
84 |
+
|
85 |
+
self.pointwise_conv = ConvModule(
|
86 |
+
in_channels,
|
87 |
+
out_channels,
|
88 |
+
1,
|
89 |
+
norm_cfg=pw_norm_cfg,
|
90 |
+
act_cfg=pw_act_cfg,
|
91 |
+
**kwargs)
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
x = self.depthwise_conv(x)
|
95 |
+
x = self.pointwise_conv(x)
|
96 |
+
return x
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/drop.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from annotator.mmpkg.mmcv import build_from_cfg
|
6 |
+
from .registry import DROPOUT_LAYERS
|
7 |
+
|
8 |
+
|
9 |
+
def drop_path(x, drop_prob=0., training=False):
|
10 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
11 |
+
residual blocks).
|
12 |
+
|
13 |
+
We follow the implementation
|
14 |
+
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
|
15 |
+
"""
|
16 |
+
if drop_prob == 0. or not training:
|
17 |
+
return x
|
18 |
+
keep_prob = 1 - drop_prob
|
19 |
+
# handle tensors with different dimensions, not just 4D tensors.
|
20 |
+
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
|
21 |
+
random_tensor = keep_prob + torch.rand(
|
22 |
+
shape, dtype=x.dtype, device=x.device)
|
23 |
+
output = x.div(keep_prob) * random_tensor.floor()
|
24 |
+
return output
|
25 |
+
|
26 |
+
|
27 |
+
@DROPOUT_LAYERS.register_module()
|
28 |
+
class DropPath(nn.Module):
|
29 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
30 |
+
residual blocks).
|
31 |
+
|
32 |
+
We follow the implementation
|
33 |
+
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
|
34 |
+
|
35 |
+
Args:
|
36 |
+
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, drop_prob=0.1):
|
40 |
+
super(DropPath, self).__init__()
|
41 |
+
self.drop_prob = drop_prob
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return drop_path(x, self.drop_prob, self.training)
|
45 |
+
|
46 |
+
|
47 |
+
@DROPOUT_LAYERS.register_module()
|
48 |
+
class Dropout(nn.Dropout):
|
49 |
+
"""A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
|
50 |
+
``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
|
51 |
+
``DropPath``
|
52 |
+
|
53 |
+
Args:
|
54 |
+
drop_prob (float): Probability of the elements to be
|
55 |
+
zeroed. Default: 0.5.
|
56 |
+
inplace (bool): Do the operation inplace or not. Default: False.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, drop_prob=0.5, inplace=False):
|
60 |
+
super().__init__(p=drop_prob, inplace=inplace)
|
61 |
+
|
62 |
+
|
63 |
+
def build_dropout(cfg, default_args=None):
|
64 |
+
"""Builder for drop out layers."""
|
65 |
+
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/generalized_attention.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from ..utils import kaiming_init
|
10 |
+
from .registry import PLUGIN_LAYERS
|
11 |
+
|
12 |
+
|
13 |
+
@PLUGIN_LAYERS.register_module()
|
14 |
+
class GeneralizedAttention(nn.Module):
|
15 |
+
"""GeneralizedAttention module.
|
16 |
+
|
17 |
+
See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
|
18 |
+
(https://arxiv.org/abs/1711.07971) for details.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
in_channels (int): Channels of the input feature map.
|
22 |
+
spatial_range (int): The spatial range. -1 indicates no spatial range
|
23 |
+
constraint. Default: -1.
|
24 |
+
num_heads (int): The head number of empirical_attention module.
|
25 |
+
Default: 9.
|
26 |
+
position_embedding_dim (int): The position embedding dimension.
|
27 |
+
Default: -1.
|
28 |
+
position_magnitude (int): A multiplier acting on coord difference.
|
29 |
+
Default: 1.
|
30 |
+
kv_stride (int): The feature stride acting on key/value feature map.
|
31 |
+
Default: 2.
|
32 |
+
q_stride (int): The feature stride acting on query feature map.
|
33 |
+
Default: 1.
|
34 |
+
attention_type (str): A binary indicator string for indicating which
|
35 |
+
items in generalized empirical_attention module are used.
|
36 |
+
Default: '1111'.
|
37 |
+
|
38 |
+
- '1000' indicates 'query and key content' (appr - appr) item,
|
39 |
+
- '0100' indicates 'query content and relative position'
|
40 |
+
(appr - position) item,
|
41 |
+
- '0010' indicates 'key content only' (bias - appr) item,
|
42 |
+
- '0001' indicates 'relative position only' (bias - position) item.
|
43 |
+
"""
|
44 |
+
|
45 |
+
_abbr_ = 'gen_attention_block'
|
46 |
+
|
47 |
+
def __init__(self,
|
48 |
+
in_channels,
|
49 |
+
spatial_range=-1,
|
50 |
+
num_heads=9,
|
51 |
+
position_embedding_dim=-1,
|
52 |
+
position_magnitude=1,
|
53 |
+
kv_stride=2,
|
54 |
+
q_stride=1,
|
55 |
+
attention_type='1111'):
|
56 |
+
|
57 |
+
super(GeneralizedAttention, self).__init__()
|
58 |
+
|
59 |
+
# hard range means local range for non-local operation
|
60 |
+
self.position_embedding_dim = (
|
61 |
+
position_embedding_dim
|
62 |
+
if position_embedding_dim > 0 else in_channels)
|
63 |
+
|
64 |
+
self.position_magnitude = position_magnitude
|
65 |
+
self.num_heads = num_heads
|
66 |
+
self.in_channels = in_channels
|
67 |
+
self.spatial_range = spatial_range
|
68 |
+
self.kv_stride = kv_stride
|
69 |
+
self.q_stride = q_stride
|
70 |
+
self.attention_type = [bool(int(_)) for _ in attention_type]
|
71 |
+
self.qk_embed_dim = in_channels // num_heads
|
72 |
+
out_c = self.qk_embed_dim * num_heads
|
73 |
+
|
74 |
+
if self.attention_type[0] or self.attention_type[1]:
|
75 |
+
self.query_conv = nn.Conv2d(
|
76 |
+
in_channels=in_channels,
|
77 |
+
out_channels=out_c,
|
78 |
+
kernel_size=1,
|
79 |
+
bias=False)
|
80 |
+
self.query_conv.kaiming_init = True
|
81 |
+
|
82 |
+
if self.attention_type[0] or self.attention_type[2]:
|
83 |
+
self.key_conv = nn.Conv2d(
|
84 |
+
in_channels=in_channels,
|
85 |
+
out_channels=out_c,
|
86 |
+
kernel_size=1,
|
87 |
+
bias=False)
|
88 |
+
self.key_conv.kaiming_init = True
|
89 |
+
|
90 |
+
self.v_dim = in_channels // num_heads
|
91 |
+
self.value_conv = nn.Conv2d(
|
92 |
+
in_channels=in_channels,
|
93 |
+
out_channels=self.v_dim * num_heads,
|
94 |
+
kernel_size=1,
|
95 |
+
bias=False)
|
96 |
+
self.value_conv.kaiming_init = True
|
97 |
+
|
98 |
+
if self.attention_type[1] or self.attention_type[3]:
|
99 |
+
self.appr_geom_fc_x = nn.Linear(
|
100 |
+
self.position_embedding_dim // 2, out_c, bias=False)
|
101 |
+
self.appr_geom_fc_x.kaiming_init = True
|
102 |
+
|
103 |
+
self.appr_geom_fc_y = nn.Linear(
|
104 |
+
self.position_embedding_dim // 2, out_c, bias=False)
|
105 |
+
self.appr_geom_fc_y.kaiming_init = True
|
106 |
+
|
107 |
+
if self.attention_type[2]:
|
108 |
+
stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
|
109 |
+
appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
|
110 |
+
self.appr_bias = nn.Parameter(appr_bias_value)
|
111 |
+
|
112 |
+
if self.attention_type[3]:
|
113 |
+
stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
|
114 |
+
geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
|
115 |
+
self.geom_bias = nn.Parameter(geom_bias_value)
|
116 |
+
|
117 |
+
self.proj_conv = nn.Conv2d(
|
118 |
+
in_channels=self.v_dim * num_heads,
|
119 |
+
out_channels=in_channels,
|
120 |
+
kernel_size=1,
|
121 |
+
bias=True)
|
122 |
+
self.proj_conv.kaiming_init = True
|
123 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
124 |
+
|
125 |
+
if self.spatial_range >= 0:
|
126 |
+
# only works when non local is after 3*3 conv
|
127 |
+
if in_channels == 256:
|
128 |
+
max_len = 84
|
129 |
+
elif in_channels == 512:
|
130 |
+
max_len = 42
|
131 |
+
|
132 |
+
max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
|
133 |
+
local_constraint_map = np.ones(
|
134 |
+
(max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
|
135 |
+
for iy in range(max_len):
|
136 |
+
for ix in range(max_len):
|
137 |
+
local_constraint_map[
|
138 |
+
iy, ix,
|
139 |
+
max((iy - self.spatial_range) //
|
140 |
+
self.kv_stride, 0):min((iy + self.spatial_range +
|
141 |
+
1) // self.kv_stride +
|
142 |
+
1, max_len),
|
143 |
+
max((ix - self.spatial_range) //
|
144 |
+
self.kv_stride, 0):min((ix + self.spatial_range +
|
145 |
+
1) // self.kv_stride +
|
146 |
+
1, max_len)] = 0
|
147 |
+
|
148 |
+
self.local_constraint_map = nn.Parameter(
|
149 |
+
torch.from_numpy(local_constraint_map).byte(),
|
150 |
+
requires_grad=False)
|
151 |
+
|
152 |
+
if self.q_stride > 1:
|
153 |
+
self.q_downsample = nn.AvgPool2d(
|
154 |
+
kernel_size=1, stride=self.q_stride)
|
155 |
+
else:
|
156 |
+
self.q_downsample = None
|
157 |
+
|
158 |
+
if self.kv_stride > 1:
|
159 |
+
self.kv_downsample = nn.AvgPool2d(
|
160 |
+
kernel_size=1, stride=self.kv_stride)
|
161 |
+
else:
|
162 |
+
self.kv_downsample = None
|
163 |
+
|
164 |
+
self.init_weights()
|
165 |
+
|
166 |
+
def get_position_embedding(self,
|
167 |
+
h,
|
168 |
+
w,
|
169 |
+
h_kv,
|
170 |
+
w_kv,
|
171 |
+
q_stride,
|
172 |
+
kv_stride,
|
173 |
+
device,
|
174 |
+
dtype,
|
175 |
+
feat_dim,
|
176 |
+
wave_length=1000):
|
177 |
+
# the default type of Tensor is float32, leading to type mismatch
|
178 |
+
# in fp16 mode. Cast it to support fp16 mode.
|
179 |
+
h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
|
180 |
+
h_idxs = h_idxs.view((h, 1)) * q_stride
|
181 |
+
|
182 |
+
w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
|
183 |
+
w_idxs = w_idxs.view((w, 1)) * q_stride
|
184 |
+
|
185 |
+
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
|
186 |
+
device=device, dtype=dtype)
|
187 |
+
h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
|
188 |
+
|
189 |
+
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
|
190 |
+
device=device, dtype=dtype)
|
191 |
+
w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
|
192 |
+
|
193 |
+
# (h, h_kv, 1)
|
194 |
+
h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
|
195 |
+
h_diff *= self.position_magnitude
|
196 |
+
|
197 |
+
# (w, w_kv, 1)
|
198 |
+
w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
|
199 |
+
w_diff *= self.position_magnitude
|
200 |
+
|
201 |
+
feat_range = torch.arange(0, feat_dim / 4).to(
|
202 |
+
device=device, dtype=dtype)
|
203 |
+
|
204 |
+
dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
|
205 |
+
dim_mat = dim_mat**((4. / feat_dim) * feat_range)
|
206 |
+
dim_mat = dim_mat.view((1, 1, -1))
|
207 |
+
|
208 |
+
embedding_x = torch.cat(
|
209 |
+
((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
|
210 |
+
|
211 |
+
embedding_y = torch.cat(
|
212 |
+
((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
|
213 |
+
|
214 |
+
return embedding_x, embedding_y
|
215 |
+
|
216 |
+
def forward(self, x_input):
|
217 |
+
num_heads = self.num_heads
|
218 |
+
|
219 |
+
# use empirical_attention
|
220 |
+
if self.q_downsample is not None:
|
221 |
+
x_q = self.q_downsample(x_input)
|
222 |
+
else:
|
223 |
+
x_q = x_input
|
224 |
+
n, _, h, w = x_q.shape
|
225 |
+
|
226 |
+
if self.kv_downsample is not None:
|
227 |
+
x_kv = self.kv_downsample(x_input)
|
228 |
+
else:
|
229 |
+
x_kv = x_input
|
230 |
+
_, _, h_kv, w_kv = x_kv.shape
|
231 |
+
|
232 |
+
if self.attention_type[0] or self.attention_type[1]:
|
233 |
+
proj_query = self.query_conv(x_q).view(
|
234 |
+
(n, num_heads, self.qk_embed_dim, h * w))
|
235 |
+
proj_query = proj_query.permute(0, 1, 3, 2)
|
236 |
+
|
237 |
+
if self.attention_type[0] or self.attention_type[2]:
|
238 |
+
proj_key = self.key_conv(x_kv).view(
|
239 |
+
(n, num_heads, self.qk_embed_dim, h_kv * w_kv))
|
240 |
+
|
241 |
+
if self.attention_type[1] or self.attention_type[3]:
|
242 |
+
position_embed_x, position_embed_y = self.get_position_embedding(
|
243 |
+
h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
|
244 |
+
x_input.device, x_input.dtype, self.position_embedding_dim)
|
245 |
+
# (n, num_heads, w, w_kv, dim)
|
246 |
+
position_feat_x = self.appr_geom_fc_x(position_embed_x).\
|
247 |
+
view(1, w, w_kv, num_heads, self.qk_embed_dim).\
|
248 |
+
permute(0, 3, 1, 2, 4).\
|
249 |
+
repeat(n, 1, 1, 1, 1)
|
250 |
+
|
251 |
+
# (n, num_heads, h, h_kv, dim)
|
252 |
+
position_feat_y = self.appr_geom_fc_y(position_embed_y).\
|
253 |
+
view(1, h, h_kv, num_heads, self.qk_embed_dim).\
|
254 |
+
permute(0, 3, 1, 2, 4).\
|
255 |
+
repeat(n, 1, 1, 1, 1)
|
256 |
+
|
257 |
+
position_feat_x /= math.sqrt(2)
|
258 |
+
position_feat_y /= math.sqrt(2)
|
259 |
+
|
260 |
+
# accelerate for saliency only
|
261 |
+
if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
|
262 |
+
appr_bias = self.appr_bias.\
|
263 |
+
view(1, num_heads, 1, self.qk_embed_dim).\
|
264 |
+
repeat(n, 1, 1, 1)
|
265 |
+
|
266 |
+
energy = torch.matmul(appr_bias, proj_key).\
|
267 |
+
view(n, num_heads, 1, h_kv * w_kv)
|
268 |
+
|
269 |
+
h = 1
|
270 |
+
w = 1
|
271 |
+
else:
|
272 |
+
# (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
|
273 |
+
if not self.attention_type[0]:
|
274 |
+
energy = torch.zeros(
|
275 |
+
n,
|
276 |
+
num_heads,
|
277 |
+
h,
|
278 |
+
w,
|
279 |
+
h_kv,
|
280 |
+
w_kv,
|
281 |
+
dtype=x_input.dtype,
|
282 |
+
device=x_input.device)
|
283 |
+
|
284 |
+
# attention_type[0]: appr - appr
|
285 |
+
# attention_type[1]: appr - position
|
286 |
+
# attention_type[2]: bias - appr
|
287 |
+
# attention_type[3]: bias - position
|
288 |
+
if self.attention_type[0] or self.attention_type[2]:
|
289 |
+
if self.attention_type[0] and self.attention_type[2]:
|
290 |
+
appr_bias = self.appr_bias.\
|
291 |
+
view(1, num_heads, 1, self.qk_embed_dim)
|
292 |
+
energy = torch.matmul(proj_query + appr_bias, proj_key).\
|
293 |
+
view(n, num_heads, h, w, h_kv, w_kv)
|
294 |
+
|
295 |
+
elif self.attention_type[0]:
|
296 |
+
energy = torch.matmul(proj_query, proj_key).\
|
297 |
+
view(n, num_heads, h, w, h_kv, w_kv)
|
298 |
+
|
299 |
+
elif self.attention_type[2]:
|
300 |
+
appr_bias = self.appr_bias.\
|
301 |
+
view(1, num_heads, 1, self.qk_embed_dim).\
|
302 |
+
repeat(n, 1, 1, 1)
|
303 |
+
|
304 |
+
energy += torch.matmul(appr_bias, proj_key).\
|
305 |
+
view(n, num_heads, 1, 1, h_kv, w_kv)
|
306 |
+
|
307 |
+
if self.attention_type[1] or self.attention_type[3]:
|
308 |
+
if self.attention_type[1] and self.attention_type[3]:
|
309 |
+
geom_bias = self.geom_bias.\
|
310 |
+
view(1, num_heads, 1, self.qk_embed_dim)
|
311 |
+
|
312 |
+
proj_query_reshape = (proj_query + geom_bias).\
|
313 |
+
view(n, num_heads, h, w, self.qk_embed_dim)
|
314 |
+
|
315 |
+
energy_x = torch.matmul(
|
316 |
+
proj_query_reshape.permute(0, 1, 3, 2, 4),
|
317 |
+
position_feat_x.permute(0, 1, 2, 4, 3))
|
318 |
+
energy_x = energy_x.\
|
319 |
+
permute(0, 1, 3, 2, 4).unsqueeze(4)
|
320 |
+
|
321 |
+
energy_y = torch.matmul(
|
322 |
+
proj_query_reshape,
|
323 |
+
position_feat_y.permute(0, 1, 2, 4, 3))
|
324 |
+
energy_y = energy_y.unsqueeze(5)
|
325 |
+
|
326 |
+
energy += energy_x + energy_y
|
327 |
+
|
328 |
+
elif self.attention_type[1]:
|
329 |
+
proj_query_reshape = proj_query.\
|
330 |
+
view(n, num_heads, h, w, self.qk_embed_dim)
|
331 |
+
proj_query_reshape = proj_query_reshape.\
|
332 |
+
permute(0, 1, 3, 2, 4)
|
333 |
+
position_feat_x_reshape = position_feat_x.\
|
334 |
+
permute(0, 1, 2, 4, 3)
|
335 |
+
position_feat_y_reshape = position_feat_y.\
|
336 |
+
permute(0, 1, 2, 4, 3)
|
337 |
+
|
338 |
+
energy_x = torch.matmul(proj_query_reshape,
|
339 |
+
position_feat_x_reshape)
|
340 |
+
energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
|
341 |
+
|
342 |
+
energy_y = torch.matmul(proj_query_reshape,
|
343 |
+
position_feat_y_reshape)
|
344 |
+
energy_y = energy_y.unsqueeze(5)
|
345 |
+
|
346 |
+
energy += energy_x + energy_y
|
347 |
+
|
348 |
+
elif self.attention_type[3]:
|
349 |
+
geom_bias = self.geom_bias.\
|
350 |
+
view(1, num_heads, self.qk_embed_dim, 1).\
|
351 |
+
repeat(n, 1, 1, 1)
|
352 |
+
|
353 |
+
position_feat_x_reshape = position_feat_x.\
|
354 |
+
view(n, num_heads, w*w_kv, self.qk_embed_dim)
|
355 |
+
|
356 |
+
position_feat_y_reshape = position_feat_y.\
|
357 |
+
view(n, num_heads, h * h_kv, self.qk_embed_dim)
|
358 |
+
|
359 |
+
energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
|
360 |
+
energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
|
361 |
+
|
362 |
+
energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
|
363 |
+
energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
|
364 |
+
|
365 |
+
energy += energy_x + energy_y
|
366 |
+
|
367 |
+
energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
|
368 |
+
|
369 |
+
if self.spatial_range >= 0:
|
370 |
+
cur_local_constraint_map = \
|
371 |
+
self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
|
372 |
+
contiguous().\
|
373 |
+
view(1, 1, h*w, h_kv*w_kv)
|
374 |
+
|
375 |
+
energy = energy.masked_fill_(cur_local_constraint_map,
|
376 |
+
float('-inf'))
|
377 |
+
|
378 |
+
attention = F.softmax(energy, 3)
|
379 |
+
|
380 |
+
proj_value = self.value_conv(x_kv)
|
381 |
+
proj_value_reshape = proj_value.\
|
382 |
+
view((n, num_heads, self.v_dim, h_kv * w_kv)).\
|
383 |
+
permute(0, 1, 3, 2)
|
384 |
+
|
385 |
+
out = torch.matmul(attention, proj_value_reshape).\
|
386 |
+
permute(0, 1, 3, 2).\
|
387 |
+
contiguous().\
|
388 |
+
view(n, self.v_dim * self.num_heads, h, w)
|
389 |
+
|
390 |
+
out = self.proj_conv(out)
|
391 |
+
|
392 |
+
# output is downsampled, upsample back to input size
|
393 |
+
if self.q_downsample is not None:
|
394 |
+
out = F.interpolate(
|
395 |
+
out,
|
396 |
+
size=x_input.shape[2:],
|
397 |
+
mode='bilinear',
|
398 |
+
align_corners=False)
|
399 |
+
|
400 |
+
out = self.gamma * out + x_input
|
401 |
+
return out
|
402 |
+
|
403 |
+
def init_weights(self):
|
404 |
+
for m in self.modules():
|
405 |
+
if hasattr(m, 'kaiming_init') and m.kaiming_init:
|
406 |
+
kaiming_init(
|
407 |
+
m,
|
408 |
+
mode='fan_in',
|
409 |
+
nonlinearity='leaky_relu',
|
410 |
+
bias=0,
|
411 |
+
distribution='uniform',
|
412 |
+
a=1)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hsigmoid.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .registry import ACTIVATION_LAYERS
|
5 |
+
|
6 |
+
|
7 |
+
@ACTIVATION_LAYERS.register_module()
|
8 |
+
class HSigmoid(nn.Module):
|
9 |
+
"""Hard Sigmoid Module. Apply the hard sigmoid function:
|
10 |
+
Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
|
11 |
+
Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
|
12 |
+
|
13 |
+
Args:
|
14 |
+
bias (float): Bias of the input feature map. Default: 1.0.
|
15 |
+
divisor (float): Divisor of the input feature map. Default: 2.0.
|
16 |
+
min_value (float): Lower bound value. Default: 0.0.
|
17 |
+
max_value (float): Upper bound value. Default: 1.0.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tensor: The output tensor.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
|
24 |
+
super(HSigmoid, self).__init__()
|
25 |
+
self.bias = bias
|
26 |
+
self.divisor = divisor
|
27 |
+
assert self.divisor != 0
|
28 |
+
self.min_value = min_value
|
29 |
+
self.max_value = max_value
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = (x + self.bias) / self.divisor
|
33 |
+
|
34 |
+
return x.clamp_(self.min_value, self.max_value)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hswish.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .registry import ACTIVATION_LAYERS
|
5 |
+
|
6 |
+
|
7 |
+
@ACTIVATION_LAYERS.register_module()
|
8 |
+
class HSwish(nn.Module):
|
9 |
+
"""Hard Swish Module.
|
10 |
+
|
11 |
+
This module applies the hard swish function:
|
12 |
+
|
13 |
+
.. math::
|
14 |
+
Hswish(x) = x * ReLU6(x + 3) / 6
|
15 |
+
|
16 |
+
Args:
|
17 |
+
inplace (bool): can optionally do the operation in-place.
|
18 |
+
Default: False.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
Tensor: The output tensor.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, inplace=False):
|
25 |
+
super(HSwish, self).__init__()
|
26 |
+
self.act = nn.ReLU6(inplace)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return x * self.act(x + 3) / 6
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/non_local.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from abc import ABCMeta
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from ..utils import constant_init, normal_init
|
8 |
+
from .conv_module import ConvModule
|
9 |
+
from .registry import PLUGIN_LAYERS
|
10 |
+
|
11 |
+
|
12 |
+
class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
13 |
+
"""Basic Non-local module.
|
14 |
+
|
15 |
+
This module is proposed in
|
16 |
+
"Non-local Neural Networks"
|
17 |
+
Paper reference: https://arxiv.org/abs/1711.07971
|
18 |
+
Code reference: https://github.com/AlexHex7/Non-local_pytorch
|
19 |
+
|
20 |
+
Args:
|
21 |
+
in_channels (int): Channels of the input feature map.
|
22 |
+
reduction (int): Channel reduction ratio. Default: 2.
|
23 |
+
use_scale (bool): Whether to scale pairwise_weight by
|
24 |
+
`1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
|
25 |
+
Default: True.
|
26 |
+
conv_cfg (None | dict): The config dict for convolution layers.
|
27 |
+
If not specified, it will use `nn.Conv2d` for convolution layers.
|
28 |
+
Default: None.
|
29 |
+
norm_cfg (None | dict): The config dict for normalization layers.
|
30 |
+
Default: None. (This parameter is only applicable to conv_out.)
|
31 |
+
mode (str): Options are `gaussian`, `concatenation`,
|
32 |
+
`embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self,
|
36 |
+
in_channels,
|
37 |
+
reduction=2,
|
38 |
+
use_scale=True,
|
39 |
+
conv_cfg=None,
|
40 |
+
norm_cfg=None,
|
41 |
+
mode='embedded_gaussian',
|
42 |
+
**kwargs):
|
43 |
+
super(_NonLocalNd, self).__init__()
|
44 |
+
self.in_channels = in_channels
|
45 |
+
self.reduction = reduction
|
46 |
+
self.use_scale = use_scale
|
47 |
+
self.inter_channels = max(in_channels // reduction, 1)
|
48 |
+
self.mode = mode
|
49 |
+
|
50 |
+
if mode not in [
|
51 |
+
'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
|
52 |
+
]:
|
53 |
+
raise ValueError("Mode should be in 'gaussian', 'concatenation', "
|
54 |
+
f"'embedded_gaussian' or 'dot_product', but got "
|
55 |
+
f'{mode} instead.')
|
56 |
+
|
57 |
+
# g, theta, phi are defaulted as `nn.ConvNd`.
|
58 |
+
# Here we use ConvModule for potential usage.
|
59 |
+
self.g = ConvModule(
|
60 |
+
self.in_channels,
|
61 |
+
self.inter_channels,
|
62 |
+
kernel_size=1,
|
63 |
+
conv_cfg=conv_cfg,
|
64 |
+
act_cfg=None)
|
65 |
+
self.conv_out = ConvModule(
|
66 |
+
self.inter_channels,
|
67 |
+
self.in_channels,
|
68 |
+
kernel_size=1,
|
69 |
+
conv_cfg=conv_cfg,
|
70 |
+
norm_cfg=norm_cfg,
|
71 |
+
act_cfg=None)
|
72 |
+
|
73 |
+
if self.mode != 'gaussian':
|
74 |
+
self.theta = ConvModule(
|
75 |
+
self.in_channels,
|
76 |
+
self.inter_channels,
|
77 |
+
kernel_size=1,
|
78 |
+
conv_cfg=conv_cfg,
|
79 |
+
act_cfg=None)
|
80 |
+
self.phi = ConvModule(
|
81 |
+
self.in_channels,
|
82 |
+
self.inter_channels,
|
83 |
+
kernel_size=1,
|
84 |
+
conv_cfg=conv_cfg,
|
85 |
+
act_cfg=None)
|
86 |
+
|
87 |
+
if self.mode == 'concatenation':
|
88 |
+
self.concat_project = ConvModule(
|
89 |
+
self.inter_channels * 2,
|
90 |
+
1,
|
91 |
+
kernel_size=1,
|
92 |
+
stride=1,
|
93 |
+
padding=0,
|
94 |
+
bias=False,
|
95 |
+
act_cfg=dict(type='ReLU'))
|
96 |
+
|
97 |
+
self.init_weights(**kwargs)
|
98 |
+
|
99 |
+
def init_weights(self, std=0.01, zeros_init=True):
|
100 |
+
if self.mode != 'gaussian':
|
101 |
+
for m in [self.g, self.theta, self.phi]:
|
102 |
+
normal_init(m.conv, std=std)
|
103 |
+
else:
|
104 |
+
normal_init(self.g.conv, std=std)
|
105 |
+
if zeros_init:
|
106 |
+
if self.conv_out.norm_cfg is None:
|
107 |
+
constant_init(self.conv_out.conv, 0)
|
108 |
+
else:
|
109 |
+
constant_init(self.conv_out.norm, 0)
|
110 |
+
else:
|
111 |
+
if self.conv_out.norm_cfg is None:
|
112 |
+
normal_init(self.conv_out.conv, std=std)
|
113 |
+
else:
|
114 |
+
normal_init(self.conv_out.norm, std=std)
|
115 |
+
|
116 |
+
def gaussian(self, theta_x, phi_x):
|
117 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
118 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
119 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
120 |
+
pairwise_weight = torch.matmul(theta_x, phi_x)
|
121 |
+
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
122 |
+
return pairwise_weight
|
123 |
+
|
124 |
+
def embedded_gaussian(self, theta_x, phi_x):
|
125 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
126 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
127 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
128 |
+
pairwise_weight = torch.matmul(theta_x, phi_x)
|
129 |
+
if self.use_scale:
|
130 |
+
# theta_x.shape[-1] is `self.inter_channels`
|
131 |
+
pairwise_weight /= theta_x.shape[-1]**0.5
|
132 |
+
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
133 |
+
return pairwise_weight
|
134 |
+
|
135 |
+
def dot_product(self, theta_x, phi_x):
|
136 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
137 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
138 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
139 |
+
pairwise_weight = torch.matmul(theta_x, phi_x)
|
140 |
+
pairwise_weight /= pairwise_weight.shape[-1]
|
141 |
+
return pairwise_weight
|
142 |
+
|
143 |
+
def concatenation(self, theta_x, phi_x):
|
144 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
145 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
146 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
147 |
+
h = theta_x.size(2)
|
148 |
+
w = phi_x.size(3)
|
149 |
+
theta_x = theta_x.repeat(1, 1, 1, w)
|
150 |
+
phi_x = phi_x.repeat(1, 1, h, 1)
|
151 |
+
|
152 |
+
concat_feature = torch.cat([theta_x, phi_x], dim=1)
|
153 |
+
pairwise_weight = self.concat_project(concat_feature)
|
154 |
+
n, _, h, w = pairwise_weight.size()
|
155 |
+
pairwise_weight = pairwise_weight.view(n, h, w)
|
156 |
+
pairwise_weight /= pairwise_weight.shape[-1]
|
157 |
+
|
158 |
+
return pairwise_weight
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
# Assume `reduction = 1`, then `inter_channels = C`
|
162 |
+
# or `inter_channels = C` when `mode="gaussian"`
|
163 |
+
|
164 |
+
# NonLocal1d x: [N, C, H]
|
165 |
+
# NonLocal2d x: [N, C, H, W]
|
166 |
+
# NonLocal3d x: [N, C, T, H, W]
|
167 |
+
n = x.size(0)
|
168 |
+
|
169 |
+
# NonLocal1d g_x: [N, H, C]
|
170 |
+
# NonLocal2d g_x: [N, HxW, C]
|
171 |
+
# NonLocal3d g_x: [N, TxHxW, C]
|
172 |
+
g_x = self.g(x).view(n, self.inter_channels, -1)
|
173 |
+
g_x = g_x.permute(0, 2, 1)
|
174 |
+
|
175 |
+
# NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
|
176 |
+
# NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
|
177 |
+
# NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
|
178 |
+
if self.mode == 'gaussian':
|
179 |
+
theta_x = x.view(n, self.in_channels, -1)
|
180 |
+
theta_x = theta_x.permute(0, 2, 1)
|
181 |
+
if self.sub_sample:
|
182 |
+
phi_x = self.phi(x).view(n, self.in_channels, -1)
|
183 |
+
else:
|
184 |
+
phi_x = x.view(n, self.in_channels, -1)
|
185 |
+
elif self.mode == 'concatenation':
|
186 |
+
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
|
187 |
+
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
|
188 |
+
else:
|
189 |
+
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
190 |
+
theta_x = theta_x.permute(0, 2, 1)
|
191 |
+
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
192 |
+
|
193 |
+
pairwise_func = getattr(self, self.mode)
|
194 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
195 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
196 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
197 |
+
pairwise_weight = pairwise_func(theta_x, phi_x)
|
198 |
+
|
199 |
+
# NonLocal1d y: [N, H, C]
|
200 |
+
# NonLocal2d y: [N, HxW, C]
|
201 |
+
# NonLocal3d y: [N, TxHxW, C]
|
202 |
+
y = torch.matmul(pairwise_weight, g_x)
|
203 |
+
# NonLocal1d y: [N, C, H]
|
204 |
+
# NonLocal2d y: [N, C, H, W]
|
205 |
+
# NonLocal3d y: [N, C, T, H, W]
|
206 |
+
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
|
207 |
+
*x.size()[2:])
|
208 |
+
|
209 |
+
output = x + self.conv_out(y)
|
210 |
+
|
211 |
+
return output
|
212 |
+
|
213 |
+
|
214 |
+
class NonLocal1d(_NonLocalNd):
|
215 |
+
"""1D Non-local module.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
in_channels (int): Same as `NonLocalND`.
|
219 |
+
sub_sample (bool): Whether to apply max pooling after pairwise
|
220 |
+
function (Note that the `sub_sample` is applied on spatial only).
|
221 |
+
Default: False.
|
222 |
+
conv_cfg (None | dict): Same as `NonLocalND`.
|
223 |
+
Default: dict(type='Conv1d').
|
224 |
+
"""
|
225 |
+
|
226 |
+
def __init__(self,
|
227 |
+
in_channels,
|
228 |
+
sub_sample=False,
|
229 |
+
conv_cfg=dict(type='Conv1d'),
|
230 |
+
**kwargs):
|
231 |
+
super(NonLocal1d, self).__init__(
|
232 |
+
in_channels, conv_cfg=conv_cfg, **kwargs)
|
233 |
+
|
234 |
+
self.sub_sample = sub_sample
|
235 |
+
|
236 |
+
if sub_sample:
|
237 |
+
max_pool_layer = nn.MaxPool1d(kernel_size=2)
|
238 |
+
self.g = nn.Sequential(self.g, max_pool_layer)
|
239 |
+
if self.mode != 'gaussian':
|
240 |
+
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
241 |
+
else:
|
242 |
+
self.phi = max_pool_layer
|
243 |
+
|
244 |
+
|
245 |
+
@PLUGIN_LAYERS.register_module()
|
246 |
+
class NonLocal2d(_NonLocalNd):
|
247 |
+
"""2D Non-local module.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
in_channels (int): Same as `NonLocalND`.
|
251 |
+
sub_sample (bool): Whether to apply max pooling after pairwise
|
252 |
+
function (Note that the `sub_sample` is applied on spatial only).
|
253 |
+
Default: False.
|
254 |
+
conv_cfg (None | dict): Same as `NonLocalND`.
|
255 |
+
Default: dict(type='Conv2d').
|
256 |
+
"""
|
257 |
+
|
258 |
+
_abbr_ = 'nonlocal_block'
|
259 |
+
|
260 |
+
def __init__(self,
|
261 |
+
in_channels,
|
262 |
+
sub_sample=False,
|
263 |
+
conv_cfg=dict(type='Conv2d'),
|
264 |
+
**kwargs):
|
265 |
+
super(NonLocal2d, self).__init__(
|
266 |
+
in_channels, conv_cfg=conv_cfg, **kwargs)
|
267 |
+
|
268 |
+
self.sub_sample = sub_sample
|
269 |
+
|
270 |
+
if sub_sample:
|
271 |
+
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
|
272 |
+
self.g = nn.Sequential(self.g, max_pool_layer)
|
273 |
+
if self.mode != 'gaussian':
|
274 |
+
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
275 |
+
else:
|
276 |
+
self.phi = max_pool_layer
|
277 |
+
|
278 |
+
|
279 |
+
class NonLocal3d(_NonLocalNd):
|
280 |
+
"""3D Non-local module.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
in_channels (int): Same as `NonLocalND`.
|
284 |
+
sub_sample (bool): Whether to apply max pooling after pairwise
|
285 |
+
function (Note that the `sub_sample` is applied on spatial only).
|
286 |
+
Default: False.
|
287 |
+
conv_cfg (None | dict): Same as `NonLocalND`.
|
288 |
+
Default: dict(type='Conv3d').
|
289 |
+
"""
|
290 |
+
|
291 |
+
def __init__(self,
|
292 |
+
in_channels,
|
293 |
+
sub_sample=False,
|
294 |
+
conv_cfg=dict(type='Conv3d'),
|
295 |
+
**kwargs):
|
296 |
+
super(NonLocal3d, self).__init__(
|
297 |
+
in_channels, conv_cfg=conv_cfg, **kwargs)
|
298 |
+
self.sub_sample = sub_sample
|
299 |
+
|
300 |
+
if sub_sample:
|
301 |
+
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
|
302 |
+
self.g = nn.Sequential(self.g, max_pool_layer)
|
303 |
+
if self.mode != 'gaussian':
|
304 |
+
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
305 |
+
else:
|
306 |
+
self.phi = max_pool_layer
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/norm.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import inspect
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmcv.utils import is_tuple_of
|
7 |
+
from annotator.mmpkg.mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
|
8 |
+
from .registry import NORM_LAYERS
|
9 |
+
|
10 |
+
NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
|
11 |
+
NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
|
12 |
+
NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
|
13 |
+
NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
|
14 |
+
NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
|
15 |
+
NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
|
16 |
+
NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
|
17 |
+
NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
|
18 |
+
NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
|
19 |
+
NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
|
20 |
+
NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
|
21 |
+
|
22 |
+
|
23 |
+
def infer_abbr(class_type):
|
24 |
+
"""Infer abbreviation from the class name.
|
25 |
+
|
26 |
+
When we build a norm layer with `build_norm_layer()`, we want to preserve
|
27 |
+
the norm type in variable names, e.g, self.bn1, self.gn. This method will
|
28 |
+
infer the abbreviation to map class types to abbreviations.
|
29 |
+
|
30 |
+
Rule 1: If the class has the property "_abbr_", return the property.
|
31 |
+
Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
|
32 |
+
InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
|
33 |
+
"in" respectively.
|
34 |
+
Rule 3: If the class name contains "batch", "group", "layer" or "instance",
|
35 |
+
the abbreviation of this layer will be "bn", "gn", "ln" and "in"
|
36 |
+
respectively.
|
37 |
+
Rule 4: Otherwise, the abbreviation falls back to "norm".
|
38 |
+
|
39 |
+
Args:
|
40 |
+
class_type (type): The norm layer type.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
str: The inferred abbreviation.
|
44 |
+
"""
|
45 |
+
if not inspect.isclass(class_type):
|
46 |
+
raise TypeError(
|
47 |
+
f'class_type must be a type, but got {type(class_type)}')
|
48 |
+
if hasattr(class_type, '_abbr_'):
|
49 |
+
return class_type._abbr_
|
50 |
+
if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
|
51 |
+
return 'in'
|
52 |
+
elif issubclass(class_type, _BatchNorm):
|
53 |
+
return 'bn'
|
54 |
+
elif issubclass(class_type, nn.GroupNorm):
|
55 |
+
return 'gn'
|
56 |
+
elif issubclass(class_type, nn.LayerNorm):
|
57 |
+
return 'ln'
|
58 |
+
else:
|
59 |
+
class_name = class_type.__name__.lower()
|
60 |
+
if 'batch' in class_name:
|
61 |
+
return 'bn'
|
62 |
+
elif 'group' in class_name:
|
63 |
+
return 'gn'
|
64 |
+
elif 'layer' in class_name:
|
65 |
+
return 'ln'
|
66 |
+
elif 'instance' in class_name:
|
67 |
+
return 'in'
|
68 |
+
else:
|
69 |
+
return 'norm_layer'
|
70 |
+
|
71 |
+
|
72 |
+
def build_norm_layer(cfg, num_features, postfix=''):
|
73 |
+
"""Build normalization layer.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
cfg (dict): The norm layer config, which should contain:
|
77 |
+
|
78 |
+
- type (str): Layer type.
|
79 |
+
- layer args: Args needed to instantiate a norm layer.
|
80 |
+
- requires_grad (bool, optional): Whether stop gradient updates.
|
81 |
+
num_features (int): Number of input channels.
|
82 |
+
postfix (int | str): The postfix to be appended into norm abbreviation
|
83 |
+
to create named layer.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
(str, nn.Module): The first element is the layer name consisting of
|
87 |
+
abbreviation and postfix, e.g., bn1, gn. The second element is the
|
88 |
+
created norm layer.
|
89 |
+
"""
|
90 |
+
if not isinstance(cfg, dict):
|
91 |
+
raise TypeError('cfg must be a dict')
|
92 |
+
if 'type' not in cfg:
|
93 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
94 |
+
cfg_ = cfg.copy()
|
95 |
+
|
96 |
+
layer_type = cfg_.pop('type')
|
97 |
+
if layer_type not in NORM_LAYERS:
|
98 |
+
raise KeyError(f'Unrecognized norm type {layer_type}')
|
99 |
+
|
100 |
+
norm_layer = NORM_LAYERS.get(layer_type)
|
101 |
+
abbr = infer_abbr(norm_layer)
|
102 |
+
|
103 |
+
assert isinstance(postfix, (int, str))
|
104 |
+
name = abbr + str(postfix)
|
105 |
+
|
106 |
+
requires_grad = cfg_.pop('requires_grad', True)
|
107 |
+
cfg_.setdefault('eps', 1e-5)
|
108 |
+
if layer_type != 'GN':
|
109 |
+
layer = norm_layer(num_features, **cfg_)
|
110 |
+
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
|
111 |
+
layer._specify_ddp_gpu_num(1)
|
112 |
+
else:
|
113 |
+
assert 'num_groups' in cfg_
|
114 |
+
layer = norm_layer(num_channels=num_features, **cfg_)
|
115 |
+
|
116 |
+
for param in layer.parameters():
|
117 |
+
param.requires_grad = requires_grad
|
118 |
+
|
119 |
+
return name, layer
|
120 |
+
|
121 |
+
|
122 |
+
def is_norm(layer, exclude=None):
|
123 |
+
"""Check if a layer is a normalization layer.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
layer (nn.Module): The layer to be checked.
|
127 |
+
exclude (type | tuple[type]): Types to be excluded.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
bool: Whether the layer is a norm layer.
|
131 |
+
"""
|
132 |
+
if exclude is not None:
|
133 |
+
if not isinstance(exclude, tuple):
|
134 |
+
exclude = (exclude, )
|
135 |
+
if not is_tuple_of(exclude, type):
|
136 |
+
raise TypeError(
|
137 |
+
f'"exclude" must be either None or type or a tuple of types, '
|
138 |
+
f'but got {type(exclude)}: {exclude}')
|
139 |
+
|
140 |
+
if exclude and isinstance(layer, exclude):
|
141 |
+
return False
|
142 |
+
|
143 |
+
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
|
144 |
+
return isinstance(layer, all_norm_bases)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/padding.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .registry import PADDING_LAYERS
|
5 |
+
|
6 |
+
PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
|
7 |
+
PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
|
8 |
+
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
|
9 |
+
|
10 |
+
|
11 |
+
def build_padding_layer(cfg, *args, **kwargs):
|
12 |
+
"""Build padding layer.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
cfg (None or dict): The padding layer config, which should contain:
|
16 |
+
- type (str): Layer type.
|
17 |
+
- layer args: Args needed to instantiate a padding layer.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
nn.Module: Created padding layer.
|
21 |
+
"""
|
22 |
+
if not isinstance(cfg, dict):
|
23 |
+
raise TypeError('cfg must be a dict')
|
24 |
+
if 'type' not in cfg:
|
25 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
26 |
+
|
27 |
+
cfg_ = cfg.copy()
|
28 |
+
padding_type = cfg_.pop('type')
|
29 |
+
if padding_type not in PADDING_LAYERS:
|
30 |
+
raise KeyError(f'Unrecognized padding type {padding_type}.')
|
31 |
+
else:
|
32 |
+
padding_layer = PADDING_LAYERS.get(padding_type)
|
33 |
+
|
34 |
+
layer = padding_layer(*args, **kwargs, **cfg_)
|
35 |
+
|
36 |
+
return layer
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/plugin.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import platform
|
3 |
+
|
4 |
+
from .registry import PLUGIN_LAYERS
|
5 |
+
|
6 |
+
if platform.system() == 'Windows':
|
7 |
+
import regex as re
|
8 |
+
else:
|
9 |
+
import re
|
10 |
+
|
11 |
+
|
12 |
+
def infer_abbr(class_type):
|
13 |
+
"""Infer abbreviation from the class name.
|
14 |
+
|
15 |
+
This method will infer the abbreviation to map class types to
|
16 |
+
abbreviations.
|
17 |
+
|
18 |
+
Rule 1: If the class has the property "abbr", return the property.
|
19 |
+
Rule 2: Otherwise, the abbreviation falls back to snake case of class
|
20 |
+
name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
class_type (type): The norm layer type.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
str: The inferred abbreviation.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def camel2snack(word):
|
30 |
+
"""Convert camel case word into snack case.
|
31 |
+
|
32 |
+
Modified from `inflection lib
|
33 |
+
<https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
|
34 |
+
|
35 |
+
Example::
|
36 |
+
|
37 |
+
>>> camel2snack("FancyBlock")
|
38 |
+
'fancy_block'
|
39 |
+
"""
|
40 |
+
|
41 |
+
word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
|
42 |
+
word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
|
43 |
+
word = word.replace('-', '_')
|
44 |
+
return word.lower()
|
45 |
+
|
46 |
+
if not inspect.isclass(class_type):
|
47 |
+
raise TypeError(
|
48 |
+
f'class_type must be a type, but got {type(class_type)}')
|
49 |
+
if hasattr(class_type, '_abbr_'):
|
50 |
+
return class_type._abbr_
|
51 |
+
else:
|
52 |
+
return camel2snack(class_type.__name__)
|
53 |
+
|
54 |
+
|
55 |
+
def build_plugin_layer(cfg, postfix='', **kwargs):
|
56 |
+
"""Build plugin layer.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
cfg (None or dict): cfg should contain:
|
60 |
+
type (str): identify plugin layer type.
|
61 |
+
layer args: args needed to instantiate a plugin layer.
|
62 |
+
postfix (int, str): appended into norm abbreviation to
|
63 |
+
create named layer. Default: ''.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
tuple[str, nn.Module]:
|
67 |
+
name (str): abbreviation + postfix
|
68 |
+
layer (nn.Module): created plugin layer
|
69 |
+
"""
|
70 |
+
if not isinstance(cfg, dict):
|
71 |
+
raise TypeError('cfg must be a dict')
|
72 |
+
if 'type' not in cfg:
|
73 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
74 |
+
cfg_ = cfg.copy()
|
75 |
+
|
76 |
+
layer_type = cfg_.pop('type')
|
77 |
+
if layer_type not in PLUGIN_LAYERS:
|
78 |
+
raise KeyError(f'Unrecognized plugin type {layer_type}')
|
79 |
+
|
80 |
+
plugin_layer = PLUGIN_LAYERS.get(layer_type)
|
81 |
+
abbr = infer_abbr(plugin_layer)
|
82 |
+
|
83 |
+
assert isinstance(postfix, (int, str))
|
84 |
+
name = abbr + str(postfix)
|
85 |
+
|
86 |
+
layer = plugin_layer(**kwargs, **cfg_)
|
87 |
+
|
88 |
+
return name, layer
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/registry.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from annotator.mmpkg.mmcv.utils import Registry
|
3 |
+
|
4 |
+
CONV_LAYERS = Registry('conv layer')
|
5 |
+
NORM_LAYERS = Registry('norm layer')
|
6 |
+
ACTIVATION_LAYERS = Registry('activation layer')
|
7 |
+
PADDING_LAYERS = Registry('padding layer')
|
8 |
+
UPSAMPLE_LAYERS = Registry('upsample layer')
|
9 |
+
PLUGIN_LAYERS = Registry('plugin layer')
|
10 |
+
|
11 |
+
DROPOUT_LAYERS = Registry('drop out layers')
|
12 |
+
POSITIONAL_ENCODING = Registry('position encoding')
|
13 |
+
ATTENTION = Registry('attention')
|
14 |
+
FEEDFORWARD_NETWORK = Registry('feed-forward Network')
|
15 |
+
TRANSFORMER_LAYER = Registry('transformerLayer')
|
16 |
+
TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/scale.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class Scale(nn.Module):
|
7 |
+
"""A learnable scale parameter.
|
8 |
+
|
9 |
+
This layer scales the input by a learnable factor. It multiplies a
|
10 |
+
learnable scale parameter of shape (1,) with input of any shape.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
scale (float): Initial value of scale factor. Default: 1.0
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, scale=1.0):
|
17 |
+
super(Scale, self).__init__()
|
18 |
+
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return x * self.scale
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/swish.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .registry import ACTIVATION_LAYERS
|
6 |
+
|
7 |
+
|
8 |
+
@ACTIVATION_LAYERS.register_module()
|
9 |
+
class Swish(nn.Module):
|
10 |
+
"""Swish Module.
|
11 |
+
|
12 |
+
This module applies the swish function:
|
13 |
+
|
14 |
+
.. math::
|
15 |
+
Swish(x) = x * Sigmoid(x)
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Tensor: The output tensor.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
super(Swish, self).__init__()
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return x * torch.sigmoid(x)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/transformer.py
ADDED
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from annotator.mmpkg.mmcv import ConfigDict, deprecated_api_warning
|
9 |
+
from annotator.mmpkg.mmcv.cnn import Linear, build_activation_layer, build_norm_layer
|
10 |
+
from annotator.mmpkg.mmcv.runner.base_module import BaseModule, ModuleList, Sequential
|
11 |
+
from annotator.mmpkg.mmcv.utils import build_from_cfg
|
12 |
+
from .drop import build_dropout
|
13 |
+
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
|
14 |
+
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
|
15 |
+
|
16 |
+
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
|
17 |
+
try:
|
18 |
+
from annotator.mmpkg.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
|
19 |
+
warnings.warn(
|
20 |
+
ImportWarning(
|
21 |
+
'``MultiScaleDeformableAttention`` has been moved to '
|
22 |
+
'``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
|
23 |
+
'``from annotator.mmpkg.mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
|
24 |
+
'to ``from annotator.mmpkg.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
|
25 |
+
))
|
26 |
+
|
27 |
+
except ImportError:
|
28 |
+
warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
|
29 |
+
'``mmcv.ops.multi_scale_deform_attn``, '
|
30 |
+
'You should install ``mmcv-full`` if you need this module. ')
|
31 |
+
|
32 |
+
|
33 |
+
def build_positional_encoding(cfg, default_args=None):
|
34 |
+
"""Builder for Position Encoding."""
|
35 |
+
return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
|
36 |
+
|
37 |
+
|
38 |
+
def build_attention(cfg, default_args=None):
|
39 |
+
"""Builder for attention."""
|
40 |
+
return build_from_cfg(cfg, ATTENTION, default_args)
|
41 |
+
|
42 |
+
|
43 |
+
def build_feedforward_network(cfg, default_args=None):
|
44 |
+
"""Builder for feed-forward network (FFN)."""
|
45 |
+
return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
|
46 |
+
|
47 |
+
|
48 |
+
def build_transformer_layer(cfg, default_args=None):
|
49 |
+
"""Builder for transformer layer."""
|
50 |
+
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
|
51 |
+
|
52 |
+
|
53 |
+
def build_transformer_layer_sequence(cfg, default_args=None):
|
54 |
+
"""Builder for transformer encoder and transformer decoder."""
|
55 |
+
return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
|
56 |
+
|
57 |
+
|
58 |
+
@ATTENTION.register_module()
|
59 |
+
class MultiheadAttention(BaseModule):
|
60 |
+
"""A wrapper for ``torch.nn.MultiheadAttention``.
|
61 |
+
|
62 |
+
This module implements MultiheadAttention with identity connection,
|
63 |
+
and positional encoding is also passed as input.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
embed_dims (int): The embedding dimension.
|
67 |
+
num_heads (int): Parallel attention heads.
|
68 |
+
attn_drop (float): A Dropout layer on attn_output_weights.
|
69 |
+
Default: 0.0.
|
70 |
+
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
71 |
+
Default: 0.0.
|
72 |
+
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
73 |
+
when adding the shortcut.
|
74 |
+
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
75 |
+
Default: None.
|
76 |
+
batch_first (bool): When it is True, Key, Query and Value are shape of
|
77 |
+
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
|
78 |
+
Default to False.
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(self,
|
82 |
+
embed_dims,
|
83 |
+
num_heads,
|
84 |
+
attn_drop=0.,
|
85 |
+
proj_drop=0.,
|
86 |
+
dropout_layer=dict(type='Dropout', drop_prob=0.),
|
87 |
+
init_cfg=None,
|
88 |
+
batch_first=False,
|
89 |
+
**kwargs):
|
90 |
+
super(MultiheadAttention, self).__init__(init_cfg)
|
91 |
+
if 'dropout' in kwargs:
|
92 |
+
warnings.warn('The arguments `dropout` in MultiheadAttention '
|
93 |
+
'has been deprecated, now you can separately '
|
94 |
+
'set `attn_drop`(float), proj_drop(float), '
|
95 |
+
'and `dropout_layer`(dict) ')
|
96 |
+
attn_drop = kwargs['dropout']
|
97 |
+
dropout_layer['drop_prob'] = kwargs.pop('dropout')
|
98 |
+
|
99 |
+
self.embed_dims = embed_dims
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.batch_first = batch_first
|
102 |
+
|
103 |
+
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
|
104 |
+
**kwargs)
|
105 |
+
|
106 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
107 |
+
self.dropout_layer = build_dropout(
|
108 |
+
dropout_layer) if dropout_layer else nn.Identity()
|
109 |
+
|
110 |
+
@deprecated_api_warning({'residual': 'identity'},
|
111 |
+
cls_name='MultiheadAttention')
|
112 |
+
def forward(self,
|
113 |
+
query,
|
114 |
+
key=None,
|
115 |
+
value=None,
|
116 |
+
identity=None,
|
117 |
+
query_pos=None,
|
118 |
+
key_pos=None,
|
119 |
+
attn_mask=None,
|
120 |
+
key_padding_mask=None,
|
121 |
+
**kwargs):
|
122 |
+
"""Forward function for `MultiheadAttention`.
|
123 |
+
|
124 |
+
**kwargs allow passing a more general data flow when combining
|
125 |
+
with other operations in `transformerlayer`.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
query (Tensor): The input query with shape [num_queries, bs,
|
129 |
+
embed_dims] if self.batch_first is False, else
|
130 |
+
[bs, num_queries embed_dims].
|
131 |
+
key (Tensor): The key tensor with shape [num_keys, bs,
|
132 |
+
embed_dims] if self.batch_first is False, else
|
133 |
+
[bs, num_keys, embed_dims] .
|
134 |
+
If None, the ``query`` will be used. Defaults to None.
|
135 |
+
value (Tensor): The value tensor with same shape as `key`.
|
136 |
+
Same in `nn.MultiheadAttention.forward`. Defaults to None.
|
137 |
+
If None, the `key` will be used.
|
138 |
+
identity (Tensor): This tensor, with the same shape as x,
|
139 |
+
will be used for the identity link.
|
140 |
+
If None, `x` will be used. Defaults to None.
|
141 |
+
query_pos (Tensor): The positional encoding for query, with
|
142 |
+
the same shape as `x`. If not None, it will
|
143 |
+
be added to `x` before forward function. Defaults to None.
|
144 |
+
key_pos (Tensor): The positional encoding for `key`, with the
|
145 |
+
same shape as `key`. Defaults to None. If not None, it will
|
146 |
+
be added to `key` before forward function. If None, and
|
147 |
+
`query_pos` has the same shape as `key`, then `query_pos`
|
148 |
+
will be used for `key_pos`. Defaults to None.
|
149 |
+
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
|
150 |
+
num_keys]. Same in `nn.MultiheadAttention.forward`.
|
151 |
+
Defaults to None.
|
152 |
+
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
|
153 |
+
Defaults to None.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Tensor: forwarded results with shape
|
157 |
+
[num_queries, bs, embed_dims]
|
158 |
+
if self.batch_first is False, else
|
159 |
+
[bs, num_queries embed_dims].
|
160 |
+
"""
|
161 |
+
|
162 |
+
if key is None:
|
163 |
+
key = query
|
164 |
+
if value is None:
|
165 |
+
value = key
|
166 |
+
if identity is None:
|
167 |
+
identity = query
|
168 |
+
if key_pos is None:
|
169 |
+
if query_pos is not None:
|
170 |
+
# use query_pos if key_pos is not available
|
171 |
+
if query_pos.shape == key.shape:
|
172 |
+
key_pos = query_pos
|
173 |
+
else:
|
174 |
+
warnings.warn(f'position encoding of key is'
|
175 |
+
f'missing in {self.__class__.__name__}.')
|
176 |
+
if query_pos is not None:
|
177 |
+
query = query + query_pos
|
178 |
+
if key_pos is not None:
|
179 |
+
key = key + key_pos
|
180 |
+
|
181 |
+
# Because the dataflow('key', 'query', 'value') of
|
182 |
+
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
183 |
+
# embed_dims), We should adjust the shape of dataflow from
|
184 |
+
# batch_first (batch, num_query, embed_dims) to num_query_first
|
185 |
+
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
186 |
+
# from num_query_first to batch_first.
|
187 |
+
if self.batch_first:
|
188 |
+
query = query.transpose(0, 1)
|
189 |
+
key = key.transpose(0, 1)
|
190 |
+
value = value.transpose(0, 1)
|
191 |
+
|
192 |
+
out = self.attn(
|
193 |
+
query=query,
|
194 |
+
key=key,
|
195 |
+
value=value,
|
196 |
+
attn_mask=attn_mask,
|
197 |
+
key_padding_mask=key_padding_mask)[0]
|
198 |
+
|
199 |
+
if self.batch_first:
|
200 |
+
out = out.transpose(0, 1)
|
201 |
+
|
202 |
+
return identity + self.dropout_layer(self.proj_drop(out))
|
203 |
+
|
204 |
+
|
205 |
+
@FEEDFORWARD_NETWORK.register_module()
|
206 |
+
class FFN(BaseModule):
|
207 |
+
"""Implements feed-forward networks (FFNs) with identity connection.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
embed_dims (int): The feature dimension. Same as
|
211 |
+
`MultiheadAttention`. Defaults: 256.
|
212 |
+
feedforward_channels (int): The hidden dimension of FFNs.
|
213 |
+
Defaults: 1024.
|
214 |
+
num_fcs (int, optional): The number of fully-connected layers in
|
215 |
+
FFNs. Default: 2.
|
216 |
+
act_cfg (dict, optional): The activation config for FFNs.
|
217 |
+
Default: dict(type='ReLU')
|
218 |
+
ffn_drop (float, optional): Probability of an element to be
|
219 |
+
zeroed in FFN. Default 0.0.
|
220 |
+
add_identity (bool, optional): Whether to add the
|
221 |
+
identity connection. Default: `True`.
|
222 |
+
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
223 |
+
when adding the shortcut.
|
224 |
+
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
225 |
+
Default: None.
|
226 |
+
"""
|
227 |
+
|
228 |
+
@deprecated_api_warning(
|
229 |
+
{
|
230 |
+
'dropout': 'ffn_drop',
|
231 |
+
'add_residual': 'add_identity'
|
232 |
+
},
|
233 |
+
cls_name='FFN')
|
234 |
+
def __init__(self,
|
235 |
+
embed_dims=256,
|
236 |
+
feedforward_channels=1024,
|
237 |
+
num_fcs=2,
|
238 |
+
act_cfg=dict(type='ReLU', inplace=True),
|
239 |
+
ffn_drop=0.,
|
240 |
+
dropout_layer=None,
|
241 |
+
add_identity=True,
|
242 |
+
init_cfg=None,
|
243 |
+
**kwargs):
|
244 |
+
super(FFN, self).__init__(init_cfg)
|
245 |
+
assert num_fcs >= 2, 'num_fcs should be no less ' \
|
246 |
+
f'than 2. got {num_fcs}.'
|
247 |
+
self.embed_dims = embed_dims
|
248 |
+
self.feedforward_channels = feedforward_channels
|
249 |
+
self.num_fcs = num_fcs
|
250 |
+
self.act_cfg = act_cfg
|
251 |
+
self.activate = build_activation_layer(act_cfg)
|
252 |
+
|
253 |
+
layers = []
|
254 |
+
in_channels = embed_dims
|
255 |
+
for _ in range(num_fcs - 1):
|
256 |
+
layers.append(
|
257 |
+
Sequential(
|
258 |
+
Linear(in_channels, feedforward_channels), self.activate,
|
259 |
+
nn.Dropout(ffn_drop)))
|
260 |
+
in_channels = feedforward_channels
|
261 |
+
layers.append(Linear(feedforward_channels, embed_dims))
|
262 |
+
layers.append(nn.Dropout(ffn_drop))
|
263 |
+
self.layers = Sequential(*layers)
|
264 |
+
self.dropout_layer = build_dropout(
|
265 |
+
dropout_layer) if dropout_layer else torch.nn.Identity()
|
266 |
+
self.add_identity = add_identity
|
267 |
+
|
268 |
+
@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
|
269 |
+
def forward(self, x, identity=None):
|
270 |
+
"""Forward function for `FFN`.
|
271 |
+
|
272 |
+
The function would add x to the output tensor if residue is None.
|
273 |
+
"""
|
274 |
+
out = self.layers(x)
|
275 |
+
if not self.add_identity:
|
276 |
+
return self.dropout_layer(out)
|
277 |
+
if identity is None:
|
278 |
+
identity = x
|
279 |
+
return identity + self.dropout_layer(out)
|
280 |
+
|
281 |
+
|
282 |
+
@TRANSFORMER_LAYER.register_module()
|
283 |
+
class BaseTransformerLayer(BaseModule):
|
284 |
+
"""Base `TransformerLayer` for vision transformer.
|
285 |
+
|
286 |
+
It can be built from `mmcv.ConfigDict` and support more flexible
|
287 |
+
customization, for example, using any number of `FFN or LN ` and
|
288 |
+
use different kinds of `attention` by specifying a list of `ConfigDict`
|
289 |
+
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
|
290 |
+
when you specifying `norm` as the first element of `operation_order`.
|
291 |
+
More details about the `prenorm`: `On Layer Normalization in the
|
292 |
+
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
|
293 |
+
|
294 |
+
Args:
|
295 |
+
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
296 |
+
Configs for `self_attention` or `cross_attention` modules,
|
297 |
+
The order of the configs in the list should be consistent with
|
298 |
+
corresponding attentions in operation_order.
|
299 |
+
If it is a dict, all of the attention modules in operation_order
|
300 |
+
will be built with this config. Default: None.
|
301 |
+
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
302 |
+
Configs for FFN, The order of the configs in the list should be
|
303 |
+
consistent with corresponding ffn in operation_order.
|
304 |
+
If it is a dict, all of the attention modules in operation_order
|
305 |
+
will be built with this config.
|
306 |
+
operation_order (tuple[str]): The execution order of operation
|
307 |
+
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
|
308 |
+
Support `prenorm` when you specifying first element as `norm`.
|
309 |
+
Default:None.
|
310 |
+
norm_cfg (dict): Config dict for normalization layer.
|
311 |
+
Default: dict(type='LN').
|
312 |
+
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
313 |
+
Default: None.
|
314 |
+
batch_first (bool): Key, Query and Value are shape
|
315 |
+
of (batch, n, embed_dim)
|
316 |
+
or (n, batch, embed_dim). Default to False.
|
317 |
+
"""
|
318 |
+
|
319 |
+
def __init__(self,
|
320 |
+
attn_cfgs=None,
|
321 |
+
ffn_cfgs=dict(
|
322 |
+
type='FFN',
|
323 |
+
embed_dims=256,
|
324 |
+
feedforward_channels=1024,
|
325 |
+
num_fcs=2,
|
326 |
+
ffn_drop=0.,
|
327 |
+
act_cfg=dict(type='ReLU', inplace=True),
|
328 |
+
),
|
329 |
+
operation_order=None,
|
330 |
+
norm_cfg=dict(type='LN'),
|
331 |
+
init_cfg=None,
|
332 |
+
batch_first=False,
|
333 |
+
**kwargs):
|
334 |
+
|
335 |
+
deprecated_args = dict(
|
336 |
+
feedforward_channels='feedforward_channels',
|
337 |
+
ffn_dropout='ffn_drop',
|
338 |
+
ffn_num_fcs='num_fcs')
|
339 |
+
for ori_name, new_name in deprecated_args.items():
|
340 |
+
if ori_name in kwargs:
|
341 |
+
warnings.warn(
|
342 |
+
f'The arguments `{ori_name}` in BaseTransformerLayer '
|
343 |
+
f'has been deprecated, now you should set `{new_name}` '
|
344 |
+
f'and other FFN related arguments '
|
345 |
+
f'to a dict named `ffn_cfgs`. ')
|
346 |
+
ffn_cfgs[new_name] = kwargs[ori_name]
|
347 |
+
|
348 |
+
super(BaseTransformerLayer, self).__init__(init_cfg)
|
349 |
+
|
350 |
+
self.batch_first = batch_first
|
351 |
+
|
352 |
+
assert set(operation_order) & set(
|
353 |
+
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
|
354 |
+
set(operation_order), f'The operation_order of' \
|
355 |
+
f' {self.__class__.__name__} should ' \
|
356 |
+
f'contains all four operation type ' \
|
357 |
+
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
|
358 |
+
|
359 |
+
num_attn = operation_order.count('self_attn') + operation_order.count(
|
360 |
+
'cross_attn')
|
361 |
+
if isinstance(attn_cfgs, dict):
|
362 |
+
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
|
363 |
+
else:
|
364 |
+
assert num_attn == len(attn_cfgs), f'The length ' \
|
365 |
+
f'of attn_cfg {num_attn} is ' \
|
366 |
+
f'not consistent with the number of attention' \
|
367 |
+
f'in operation_order {operation_order}.'
|
368 |
+
|
369 |
+
self.num_attn = num_attn
|
370 |
+
self.operation_order = operation_order
|
371 |
+
self.norm_cfg = norm_cfg
|
372 |
+
self.pre_norm = operation_order[0] == 'norm'
|
373 |
+
self.attentions = ModuleList()
|
374 |
+
|
375 |
+
index = 0
|
376 |
+
for operation_name in operation_order:
|
377 |
+
if operation_name in ['self_attn', 'cross_attn']:
|
378 |
+
if 'batch_first' in attn_cfgs[index]:
|
379 |
+
assert self.batch_first == attn_cfgs[index]['batch_first']
|
380 |
+
else:
|
381 |
+
attn_cfgs[index]['batch_first'] = self.batch_first
|
382 |
+
attention = build_attention(attn_cfgs[index])
|
383 |
+
# Some custom attentions used as `self_attn`
|
384 |
+
# or `cross_attn` can have different behavior.
|
385 |
+
attention.operation_name = operation_name
|
386 |
+
self.attentions.append(attention)
|
387 |
+
index += 1
|
388 |
+
|
389 |
+
self.embed_dims = self.attentions[0].embed_dims
|
390 |
+
|
391 |
+
self.ffns = ModuleList()
|
392 |
+
num_ffns = operation_order.count('ffn')
|
393 |
+
if isinstance(ffn_cfgs, dict):
|
394 |
+
ffn_cfgs = ConfigDict(ffn_cfgs)
|
395 |
+
if isinstance(ffn_cfgs, dict):
|
396 |
+
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
|
397 |
+
assert len(ffn_cfgs) == num_ffns
|
398 |
+
for ffn_index in range(num_ffns):
|
399 |
+
if 'embed_dims' not in ffn_cfgs[ffn_index]:
|
400 |
+
ffn_cfgs['embed_dims'] = self.embed_dims
|
401 |
+
else:
|
402 |
+
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
|
403 |
+
self.ffns.append(
|
404 |
+
build_feedforward_network(ffn_cfgs[ffn_index],
|
405 |
+
dict(type='FFN')))
|
406 |
+
|
407 |
+
self.norms = ModuleList()
|
408 |
+
num_norms = operation_order.count('norm')
|
409 |
+
for _ in range(num_norms):
|
410 |
+
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
|
411 |
+
|
412 |
+
def forward(self,
|
413 |
+
query,
|
414 |
+
key=None,
|
415 |
+
value=None,
|
416 |
+
query_pos=None,
|
417 |
+
key_pos=None,
|
418 |
+
attn_masks=None,
|
419 |
+
query_key_padding_mask=None,
|
420 |
+
key_padding_mask=None,
|
421 |
+
**kwargs):
|
422 |
+
"""Forward function for `TransformerDecoderLayer`.
|
423 |
+
|
424 |
+
**kwargs contains some specific arguments of attentions.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
query (Tensor): The input query with shape
|
428 |
+
[num_queries, bs, embed_dims] if
|
429 |
+
self.batch_first is False, else
|
430 |
+
[bs, num_queries embed_dims].
|
431 |
+
key (Tensor): The key tensor with shape [num_keys, bs,
|
432 |
+
embed_dims] if self.batch_first is False, else
|
433 |
+
[bs, num_keys, embed_dims] .
|
434 |
+
value (Tensor): The value tensor with same shape as `key`.
|
435 |
+
query_pos (Tensor): The positional encoding for `query`.
|
436 |
+
Default: None.
|
437 |
+
key_pos (Tensor): The positional encoding for `key`.
|
438 |
+
Default: None.
|
439 |
+
attn_masks (List[Tensor] | None): 2D Tensor used in
|
440 |
+
calculation of corresponding attention. The length of
|
441 |
+
it should equal to the number of `attention` in
|
442 |
+
`operation_order`. Default: None.
|
443 |
+
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
444 |
+
shape [bs, num_queries]. Only used in `self_attn` layer.
|
445 |
+
Defaults to None.
|
446 |
+
key_padding_mask (Tensor): ByteTensor for `query`, with
|
447 |
+
shape [bs, num_keys]. Default: None.
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
|
451 |
+
"""
|
452 |
+
|
453 |
+
norm_index = 0
|
454 |
+
attn_index = 0
|
455 |
+
ffn_index = 0
|
456 |
+
identity = query
|
457 |
+
if attn_masks is None:
|
458 |
+
attn_masks = [None for _ in range(self.num_attn)]
|
459 |
+
elif isinstance(attn_masks, torch.Tensor):
|
460 |
+
attn_masks = [
|
461 |
+
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
|
462 |
+
]
|
463 |
+
warnings.warn(f'Use same attn_mask in all attentions in '
|
464 |
+
f'{self.__class__.__name__} ')
|
465 |
+
else:
|
466 |
+
assert len(attn_masks) == self.num_attn, f'The length of ' \
|
467 |
+
f'attn_masks {len(attn_masks)} must be equal ' \
|
468 |
+
f'to the number of attention in ' \
|
469 |
+
f'operation_order {self.num_attn}'
|
470 |
+
|
471 |
+
for layer in self.operation_order:
|
472 |
+
if layer == 'self_attn':
|
473 |
+
temp_key = temp_value = query
|
474 |
+
query = self.attentions[attn_index](
|
475 |
+
query,
|
476 |
+
temp_key,
|
477 |
+
temp_value,
|
478 |
+
identity if self.pre_norm else None,
|
479 |
+
query_pos=query_pos,
|
480 |
+
key_pos=query_pos,
|
481 |
+
attn_mask=attn_masks[attn_index],
|
482 |
+
key_padding_mask=query_key_padding_mask,
|
483 |
+
**kwargs)
|
484 |
+
attn_index += 1
|
485 |
+
identity = query
|
486 |
+
|
487 |
+
elif layer == 'norm':
|
488 |
+
query = self.norms[norm_index](query)
|
489 |
+
norm_index += 1
|
490 |
+
|
491 |
+
elif layer == 'cross_attn':
|
492 |
+
query = self.attentions[attn_index](
|
493 |
+
query,
|
494 |
+
key,
|
495 |
+
value,
|
496 |
+
identity if self.pre_norm else None,
|
497 |
+
query_pos=query_pos,
|
498 |
+
key_pos=key_pos,
|
499 |
+
attn_mask=attn_masks[attn_index],
|
500 |
+
key_padding_mask=key_padding_mask,
|
501 |
+
**kwargs)
|
502 |
+
attn_index += 1
|
503 |
+
identity = query
|
504 |
+
|
505 |
+
elif layer == 'ffn':
|
506 |
+
query = self.ffns[ffn_index](
|
507 |
+
query, identity if self.pre_norm else None)
|
508 |
+
ffn_index += 1
|
509 |
+
|
510 |
+
return query
|
511 |
+
|
512 |
+
|
513 |
+
@TRANSFORMER_LAYER_SEQUENCE.register_module()
|
514 |
+
class TransformerLayerSequence(BaseModule):
|
515 |
+
"""Base class for TransformerEncoder and TransformerDecoder in vision
|
516 |
+
transformer.
|
517 |
+
|
518 |
+
As base-class of Encoder and Decoder in vision transformer.
|
519 |
+
Support customization such as specifying different kind
|
520 |
+
of `transformer_layer` in `transformer_coder`.
|
521 |
+
|
522 |
+
Args:
|
523 |
+
transformerlayer (list[obj:`mmcv.ConfigDict`] |
|
524 |
+
obj:`mmcv.ConfigDict`): Config of transformerlayer
|
525 |
+
in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
|
526 |
+
it would be repeated `num_layer` times to a
|
527 |
+
list[`mmcv.ConfigDict`]. Default: None.
|
528 |
+
num_layers (int): The number of `TransformerLayer`. Default: None.
|
529 |
+
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
530 |
+
Default: None.
|
531 |
+
"""
|
532 |
+
|
533 |
+
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
|
534 |
+
super(TransformerLayerSequence, self).__init__(init_cfg)
|
535 |
+
if isinstance(transformerlayers, dict):
|
536 |
+
transformerlayers = [
|
537 |
+
copy.deepcopy(transformerlayers) for _ in range(num_layers)
|
538 |
+
]
|
539 |
+
else:
|
540 |
+
assert isinstance(transformerlayers, list) and \
|
541 |
+
len(transformerlayers) == num_layers
|
542 |
+
self.num_layers = num_layers
|
543 |
+
self.layers = ModuleList()
|
544 |
+
for i in range(num_layers):
|
545 |
+
self.layers.append(build_transformer_layer(transformerlayers[i]))
|
546 |
+
self.embed_dims = self.layers[0].embed_dims
|
547 |
+
self.pre_norm = self.layers[0].pre_norm
|
548 |
+
|
549 |
+
def forward(self,
|
550 |
+
query,
|
551 |
+
key,
|
552 |
+
value,
|
553 |
+
query_pos=None,
|
554 |
+
key_pos=None,
|
555 |
+
attn_masks=None,
|
556 |
+
query_key_padding_mask=None,
|
557 |
+
key_padding_mask=None,
|
558 |
+
**kwargs):
|
559 |
+
"""Forward function for `TransformerCoder`.
|
560 |
+
|
561 |
+
Args:
|
562 |
+
query (Tensor): Input query with shape
|
563 |
+
`(num_queries, bs, embed_dims)`.
|
564 |
+
key (Tensor): The key tensor with shape
|
565 |
+
`(num_keys, bs, embed_dims)`.
|
566 |
+
value (Tensor): The value tensor with shape
|
567 |
+
`(num_keys, bs, embed_dims)`.
|
568 |
+
query_pos (Tensor): The positional encoding for `query`.
|
569 |
+
Default: None.
|
570 |
+
key_pos (Tensor): The positional encoding for `key`.
|
571 |
+
Default: None.
|
572 |
+
attn_masks (List[Tensor], optional): Each element is 2D Tensor
|
573 |
+
which is used in calculation of corresponding attention in
|
574 |
+
operation_order. Default: None.
|
575 |
+
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
576 |
+
shape [bs, num_queries]. Only used in self-attention
|
577 |
+
Default: None.
|
578 |
+
key_padding_mask (Tensor): ByteTensor for `query`, with
|
579 |
+
shape [bs, num_keys]. Default: None.
|
580 |
+
|
581 |
+
Returns:
|
582 |
+
Tensor: results with shape [num_queries, bs, embed_dims].
|
583 |
+
"""
|
584 |
+
for layer in self.layers:
|
585 |
+
query = layer(
|
586 |
+
query,
|
587 |
+
key,
|
588 |
+
value,
|
589 |
+
query_pos=query_pos,
|
590 |
+
key_pos=key_pos,
|
591 |
+
attn_masks=attn_masks,
|
592 |
+
query_key_padding_mask=query_key_padding_mask,
|
593 |
+
key_padding_mask=key_padding_mask,
|
594 |
+
**kwargs)
|
595 |
+
return query
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/upsample.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from ..utils import xavier_init
|
6 |
+
from .registry import UPSAMPLE_LAYERS
|
7 |
+
|
8 |
+
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
|
9 |
+
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
|
10 |
+
|
11 |
+
|
12 |
+
@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
|
13 |
+
class PixelShufflePack(nn.Module):
|
14 |
+
"""Pixel Shuffle upsample layer.
|
15 |
+
|
16 |
+
This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
|
17 |
+
achieve a simple upsampling with pixel shuffle.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
in_channels (int): Number of input channels.
|
21 |
+
out_channels (int): Number of output channels.
|
22 |
+
scale_factor (int): Upsample ratio.
|
23 |
+
upsample_kernel (int): Kernel size of the conv layer to expand the
|
24 |
+
channels.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, in_channels, out_channels, scale_factor,
|
28 |
+
upsample_kernel):
|
29 |
+
super(PixelShufflePack, self).__init__()
|
30 |
+
self.in_channels = in_channels
|
31 |
+
self.out_channels = out_channels
|
32 |
+
self.scale_factor = scale_factor
|
33 |
+
self.upsample_kernel = upsample_kernel
|
34 |
+
self.upsample_conv = nn.Conv2d(
|
35 |
+
self.in_channels,
|
36 |
+
self.out_channels * scale_factor * scale_factor,
|
37 |
+
self.upsample_kernel,
|
38 |
+
padding=(self.upsample_kernel - 1) // 2)
|
39 |
+
self.init_weights()
|
40 |
+
|
41 |
+
def init_weights(self):
|
42 |
+
xavier_init(self.upsample_conv, distribution='uniform')
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.upsample_conv(x)
|
46 |
+
x = F.pixel_shuffle(x, self.scale_factor)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
def build_upsample_layer(cfg, *args, **kwargs):
|
51 |
+
"""Build upsample layer.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
cfg (dict): The upsample layer config, which should contain:
|
55 |
+
|
56 |
+
- type (str): Layer type.
|
57 |
+
- scale_factor (int): Upsample ratio, which is not applicable to
|
58 |
+
deconv.
|
59 |
+
- layer args: Args needed to instantiate a upsample layer.
|
60 |
+
args (argument list): Arguments passed to the ``__init__``
|
61 |
+
method of the corresponding conv layer.
|
62 |
+
kwargs (keyword arguments): Keyword arguments passed to the
|
63 |
+
``__init__`` method of the corresponding conv layer.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
nn.Module: Created upsample layer.
|
67 |
+
"""
|
68 |
+
if not isinstance(cfg, dict):
|
69 |
+
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
70 |
+
if 'type' not in cfg:
|
71 |
+
raise KeyError(
|
72 |
+
f'the cfg dict must contain the key "type", but got {cfg}')
|
73 |
+
cfg_ = cfg.copy()
|
74 |
+
|
75 |
+
layer_type = cfg_.pop('type')
|
76 |
+
if layer_type not in UPSAMPLE_LAYERS:
|
77 |
+
raise KeyError(f'Unrecognized upsample type {layer_type}')
|
78 |
+
else:
|
79 |
+
upsample = UPSAMPLE_LAYERS.get(layer_type)
|
80 |
+
|
81 |
+
if upsample is nn.Upsample:
|
82 |
+
cfg_['mode'] = layer_type
|
83 |
+
layer = upsample(*args, **kwargs, **cfg_)
|
84 |
+
return layer
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/wrappers.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501
|
3 |
+
|
4 |
+
Wrap some nn modules to support empty tensor input. Currently, these wrappers
|
5 |
+
are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
|
6 |
+
heads are trained on only positive RoIs.
|
7 |
+
"""
|
8 |
+
import math
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.nn.modules.utils import _pair, _triple
|
13 |
+
|
14 |
+
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
|
15 |
+
|
16 |
+
if torch.__version__ == 'parrots':
|
17 |
+
TORCH_VERSION = torch.__version__
|
18 |
+
else:
|
19 |
+
# torch.__version__ could be 1.3.1+cu92, we only need the first two
|
20 |
+
# for comparison
|
21 |
+
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
|
22 |
+
|
23 |
+
|
24 |
+
def obsolete_torch_version(torch_version, version_threshold):
|
25 |
+
return torch_version == 'parrots' or torch_version <= version_threshold
|
26 |
+
|
27 |
+
|
28 |
+
class NewEmptyTensorOp(torch.autograd.Function):
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def forward(ctx, x, new_shape):
|
32 |
+
ctx.shape = x.shape
|
33 |
+
return x.new_empty(new_shape)
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def backward(ctx, grad):
|
37 |
+
shape = ctx.shape
|
38 |
+
return NewEmptyTensorOp.apply(grad, shape), None
|
39 |
+
|
40 |
+
|
41 |
+
@CONV_LAYERS.register_module('Conv', force=True)
|
42 |
+
class Conv2d(nn.Conv2d):
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
46 |
+
out_shape = [x.shape[0], self.out_channels]
|
47 |
+
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
|
48 |
+
self.padding, self.stride, self.dilation):
|
49 |
+
o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
|
50 |
+
out_shape.append(o)
|
51 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
52 |
+
if self.training:
|
53 |
+
# produce dummy gradient to avoid DDP warning.
|
54 |
+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
55 |
+
return empty + dummy
|
56 |
+
else:
|
57 |
+
return empty
|
58 |
+
|
59 |
+
return super().forward(x)
|
60 |
+
|
61 |
+
|
62 |
+
@CONV_LAYERS.register_module('Conv3d', force=True)
|
63 |
+
class Conv3d(nn.Conv3d):
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
67 |
+
out_shape = [x.shape[0], self.out_channels]
|
68 |
+
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
|
69 |
+
self.padding, self.stride, self.dilation):
|
70 |
+
o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
|
71 |
+
out_shape.append(o)
|
72 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
73 |
+
if self.training:
|
74 |
+
# produce dummy gradient to avoid DDP warning.
|
75 |
+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
76 |
+
return empty + dummy
|
77 |
+
else:
|
78 |
+
return empty
|
79 |
+
|
80 |
+
return super().forward(x)
|
81 |
+
|
82 |
+
|
83 |
+
@CONV_LAYERS.register_module()
|
84 |
+
@CONV_LAYERS.register_module('deconv')
|
85 |
+
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
|
86 |
+
class ConvTranspose2d(nn.ConvTranspose2d):
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
90 |
+
out_shape = [x.shape[0], self.out_channels]
|
91 |
+
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
|
92 |
+
self.padding, self.stride,
|
93 |
+
self.dilation, self.output_padding):
|
94 |
+
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
|
95 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
96 |
+
if self.training:
|
97 |
+
# produce dummy gradient to avoid DDP warning.
|
98 |
+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
99 |
+
return empty + dummy
|
100 |
+
else:
|
101 |
+
return empty
|
102 |
+
|
103 |
+
return super().forward(x)
|
104 |
+
|
105 |
+
|
106 |
+
@CONV_LAYERS.register_module()
|
107 |
+
@CONV_LAYERS.register_module('deconv3d')
|
108 |
+
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
|
109 |
+
class ConvTranspose3d(nn.ConvTranspose3d):
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
113 |
+
out_shape = [x.shape[0], self.out_channels]
|
114 |
+
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
|
115 |
+
self.padding, self.stride,
|
116 |
+
self.dilation, self.output_padding):
|
117 |
+
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
|
118 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
119 |
+
if self.training:
|
120 |
+
# produce dummy gradient to avoid DDP warning.
|
121 |
+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
122 |
+
return empty + dummy
|
123 |
+
else:
|
124 |
+
return empty
|
125 |
+
|
126 |
+
return super().forward(x)
|
127 |
+
|
128 |
+
|
129 |
+
class MaxPool2d(nn.MaxPool2d):
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
# PyTorch 1.9 does not support empty tensor inference yet
|
133 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
|
134 |
+
out_shape = list(x.shape[:2])
|
135 |
+
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
|
136 |
+
_pair(self.padding), _pair(self.stride),
|
137 |
+
_pair(self.dilation)):
|
138 |
+
o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
|
139 |
+
o = math.ceil(o) if self.ceil_mode else math.floor(o)
|
140 |
+
out_shape.append(o)
|
141 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
142 |
+
return empty
|
143 |
+
|
144 |
+
return super().forward(x)
|
145 |
+
|
146 |
+
|
147 |
+
class MaxPool3d(nn.MaxPool3d):
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
# PyTorch 1.9 does not support empty tensor inference yet
|
151 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
|
152 |
+
out_shape = list(x.shape[:2])
|
153 |
+
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
|
154 |
+
_triple(self.padding),
|
155 |
+
_triple(self.stride),
|
156 |
+
_triple(self.dilation)):
|
157 |
+
o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
|
158 |
+
o = math.ceil(o) if self.ceil_mode else math.floor(o)
|
159 |
+
out_shape.append(o)
|
160 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
161 |
+
return empty
|
162 |
+
|
163 |
+
return super().forward(x)
|
164 |
+
|
165 |
+
|
166 |
+
class Linear(torch.nn.Linear):
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
# empty tensor forward of Linear layer is supported in Pytorch 1.6
|
170 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
|
171 |
+
out_shape = [x.shape[0], self.out_features]
|
172 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
173 |
+
if self.training:
|
174 |
+
# produce dummy gradient to avoid DDP warning.
|
175 |
+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
176 |
+
return empty + dummy
|
177 |
+
else:
|
178 |
+
return empty
|
179 |
+
|
180 |
+
return super().forward(x)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/builder.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from ..runner import Sequential
|
3 |
+
from ..utils import Registry, build_from_cfg
|
4 |
+
|
5 |
+
|
6 |
+
def build_model_from_cfg(cfg, registry, default_args=None):
|
7 |
+
"""Build a PyTorch model from config dict(s). Different from
|
8 |
+
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
cfg (dict, list[dict]): The config of modules, is is either a config
|
12 |
+
dict or a list of config dicts. If cfg is a list, a
|
13 |
+
the built modules will be wrapped with ``nn.Sequential``.
|
14 |
+
registry (:obj:`Registry`): A registry the module belongs to.
|
15 |
+
default_args (dict, optional): Default arguments to build the module.
|
16 |
+
Defaults to None.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
nn.Module: A built nn module.
|
20 |
+
"""
|
21 |
+
if isinstance(cfg, list):
|
22 |
+
modules = [
|
23 |
+
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
|
24 |
+
]
|
25 |
+
return Sequential(*modules)
|
26 |
+
else:
|
27 |
+
return build_from_cfg(cfg, registry, default_args)
|
28 |
+
|
29 |
+
|
30 |
+
MODELS = Registry('model', build_func=build_model_from_cfg)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/resnet.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.utils.checkpoint as cp
|
6 |
+
|
7 |
+
from .utils import constant_init, kaiming_init
|
8 |
+
|
9 |
+
|
10 |
+
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
|
11 |
+
"""3x3 convolution with padding."""
|
12 |
+
return nn.Conv2d(
|
13 |
+
in_planes,
|
14 |
+
out_planes,
|
15 |
+
kernel_size=3,
|
16 |
+
stride=stride,
|
17 |
+
padding=dilation,
|
18 |
+
dilation=dilation,
|
19 |
+
bias=False)
|
20 |
+
|
21 |
+
|
22 |
+
class BasicBlock(nn.Module):
|
23 |
+
expansion = 1
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
inplanes,
|
27 |
+
planes,
|
28 |
+
stride=1,
|
29 |
+
dilation=1,
|
30 |
+
downsample=None,
|
31 |
+
style='pytorch',
|
32 |
+
with_cp=False):
|
33 |
+
super(BasicBlock, self).__init__()
|
34 |
+
assert style in ['pytorch', 'caffe']
|
35 |
+
self.conv1 = conv3x3(inplanes, planes, stride, dilation)
|
36 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
37 |
+
self.relu = nn.ReLU(inplace=True)
|
38 |
+
self.conv2 = conv3x3(planes, planes)
|
39 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
40 |
+
self.downsample = downsample
|
41 |
+
self.stride = stride
|
42 |
+
self.dilation = dilation
|
43 |
+
assert not with_cp
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
residual = x
|
47 |
+
|
48 |
+
out = self.conv1(x)
|
49 |
+
out = self.bn1(out)
|
50 |
+
out = self.relu(out)
|
51 |
+
|
52 |
+
out = self.conv2(out)
|
53 |
+
out = self.bn2(out)
|
54 |
+
|
55 |
+
if self.downsample is not None:
|
56 |
+
residual = self.downsample(x)
|
57 |
+
|
58 |
+
out += residual
|
59 |
+
out = self.relu(out)
|
60 |
+
|
61 |
+
return out
|
62 |
+
|
63 |
+
|
64 |
+
class Bottleneck(nn.Module):
|
65 |
+
expansion = 4
|
66 |
+
|
67 |
+
def __init__(self,
|
68 |
+
inplanes,
|
69 |
+
planes,
|
70 |
+
stride=1,
|
71 |
+
dilation=1,
|
72 |
+
downsample=None,
|
73 |
+
style='pytorch',
|
74 |
+
with_cp=False):
|
75 |
+
"""Bottleneck block.
|
76 |
+
|
77 |
+
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
|
78 |
+
it is "caffe", the stride-two layer is the first 1x1 conv layer.
|
79 |
+
"""
|
80 |
+
super(Bottleneck, self).__init__()
|
81 |
+
assert style in ['pytorch', 'caffe']
|
82 |
+
if style == 'pytorch':
|
83 |
+
conv1_stride = 1
|
84 |
+
conv2_stride = stride
|
85 |
+
else:
|
86 |
+
conv1_stride = stride
|
87 |
+
conv2_stride = 1
|
88 |
+
self.conv1 = nn.Conv2d(
|
89 |
+
inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
|
90 |
+
self.conv2 = nn.Conv2d(
|
91 |
+
planes,
|
92 |
+
planes,
|
93 |
+
kernel_size=3,
|
94 |
+
stride=conv2_stride,
|
95 |
+
padding=dilation,
|
96 |
+
dilation=dilation,
|
97 |
+
bias=False)
|
98 |
+
|
99 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
100 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
101 |
+
self.conv3 = nn.Conv2d(
|
102 |
+
planes, planes * self.expansion, kernel_size=1, bias=False)
|
103 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
104 |
+
self.relu = nn.ReLU(inplace=True)
|
105 |
+
self.downsample = downsample
|
106 |
+
self.stride = stride
|
107 |
+
self.dilation = dilation
|
108 |
+
self.with_cp = with_cp
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
|
112 |
+
def _inner_forward(x):
|
113 |
+
residual = x
|
114 |
+
|
115 |
+
out = self.conv1(x)
|
116 |
+
out = self.bn1(out)
|
117 |
+
out = self.relu(out)
|
118 |
+
|
119 |
+
out = self.conv2(out)
|
120 |
+
out = self.bn2(out)
|
121 |
+
out = self.relu(out)
|
122 |
+
|
123 |
+
out = self.conv3(out)
|
124 |
+
out = self.bn3(out)
|
125 |
+
|
126 |
+
if self.downsample is not None:
|
127 |
+
residual = self.downsample(x)
|
128 |
+
|
129 |
+
out += residual
|
130 |
+
|
131 |
+
return out
|
132 |
+
|
133 |
+
if self.with_cp and x.requires_grad:
|
134 |
+
out = cp.checkpoint(_inner_forward, x)
|
135 |
+
else:
|
136 |
+
out = _inner_forward(x)
|
137 |
+
|
138 |
+
out = self.relu(out)
|
139 |
+
|
140 |
+
return out
|
141 |
+
|
142 |
+
|
143 |
+
def make_res_layer(block,
|
144 |
+
inplanes,
|
145 |
+
planes,
|
146 |
+
blocks,
|
147 |
+
stride=1,
|
148 |
+
dilation=1,
|
149 |
+
style='pytorch',
|
150 |
+
with_cp=False):
|
151 |
+
downsample = None
|
152 |
+
if stride != 1 or inplanes != planes * block.expansion:
|
153 |
+
downsample = nn.Sequential(
|
154 |
+
nn.Conv2d(
|
155 |
+
inplanes,
|
156 |
+
planes * block.expansion,
|
157 |
+
kernel_size=1,
|
158 |
+
stride=stride,
|
159 |
+
bias=False),
|
160 |
+
nn.BatchNorm2d(planes * block.expansion),
|
161 |
+
)
|
162 |
+
|
163 |
+
layers = []
|
164 |
+
layers.append(
|
165 |
+
block(
|
166 |
+
inplanes,
|
167 |
+
planes,
|
168 |
+
stride,
|
169 |
+
dilation,
|
170 |
+
downsample,
|
171 |
+
style=style,
|
172 |
+
with_cp=with_cp))
|
173 |
+
inplanes = planes * block.expansion
|
174 |
+
for _ in range(1, blocks):
|
175 |
+
layers.append(
|
176 |
+
block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
|
177 |
+
|
178 |
+
return nn.Sequential(*layers)
|
179 |
+
|
180 |
+
|
181 |
+
class ResNet(nn.Module):
|
182 |
+
"""ResNet backbone.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
|
186 |
+
num_stages (int): Resnet stages, normally 4.
|
187 |
+
strides (Sequence[int]): Strides of the first block of each stage.
|
188 |
+
dilations (Sequence[int]): Dilation of each stage.
|
189 |
+
out_indices (Sequence[int]): Output from which stages.
|
190 |
+
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
191 |
+
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
192 |
+
the first 1x1 conv layer.
|
193 |
+
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
194 |
+
not freezing any parameters.
|
195 |
+
bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
|
196 |
+
running stats (mean and var).
|
197 |
+
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
|
198 |
+
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
199 |
+
memory while slowing down the training speed.
|
200 |
+
"""
|
201 |
+
|
202 |
+
arch_settings = {
|
203 |
+
18: (BasicBlock, (2, 2, 2, 2)),
|
204 |
+
34: (BasicBlock, (3, 4, 6, 3)),
|
205 |
+
50: (Bottleneck, (3, 4, 6, 3)),
|
206 |
+
101: (Bottleneck, (3, 4, 23, 3)),
|
207 |
+
152: (Bottleneck, (3, 8, 36, 3))
|
208 |
+
}
|
209 |
+
|
210 |
+
def __init__(self,
|
211 |
+
depth,
|
212 |
+
num_stages=4,
|
213 |
+
strides=(1, 2, 2, 2),
|
214 |
+
dilations=(1, 1, 1, 1),
|
215 |
+
out_indices=(0, 1, 2, 3),
|
216 |
+
style='pytorch',
|
217 |
+
frozen_stages=-1,
|
218 |
+
bn_eval=True,
|
219 |
+
bn_frozen=False,
|
220 |
+
with_cp=False):
|
221 |
+
super(ResNet, self).__init__()
|
222 |
+
if depth not in self.arch_settings:
|
223 |
+
raise KeyError(f'invalid depth {depth} for resnet')
|
224 |
+
assert num_stages >= 1 and num_stages <= 4
|
225 |
+
block, stage_blocks = self.arch_settings[depth]
|
226 |
+
stage_blocks = stage_blocks[:num_stages]
|
227 |
+
assert len(strides) == len(dilations) == num_stages
|
228 |
+
assert max(out_indices) < num_stages
|
229 |
+
|
230 |
+
self.out_indices = out_indices
|
231 |
+
self.style = style
|
232 |
+
self.frozen_stages = frozen_stages
|
233 |
+
self.bn_eval = bn_eval
|
234 |
+
self.bn_frozen = bn_frozen
|
235 |
+
self.with_cp = with_cp
|
236 |
+
|
237 |
+
self.inplanes = 64
|
238 |
+
self.conv1 = nn.Conv2d(
|
239 |
+
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
240 |
+
self.bn1 = nn.BatchNorm2d(64)
|
241 |
+
self.relu = nn.ReLU(inplace=True)
|
242 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
243 |
+
|
244 |
+
self.res_layers = []
|
245 |
+
for i, num_blocks in enumerate(stage_blocks):
|
246 |
+
stride = strides[i]
|
247 |
+
dilation = dilations[i]
|
248 |
+
planes = 64 * 2**i
|
249 |
+
res_layer = make_res_layer(
|
250 |
+
block,
|
251 |
+
self.inplanes,
|
252 |
+
planes,
|
253 |
+
num_blocks,
|
254 |
+
stride=stride,
|
255 |
+
dilation=dilation,
|
256 |
+
style=self.style,
|
257 |
+
with_cp=with_cp)
|
258 |
+
self.inplanes = planes * block.expansion
|
259 |
+
layer_name = f'layer{i + 1}'
|
260 |
+
self.add_module(layer_name, res_layer)
|
261 |
+
self.res_layers.append(layer_name)
|
262 |
+
|
263 |
+
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
|
264 |
+
|
265 |
+
def init_weights(self, pretrained=None):
|
266 |
+
if isinstance(pretrained, str):
|
267 |
+
logger = logging.getLogger()
|
268 |
+
from ..runner import load_checkpoint
|
269 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
270 |
+
elif pretrained is None:
|
271 |
+
for m in self.modules():
|
272 |
+
if isinstance(m, nn.Conv2d):
|
273 |
+
kaiming_init(m)
|
274 |
+
elif isinstance(m, nn.BatchNorm2d):
|
275 |
+
constant_init(m, 1)
|
276 |
+
else:
|
277 |
+
raise TypeError('pretrained must be a str or None')
|
278 |
+
|
279 |
+
def forward(self, x):
|
280 |
+
x = self.conv1(x)
|
281 |
+
x = self.bn1(x)
|
282 |
+
x = self.relu(x)
|
283 |
+
x = self.maxpool(x)
|
284 |
+
outs = []
|
285 |
+
for i, layer_name in enumerate(self.res_layers):
|
286 |
+
res_layer = getattr(self, layer_name)
|
287 |
+
x = res_layer(x)
|
288 |
+
if i in self.out_indices:
|
289 |
+
outs.append(x)
|
290 |
+
if len(outs) == 1:
|
291 |
+
return outs[0]
|
292 |
+
else:
|
293 |
+
return tuple(outs)
|
294 |
+
|
295 |
+
def train(self, mode=True):
|
296 |
+
super(ResNet, self).train(mode)
|
297 |
+
if self.bn_eval:
|
298 |
+
for m in self.modules():
|
299 |
+
if isinstance(m, nn.BatchNorm2d):
|
300 |
+
m.eval()
|
301 |
+
if self.bn_frozen:
|
302 |
+
for params in m.parameters():
|
303 |
+
params.requires_grad = False
|
304 |
+
if mode and self.frozen_stages >= 0:
|
305 |
+
for param in self.conv1.parameters():
|
306 |
+
param.requires_grad = False
|
307 |
+
for param in self.bn1.parameters():
|
308 |
+
param.requires_grad = False
|
309 |
+
self.bn1.eval()
|
310 |
+
self.bn1.weight.requires_grad = False
|
311 |
+
self.bn1.bias.requires_grad = False
|
312 |
+
for i in range(1, self.frozen_stages + 1):
|
313 |
+
mod = getattr(self, f'layer{i}')
|
314 |
+
mod.eval()
|
315 |
+
for param in mod.parameters():
|
316 |
+
param.requires_grad = False
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .flops_counter import get_model_complexity_info
|
3 |
+
from .fuse_conv_bn import fuse_conv_bn
|
4 |
+
from .sync_bn import revert_sync_batchnorm
|
5 |
+
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
|
6 |
+
KaimingInit, NormalInit, PretrainedInit,
|
7 |
+
TruncNormalInit, UniformInit, XavierInit,
|
8 |
+
bias_init_with_prob, caffe2_xavier_init,
|
9 |
+
constant_init, initialize, kaiming_init, normal_init,
|
10 |
+
trunc_normal_init, uniform_init, xavier_init)
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
|
14 |
+
'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
|
15 |
+
'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
|
16 |
+
'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
|
17 |
+
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
18 |
+
'Caffe2XavierInit', 'revert_sync_batchnorm'
|
19 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/flops_counter.py
ADDED
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from flops-counter.pytorch by Vladislav Sovrasov
|
2 |
+
# original repo: https://github.com/sovrasov/flops-counter.pytorch
|
3 |
+
|
4 |
+
# MIT License
|
5 |
+
|
6 |
+
# Copyright (c) 2018 Vladislav Sovrasov
|
7 |
+
|
8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
10 |
+
# in the Software without restriction, including without limitation the rights
|
11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
13 |
+
# furnished to do so, subject to the following conditions:
|
14 |
+
|
15 |
+
# The above copyright notice and this permission notice shall be included in
|
16 |
+
# all copies or substantial portions of the Software.
|
17 |
+
|
18 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
19 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
20 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
21 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
22 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
23 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
24 |
+
# SOFTWARE.
|
25 |
+
|
26 |
+
import sys
|
27 |
+
from functools import partial
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
|
33 |
+
import annotator.mmpkg.mmcv as mmcv
|
34 |
+
|
35 |
+
|
36 |
+
def get_model_complexity_info(model,
|
37 |
+
input_shape,
|
38 |
+
print_per_layer_stat=True,
|
39 |
+
as_strings=True,
|
40 |
+
input_constructor=None,
|
41 |
+
flush=False,
|
42 |
+
ost=sys.stdout):
|
43 |
+
"""Get complexity information of a model.
|
44 |
+
|
45 |
+
This method can calculate FLOPs and parameter counts of a model with
|
46 |
+
corresponding input shape. It can also print complexity information for
|
47 |
+
each layer in a model.
|
48 |
+
|
49 |
+
Supported layers are listed as below:
|
50 |
+
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
|
51 |
+
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
|
52 |
+
``nn.ReLU6``.
|
53 |
+
- Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
|
54 |
+
``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
|
55 |
+
``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
|
56 |
+
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
|
57 |
+
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
|
58 |
+
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
|
59 |
+
``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
|
60 |
+
``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
|
61 |
+
- Linear: ``nn.Linear``.
|
62 |
+
- Deconvolution: ``nn.ConvTranspose2d``.
|
63 |
+
- Upsample: ``nn.Upsample``.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
model (nn.Module): The model for complexity calculation.
|
67 |
+
input_shape (tuple): Input shape used for calculation.
|
68 |
+
print_per_layer_stat (bool): Whether to print complexity information
|
69 |
+
for each layer in a model. Default: True.
|
70 |
+
as_strings (bool): Output FLOPs and params counts in a string form.
|
71 |
+
Default: True.
|
72 |
+
input_constructor (None | callable): If specified, it takes a callable
|
73 |
+
method that generates input. otherwise, it will generate a random
|
74 |
+
tensor with input shape to calculate FLOPs. Default: None.
|
75 |
+
flush (bool): same as that in :func:`print`. Default: False.
|
76 |
+
ost (stream): same as ``file`` param in :func:`print`.
|
77 |
+
Default: sys.stdout.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tuple[float | str]: If ``as_strings`` is set to True, it will return
|
81 |
+
FLOPs and parameter counts in a string format. otherwise, it will
|
82 |
+
return those in a float number format.
|
83 |
+
"""
|
84 |
+
assert type(input_shape) is tuple
|
85 |
+
assert len(input_shape) >= 1
|
86 |
+
assert isinstance(model, nn.Module)
|
87 |
+
flops_model = add_flops_counting_methods(model)
|
88 |
+
flops_model.eval()
|
89 |
+
flops_model.start_flops_count()
|
90 |
+
if input_constructor:
|
91 |
+
input = input_constructor(input_shape)
|
92 |
+
_ = flops_model(**input)
|
93 |
+
else:
|
94 |
+
try:
|
95 |
+
batch = torch.ones(()).new_empty(
|
96 |
+
(1, *input_shape),
|
97 |
+
dtype=next(flops_model.parameters()).dtype,
|
98 |
+
device=next(flops_model.parameters()).device)
|
99 |
+
except StopIteration:
|
100 |
+
# Avoid StopIteration for models which have no parameters,
|
101 |
+
# like `nn.Relu()`, `nn.AvgPool2d`, etc.
|
102 |
+
batch = torch.ones(()).new_empty((1, *input_shape))
|
103 |
+
|
104 |
+
_ = flops_model(batch)
|
105 |
+
|
106 |
+
flops_count, params_count = flops_model.compute_average_flops_cost()
|
107 |
+
if print_per_layer_stat:
|
108 |
+
print_model_with_flops(
|
109 |
+
flops_model, flops_count, params_count, ost=ost, flush=flush)
|
110 |
+
flops_model.stop_flops_count()
|
111 |
+
|
112 |
+
if as_strings:
|
113 |
+
return flops_to_string(flops_count), params_to_string(params_count)
|
114 |
+
|
115 |
+
return flops_count, params_count
|
116 |
+
|
117 |
+
|
118 |
+
def flops_to_string(flops, units='GFLOPs', precision=2):
|
119 |
+
"""Convert FLOPs number into a string.
|
120 |
+
|
121 |
+
Note that Here we take a multiply-add counts as one FLOP.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
flops (float): FLOPs number to be converted.
|
125 |
+
units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
|
126 |
+
'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
|
127 |
+
choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
|
128 |
+
precision (int): Digit number after the decimal point. Default: 2.
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
str: The converted FLOPs number with units.
|
132 |
+
|
133 |
+
Examples:
|
134 |
+
>>> flops_to_string(1e9)
|
135 |
+
'1.0 GFLOPs'
|
136 |
+
>>> flops_to_string(2e5, 'MFLOPs')
|
137 |
+
'0.2 MFLOPs'
|
138 |
+
>>> flops_to_string(3e-9, None)
|
139 |
+
'3e-09 FLOPs'
|
140 |
+
"""
|
141 |
+
if units is None:
|
142 |
+
if flops // 10**9 > 0:
|
143 |
+
return str(round(flops / 10.**9, precision)) + ' GFLOPs'
|
144 |
+
elif flops // 10**6 > 0:
|
145 |
+
return str(round(flops / 10.**6, precision)) + ' MFLOPs'
|
146 |
+
elif flops // 10**3 > 0:
|
147 |
+
return str(round(flops / 10.**3, precision)) + ' KFLOPs'
|
148 |
+
else:
|
149 |
+
return str(flops) + ' FLOPs'
|
150 |
+
else:
|
151 |
+
if units == 'GFLOPs':
|
152 |
+
return str(round(flops / 10.**9, precision)) + ' ' + units
|
153 |
+
elif units == 'MFLOPs':
|
154 |
+
return str(round(flops / 10.**6, precision)) + ' ' + units
|
155 |
+
elif units == 'KFLOPs':
|
156 |
+
return str(round(flops / 10.**3, precision)) + ' ' + units
|
157 |
+
else:
|
158 |
+
return str(flops) + ' FLOPs'
|
159 |
+
|
160 |
+
|
161 |
+
def params_to_string(num_params, units=None, precision=2):
|
162 |
+
"""Convert parameter number into a string.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
num_params (float): Parameter number to be converted.
|
166 |
+
units (str | None): Converted FLOPs units. Options are None, 'M',
|
167 |
+
'K' and ''. If set to None, it will automatically choose the most
|
168 |
+
suitable unit for Parameter number. Default: None.
|
169 |
+
precision (int): Digit number after the decimal point. Default: 2.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
str: The converted parameter number with units.
|
173 |
+
|
174 |
+
Examples:
|
175 |
+
>>> params_to_string(1e9)
|
176 |
+
'1000.0 M'
|
177 |
+
>>> params_to_string(2e5)
|
178 |
+
'200.0 k'
|
179 |
+
>>> params_to_string(3e-9)
|
180 |
+
'3e-09'
|
181 |
+
"""
|
182 |
+
if units is None:
|
183 |
+
if num_params // 10**6 > 0:
|
184 |
+
return str(round(num_params / 10**6, precision)) + ' M'
|
185 |
+
elif num_params // 10**3:
|
186 |
+
return str(round(num_params / 10**3, precision)) + ' k'
|
187 |
+
else:
|
188 |
+
return str(num_params)
|
189 |
+
else:
|
190 |
+
if units == 'M':
|
191 |
+
return str(round(num_params / 10.**6, precision)) + ' ' + units
|
192 |
+
elif units == 'K':
|
193 |
+
return str(round(num_params / 10.**3, precision)) + ' ' + units
|
194 |
+
else:
|
195 |
+
return str(num_params)
|
196 |
+
|
197 |
+
|
198 |
+
def print_model_with_flops(model,
|
199 |
+
total_flops,
|
200 |
+
total_params,
|
201 |
+
units='GFLOPs',
|
202 |
+
precision=3,
|
203 |
+
ost=sys.stdout,
|
204 |
+
flush=False):
|
205 |
+
"""Print a model with FLOPs for each layer.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
model (nn.Module): The model to be printed.
|
209 |
+
total_flops (float): Total FLOPs of the model.
|
210 |
+
total_params (float): Total parameter counts of the model.
|
211 |
+
units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
|
212 |
+
precision (int): Digit number after the decimal point. Default: 3.
|
213 |
+
ost (stream): same as `file` param in :func:`print`.
|
214 |
+
Default: sys.stdout.
|
215 |
+
flush (bool): same as that in :func:`print`. Default: False.
|
216 |
+
|
217 |
+
Example:
|
218 |
+
>>> class ExampleModel(nn.Module):
|
219 |
+
|
220 |
+
>>> def __init__(self):
|
221 |
+
>>> super().__init__()
|
222 |
+
>>> self.conv1 = nn.Conv2d(3, 8, 3)
|
223 |
+
>>> self.conv2 = nn.Conv2d(8, 256, 3)
|
224 |
+
>>> self.conv3 = nn.Conv2d(256, 8, 3)
|
225 |
+
>>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
226 |
+
>>> self.flatten = nn.Flatten()
|
227 |
+
>>> self.fc = nn.Linear(8, 1)
|
228 |
+
|
229 |
+
>>> def forward(self, x):
|
230 |
+
>>> x = self.conv1(x)
|
231 |
+
>>> x = self.conv2(x)
|
232 |
+
>>> x = self.conv3(x)
|
233 |
+
>>> x = self.avg_pool(x)
|
234 |
+
>>> x = self.flatten(x)
|
235 |
+
>>> x = self.fc(x)
|
236 |
+
>>> return x
|
237 |
+
|
238 |
+
>>> model = ExampleModel()
|
239 |
+
>>> x = (3, 16, 16)
|
240 |
+
to print the complexity information state for each layer, you can use
|
241 |
+
>>> get_model_complexity_info(model, x)
|
242 |
+
or directly use
|
243 |
+
>>> print_model_with_flops(model, 4579784.0, 37361)
|
244 |
+
ExampleModel(
|
245 |
+
0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
|
246 |
+
(conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
|
247 |
+
(conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
|
248 |
+
(conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
|
249 |
+
(avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
|
250 |
+
(flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
|
251 |
+
(fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
|
252 |
+
)
|
253 |
+
"""
|
254 |
+
|
255 |
+
def accumulate_params(self):
|
256 |
+
if is_supported_instance(self):
|
257 |
+
return self.__params__
|
258 |
+
else:
|
259 |
+
sum = 0
|
260 |
+
for m in self.children():
|
261 |
+
sum += m.accumulate_params()
|
262 |
+
return sum
|
263 |
+
|
264 |
+
def accumulate_flops(self):
|
265 |
+
if is_supported_instance(self):
|
266 |
+
return self.__flops__ / model.__batch_counter__
|
267 |
+
else:
|
268 |
+
sum = 0
|
269 |
+
for m in self.children():
|
270 |
+
sum += m.accumulate_flops()
|
271 |
+
return sum
|
272 |
+
|
273 |
+
def flops_repr(self):
|
274 |
+
accumulated_num_params = self.accumulate_params()
|
275 |
+
accumulated_flops_cost = self.accumulate_flops()
|
276 |
+
return ', '.join([
|
277 |
+
params_to_string(
|
278 |
+
accumulated_num_params, units='M', precision=precision),
|
279 |
+
'{:.3%} Params'.format(accumulated_num_params / total_params),
|
280 |
+
flops_to_string(
|
281 |
+
accumulated_flops_cost, units=units, precision=precision),
|
282 |
+
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
|
283 |
+
self.original_extra_repr()
|
284 |
+
])
|
285 |
+
|
286 |
+
def add_extra_repr(m):
|
287 |
+
m.accumulate_flops = accumulate_flops.__get__(m)
|
288 |
+
m.accumulate_params = accumulate_params.__get__(m)
|
289 |
+
flops_extra_repr = flops_repr.__get__(m)
|
290 |
+
if m.extra_repr != flops_extra_repr:
|
291 |
+
m.original_extra_repr = m.extra_repr
|
292 |
+
m.extra_repr = flops_extra_repr
|
293 |
+
assert m.extra_repr != m.original_extra_repr
|
294 |
+
|
295 |
+
def del_extra_repr(m):
|
296 |
+
if hasattr(m, 'original_extra_repr'):
|
297 |
+
m.extra_repr = m.original_extra_repr
|
298 |
+
del m.original_extra_repr
|
299 |
+
if hasattr(m, 'accumulate_flops'):
|
300 |
+
del m.accumulate_flops
|
301 |
+
|
302 |
+
model.apply(add_extra_repr)
|
303 |
+
print(model, file=ost, flush=flush)
|
304 |
+
model.apply(del_extra_repr)
|
305 |
+
|
306 |
+
|
307 |
+
def get_model_parameters_number(model):
|
308 |
+
"""Calculate parameter number of a model.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
model (nn.module): The model for parameter number calculation.
|
312 |
+
|
313 |
+
Returns:
|
314 |
+
float: Parameter number of the model.
|
315 |
+
"""
|
316 |
+
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
317 |
+
return num_params
|
318 |
+
|
319 |
+
|
320 |
+
def add_flops_counting_methods(net_main_module):
|
321 |
+
# adding additional methods to the existing module object,
|
322 |
+
# this is done this way so that each function has access to self object
|
323 |
+
net_main_module.start_flops_count = start_flops_count.__get__(
|
324 |
+
net_main_module)
|
325 |
+
net_main_module.stop_flops_count = stop_flops_count.__get__(
|
326 |
+
net_main_module)
|
327 |
+
net_main_module.reset_flops_count = reset_flops_count.__get__(
|
328 |
+
net_main_module)
|
329 |
+
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501
|
330 |
+
net_main_module)
|
331 |
+
|
332 |
+
net_main_module.reset_flops_count()
|
333 |
+
|
334 |
+
return net_main_module
|
335 |
+
|
336 |
+
|
337 |
+
def compute_average_flops_cost(self):
|
338 |
+
"""Compute average FLOPs cost.
|
339 |
+
|
340 |
+
A method to compute average FLOPs cost, which will be available after
|
341 |
+
`add_flops_counting_methods()` is called on a desired net object.
|
342 |
+
|
343 |
+
Returns:
|
344 |
+
float: Current mean flops consumption per image.
|
345 |
+
"""
|
346 |
+
batches_count = self.__batch_counter__
|
347 |
+
flops_sum = 0
|
348 |
+
for module in self.modules():
|
349 |
+
if is_supported_instance(module):
|
350 |
+
flops_sum += module.__flops__
|
351 |
+
params_sum = get_model_parameters_number(self)
|
352 |
+
return flops_sum / batches_count, params_sum
|
353 |
+
|
354 |
+
|
355 |
+
def start_flops_count(self):
|
356 |
+
"""Activate the computation of mean flops consumption per image.
|
357 |
+
|
358 |
+
A method to activate the computation of mean flops consumption per image.
|
359 |
+
which will be available after ``add_flops_counting_methods()`` is called on
|
360 |
+
a desired net object. It should be called before running the network.
|
361 |
+
"""
|
362 |
+
add_batch_counter_hook_function(self)
|
363 |
+
|
364 |
+
def add_flops_counter_hook_function(module):
|
365 |
+
if is_supported_instance(module):
|
366 |
+
if hasattr(module, '__flops_handle__'):
|
367 |
+
return
|
368 |
+
|
369 |
+
else:
|
370 |
+
handle = module.register_forward_hook(
|
371 |
+
get_modules_mapping()[type(module)])
|
372 |
+
|
373 |
+
module.__flops_handle__ = handle
|
374 |
+
|
375 |
+
self.apply(partial(add_flops_counter_hook_function))
|
376 |
+
|
377 |
+
|
378 |
+
def stop_flops_count(self):
|
379 |
+
"""Stop computing the mean flops consumption per image.
|
380 |
+
|
381 |
+
A method to stop computing the mean flops consumption per image, which will
|
382 |
+
be available after ``add_flops_counting_methods()`` is called on a desired
|
383 |
+
net object. It can be called to pause the computation whenever.
|
384 |
+
"""
|
385 |
+
remove_batch_counter_hook_function(self)
|
386 |
+
self.apply(remove_flops_counter_hook_function)
|
387 |
+
|
388 |
+
|
389 |
+
def reset_flops_count(self):
|
390 |
+
"""Reset statistics computed so far.
|
391 |
+
|
392 |
+
A method to Reset computed statistics, which will be available after
|
393 |
+
`add_flops_counting_methods()` is called on a desired net object.
|
394 |
+
"""
|
395 |
+
add_batch_counter_variables_or_reset(self)
|
396 |
+
self.apply(add_flops_counter_variable_or_reset)
|
397 |
+
|
398 |
+
|
399 |
+
# ---- Internal functions
|
400 |
+
def empty_flops_counter_hook(module, input, output):
|
401 |
+
module.__flops__ += 0
|
402 |
+
|
403 |
+
|
404 |
+
def upsample_flops_counter_hook(module, input, output):
|
405 |
+
output_size = output[0]
|
406 |
+
batch_size = output_size.shape[0]
|
407 |
+
output_elements_count = batch_size
|
408 |
+
for val in output_size.shape[1:]:
|
409 |
+
output_elements_count *= val
|
410 |
+
module.__flops__ += int(output_elements_count)
|
411 |
+
|
412 |
+
|
413 |
+
def relu_flops_counter_hook(module, input, output):
|
414 |
+
active_elements_count = output.numel()
|
415 |
+
module.__flops__ += int(active_elements_count)
|
416 |
+
|
417 |
+
|
418 |
+
def linear_flops_counter_hook(module, input, output):
|
419 |
+
input = input[0]
|
420 |
+
output_last_dim = output.shape[
|
421 |
+
-1] # pytorch checks dimensions, so here we don't care much
|
422 |
+
module.__flops__ += int(np.prod(input.shape) * output_last_dim)
|
423 |
+
|
424 |
+
|
425 |
+
def pool_flops_counter_hook(module, input, output):
|
426 |
+
input = input[0]
|
427 |
+
module.__flops__ += int(np.prod(input.shape))
|
428 |
+
|
429 |
+
|
430 |
+
def norm_flops_counter_hook(module, input, output):
|
431 |
+
input = input[0]
|
432 |
+
|
433 |
+
batch_flops = np.prod(input.shape)
|
434 |
+
if (getattr(module, 'affine', False)
|
435 |
+
or getattr(module, 'elementwise_affine', False)):
|
436 |
+
batch_flops *= 2
|
437 |
+
module.__flops__ += int(batch_flops)
|
438 |
+
|
439 |
+
|
440 |
+
def deconv_flops_counter_hook(conv_module, input, output):
|
441 |
+
# Can have multiple inputs, getting the first one
|
442 |
+
input = input[0]
|
443 |
+
|
444 |
+
batch_size = input.shape[0]
|
445 |
+
input_height, input_width = input.shape[2:]
|
446 |
+
|
447 |
+
kernel_height, kernel_width = conv_module.kernel_size
|
448 |
+
in_channels = conv_module.in_channels
|
449 |
+
out_channels = conv_module.out_channels
|
450 |
+
groups = conv_module.groups
|
451 |
+
|
452 |
+
filters_per_channel = out_channels // groups
|
453 |
+
conv_per_position_flops = (
|
454 |
+
kernel_height * kernel_width * in_channels * filters_per_channel)
|
455 |
+
|
456 |
+
active_elements_count = batch_size * input_height * input_width
|
457 |
+
overall_conv_flops = conv_per_position_flops * active_elements_count
|
458 |
+
bias_flops = 0
|
459 |
+
if conv_module.bias is not None:
|
460 |
+
output_height, output_width = output.shape[2:]
|
461 |
+
bias_flops = out_channels * batch_size * output_height * output_height
|
462 |
+
overall_flops = overall_conv_flops + bias_flops
|
463 |
+
|
464 |
+
conv_module.__flops__ += int(overall_flops)
|
465 |
+
|
466 |
+
|
467 |
+
def conv_flops_counter_hook(conv_module, input, output):
|
468 |
+
# Can have multiple inputs, getting the first one
|
469 |
+
input = input[0]
|
470 |
+
|
471 |
+
batch_size = input.shape[0]
|
472 |
+
output_dims = list(output.shape[2:])
|
473 |
+
|
474 |
+
kernel_dims = list(conv_module.kernel_size)
|
475 |
+
in_channels = conv_module.in_channels
|
476 |
+
out_channels = conv_module.out_channels
|
477 |
+
groups = conv_module.groups
|
478 |
+
|
479 |
+
filters_per_channel = out_channels // groups
|
480 |
+
conv_per_position_flops = int(
|
481 |
+
np.prod(kernel_dims)) * in_channels * filters_per_channel
|
482 |
+
|
483 |
+
active_elements_count = batch_size * int(np.prod(output_dims))
|
484 |
+
|
485 |
+
overall_conv_flops = conv_per_position_flops * active_elements_count
|
486 |
+
|
487 |
+
bias_flops = 0
|
488 |
+
|
489 |
+
if conv_module.bias is not None:
|
490 |
+
|
491 |
+
bias_flops = out_channels * active_elements_count
|
492 |
+
|
493 |
+
overall_flops = overall_conv_flops + bias_flops
|
494 |
+
|
495 |
+
conv_module.__flops__ += int(overall_flops)
|
496 |
+
|
497 |
+
|
498 |
+
def batch_counter_hook(module, input, output):
|
499 |
+
batch_size = 1
|
500 |
+
if len(input) > 0:
|
501 |
+
# Can have multiple inputs, getting the first one
|
502 |
+
input = input[0]
|
503 |
+
batch_size = len(input)
|
504 |
+
else:
|
505 |
+
pass
|
506 |
+
print('Warning! No positional inputs found for a module, '
|
507 |
+
'assuming batch size is 1.')
|
508 |
+
module.__batch_counter__ += batch_size
|
509 |
+
|
510 |
+
|
511 |
+
def add_batch_counter_variables_or_reset(module):
|
512 |
+
|
513 |
+
module.__batch_counter__ = 0
|
514 |
+
|
515 |
+
|
516 |
+
def add_batch_counter_hook_function(module):
|
517 |
+
if hasattr(module, '__batch_counter_handle__'):
|
518 |
+
return
|
519 |
+
|
520 |
+
handle = module.register_forward_hook(batch_counter_hook)
|
521 |
+
module.__batch_counter_handle__ = handle
|
522 |
+
|
523 |
+
|
524 |
+
def remove_batch_counter_hook_function(module):
|
525 |
+
if hasattr(module, '__batch_counter_handle__'):
|
526 |
+
module.__batch_counter_handle__.remove()
|
527 |
+
del module.__batch_counter_handle__
|
528 |
+
|
529 |
+
|
530 |
+
def add_flops_counter_variable_or_reset(module):
|
531 |
+
if is_supported_instance(module):
|
532 |
+
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
|
533 |
+
print('Warning: variables __flops__ or __params__ are already '
|
534 |
+
'defined for the module' + type(module).__name__ +
|
535 |
+
' ptflops can affect your code!')
|
536 |
+
module.__flops__ = 0
|
537 |
+
module.__params__ = get_model_parameters_number(module)
|
538 |
+
|
539 |
+
|
540 |
+
def is_supported_instance(module):
|
541 |
+
if type(module) in get_modules_mapping():
|
542 |
+
return True
|
543 |
+
return False
|
544 |
+
|
545 |
+
|
546 |
+
def remove_flops_counter_hook_function(module):
|
547 |
+
if is_supported_instance(module):
|
548 |
+
if hasattr(module, '__flops_handle__'):
|
549 |
+
module.__flops_handle__.remove()
|
550 |
+
del module.__flops_handle__
|
551 |
+
|
552 |
+
|
553 |
+
def get_modules_mapping():
|
554 |
+
return {
|
555 |
+
# convolutions
|
556 |
+
nn.Conv1d: conv_flops_counter_hook,
|
557 |
+
nn.Conv2d: conv_flops_counter_hook,
|
558 |
+
mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
|
559 |
+
nn.Conv3d: conv_flops_counter_hook,
|
560 |
+
mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
|
561 |
+
# activations
|
562 |
+
nn.ReLU: relu_flops_counter_hook,
|
563 |
+
nn.PReLU: relu_flops_counter_hook,
|
564 |
+
nn.ELU: relu_flops_counter_hook,
|
565 |
+
nn.LeakyReLU: relu_flops_counter_hook,
|
566 |
+
nn.ReLU6: relu_flops_counter_hook,
|
567 |
+
# poolings
|
568 |
+
nn.MaxPool1d: pool_flops_counter_hook,
|
569 |
+
nn.AvgPool1d: pool_flops_counter_hook,
|
570 |
+
nn.AvgPool2d: pool_flops_counter_hook,
|
571 |
+
nn.MaxPool2d: pool_flops_counter_hook,
|
572 |
+
mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
|
573 |
+
nn.MaxPool3d: pool_flops_counter_hook,
|
574 |
+
mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
|
575 |
+
nn.AvgPool3d: pool_flops_counter_hook,
|
576 |
+
nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
|
577 |
+
nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
|
578 |
+
nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
|
579 |
+
nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
|
580 |
+
nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
|
581 |
+
nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
|
582 |
+
# normalizations
|
583 |
+
nn.BatchNorm1d: norm_flops_counter_hook,
|
584 |
+
nn.BatchNorm2d: norm_flops_counter_hook,
|
585 |
+
nn.BatchNorm3d: norm_flops_counter_hook,
|
586 |
+
nn.GroupNorm: norm_flops_counter_hook,
|
587 |
+
nn.InstanceNorm1d: norm_flops_counter_hook,
|
588 |
+
nn.InstanceNorm2d: norm_flops_counter_hook,
|
589 |
+
nn.InstanceNorm3d: norm_flops_counter_hook,
|
590 |
+
nn.LayerNorm: norm_flops_counter_hook,
|
591 |
+
# FC
|
592 |
+
nn.Linear: linear_flops_counter_hook,
|
593 |
+
mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
|
594 |
+
# Upscale
|
595 |
+
nn.Upsample: upsample_flops_counter_hook,
|
596 |
+
# Deconvolution
|
597 |
+
nn.ConvTranspose2d: deconv_flops_counter_hook,
|
598 |
+
mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
|
599 |
+
}
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/fuse_conv_bn.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
def _fuse_conv_bn(conv, bn):
|
7 |
+
"""Fuse conv and bn into one module.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
conv (nn.Module): Conv to be fused.
|
11 |
+
bn (nn.Module): BN to be fused.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
nn.Module: Fused module.
|
15 |
+
"""
|
16 |
+
conv_w = conv.weight
|
17 |
+
conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
|
18 |
+
bn.running_mean)
|
19 |
+
|
20 |
+
factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
|
21 |
+
conv.weight = nn.Parameter(conv_w *
|
22 |
+
factor.reshape([conv.out_channels, 1, 1, 1]))
|
23 |
+
conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
|
24 |
+
return conv
|
25 |
+
|
26 |
+
|
27 |
+
def fuse_conv_bn(module):
|
28 |
+
"""Recursively fuse conv and bn in a module.
|
29 |
+
|
30 |
+
During inference, the functionary of batch norm layers is turned off
|
31 |
+
but only the mean and var alone channels are used, which exposes the
|
32 |
+
chance to fuse it with the preceding conv layers to save computations and
|
33 |
+
simplify network structures.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
module (nn.Module): Module to be fused.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
nn.Module: Fused module.
|
40 |
+
"""
|
41 |
+
last_conv = None
|
42 |
+
last_conv_name = None
|
43 |
+
|
44 |
+
for name, child in module.named_children():
|
45 |
+
if isinstance(child,
|
46 |
+
(nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
|
47 |
+
if last_conv is None: # only fuse BN that is after Conv
|
48 |
+
continue
|
49 |
+
fused_conv = _fuse_conv_bn(last_conv, child)
|
50 |
+
module._modules[last_conv_name] = fused_conv
|
51 |
+
# To reduce changes, set BN as Identity instead of deleting it.
|
52 |
+
module._modules[name] = nn.Identity()
|
53 |
+
last_conv = None
|
54 |
+
elif isinstance(child, nn.Conv2d):
|
55 |
+
last_conv = child
|
56 |
+
last_conv_name = name
|
57 |
+
else:
|
58 |
+
fuse_conv_bn(child)
|
59 |
+
return module
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/sync_bn.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import annotator.mmpkg.mmcv as mmcv
|
4 |
+
|
5 |
+
|
6 |
+
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
|
7 |
+
"""A general BatchNorm layer without input dimension check.
|
8 |
+
|
9 |
+
Reproduced from @kapily's work:
|
10 |
+
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
11 |
+
The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
|
12 |
+
is `_check_input_dim` that is designed for tensor sanity checks.
|
13 |
+
The check has been bypassed in this class for the convenience of converting
|
14 |
+
SyncBatchNorm.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def _check_input_dim(self, input):
|
18 |
+
return
|
19 |
+
|
20 |
+
|
21 |
+
def revert_sync_batchnorm(module):
|
22 |
+
"""Helper function to convert all `SyncBatchNorm` (SyncBN) and
|
23 |
+
`mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
|
24 |
+
`BatchNormXd` layers.
|
25 |
+
|
26 |
+
Adapted from @kapily's work:
|
27 |
+
(https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
|
28 |
+
|
29 |
+
Args:
|
30 |
+
module (nn.Module): The module containing `SyncBatchNorm` layers.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
module_output: The converted module with `BatchNormXd` layers.
|
34 |
+
"""
|
35 |
+
module_output = module
|
36 |
+
module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
|
37 |
+
if hasattr(mmcv, 'ops'):
|
38 |
+
module_checklist.append(mmcv.ops.SyncBatchNorm)
|
39 |
+
if isinstance(module, tuple(module_checklist)):
|
40 |
+
module_output = _BatchNormXd(module.num_features, module.eps,
|
41 |
+
module.momentum, module.affine,
|
42 |
+
module.track_running_stats)
|
43 |
+
if module.affine:
|
44 |
+
# no_grad() may not be needed here but
|
45 |
+
# just to be consistent with `convert_sync_batchnorm()`
|
46 |
+
with torch.no_grad():
|
47 |
+
module_output.weight = module.weight
|
48 |
+
module_output.bias = module.bias
|
49 |
+
module_output.running_mean = module.running_mean
|
50 |
+
module_output.running_var = module.running_var
|
51 |
+
module_output.num_batches_tracked = module.num_batches_tracked
|
52 |
+
module_output.training = module.training
|
53 |
+
# qconfig exists in quantized models
|
54 |
+
if hasattr(module, 'qconfig'):
|
55 |
+
module_output.qconfig = module.qconfig
|
56 |
+
for name, child in module.named_children():
|
57 |
+
module_output.add_module(name, revert_sync_batchnorm(child))
|
58 |
+
del module
|
59 |
+
return module_output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/weight_init.py
ADDED
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg, get_logger, print_log
|
12 |
+
|
13 |
+
INITIALIZERS = Registry('initializer')
|
14 |
+
|
15 |
+
|
16 |
+
def update_init_info(module, init_info):
|
17 |
+
"""Update the `_params_init_info` in the module if the value of parameters
|
18 |
+
are changed.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
module (obj:`nn.Module`): The module of PyTorch with a user-defined
|
22 |
+
attribute `_params_init_info` which records the initialization
|
23 |
+
information.
|
24 |
+
init_info (str): The string that describes the initialization.
|
25 |
+
"""
|
26 |
+
assert hasattr(
|
27 |
+
module,
|
28 |
+
'_params_init_info'), f'Can not find `_params_init_info` in {module}'
|
29 |
+
for name, param in module.named_parameters():
|
30 |
+
|
31 |
+
assert param in module._params_init_info, (
|
32 |
+
f'Find a new :obj:`Parameter` '
|
33 |
+
f'named `{name}` during executing the '
|
34 |
+
f'`init_weights` of '
|
35 |
+
f'`{module.__class__.__name__}`. '
|
36 |
+
f'Please do not add or '
|
37 |
+
f'replace parameters during executing '
|
38 |
+
f'the `init_weights`. ')
|
39 |
+
|
40 |
+
# The parameter has been changed during executing the
|
41 |
+
# `init_weights` of module
|
42 |
+
mean_value = param.data.mean()
|
43 |
+
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
|
44 |
+
module._params_init_info[param]['init_info'] = init_info
|
45 |
+
module._params_init_info[param]['tmp_mean_value'] = mean_value
|
46 |
+
|
47 |
+
|
48 |
+
def constant_init(module, val, bias=0):
|
49 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
50 |
+
nn.init.constant_(module.weight, val)
|
51 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
52 |
+
nn.init.constant_(module.bias, bias)
|
53 |
+
|
54 |
+
|
55 |
+
def xavier_init(module, gain=1, bias=0, distribution='normal'):
|
56 |
+
assert distribution in ['uniform', 'normal']
|
57 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
58 |
+
if distribution == 'uniform':
|
59 |
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
60 |
+
else:
|
61 |
+
nn.init.xavier_normal_(module.weight, gain=gain)
|
62 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
63 |
+
nn.init.constant_(module.bias, bias)
|
64 |
+
|
65 |
+
|
66 |
+
def normal_init(module, mean=0, std=1, bias=0):
|
67 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
68 |
+
nn.init.normal_(module.weight, mean, std)
|
69 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
70 |
+
nn.init.constant_(module.bias, bias)
|
71 |
+
|
72 |
+
|
73 |
+
def trunc_normal_init(module: nn.Module,
|
74 |
+
mean: float = 0,
|
75 |
+
std: float = 1,
|
76 |
+
a: float = -2,
|
77 |
+
b: float = 2,
|
78 |
+
bias: float = 0) -> None:
|
79 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
80 |
+
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
|
81 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
82 |
+
nn.init.constant_(module.bias, bias) # type: ignore
|
83 |
+
|
84 |
+
|
85 |
+
def uniform_init(module, a=0, b=1, bias=0):
|
86 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
87 |
+
nn.init.uniform_(module.weight, a, b)
|
88 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
89 |
+
nn.init.constant_(module.bias, bias)
|
90 |
+
|
91 |
+
|
92 |
+
def kaiming_init(module,
|
93 |
+
a=0,
|
94 |
+
mode='fan_out',
|
95 |
+
nonlinearity='relu',
|
96 |
+
bias=0,
|
97 |
+
distribution='normal'):
|
98 |
+
assert distribution in ['uniform', 'normal']
|
99 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
100 |
+
if distribution == 'uniform':
|
101 |
+
nn.init.kaiming_uniform_(
|
102 |
+
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
103 |
+
else:
|
104 |
+
nn.init.kaiming_normal_(
|
105 |
+
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
|
106 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
107 |
+
nn.init.constant_(module.bias, bias)
|
108 |
+
|
109 |
+
|
110 |
+
def caffe2_xavier_init(module, bias=0):
|
111 |
+
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
|
112 |
+
# Acknowledgment to FAIR's internal code
|
113 |
+
kaiming_init(
|
114 |
+
module,
|
115 |
+
a=1,
|
116 |
+
mode='fan_in',
|
117 |
+
nonlinearity='leaky_relu',
|
118 |
+
bias=bias,
|
119 |
+
distribution='uniform')
|
120 |
+
|
121 |
+
|
122 |
+
def bias_init_with_prob(prior_prob):
|
123 |
+
"""initialize conv/fc bias value according to a given probability value."""
|
124 |
+
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
|
125 |
+
return bias_init
|
126 |
+
|
127 |
+
|
128 |
+
def _get_bases_name(m):
|
129 |
+
return [b.__name__ for b in m.__class__.__bases__]
|
130 |
+
|
131 |
+
|
132 |
+
class BaseInit(object):
|
133 |
+
|
134 |
+
def __init__(self, *, bias=0, bias_prob=None, layer=None):
|
135 |
+
self.wholemodule = False
|
136 |
+
if not isinstance(bias, (int, float)):
|
137 |
+
raise TypeError(f'bias must be a number, but got a {type(bias)}')
|
138 |
+
|
139 |
+
if bias_prob is not None:
|
140 |
+
if not isinstance(bias_prob, float):
|
141 |
+
raise TypeError(f'bias_prob type must be float, \
|
142 |
+
but got {type(bias_prob)}')
|
143 |
+
|
144 |
+
if layer is not None:
|
145 |
+
if not isinstance(layer, (str, list)):
|
146 |
+
raise TypeError(f'layer must be a str or a list of str, \
|
147 |
+
but got a {type(layer)}')
|
148 |
+
else:
|
149 |
+
layer = []
|
150 |
+
|
151 |
+
if bias_prob is not None:
|
152 |
+
self.bias = bias_init_with_prob(bias_prob)
|
153 |
+
else:
|
154 |
+
self.bias = bias
|
155 |
+
self.layer = [layer] if isinstance(layer, str) else layer
|
156 |
+
|
157 |
+
def _get_init_info(self):
|
158 |
+
info = f'{self.__class__.__name__}, bias={self.bias}'
|
159 |
+
return info
|
160 |
+
|
161 |
+
|
162 |
+
@INITIALIZERS.register_module(name='Constant')
|
163 |
+
class ConstantInit(BaseInit):
|
164 |
+
"""Initialize module parameters with constant values.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
val (int | float): the value to fill the weights in the module with
|
168 |
+
bias (int | float): the value to fill the bias. Defaults to 0.
|
169 |
+
bias_prob (float, optional): the probability for bias initialization.
|
170 |
+
Defaults to None.
|
171 |
+
layer (str | list[str], optional): the layer will be initialized.
|
172 |
+
Defaults to None.
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(self, val, **kwargs):
|
176 |
+
super().__init__(**kwargs)
|
177 |
+
self.val = val
|
178 |
+
|
179 |
+
def __call__(self, module):
|
180 |
+
|
181 |
+
def init(m):
|
182 |
+
if self.wholemodule:
|
183 |
+
constant_init(m, self.val, self.bias)
|
184 |
+
else:
|
185 |
+
layername = m.__class__.__name__
|
186 |
+
basesname = _get_bases_name(m)
|
187 |
+
if len(set(self.layer) & set([layername] + basesname)):
|
188 |
+
constant_init(m, self.val, self.bias)
|
189 |
+
|
190 |
+
module.apply(init)
|
191 |
+
if hasattr(module, '_params_init_info'):
|
192 |
+
update_init_info(module, init_info=self._get_init_info())
|
193 |
+
|
194 |
+
def _get_init_info(self):
|
195 |
+
info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
|
196 |
+
return info
|
197 |
+
|
198 |
+
|
199 |
+
@INITIALIZERS.register_module(name='Xavier')
|
200 |
+
class XavierInit(BaseInit):
|
201 |
+
r"""Initialize module parameters with values according to the method
|
202 |
+
described in `Understanding the difficulty of training deep feedforward
|
203 |
+
neural networks - Glorot, X. & Bengio, Y. (2010).
|
204 |
+
<http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
|
205 |
+
|
206 |
+
Args:
|
207 |
+
gain (int | float): an optional scaling factor. Defaults to 1.
|
208 |
+
bias (int | float): the value to fill the bias. Defaults to 0.
|
209 |
+
bias_prob (float, optional): the probability for bias initialization.
|
210 |
+
Defaults to None.
|
211 |
+
distribution (str): distribution either be ``'normal'``
|
212 |
+
or ``'uniform'``. Defaults to ``'normal'``.
|
213 |
+
layer (str | list[str], optional): the layer will be initialized.
|
214 |
+
Defaults to None.
|
215 |
+
"""
|
216 |
+
|
217 |
+
def __init__(self, gain=1, distribution='normal', **kwargs):
|
218 |
+
super().__init__(**kwargs)
|
219 |
+
self.gain = gain
|
220 |
+
self.distribution = distribution
|
221 |
+
|
222 |
+
def __call__(self, module):
|
223 |
+
|
224 |
+
def init(m):
|
225 |
+
if self.wholemodule:
|
226 |
+
xavier_init(m, self.gain, self.bias, self.distribution)
|
227 |
+
else:
|
228 |
+
layername = m.__class__.__name__
|
229 |
+
basesname = _get_bases_name(m)
|
230 |
+
if len(set(self.layer) & set([layername] + basesname)):
|
231 |
+
xavier_init(m, self.gain, self.bias, self.distribution)
|
232 |
+
|
233 |
+
module.apply(init)
|
234 |
+
if hasattr(module, '_params_init_info'):
|
235 |
+
update_init_info(module, init_info=self._get_init_info())
|
236 |
+
|
237 |
+
def _get_init_info(self):
|
238 |
+
info = f'{self.__class__.__name__}: gain={self.gain}, ' \
|
239 |
+
f'distribution={self.distribution}, bias={self.bias}'
|
240 |
+
return info
|
241 |
+
|
242 |
+
|
243 |
+
@INITIALIZERS.register_module(name='Normal')
|
244 |
+
class NormalInit(BaseInit):
|
245 |
+
r"""Initialize module parameters with the values drawn from the normal
|
246 |
+
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
mean (int | float):the mean of the normal distribution. Defaults to 0.
|
250 |
+
std (int | float): the standard deviation of the normal distribution.
|
251 |
+
Defaults to 1.
|
252 |
+
bias (int | float): the value to fill the bias. Defaults to 0.
|
253 |
+
bias_prob (float, optional): the probability for bias initialization.
|
254 |
+
Defaults to None.
|
255 |
+
layer (str | list[str], optional): the layer will be initialized.
|
256 |
+
Defaults to None.
|
257 |
+
|
258 |
+
"""
|
259 |
+
|
260 |
+
def __init__(self, mean=0, std=1, **kwargs):
|
261 |
+
super().__init__(**kwargs)
|
262 |
+
self.mean = mean
|
263 |
+
self.std = std
|
264 |
+
|
265 |
+
def __call__(self, module):
|
266 |
+
|
267 |
+
def init(m):
|
268 |
+
if self.wholemodule:
|
269 |
+
normal_init(m, self.mean, self.std, self.bias)
|
270 |
+
else:
|
271 |
+
layername = m.__class__.__name__
|
272 |
+
basesname = _get_bases_name(m)
|
273 |
+
if len(set(self.layer) & set([layername] + basesname)):
|
274 |
+
normal_init(m, self.mean, self.std, self.bias)
|
275 |
+
|
276 |
+
module.apply(init)
|
277 |
+
if hasattr(module, '_params_init_info'):
|
278 |
+
update_init_info(module, init_info=self._get_init_info())
|
279 |
+
|
280 |
+
def _get_init_info(self):
|
281 |
+
info = f'{self.__class__.__name__}: mean={self.mean},' \
|
282 |
+
f' std={self.std}, bias={self.bias}'
|
283 |
+
return info
|
284 |
+
|
285 |
+
|
286 |
+
@INITIALIZERS.register_module(name='TruncNormal')
|
287 |
+
class TruncNormalInit(BaseInit):
|
288 |
+
r"""Initialize module parameters with the values drawn from the normal
|
289 |
+
distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
|
290 |
+
outside :math:`[a, b]`.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
mean (float): the mean of the normal distribution. Defaults to 0.
|
294 |
+
std (float): the standard deviation of the normal distribution.
|
295 |
+
Defaults to 1.
|
296 |
+
a (float): The minimum cutoff value.
|
297 |
+
b ( float): The maximum cutoff value.
|
298 |
+
bias (float): the value to fill the bias. Defaults to 0.
|
299 |
+
bias_prob (float, optional): the probability for bias initialization.
|
300 |
+
Defaults to None.
|
301 |
+
layer (str | list[str], optional): the layer will be initialized.
|
302 |
+
Defaults to None.
|
303 |
+
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(self,
|
307 |
+
mean: float = 0,
|
308 |
+
std: float = 1,
|
309 |
+
a: float = -2,
|
310 |
+
b: float = 2,
|
311 |
+
**kwargs) -> None:
|
312 |
+
super().__init__(**kwargs)
|
313 |
+
self.mean = mean
|
314 |
+
self.std = std
|
315 |
+
self.a = a
|
316 |
+
self.b = b
|
317 |
+
|
318 |
+
def __call__(self, module: nn.Module) -> None:
|
319 |
+
|
320 |
+
def init(m):
|
321 |
+
if self.wholemodule:
|
322 |
+
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
|
323 |
+
self.bias)
|
324 |
+
else:
|
325 |
+
layername = m.__class__.__name__
|
326 |
+
basesname = _get_bases_name(m)
|
327 |
+
if len(set(self.layer) & set([layername] + basesname)):
|
328 |
+
trunc_normal_init(m, self.mean, self.std, self.a, self.b,
|
329 |
+
self.bias)
|
330 |
+
|
331 |
+
module.apply(init)
|
332 |
+
if hasattr(module, '_params_init_info'):
|
333 |
+
update_init_info(module, init_info=self._get_init_info())
|
334 |
+
|
335 |
+
def _get_init_info(self):
|
336 |
+
info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
|
337 |
+
f' mean={self.mean}, std={self.std}, bias={self.bias}'
|
338 |
+
return info
|
339 |
+
|
340 |
+
|
341 |
+
@INITIALIZERS.register_module(name='Uniform')
|
342 |
+
class UniformInit(BaseInit):
|
343 |
+
r"""Initialize module parameters with values drawn from the uniform
|
344 |
+
distribution :math:`\mathcal{U}(a, b)`.
|
345 |
+
|
346 |
+
Args:
|
347 |
+
a (int | float): the lower bound of the uniform distribution.
|
348 |
+
Defaults to 0.
|
349 |
+
b (int | float): the upper bound of the uniform distribution.
|
350 |
+
Defaults to 1.
|
351 |
+
bias (int | float): the value to fill the bias. Defaults to 0.
|
352 |
+
bias_prob (float, optional): the probability for bias initialization.
|
353 |
+
Defaults to None.
|
354 |
+
layer (str | list[str], optional): the layer will be initialized.
|
355 |
+
Defaults to None.
|
356 |
+
"""
|
357 |
+
|
358 |
+
def __init__(self, a=0, b=1, **kwargs):
|
359 |
+
super().__init__(**kwargs)
|
360 |
+
self.a = a
|
361 |
+
self.b = b
|
362 |
+
|
363 |
+
def __call__(self, module):
|
364 |
+
|
365 |
+
def init(m):
|
366 |
+
if self.wholemodule:
|
367 |
+
uniform_init(m, self.a, self.b, self.bias)
|
368 |
+
else:
|
369 |
+
layername = m.__class__.__name__
|
370 |
+
basesname = _get_bases_name(m)
|
371 |
+
if len(set(self.layer) & set([layername] + basesname)):
|
372 |
+
uniform_init(m, self.a, self.b, self.bias)
|
373 |
+
|
374 |
+
module.apply(init)
|
375 |
+
if hasattr(module, '_params_init_info'):
|
376 |
+
update_init_info(module, init_info=self._get_init_info())
|
377 |
+
|
378 |
+
def _get_init_info(self):
|
379 |
+
info = f'{self.__class__.__name__}: a={self.a},' \
|
380 |
+
f' b={self.b}, bias={self.bias}'
|
381 |
+
return info
|
382 |
+
|
383 |
+
|
384 |
+
@INITIALIZERS.register_module(name='Kaiming')
|
385 |
+
class KaimingInit(BaseInit):
|
386 |
+
r"""Initialize module parameters with the values according to the method
|
387 |
+
described in `Delving deep into rectifiers: Surpassing human-level
|
388 |
+
performance on ImageNet classification - He, K. et al. (2015).
|
389 |
+
<https://www.cv-foundation.org/openaccess/content_iccv_2015/
|
390 |
+
papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
|
391 |
+
|
392 |
+
Args:
|
393 |
+
a (int | float): the negative slope of the rectifier used after this
|
394 |
+
layer (only used with ``'leaky_relu'``). Defaults to 0.
|
395 |
+
mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
|
396 |
+
``'fan_in'`` preserves the magnitude of the variance of the weights
|
397 |
+
in the forward pass. Choosing ``'fan_out'`` preserves the
|
398 |
+
magnitudes in the backwards pass. Defaults to ``'fan_out'``.
|
399 |
+
nonlinearity (str): the non-linear function (`nn.functional` name),
|
400 |
+
recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
|
401 |
+
Defaults to 'relu'.
|
402 |
+
bias (int | float): the value to fill the bias. Defaults to 0.
|
403 |
+
bias_prob (float, optional): the probability for bias initialization.
|
404 |
+
Defaults to None.
|
405 |
+
distribution (str): distribution either be ``'normal'`` or
|
406 |
+
``'uniform'``. Defaults to ``'normal'``.
|
407 |
+
layer (str | list[str], optional): the layer will be initialized.
|
408 |
+
Defaults to None.
|
409 |
+
"""
|
410 |
+
|
411 |
+
def __init__(self,
|
412 |
+
a=0,
|
413 |
+
mode='fan_out',
|
414 |
+
nonlinearity='relu',
|
415 |
+
distribution='normal',
|
416 |
+
**kwargs):
|
417 |
+
super().__init__(**kwargs)
|
418 |
+
self.a = a
|
419 |
+
self.mode = mode
|
420 |
+
self.nonlinearity = nonlinearity
|
421 |
+
self.distribution = distribution
|
422 |
+
|
423 |
+
def __call__(self, module):
|
424 |
+
|
425 |
+
def init(m):
|
426 |
+
if self.wholemodule:
|
427 |
+
kaiming_init(m, self.a, self.mode, self.nonlinearity,
|
428 |
+
self.bias, self.distribution)
|
429 |
+
else:
|
430 |
+
layername = m.__class__.__name__
|
431 |
+
basesname = _get_bases_name(m)
|
432 |
+
if len(set(self.layer) & set([layername] + basesname)):
|
433 |
+
kaiming_init(m, self.a, self.mode, self.nonlinearity,
|
434 |
+
self.bias, self.distribution)
|
435 |
+
|
436 |
+
module.apply(init)
|
437 |
+
if hasattr(module, '_params_init_info'):
|
438 |
+
update_init_info(module, init_info=self._get_init_info())
|
439 |
+
|
440 |
+
def _get_init_info(self):
|
441 |
+
info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
|
442 |
+
f'nonlinearity={self.nonlinearity}, ' \
|
443 |
+
f'distribution ={self.distribution}, bias={self.bias}'
|
444 |
+
return info
|
445 |
+
|
446 |
+
|
447 |
+
@INITIALIZERS.register_module(name='Caffe2Xavier')
|
448 |
+
class Caffe2XavierInit(KaimingInit):
|
449 |
+
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
|
450 |
+
# Acknowledgment to FAIR's internal code
|
451 |
+
def __init__(self, **kwargs):
|
452 |
+
super().__init__(
|
453 |
+
a=1,
|
454 |
+
mode='fan_in',
|
455 |
+
nonlinearity='leaky_relu',
|
456 |
+
distribution='uniform',
|
457 |
+
**kwargs)
|
458 |
+
|
459 |
+
def __call__(self, module):
|
460 |
+
super().__call__(module)
|
461 |
+
|
462 |
+
|
463 |
+
@INITIALIZERS.register_module(name='Pretrained')
|
464 |
+
class PretrainedInit(object):
|
465 |
+
"""Initialize module by loading a pretrained model.
|
466 |
+
|
467 |
+
Args:
|
468 |
+
checkpoint (str): the checkpoint file of the pretrained model should
|
469 |
+
be load.
|
470 |
+
prefix (str, optional): the prefix of a sub-module in the pretrained
|
471 |
+
model. it is for loading a part of the pretrained model to
|
472 |
+
initialize. For example, if we would like to only load the
|
473 |
+
backbone of a detector model, we can set ``prefix='backbone.'``.
|
474 |
+
Defaults to None.
|
475 |
+
map_location (str): map tensors into proper locations.
|
476 |
+
"""
|
477 |
+
|
478 |
+
def __init__(self, checkpoint, prefix=None, map_location=None):
|
479 |
+
self.checkpoint = checkpoint
|
480 |
+
self.prefix = prefix
|
481 |
+
self.map_location = map_location
|
482 |
+
|
483 |
+
def __call__(self, module):
|
484 |
+
from annotator.mmpkg.mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
|
485 |
+
load_state_dict)
|
486 |
+
logger = get_logger('mmcv')
|
487 |
+
if self.prefix is None:
|
488 |
+
print_log(f'load model from: {self.checkpoint}', logger=logger)
|
489 |
+
load_checkpoint(
|
490 |
+
module,
|
491 |
+
self.checkpoint,
|
492 |
+
map_location=self.map_location,
|
493 |
+
strict=False,
|
494 |
+
logger=logger)
|
495 |
+
else:
|
496 |
+
print_log(
|
497 |
+
f'load {self.prefix} in model from: {self.checkpoint}',
|
498 |
+
logger=logger)
|
499 |
+
state_dict = _load_checkpoint_with_prefix(
|
500 |
+
self.prefix, self.checkpoint, map_location=self.map_location)
|
501 |
+
load_state_dict(module, state_dict, strict=False, logger=logger)
|
502 |
+
|
503 |
+
if hasattr(module, '_params_init_info'):
|
504 |
+
update_init_info(module, init_info=self._get_init_info())
|
505 |
+
|
506 |
+
def _get_init_info(self):
|
507 |
+
info = f'{self.__class__.__name__}: load from {self.checkpoint}'
|
508 |
+
return info
|
509 |
+
|
510 |
+
|
511 |
+
def _initialize(module, cfg, wholemodule=False):
|
512 |
+
func = build_from_cfg(cfg, INITIALIZERS)
|
513 |
+
# wholemodule flag is for override mode, there is no layer key in override
|
514 |
+
# and initializer will give init values for the whole module with the name
|
515 |
+
# in override.
|
516 |
+
func.wholemodule = wholemodule
|
517 |
+
func(module)
|
518 |
+
|
519 |
+
|
520 |
+
def _initialize_override(module, override, cfg):
|
521 |
+
if not isinstance(override, (dict, list)):
|
522 |
+
raise TypeError(f'override must be a dict or a list of dict, \
|
523 |
+
but got {type(override)}')
|
524 |
+
|
525 |
+
override = [override] if isinstance(override, dict) else override
|
526 |
+
|
527 |
+
for override_ in override:
|
528 |
+
|
529 |
+
cp_override = copy.deepcopy(override_)
|
530 |
+
name = cp_override.pop('name', None)
|
531 |
+
if name is None:
|
532 |
+
raise ValueError('`override` must contain the key "name",'
|
533 |
+
f'but got {cp_override}')
|
534 |
+
# if override only has name key, it means use args in init_cfg
|
535 |
+
if not cp_override:
|
536 |
+
cp_override.update(cfg)
|
537 |
+
# if override has name key and other args except type key, it will
|
538 |
+
# raise error
|
539 |
+
elif 'type' not in cp_override.keys():
|
540 |
+
raise ValueError(
|
541 |
+
f'`override` need "type" key, but got {cp_override}')
|
542 |
+
|
543 |
+
if hasattr(module, name):
|
544 |
+
_initialize(getattr(module, name), cp_override, wholemodule=True)
|
545 |
+
else:
|
546 |
+
raise RuntimeError(f'module did not have attribute {name}, '
|
547 |
+
f'but init_cfg is {cp_override}.')
|
548 |
+
|
549 |
+
|
550 |
+
def initialize(module, init_cfg):
|
551 |
+
"""Initialize a module.
|
552 |
+
|
553 |
+
Args:
|
554 |
+
module (``torch.nn.Module``): the module will be initialized.
|
555 |
+
init_cfg (dict | list[dict]): initialization configuration dict to
|
556 |
+
define initializer. OpenMMLab has implemented 6 initializers
|
557 |
+
including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
|
558 |
+
``Kaiming``, and ``Pretrained``.
|
559 |
+
Example:
|
560 |
+
>>> module = nn.Linear(2, 3, bias=True)
|
561 |
+
>>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
|
562 |
+
>>> initialize(module, init_cfg)
|
563 |
+
|
564 |
+
>>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
|
565 |
+
>>> # define key ``'layer'`` for initializing layer with different
|
566 |
+
>>> # configuration
|
567 |
+
>>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
|
568 |
+
dict(type='Constant', layer='Linear', val=2)]
|
569 |
+
>>> initialize(module, init_cfg)
|
570 |
+
|
571 |
+
>>> # define key``'override'`` to initialize some specific part in
|
572 |
+
>>> # module
|
573 |
+
>>> class FooNet(nn.Module):
|
574 |
+
>>> def __init__(self):
|
575 |
+
>>> super().__init__()
|
576 |
+
>>> self.feat = nn.Conv2d(3, 16, 3)
|
577 |
+
>>> self.reg = nn.Conv2d(16, 10, 3)
|
578 |
+
>>> self.cls = nn.Conv2d(16, 5, 3)
|
579 |
+
>>> model = FooNet()
|
580 |
+
>>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
|
581 |
+
>>> override=dict(type='Constant', name='reg', val=3, bias=4))
|
582 |
+
>>> initialize(model, init_cfg)
|
583 |
+
|
584 |
+
>>> model = ResNet(depth=50)
|
585 |
+
>>> # Initialize weights with the pretrained model.
|
586 |
+
>>> init_cfg = dict(type='Pretrained',
|
587 |
+
checkpoint='torchvision://resnet50')
|
588 |
+
>>> initialize(model, init_cfg)
|
589 |
+
|
590 |
+
>>> # Initialize weights of a sub-module with the specific part of
|
591 |
+
>>> # a pretrained model by using "prefix".
|
592 |
+
>>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
|
593 |
+
>>> 'retinanet_r50_fpn_1x_coco/'\
|
594 |
+
>>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
|
595 |
+
>>> init_cfg = dict(type='Pretrained',
|
596 |
+
checkpoint=url, prefix='backbone.')
|
597 |
+
"""
|
598 |
+
if not isinstance(init_cfg, (dict, list)):
|
599 |
+
raise TypeError(f'init_cfg must be a dict or a list of dict, \
|
600 |
+
but got {type(init_cfg)}')
|
601 |
+
|
602 |
+
if isinstance(init_cfg, dict):
|
603 |
+
init_cfg = [init_cfg]
|
604 |
+
|
605 |
+
for cfg in init_cfg:
|
606 |
+
# should deeply copy the original config because cfg may be used by
|
607 |
+
# other modules, e.g., one init_cfg shared by multiple bottleneck
|
608 |
+
# blocks, the expected cfg will be changed after pop and will change
|
609 |
+
# the initialization behavior of other modules
|
610 |
+
cp_cfg = copy.deepcopy(cfg)
|
611 |
+
override = cp_cfg.pop('override', None)
|
612 |
+
_initialize(module, cp_cfg)
|
613 |
+
|
614 |
+
if override is not None:
|
615 |
+
cp_cfg.pop('layer', None)
|
616 |
+
_initialize_override(module, override, cp_cfg)
|
617 |
+
else:
|
618 |
+
# All attributes in module have same initialization.
|
619 |
+
pass
|
620 |
+
|
621 |
+
|
622 |
+
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
|
623 |
+
b: float) -> Tensor:
|
624 |
+
# Method based on
|
625 |
+
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
626 |
+
# Modified from
|
627 |
+
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
628 |
+
def norm_cdf(x):
|
629 |
+
# Computes standard normal cumulative distribution function
|
630 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
631 |
+
|
632 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
633 |
+
warnings.warn(
|
634 |
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
635 |
+
'The distribution of values may be incorrect.',
|
636 |
+
stacklevel=2)
|
637 |
+
|
638 |
+
with torch.no_grad():
|
639 |
+
# Values are generated by using a truncated uniform distribution and
|
640 |
+
# then using the inverse CDF for the normal distribution.
|
641 |
+
# Get upper and lower cdf values
|
642 |
+
lower = norm_cdf((a - mean) / std)
|
643 |
+
upper = norm_cdf((b - mean) / std)
|
644 |
+
|
645 |
+
# Uniformly fill tensor with values from [lower, upper], then translate
|
646 |
+
# to [2lower-1, 2upper-1].
|
647 |
+
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
648 |
+
|
649 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
650 |
+
# standard normal
|
651 |
+
tensor.erfinv_()
|
652 |
+
|
653 |
+
# Transform to proper mean, std
|
654 |
+
tensor.mul_(std * math.sqrt(2.))
|
655 |
+
tensor.add_(mean)
|
656 |
+
|
657 |
+
# Clamp to ensure it's in the proper range
|
658 |
+
tensor.clamp_(min=a, max=b)
|
659 |
+
return tensor
|
660 |
+
|
661 |
+
|
662 |
+
def trunc_normal_(tensor: Tensor,
|
663 |
+
mean: float = 0.,
|
664 |
+
std: float = 1.,
|
665 |
+
a: float = -2.,
|
666 |
+
b: float = 2.) -> Tensor:
|
667 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
668 |
+
normal distribution. The values are effectively drawn from the
|
669 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
670 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
671 |
+
the bounds. The method used for generating the random values works
|
672 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
673 |
+
|
674 |
+
Modified from
|
675 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
676 |
+
|
677 |
+
Args:
|
678 |
+
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
|
679 |
+
mean (float): the mean of the normal distribution.
|
680 |
+
std (float): the standard deviation of the normal distribution.
|
681 |
+
a (float): the minimum cutoff value.
|
682 |
+
b (float): the maximum cutoff value.
|
683 |
+
"""
|
684 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/vgg.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .utils import constant_init, kaiming_init, normal_init
|
7 |
+
|
8 |
+
|
9 |
+
def conv3x3(in_planes, out_planes, dilation=1):
|
10 |
+
"""3x3 convolution with padding."""
|
11 |
+
return nn.Conv2d(
|
12 |
+
in_planes,
|
13 |
+
out_planes,
|
14 |
+
kernel_size=3,
|
15 |
+
padding=dilation,
|
16 |
+
dilation=dilation)
|
17 |
+
|
18 |
+
|
19 |
+
def make_vgg_layer(inplanes,
|
20 |
+
planes,
|
21 |
+
num_blocks,
|
22 |
+
dilation=1,
|
23 |
+
with_bn=False,
|
24 |
+
ceil_mode=False):
|
25 |
+
layers = []
|
26 |
+
for _ in range(num_blocks):
|
27 |
+
layers.append(conv3x3(inplanes, planes, dilation))
|
28 |
+
if with_bn:
|
29 |
+
layers.append(nn.BatchNorm2d(planes))
|
30 |
+
layers.append(nn.ReLU(inplace=True))
|
31 |
+
inplanes = planes
|
32 |
+
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
|
33 |
+
|
34 |
+
return layers
|
35 |
+
|
36 |
+
|
37 |
+
class VGG(nn.Module):
|
38 |
+
"""VGG backbone.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
depth (int): Depth of vgg, from {11, 13, 16, 19}.
|
42 |
+
with_bn (bool): Use BatchNorm or not.
|
43 |
+
num_classes (int): number of classes for classification.
|
44 |
+
num_stages (int): VGG stages, normally 5.
|
45 |
+
dilations (Sequence[int]): Dilation of each stage.
|
46 |
+
out_indices (Sequence[int]): Output from which stages.
|
47 |
+
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
48 |
+
not freezing any parameters.
|
49 |
+
bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
|
50 |
+
running stats (mean and var).
|
51 |
+
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
|
52 |
+
"""
|
53 |
+
|
54 |
+
arch_settings = {
|
55 |
+
11: (1, 1, 2, 2, 2),
|
56 |
+
13: (2, 2, 2, 2, 2),
|
57 |
+
16: (2, 2, 3, 3, 3),
|
58 |
+
19: (2, 2, 4, 4, 4)
|
59 |
+
}
|
60 |
+
|
61 |
+
def __init__(self,
|
62 |
+
depth,
|
63 |
+
with_bn=False,
|
64 |
+
num_classes=-1,
|
65 |
+
num_stages=5,
|
66 |
+
dilations=(1, 1, 1, 1, 1),
|
67 |
+
out_indices=(0, 1, 2, 3, 4),
|
68 |
+
frozen_stages=-1,
|
69 |
+
bn_eval=True,
|
70 |
+
bn_frozen=False,
|
71 |
+
ceil_mode=False,
|
72 |
+
with_last_pool=True):
|
73 |
+
super(VGG, self).__init__()
|
74 |
+
if depth not in self.arch_settings:
|
75 |
+
raise KeyError(f'invalid depth {depth} for vgg')
|
76 |
+
assert num_stages >= 1 and num_stages <= 5
|
77 |
+
stage_blocks = self.arch_settings[depth]
|
78 |
+
self.stage_blocks = stage_blocks[:num_stages]
|
79 |
+
assert len(dilations) == num_stages
|
80 |
+
assert max(out_indices) <= num_stages
|
81 |
+
|
82 |
+
self.num_classes = num_classes
|
83 |
+
self.out_indices = out_indices
|
84 |
+
self.frozen_stages = frozen_stages
|
85 |
+
self.bn_eval = bn_eval
|
86 |
+
self.bn_frozen = bn_frozen
|
87 |
+
|
88 |
+
self.inplanes = 3
|
89 |
+
start_idx = 0
|
90 |
+
vgg_layers = []
|
91 |
+
self.range_sub_modules = []
|
92 |
+
for i, num_blocks in enumerate(self.stage_blocks):
|
93 |
+
num_modules = num_blocks * (2 + with_bn) + 1
|
94 |
+
end_idx = start_idx + num_modules
|
95 |
+
dilation = dilations[i]
|
96 |
+
planes = 64 * 2**i if i < 4 else 512
|
97 |
+
vgg_layer = make_vgg_layer(
|
98 |
+
self.inplanes,
|
99 |
+
planes,
|
100 |
+
num_blocks,
|
101 |
+
dilation=dilation,
|
102 |
+
with_bn=with_bn,
|
103 |
+
ceil_mode=ceil_mode)
|
104 |
+
vgg_layers.extend(vgg_layer)
|
105 |
+
self.inplanes = planes
|
106 |
+
self.range_sub_modules.append([start_idx, end_idx])
|
107 |
+
start_idx = end_idx
|
108 |
+
if not with_last_pool:
|
109 |
+
vgg_layers.pop(-1)
|
110 |
+
self.range_sub_modules[-1][1] -= 1
|
111 |
+
self.module_name = 'features'
|
112 |
+
self.add_module(self.module_name, nn.Sequential(*vgg_layers))
|
113 |
+
|
114 |
+
if self.num_classes > 0:
|
115 |
+
self.classifier = nn.Sequential(
|
116 |
+
nn.Linear(512 * 7 * 7, 4096),
|
117 |
+
nn.ReLU(True),
|
118 |
+
nn.Dropout(),
|
119 |
+
nn.Linear(4096, 4096),
|
120 |
+
nn.ReLU(True),
|
121 |
+
nn.Dropout(),
|
122 |
+
nn.Linear(4096, num_classes),
|
123 |
+
)
|
124 |
+
|
125 |
+
def init_weights(self, pretrained=None):
|
126 |
+
if isinstance(pretrained, str):
|
127 |
+
logger = logging.getLogger()
|
128 |
+
from ..runner import load_checkpoint
|
129 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
130 |
+
elif pretrained is None:
|
131 |
+
for m in self.modules():
|
132 |
+
if isinstance(m, nn.Conv2d):
|
133 |
+
kaiming_init(m)
|
134 |
+
elif isinstance(m, nn.BatchNorm2d):
|
135 |
+
constant_init(m, 1)
|
136 |
+
elif isinstance(m, nn.Linear):
|
137 |
+
normal_init(m, std=0.01)
|
138 |
+
else:
|
139 |
+
raise TypeError('pretrained must be a str or None')
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
outs = []
|
143 |
+
vgg_layers = getattr(self, self.module_name)
|
144 |
+
for i in range(len(self.stage_blocks)):
|
145 |
+
for j in range(*self.range_sub_modules[i]):
|
146 |
+
vgg_layer = vgg_layers[j]
|
147 |
+
x = vgg_layer(x)
|
148 |
+
if i in self.out_indices:
|
149 |
+
outs.append(x)
|
150 |
+
if self.num_classes > 0:
|
151 |
+
x = x.view(x.size(0), -1)
|
152 |
+
x = self.classifier(x)
|
153 |
+
outs.append(x)
|
154 |
+
if len(outs) == 1:
|
155 |
+
return outs[0]
|
156 |
+
else:
|
157 |
+
return tuple(outs)
|
158 |
+
|
159 |
+
def train(self, mode=True):
|
160 |
+
super(VGG, self).train(mode)
|
161 |
+
if self.bn_eval:
|
162 |
+
for m in self.modules():
|
163 |
+
if isinstance(m, nn.BatchNorm2d):
|
164 |
+
m.eval()
|
165 |
+
if self.bn_frozen:
|
166 |
+
for params in m.parameters():
|
167 |
+
params.requires_grad = False
|
168 |
+
vgg_layers = getattr(self, self.module_name)
|
169 |
+
if mode and self.frozen_stages >= 0:
|
170 |
+
for i in range(self.frozen_stages):
|
171 |
+
for j in range(*self.range_sub_modules[i]):
|
172 |
+
mod = vgg_layers[j]
|
173 |
+
mod.eval()
|
174 |
+
for param in mod.parameters():
|
175 |
+
param.requires_grad = False
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/engine/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
|
3 |
+
single_gpu_test)
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
|
7 |
+
'single_gpu_test'
|
8 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/engine/test.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import pickle
|
4 |
+
import shutil
|
5 |
+
import tempfile
|
6 |
+
import time
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
|
11 |
+
import annotator.mmpkg.mmcv as mmcv
|
12 |
+
from annotator.mmpkg.mmcv.runner import get_dist_info
|
13 |
+
|
14 |
+
|
15 |
+
def single_gpu_test(model, data_loader):
|
16 |
+
"""Test model with a single gpu.
|
17 |
+
|
18 |
+
This method tests model with a single gpu and displays test progress bar.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model (nn.Module): Model to be tested.
|
22 |
+
data_loader (nn.Dataloader): Pytorch data loader.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
list: The prediction results.
|
26 |
+
"""
|
27 |
+
model.eval()
|
28 |
+
results = []
|
29 |
+
dataset = data_loader.dataset
|
30 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
31 |
+
for data in data_loader:
|
32 |
+
with torch.no_grad():
|
33 |
+
result = model(return_loss=False, **data)
|
34 |
+
results.extend(result)
|
35 |
+
|
36 |
+
# Assume result has the same length of batch_size
|
37 |
+
# refer to https://github.com/open-mmlab/mmcv/issues/985
|
38 |
+
batch_size = len(result)
|
39 |
+
for _ in range(batch_size):
|
40 |
+
prog_bar.update()
|
41 |
+
return results
|
42 |
+
|
43 |
+
|
44 |
+
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
|
45 |
+
"""Test model with multiple gpus.
|
46 |
+
|
47 |
+
This method tests model with multiple gpus and collects the results
|
48 |
+
under two different modes: gpu and cpu modes. By setting
|
49 |
+
``gpu_collect=True``, it encodes results to gpu tensors and use gpu
|
50 |
+
communication for results collection. On cpu mode it saves the results on
|
51 |
+
different gpus to ``tmpdir`` and collects them by the rank 0 worker.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
model (nn.Module): Model to be tested.
|
55 |
+
data_loader (nn.Dataloader): Pytorch data loader.
|
56 |
+
tmpdir (str): Path of directory to save the temporary results from
|
57 |
+
different gpus under cpu mode.
|
58 |
+
gpu_collect (bool): Option to use either gpu or cpu to collect results.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
list: The prediction results.
|
62 |
+
"""
|
63 |
+
model.eval()
|
64 |
+
results = []
|
65 |
+
dataset = data_loader.dataset
|
66 |
+
rank, world_size = get_dist_info()
|
67 |
+
if rank == 0:
|
68 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
69 |
+
time.sleep(2) # This line can prevent deadlock problem in some cases.
|
70 |
+
for i, data in enumerate(data_loader):
|
71 |
+
with torch.no_grad():
|
72 |
+
result = model(return_loss=False, **data)
|
73 |
+
results.extend(result)
|
74 |
+
|
75 |
+
if rank == 0:
|
76 |
+
batch_size = len(result)
|
77 |
+
batch_size_all = batch_size * world_size
|
78 |
+
if batch_size_all + prog_bar.completed > len(dataset):
|
79 |
+
batch_size_all = len(dataset) - prog_bar.completed
|
80 |
+
for _ in range(batch_size_all):
|
81 |
+
prog_bar.update()
|
82 |
+
|
83 |
+
# collect results from all ranks
|
84 |
+
if gpu_collect:
|
85 |
+
results = collect_results_gpu(results, len(dataset))
|
86 |
+
else:
|
87 |
+
results = collect_results_cpu(results, len(dataset), tmpdir)
|
88 |
+
return results
|
89 |
+
|
90 |
+
|
91 |
+
def collect_results_cpu(result_part, size, tmpdir=None):
|
92 |
+
"""Collect results under cpu mode.
|
93 |
+
|
94 |
+
On cpu mode, this function will save the results on different gpus to
|
95 |
+
``tmpdir`` and collect them by the rank 0 worker.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
result_part (list): Result list containing result parts
|
99 |
+
to be collected.
|
100 |
+
size (int): Size of the results, commonly equal to length of
|
101 |
+
the results.
|
102 |
+
tmpdir (str | None): temporal directory for collected results to
|
103 |
+
store. If set to None, it will create a random temporal directory
|
104 |
+
for it.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
list: The collected results.
|
108 |
+
"""
|
109 |
+
rank, world_size = get_dist_info()
|
110 |
+
# create a tmp dir if it is not specified
|
111 |
+
if tmpdir is None:
|
112 |
+
MAX_LEN = 512
|
113 |
+
# 32 is whitespace
|
114 |
+
dir_tensor = torch.full((MAX_LEN, ),
|
115 |
+
32,
|
116 |
+
dtype=torch.uint8,
|
117 |
+
device='cuda')
|
118 |
+
if rank == 0:
|
119 |
+
mmcv.mkdir_or_exist('.dist_test')
|
120 |
+
tmpdir = tempfile.mkdtemp(dir='.dist_test')
|
121 |
+
tmpdir = torch.tensor(
|
122 |
+
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
|
123 |
+
dir_tensor[:len(tmpdir)] = tmpdir
|
124 |
+
dist.broadcast(dir_tensor, 0)
|
125 |
+
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
|
126 |
+
else:
|
127 |
+
mmcv.mkdir_or_exist(tmpdir)
|
128 |
+
# dump the part result to the dir
|
129 |
+
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
|
130 |
+
dist.barrier()
|
131 |
+
# collect all parts
|
132 |
+
if rank != 0:
|
133 |
+
return None
|
134 |
+
else:
|
135 |
+
# load results of all parts from tmp dir
|
136 |
+
part_list = []
|
137 |
+
for i in range(world_size):
|
138 |
+
part_file = osp.join(tmpdir, f'part_{i}.pkl')
|
139 |
+
part_result = mmcv.load(part_file)
|
140 |
+
# When data is severely insufficient, an empty part_result
|
141 |
+
# on a certain gpu could makes the overall outputs empty.
|
142 |
+
if part_result:
|
143 |
+
part_list.append(part_result)
|
144 |
+
# sort the results
|
145 |
+
ordered_results = []
|
146 |
+
for res in zip(*part_list):
|
147 |
+
ordered_results.extend(list(res))
|
148 |
+
# the dataloader may pad some samples
|
149 |
+
ordered_results = ordered_results[:size]
|
150 |
+
# remove tmp dir
|
151 |
+
shutil.rmtree(tmpdir)
|
152 |
+
return ordered_results
|
153 |
+
|
154 |
+
|
155 |
+
def collect_results_gpu(result_part, size):
|
156 |
+
"""Collect results under gpu mode.
|
157 |
+
|
158 |
+
On gpu mode, this function will encode results to gpu tensors and use gpu
|
159 |
+
communication for results collection.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
result_part (list): Result list containing result parts
|
163 |
+
to be collected.
|
164 |
+
size (int): Size of the results, commonly equal to length of
|
165 |
+
the results.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
list: The collected results.
|
169 |
+
"""
|
170 |
+
rank, world_size = get_dist_info()
|
171 |
+
# dump result part to tensor with pickle
|
172 |
+
part_tensor = torch.tensor(
|
173 |
+
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
|
174 |
+
# gather all result part tensor shape
|
175 |
+
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
|
176 |
+
shape_list = [shape_tensor.clone() for _ in range(world_size)]
|
177 |
+
dist.all_gather(shape_list, shape_tensor)
|
178 |
+
# padding result part tensor to max length
|
179 |
+
shape_max = torch.tensor(shape_list).max()
|
180 |
+
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
|
181 |
+
part_send[:shape_tensor[0]] = part_tensor
|
182 |
+
part_recv_list = [
|
183 |
+
part_tensor.new_zeros(shape_max) for _ in range(world_size)
|
184 |
+
]
|
185 |
+
# gather all result part
|
186 |
+
dist.all_gather(part_recv_list, part_send)
|
187 |
+
|
188 |
+
if rank == 0:
|
189 |
+
part_list = []
|
190 |
+
for recv, shape in zip(part_recv_list, shape_list):
|
191 |
+
part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
|
192 |
+
# When data is severely insufficient, an empty part_result
|
193 |
+
# on a certain gpu could makes the overall outputs empty.
|
194 |
+
if part_result:
|
195 |
+
part_list.append(part_result)
|
196 |
+
# sort the results
|
197 |
+
ordered_results = []
|
198 |
+
for res in zip(*part_list):
|
199 |
+
ordered_results.extend(list(res))
|
200 |
+
# the dataloader may pad some samples
|
201 |
+
ordered_results = ordered_results[:size]
|
202 |
+
return ordered_results
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .file_client import BaseStorageBackend, FileClient
|
3 |
+
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
|
4 |
+
from .io import dump, load, register_handler
|
5 |
+
from .parse import dict_from_file, list_from_file
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
|
9 |
+
'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
|
10 |
+
'list_from_file', 'dict_from_file'
|
11 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/file_client.py
ADDED
@@ -0,0 +1,1148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import inspect
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import re
|
6 |
+
import tempfile
|
7 |
+
import warnings
|
8 |
+
from abc import ABCMeta, abstractmethod
|
9 |
+
from contextlib import contextmanager
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Iterable, Iterator, Optional, Tuple, Union
|
12 |
+
from urllib.request import urlopen
|
13 |
+
|
14 |
+
import annotator.mmpkg.mmcv as mmcv
|
15 |
+
from annotator.mmpkg.mmcv.utils.misc import has_method
|
16 |
+
from annotator.mmpkg.mmcv.utils.path import is_filepath
|
17 |
+
|
18 |
+
|
19 |
+
class BaseStorageBackend(metaclass=ABCMeta):
|
20 |
+
"""Abstract class of storage backends.
|
21 |
+
|
22 |
+
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
23 |
+
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
24 |
+
as texts.
|
25 |
+
"""
|
26 |
+
|
27 |
+
# a flag to indicate whether the backend can create a symlink for a file
|
28 |
+
_allow_symlink = False
|
29 |
+
|
30 |
+
@property
|
31 |
+
def name(self):
|
32 |
+
return self.__class__.__name__
|
33 |
+
|
34 |
+
@property
|
35 |
+
def allow_symlink(self):
|
36 |
+
return self._allow_symlink
|
37 |
+
|
38 |
+
@abstractmethod
|
39 |
+
def get(self, filepath):
|
40 |
+
pass
|
41 |
+
|
42 |
+
@abstractmethod
|
43 |
+
def get_text(self, filepath):
|
44 |
+
pass
|
45 |
+
|
46 |
+
|
47 |
+
class CephBackend(BaseStorageBackend):
|
48 |
+
"""Ceph storage backend (for internal use).
|
49 |
+
|
50 |
+
Args:
|
51 |
+
path_mapping (dict|None): path mapping dict from local path to Petrel
|
52 |
+
path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
|
53 |
+
will be replaced by ``dst``. Default: None.
|
54 |
+
|
55 |
+
.. warning::
|
56 |
+
:class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
|
57 |
+
please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, path_mapping=None):
|
61 |
+
try:
|
62 |
+
import ceph
|
63 |
+
except ImportError:
|
64 |
+
raise ImportError('Please install ceph to enable CephBackend.')
|
65 |
+
|
66 |
+
warnings.warn(
|
67 |
+
'CephBackend will be deprecated, please use PetrelBackend instead')
|
68 |
+
self._client = ceph.S3Client()
|
69 |
+
assert isinstance(path_mapping, dict) or path_mapping is None
|
70 |
+
self.path_mapping = path_mapping
|
71 |
+
|
72 |
+
def get(self, filepath):
|
73 |
+
filepath = str(filepath)
|
74 |
+
if self.path_mapping is not None:
|
75 |
+
for k, v in self.path_mapping.items():
|
76 |
+
filepath = filepath.replace(k, v)
|
77 |
+
value = self._client.Get(filepath)
|
78 |
+
value_buf = memoryview(value)
|
79 |
+
return value_buf
|
80 |
+
|
81 |
+
def get_text(self, filepath, encoding=None):
|
82 |
+
raise NotImplementedError
|
83 |
+
|
84 |
+
|
85 |
+
class PetrelBackend(BaseStorageBackend):
|
86 |
+
"""Petrel storage backend (for internal use).
|
87 |
+
|
88 |
+
PetrelBackend supports reading and writing data to multiple clusters.
|
89 |
+
If the file path contains the cluster name, PetrelBackend will read data
|
90 |
+
from specified cluster or write data to it. Otherwise, PetrelBackend will
|
91 |
+
access the default cluster.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
path_mapping (dict, optional): Path mapping dict from local path to
|
95 |
+
Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
|
96 |
+
``filepath`` will be replaced by ``dst``. Default: None.
|
97 |
+
enable_mc (bool, optional): Whether to enable memcached support.
|
98 |
+
Default: True.
|
99 |
+
|
100 |
+
Examples:
|
101 |
+
>>> filepath1 = 's3://path/of/file'
|
102 |
+
>>> filepath2 = 'cluster-name:s3://path/of/file'
|
103 |
+
>>> client = PetrelBackend()
|
104 |
+
>>> client.get(filepath1) # get data from default cluster
|
105 |
+
>>> client.get(filepath2) # get data from 'cluster-name' cluster
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self,
|
109 |
+
path_mapping: Optional[dict] = None,
|
110 |
+
enable_mc: bool = True):
|
111 |
+
try:
|
112 |
+
from petrel_client import client
|
113 |
+
except ImportError:
|
114 |
+
raise ImportError('Please install petrel_client to enable '
|
115 |
+
'PetrelBackend.')
|
116 |
+
|
117 |
+
self._client = client.Client(enable_mc=enable_mc)
|
118 |
+
assert isinstance(path_mapping, dict) or path_mapping is None
|
119 |
+
self.path_mapping = path_mapping
|
120 |
+
|
121 |
+
def _map_path(self, filepath: Union[str, Path]) -> str:
|
122 |
+
"""Map ``filepath`` to a string path whose prefix will be replaced by
|
123 |
+
:attr:`self.path_mapping`.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
filepath (str): Path to be mapped.
|
127 |
+
"""
|
128 |
+
filepath = str(filepath)
|
129 |
+
if self.path_mapping is not None:
|
130 |
+
for k, v in self.path_mapping.items():
|
131 |
+
filepath = filepath.replace(k, v)
|
132 |
+
return filepath
|
133 |
+
|
134 |
+
def _format_path(self, filepath: str) -> str:
|
135 |
+
"""Convert a ``filepath`` to standard format of petrel oss.
|
136 |
+
|
137 |
+
If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
|
138 |
+
environment, the ``filepath`` will be the format of
|
139 |
+
's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
|
140 |
+
above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
filepath (str): Path to be formatted.
|
144 |
+
"""
|
145 |
+
return re.sub(r'\\+', '/', filepath)
|
146 |
+
|
147 |
+
def get(self, filepath: Union[str, Path]) -> memoryview:
|
148 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
filepath (str or Path): Path to read data.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
memoryview: A memory view of expected bytes object to avoid
|
155 |
+
copying. The memoryview object can be converted to bytes by
|
156 |
+
``value_buf.tobytes()``.
|
157 |
+
"""
|
158 |
+
filepath = self._map_path(filepath)
|
159 |
+
filepath = self._format_path(filepath)
|
160 |
+
value = self._client.Get(filepath)
|
161 |
+
value_buf = memoryview(value)
|
162 |
+
return value_buf
|
163 |
+
|
164 |
+
def get_text(self,
|
165 |
+
filepath: Union[str, Path],
|
166 |
+
encoding: str = 'utf-8') -> str:
|
167 |
+
"""Read data from a given ``filepath`` with 'r' mode.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
filepath (str or Path): Path to read data.
|
171 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
172 |
+
Default: 'utf-8'.
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
str: Expected text reading from ``filepath``.
|
176 |
+
"""
|
177 |
+
return str(self.get(filepath), encoding=encoding)
|
178 |
+
|
179 |
+
def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
180 |
+
"""Save data to a given ``filepath``.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
obj (bytes): Data to be saved.
|
184 |
+
filepath (str or Path): Path to write data.
|
185 |
+
"""
|
186 |
+
filepath = self._map_path(filepath)
|
187 |
+
filepath = self._format_path(filepath)
|
188 |
+
self._client.put(filepath, obj)
|
189 |
+
|
190 |
+
def put_text(self,
|
191 |
+
obj: str,
|
192 |
+
filepath: Union[str, Path],
|
193 |
+
encoding: str = 'utf-8') -> None:
|
194 |
+
"""Save data to a given ``filepath``.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
obj (str): Data to be written.
|
198 |
+
filepath (str or Path): Path to write data.
|
199 |
+
encoding (str): The encoding format used to encode the ``obj``.
|
200 |
+
Default: 'utf-8'.
|
201 |
+
"""
|
202 |
+
self.put(bytes(obj, encoding=encoding), filepath)
|
203 |
+
|
204 |
+
def remove(self, filepath: Union[str, Path]) -> None:
|
205 |
+
"""Remove a file.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
filepath (str or Path): Path to be removed.
|
209 |
+
"""
|
210 |
+
if not has_method(self._client, 'delete'):
|
211 |
+
raise NotImplementedError(
|
212 |
+
('Current version of Petrel Python SDK has not supported '
|
213 |
+
'the `delete` method, please use a higher version or dev'
|
214 |
+
' branch instead.'))
|
215 |
+
|
216 |
+
filepath = self._map_path(filepath)
|
217 |
+
filepath = self._format_path(filepath)
|
218 |
+
self._client.delete(filepath)
|
219 |
+
|
220 |
+
def exists(self, filepath: Union[str, Path]) -> bool:
|
221 |
+
"""Check whether a file path exists.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
filepath (str or Path): Path to be checked whether exists.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
228 |
+
"""
|
229 |
+
if not (has_method(self._client, 'contains')
|
230 |
+
and has_method(self._client, 'isdir')):
|
231 |
+
raise NotImplementedError(
|
232 |
+
('Current version of Petrel Python SDK has not supported '
|
233 |
+
'the `contains` and `isdir` methods, please use a higher'
|
234 |
+
'version or dev branch instead.'))
|
235 |
+
|
236 |
+
filepath = self._map_path(filepath)
|
237 |
+
filepath = self._format_path(filepath)
|
238 |
+
return self._client.contains(filepath) or self._client.isdir(filepath)
|
239 |
+
|
240 |
+
def isdir(self, filepath: Union[str, Path]) -> bool:
|
241 |
+
"""Check whether a file path is a directory.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
filepath (str or Path): Path to be checked whether it is a
|
245 |
+
directory.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
bool: Return ``True`` if ``filepath`` points to a directory,
|
249 |
+
``False`` otherwise.
|
250 |
+
"""
|
251 |
+
if not has_method(self._client, 'isdir'):
|
252 |
+
raise NotImplementedError(
|
253 |
+
('Current version of Petrel Python SDK has not supported '
|
254 |
+
'the `isdir` method, please use a higher version or dev'
|
255 |
+
' branch instead.'))
|
256 |
+
|
257 |
+
filepath = self._map_path(filepath)
|
258 |
+
filepath = self._format_path(filepath)
|
259 |
+
return self._client.isdir(filepath)
|
260 |
+
|
261 |
+
def isfile(self, filepath: Union[str, Path]) -> bool:
|
262 |
+
"""Check whether a file path is a file.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
filepath (str or Path): Path to be checked whether it is a file.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
bool: Return ``True`` if ``filepath`` points to a file, ``False``
|
269 |
+
otherwise.
|
270 |
+
"""
|
271 |
+
if not has_method(self._client, 'contains'):
|
272 |
+
raise NotImplementedError(
|
273 |
+
('Current version of Petrel Python SDK has not supported '
|
274 |
+
'the `contains` method, please use a higher version or '
|
275 |
+
'dev branch instead.'))
|
276 |
+
|
277 |
+
filepath = self._map_path(filepath)
|
278 |
+
filepath = self._format_path(filepath)
|
279 |
+
return self._client.contains(filepath)
|
280 |
+
|
281 |
+
def join_path(self, filepath: Union[str, Path],
|
282 |
+
*filepaths: Union[str, Path]) -> str:
|
283 |
+
"""Concatenate all file paths.
|
284 |
+
|
285 |
+
Args:
|
286 |
+
filepath (str or Path): Path to be concatenated.
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
str: The result after concatenation.
|
290 |
+
"""
|
291 |
+
filepath = self._format_path(self._map_path(filepath))
|
292 |
+
if filepath.endswith('/'):
|
293 |
+
filepath = filepath[:-1]
|
294 |
+
formatted_paths = [filepath]
|
295 |
+
for path in filepaths:
|
296 |
+
formatted_paths.append(self._format_path(self._map_path(path)))
|
297 |
+
return '/'.join(formatted_paths)
|
298 |
+
|
299 |
+
@contextmanager
|
300 |
+
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
|
301 |
+
"""Download a file from ``filepath`` and return a temporary path.
|
302 |
+
|
303 |
+
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
304 |
+
can be called with ``with`` statement, and when exists from the
|
305 |
+
``with`` statement, the temporary path will be released.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
filepath (str | Path): Download a file from ``filepath``.
|
309 |
+
|
310 |
+
Examples:
|
311 |
+
>>> client = PetrelBackend()
|
312 |
+
>>> # After existing from the ``with`` clause,
|
313 |
+
>>> # the path will be removed
|
314 |
+
>>> with client.get_local_path('s3://path/of/your/file') as path:
|
315 |
+
... # do something here
|
316 |
+
|
317 |
+
Yields:
|
318 |
+
Iterable[str]: Only yield one temporary path.
|
319 |
+
"""
|
320 |
+
filepath = self._map_path(filepath)
|
321 |
+
filepath = self._format_path(filepath)
|
322 |
+
assert self.isfile(filepath)
|
323 |
+
try:
|
324 |
+
f = tempfile.NamedTemporaryFile(delete=False)
|
325 |
+
f.write(self.get(filepath))
|
326 |
+
f.close()
|
327 |
+
yield f.name
|
328 |
+
finally:
|
329 |
+
os.remove(f.name)
|
330 |
+
|
331 |
+
def list_dir_or_file(self,
|
332 |
+
dir_path: Union[str, Path],
|
333 |
+
list_dir: bool = True,
|
334 |
+
list_file: bool = True,
|
335 |
+
suffix: Optional[Union[str, Tuple[str]]] = None,
|
336 |
+
recursive: bool = False) -> Iterator[str]:
|
337 |
+
"""Scan a directory to find the interested directories or files in
|
338 |
+
arbitrary order.
|
339 |
+
|
340 |
+
Note:
|
341 |
+
Petrel has no concept of directories but it simulates the directory
|
342 |
+
hierarchy in the filesystem through public prefixes. In addition,
|
343 |
+
if the returned path ends with '/', it means the path is a public
|
344 |
+
prefix which is a logical directory.
|
345 |
+
|
346 |
+
Note:
|
347 |
+
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
348 |
+
In addition, the returned path of directory will not contains the
|
349 |
+
suffix '/' which is consistent with other backends.
|
350 |
+
|
351 |
+
Args:
|
352 |
+
dir_path (str | Path): Path of the directory.
|
353 |
+
list_dir (bool): List the directories. Default: True.
|
354 |
+
list_file (bool): List the path of files. Default: True.
|
355 |
+
suffix (str or tuple[str], optional): File suffix
|
356 |
+
that we are interested in. Default: None.
|
357 |
+
recursive (bool): If set to True, recursively scan the
|
358 |
+
directory. Default: False.
|
359 |
+
|
360 |
+
Yields:
|
361 |
+
Iterable[str]: A relative path to ``dir_path``.
|
362 |
+
"""
|
363 |
+
if not has_method(self._client, 'list'):
|
364 |
+
raise NotImplementedError(
|
365 |
+
('Current version of Petrel Python SDK has not supported '
|
366 |
+
'the `list` method, please use a higher version or dev'
|
367 |
+
' branch instead.'))
|
368 |
+
|
369 |
+
dir_path = self._map_path(dir_path)
|
370 |
+
dir_path = self._format_path(dir_path)
|
371 |
+
if list_dir and suffix is not None:
|
372 |
+
raise TypeError(
|
373 |
+
'`list_dir` should be False when `suffix` is not None')
|
374 |
+
|
375 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
376 |
+
raise TypeError('`suffix` must be a string or tuple of strings')
|
377 |
+
|
378 |
+
# Petrel's simulated directory hierarchy assumes that directory paths
|
379 |
+
# should end with `/`
|
380 |
+
if not dir_path.endswith('/'):
|
381 |
+
dir_path += '/'
|
382 |
+
|
383 |
+
root = dir_path
|
384 |
+
|
385 |
+
def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
|
386 |
+
recursive):
|
387 |
+
for path in self._client.list(dir_path):
|
388 |
+
# the `self.isdir` is not used here to determine whether path
|
389 |
+
# is a directory, because `self.isdir` relies on
|
390 |
+
# `self._client.list`
|
391 |
+
if path.endswith('/'): # a directory path
|
392 |
+
next_dir_path = self.join_path(dir_path, path)
|
393 |
+
if list_dir:
|
394 |
+
# get the relative path and exclude the last
|
395 |
+
# character '/'
|
396 |
+
rel_dir = next_dir_path[len(root):-1]
|
397 |
+
yield rel_dir
|
398 |
+
if recursive:
|
399 |
+
yield from _list_dir_or_file(next_dir_path, list_dir,
|
400 |
+
list_file, suffix,
|
401 |
+
recursive)
|
402 |
+
else: # a file path
|
403 |
+
absolute_path = self.join_path(dir_path, path)
|
404 |
+
rel_path = absolute_path[len(root):]
|
405 |
+
if (suffix is None
|
406 |
+
or rel_path.endswith(suffix)) and list_file:
|
407 |
+
yield rel_path
|
408 |
+
|
409 |
+
return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
|
410 |
+
recursive)
|
411 |
+
|
412 |
+
|
413 |
+
class MemcachedBackend(BaseStorageBackend):
|
414 |
+
"""Memcached storage backend.
|
415 |
+
|
416 |
+
Attributes:
|
417 |
+
server_list_cfg (str): Config file for memcached server list.
|
418 |
+
client_cfg (str): Config file for memcached client.
|
419 |
+
sys_path (str | None): Additional path to be appended to `sys.path`.
|
420 |
+
Default: None.
|
421 |
+
"""
|
422 |
+
|
423 |
+
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
|
424 |
+
if sys_path is not None:
|
425 |
+
import sys
|
426 |
+
sys.path.append(sys_path)
|
427 |
+
try:
|
428 |
+
import mc
|
429 |
+
except ImportError:
|
430 |
+
raise ImportError(
|
431 |
+
'Please install memcached to enable MemcachedBackend.')
|
432 |
+
|
433 |
+
self.server_list_cfg = server_list_cfg
|
434 |
+
self.client_cfg = client_cfg
|
435 |
+
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
|
436 |
+
self.client_cfg)
|
437 |
+
# mc.pyvector servers as a point which points to a memory cache
|
438 |
+
self._mc_buffer = mc.pyvector()
|
439 |
+
|
440 |
+
def get(self, filepath):
|
441 |
+
filepath = str(filepath)
|
442 |
+
import mc
|
443 |
+
self._client.Get(filepath, self._mc_buffer)
|
444 |
+
value_buf = mc.ConvertBuffer(self._mc_buffer)
|
445 |
+
return value_buf
|
446 |
+
|
447 |
+
def get_text(self, filepath, encoding=None):
|
448 |
+
raise NotImplementedError
|
449 |
+
|
450 |
+
|
451 |
+
class LmdbBackend(BaseStorageBackend):
|
452 |
+
"""Lmdb storage backend.
|
453 |
+
|
454 |
+
Args:
|
455 |
+
db_path (str): Lmdb database path.
|
456 |
+
readonly (bool, optional): Lmdb environment parameter. If True,
|
457 |
+
disallow any write operations. Default: True.
|
458 |
+
lock (bool, optional): Lmdb environment parameter. If False, when
|
459 |
+
concurrent access occurs, do not lock the database. Default: False.
|
460 |
+
readahead (bool, optional): Lmdb environment parameter. If False,
|
461 |
+
disable the OS filesystem readahead mechanism, which may improve
|
462 |
+
random read performance when a database is larger than RAM.
|
463 |
+
Default: False.
|
464 |
+
|
465 |
+
Attributes:
|
466 |
+
db_path (str): Lmdb database path.
|
467 |
+
"""
|
468 |
+
|
469 |
+
def __init__(self,
|
470 |
+
db_path,
|
471 |
+
readonly=True,
|
472 |
+
lock=False,
|
473 |
+
readahead=False,
|
474 |
+
**kwargs):
|
475 |
+
try:
|
476 |
+
import lmdb
|
477 |
+
except ImportError:
|
478 |
+
raise ImportError('Please install lmdb to enable LmdbBackend.')
|
479 |
+
|
480 |
+
self.db_path = str(db_path)
|
481 |
+
self._client = lmdb.open(
|
482 |
+
self.db_path,
|
483 |
+
readonly=readonly,
|
484 |
+
lock=lock,
|
485 |
+
readahead=readahead,
|
486 |
+
**kwargs)
|
487 |
+
|
488 |
+
def get(self, filepath):
|
489 |
+
"""Get values according to the filepath.
|
490 |
+
|
491 |
+
Args:
|
492 |
+
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
|
493 |
+
"""
|
494 |
+
filepath = str(filepath)
|
495 |
+
with self._client.begin(write=False) as txn:
|
496 |
+
value_buf = txn.get(filepath.encode('ascii'))
|
497 |
+
return value_buf
|
498 |
+
|
499 |
+
def get_text(self, filepath, encoding=None):
|
500 |
+
raise NotImplementedError
|
501 |
+
|
502 |
+
|
503 |
+
class HardDiskBackend(BaseStorageBackend):
|
504 |
+
"""Raw hard disks storage backend."""
|
505 |
+
|
506 |
+
_allow_symlink = True
|
507 |
+
|
508 |
+
def get(self, filepath: Union[str, Path]) -> bytes:
|
509 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
510 |
+
|
511 |
+
Args:
|
512 |
+
filepath (str or Path): Path to read data.
|
513 |
+
|
514 |
+
Returns:
|
515 |
+
bytes: Expected bytes object.
|
516 |
+
"""
|
517 |
+
with open(filepath, 'rb') as f:
|
518 |
+
value_buf = f.read()
|
519 |
+
return value_buf
|
520 |
+
|
521 |
+
def get_text(self,
|
522 |
+
filepath: Union[str, Path],
|
523 |
+
encoding: str = 'utf-8') -> str:
|
524 |
+
"""Read data from a given ``filepath`` with 'r' mode.
|
525 |
+
|
526 |
+
Args:
|
527 |
+
filepath (str or Path): Path to read data.
|
528 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
529 |
+
Default: 'utf-8'.
|
530 |
+
|
531 |
+
Returns:
|
532 |
+
str: Expected text reading from ``filepath``.
|
533 |
+
"""
|
534 |
+
with open(filepath, 'r', encoding=encoding) as f:
|
535 |
+
value_buf = f.read()
|
536 |
+
return value_buf
|
537 |
+
|
538 |
+
def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
539 |
+
"""Write data to a given ``filepath`` with 'wb' mode.
|
540 |
+
|
541 |
+
Note:
|
542 |
+
``put`` will create a directory if the directory of ``filepath``
|
543 |
+
does not exist.
|
544 |
+
|
545 |
+
Args:
|
546 |
+
obj (bytes): Data to be written.
|
547 |
+
filepath (str or Path): Path to write data.
|
548 |
+
"""
|
549 |
+
mmcv.mkdir_or_exist(osp.dirname(filepath))
|
550 |
+
with open(filepath, 'wb') as f:
|
551 |
+
f.write(obj)
|
552 |
+
|
553 |
+
def put_text(self,
|
554 |
+
obj: str,
|
555 |
+
filepath: Union[str, Path],
|
556 |
+
encoding: str = 'utf-8') -> None:
|
557 |
+
"""Write data to a given ``filepath`` with 'w' mode.
|
558 |
+
|
559 |
+
Note:
|
560 |
+
``put_text`` will create a directory if the directory of
|
561 |
+
``filepath`` does not exist.
|
562 |
+
|
563 |
+
Args:
|
564 |
+
obj (str): Data to be written.
|
565 |
+
filepath (str or Path): Path to write data.
|
566 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
567 |
+
Default: 'utf-8'.
|
568 |
+
"""
|
569 |
+
mmcv.mkdir_or_exist(osp.dirname(filepath))
|
570 |
+
with open(filepath, 'w', encoding=encoding) as f:
|
571 |
+
f.write(obj)
|
572 |
+
|
573 |
+
def remove(self, filepath: Union[str, Path]) -> None:
|
574 |
+
"""Remove a file.
|
575 |
+
|
576 |
+
Args:
|
577 |
+
filepath (str or Path): Path to be removed.
|
578 |
+
"""
|
579 |
+
os.remove(filepath)
|
580 |
+
|
581 |
+
def exists(self, filepath: Union[str, Path]) -> bool:
|
582 |
+
"""Check whether a file path exists.
|
583 |
+
|
584 |
+
Args:
|
585 |
+
filepath (str or Path): Path to be checked whether exists.
|
586 |
+
|
587 |
+
Returns:
|
588 |
+
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
589 |
+
"""
|
590 |
+
return osp.exists(filepath)
|
591 |
+
|
592 |
+
def isdir(self, filepath: Union[str, Path]) -> bool:
|
593 |
+
"""Check whether a file path is a directory.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
filepath (str or Path): Path to be checked whether it is a
|
597 |
+
directory.
|
598 |
+
|
599 |
+
Returns:
|
600 |
+
bool: Return ``True`` if ``filepath`` points to a directory,
|
601 |
+
``False`` otherwise.
|
602 |
+
"""
|
603 |
+
return osp.isdir(filepath)
|
604 |
+
|
605 |
+
def isfile(self, filepath: Union[str, Path]) -> bool:
|
606 |
+
"""Check whether a file path is a file.
|
607 |
+
|
608 |
+
Args:
|
609 |
+
filepath (str or Path): Path to be checked whether it is a file.
|
610 |
+
|
611 |
+
Returns:
|
612 |
+
bool: Return ``True`` if ``filepath`` points to a file, ``False``
|
613 |
+
otherwise.
|
614 |
+
"""
|
615 |
+
return osp.isfile(filepath)
|
616 |
+
|
617 |
+
def join_path(self, filepath: Union[str, Path],
|
618 |
+
*filepaths: Union[str, Path]) -> str:
|
619 |
+
"""Concatenate all file paths.
|
620 |
+
|
621 |
+
Join one or more filepath components intelligently. The return value
|
622 |
+
is the concatenation of filepath and any members of *filepaths.
|
623 |
+
|
624 |
+
Args:
|
625 |
+
filepath (str or Path): Path to be concatenated.
|
626 |
+
|
627 |
+
Returns:
|
628 |
+
str: The result of concatenation.
|
629 |
+
"""
|
630 |
+
return osp.join(filepath, *filepaths)
|
631 |
+
|
632 |
+
@contextmanager
|
633 |
+
def get_local_path(
|
634 |
+
self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]:
|
635 |
+
"""Only for unified API and do nothing."""
|
636 |
+
yield filepath
|
637 |
+
|
638 |
+
def list_dir_or_file(self,
|
639 |
+
dir_path: Union[str, Path],
|
640 |
+
list_dir: bool = True,
|
641 |
+
list_file: bool = True,
|
642 |
+
suffix: Optional[Union[str, Tuple[str]]] = None,
|
643 |
+
recursive: bool = False) -> Iterator[str]:
|
644 |
+
"""Scan a directory to find the interested directories or files in
|
645 |
+
arbitrary order.
|
646 |
+
|
647 |
+
Note:
|
648 |
+
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
649 |
+
|
650 |
+
Args:
|
651 |
+
dir_path (str | Path): Path of the directory.
|
652 |
+
list_dir (bool): List the directories. Default: True.
|
653 |
+
list_file (bool): List the path of files. Default: True.
|
654 |
+
suffix (str or tuple[str], optional): File suffix
|
655 |
+
that we are interested in. Default: None.
|
656 |
+
recursive (bool): If set to True, recursively scan the
|
657 |
+
directory. Default: False.
|
658 |
+
|
659 |
+
Yields:
|
660 |
+
Iterable[str]: A relative path to ``dir_path``.
|
661 |
+
"""
|
662 |
+
if list_dir and suffix is not None:
|
663 |
+
raise TypeError('`suffix` should be None when `list_dir` is True')
|
664 |
+
|
665 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
666 |
+
raise TypeError('`suffix` must be a string or tuple of strings')
|
667 |
+
|
668 |
+
root = dir_path
|
669 |
+
|
670 |
+
def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
|
671 |
+
recursive):
|
672 |
+
for entry in os.scandir(dir_path):
|
673 |
+
if not entry.name.startswith('.') and entry.is_file():
|
674 |
+
rel_path = osp.relpath(entry.path, root)
|
675 |
+
if (suffix is None
|
676 |
+
or rel_path.endswith(suffix)) and list_file:
|
677 |
+
yield rel_path
|
678 |
+
elif osp.isdir(entry.path):
|
679 |
+
if list_dir:
|
680 |
+
rel_dir = osp.relpath(entry.path, root)
|
681 |
+
yield rel_dir
|
682 |
+
if recursive:
|
683 |
+
yield from _list_dir_or_file(entry.path, list_dir,
|
684 |
+
list_file, suffix,
|
685 |
+
recursive)
|
686 |
+
|
687 |
+
return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
|
688 |
+
recursive)
|
689 |
+
|
690 |
+
|
691 |
+
class HTTPBackend(BaseStorageBackend):
|
692 |
+
"""HTTP and HTTPS storage bachend."""
|
693 |
+
|
694 |
+
def get(self, filepath):
|
695 |
+
value_buf = urlopen(filepath).read()
|
696 |
+
return value_buf
|
697 |
+
|
698 |
+
def get_text(self, filepath, encoding='utf-8'):
|
699 |
+
value_buf = urlopen(filepath).read()
|
700 |
+
return value_buf.decode(encoding)
|
701 |
+
|
702 |
+
@contextmanager
|
703 |
+
def get_local_path(self, filepath: str) -> Iterable[str]:
|
704 |
+
"""Download a file from ``filepath``.
|
705 |
+
|
706 |
+
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
707 |
+
can be called with ``with`` statement, and when exists from the
|
708 |
+
``with`` statement, the temporary path will be released.
|
709 |
+
|
710 |
+
Args:
|
711 |
+
filepath (str): Download a file from ``filepath``.
|
712 |
+
|
713 |
+
Examples:
|
714 |
+
>>> client = HTTPBackend()
|
715 |
+
>>> # After existing from the ``with`` clause,
|
716 |
+
>>> # the path will be removed
|
717 |
+
>>> with client.get_local_path('http://path/of/your/file') as path:
|
718 |
+
... # do something here
|
719 |
+
"""
|
720 |
+
try:
|
721 |
+
f = tempfile.NamedTemporaryFile(delete=False)
|
722 |
+
f.write(self.get(filepath))
|
723 |
+
f.close()
|
724 |
+
yield f.name
|
725 |
+
finally:
|
726 |
+
os.remove(f.name)
|
727 |
+
|
728 |
+
|
729 |
+
class FileClient:
|
730 |
+
"""A general file client to access files in different backends.
|
731 |
+
|
732 |
+
The client loads a file or text in a specified backend from its path
|
733 |
+
and returns it as a binary or text file. There are two ways to choose a
|
734 |
+
backend, the name of backend and the prefix of path. Although both of them
|
735 |
+
can be used to choose a storage backend, ``backend`` has a higher priority
|
736 |
+
that is if they are all set, the storage backend will be chosen by the
|
737 |
+
backend argument. If they are all `None`, the disk backend will be chosen.
|
738 |
+
Note that It can also register other backend accessor with a given name,
|
739 |
+
prefixes, and backend class. In addition, We use the singleton pattern to
|
740 |
+
avoid repeated object creation. If the arguments are the same, the same
|
741 |
+
object will be returned.
|
742 |
+
|
743 |
+
Args:
|
744 |
+
backend (str, optional): The storage backend type. Options are "disk",
|
745 |
+
"ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
|
746 |
+
prefix (str, optional): The prefix of the registered storage backend.
|
747 |
+
Options are "s3", "http", "https". Default: None.
|
748 |
+
|
749 |
+
Examples:
|
750 |
+
>>> # only set backend
|
751 |
+
>>> file_client = FileClient(backend='petrel')
|
752 |
+
>>> # only set prefix
|
753 |
+
>>> file_client = FileClient(prefix='s3')
|
754 |
+
>>> # set both backend and prefix but use backend to choose client
|
755 |
+
>>> file_client = FileClient(backend='petrel', prefix='s3')
|
756 |
+
>>> # if the arguments are the same, the same object is returned
|
757 |
+
>>> file_client1 = FileClient(backend='petrel')
|
758 |
+
>>> file_client1 is file_client
|
759 |
+
True
|
760 |
+
|
761 |
+
Attributes:
|
762 |
+
client (:obj:`BaseStorageBackend`): The backend object.
|
763 |
+
"""
|
764 |
+
|
765 |
+
_backends = {
|
766 |
+
'disk': HardDiskBackend,
|
767 |
+
'ceph': CephBackend,
|
768 |
+
'memcached': MemcachedBackend,
|
769 |
+
'lmdb': LmdbBackend,
|
770 |
+
'petrel': PetrelBackend,
|
771 |
+
'http': HTTPBackend,
|
772 |
+
}
|
773 |
+
# This collection is used to record the overridden backends, and when a
|
774 |
+
# backend appears in the collection, the singleton pattern is disabled for
|
775 |
+
# that backend, because if the singleton pattern is used, then the object
|
776 |
+
# returned will be the backend before overwriting
|
777 |
+
_overridden_backends = set()
|
778 |
+
_prefix_to_backends = {
|
779 |
+
's3': PetrelBackend,
|
780 |
+
'http': HTTPBackend,
|
781 |
+
'https': HTTPBackend,
|
782 |
+
}
|
783 |
+
_overridden_prefixes = set()
|
784 |
+
|
785 |
+
_instances = {}
|
786 |
+
|
787 |
+
def __new__(cls, backend=None, prefix=None, **kwargs):
|
788 |
+
if backend is None and prefix is None:
|
789 |
+
backend = 'disk'
|
790 |
+
if backend is not None and backend not in cls._backends:
|
791 |
+
raise ValueError(
|
792 |
+
f'Backend {backend} is not supported. Currently supported ones'
|
793 |
+
f' are {list(cls._backends.keys())}')
|
794 |
+
if prefix is not None and prefix not in cls._prefix_to_backends:
|
795 |
+
raise ValueError(
|
796 |
+
f'prefix {prefix} is not supported. Currently supported ones '
|
797 |
+
f'are {list(cls._prefix_to_backends.keys())}')
|
798 |
+
|
799 |
+
# concatenate the arguments to a unique key for determining whether
|
800 |
+
# objects with the same arguments were created
|
801 |
+
arg_key = f'{backend}:{prefix}'
|
802 |
+
for key, value in kwargs.items():
|
803 |
+
arg_key += f':{key}:{value}'
|
804 |
+
|
805 |
+
# if a backend was overridden, it will create a new object
|
806 |
+
if (arg_key in cls._instances
|
807 |
+
and backend not in cls._overridden_backends
|
808 |
+
and prefix not in cls._overridden_prefixes):
|
809 |
+
_instance = cls._instances[arg_key]
|
810 |
+
else:
|
811 |
+
# create a new object and put it to _instance
|
812 |
+
_instance = super().__new__(cls)
|
813 |
+
if backend is not None:
|
814 |
+
_instance.client = cls._backends[backend](**kwargs)
|
815 |
+
else:
|
816 |
+
_instance.client = cls._prefix_to_backends[prefix](**kwargs)
|
817 |
+
|
818 |
+
cls._instances[arg_key] = _instance
|
819 |
+
|
820 |
+
return _instance
|
821 |
+
|
822 |
+
@property
|
823 |
+
def name(self):
|
824 |
+
return self.client.name
|
825 |
+
|
826 |
+
@property
|
827 |
+
def allow_symlink(self):
|
828 |
+
return self.client.allow_symlink
|
829 |
+
|
830 |
+
@staticmethod
|
831 |
+
def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
|
832 |
+
"""Parse the prefix of a uri.
|
833 |
+
|
834 |
+
Args:
|
835 |
+
uri (str | Path): Uri to be parsed that contains the file prefix.
|
836 |
+
|
837 |
+
Examples:
|
838 |
+
>>> FileClient.parse_uri_prefix('s3://path/of/your/file')
|
839 |
+
's3'
|
840 |
+
|
841 |
+
Returns:
|
842 |
+
str | None: Return the prefix of uri if the uri contains '://'
|
843 |
+
else ``None``.
|
844 |
+
"""
|
845 |
+
assert is_filepath(uri)
|
846 |
+
uri = str(uri)
|
847 |
+
if '://' not in uri:
|
848 |
+
return None
|
849 |
+
else:
|
850 |
+
prefix, _ = uri.split('://')
|
851 |
+
# In the case of PetrelBackend, the prefix may contains the cluster
|
852 |
+
# name like clusterName:s3
|
853 |
+
if ':' in prefix:
|
854 |
+
_, prefix = prefix.split(':')
|
855 |
+
return prefix
|
856 |
+
|
857 |
+
@classmethod
|
858 |
+
def infer_client(cls,
|
859 |
+
file_client_args: Optional[dict] = None,
|
860 |
+
uri: Optional[Union[str, Path]] = None) -> 'FileClient':
|
861 |
+
"""Infer a suitable file client based on the URI and arguments.
|
862 |
+
|
863 |
+
Args:
|
864 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
865 |
+
FileClient. Default: None.
|
866 |
+
uri (str | Path, optional): Uri to be parsed that contains the file
|
867 |
+
prefix. Default: None.
|
868 |
+
|
869 |
+
Examples:
|
870 |
+
>>> uri = 's3://path/of/your/file'
|
871 |
+
>>> file_client = FileClient.infer_client(uri=uri)
|
872 |
+
>>> file_client_args = {'backend': 'petrel'}
|
873 |
+
>>> file_client = FileClient.infer_client(file_client_args)
|
874 |
+
|
875 |
+
Returns:
|
876 |
+
FileClient: Instantiated FileClient object.
|
877 |
+
"""
|
878 |
+
assert file_client_args is not None or uri is not None
|
879 |
+
if file_client_args is None:
|
880 |
+
file_prefix = cls.parse_uri_prefix(uri) # type: ignore
|
881 |
+
return cls(prefix=file_prefix)
|
882 |
+
else:
|
883 |
+
return cls(**file_client_args)
|
884 |
+
|
885 |
+
@classmethod
|
886 |
+
def _register_backend(cls, name, backend, force=False, prefixes=None):
|
887 |
+
if not isinstance(name, str):
|
888 |
+
raise TypeError('the backend name should be a string, '
|
889 |
+
f'but got {type(name)}')
|
890 |
+
if not inspect.isclass(backend):
|
891 |
+
raise TypeError(
|
892 |
+
f'backend should be a class but got {type(backend)}')
|
893 |
+
if not issubclass(backend, BaseStorageBackend):
|
894 |
+
raise TypeError(
|
895 |
+
f'backend {backend} is not a subclass of BaseStorageBackend')
|
896 |
+
if not force and name in cls._backends:
|
897 |
+
raise KeyError(
|
898 |
+
f'{name} is already registered as a storage backend, '
|
899 |
+
'add "force=True" if you want to override it')
|
900 |
+
|
901 |
+
if name in cls._backends and force:
|
902 |
+
cls._overridden_backends.add(name)
|
903 |
+
cls._backends[name] = backend
|
904 |
+
|
905 |
+
if prefixes is not None:
|
906 |
+
if isinstance(prefixes, str):
|
907 |
+
prefixes = [prefixes]
|
908 |
+
else:
|
909 |
+
assert isinstance(prefixes, (list, tuple))
|
910 |
+
for prefix in prefixes:
|
911 |
+
if prefix not in cls._prefix_to_backends:
|
912 |
+
cls._prefix_to_backends[prefix] = backend
|
913 |
+
elif (prefix in cls._prefix_to_backends) and force:
|
914 |
+
cls._overridden_prefixes.add(prefix)
|
915 |
+
cls._prefix_to_backends[prefix] = backend
|
916 |
+
else:
|
917 |
+
raise KeyError(
|
918 |
+
f'{prefix} is already registered as a storage backend,'
|
919 |
+
' add "force=True" if you want to override it')
|
920 |
+
|
921 |
+
@classmethod
|
922 |
+
def register_backend(cls, name, backend=None, force=False, prefixes=None):
|
923 |
+
"""Register a backend to FileClient.
|
924 |
+
|
925 |
+
This method can be used as a normal class method or a decorator.
|
926 |
+
|
927 |
+
.. code-block:: python
|
928 |
+
|
929 |
+
class NewBackend(BaseStorageBackend):
|
930 |
+
|
931 |
+
def get(self, filepath):
|
932 |
+
return filepath
|
933 |
+
|
934 |
+
def get_text(self, filepath):
|
935 |
+
return filepath
|
936 |
+
|
937 |
+
FileClient.register_backend('new', NewBackend)
|
938 |
+
|
939 |
+
or
|
940 |
+
|
941 |
+
.. code-block:: python
|
942 |
+
|
943 |
+
@FileClient.register_backend('new')
|
944 |
+
class NewBackend(BaseStorageBackend):
|
945 |
+
|
946 |
+
def get(self, filepath):
|
947 |
+
return filepath
|
948 |
+
|
949 |
+
def get_text(self, filepath):
|
950 |
+
return filepath
|
951 |
+
|
952 |
+
Args:
|
953 |
+
name (str): The name of the registered backend.
|
954 |
+
backend (class, optional): The backend class to be registered,
|
955 |
+
which must be a subclass of :class:`BaseStorageBackend`.
|
956 |
+
When this method is used as a decorator, backend is None.
|
957 |
+
Defaults to None.
|
958 |
+
force (bool, optional): Whether to override the backend if the name
|
959 |
+
has already been registered. Defaults to False.
|
960 |
+
prefixes (str or list[str] or tuple[str], optional): The prefixes
|
961 |
+
of the registered storage backend. Default: None.
|
962 |
+
`New in version 1.3.15.`
|
963 |
+
"""
|
964 |
+
if backend is not None:
|
965 |
+
cls._register_backend(
|
966 |
+
name, backend, force=force, prefixes=prefixes)
|
967 |
+
return
|
968 |
+
|
969 |
+
def _register(backend_cls):
|
970 |
+
cls._register_backend(
|
971 |
+
name, backend_cls, force=force, prefixes=prefixes)
|
972 |
+
return backend_cls
|
973 |
+
|
974 |
+
return _register
|
975 |
+
|
976 |
+
def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
|
977 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
978 |
+
|
979 |
+
Note:
|
980 |
+
There are two types of return values for ``get``, one is ``bytes``
|
981 |
+
and the other is ``memoryview``. The advantage of using memoryview
|
982 |
+
is that you can avoid copying, and if you want to convert it to
|
983 |
+
``bytes``, you can use ``.tobytes()``.
|
984 |
+
|
985 |
+
Args:
|
986 |
+
filepath (str or Path): Path to read data.
|
987 |
+
|
988 |
+
Returns:
|
989 |
+
bytes | memoryview: Expected bytes object or a memory view of the
|
990 |
+
bytes object.
|
991 |
+
"""
|
992 |
+
return self.client.get(filepath)
|
993 |
+
|
994 |
+
def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str:
|
995 |
+
"""Read data from a given ``filepath`` with 'r' mode.
|
996 |
+
|
997 |
+
Args:
|
998 |
+
filepath (str or Path): Path to read data.
|
999 |
+
encoding (str): The encoding format used to open the ``filepath``.
|
1000 |
+
Default: 'utf-8'.
|
1001 |
+
|
1002 |
+
Returns:
|
1003 |
+
str: Expected text reading from ``filepath``.
|
1004 |
+
"""
|
1005 |
+
return self.client.get_text(filepath, encoding)
|
1006 |
+
|
1007 |
+
def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
|
1008 |
+
"""Write data to a given ``filepath`` with 'wb' mode.
|
1009 |
+
|
1010 |
+
Note:
|
1011 |
+
``put`` should create a directory if the directory of ``filepath``
|
1012 |
+
does not exist.
|
1013 |
+
|
1014 |
+
Args:
|
1015 |
+
obj (bytes): Data to be written.
|
1016 |
+
filepath (str or Path): Path to write data.
|
1017 |
+
"""
|
1018 |
+
self.client.put(obj, filepath)
|
1019 |
+
|
1020 |
+
def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
|
1021 |
+
"""Write data to a given ``filepath`` with 'w' mode.
|
1022 |
+
|
1023 |
+
Note:
|
1024 |
+
``put_text`` should create a directory if the directory of
|
1025 |
+
``filepath`` does not exist.
|
1026 |
+
|
1027 |
+
Args:
|
1028 |
+
obj (str): Data to be written.
|
1029 |
+
filepath (str or Path): Path to write data.
|
1030 |
+
encoding (str, optional): The encoding format used to open the
|
1031 |
+
`filepath`. Default: 'utf-8'.
|
1032 |
+
"""
|
1033 |
+
self.client.put_text(obj, filepath)
|
1034 |
+
|
1035 |
+
def remove(self, filepath: Union[str, Path]) -> None:
|
1036 |
+
"""Remove a file.
|
1037 |
+
|
1038 |
+
Args:
|
1039 |
+
filepath (str, Path): Path to be removed.
|
1040 |
+
"""
|
1041 |
+
self.client.remove(filepath)
|
1042 |
+
|
1043 |
+
def exists(self, filepath: Union[str, Path]) -> bool:
|
1044 |
+
"""Check whether a file path exists.
|
1045 |
+
|
1046 |
+
Args:
|
1047 |
+
filepath (str or Path): Path to be checked whether exists.
|
1048 |
+
|
1049 |
+
Returns:
|
1050 |
+
bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
|
1051 |
+
"""
|
1052 |
+
return self.client.exists(filepath)
|
1053 |
+
|
1054 |
+
def isdir(self, filepath: Union[str, Path]) -> bool:
|
1055 |
+
"""Check whether a file path is a directory.
|
1056 |
+
|
1057 |
+
Args:
|
1058 |
+
filepath (str or Path): Path to be checked whether it is a
|
1059 |
+
directory.
|
1060 |
+
|
1061 |
+
Returns:
|
1062 |
+
bool: Return ``True`` if ``filepath`` points to a directory,
|
1063 |
+
``False`` otherwise.
|
1064 |
+
"""
|
1065 |
+
return self.client.isdir(filepath)
|
1066 |
+
|
1067 |
+
def isfile(self, filepath: Union[str, Path]) -> bool:
|
1068 |
+
"""Check whether a file path is a file.
|
1069 |
+
|
1070 |
+
Args:
|
1071 |
+
filepath (str or Path): Path to be checked whether it is a file.
|
1072 |
+
|
1073 |
+
Returns:
|
1074 |
+
bool: Return ``True`` if ``filepath`` points to a file, ``False``
|
1075 |
+
otherwise.
|
1076 |
+
"""
|
1077 |
+
return self.client.isfile(filepath)
|
1078 |
+
|
1079 |
+
def join_path(self, filepath: Union[str, Path],
|
1080 |
+
*filepaths: Union[str, Path]) -> str:
|
1081 |
+
"""Concatenate all file paths.
|
1082 |
+
|
1083 |
+
Join one or more filepath components intelligently. The return value
|
1084 |
+
is the concatenation of filepath and any members of *filepaths.
|
1085 |
+
|
1086 |
+
Args:
|
1087 |
+
filepath (str or Path): Path to be concatenated.
|
1088 |
+
|
1089 |
+
Returns:
|
1090 |
+
str: The result of concatenation.
|
1091 |
+
"""
|
1092 |
+
return self.client.join_path(filepath, *filepaths)
|
1093 |
+
|
1094 |
+
@contextmanager
|
1095 |
+
def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
|
1096 |
+
"""Download data from ``filepath`` and write the data to local path.
|
1097 |
+
|
1098 |
+
``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
|
1099 |
+
can be called with ``with`` statement, and when exists from the
|
1100 |
+
``with`` statement, the temporary path will be released.
|
1101 |
+
|
1102 |
+
Note:
|
1103 |
+
If the ``filepath`` is a local path, just return itself.
|
1104 |
+
|
1105 |
+
.. warning::
|
1106 |
+
``get_local_path`` is an experimental interface that may change in
|
1107 |
+
the future.
|
1108 |
+
|
1109 |
+
Args:
|
1110 |
+
filepath (str or Path): Path to be read data.
|
1111 |
+
|
1112 |
+
Examples:
|
1113 |
+
>>> file_client = FileClient(prefix='s3')
|
1114 |
+
>>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
|
1115 |
+
... # do something here
|
1116 |
+
|
1117 |
+
Yields:
|
1118 |
+
Iterable[str]: Only yield one path.
|
1119 |
+
"""
|
1120 |
+
with self.client.get_local_path(str(filepath)) as local_path:
|
1121 |
+
yield local_path
|
1122 |
+
|
1123 |
+
def list_dir_or_file(self,
|
1124 |
+
dir_path: Union[str, Path],
|
1125 |
+
list_dir: bool = True,
|
1126 |
+
list_file: bool = True,
|
1127 |
+
suffix: Optional[Union[str, Tuple[str]]] = None,
|
1128 |
+
recursive: bool = False) -> Iterator[str]:
|
1129 |
+
"""Scan a directory to find the interested directories or files in
|
1130 |
+
arbitrary order.
|
1131 |
+
|
1132 |
+
Note:
|
1133 |
+
:meth:`list_dir_or_file` returns the path relative to ``dir_path``.
|
1134 |
+
|
1135 |
+
Args:
|
1136 |
+
dir_path (str | Path): Path of the directory.
|
1137 |
+
list_dir (bool): List the directories. Default: True.
|
1138 |
+
list_file (bool): List the path of files. Default: True.
|
1139 |
+
suffix (str or tuple[str], optional): File suffix
|
1140 |
+
that we are interested in. Default: None.
|
1141 |
+
recursive (bool): If set to True, recursively scan the
|
1142 |
+
directory. Default: False.
|
1143 |
+
|
1144 |
+
Yields:
|
1145 |
+
Iterable[str]: A relative path to ``dir_path``.
|
1146 |
+
"""
|
1147 |
+
yield from self.client.list_dir_or_file(dir_path, list_dir, list_file,
|
1148 |
+
suffix, recursive)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .base import BaseFileHandler
|
3 |
+
from .json_handler import JsonHandler
|
4 |
+
from .pickle_handler import PickleHandler
|
5 |
+
from .yaml_handler import YamlHandler
|
6 |
+
|
7 |
+
__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/base.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from abc import ABCMeta, abstractmethod
|
3 |
+
|
4 |
+
|
5 |
+
class BaseFileHandler(metaclass=ABCMeta):
|
6 |
+
# `str_like` is a flag to indicate whether the type of file object is
|
7 |
+
# str-like object or bytes-like object. Pickle only processes bytes-like
|
8 |
+
# objects but json only processes str-like object. If it is str-like
|
9 |
+
# object, `StringIO` will be used to process the buffer.
|
10 |
+
str_like = True
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
def load_from_fileobj(self, file, **kwargs):
|
14 |
+
pass
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def dump_to_str(self, obj, **kwargs):
|
22 |
+
pass
|
23 |
+
|
24 |
+
def load_from_path(self, filepath, mode='r', **kwargs):
|
25 |
+
with open(filepath, mode) as f:
|
26 |
+
return self.load_from_fileobj(f, **kwargs)
|
27 |
+
|
28 |
+
def dump_to_path(self, obj, filepath, mode='w', **kwargs):
|
29 |
+
with open(filepath, mode) as f:
|
30 |
+
self.dump_to_fileobj(obj, f, **kwargs)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/json_handler.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import json
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from .base import BaseFileHandler
|
7 |
+
|
8 |
+
|
9 |
+
def set_default(obj):
|
10 |
+
"""Set default json values for non-serializable values.
|
11 |
+
|
12 |
+
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
|
13 |
+
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
|
14 |
+
etc.) into plain numbers of plain python built-in types.
|
15 |
+
"""
|
16 |
+
if isinstance(obj, (set, range)):
|
17 |
+
return list(obj)
|
18 |
+
elif isinstance(obj, np.ndarray):
|
19 |
+
return obj.tolist()
|
20 |
+
elif isinstance(obj, np.generic):
|
21 |
+
return obj.item()
|
22 |
+
raise TypeError(f'{type(obj)} is unsupported for json dump')
|
23 |
+
|
24 |
+
|
25 |
+
class JsonHandler(BaseFileHandler):
|
26 |
+
|
27 |
+
def load_from_fileobj(self, file):
|
28 |
+
return json.load(file)
|
29 |
+
|
30 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
31 |
+
kwargs.setdefault('default', set_default)
|
32 |
+
json.dump(obj, file, **kwargs)
|
33 |
+
|
34 |
+
def dump_to_str(self, obj, **kwargs):
|
35 |
+
kwargs.setdefault('default', set_default)
|
36 |
+
return json.dumps(obj, **kwargs)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/pickle_handler.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
from .base import BaseFileHandler
|
5 |
+
|
6 |
+
|
7 |
+
class PickleHandler(BaseFileHandler):
|
8 |
+
|
9 |
+
str_like = False
|
10 |
+
|
11 |
+
def load_from_fileobj(self, file, **kwargs):
|
12 |
+
return pickle.load(file, **kwargs)
|
13 |
+
|
14 |
+
def load_from_path(self, filepath, **kwargs):
|
15 |
+
return super(PickleHandler, self).load_from_path(
|
16 |
+
filepath, mode='rb', **kwargs)
|
17 |
+
|
18 |
+
def dump_to_str(self, obj, **kwargs):
|
19 |
+
kwargs.setdefault('protocol', 2)
|
20 |
+
return pickle.dumps(obj, **kwargs)
|
21 |
+
|
22 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
23 |
+
kwargs.setdefault('protocol', 2)
|
24 |
+
pickle.dump(obj, file, **kwargs)
|
25 |
+
|
26 |
+
def dump_to_path(self, obj, filepath, **kwargs):
|
27 |
+
super(PickleHandler, self).dump_to_path(
|
28 |
+
obj, filepath, mode='wb', **kwargs)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/yaml_handler.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import yaml
|
3 |
+
|
4 |
+
try:
|
5 |
+
from yaml import CLoader as Loader, CDumper as Dumper
|
6 |
+
except ImportError:
|
7 |
+
from yaml import Loader, Dumper
|
8 |
+
|
9 |
+
from .base import BaseFileHandler # isort:skip
|
10 |
+
|
11 |
+
|
12 |
+
class YamlHandler(BaseFileHandler):
|
13 |
+
|
14 |
+
def load_from_fileobj(self, file, **kwargs):
|
15 |
+
kwargs.setdefault('Loader', Loader)
|
16 |
+
return yaml.load(file, **kwargs)
|
17 |
+
|
18 |
+
def dump_to_fileobj(self, obj, file, **kwargs):
|
19 |
+
kwargs.setdefault('Dumper', Dumper)
|
20 |
+
yaml.dump(obj, file, **kwargs)
|
21 |
+
|
22 |
+
def dump_to_str(self, obj, **kwargs):
|
23 |
+
kwargs.setdefault('Dumper', Dumper)
|
24 |
+
return yaml.dump(obj, **kwargs)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/io.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from io import BytesIO, StringIO
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from ..utils import is_list_of, is_str
|
6 |
+
from .file_client import FileClient
|
7 |
+
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
|
8 |
+
|
9 |
+
file_handlers = {
|
10 |
+
'json': JsonHandler(),
|
11 |
+
'yaml': YamlHandler(),
|
12 |
+
'yml': YamlHandler(),
|
13 |
+
'pickle': PickleHandler(),
|
14 |
+
'pkl': PickleHandler()
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def load(file, file_format=None, file_client_args=None, **kwargs):
|
19 |
+
"""Load data from json/yaml/pickle files.
|
20 |
+
|
21 |
+
This method provides a unified api for loading data from serialized files.
|
22 |
+
|
23 |
+
Note:
|
24 |
+
In v1.3.16 and later, ``load`` supports loading data from serialized
|
25 |
+
files those can be storaged in different backends.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
file (str or :obj:`Path` or file-like object): Filename or a file-like
|
29 |
+
object.
|
30 |
+
file_format (str, optional): If not specified, the file format will be
|
31 |
+
inferred from the file extension, otherwise use the specified one.
|
32 |
+
Currently supported formats include "json", "yaml/yml" and
|
33 |
+
"pickle/pkl".
|
34 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
35 |
+
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
36 |
+
Default: None.
|
37 |
+
|
38 |
+
Examples:
|
39 |
+
>>> load('/path/of/your/file') # file is storaged in disk
|
40 |
+
>>> load('https://path/of/your/file') # file is storaged in Internet
|
41 |
+
>>> load('s3://path/of/your/file') # file is storaged in petrel
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
The content from the file.
|
45 |
+
"""
|
46 |
+
if isinstance(file, Path):
|
47 |
+
file = str(file)
|
48 |
+
if file_format is None and is_str(file):
|
49 |
+
file_format = file.split('.')[-1]
|
50 |
+
if file_format not in file_handlers:
|
51 |
+
raise TypeError(f'Unsupported format: {file_format}')
|
52 |
+
|
53 |
+
handler = file_handlers[file_format]
|
54 |
+
if is_str(file):
|
55 |
+
file_client = FileClient.infer_client(file_client_args, file)
|
56 |
+
if handler.str_like:
|
57 |
+
with StringIO(file_client.get_text(file)) as f:
|
58 |
+
obj = handler.load_from_fileobj(f, **kwargs)
|
59 |
+
else:
|
60 |
+
with BytesIO(file_client.get(file)) as f:
|
61 |
+
obj = handler.load_from_fileobj(f, **kwargs)
|
62 |
+
elif hasattr(file, 'read'):
|
63 |
+
obj = handler.load_from_fileobj(file, **kwargs)
|
64 |
+
else:
|
65 |
+
raise TypeError('"file" must be a filepath str or a file-object')
|
66 |
+
return obj
|
67 |
+
|
68 |
+
|
69 |
+
def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
|
70 |
+
"""Dump data to json/yaml/pickle strings or files.
|
71 |
+
|
72 |
+
This method provides a unified api for dumping data as strings or to files,
|
73 |
+
and also supports custom arguments for each file format.
|
74 |
+
|
75 |
+
Note:
|
76 |
+
In v1.3.16 and later, ``dump`` supports dumping data as strings or to
|
77 |
+
files which is saved to different backends.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
obj (any): The python object to be dumped.
|
81 |
+
file (str or :obj:`Path` or file-like object, optional): If not
|
82 |
+
specified, then the object is dumped to a str, otherwise to a file
|
83 |
+
specified by the filename or file-like object.
|
84 |
+
file_format (str, optional): Same as :func:`load`.
|
85 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
86 |
+
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
87 |
+
Default: None.
|
88 |
+
|
89 |
+
Examples:
|
90 |
+
>>> dump('hello world', '/path/of/your/file') # disk
|
91 |
+
>>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
bool: True for success, False otherwise.
|
95 |
+
"""
|
96 |
+
if isinstance(file, Path):
|
97 |
+
file = str(file)
|
98 |
+
if file_format is None:
|
99 |
+
if is_str(file):
|
100 |
+
file_format = file.split('.')[-1]
|
101 |
+
elif file is None:
|
102 |
+
raise ValueError(
|
103 |
+
'file_format must be specified since file is None')
|
104 |
+
if file_format not in file_handlers:
|
105 |
+
raise TypeError(f'Unsupported format: {file_format}')
|
106 |
+
|
107 |
+
handler = file_handlers[file_format]
|
108 |
+
if file is None:
|
109 |
+
return handler.dump_to_str(obj, **kwargs)
|
110 |
+
elif is_str(file):
|
111 |
+
file_client = FileClient.infer_client(file_client_args, file)
|
112 |
+
if handler.str_like:
|
113 |
+
with StringIO() as f:
|
114 |
+
handler.dump_to_fileobj(obj, f, **kwargs)
|
115 |
+
file_client.put_text(f.getvalue(), file)
|
116 |
+
else:
|
117 |
+
with BytesIO() as f:
|
118 |
+
handler.dump_to_fileobj(obj, f, **kwargs)
|
119 |
+
file_client.put(f.getvalue(), file)
|
120 |
+
elif hasattr(file, 'write'):
|
121 |
+
handler.dump_to_fileobj(obj, file, **kwargs)
|
122 |
+
else:
|
123 |
+
raise TypeError('"file" must be a filename str or a file-object')
|
124 |
+
|
125 |
+
|
126 |
+
def _register_handler(handler, file_formats):
|
127 |
+
"""Register a handler for some file extensions.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
handler (:obj:`BaseFileHandler`): Handler to be registered.
|
131 |
+
file_formats (str or list[str]): File formats to be handled by this
|
132 |
+
handler.
|
133 |
+
"""
|
134 |
+
if not isinstance(handler, BaseFileHandler):
|
135 |
+
raise TypeError(
|
136 |
+
f'handler must be a child of BaseFileHandler, not {type(handler)}')
|
137 |
+
if isinstance(file_formats, str):
|
138 |
+
file_formats = [file_formats]
|
139 |
+
if not is_list_of(file_formats, str):
|
140 |
+
raise TypeError('file_formats must be a str or a list of str')
|
141 |
+
for ext in file_formats:
|
142 |
+
file_handlers[ext] = handler
|
143 |
+
|
144 |
+
|
145 |
+
def register_handler(file_formats, **kwargs):
|
146 |
+
|
147 |
+
def wrap(cls):
|
148 |
+
_register_handler(cls(**kwargs), file_formats)
|
149 |
+
return cls
|
150 |
+
|
151 |
+
return wrap
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/parse.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
|
3 |
+
from io import StringIO
|
4 |
+
|
5 |
+
from .file_client import FileClient
|
6 |
+
|
7 |
+
|
8 |
+
def list_from_file(filename,
|
9 |
+
prefix='',
|
10 |
+
offset=0,
|
11 |
+
max_num=0,
|
12 |
+
encoding='utf-8',
|
13 |
+
file_client_args=None):
|
14 |
+
"""Load a text file and parse the content as a list of strings.
|
15 |
+
|
16 |
+
Note:
|
17 |
+
In v1.3.16 and later, ``list_from_file`` supports loading a text file
|
18 |
+
which can be storaged in different backends and parsing the content as
|
19 |
+
a list for strings.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
filename (str): Filename.
|
23 |
+
prefix (str): The prefix to be inserted to the beginning of each item.
|
24 |
+
offset (int): The offset of lines.
|
25 |
+
max_num (int): The maximum number of lines to be read,
|
26 |
+
zeros and negatives mean no limitation.
|
27 |
+
encoding (str): Encoding used to open the file. Default utf-8.
|
28 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
29 |
+
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
30 |
+
Default: None.
|
31 |
+
|
32 |
+
Examples:
|
33 |
+
>>> list_from_file('/path/of/your/file') # disk
|
34 |
+
['hello', 'world']
|
35 |
+
>>> list_from_file('s3://path/of/your/file') # ceph or petrel
|
36 |
+
['hello', 'world']
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
list[str]: A list of strings.
|
40 |
+
"""
|
41 |
+
cnt = 0
|
42 |
+
item_list = []
|
43 |
+
file_client = FileClient.infer_client(file_client_args, filename)
|
44 |
+
with StringIO(file_client.get_text(filename, encoding)) as f:
|
45 |
+
for _ in range(offset):
|
46 |
+
f.readline()
|
47 |
+
for line in f:
|
48 |
+
if 0 < max_num <= cnt:
|
49 |
+
break
|
50 |
+
item_list.append(prefix + line.rstrip('\n\r'))
|
51 |
+
cnt += 1
|
52 |
+
return item_list
|
53 |
+
|
54 |
+
|
55 |
+
def dict_from_file(filename,
|
56 |
+
key_type=str,
|
57 |
+
encoding='utf-8',
|
58 |
+
file_client_args=None):
|
59 |
+
"""Load a text file and parse the content as a dict.
|
60 |
+
|
61 |
+
Each line of the text file will be two or more columns split by
|
62 |
+
whitespaces or tabs. The first column will be parsed as dict keys, and
|
63 |
+
the following columns will be parsed as dict values.
|
64 |
+
|
65 |
+
Note:
|
66 |
+
In v1.3.16 and later, ``dict_from_file`` supports loading a text file
|
67 |
+
which can be storaged in different backends and parsing the content as
|
68 |
+
a dict.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
filename(str): Filename.
|
72 |
+
key_type(type): Type of the dict keys. str is user by default and
|
73 |
+
type conversion will be performed if specified.
|
74 |
+
encoding (str): Encoding used to open the file. Default utf-8.
|
75 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
76 |
+
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
77 |
+
Default: None.
|
78 |
+
|
79 |
+
Examples:
|
80 |
+
>>> dict_from_file('/path/of/your/file') # disk
|
81 |
+
{'key1': 'value1', 'key2': 'value2'}
|
82 |
+
>>> dict_from_file('s3://path/of/your/file') # ceph or petrel
|
83 |
+
{'key1': 'value1', 'key2': 'value2'}
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
dict: The parsed contents.
|
87 |
+
"""
|
88 |
+
mapping = {}
|
89 |
+
file_client = FileClient.infer_client(file_client_args, filename)
|
90 |
+
with StringIO(file_client.get_text(filename, encoding)) as f:
|
91 |
+
for line in f:
|
92 |
+
items = line.rstrip('\n').split()
|
93 |
+
assert len(items) >= 2
|
94 |
+
key = key_type(items[0])
|
95 |
+
val = items[1:] if len(items) > 2 else items[1]
|
96 |
+
mapping[key] = val
|
97 |
+
return mapping
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/__init__.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr,
|
3 |
+
gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert,
|
4 |
+
rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
|
5 |
+
from .geometric import (cutout, imcrop, imflip, imflip_, impad,
|
6 |
+
impad_to_multiple, imrescale, imresize, imresize_like,
|
7 |
+
imresize_to_multiple, imrotate, imshear, imtranslate,
|
8 |
+
rescale_size)
|
9 |
+
from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
|
10 |
+
from .misc import tensor2imgs
|
11 |
+
from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
|
12 |
+
adjust_lighting, adjust_sharpness, auto_contrast,
|
13 |
+
clahe, imdenormalize, imequalize, iminvert,
|
14 |
+
imnormalize, imnormalize_, lut_transform, posterize,
|
15 |
+
solarize)
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
|
19 |
+
'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
|
20 |
+
'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
|
21 |
+
'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
|
22 |
+
'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
|
23 |
+
'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
|
24 |
+
'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
|
25 |
+
'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
|
26 |
+
'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
|
27 |
+
'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting'
|
28 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/colorspace.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def imconvert(img, src, dst):
|
7 |
+
"""Convert an image from the src colorspace to dst colorspace.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
img (ndarray): The input image.
|
11 |
+
src (str): The source colorspace, e.g., 'rgb', 'hsv'.
|
12 |
+
dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
ndarray: The converted image.
|
16 |
+
"""
|
17 |
+
code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
|
18 |
+
out_img = cv2.cvtColor(img, code)
|
19 |
+
return out_img
|
20 |
+
|
21 |
+
|
22 |
+
def bgr2gray(img, keepdim=False):
|
23 |
+
"""Convert a BGR image to grayscale image.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
img (ndarray): The input image.
|
27 |
+
keepdim (bool): If False (by default), then return the grayscale image
|
28 |
+
with 2 dims, otherwise 3 dims.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
ndarray: The converted grayscale image.
|
32 |
+
"""
|
33 |
+
out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
34 |
+
if keepdim:
|
35 |
+
out_img = out_img[..., None]
|
36 |
+
return out_img
|
37 |
+
|
38 |
+
|
39 |
+
def rgb2gray(img, keepdim=False):
|
40 |
+
"""Convert a RGB image to grayscale image.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
img (ndarray): The input image.
|
44 |
+
keepdim (bool): If False (by default), then return the grayscale image
|
45 |
+
with 2 dims, otherwise 3 dims.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
ndarray: The converted grayscale image.
|
49 |
+
"""
|
50 |
+
out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
51 |
+
if keepdim:
|
52 |
+
out_img = out_img[..., None]
|
53 |
+
return out_img
|
54 |
+
|
55 |
+
|
56 |
+
def gray2bgr(img):
|
57 |
+
"""Convert a grayscale image to BGR image.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
img (ndarray): The input image.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
ndarray: The converted BGR image.
|
64 |
+
"""
|
65 |
+
img = img[..., None] if img.ndim == 2 else img
|
66 |
+
out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
67 |
+
return out_img
|
68 |
+
|
69 |
+
|
70 |
+
def gray2rgb(img):
|
71 |
+
"""Convert a grayscale image to RGB image.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
img (ndarray): The input image.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
ndarray: The converted RGB image.
|
78 |
+
"""
|
79 |
+
img = img[..., None] if img.ndim == 2 else img
|
80 |
+
out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
81 |
+
return out_img
|
82 |
+
|
83 |
+
|
84 |
+
def _convert_input_type_range(img):
|
85 |
+
"""Convert the type and range of the input image.
|
86 |
+
|
87 |
+
It converts the input image to np.float32 type and range of [0, 1].
|
88 |
+
It is mainly used for pre-processing the input image in colorspace
|
89 |
+
conversion functions such as rgb2ycbcr and ycbcr2rgb.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
img (ndarray): The input image. It accepts:
|
93 |
+
1. np.uint8 type with range [0, 255];
|
94 |
+
2. np.float32 type with range [0, 1].
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
(ndarray): The converted image with type of np.float32 and range of
|
98 |
+
[0, 1].
|
99 |
+
"""
|
100 |
+
img_type = img.dtype
|
101 |
+
img = img.astype(np.float32)
|
102 |
+
if img_type == np.float32:
|
103 |
+
pass
|
104 |
+
elif img_type == np.uint8:
|
105 |
+
img /= 255.
|
106 |
+
else:
|
107 |
+
raise TypeError('The img type should be np.float32 or np.uint8, '
|
108 |
+
f'but got {img_type}')
|
109 |
+
return img
|
110 |
+
|
111 |
+
|
112 |
+
def _convert_output_type_range(img, dst_type):
|
113 |
+
"""Convert the type and range of the image according to dst_type.
|
114 |
+
|
115 |
+
It converts the image to desired type and range. If `dst_type` is np.uint8,
|
116 |
+
images will be converted to np.uint8 type with range [0, 255]. If
|
117 |
+
`dst_type` is np.float32, it converts the image to np.float32 type with
|
118 |
+
range [0, 1].
|
119 |
+
It is mainly used for post-processing images in colorspace conversion
|
120 |
+
functions such as rgb2ycbcr and ycbcr2rgb.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
img (ndarray): The image to be converted with np.float32 type and
|
124 |
+
range [0, 255].
|
125 |
+
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
|
126 |
+
converts the image to np.uint8 type with range [0, 255]. If
|
127 |
+
dst_type is np.float32, it converts the image to np.float32 type
|
128 |
+
with range [0, 1].
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
(ndarray): The converted image with desired type and range.
|
132 |
+
"""
|
133 |
+
if dst_type not in (np.uint8, np.float32):
|
134 |
+
raise TypeError('The dst_type should be np.float32 or np.uint8, '
|
135 |
+
f'but got {dst_type}')
|
136 |
+
if dst_type == np.uint8:
|
137 |
+
img = img.round()
|
138 |
+
else:
|
139 |
+
img /= 255.
|
140 |
+
return img.astype(dst_type)
|
141 |
+
|
142 |
+
|
143 |
+
def rgb2ycbcr(img, y_only=False):
|
144 |
+
"""Convert a RGB image to YCbCr image.
|
145 |
+
|
146 |
+
This function produces the same results as Matlab's `rgb2ycbcr` function.
|
147 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
148 |
+
television. See more details in
|
149 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
150 |
+
|
151 |
+
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
|
152 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
153 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
img (ndarray): The input image. It accepts:
|
157 |
+
1. np.uint8 type with range [0, 255];
|
158 |
+
2. np.float32 type with range [0, 1].
|
159 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
163 |
+
and range as input image.
|
164 |
+
"""
|
165 |
+
img_type = img.dtype
|
166 |
+
img = _convert_input_type_range(img)
|
167 |
+
if y_only:
|
168 |
+
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
|
169 |
+
else:
|
170 |
+
out_img = np.matmul(
|
171 |
+
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
172 |
+
[24.966, 112.0, -18.214]]) + [16, 128, 128]
|
173 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
174 |
+
return out_img
|
175 |
+
|
176 |
+
|
177 |
+
def bgr2ycbcr(img, y_only=False):
|
178 |
+
"""Convert a BGR image to YCbCr image.
|
179 |
+
|
180 |
+
The bgr version of rgb2ycbcr.
|
181 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
182 |
+
television. See more details in
|
183 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
184 |
+
|
185 |
+
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
|
186 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
187 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
img (ndarray): The input image. It accepts:
|
191 |
+
1. np.uint8 type with range [0, 255];
|
192 |
+
2. np.float32 type with range [0, 1].
|
193 |
+
y_only (bool): Whether to only return Y channel. Default: False.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
ndarray: The converted YCbCr image. The output image has the same type
|
197 |
+
and range as input image.
|
198 |
+
"""
|
199 |
+
img_type = img.dtype
|
200 |
+
img = _convert_input_type_range(img)
|
201 |
+
if y_only:
|
202 |
+
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
|
203 |
+
else:
|
204 |
+
out_img = np.matmul(
|
205 |
+
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
206 |
+
[65.481, -37.797, 112.0]]) + [16, 128, 128]
|
207 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
208 |
+
return out_img
|
209 |
+
|
210 |
+
|
211 |
+
def ycbcr2rgb(img):
|
212 |
+
"""Convert a YCbCr image to RGB image.
|
213 |
+
|
214 |
+
This function produces the same results as Matlab's ycbcr2rgb function.
|
215 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
216 |
+
television. See more details in
|
217 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
218 |
+
|
219 |
+
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
|
220 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
221 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
img (ndarray): The input image. It accepts:
|
225 |
+
1. np.uint8 type with range [0, 255];
|
226 |
+
2. np.float32 type with range [0, 1].
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
ndarray: The converted RGB image. The output image has the same type
|
230 |
+
and range as input image.
|
231 |
+
"""
|
232 |
+
img_type = img.dtype
|
233 |
+
img = _convert_input_type_range(img) * 255
|
234 |
+
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
|
235 |
+
[0, -0.00153632, 0.00791071],
|
236 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [
|
237 |
+
-222.921, 135.576, -276.836
|
238 |
+
]
|
239 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
240 |
+
return out_img
|
241 |
+
|
242 |
+
|
243 |
+
def ycbcr2bgr(img):
|
244 |
+
"""Convert a YCbCr image to BGR image.
|
245 |
+
|
246 |
+
The bgr version of ycbcr2rgb.
|
247 |
+
It implements the ITU-R BT.601 conversion for standard-definition
|
248 |
+
television. See more details in
|
249 |
+
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
|
250 |
+
|
251 |
+
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
|
252 |
+
In OpenCV, it implements a JPEG conversion. See more details in
|
253 |
+
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
img (ndarray): The input image. It accepts:
|
257 |
+
1. np.uint8 type with range [0, 255];
|
258 |
+
2. np.float32 type with range [0, 1].
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
ndarray: The converted BGR image. The output image has the same type
|
262 |
+
and range as input image.
|
263 |
+
"""
|
264 |
+
img_type = img.dtype
|
265 |
+
img = _convert_input_type_range(img) * 255
|
266 |
+
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
|
267 |
+
[0.00791071, -0.00153632, 0],
|
268 |
+
[0, -0.00318811, 0.00625893]]) * 255.0 + [
|
269 |
+
-276.836, 135.576, -222.921
|
270 |
+
]
|
271 |
+
out_img = _convert_output_type_range(out_img, img_type)
|
272 |
+
return out_img
|
273 |
+
|
274 |
+
|
275 |
+
def convert_color_factory(src, dst):
|
276 |
+
|
277 |
+
code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
|
278 |
+
|
279 |
+
def convert_color(img):
|
280 |
+
out_img = cv2.cvtColor(img, code)
|
281 |
+
return out_img
|
282 |
+
|
283 |
+
convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
|
284 |
+
image.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
img (ndarray or str): The input image.
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
ndarray: The converted {dst.upper()} image.
|
291 |
+
"""
|
292 |
+
|
293 |
+
return convert_color
|
294 |
+
|
295 |
+
|
296 |
+
bgr2rgb = convert_color_factory('bgr', 'rgb')
|
297 |
+
|
298 |
+
rgb2bgr = convert_color_factory('rgb', 'bgr')
|
299 |
+
|
300 |
+
bgr2hsv = convert_color_factory('bgr', 'hsv')
|
301 |
+
|
302 |
+
hsv2bgr = convert_color_factory('hsv', 'bgr')
|
303 |
+
|
304 |
+
bgr2hls = convert_color_factory('bgr', 'hls')
|
305 |
+
|
306 |
+
hls2bgr = convert_color_factory('hls', 'bgr')
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/geometric.py
ADDED
@@ -0,0 +1,728 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import numbers
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from ..utils import to_2tuple
|
8 |
+
from .io import imread_backend
|
9 |
+
|
10 |
+
try:
|
11 |
+
from PIL import Image
|
12 |
+
except ImportError:
|
13 |
+
Image = None
|
14 |
+
|
15 |
+
|
16 |
+
def _scale_size(size, scale):
|
17 |
+
"""Rescale a size by a ratio.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
size (tuple[int]): (w, h).
|
21 |
+
scale (float | tuple(float)): Scaling factor.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
tuple[int]: scaled size.
|
25 |
+
"""
|
26 |
+
if isinstance(scale, (float, int)):
|
27 |
+
scale = (scale, scale)
|
28 |
+
w, h = size
|
29 |
+
return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
|
30 |
+
|
31 |
+
|
32 |
+
cv2_interp_codes = {
|
33 |
+
'nearest': cv2.INTER_NEAREST,
|
34 |
+
'bilinear': cv2.INTER_LINEAR,
|
35 |
+
'bicubic': cv2.INTER_CUBIC,
|
36 |
+
'area': cv2.INTER_AREA,
|
37 |
+
'lanczos': cv2.INTER_LANCZOS4
|
38 |
+
}
|
39 |
+
|
40 |
+
if Image is not None:
|
41 |
+
pillow_interp_codes = {
|
42 |
+
'nearest': Image.NEAREST,
|
43 |
+
'bilinear': Image.BILINEAR,
|
44 |
+
'bicubic': Image.BICUBIC,
|
45 |
+
'box': Image.BOX,
|
46 |
+
'lanczos': Image.LANCZOS,
|
47 |
+
'hamming': Image.HAMMING
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
def imresize(img,
|
52 |
+
size,
|
53 |
+
return_scale=False,
|
54 |
+
interpolation='bilinear',
|
55 |
+
out=None,
|
56 |
+
backend=None):
|
57 |
+
"""Resize image to a given size.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
img (ndarray): The input image.
|
61 |
+
size (tuple[int]): Target size (w, h).
|
62 |
+
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
63 |
+
interpolation (str): Interpolation method, accepted values are
|
64 |
+
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
|
65 |
+
backend, "nearest", "bilinear" for 'pillow' backend.
|
66 |
+
out (ndarray): The output destination.
|
67 |
+
backend (str | None): The image resize backend type. Options are `cv2`,
|
68 |
+
`pillow`, `None`. If backend is None, the global imread_backend
|
69 |
+
specified by ``mmcv.use_backend()`` will be used. Default: None.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
73 |
+
`resized_img`.
|
74 |
+
"""
|
75 |
+
h, w = img.shape[:2]
|
76 |
+
if backend is None:
|
77 |
+
backend = imread_backend
|
78 |
+
if backend not in ['cv2', 'pillow']:
|
79 |
+
raise ValueError(f'backend: {backend} is not supported for resize.'
|
80 |
+
f"Supported backends are 'cv2', 'pillow'")
|
81 |
+
|
82 |
+
if backend == 'pillow':
|
83 |
+
assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
|
84 |
+
pil_image = Image.fromarray(img)
|
85 |
+
pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
|
86 |
+
resized_img = np.array(pil_image)
|
87 |
+
else:
|
88 |
+
resized_img = cv2.resize(
|
89 |
+
img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
|
90 |
+
if not return_scale:
|
91 |
+
return resized_img
|
92 |
+
else:
|
93 |
+
w_scale = size[0] / w
|
94 |
+
h_scale = size[1] / h
|
95 |
+
return resized_img, w_scale, h_scale
|
96 |
+
|
97 |
+
|
98 |
+
def imresize_to_multiple(img,
|
99 |
+
divisor,
|
100 |
+
size=None,
|
101 |
+
scale_factor=None,
|
102 |
+
keep_ratio=False,
|
103 |
+
return_scale=False,
|
104 |
+
interpolation='bilinear',
|
105 |
+
out=None,
|
106 |
+
backend=None):
|
107 |
+
"""Resize image according to a given size or scale factor and then rounds
|
108 |
+
up the the resized or rescaled image size to the nearest value that can be
|
109 |
+
divided by the divisor.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
img (ndarray): The input image.
|
113 |
+
divisor (int | tuple): Resized image size will be a multiple of
|
114 |
+
divisor. If divisor is a tuple, divisor should be
|
115 |
+
(w_divisor, h_divisor).
|
116 |
+
size (None | int | tuple[int]): Target size (w, h). Default: None.
|
117 |
+
scale_factor (None | float | tuple[float]): Multiplier for spatial
|
118 |
+
size. Should match input size if it is a tuple and the 2D style is
|
119 |
+
(w_scale_factor, h_scale_factor). Default: None.
|
120 |
+
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
|
121 |
+
image. Default: False.
|
122 |
+
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
123 |
+
interpolation (str): Interpolation method, accepted values are
|
124 |
+
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
|
125 |
+
backend, "nearest", "bilinear" for 'pillow' backend.
|
126 |
+
out (ndarray): The output destination.
|
127 |
+
backend (str | None): The image resize backend type. Options are `cv2`,
|
128 |
+
`pillow`, `None`. If backend is None, the global imread_backend
|
129 |
+
specified by ``mmcv.use_backend()`` will be used. Default: None.
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
133 |
+
`resized_img`.
|
134 |
+
"""
|
135 |
+
h, w = img.shape[:2]
|
136 |
+
if size is not None and scale_factor is not None:
|
137 |
+
raise ValueError('only one of size or scale_factor should be defined')
|
138 |
+
elif size is None and scale_factor is None:
|
139 |
+
raise ValueError('one of size or scale_factor should be defined')
|
140 |
+
elif size is not None:
|
141 |
+
size = to_2tuple(size)
|
142 |
+
if keep_ratio:
|
143 |
+
size = rescale_size((w, h), size, return_scale=False)
|
144 |
+
else:
|
145 |
+
size = _scale_size((w, h), scale_factor)
|
146 |
+
|
147 |
+
divisor = to_2tuple(divisor)
|
148 |
+
size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
|
149 |
+
resized_img, w_scale, h_scale = imresize(
|
150 |
+
img,
|
151 |
+
size,
|
152 |
+
return_scale=True,
|
153 |
+
interpolation=interpolation,
|
154 |
+
out=out,
|
155 |
+
backend=backend)
|
156 |
+
if return_scale:
|
157 |
+
return resized_img, w_scale, h_scale
|
158 |
+
else:
|
159 |
+
return resized_img
|
160 |
+
|
161 |
+
|
162 |
+
def imresize_like(img,
|
163 |
+
dst_img,
|
164 |
+
return_scale=False,
|
165 |
+
interpolation='bilinear',
|
166 |
+
backend=None):
|
167 |
+
"""Resize image to the same size of a given image.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
img (ndarray): The input image.
|
171 |
+
dst_img (ndarray): The target image.
|
172 |
+
return_scale (bool): Whether to return `w_scale` and `h_scale`.
|
173 |
+
interpolation (str): Same as :func:`resize`.
|
174 |
+
backend (str | None): Same as :func:`resize`.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
|
178 |
+
`resized_img`.
|
179 |
+
"""
|
180 |
+
h, w = dst_img.shape[:2]
|
181 |
+
return imresize(img, (w, h), return_scale, interpolation, backend=backend)
|
182 |
+
|
183 |
+
|
184 |
+
def rescale_size(old_size, scale, return_scale=False):
|
185 |
+
"""Calculate the new size to be rescaled to.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
old_size (tuple[int]): The old size (w, h) of image.
|
189 |
+
scale (float | tuple[int]): The scaling factor or maximum size.
|
190 |
+
If it is a float number, then the image will be rescaled by this
|
191 |
+
factor, else if it is a tuple of 2 integers, then the image will
|
192 |
+
be rescaled as large as possible within the scale.
|
193 |
+
return_scale (bool): Whether to return the scaling factor besides the
|
194 |
+
rescaled image size.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
tuple[int]: The new rescaled image size.
|
198 |
+
"""
|
199 |
+
w, h = old_size
|
200 |
+
if isinstance(scale, (float, int)):
|
201 |
+
if scale <= 0:
|
202 |
+
raise ValueError(f'Invalid scale {scale}, must be positive.')
|
203 |
+
scale_factor = scale
|
204 |
+
elif isinstance(scale, tuple):
|
205 |
+
max_long_edge = max(scale)
|
206 |
+
max_short_edge = min(scale)
|
207 |
+
scale_factor = min(max_long_edge / max(h, w),
|
208 |
+
max_short_edge / min(h, w))
|
209 |
+
else:
|
210 |
+
raise TypeError(
|
211 |
+
f'Scale must be a number or tuple of int, but got {type(scale)}')
|
212 |
+
|
213 |
+
new_size = _scale_size((w, h), scale_factor)
|
214 |
+
|
215 |
+
if return_scale:
|
216 |
+
return new_size, scale_factor
|
217 |
+
else:
|
218 |
+
return new_size
|
219 |
+
|
220 |
+
|
221 |
+
def imrescale(img,
|
222 |
+
scale,
|
223 |
+
return_scale=False,
|
224 |
+
interpolation='bilinear',
|
225 |
+
backend=None):
|
226 |
+
"""Resize image while keeping the aspect ratio.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
img (ndarray): The input image.
|
230 |
+
scale (float | tuple[int]): The scaling factor or maximum size.
|
231 |
+
If it is a float number, then the image will be rescaled by this
|
232 |
+
factor, else if it is a tuple of 2 integers, then the image will
|
233 |
+
be rescaled as large as possible within the scale.
|
234 |
+
return_scale (bool): Whether to return the scaling factor besides the
|
235 |
+
rescaled image.
|
236 |
+
interpolation (str): Same as :func:`resize`.
|
237 |
+
backend (str | None): Same as :func:`resize`.
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
ndarray: The rescaled image.
|
241 |
+
"""
|
242 |
+
h, w = img.shape[:2]
|
243 |
+
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
|
244 |
+
rescaled_img = imresize(
|
245 |
+
img, new_size, interpolation=interpolation, backend=backend)
|
246 |
+
if return_scale:
|
247 |
+
return rescaled_img, scale_factor
|
248 |
+
else:
|
249 |
+
return rescaled_img
|
250 |
+
|
251 |
+
|
252 |
+
def imflip(img, direction='horizontal'):
|
253 |
+
"""Flip an image horizontally or vertically.
|
254 |
+
|
255 |
+
Args:
|
256 |
+
img (ndarray): Image to be flipped.
|
257 |
+
direction (str): The flip direction, either "horizontal" or
|
258 |
+
"vertical" or "diagonal".
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
ndarray: The flipped image.
|
262 |
+
"""
|
263 |
+
assert direction in ['horizontal', 'vertical', 'diagonal']
|
264 |
+
if direction == 'horizontal':
|
265 |
+
return np.flip(img, axis=1)
|
266 |
+
elif direction == 'vertical':
|
267 |
+
return np.flip(img, axis=0)
|
268 |
+
else:
|
269 |
+
return np.flip(img, axis=(0, 1))
|
270 |
+
|
271 |
+
|
272 |
+
def imflip_(img, direction='horizontal'):
|
273 |
+
"""Inplace flip an image horizontally or vertically.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
img (ndarray): Image to be flipped.
|
277 |
+
direction (str): The flip direction, either "horizontal" or
|
278 |
+
"vertical" or "diagonal".
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
ndarray: The flipped image (inplace).
|
282 |
+
"""
|
283 |
+
assert direction in ['horizontal', 'vertical', 'diagonal']
|
284 |
+
if direction == 'horizontal':
|
285 |
+
return cv2.flip(img, 1, img)
|
286 |
+
elif direction == 'vertical':
|
287 |
+
return cv2.flip(img, 0, img)
|
288 |
+
else:
|
289 |
+
return cv2.flip(img, -1, img)
|
290 |
+
|
291 |
+
|
292 |
+
def imrotate(img,
|
293 |
+
angle,
|
294 |
+
center=None,
|
295 |
+
scale=1.0,
|
296 |
+
border_value=0,
|
297 |
+
interpolation='bilinear',
|
298 |
+
auto_bound=False):
|
299 |
+
"""Rotate an image.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
img (ndarray): Image to be rotated.
|
303 |
+
angle (float): Rotation angle in degrees, positive values mean
|
304 |
+
clockwise rotation.
|
305 |
+
center (tuple[float], optional): Center point (w, h) of the rotation in
|
306 |
+
the source image. If not specified, the center of the image will be
|
307 |
+
used.
|
308 |
+
scale (float): Isotropic scale factor.
|
309 |
+
border_value (int): Border value.
|
310 |
+
interpolation (str): Same as :func:`resize`.
|
311 |
+
auto_bound (bool): Whether to adjust the image size to cover the whole
|
312 |
+
rotated image.
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
ndarray: The rotated image.
|
316 |
+
"""
|
317 |
+
if center is not None and auto_bound:
|
318 |
+
raise ValueError('`auto_bound` conflicts with `center`')
|
319 |
+
h, w = img.shape[:2]
|
320 |
+
if center is None:
|
321 |
+
center = ((w - 1) * 0.5, (h - 1) * 0.5)
|
322 |
+
assert isinstance(center, tuple)
|
323 |
+
|
324 |
+
matrix = cv2.getRotationMatrix2D(center, -angle, scale)
|
325 |
+
if auto_bound:
|
326 |
+
cos = np.abs(matrix[0, 0])
|
327 |
+
sin = np.abs(matrix[0, 1])
|
328 |
+
new_w = h * sin + w * cos
|
329 |
+
new_h = h * cos + w * sin
|
330 |
+
matrix[0, 2] += (new_w - w) * 0.5
|
331 |
+
matrix[1, 2] += (new_h - h) * 0.5
|
332 |
+
w = int(np.round(new_w))
|
333 |
+
h = int(np.round(new_h))
|
334 |
+
rotated = cv2.warpAffine(
|
335 |
+
img,
|
336 |
+
matrix, (w, h),
|
337 |
+
flags=cv2_interp_codes[interpolation],
|
338 |
+
borderValue=border_value)
|
339 |
+
return rotated
|
340 |
+
|
341 |
+
|
342 |
+
def bbox_clip(bboxes, img_shape):
|
343 |
+
"""Clip bboxes to fit the image shape.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
bboxes (ndarray): Shape (..., 4*k)
|
347 |
+
img_shape (tuple[int]): (height, width) of the image.
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
ndarray: Clipped bboxes.
|
351 |
+
"""
|
352 |
+
assert bboxes.shape[-1] % 4 == 0
|
353 |
+
cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype)
|
354 |
+
cmin[0::2] = img_shape[1] - 1
|
355 |
+
cmin[1::2] = img_shape[0] - 1
|
356 |
+
clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0)
|
357 |
+
return clipped_bboxes
|
358 |
+
|
359 |
+
|
360 |
+
def bbox_scaling(bboxes, scale, clip_shape=None):
|
361 |
+
"""Scaling bboxes w.r.t the box center.
|
362 |
+
|
363 |
+
Args:
|
364 |
+
bboxes (ndarray): Shape(..., 4).
|
365 |
+
scale (float): Scaling factor.
|
366 |
+
clip_shape (tuple[int], optional): If specified, bboxes that exceed the
|
367 |
+
boundary will be clipped according to the given shape (h, w).
|
368 |
+
|
369 |
+
Returns:
|
370 |
+
ndarray: Scaled bboxes.
|
371 |
+
"""
|
372 |
+
if float(scale) == 1.0:
|
373 |
+
scaled_bboxes = bboxes.copy()
|
374 |
+
else:
|
375 |
+
w = bboxes[..., 2] - bboxes[..., 0] + 1
|
376 |
+
h = bboxes[..., 3] - bboxes[..., 1] + 1
|
377 |
+
dw = (w * (scale - 1)) * 0.5
|
378 |
+
dh = (h * (scale - 1)) * 0.5
|
379 |
+
scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
|
380 |
+
if clip_shape is not None:
|
381 |
+
return bbox_clip(scaled_bboxes, clip_shape)
|
382 |
+
else:
|
383 |
+
return scaled_bboxes
|
384 |
+
|
385 |
+
|
386 |
+
def imcrop(img, bboxes, scale=1.0, pad_fill=None):
|
387 |
+
"""Crop image patches.
|
388 |
+
|
389 |
+
3 steps: scale the bboxes -> clip bboxes -> crop and pad.
|
390 |
+
|
391 |
+
Args:
|
392 |
+
img (ndarray): Image to be cropped.
|
393 |
+
bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
|
394 |
+
scale (float, optional): Scale ratio of bboxes, the default value
|
395 |
+
1.0 means no padding.
|
396 |
+
pad_fill (Number | list[Number]): Value to be filled for padding.
|
397 |
+
Default: None, which means no padding.
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
list[ndarray] | ndarray: The cropped image patches.
|
401 |
+
"""
|
402 |
+
chn = 1 if img.ndim == 2 else img.shape[2]
|
403 |
+
if pad_fill is not None:
|
404 |
+
if isinstance(pad_fill, (int, float)):
|
405 |
+
pad_fill = [pad_fill for _ in range(chn)]
|
406 |
+
assert len(pad_fill) == chn
|
407 |
+
|
408 |
+
_bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes
|
409 |
+
scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32)
|
410 |
+
clipped_bbox = bbox_clip(scaled_bboxes, img.shape)
|
411 |
+
|
412 |
+
patches = []
|
413 |
+
for i in range(clipped_bbox.shape[0]):
|
414 |
+
x1, y1, x2, y2 = tuple(clipped_bbox[i, :])
|
415 |
+
if pad_fill is None:
|
416 |
+
patch = img[y1:y2 + 1, x1:x2 + 1, ...]
|
417 |
+
else:
|
418 |
+
_x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
|
419 |
+
if chn == 1:
|
420 |
+
patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
|
421 |
+
else:
|
422 |
+
patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
|
423 |
+
patch = np.array(
|
424 |
+
pad_fill, dtype=img.dtype) * np.ones(
|
425 |
+
patch_shape, dtype=img.dtype)
|
426 |
+
x_start = 0 if _x1 >= 0 else -_x1
|
427 |
+
y_start = 0 if _y1 >= 0 else -_y1
|
428 |
+
w = x2 - x1 + 1
|
429 |
+
h = y2 - y1 + 1
|
430 |
+
patch[y_start:y_start + h, x_start:x_start + w,
|
431 |
+
...] = img[y1:y1 + h, x1:x1 + w, ...]
|
432 |
+
patches.append(patch)
|
433 |
+
|
434 |
+
if bboxes.ndim == 1:
|
435 |
+
return patches[0]
|
436 |
+
else:
|
437 |
+
return patches
|
438 |
+
|
439 |
+
|
440 |
+
def impad(img,
|
441 |
+
*,
|
442 |
+
shape=None,
|
443 |
+
padding=None,
|
444 |
+
pad_val=0,
|
445 |
+
padding_mode='constant'):
|
446 |
+
"""Pad the given image to a certain shape or pad on all sides with
|
447 |
+
specified padding mode and padding value.
|
448 |
+
|
449 |
+
Args:
|
450 |
+
img (ndarray): Image to be padded.
|
451 |
+
shape (tuple[int]): Expected padding shape (h, w). Default: None.
|
452 |
+
padding (int or tuple[int]): Padding on each border. If a single int is
|
453 |
+
provided this is used to pad all borders. If tuple of length 2 is
|
454 |
+
provided this is the padding on left/right and top/bottom
|
455 |
+
respectively. If a tuple of length 4 is provided this is the
|
456 |
+
padding for the left, top, right and bottom borders respectively.
|
457 |
+
Default: None. Note that `shape` and `padding` can not be both
|
458 |
+
set.
|
459 |
+
pad_val (Number | Sequence[Number]): Values to be filled in padding
|
460 |
+
areas when padding_mode is 'constant'. Default: 0.
|
461 |
+
padding_mode (str): Type of padding. Should be: constant, edge,
|
462 |
+
reflect or symmetric. Default: constant.
|
463 |
+
|
464 |
+
- constant: pads with a constant value, this value is specified
|
465 |
+
with pad_val.
|
466 |
+
- edge: pads with the last value at the edge of the image.
|
467 |
+
- reflect: pads with reflection of image without repeating the
|
468 |
+
last value on the edge. For example, padding [1, 2, 3, 4]
|
469 |
+
with 2 elements on both sides in reflect mode will result
|
470 |
+
in [3, 2, 1, 2, 3, 4, 3, 2].
|
471 |
+
- symmetric: pads with reflection of image repeating the last
|
472 |
+
value on the edge. For example, padding [1, 2, 3, 4] with
|
473 |
+
2 elements on both sides in symmetric mode will result in
|
474 |
+
[2, 1, 1, 2, 3, 4, 4, 3]
|
475 |
+
|
476 |
+
Returns:
|
477 |
+
ndarray: The padded image.
|
478 |
+
"""
|
479 |
+
|
480 |
+
assert (shape is not None) ^ (padding is not None)
|
481 |
+
if shape is not None:
|
482 |
+
padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0])
|
483 |
+
|
484 |
+
# check pad_val
|
485 |
+
if isinstance(pad_val, tuple):
|
486 |
+
assert len(pad_val) == img.shape[-1]
|
487 |
+
elif not isinstance(pad_val, numbers.Number):
|
488 |
+
raise TypeError('pad_val must be a int or a tuple. '
|
489 |
+
f'But received {type(pad_val)}')
|
490 |
+
|
491 |
+
# check padding
|
492 |
+
if isinstance(padding, tuple) and len(padding) in [2, 4]:
|
493 |
+
if len(padding) == 2:
|
494 |
+
padding = (padding[0], padding[1], padding[0], padding[1])
|
495 |
+
elif isinstance(padding, numbers.Number):
|
496 |
+
padding = (padding, padding, padding, padding)
|
497 |
+
else:
|
498 |
+
raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
|
499 |
+
f'But received {padding}')
|
500 |
+
|
501 |
+
# check padding mode
|
502 |
+
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
503 |
+
|
504 |
+
border_type = {
|
505 |
+
'constant': cv2.BORDER_CONSTANT,
|
506 |
+
'edge': cv2.BORDER_REPLICATE,
|
507 |
+
'reflect': cv2.BORDER_REFLECT_101,
|
508 |
+
'symmetric': cv2.BORDER_REFLECT
|
509 |
+
}
|
510 |
+
img = cv2.copyMakeBorder(
|
511 |
+
img,
|
512 |
+
padding[1],
|
513 |
+
padding[3],
|
514 |
+
padding[0],
|
515 |
+
padding[2],
|
516 |
+
border_type[padding_mode],
|
517 |
+
value=pad_val)
|
518 |
+
|
519 |
+
return img
|
520 |
+
|
521 |
+
|
522 |
+
def impad_to_multiple(img, divisor, pad_val=0):
|
523 |
+
"""Pad an image to ensure each edge to be multiple to some number.
|
524 |
+
|
525 |
+
Args:
|
526 |
+
img (ndarray): Image to be padded.
|
527 |
+
divisor (int): Padded image edges will be multiple to divisor.
|
528 |
+
pad_val (Number | Sequence[Number]): Same as :func:`impad`.
|
529 |
+
|
530 |
+
Returns:
|
531 |
+
ndarray: The padded image.
|
532 |
+
"""
|
533 |
+
pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
|
534 |
+
pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
|
535 |
+
return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
|
536 |
+
|
537 |
+
|
538 |
+
def cutout(img, shape, pad_val=0):
|
539 |
+
"""Randomly cut out a rectangle from the original img.
|
540 |
+
|
541 |
+
Args:
|
542 |
+
img (ndarray): Image to be cutout.
|
543 |
+
shape (int | tuple[int]): Expected cutout shape (h, w). If given as a
|
544 |
+
int, the value will be used for both h and w.
|
545 |
+
pad_val (int | float | tuple[int | float]): Values to be filled in the
|
546 |
+
cut area. Defaults to 0.
|
547 |
+
|
548 |
+
Returns:
|
549 |
+
ndarray: The cutout image.
|
550 |
+
"""
|
551 |
+
|
552 |
+
channels = 1 if img.ndim == 2 else img.shape[2]
|
553 |
+
if isinstance(shape, int):
|
554 |
+
cut_h, cut_w = shape, shape
|
555 |
+
else:
|
556 |
+
assert isinstance(shape, tuple) and len(shape) == 2, \
|
557 |
+
f'shape must be a int or a tuple with length 2, but got type ' \
|
558 |
+
f'{type(shape)} instead.'
|
559 |
+
cut_h, cut_w = shape
|
560 |
+
if isinstance(pad_val, (int, float)):
|
561 |
+
pad_val = tuple([pad_val] * channels)
|
562 |
+
elif isinstance(pad_val, tuple):
|
563 |
+
assert len(pad_val) == channels, \
|
564 |
+
'Expected the num of elements in tuple equals the channels' \
|
565 |
+
'of input image. Found {} vs {}'.format(
|
566 |
+
len(pad_val), channels)
|
567 |
+
else:
|
568 |
+
raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`')
|
569 |
+
|
570 |
+
img_h, img_w = img.shape[:2]
|
571 |
+
y0 = np.random.uniform(img_h)
|
572 |
+
x0 = np.random.uniform(img_w)
|
573 |
+
|
574 |
+
y1 = int(max(0, y0 - cut_h / 2.))
|
575 |
+
x1 = int(max(0, x0 - cut_w / 2.))
|
576 |
+
y2 = min(img_h, y1 + cut_h)
|
577 |
+
x2 = min(img_w, x1 + cut_w)
|
578 |
+
|
579 |
+
if img.ndim == 2:
|
580 |
+
patch_shape = (y2 - y1, x2 - x1)
|
581 |
+
else:
|
582 |
+
patch_shape = (y2 - y1, x2 - x1, channels)
|
583 |
+
|
584 |
+
img_cutout = img.copy()
|
585 |
+
patch = np.array(
|
586 |
+
pad_val, dtype=img.dtype) * np.ones(
|
587 |
+
patch_shape, dtype=img.dtype)
|
588 |
+
img_cutout[y1:y2, x1:x2, ...] = patch
|
589 |
+
|
590 |
+
return img_cutout
|
591 |
+
|
592 |
+
|
593 |
+
def _get_shear_matrix(magnitude, direction='horizontal'):
|
594 |
+
"""Generate the shear matrix for transformation.
|
595 |
+
|
596 |
+
Args:
|
597 |
+
magnitude (int | float): The magnitude used for shear.
|
598 |
+
direction (str): The flip direction, either "horizontal"
|
599 |
+
or "vertical".
|
600 |
+
|
601 |
+
Returns:
|
602 |
+
ndarray: The shear matrix with dtype float32.
|
603 |
+
"""
|
604 |
+
if direction == 'horizontal':
|
605 |
+
shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]])
|
606 |
+
elif direction == 'vertical':
|
607 |
+
shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]])
|
608 |
+
return shear_matrix
|
609 |
+
|
610 |
+
|
611 |
+
def imshear(img,
|
612 |
+
magnitude,
|
613 |
+
direction='horizontal',
|
614 |
+
border_value=0,
|
615 |
+
interpolation='bilinear'):
|
616 |
+
"""Shear an image.
|
617 |
+
|
618 |
+
Args:
|
619 |
+
img (ndarray): Image to be sheared with format (h, w)
|
620 |
+
or (h, w, c).
|
621 |
+
magnitude (int | float): The magnitude used for shear.
|
622 |
+
direction (str): The flip direction, either "horizontal"
|
623 |
+
or "vertical".
|
624 |
+
border_value (int | tuple[int]): Value used in case of a
|
625 |
+
constant border.
|
626 |
+
interpolation (str): Same as :func:`resize`.
|
627 |
+
|
628 |
+
Returns:
|
629 |
+
ndarray: The sheared image.
|
630 |
+
"""
|
631 |
+
assert direction in ['horizontal',
|
632 |
+
'vertical'], f'Invalid direction: {direction}'
|
633 |
+
height, width = img.shape[:2]
|
634 |
+
if img.ndim == 2:
|
635 |
+
channels = 1
|
636 |
+
elif img.ndim == 3:
|
637 |
+
channels = img.shape[-1]
|
638 |
+
if isinstance(border_value, int):
|
639 |
+
border_value = tuple([border_value] * channels)
|
640 |
+
elif isinstance(border_value, tuple):
|
641 |
+
assert len(border_value) == channels, \
|
642 |
+
'Expected the num of elements in tuple equals the channels' \
|
643 |
+
'of input image. Found {} vs {}'.format(
|
644 |
+
len(border_value), channels)
|
645 |
+
else:
|
646 |
+
raise ValueError(
|
647 |
+
f'Invalid type {type(border_value)} for `border_value`')
|
648 |
+
shear_matrix = _get_shear_matrix(magnitude, direction)
|
649 |
+
sheared = cv2.warpAffine(
|
650 |
+
img,
|
651 |
+
shear_matrix,
|
652 |
+
(width, height),
|
653 |
+
# Note case when the number elements in `border_value`
|
654 |
+
# greater than 3 (e.g. shearing masks whose channels large
|
655 |
+
# than 3) will raise TypeError in `cv2.warpAffine`.
|
656 |
+
# Here simply slice the first 3 values in `border_value`.
|
657 |
+
borderValue=border_value[:3],
|
658 |
+
flags=cv2_interp_codes[interpolation])
|
659 |
+
return sheared
|
660 |
+
|
661 |
+
|
662 |
+
def _get_translate_matrix(offset, direction='horizontal'):
|
663 |
+
"""Generate the translate matrix.
|
664 |
+
|
665 |
+
Args:
|
666 |
+
offset (int | float): The offset used for translate.
|
667 |
+
direction (str): The translate direction, either
|
668 |
+
"horizontal" or "vertical".
|
669 |
+
|
670 |
+
Returns:
|
671 |
+
ndarray: The translate matrix with dtype float32.
|
672 |
+
"""
|
673 |
+
if direction == 'horizontal':
|
674 |
+
translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]])
|
675 |
+
elif direction == 'vertical':
|
676 |
+
translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]])
|
677 |
+
return translate_matrix
|
678 |
+
|
679 |
+
|
680 |
+
def imtranslate(img,
|
681 |
+
offset,
|
682 |
+
direction='horizontal',
|
683 |
+
border_value=0,
|
684 |
+
interpolation='bilinear'):
|
685 |
+
"""Translate an image.
|
686 |
+
|
687 |
+
Args:
|
688 |
+
img (ndarray): Image to be translated with format
|
689 |
+
(h, w) or (h, w, c).
|
690 |
+
offset (int | float): The offset used for translate.
|
691 |
+
direction (str): The translate direction, either "horizontal"
|
692 |
+
or "vertical".
|
693 |
+
border_value (int | tuple[int]): Value used in case of a
|
694 |
+
constant border.
|
695 |
+
interpolation (str): Same as :func:`resize`.
|
696 |
+
|
697 |
+
Returns:
|
698 |
+
ndarray: The translated image.
|
699 |
+
"""
|
700 |
+
assert direction in ['horizontal',
|
701 |
+
'vertical'], f'Invalid direction: {direction}'
|
702 |
+
height, width = img.shape[:2]
|
703 |
+
if img.ndim == 2:
|
704 |
+
channels = 1
|
705 |
+
elif img.ndim == 3:
|
706 |
+
channels = img.shape[-1]
|
707 |
+
if isinstance(border_value, int):
|
708 |
+
border_value = tuple([border_value] * channels)
|
709 |
+
elif isinstance(border_value, tuple):
|
710 |
+
assert len(border_value) == channels, \
|
711 |
+
'Expected the num of elements in tuple equals the channels' \
|
712 |
+
'of input image. Found {} vs {}'.format(
|
713 |
+
len(border_value), channels)
|
714 |
+
else:
|
715 |
+
raise ValueError(
|
716 |
+
f'Invalid type {type(border_value)} for `border_value`.')
|
717 |
+
translate_matrix = _get_translate_matrix(offset, direction)
|
718 |
+
translated = cv2.warpAffine(
|
719 |
+
img,
|
720 |
+
translate_matrix,
|
721 |
+
(width, height),
|
722 |
+
# Note case when the number elements in `border_value`
|
723 |
+
# greater than 3 (e.g. translating masks whose channels
|
724 |
+
# large than 3) will raise TypeError in `cv2.warpAffine`.
|
725 |
+
# Here simply slice the first 3 values in `border_value`.
|
726 |
+
borderValue=border_value[:3],
|
727 |
+
flags=cv2_interp_codes[interpolation])
|
728 |
+
return translated
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/io.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import io
|
3 |
+
import os.path as osp
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
|
9 |
+
IMREAD_UNCHANGED)
|
10 |
+
|
11 |
+
from annotator.mmpkg.mmcv.utils import check_file_exist, is_str, mkdir_or_exist
|
12 |
+
|
13 |
+
try:
|
14 |
+
from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
|
15 |
+
except ImportError:
|
16 |
+
TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None
|
17 |
+
|
18 |
+
try:
|
19 |
+
from PIL import Image, ImageOps
|
20 |
+
except ImportError:
|
21 |
+
Image = None
|
22 |
+
|
23 |
+
try:
|
24 |
+
import tifffile
|
25 |
+
except ImportError:
|
26 |
+
tifffile = None
|
27 |
+
|
28 |
+
jpeg = None
|
29 |
+
supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile']
|
30 |
+
|
31 |
+
imread_flags = {
|
32 |
+
'color': IMREAD_COLOR,
|
33 |
+
'grayscale': IMREAD_GRAYSCALE,
|
34 |
+
'unchanged': IMREAD_UNCHANGED,
|
35 |
+
'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR,
|
36 |
+
'grayscale_ignore_orientation':
|
37 |
+
IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE
|
38 |
+
}
|
39 |
+
|
40 |
+
imread_backend = 'cv2'
|
41 |
+
|
42 |
+
|
43 |
+
def use_backend(backend):
|
44 |
+
"""Select a backend for image decoding.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
backend (str): The image decoding backend type. Options are `cv2`,
|
48 |
+
`pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG)
|
49 |
+
and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg`
|
50 |
+
file format.
|
51 |
+
"""
|
52 |
+
assert backend in supported_backends
|
53 |
+
global imread_backend
|
54 |
+
imread_backend = backend
|
55 |
+
if imread_backend == 'turbojpeg':
|
56 |
+
if TurboJPEG is None:
|
57 |
+
raise ImportError('`PyTurboJPEG` is not installed')
|
58 |
+
global jpeg
|
59 |
+
if jpeg is None:
|
60 |
+
jpeg = TurboJPEG()
|
61 |
+
elif imread_backend == 'pillow':
|
62 |
+
if Image is None:
|
63 |
+
raise ImportError('`Pillow` is not installed')
|
64 |
+
elif imread_backend == 'tifffile':
|
65 |
+
if tifffile is None:
|
66 |
+
raise ImportError('`tifffile` is not installed')
|
67 |
+
|
68 |
+
|
69 |
+
def _jpegflag(flag='color', channel_order='bgr'):
|
70 |
+
channel_order = channel_order.lower()
|
71 |
+
if channel_order not in ['rgb', 'bgr']:
|
72 |
+
raise ValueError('channel order must be either "rgb" or "bgr"')
|
73 |
+
|
74 |
+
if flag == 'color':
|
75 |
+
if channel_order == 'bgr':
|
76 |
+
return TJPF_BGR
|
77 |
+
elif channel_order == 'rgb':
|
78 |
+
return TJCS_RGB
|
79 |
+
elif flag == 'grayscale':
|
80 |
+
return TJPF_GRAY
|
81 |
+
else:
|
82 |
+
raise ValueError('flag must be "color" or "grayscale"')
|
83 |
+
|
84 |
+
|
85 |
+
def _pillow2array(img, flag='color', channel_order='bgr'):
|
86 |
+
"""Convert a pillow image to numpy array.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
img (:obj:`PIL.Image.Image`): The image loaded using PIL
|
90 |
+
flag (str): Flags specifying the color type of a loaded image,
|
91 |
+
candidates are 'color', 'grayscale' and 'unchanged'.
|
92 |
+
Default to 'color'.
|
93 |
+
channel_order (str): The channel order of the output image array,
|
94 |
+
candidates are 'bgr' and 'rgb'. Default to 'bgr'.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
np.ndarray: The converted numpy array
|
98 |
+
"""
|
99 |
+
channel_order = channel_order.lower()
|
100 |
+
if channel_order not in ['rgb', 'bgr']:
|
101 |
+
raise ValueError('channel order must be either "rgb" or "bgr"')
|
102 |
+
|
103 |
+
if flag == 'unchanged':
|
104 |
+
array = np.array(img)
|
105 |
+
if array.ndim >= 3 and array.shape[2] >= 3: # color image
|
106 |
+
array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR
|
107 |
+
else:
|
108 |
+
# Handle exif orientation tag
|
109 |
+
if flag in ['color', 'grayscale']:
|
110 |
+
img = ImageOps.exif_transpose(img)
|
111 |
+
# If the image mode is not 'RGB', convert it to 'RGB' first.
|
112 |
+
if img.mode != 'RGB':
|
113 |
+
if img.mode != 'LA':
|
114 |
+
# Most formats except 'LA' can be directly converted to RGB
|
115 |
+
img = img.convert('RGB')
|
116 |
+
else:
|
117 |
+
# When the mode is 'LA', the default conversion will fill in
|
118 |
+
# the canvas with black, which sometimes shadows black objects
|
119 |
+
# in the foreground.
|
120 |
+
#
|
121 |
+
# Therefore, a random color (124, 117, 104) is used for canvas
|
122 |
+
img_rgba = img.convert('RGBA')
|
123 |
+
img = Image.new('RGB', img_rgba.size, (124, 117, 104))
|
124 |
+
img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha
|
125 |
+
if flag in ['color', 'color_ignore_orientation']:
|
126 |
+
array = np.array(img)
|
127 |
+
if channel_order != 'rgb':
|
128 |
+
array = array[:, :, ::-1] # RGB to BGR
|
129 |
+
elif flag in ['grayscale', 'grayscale_ignore_orientation']:
|
130 |
+
img = img.convert('L')
|
131 |
+
array = np.array(img)
|
132 |
+
else:
|
133 |
+
raise ValueError(
|
134 |
+
'flag must be "color", "grayscale", "unchanged", '
|
135 |
+
f'"color_ignore_orientation" or "grayscale_ignore_orientation"'
|
136 |
+
f' but got {flag}')
|
137 |
+
return array
|
138 |
+
|
139 |
+
|
140 |
+
def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
|
141 |
+
"""Read an image.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
img_or_path (ndarray or str or Path): Either a numpy array or str or
|
145 |
+
pathlib.Path. If it is a numpy array (loaded image), then
|
146 |
+
it will be returned as is.
|
147 |
+
flag (str): Flags specifying the color type of a loaded image,
|
148 |
+
candidates are `color`, `grayscale`, `unchanged`,
|
149 |
+
`color_ignore_orientation` and `grayscale_ignore_orientation`.
|
150 |
+
By default, `cv2` and `pillow` backend would rotate the image
|
151 |
+
according to its EXIF info unless called with `unchanged` or
|
152 |
+
`*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend
|
153 |
+
always ignore image's EXIF info regardless of the flag.
|
154 |
+
The `turbojpeg` backend only supports `color` and `grayscale`.
|
155 |
+
channel_order (str): Order of channel, candidates are `bgr` and `rgb`.
|
156 |
+
backend (str | None): The image decoding backend type. Options are
|
157 |
+
`cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
|
158 |
+
If backend is None, the global imread_backend specified by
|
159 |
+
``mmcv.use_backend()`` will be used. Default: None.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
ndarray: Loaded image array.
|
163 |
+
"""
|
164 |
+
|
165 |
+
if backend is None:
|
166 |
+
backend = imread_backend
|
167 |
+
if backend not in supported_backends:
|
168 |
+
raise ValueError(f'backend: {backend} is not supported. Supported '
|
169 |
+
"backends are 'cv2', 'turbojpeg', 'pillow'")
|
170 |
+
if isinstance(img_or_path, Path):
|
171 |
+
img_or_path = str(img_or_path)
|
172 |
+
|
173 |
+
if isinstance(img_or_path, np.ndarray):
|
174 |
+
return img_or_path
|
175 |
+
elif is_str(img_or_path):
|
176 |
+
check_file_exist(img_or_path,
|
177 |
+
f'img file does not exist: {img_or_path}')
|
178 |
+
if backend == 'turbojpeg':
|
179 |
+
with open(img_or_path, 'rb') as in_file:
|
180 |
+
img = jpeg.decode(in_file.read(),
|
181 |
+
_jpegflag(flag, channel_order))
|
182 |
+
if img.shape[-1] == 1:
|
183 |
+
img = img[:, :, 0]
|
184 |
+
return img
|
185 |
+
elif backend == 'pillow':
|
186 |
+
img = Image.open(img_or_path)
|
187 |
+
img = _pillow2array(img, flag, channel_order)
|
188 |
+
return img
|
189 |
+
elif backend == 'tifffile':
|
190 |
+
img = tifffile.imread(img_or_path)
|
191 |
+
return img
|
192 |
+
else:
|
193 |
+
flag = imread_flags[flag] if is_str(flag) else flag
|
194 |
+
img = cv2.imread(img_or_path, flag)
|
195 |
+
if flag == IMREAD_COLOR and channel_order == 'rgb':
|
196 |
+
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
|
197 |
+
return img
|
198 |
+
else:
|
199 |
+
raise TypeError('"img" must be a numpy array or a str or '
|
200 |
+
'a pathlib.Path object')
|
201 |
+
|
202 |
+
|
203 |
+
def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
|
204 |
+
"""Read an image from bytes.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
content (bytes): Image bytes got from files or other streams.
|
208 |
+
flag (str): Same as :func:`imread`.
|
209 |
+
backend (str | None): The image decoding backend type. Options are
|
210 |
+
`cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
|
211 |
+
global imread_backend specified by ``mmcv.use_backend()`` will be
|
212 |
+
used. Default: None.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
ndarray: Loaded image array.
|
216 |
+
"""
|
217 |
+
|
218 |
+
if backend is None:
|
219 |
+
backend = imread_backend
|
220 |
+
if backend not in supported_backends:
|
221 |
+
raise ValueError(f'backend: {backend} is not supported. Supported '
|
222 |
+
"backends are 'cv2', 'turbojpeg', 'pillow'")
|
223 |
+
if backend == 'turbojpeg':
|
224 |
+
img = jpeg.decode(content, _jpegflag(flag, channel_order))
|
225 |
+
if img.shape[-1] == 1:
|
226 |
+
img = img[:, :, 0]
|
227 |
+
return img
|
228 |
+
elif backend == 'pillow':
|
229 |
+
buff = io.BytesIO(content)
|
230 |
+
img = Image.open(buff)
|
231 |
+
img = _pillow2array(img, flag, channel_order)
|
232 |
+
return img
|
233 |
+
else:
|
234 |
+
img_np = np.frombuffer(content, np.uint8)
|
235 |
+
flag = imread_flags[flag] if is_str(flag) else flag
|
236 |
+
img = cv2.imdecode(img_np, flag)
|
237 |
+
if flag == IMREAD_COLOR and channel_order == 'rgb':
|
238 |
+
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
|
239 |
+
return img
|
240 |
+
|
241 |
+
|
242 |
+
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
243 |
+
"""Write image to file.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
img (ndarray): Image array to be written.
|
247 |
+
file_path (str): Image file path.
|
248 |
+
params (None or list): Same as opencv :func:`imwrite` interface.
|
249 |
+
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
250 |
+
whether to create it automatically.
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
bool: Successful or not.
|
254 |
+
"""
|
255 |
+
if auto_mkdir:
|
256 |
+
dir_name = osp.abspath(osp.dirname(file_path))
|
257 |
+
mkdir_or_exist(dir_name)
|
258 |
+
return cv2.imwrite(file_path, img, params)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/misc.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import annotator.mmpkg.mmcv as mmcv
|
5 |
+
|
6 |
+
try:
|
7 |
+
import torch
|
8 |
+
except ImportError:
|
9 |
+
torch = None
|
10 |
+
|
11 |
+
|
12 |
+
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
|
13 |
+
"""Convert tensor to 3-channel images.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
tensor (torch.Tensor): Tensor that contains multiple images, shape (
|
17 |
+
N, C, H, W).
|
18 |
+
mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
|
19 |
+
std (tuple[float], optional): Standard deviation of images.
|
20 |
+
Defaults to (1, 1, 1).
|
21 |
+
to_rgb (bool, optional): Whether the tensor was converted to RGB
|
22 |
+
format in the first place. If so, convert it back to BGR.
|
23 |
+
Defaults to True.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
list[np.ndarray]: A list that contains multiple images.
|
27 |
+
"""
|
28 |
+
|
29 |
+
if torch is None:
|
30 |
+
raise RuntimeError('pytorch is not installed')
|
31 |
+
assert torch.is_tensor(tensor) and tensor.ndim == 4
|
32 |
+
assert len(mean) == 3
|
33 |
+
assert len(std) == 3
|
34 |
+
|
35 |
+
num_imgs = tensor.size(0)
|
36 |
+
mean = np.array(mean, dtype=np.float32)
|
37 |
+
std = np.array(std, dtype=np.float32)
|
38 |
+
imgs = []
|
39 |
+
for img_id in range(num_imgs):
|
40 |
+
img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
|
41 |
+
img = mmcv.imdenormalize(
|
42 |
+
img, mean, std, to_bgr=to_rgb).astype(np.uint8)
|
43 |
+
imgs.append(np.ascontiguousarray(img))
|
44 |
+
return imgs
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/photometric.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from ..utils import is_tuple_of
|
6 |
+
from .colorspace import bgr2gray, gray2bgr
|
7 |
+
|
8 |
+
|
9 |
+
def imnormalize(img, mean, std, to_rgb=True):
|
10 |
+
"""Normalize an image with mean and std.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
img (ndarray): Image to be normalized.
|
14 |
+
mean (ndarray): The mean to be used for normalize.
|
15 |
+
std (ndarray): The std to be used for normalize.
|
16 |
+
to_rgb (bool): Whether to convert to rgb.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
ndarray: The normalized image.
|
20 |
+
"""
|
21 |
+
img = img.copy().astype(np.float32)
|
22 |
+
return imnormalize_(img, mean, std, to_rgb)
|
23 |
+
|
24 |
+
|
25 |
+
def imnormalize_(img, mean, std, to_rgb=True):
|
26 |
+
"""Inplace normalize an image with mean and std.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
img (ndarray): Image to be normalized.
|
30 |
+
mean (ndarray): The mean to be used for normalize.
|
31 |
+
std (ndarray): The std to be used for normalize.
|
32 |
+
to_rgb (bool): Whether to convert to rgb.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
ndarray: The normalized image.
|
36 |
+
"""
|
37 |
+
# cv2 inplace normalization does not accept uint8
|
38 |
+
assert img.dtype != np.uint8
|
39 |
+
mean = np.float64(mean.reshape(1, -1))
|
40 |
+
stdinv = 1 / np.float64(std.reshape(1, -1))
|
41 |
+
if to_rgb:
|
42 |
+
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
|
43 |
+
cv2.subtract(img, mean, img) # inplace
|
44 |
+
cv2.multiply(img, stdinv, img) # inplace
|
45 |
+
return img
|
46 |
+
|
47 |
+
|
48 |
+
def imdenormalize(img, mean, std, to_bgr=True):
|
49 |
+
assert img.dtype != np.uint8
|
50 |
+
mean = mean.reshape(1, -1).astype(np.float64)
|
51 |
+
std = std.reshape(1, -1).astype(np.float64)
|
52 |
+
img = cv2.multiply(img, std) # make a copy
|
53 |
+
cv2.add(img, mean, img) # inplace
|
54 |
+
if to_bgr:
|
55 |
+
cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
|
56 |
+
return img
|
57 |
+
|
58 |
+
|
59 |
+
def iminvert(img):
|
60 |
+
"""Invert (negate) an image.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
img (ndarray): Image to be inverted.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
ndarray: The inverted image.
|
67 |
+
"""
|
68 |
+
return np.full_like(img, 255) - img
|
69 |
+
|
70 |
+
|
71 |
+
def solarize(img, thr=128):
|
72 |
+
"""Solarize an image (invert all pixel values above a threshold)
|
73 |
+
|
74 |
+
Args:
|
75 |
+
img (ndarray): Image to be solarized.
|
76 |
+
thr (int): Threshold for solarizing (0 - 255).
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
ndarray: The solarized image.
|
80 |
+
"""
|
81 |
+
img = np.where(img < thr, img, 255 - img)
|
82 |
+
return img
|
83 |
+
|
84 |
+
|
85 |
+
def posterize(img, bits):
|
86 |
+
"""Posterize an image (reduce the number of bits for each color channel)
|
87 |
+
|
88 |
+
Args:
|
89 |
+
img (ndarray): Image to be posterized.
|
90 |
+
bits (int): Number of bits (1 to 8) to use for posterizing.
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
ndarray: The posterized image.
|
94 |
+
"""
|
95 |
+
shift = 8 - bits
|
96 |
+
img = np.left_shift(np.right_shift(img, shift), shift)
|
97 |
+
return img
|
98 |
+
|
99 |
+
|
100 |
+
def adjust_color(img, alpha=1, beta=None, gamma=0):
|
101 |
+
r"""It blends the source image and its gray image:
|
102 |
+
|
103 |
+
.. math::
|
104 |
+
output = img * alpha + gray\_img * beta + gamma
|
105 |
+
|
106 |
+
Args:
|
107 |
+
img (ndarray): The input source image.
|
108 |
+
alpha (int | float): Weight for the source image. Default 1.
|
109 |
+
beta (int | float): Weight for the converted gray image.
|
110 |
+
If None, it's assigned the value (1 - `alpha`).
|
111 |
+
gamma (int | float): Scalar added to each sum.
|
112 |
+
Same as :func:`cv2.addWeighted`. Default 0.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
ndarray: Colored image which has the same size and dtype as input.
|
116 |
+
"""
|
117 |
+
gray_img = bgr2gray(img)
|
118 |
+
gray_img = np.tile(gray_img[..., None], [1, 1, 3])
|
119 |
+
if beta is None:
|
120 |
+
beta = 1 - alpha
|
121 |
+
colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
|
122 |
+
if not colored_img.dtype == np.uint8:
|
123 |
+
# Note when the dtype of `img` is not the default `np.uint8`
|
124 |
+
# (e.g. np.float32), the value in `colored_img` got from cv2
|
125 |
+
# is not guaranteed to be in range [0, 255], so here clip
|
126 |
+
# is needed.
|
127 |
+
colored_img = np.clip(colored_img, 0, 255)
|
128 |
+
return colored_img
|
129 |
+
|
130 |
+
|
131 |
+
def imequalize(img):
|
132 |
+
"""Equalize the image histogram.
|
133 |
+
|
134 |
+
This function applies a non-linear mapping to the input image,
|
135 |
+
in order to create a uniform distribution of grayscale values
|
136 |
+
in the output image.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
img (ndarray): Image to be equalized.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
ndarray: The equalized image.
|
143 |
+
"""
|
144 |
+
|
145 |
+
def _scale_channel(im, c):
|
146 |
+
"""Scale the data in the corresponding channel."""
|
147 |
+
im = im[:, :, c]
|
148 |
+
# Compute the histogram of the image channel.
|
149 |
+
histo = np.histogram(im, 256, (0, 255))[0]
|
150 |
+
# For computing the step, filter out the nonzeros.
|
151 |
+
nonzero_histo = histo[histo > 0]
|
152 |
+
step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255
|
153 |
+
if not step:
|
154 |
+
lut = np.array(range(256))
|
155 |
+
else:
|
156 |
+
# Compute the cumulative sum, shifted by step // 2
|
157 |
+
# and then normalized by step.
|
158 |
+
lut = (np.cumsum(histo) + (step // 2)) // step
|
159 |
+
# Shift lut, prepending with 0.
|
160 |
+
lut = np.concatenate([[0], lut[:-1]], 0)
|
161 |
+
# handle potential integer overflow
|
162 |
+
lut[lut > 255] = 255
|
163 |
+
# If step is zero, return the original image.
|
164 |
+
# Otherwise, index from lut.
|
165 |
+
return np.where(np.equal(step, 0), im, lut[im])
|
166 |
+
|
167 |
+
# Scales each channel independently and then stacks
|
168 |
+
# the result.
|
169 |
+
s1 = _scale_channel(img, 0)
|
170 |
+
s2 = _scale_channel(img, 1)
|
171 |
+
s3 = _scale_channel(img, 2)
|
172 |
+
equalized_img = np.stack([s1, s2, s3], axis=-1)
|
173 |
+
return equalized_img.astype(img.dtype)
|
174 |
+
|
175 |
+
|
176 |
+
def adjust_brightness(img, factor=1.):
|
177 |
+
"""Adjust image brightness.
|
178 |
+
|
179 |
+
This function controls the brightness of an image. An
|
180 |
+
enhancement factor of 0.0 gives a black image.
|
181 |
+
A factor of 1.0 gives the original image. This function
|
182 |
+
blends the source image and the degenerated black image:
|
183 |
+
|
184 |
+
.. math::
|
185 |
+
output = img * factor + degenerated * (1 - factor)
|
186 |
+
|
187 |
+
Args:
|
188 |
+
img (ndarray): Image to be brightened.
|
189 |
+
factor (float): A value controls the enhancement.
|
190 |
+
Factor 1.0 returns the original image, lower
|
191 |
+
factors mean less color (brightness, contrast,
|
192 |
+
etc), and higher values more. Default 1.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
ndarray: The brightened image.
|
196 |
+
"""
|
197 |
+
degenerated = np.zeros_like(img)
|
198 |
+
# Note manually convert the dtype to np.float32, to
|
199 |
+
# achieve as close results as PIL.ImageEnhance.Brightness.
|
200 |
+
# Set beta=1-factor, and gamma=0
|
201 |
+
brightened_img = cv2.addWeighted(
|
202 |
+
img.astype(np.float32), factor, degenerated.astype(np.float32),
|
203 |
+
1 - factor, 0)
|
204 |
+
brightened_img = np.clip(brightened_img, 0, 255)
|
205 |
+
return brightened_img.astype(img.dtype)
|
206 |
+
|
207 |
+
|
208 |
+
def adjust_contrast(img, factor=1.):
|
209 |
+
"""Adjust image contrast.
|
210 |
+
|
211 |
+
This function controls the contrast of an image. An
|
212 |
+
enhancement factor of 0.0 gives a solid grey
|
213 |
+
image. A factor of 1.0 gives the original image. It
|
214 |
+
blends the source image and the degenerated mean image:
|
215 |
+
|
216 |
+
.. math::
|
217 |
+
output = img * factor + degenerated * (1 - factor)
|
218 |
+
|
219 |
+
Args:
|
220 |
+
img (ndarray): Image to be contrasted. BGR order.
|
221 |
+
factor (float): Same as :func:`mmcv.adjust_brightness`.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
ndarray: The contrasted image.
|
225 |
+
"""
|
226 |
+
gray_img = bgr2gray(img)
|
227 |
+
hist = np.histogram(gray_img, 256, (0, 255))[0]
|
228 |
+
mean = round(np.sum(gray_img) / np.sum(hist))
|
229 |
+
degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
|
230 |
+
degenerated = gray2bgr(degenerated)
|
231 |
+
contrasted_img = cv2.addWeighted(
|
232 |
+
img.astype(np.float32), factor, degenerated.astype(np.float32),
|
233 |
+
1 - factor, 0)
|
234 |
+
contrasted_img = np.clip(contrasted_img, 0, 255)
|
235 |
+
return contrasted_img.astype(img.dtype)
|
236 |
+
|
237 |
+
|
238 |
+
def auto_contrast(img, cutoff=0):
|
239 |
+
"""Auto adjust image contrast.
|
240 |
+
|
241 |
+
This function maximize (normalize) image contrast by first removing cutoff
|
242 |
+
percent of the lightest and darkest pixels from the histogram and remapping
|
243 |
+
the image so that the darkest pixel becomes black (0), and the lightest
|
244 |
+
becomes white (255).
|
245 |
+
|
246 |
+
Args:
|
247 |
+
img (ndarray): Image to be contrasted. BGR order.
|
248 |
+
cutoff (int | float | tuple): The cutoff percent of the lightest and
|
249 |
+
darkest pixels to be removed. If given as tuple, it shall be
|
250 |
+
(low, high). Otherwise, the single value will be used for both.
|
251 |
+
Defaults to 0.
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
ndarray: The contrasted image.
|
255 |
+
"""
|
256 |
+
|
257 |
+
def _auto_contrast_channel(im, c, cutoff):
|
258 |
+
im = im[:, :, c]
|
259 |
+
# Compute the histogram of the image channel.
|
260 |
+
histo = np.histogram(im, 256, (0, 255))[0]
|
261 |
+
# Remove cut-off percent pixels from histo
|
262 |
+
histo_sum = np.cumsum(histo)
|
263 |
+
cut_low = histo_sum[-1] * cutoff[0] // 100
|
264 |
+
cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100
|
265 |
+
histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low
|
266 |
+
histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0)
|
267 |
+
|
268 |
+
# Compute mapping
|
269 |
+
low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1]
|
270 |
+
# If all the values have been cut off, return the origin img
|
271 |
+
if low >= high:
|
272 |
+
return im
|
273 |
+
scale = 255.0 / (high - low)
|
274 |
+
offset = -low * scale
|
275 |
+
lut = np.array(range(256))
|
276 |
+
lut = lut * scale + offset
|
277 |
+
lut = np.clip(lut, 0, 255)
|
278 |
+
return lut[im]
|
279 |
+
|
280 |
+
if isinstance(cutoff, (int, float)):
|
281 |
+
cutoff = (cutoff, cutoff)
|
282 |
+
else:
|
283 |
+
assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \
|
284 |
+
f'float or tuple, but got {type(cutoff)} instead.'
|
285 |
+
# Auto adjusts contrast for each channel independently and then stacks
|
286 |
+
# the result.
|
287 |
+
s1 = _auto_contrast_channel(img, 0, cutoff)
|
288 |
+
s2 = _auto_contrast_channel(img, 1, cutoff)
|
289 |
+
s3 = _auto_contrast_channel(img, 2, cutoff)
|
290 |
+
contrasted_img = np.stack([s1, s2, s3], axis=-1)
|
291 |
+
return contrasted_img.astype(img.dtype)
|
292 |
+
|
293 |
+
|
294 |
+
def adjust_sharpness(img, factor=1., kernel=None):
|
295 |
+
"""Adjust image sharpness.
|
296 |
+
|
297 |
+
This function controls the sharpness of an image. An
|
298 |
+
enhancement factor of 0.0 gives a blurred image. A
|
299 |
+
factor of 1.0 gives the original image. And a factor
|
300 |
+
of 2.0 gives a sharpened image. It blends the source
|
301 |
+
image and the degenerated mean image:
|
302 |
+
|
303 |
+
.. math::
|
304 |
+
output = img * factor + degenerated * (1 - factor)
|
305 |
+
|
306 |
+
Args:
|
307 |
+
img (ndarray): Image to be sharpened. BGR order.
|
308 |
+
factor (float): Same as :func:`mmcv.adjust_brightness`.
|
309 |
+
kernel (np.ndarray, optional): Filter kernel to be applied on the img
|
310 |
+
to obtain the degenerated img. Defaults to None.
|
311 |
+
|
312 |
+
Note:
|
313 |
+
No value sanity check is enforced on the kernel set by users. So with
|
314 |
+
an inappropriate kernel, the ``adjust_sharpness`` may fail to perform
|
315 |
+
the function its name indicates but end up performing whatever
|
316 |
+
transform determined by the kernel.
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
ndarray: The sharpened image.
|
320 |
+
"""
|
321 |
+
|
322 |
+
if kernel is None:
|
323 |
+
# adopted from PIL.ImageFilter.SMOOTH
|
324 |
+
kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13
|
325 |
+
assert isinstance(kernel, np.ndarray), \
|
326 |
+
f'kernel must be of type np.ndarray, but got {type(kernel)} instead.'
|
327 |
+
assert kernel.ndim == 2, \
|
328 |
+
f'kernel must have a dimension of 2, but got {kernel.ndim} instead.'
|
329 |
+
|
330 |
+
degenerated = cv2.filter2D(img, -1, kernel)
|
331 |
+
sharpened_img = cv2.addWeighted(
|
332 |
+
img.astype(np.float32), factor, degenerated.astype(np.float32),
|
333 |
+
1 - factor, 0)
|
334 |
+
sharpened_img = np.clip(sharpened_img, 0, 255)
|
335 |
+
return sharpened_img.astype(img.dtype)
|
336 |
+
|
337 |
+
|
338 |
+
def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True):
|
339 |
+
"""AlexNet-style PCA jitter.
|
340 |
+
|
341 |
+
This data augmentation is proposed in `ImageNet Classification with Deep
|
342 |
+
Convolutional Neural Networks
|
343 |
+
<https://dl.acm.org/doi/pdf/10.1145/3065386>`_.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
img (ndarray): Image to be adjusted lighting. BGR order.
|
347 |
+
eigval (ndarray): the eigenvalue of the convariance matrix of pixel
|
348 |
+
values, respectively.
|
349 |
+
eigvec (ndarray): the eigenvector of the convariance matrix of pixel
|
350 |
+
values, respectively.
|
351 |
+
alphastd (float): The standard deviation for distribution of alpha.
|
352 |
+
Defaults to 0.1
|
353 |
+
to_rgb (bool): Whether to convert img to rgb.
|
354 |
+
|
355 |
+
Returns:
|
356 |
+
ndarray: The adjusted image.
|
357 |
+
"""
|
358 |
+
assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \
|
359 |
+
f'eigval and eigvec should both be of type np.ndarray, got ' \
|
360 |
+
f'{type(eigval)} and {type(eigvec)} instead.'
|
361 |
+
|
362 |
+
assert eigval.ndim == 1 and eigvec.ndim == 2
|
363 |
+
assert eigvec.shape == (3, eigval.shape[0])
|
364 |
+
n_eigval = eigval.shape[0]
|
365 |
+
assert isinstance(alphastd, float), 'alphastd should be of type float, ' \
|
366 |
+
f'got {type(alphastd)} instead.'
|
367 |
+
|
368 |
+
img = img.copy().astype(np.float32)
|
369 |
+
if to_rgb:
|
370 |
+
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
|
371 |
+
|
372 |
+
alpha = np.random.normal(0, alphastd, n_eigval)
|
373 |
+
alter = eigvec \
|
374 |
+
* np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \
|
375 |
+
* np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval))
|
376 |
+
alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape)
|
377 |
+
img_adjusted = img + alter
|
378 |
+
return img_adjusted
|
379 |
+
|
380 |
+
|
381 |
+
def lut_transform(img, lut_table):
|
382 |
+
"""Transform array by look-up table.
|
383 |
+
|
384 |
+
The function lut_transform fills the output array with values from the
|
385 |
+
look-up table. Indices of the entries are taken from the input array.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
img (ndarray): Image to be transformed.
|
389 |
+
lut_table (ndarray): look-up table of 256 elements; in case of
|
390 |
+
multi-channel input array, the table should either have a single
|
391 |
+
channel (in this case the same table is used for all channels) or
|
392 |
+
the same number of channels as in the input array.
|
393 |
+
|
394 |
+
Returns:
|
395 |
+
ndarray: The transformed image.
|
396 |
+
"""
|
397 |
+
assert isinstance(img, np.ndarray)
|
398 |
+
assert 0 <= np.min(img) and np.max(img) <= 255
|
399 |
+
assert isinstance(lut_table, np.ndarray)
|
400 |
+
assert lut_table.shape == (256, )
|
401 |
+
|
402 |
+
return cv2.LUT(np.array(img, dtype=np.uint8), lut_table)
|
403 |
+
|
404 |
+
|
405 |
+
def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
|
406 |
+
"""Use CLAHE method to process the image.
|
407 |
+
|
408 |
+
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
|
409 |
+
Graphics Gems, 1994:474-485.` for more information.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
img (ndarray): Image to be processed.
|
413 |
+
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
|
414 |
+
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
|
415 |
+
Input image will be divided into equally sized rectangular tiles.
|
416 |
+
It defines the number of tiles in row and column. Default: (8, 8).
|
417 |
+
|
418 |
+
Returns:
|
419 |
+
ndarray: The processed image.
|
420 |
+
"""
|
421 |
+
assert isinstance(img, np.ndarray)
|
422 |
+
assert img.ndim == 2
|
423 |
+
assert isinstance(clip_limit, (float, int))
|
424 |
+
assert is_tuple_of(tile_grid_size, int)
|
425 |
+
assert len(tile_grid_size) == 2
|
426 |
+
|
427 |
+
clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
|
428 |
+
return clahe.apply(np.array(img, dtype=np.uint8))
|