toto10 commited on
Commit
91281ee
1 Parent(s): 05c7648

b04c629d473b73ca15d008cade0dbdf01deeb1a8f48216e43b46a735e2975a9a

Browse files
Files changed (50) hide show
  1. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/arraymisc/quantization.py +55 -0
  2. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/__init__.py +41 -0
  3. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/alexnet.py +61 -0
  4. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/__init__.py +35 -0
  5. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/activation.py +92 -0
  6. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/context_block.py +125 -0
  7. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv.py +44 -0
  8. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv2d_adaptive_padding.py +62 -0
  9. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_module.py +206 -0
  10. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_ws.py +148 -0
  11. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/depthwise_separable_conv_module.py +96 -0
  12. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/drop.py +65 -0
  13. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/generalized_attention.py +412 -0
  14. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hsigmoid.py +34 -0
  15. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hswish.py +29 -0
  16. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/non_local.py +306 -0
  17. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/norm.py +144 -0
  18. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/padding.py +36 -0
  19. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/plugin.py +88 -0
  20. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/registry.py +16 -0
  21. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/scale.py +21 -0
  22. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/swish.py +25 -0
  23. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/transformer.py +595 -0
  24. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/upsample.py +84 -0
  25. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/wrappers.py +180 -0
  26. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/builder.py +30 -0
  27. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/resnet.py +316 -0
  28. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/__init__.py +19 -0
  29. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/flops_counter.py +599 -0
  30. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/fuse_conv_bn.py +59 -0
  31. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/sync_bn.py +59 -0
  32. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/weight_init.py +684 -0
  33. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/vgg.py +175 -0
  34. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/engine/__init__.py +8 -0
  35. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/engine/test.py +202 -0
  36. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/__init__.py +11 -0
  37. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/file_client.py +1148 -0
  38. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/__init__.py +7 -0
  39. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/base.py +30 -0
  40. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/json_handler.py +36 -0
  41. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/pickle_handler.py +28 -0
  42. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/yaml_handler.py +24 -0
  43. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/io.py +151 -0
  44. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/parse.py +97 -0
  45. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/__init__.py +28 -0
  46. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/colorspace.py +306 -0
  47. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/geometric.py +728 -0
  48. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/io.py +258 -0
  49. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/misc.py +44 -0
  50. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/photometric.py +428 -0
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/arraymisc/quantization.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import numpy as np
3
+
4
+
5
+ def quantize(arr, min_val, max_val, levels, dtype=np.int64):
6
+ """Quantize an array of (-inf, inf) to [0, levels-1].
7
+
8
+ Args:
9
+ arr (ndarray): Input array.
10
+ min_val (scalar): Minimum value to be clipped.
11
+ max_val (scalar): Maximum value to be clipped.
12
+ levels (int): Quantization levels.
13
+ dtype (np.type): The type of the quantized array.
14
+
15
+ Returns:
16
+ tuple: Quantized array.
17
+ """
18
+ if not (isinstance(levels, int) and levels > 1):
19
+ raise ValueError(
20
+ f'levels must be a positive integer, but got {levels}')
21
+ if min_val >= max_val:
22
+ raise ValueError(
23
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
24
+
25
+ arr = np.clip(arr, min_val, max_val) - min_val
26
+ quantized_arr = np.minimum(
27
+ np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
28
+
29
+ return quantized_arr
30
+
31
+
32
+ def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
33
+ """Dequantize an array.
34
+
35
+ Args:
36
+ arr (ndarray): Input array.
37
+ min_val (scalar): Minimum value to be clipped.
38
+ max_val (scalar): Maximum value to be clipped.
39
+ levels (int): Quantization levels.
40
+ dtype (np.type): The type of the dequantized array.
41
+
42
+ Returns:
43
+ tuple: Dequantized array.
44
+ """
45
+ if not (isinstance(levels, int) and levels > 1):
46
+ raise ValueError(
47
+ f'levels must be a positive integer, but got {levels}')
48
+ if min_val >= max_val:
49
+ raise ValueError(
50
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
51
+
52
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
53
+ min_val) / levels + min_val
54
+
55
+ return dequantized_arr
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .alexnet import AlexNet
3
+ # yapf: disable
4
+ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
5
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
6
+ ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
7
+ ConvTranspose2d, ConvTranspose3d, ConvWS2d,
8
+ DepthwiseSeparableConvModule, GeneralizedAttention,
9
+ HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
10
+ NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
11
+ build_activation_layer, build_conv_layer,
12
+ build_norm_layer, build_padding_layer, build_plugin_layer,
13
+ build_upsample_layer, conv_ws_2d, is_norm)
14
+ from .builder import MODELS, build_model_from_cfg
15
+ # yapf: enable
16
+ from .resnet import ResNet, make_res_layer
17
+ from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
18
+ NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
19
+ XavierInit, bias_init_with_prob, caffe2_xavier_init,
20
+ constant_init, fuse_conv_bn, get_model_complexity_info,
21
+ initialize, kaiming_init, normal_init, trunc_normal_init,
22
+ uniform_init, xavier_init)
23
+ from .vgg import VGG, make_vgg_layer
24
+
25
+ __all__ = [
26
+ 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
27
+ 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
28
+ 'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
29
+ 'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
30
+ 'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
31
+ 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
32
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
33
+ 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
34
+ 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
35
+ 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
36
+ 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
37
+ 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
38
+ 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
39
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
40
+ 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
41
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/alexnet.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+
4
+ import torch.nn as nn
5
+
6
+
7
+ class AlexNet(nn.Module):
8
+ """AlexNet backbone.
9
+
10
+ Args:
11
+ num_classes (int): number of classes for classification.
12
+ """
13
+
14
+ def __init__(self, num_classes=-1):
15
+ super(AlexNet, self).__init__()
16
+ self.num_classes = num_classes
17
+ self.features = nn.Sequential(
18
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
19
+ nn.ReLU(inplace=True),
20
+ nn.MaxPool2d(kernel_size=3, stride=2),
21
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
22
+ nn.ReLU(inplace=True),
23
+ nn.MaxPool2d(kernel_size=3, stride=2),
24
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
25
+ nn.ReLU(inplace=True),
26
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
27
+ nn.ReLU(inplace=True),
28
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
29
+ nn.ReLU(inplace=True),
30
+ nn.MaxPool2d(kernel_size=3, stride=2),
31
+ )
32
+ if self.num_classes > 0:
33
+ self.classifier = nn.Sequential(
34
+ nn.Dropout(),
35
+ nn.Linear(256 * 6 * 6, 4096),
36
+ nn.ReLU(inplace=True),
37
+ nn.Dropout(),
38
+ nn.Linear(4096, 4096),
39
+ nn.ReLU(inplace=True),
40
+ nn.Linear(4096, num_classes),
41
+ )
42
+
43
+ def init_weights(self, pretrained=None):
44
+ if isinstance(pretrained, str):
45
+ logger = logging.getLogger()
46
+ from ..runner import load_checkpoint
47
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
48
+ elif pretrained is None:
49
+ # use default initializer
50
+ pass
51
+ else:
52
+ raise TypeError('pretrained must be a str or None')
53
+
54
+ def forward(self, x):
55
+
56
+ x = self.features(x)
57
+ if self.num_classes > 0:
58
+ x = x.view(x.size(0), 256 * 6 * 6)
59
+ x = self.classifier(x)
60
+
61
+ return x
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .activation import build_activation_layer
3
+ from .context_block import ContextBlock
4
+ from .conv import build_conv_layer
5
+ from .conv2d_adaptive_padding import Conv2dAdaptivePadding
6
+ from .conv_module import ConvModule
7
+ from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
8
+ from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
9
+ from .drop import Dropout, DropPath
10
+ from .generalized_attention import GeneralizedAttention
11
+ from .hsigmoid import HSigmoid
12
+ from .hswish import HSwish
13
+ from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
14
+ from .norm import build_norm_layer, is_norm
15
+ from .padding import build_padding_layer
16
+ from .plugin import build_plugin_layer
17
+ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
18
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
19
+ from .scale import Scale
20
+ from .swish import Swish
21
+ from .upsample import build_upsample_layer
22
+ from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
23
+ Linear, MaxPool2d, MaxPool3d)
24
+
25
+ __all__ = [
26
+ 'ConvModule', 'build_activation_layer', 'build_conv_layer',
27
+ 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
28
+ 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
29
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
30
+ 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
31
+ 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
32
+ 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
33
+ 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
34
+ 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
35
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/activation.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from annotator.mmpkg.mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
7
+ from .registry import ACTIVATION_LAYERS
8
+
9
+ for module in [
10
+ nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
11
+ nn.Sigmoid, nn.Tanh
12
+ ]:
13
+ ACTIVATION_LAYERS.register_module(module=module)
14
+
15
+
16
+ @ACTIVATION_LAYERS.register_module(name='Clip')
17
+ @ACTIVATION_LAYERS.register_module()
18
+ class Clamp(nn.Module):
19
+ """Clamp activation layer.
20
+
21
+ This activation function is to clamp the feature map value within
22
+ :math:`[min, max]`. More details can be found in ``torch.clamp()``.
23
+
24
+ Args:
25
+ min (Number | optional): Lower-bound of the range to be clamped to.
26
+ Default to -1.
27
+ max (Number | optional): Upper-bound of the range to be clamped to.
28
+ Default to 1.
29
+ """
30
+
31
+ def __init__(self, min=-1., max=1.):
32
+ super(Clamp, self).__init__()
33
+ self.min = min
34
+ self.max = max
35
+
36
+ def forward(self, x):
37
+ """Forward function.
38
+
39
+ Args:
40
+ x (torch.Tensor): The input tensor.
41
+
42
+ Returns:
43
+ torch.Tensor: Clamped tensor.
44
+ """
45
+ return torch.clamp(x, min=self.min, max=self.max)
46
+
47
+
48
+ class GELU(nn.Module):
49
+ r"""Applies the Gaussian Error Linear Units function:
50
+
51
+ .. math::
52
+ \text{GELU}(x) = x * \Phi(x)
53
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for
54
+ Gaussian Distribution.
55
+
56
+ Shape:
57
+ - Input: :math:`(N, *)` where `*` means, any number of additional
58
+ dimensions
59
+ - Output: :math:`(N, *)`, same shape as the input
60
+
61
+ .. image:: scripts/activation_images/GELU.png
62
+
63
+ Examples::
64
+
65
+ >>> m = nn.GELU()
66
+ >>> input = torch.randn(2)
67
+ >>> output = m(input)
68
+ """
69
+
70
+ def forward(self, input):
71
+ return F.gelu(input)
72
+
73
+
74
+ if (TORCH_VERSION == 'parrots'
75
+ or digit_version(TORCH_VERSION) < digit_version('1.4')):
76
+ ACTIVATION_LAYERS.register_module(module=GELU)
77
+ else:
78
+ ACTIVATION_LAYERS.register_module(module=nn.GELU)
79
+
80
+
81
+ def build_activation_layer(cfg):
82
+ """Build activation layer.
83
+
84
+ Args:
85
+ cfg (dict): The activation layer config, which should contain:
86
+ - type (str): Layer type.
87
+ - layer args: Args needed to instantiate an activation layer.
88
+
89
+ Returns:
90
+ nn.Module: Created activation layer.
91
+ """
92
+ return build_from_cfg(cfg, ACTIVATION_LAYERS)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/context_block.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from torch import nn
4
+
5
+ from ..utils import constant_init, kaiming_init
6
+ from .registry import PLUGIN_LAYERS
7
+
8
+
9
+ def last_zero_init(m):
10
+ if isinstance(m, nn.Sequential):
11
+ constant_init(m[-1], val=0)
12
+ else:
13
+ constant_init(m, val=0)
14
+
15
+
16
+ @PLUGIN_LAYERS.register_module()
17
+ class ContextBlock(nn.Module):
18
+ """ContextBlock module in GCNet.
19
+
20
+ See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
21
+ (https://arxiv.org/abs/1904.11492) for details.
22
+
23
+ Args:
24
+ in_channels (int): Channels of the input feature map.
25
+ ratio (float): Ratio of channels of transform bottleneck
26
+ pooling_type (str): Pooling method for context modeling.
27
+ Options are 'att' and 'avg', stand for attention pooling and
28
+ average pooling respectively. Default: 'att'.
29
+ fusion_types (Sequence[str]): Fusion method for feature fusion,
30
+ Options are 'channels_add', 'channel_mul', stand for channelwise
31
+ addition and multiplication respectively. Default: ('channel_add',)
32
+ """
33
+
34
+ _abbr_ = 'context_block'
35
+
36
+ def __init__(self,
37
+ in_channels,
38
+ ratio,
39
+ pooling_type='att',
40
+ fusion_types=('channel_add', )):
41
+ super(ContextBlock, self).__init__()
42
+ assert pooling_type in ['avg', 'att']
43
+ assert isinstance(fusion_types, (list, tuple))
44
+ valid_fusion_types = ['channel_add', 'channel_mul']
45
+ assert all([f in valid_fusion_types for f in fusion_types])
46
+ assert len(fusion_types) > 0, 'at least one fusion should be used'
47
+ self.in_channels = in_channels
48
+ self.ratio = ratio
49
+ self.planes = int(in_channels * ratio)
50
+ self.pooling_type = pooling_type
51
+ self.fusion_types = fusion_types
52
+ if pooling_type == 'att':
53
+ self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
54
+ self.softmax = nn.Softmax(dim=2)
55
+ else:
56
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
57
+ if 'channel_add' in fusion_types:
58
+ self.channel_add_conv = nn.Sequential(
59
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
60
+ nn.LayerNorm([self.planes, 1, 1]),
61
+ nn.ReLU(inplace=True), # yapf: disable
62
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
63
+ else:
64
+ self.channel_add_conv = None
65
+ if 'channel_mul' in fusion_types:
66
+ self.channel_mul_conv = nn.Sequential(
67
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
68
+ nn.LayerNorm([self.planes, 1, 1]),
69
+ nn.ReLU(inplace=True), # yapf: disable
70
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
71
+ else:
72
+ self.channel_mul_conv = None
73
+ self.reset_parameters()
74
+
75
+ def reset_parameters(self):
76
+ if self.pooling_type == 'att':
77
+ kaiming_init(self.conv_mask, mode='fan_in')
78
+ self.conv_mask.inited = True
79
+
80
+ if self.channel_add_conv is not None:
81
+ last_zero_init(self.channel_add_conv)
82
+ if self.channel_mul_conv is not None:
83
+ last_zero_init(self.channel_mul_conv)
84
+
85
+ def spatial_pool(self, x):
86
+ batch, channel, height, width = x.size()
87
+ if self.pooling_type == 'att':
88
+ input_x = x
89
+ # [N, C, H * W]
90
+ input_x = input_x.view(batch, channel, height * width)
91
+ # [N, 1, C, H * W]
92
+ input_x = input_x.unsqueeze(1)
93
+ # [N, 1, H, W]
94
+ context_mask = self.conv_mask(x)
95
+ # [N, 1, H * W]
96
+ context_mask = context_mask.view(batch, 1, height * width)
97
+ # [N, 1, H * W]
98
+ context_mask = self.softmax(context_mask)
99
+ # [N, 1, H * W, 1]
100
+ context_mask = context_mask.unsqueeze(-1)
101
+ # [N, 1, C, 1]
102
+ context = torch.matmul(input_x, context_mask)
103
+ # [N, C, 1, 1]
104
+ context = context.view(batch, channel, 1, 1)
105
+ else:
106
+ # [N, C, 1, 1]
107
+ context = self.avg_pool(x)
108
+
109
+ return context
110
+
111
+ def forward(self, x):
112
+ # [N, C, 1, 1]
113
+ context = self.spatial_pool(x)
114
+
115
+ out = x
116
+ if self.channel_mul_conv is not None:
117
+ # [N, C, 1, 1]
118
+ channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
119
+ out = out * channel_mul_term
120
+ if self.channel_add_conv is not None:
121
+ # [N, C, 1, 1]
122
+ channel_add_term = self.channel_add_conv(context)
123
+ out = out + channel_add_term
124
+
125
+ return out
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from torch import nn
3
+
4
+ from .registry import CONV_LAYERS
5
+
6
+ CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
7
+ CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
8
+ CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
9
+ CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
10
+
11
+
12
+ def build_conv_layer(cfg, *args, **kwargs):
13
+ """Build convolution layer.
14
+
15
+ Args:
16
+ cfg (None or dict): The conv layer config, which should contain:
17
+ - type (str): Layer type.
18
+ - layer args: Args needed to instantiate an conv layer.
19
+ args (argument list): Arguments passed to the `__init__`
20
+ method of the corresponding conv layer.
21
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
22
+ method of the corresponding conv layer.
23
+
24
+ Returns:
25
+ nn.Module: Created conv layer.
26
+ """
27
+ if cfg is None:
28
+ cfg_ = dict(type='Conv2d')
29
+ else:
30
+ if not isinstance(cfg, dict):
31
+ raise TypeError('cfg must be a dict')
32
+ if 'type' not in cfg:
33
+ raise KeyError('the cfg dict must contain the key "type"')
34
+ cfg_ = cfg.copy()
35
+
36
+ layer_type = cfg_.pop('type')
37
+ if layer_type not in CONV_LAYERS:
38
+ raise KeyError(f'Unrecognized norm type {layer_type}')
39
+ else:
40
+ conv_layer = CONV_LAYERS.get(layer_type)
41
+
42
+ layer = conv_layer(*args, **kwargs, **cfg_)
43
+
44
+ return layer
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv2d_adaptive_padding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from .registry import CONV_LAYERS
8
+
9
+
10
+ @CONV_LAYERS.register_module()
11
+ class Conv2dAdaptivePadding(nn.Conv2d):
12
+ """Implementation of 2D convolution in tensorflow with `padding` as "same",
13
+ which applies padding to input (if needed) so that input image gets fully
14
+ covered by filter and stride you specified. For stride 1, this will ensure
15
+ that output image size is same as input. For stride of 2, output dimensions
16
+ will be half, for example.
17
+
18
+ Args:
19
+ in_channels (int): Number of channels in the input image
20
+ out_channels (int): Number of channels produced by the convolution
21
+ kernel_size (int or tuple): Size of the convolving kernel
22
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
23
+ padding (int or tuple, optional): Zero-padding added to both sides of
24
+ the input. Default: 0
25
+ dilation (int or tuple, optional): Spacing between kernel elements.
26
+ Default: 1
27
+ groups (int, optional): Number of blocked connections from input
28
+ channels to output channels. Default: 1
29
+ bias (bool, optional): If ``True``, adds a learnable bias to the
30
+ output. Default: ``True``
31
+ """
32
+
33
+ def __init__(self,
34
+ in_channels,
35
+ out_channels,
36
+ kernel_size,
37
+ stride=1,
38
+ padding=0,
39
+ dilation=1,
40
+ groups=1,
41
+ bias=True):
42
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0,
43
+ dilation, groups, bias)
44
+
45
+ def forward(self, x):
46
+ img_h, img_w = x.size()[-2:]
47
+ kernel_h, kernel_w = self.weight.size()[-2:]
48
+ stride_h, stride_w = self.stride
49
+ output_h = math.ceil(img_h / stride_h)
50
+ output_w = math.ceil(img_w / stride_w)
51
+ pad_h = (
52
+ max((output_h - 1) * self.stride[0] +
53
+ (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
54
+ pad_w = (
55
+ max((output_w - 1) * self.stride[1] +
56
+ (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
57
+ if pad_h > 0 or pad_w > 0:
58
+ x = F.pad(x, [
59
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
60
+ ])
61
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
62
+ self.dilation, self.groups)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_module.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import torch.nn as nn
5
+
6
+ from annotator.mmpkg.mmcv.utils import _BatchNorm, _InstanceNorm
7
+ from ..utils import constant_init, kaiming_init
8
+ from .activation import build_activation_layer
9
+ from .conv import build_conv_layer
10
+ from .norm import build_norm_layer
11
+ from .padding import build_padding_layer
12
+ from .registry import PLUGIN_LAYERS
13
+
14
+
15
+ @PLUGIN_LAYERS.register_module()
16
+ class ConvModule(nn.Module):
17
+ """A conv block that bundles conv/norm/activation layers.
18
+
19
+ This block simplifies the usage of convolution layers, which are commonly
20
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
21
+ It is based upon three build methods: `build_conv_layer()`,
22
+ `build_norm_layer()` and `build_activation_layer()`.
23
+
24
+ Besides, we add some additional features in this module.
25
+ 1. Automatically set `bias` of the conv layer.
26
+ 2. Spectral norm is supported.
27
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
28
+ supports zero and circular padding, and we add "reflect" padding mode.
29
+
30
+ Args:
31
+ in_channels (int): Number of channels in the input feature map.
32
+ Same as that in ``nn._ConvNd``.
33
+ out_channels (int): Number of channels produced by the convolution.
34
+ Same as that in ``nn._ConvNd``.
35
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
36
+ Same as that in ``nn._ConvNd``.
37
+ stride (int | tuple[int]): Stride of the convolution.
38
+ Same as that in ``nn._ConvNd``.
39
+ padding (int | tuple[int]): Zero-padding added to both sides of
40
+ the input. Same as that in ``nn._ConvNd``.
41
+ dilation (int | tuple[int]): Spacing between kernel elements.
42
+ Same as that in ``nn._ConvNd``.
43
+ groups (int): Number of blocked connections from input channels to
44
+ output channels. Same as that in ``nn._ConvNd``.
45
+ bias (bool | str): If specified as `auto`, it will be decided by the
46
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
47
+ False. Default: "auto".
48
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
49
+ which means using conv2d.
50
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
51
+ act_cfg (dict): Config dict for activation layer.
52
+ Default: dict(type='ReLU').
53
+ inplace (bool): Whether to use inplace mode for activation.
54
+ Default: True.
55
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
56
+ Default: False.
57
+ padding_mode (str): If the `padding_mode` has not been supported by
58
+ current `Conv2d` in PyTorch, we will use our own padding layer
59
+ instead. Currently, we support ['zeros', 'circular'] with official
60
+ implementation and ['reflect'] with our own implementation.
61
+ Default: 'zeros'.
62
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
63
+ sequence of "conv", "norm" and "act". Common examples are
64
+ ("conv", "norm", "act") and ("act", "conv", "norm").
65
+ Default: ('conv', 'norm', 'act').
66
+ """
67
+
68
+ _abbr_ = 'conv_block'
69
+
70
+ def __init__(self,
71
+ in_channels,
72
+ out_channels,
73
+ kernel_size,
74
+ stride=1,
75
+ padding=0,
76
+ dilation=1,
77
+ groups=1,
78
+ bias='auto',
79
+ conv_cfg=None,
80
+ norm_cfg=None,
81
+ act_cfg=dict(type='ReLU'),
82
+ inplace=True,
83
+ with_spectral_norm=False,
84
+ padding_mode='zeros',
85
+ order=('conv', 'norm', 'act')):
86
+ super(ConvModule, self).__init__()
87
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
88
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
89
+ assert act_cfg is None or isinstance(act_cfg, dict)
90
+ official_padding_mode = ['zeros', 'circular']
91
+ self.conv_cfg = conv_cfg
92
+ self.norm_cfg = norm_cfg
93
+ self.act_cfg = act_cfg
94
+ self.inplace = inplace
95
+ self.with_spectral_norm = with_spectral_norm
96
+ self.with_explicit_padding = padding_mode not in official_padding_mode
97
+ self.order = order
98
+ assert isinstance(self.order, tuple) and len(self.order) == 3
99
+ assert set(order) == set(['conv', 'norm', 'act'])
100
+
101
+ self.with_norm = norm_cfg is not None
102
+ self.with_activation = act_cfg is not None
103
+ # if the conv layer is before a norm layer, bias is unnecessary.
104
+ if bias == 'auto':
105
+ bias = not self.with_norm
106
+ self.with_bias = bias
107
+
108
+ if self.with_explicit_padding:
109
+ pad_cfg = dict(type=padding_mode)
110
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
111
+
112
+ # reset padding to 0 for conv module
113
+ conv_padding = 0 if self.with_explicit_padding else padding
114
+ # build convolution layer
115
+ self.conv = build_conv_layer(
116
+ conv_cfg,
117
+ in_channels,
118
+ out_channels,
119
+ kernel_size,
120
+ stride=stride,
121
+ padding=conv_padding,
122
+ dilation=dilation,
123
+ groups=groups,
124
+ bias=bias)
125
+ # export the attributes of self.conv to a higher level for convenience
126
+ self.in_channels = self.conv.in_channels
127
+ self.out_channels = self.conv.out_channels
128
+ self.kernel_size = self.conv.kernel_size
129
+ self.stride = self.conv.stride
130
+ self.padding = padding
131
+ self.dilation = self.conv.dilation
132
+ self.transposed = self.conv.transposed
133
+ self.output_padding = self.conv.output_padding
134
+ self.groups = self.conv.groups
135
+
136
+ if self.with_spectral_norm:
137
+ self.conv = nn.utils.spectral_norm(self.conv)
138
+
139
+ # build normalization layers
140
+ if self.with_norm:
141
+ # norm layer is after conv layer
142
+ if order.index('norm') > order.index('conv'):
143
+ norm_channels = out_channels
144
+ else:
145
+ norm_channels = in_channels
146
+ self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
147
+ self.add_module(self.norm_name, norm)
148
+ if self.with_bias:
149
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
150
+ warnings.warn(
151
+ 'Unnecessary conv bias before batch/instance norm')
152
+ else:
153
+ self.norm_name = None
154
+
155
+ # build activation layer
156
+ if self.with_activation:
157
+ act_cfg_ = act_cfg.copy()
158
+ # nn.Tanh has no 'inplace' argument
159
+ if act_cfg_['type'] not in [
160
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
161
+ ]:
162
+ act_cfg_.setdefault('inplace', inplace)
163
+ self.activate = build_activation_layer(act_cfg_)
164
+
165
+ # Use msra init by default
166
+ self.init_weights()
167
+
168
+ @property
169
+ def norm(self):
170
+ if self.norm_name:
171
+ return getattr(self, self.norm_name)
172
+ else:
173
+ return None
174
+
175
+ def init_weights(self):
176
+ # 1. It is mainly for customized conv layers with their own
177
+ # initialization manners by calling their own ``init_weights()``,
178
+ # and we do not want ConvModule to override the initialization.
179
+ # 2. For customized conv layers without their own initialization
180
+ # manners (that is, they don't have their own ``init_weights()``)
181
+ # and PyTorch's conv layers, they will be initialized by
182
+ # this method with default ``kaiming_init``.
183
+ # Note: For PyTorch's conv layers, they will be overwritten by our
184
+ # initialization implementation using default ``kaiming_init``.
185
+ if not hasattr(self.conv, 'init_weights'):
186
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
187
+ nonlinearity = 'leaky_relu'
188
+ a = self.act_cfg.get('negative_slope', 0.01)
189
+ else:
190
+ nonlinearity = 'relu'
191
+ a = 0
192
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
193
+ if self.with_norm:
194
+ constant_init(self.norm, 1, bias=0)
195
+
196
+ def forward(self, x, activate=True, norm=True):
197
+ for layer in self.order:
198
+ if layer == 'conv':
199
+ if self.with_explicit_padding:
200
+ x = self.padding_layer(x)
201
+ x = self.conv(x)
202
+ elif layer == 'norm' and norm and self.with_norm:
203
+ x = self.norm(x)
204
+ elif layer == 'act' and activate and self.with_activation:
205
+ x = self.activate(x)
206
+ return x
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/conv_ws.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .registry import CONV_LAYERS
7
+
8
+
9
+ def conv_ws_2d(input,
10
+ weight,
11
+ bias=None,
12
+ stride=1,
13
+ padding=0,
14
+ dilation=1,
15
+ groups=1,
16
+ eps=1e-5):
17
+ c_in = weight.size(0)
18
+ weight_flat = weight.view(c_in, -1)
19
+ mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
20
+ std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
21
+ weight = (weight - mean) / (std + eps)
22
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
23
+
24
+
25
+ @CONV_LAYERS.register_module('ConvWS')
26
+ class ConvWS2d(nn.Conv2d):
27
+
28
+ def __init__(self,
29
+ in_channels,
30
+ out_channels,
31
+ kernel_size,
32
+ stride=1,
33
+ padding=0,
34
+ dilation=1,
35
+ groups=1,
36
+ bias=True,
37
+ eps=1e-5):
38
+ super(ConvWS2d, self).__init__(
39
+ in_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ stride=stride,
43
+ padding=padding,
44
+ dilation=dilation,
45
+ groups=groups,
46
+ bias=bias)
47
+ self.eps = eps
48
+
49
+ def forward(self, x):
50
+ return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
51
+ self.dilation, self.groups, self.eps)
52
+
53
+
54
+ @CONV_LAYERS.register_module(name='ConvAWS')
55
+ class ConvAWS2d(nn.Conv2d):
56
+ """AWS (Adaptive Weight Standardization)
57
+
58
+ This is a variant of Weight Standardization
59
+ (https://arxiv.org/pdf/1903.10520.pdf)
60
+ It is used in DetectoRS to avoid NaN
61
+ (https://arxiv.org/pdf/2006.02334.pdf)
62
+
63
+ Args:
64
+ in_channels (int): Number of channels in the input image
65
+ out_channels (int): Number of channels produced by the convolution
66
+ kernel_size (int or tuple): Size of the conv kernel
67
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
68
+ padding (int or tuple, optional): Zero-padding added to both sides of
69
+ the input. Default: 0
70
+ dilation (int or tuple, optional): Spacing between kernel elements.
71
+ Default: 1
72
+ groups (int, optional): Number of blocked connections from input
73
+ channels to output channels. Default: 1
74
+ bias (bool, optional): If set True, adds a learnable bias to the
75
+ output. Default: True
76
+ """
77
+
78
+ def __init__(self,
79
+ in_channels,
80
+ out_channels,
81
+ kernel_size,
82
+ stride=1,
83
+ padding=0,
84
+ dilation=1,
85
+ groups=1,
86
+ bias=True):
87
+ super().__init__(
88
+ in_channels,
89
+ out_channels,
90
+ kernel_size,
91
+ stride=stride,
92
+ padding=padding,
93
+ dilation=dilation,
94
+ groups=groups,
95
+ bias=bias)
96
+ self.register_buffer('weight_gamma',
97
+ torch.ones(self.out_channels, 1, 1, 1))
98
+ self.register_buffer('weight_beta',
99
+ torch.zeros(self.out_channels, 1, 1, 1))
100
+
101
+ def _get_weight(self, weight):
102
+ weight_flat = weight.view(weight.size(0), -1)
103
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
104
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
105
+ weight = (weight - mean) / std
106
+ weight = self.weight_gamma * weight + self.weight_beta
107
+ return weight
108
+
109
+ def forward(self, x):
110
+ weight = self._get_weight(self.weight)
111
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding,
112
+ self.dilation, self.groups)
113
+
114
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
115
+ missing_keys, unexpected_keys, error_msgs):
116
+ """Override default load function.
117
+
118
+ AWS overrides the function _load_from_state_dict to recover
119
+ weight_gamma and weight_beta if they are missing. If weight_gamma and
120
+ weight_beta are found in the checkpoint, this function will return
121
+ after super()._load_from_state_dict. Otherwise, it will compute the
122
+ mean and std of the pretrained weights and store them in weight_beta
123
+ and weight_gamma.
124
+ """
125
+
126
+ self.weight_gamma.data.fill_(-1)
127
+ local_missing_keys = []
128
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
129
+ strict, local_missing_keys,
130
+ unexpected_keys, error_msgs)
131
+ if self.weight_gamma.data.mean() > 0:
132
+ for k in local_missing_keys:
133
+ missing_keys.append(k)
134
+ return
135
+ weight = self.weight.data
136
+ weight_flat = weight.view(weight.size(0), -1)
137
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
138
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
139
+ self.weight_beta.data.copy_(mean)
140
+ self.weight_gamma.data.copy_(std)
141
+ missing_gamma_beta = [
142
+ k for k in local_missing_keys
143
+ if k.endswith('weight_gamma') or k.endswith('weight_beta')
144
+ ]
145
+ for k in missing_gamma_beta:
146
+ local_missing_keys.remove(k)
147
+ for k in local_missing_keys:
148
+ missing_keys.append(k)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/depthwise_separable_conv_module.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+
4
+ from .conv_module import ConvModule
5
+
6
+
7
+ class DepthwiseSeparableConvModule(nn.Module):
8
+ """Depthwise separable convolution module.
9
+
10
+ See https://arxiv.org/pdf/1704.04861.pdf for details.
11
+
12
+ This module can replace a ConvModule with the conv block replaced by two
13
+ conv block: depthwise conv block and pointwise conv block. The depthwise
14
+ conv block contains depthwise-conv/norm/activation layers. The pointwise
15
+ conv block contains pointwise-conv/norm/activation layers. It should be
16
+ noted that there will be norm/activation layer in the depthwise conv block
17
+ if `norm_cfg` and `act_cfg` are specified.
18
+
19
+ Args:
20
+ in_channels (int): Number of channels in the input feature map.
21
+ Same as that in ``nn._ConvNd``.
22
+ out_channels (int): Number of channels produced by the convolution.
23
+ Same as that in ``nn._ConvNd``.
24
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
25
+ Same as that in ``nn._ConvNd``.
26
+ stride (int | tuple[int]): Stride of the convolution.
27
+ Same as that in ``nn._ConvNd``. Default: 1.
28
+ padding (int | tuple[int]): Zero-padding added to both sides of
29
+ the input. Same as that in ``nn._ConvNd``. Default: 0.
30
+ dilation (int | tuple[int]): Spacing between kernel elements.
31
+ Same as that in ``nn._ConvNd``. Default: 1.
32
+ norm_cfg (dict): Default norm config for both depthwise ConvModule and
33
+ pointwise ConvModule. Default: None.
34
+ act_cfg (dict): Default activation config for both depthwise ConvModule
35
+ and pointwise ConvModule. Default: dict(type='ReLU').
36
+ dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
37
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
38
+ dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
39
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
40
+ pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
41
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
42
+ pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
43
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
44
+ kwargs (optional): Other shared arguments for depthwise and pointwise
45
+ ConvModule. See ConvModule for ref.
46
+ """
47
+
48
+ def __init__(self,
49
+ in_channels,
50
+ out_channels,
51
+ kernel_size,
52
+ stride=1,
53
+ padding=0,
54
+ dilation=1,
55
+ norm_cfg=None,
56
+ act_cfg=dict(type='ReLU'),
57
+ dw_norm_cfg='default',
58
+ dw_act_cfg='default',
59
+ pw_norm_cfg='default',
60
+ pw_act_cfg='default',
61
+ **kwargs):
62
+ super(DepthwiseSeparableConvModule, self).__init__()
63
+ assert 'groups' not in kwargs, 'groups should not be specified'
64
+
65
+ # if norm/activation config of depthwise/pointwise ConvModule is not
66
+ # specified, use default config.
67
+ dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
68
+ dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
69
+ pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
70
+ pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
71
+
72
+ # depthwise convolution
73
+ self.depthwise_conv = ConvModule(
74
+ in_channels,
75
+ in_channels,
76
+ kernel_size,
77
+ stride=stride,
78
+ padding=padding,
79
+ dilation=dilation,
80
+ groups=in_channels,
81
+ norm_cfg=dw_norm_cfg,
82
+ act_cfg=dw_act_cfg,
83
+ **kwargs)
84
+
85
+ self.pointwise_conv = ConvModule(
86
+ in_channels,
87
+ out_channels,
88
+ 1,
89
+ norm_cfg=pw_norm_cfg,
90
+ act_cfg=pw_act_cfg,
91
+ **kwargs)
92
+
93
+ def forward(self, x):
94
+ x = self.depthwise_conv(x)
95
+ x = self.pointwise_conv(x)
96
+ return x
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/drop.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from annotator.mmpkg.mmcv import build_from_cfg
6
+ from .registry import DROPOUT_LAYERS
7
+
8
+
9
+ def drop_path(x, drop_prob=0., training=False):
10
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
11
+ residual blocks).
12
+
13
+ We follow the implementation
14
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
15
+ """
16
+ if drop_prob == 0. or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ # handle tensors with different dimensions, not just 4D tensors.
20
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
21
+ random_tensor = keep_prob + torch.rand(
22
+ shape, dtype=x.dtype, device=x.device)
23
+ output = x.div(keep_prob) * random_tensor.floor()
24
+ return output
25
+
26
+
27
+ @DROPOUT_LAYERS.register_module()
28
+ class DropPath(nn.Module):
29
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
30
+ residual blocks).
31
+
32
+ We follow the implementation
33
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
34
+
35
+ Args:
36
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
37
+ """
38
+
39
+ def __init__(self, drop_prob=0.1):
40
+ super(DropPath, self).__init__()
41
+ self.drop_prob = drop_prob
42
+
43
+ def forward(self, x):
44
+ return drop_path(x, self.drop_prob, self.training)
45
+
46
+
47
+ @DROPOUT_LAYERS.register_module()
48
+ class Dropout(nn.Dropout):
49
+ """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
50
+ ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
51
+ ``DropPath``
52
+
53
+ Args:
54
+ drop_prob (float): Probability of the elements to be
55
+ zeroed. Default: 0.5.
56
+ inplace (bool): Do the operation inplace or not. Default: False.
57
+ """
58
+
59
+ def __init__(self, drop_prob=0.5, inplace=False):
60
+ super().__init__(p=drop_prob, inplace=inplace)
61
+
62
+
63
+ def build_dropout(cfg, default_args=None):
64
+ """Builder for drop out layers."""
65
+ return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/generalized_attention.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from ..utils import kaiming_init
10
+ from .registry import PLUGIN_LAYERS
11
+
12
+
13
+ @PLUGIN_LAYERS.register_module()
14
+ class GeneralizedAttention(nn.Module):
15
+ """GeneralizedAttention module.
16
+
17
+ See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
18
+ (https://arxiv.org/abs/1711.07971) for details.
19
+
20
+ Args:
21
+ in_channels (int): Channels of the input feature map.
22
+ spatial_range (int): The spatial range. -1 indicates no spatial range
23
+ constraint. Default: -1.
24
+ num_heads (int): The head number of empirical_attention module.
25
+ Default: 9.
26
+ position_embedding_dim (int): The position embedding dimension.
27
+ Default: -1.
28
+ position_magnitude (int): A multiplier acting on coord difference.
29
+ Default: 1.
30
+ kv_stride (int): The feature stride acting on key/value feature map.
31
+ Default: 2.
32
+ q_stride (int): The feature stride acting on query feature map.
33
+ Default: 1.
34
+ attention_type (str): A binary indicator string for indicating which
35
+ items in generalized empirical_attention module are used.
36
+ Default: '1111'.
37
+
38
+ - '1000' indicates 'query and key content' (appr - appr) item,
39
+ - '0100' indicates 'query content and relative position'
40
+ (appr - position) item,
41
+ - '0010' indicates 'key content only' (bias - appr) item,
42
+ - '0001' indicates 'relative position only' (bias - position) item.
43
+ """
44
+
45
+ _abbr_ = 'gen_attention_block'
46
+
47
+ def __init__(self,
48
+ in_channels,
49
+ spatial_range=-1,
50
+ num_heads=9,
51
+ position_embedding_dim=-1,
52
+ position_magnitude=1,
53
+ kv_stride=2,
54
+ q_stride=1,
55
+ attention_type='1111'):
56
+
57
+ super(GeneralizedAttention, self).__init__()
58
+
59
+ # hard range means local range for non-local operation
60
+ self.position_embedding_dim = (
61
+ position_embedding_dim
62
+ if position_embedding_dim > 0 else in_channels)
63
+
64
+ self.position_magnitude = position_magnitude
65
+ self.num_heads = num_heads
66
+ self.in_channels = in_channels
67
+ self.spatial_range = spatial_range
68
+ self.kv_stride = kv_stride
69
+ self.q_stride = q_stride
70
+ self.attention_type = [bool(int(_)) for _ in attention_type]
71
+ self.qk_embed_dim = in_channels // num_heads
72
+ out_c = self.qk_embed_dim * num_heads
73
+
74
+ if self.attention_type[0] or self.attention_type[1]:
75
+ self.query_conv = nn.Conv2d(
76
+ in_channels=in_channels,
77
+ out_channels=out_c,
78
+ kernel_size=1,
79
+ bias=False)
80
+ self.query_conv.kaiming_init = True
81
+
82
+ if self.attention_type[0] or self.attention_type[2]:
83
+ self.key_conv = nn.Conv2d(
84
+ in_channels=in_channels,
85
+ out_channels=out_c,
86
+ kernel_size=1,
87
+ bias=False)
88
+ self.key_conv.kaiming_init = True
89
+
90
+ self.v_dim = in_channels // num_heads
91
+ self.value_conv = nn.Conv2d(
92
+ in_channels=in_channels,
93
+ out_channels=self.v_dim * num_heads,
94
+ kernel_size=1,
95
+ bias=False)
96
+ self.value_conv.kaiming_init = True
97
+
98
+ if self.attention_type[1] or self.attention_type[3]:
99
+ self.appr_geom_fc_x = nn.Linear(
100
+ self.position_embedding_dim // 2, out_c, bias=False)
101
+ self.appr_geom_fc_x.kaiming_init = True
102
+
103
+ self.appr_geom_fc_y = nn.Linear(
104
+ self.position_embedding_dim // 2, out_c, bias=False)
105
+ self.appr_geom_fc_y.kaiming_init = True
106
+
107
+ if self.attention_type[2]:
108
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
109
+ appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
110
+ self.appr_bias = nn.Parameter(appr_bias_value)
111
+
112
+ if self.attention_type[3]:
113
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
114
+ geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
115
+ self.geom_bias = nn.Parameter(geom_bias_value)
116
+
117
+ self.proj_conv = nn.Conv2d(
118
+ in_channels=self.v_dim * num_heads,
119
+ out_channels=in_channels,
120
+ kernel_size=1,
121
+ bias=True)
122
+ self.proj_conv.kaiming_init = True
123
+ self.gamma = nn.Parameter(torch.zeros(1))
124
+
125
+ if self.spatial_range >= 0:
126
+ # only works when non local is after 3*3 conv
127
+ if in_channels == 256:
128
+ max_len = 84
129
+ elif in_channels == 512:
130
+ max_len = 42
131
+
132
+ max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
133
+ local_constraint_map = np.ones(
134
+ (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
135
+ for iy in range(max_len):
136
+ for ix in range(max_len):
137
+ local_constraint_map[
138
+ iy, ix,
139
+ max((iy - self.spatial_range) //
140
+ self.kv_stride, 0):min((iy + self.spatial_range +
141
+ 1) // self.kv_stride +
142
+ 1, max_len),
143
+ max((ix - self.spatial_range) //
144
+ self.kv_stride, 0):min((ix + self.spatial_range +
145
+ 1) // self.kv_stride +
146
+ 1, max_len)] = 0
147
+
148
+ self.local_constraint_map = nn.Parameter(
149
+ torch.from_numpy(local_constraint_map).byte(),
150
+ requires_grad=False)
151
+
152
+ if self.q_stride > 1:
153
+ self.q_downsample = nn.AvgPool2d(
154
+ kernel_size=1, stride=self.q_stride)
155
+ else:
156
+ self.q_downsample = None
157
+
158
+ if self.kv_stride > 1:
159
+ self.kv_downsample = nn.AvgPool2d(
160
+ kernel_size=1, stride=self.kv_stride)
161
+ else:
162
+ self.kv_downsample = None
163
+
164
+ self.init_weights()
165
+
166
+ def get_position_embedding(self,
167
+ h,
168
+ w,
169
+ h_kv,
170
+ w_kv,
171
+ q_stride,
172
+ kv_stride,
173
+ device,
174
+ dtype,
175
+ feat_dim,
176
+ wave_length=1000):
177
+ # the default type of Tensor is float32, leading to type mismatch
178
+ # in fp16 mode. Cast it to support fp16 mode.
179
+ h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
180
+ h_idxs = h_idxs.view((h, 1)) * q_stride
181
+
182
+ w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
183
+ w_idxs = w_idxs.view((w, 1)) * q_stride
184
+
185
+ h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
186
+ device=device, dtype=dtype)
187
+ h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
188
+
189
+ w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
190
+ device=device, dtype=dtype)
191
+ w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
192
+
193
+ # (h, h_kv, 1)
194
+ h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
195
+ h_diff *= self.position_magnitude
196
+
197
+ # (w, w_kv, 1)
198
+ w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
199
+ w_diff *= self.position_magnitude
200
+
201
+ feat_range = torch.arange(0, feat_dim / 4).to(
202
+ device=device, dtype=dtype)
203
+
204
+ dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
205
+ dim_mat = dim_mat**((4. / feat_dim) * feat_range)
206
+ dim_mat = dim_mat.view((1, 1, -1))
207
+
208
+ embedding_x = torch.cat(
209
+ ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
210
+
211
+ embedding_y = torch.cat(
212
+ ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
213
+
214
+ return embedding_x, embedding_y
215
+
216
+ def forward(self, x_input):
217
+ num_heads = self.num_heads
218
+
219
+ # use empirical_attention
220
+ if self.q_downsample is not None:
221
+ x_q = self.q_downsample(x_input)
222
+ else:
223
+ x_q = x_input
224
+ n, _, h, w = x_q.shape
225
+
226
+ if self.kv_downsample is not None:
227
+ x_kv = self.kv_downsample(x_input)
228
+ else:
229
+ x_kv = x_input
230
+ _, _, h_kv, w_kv = x_kv.shape
231
+
232
+ if self.attention_type[0] or self.attention_type[1]:
233
+ proj_query = self.query_conv(x_q).view(
234
+ (n, num_heads, self.qk_embed_dim, h * w))
235
+ proj_query = proj_query.permute(0, 1, 3, 2)
236
+
237
+ if self.attention_type[0] or self.attention_type[2]:
238
+ proj_key = self.key_conv(x_kv).view(
239
+ (n, num_heads, self.qk_embed_dim, h_kv * w_kv))
240
+
241
+ if self.attention_type[1] or self.attention_type[3]:
242
+ position_embed_x, position_embed_y = self.get_position_embedding(
243
+ h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
244
+ x_input.device, x_input.dtype, self.position_embedding_dim)
245
+ # (n, num_heads, w, w_kv, dim)
246
+ position_feat_x = self.appr_geom_fc_x(position_embed_x).\
247
+ view(1, w, w_kv, num_heads, self.qk_embed_dim).\
248
+ permute(0, 3, 1, 2, 4).\
249
+ repeat(n, 1, 1, 1, 1)
250
+
251
+ # (n, num_heads, h, h_kv, dim)
252
+ position_feat_y = self.appr_geom_fc_y(position_embed_y).\
253
+ view(1, h, h_kv, num_heads, self.qk_embed_dim).\
254
+ permute(0, 3, 1, 2, 4).\
255
+ repeat(n, 1, 1, 1, 1)
256
+
257
+ position_feat_x /= math.sqrt(2)
258
+ position_feat_y /= math.sqrt(2)
259
+
260
+ # accelerate for saliency only
261
+ if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
262
+ appr_bias = self.appr_bias.\
263
+ view(1, num_heads, 1, self.qk_embed_dim).\
264
+ repeat(n, 1, 1, 1)
265
+
266
+ energy = torch.matmul(appr_bias, proj_key).\
267
+ view(n, num_heads, 1, h_kv * w_kv)
268
+
269
+ h = 1
270
+ w = 1
271
+ else:
272
+ # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
273
+ if not self.attention_type[0]:
274
+ energy = torch.zeros(
275
+ n,
276
+ num_heads,
277
+ h,
278
+ w,
279
+ h_kv,
280
+ w_kv,
281
+ dtype=x_input.dtype,
282
+ device=x_input.device)
283
+
284
+ # attention_type[0]: appr - appr
285
+ # attention_type[1]: appr - position
286
+ # attention_type[2]: bias - appr
287
+ # attention_type[3]: bias - position
288
+ if self.attention_type[0] or self.attention_type[2]:
289
+ if self.attention_type[0] and self.attention_type[2]:
290
+ appr_bias = self.appr_bias.\
291
+ view(1, num_heads, 1, self.qk_embed_dim)
292
+ energy = torch.matmul(proj_query + appr_bias, proj_key).\
293
+ view(n, num_heads, h, w, h_kv, w_kv)
294
+
295
+ elif self.attention_type[0]:
296
+ energy = torch.matmul(proj_query, proj_key).\
297
+ view(n, num_heads, h, w, h_kv, w_kv)
298
+
299
+ elif self.attention_type[2]:
300
+ appr_bias = self.appr_bias.\
301
+ view(1, num_heads, 1, self.qk_embed_dim).\
302
+ repeat(n, 1, 1, 1)
303
+
304
+ energy += torch.matmul(appr_bias, proj_key).\
305
+ view(n, num_heads, 1, 1, h_kv, w_kv)
306
+
307
+ if self.attention_type[1] or self.attention_type[3]:
308
+ if self.attention_type[1] and self.attention_type[3]:
309
+ geom_bias = self.geom_bias.\
310
+ view(1, num_heads, 1, self.qk_embed_dim)
311
+
312
+ proj_query_reshape = (proj_query + geom_bias).\
313
+ view(n, num_heads, h, w, self.qk_embed_dim)
314
+
315
+ energy_x = torch.matmul(
316
+ proj_query_reshape.permute(0, 1, 3, 2, 4),
317
+ position_feat_x.permute(0, 1, 2, 4, 3))
318
+ energy_x = energy_x.\
319
+ permute(0, 1, 3, 2, 4).unsqueeze(4)
320
+
321
+ energy_y = torch.matmul(
322
+ proj_query_reshape,
323
+ position_feat_y.permute(0, 1, 2, 4, 3))
324
+ energy_y = energy_y.unsqueeze(5)
325
+
326
+ energy += energy_x + energy_y
327
+
328
+ elif self.attention_type[1]:
329
+ proj_query_reshape = proj_query.\
330
+ view(n, num_heads, h, w, self.qk_embed_dim)
331
+ proj_query_reshape = proj_query_reshape.\
332
+ permute(0, 1, 3, 2, 4)
333
+ position_feat_x_reshape = position_feat_x.\
334
+ permute(0, 1, 2, 4, 3)
335
+ position_feat_y_reshape = position_feat_y.\
336
+ permute(0, 1, 2, 4, 3)
337
+
338
+ energy_x = torch.matmul(proj_query_reshape,
339
+ position_feat_x_reshape)
340
+ energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
341
+
342
+ energy_y = torch.matmul(proj_query_reshape,
343
+ position_feat_y_reshape)
344
+ energy_y = energy_y.unsqueeze(5)
345
+
346
+ energy += energy_x + energy_y
347
+
348
+ elif self.attention_type[3]:
349
+ geom_bias = self.geom_bias.\
350
+ view(1, num_heads, self.qk_embed_dim, 1).\
351
+ repeat(n, 1, 1, 1)
352
+
353
+ position_feat_x_reshape = position_feat_x.\
354
+ view(n, num_heads, w*w_kv, self.qk_embed_dim)
355
+
356
+ position_feat_y_reshape = position_feat_y.\
357
+ view(n, num_heads, h * h_kv, self.qk_embed_dim)
358
+
359
+ energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
360
+ energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
361
+
362
+ energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
363
+ energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
364
+
365
+ energy += energy_x + energy_y
366
+
367
+ energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
368
+
369
+ if self.spatial_range >= 0:
370
+ cur_local_constraint_map = \
371
+ self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
372
+ contiguous().\
373
+ view(1, 1, h*w, h_kv*w_kv)
374
+
375
+ energy = energy.masked_fill_(cur_local_constraint_map,
376
+ float('-inf'))
377
+
378
+ attention = F.softmax(energy, 3)
379
+
380
+ proj_value = self.value_conv(x_kv)
381
+ proj_value_reshape = proj_value.\
382
+ view((n, num_heads, self.v_dim, h_kv * w_kv)).\
383
+ permute(0, 1, 3, 2)
384
+
385
+ out = torch.matmul(attention, proj_value_reshape).\
386
+ permute(0, 1, 3, 2).\
387
+ contiguous().\
388
+ view(n, self.v_dim * self.num_heads, h, w)
389
+
390
+ out = self.proj_conv(out)
391
+
392
+ # output is downsampled, upsample back to input size
393
+ if self.q_downsample is not None:
394
+ out = F.interpolate(
395
+ out,
396
+ size=x_input.shape[2:],
397
+ mode='bilinear',
398
+ align_corners=False)
399
+
400
+ out = self.gamma * out + x_input
401
+ return out
402
+
403
+ def init_weights(self):
404
+ for m in self.modules():
405
+ if hasattr(m, 'kaiming_init') and m.kaiming_init:
406
+ kaiming_init(
407
+ m,
408
+ mode='fan_in',
409
+ nonlinearity='leaky_relu',
410
+ bias=0,
411
+ distribution='uniform',
412
+ a=1)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hsigmoid.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+
4
+ from .registry import ACTIVATION_LAYERS
5
+
6
+
7
+ @ACTIVATION_LAYERS.register_module()
8
+ class HSigmoid(nn.Module):
9
+ """Hard Sigmoid Module. Apply the hard sigmoid function:
10
+ Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
11
+ Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
12
+
13
+ Args:
14
+ bias (float): Bias of the input feature map. Default: 1.0.
15
+ divisor (float): Divisor of the input feature map. Default: 2.0.
16
+ min_value (float): Lower bound value. Default: 0.0.
17
+ max_value (float): Upper bound value. Default: 1.0.
18
+
19
+ Returns:
20
+ Tensor: The output tensor.
21
+ """
22
+
23
+ def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
24
+ super(HSigmoid, self).__init__()
25
+ self.bias = bias
26
+ self.divisor = divisor
27
+ assert self.divisor != 0
28
+ self.min_value = min_value
29
+ self.max_value = max_value
30
+
31
+ def forward(self, x):
32
+ x = (x + self.bias) / self.divisor
33
+
34
+ return x.clamp_(self.min_value, self.max_value)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/hswish.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+
4
+ from .registry import ACTIVATION_LAYERS
5
+
6
+
7
+ @ACTIVATION_LAYERS.register_module()
8
+ class HSwish(nn.Module):
9
+ """Hard Swish Module.
10
+
11
+ This module applies the hard swish function:
12
+
13
+ .. math::
14
+ Hswish(x) = x * ReLU6(x + 3) / 6
15
+
16
+ Args:
17
+ inplace (bool): can optionally do the operation in-place.
18
+ Default: False.
19
+
20
+ Returns:
21
+ Tensor: The output tensor.
22
+ """
23
+
24
+ def __init__(self, inplace=False):
25
+ super(HSwish, self).__init__()
26
+ self.act = nn.ReLU6(inplace)
27
+
28
+ def forward(self, x):
29
+ return x * self.act(x + 3) / 6
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/non_local.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from abc import ABCMeta
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..utils import constant_init, normal_init
8
+ from .conv_module import ConvModule
9
+ from .registry import PLUGIN_LAYERS
10
+
11
+
12
+ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
13
+ """Basic Non-local module.
14
+
15
+ This module is proposed in
16
+ "Non-local Neural Networks"
17
+ Paper reference: https://arxiv.org/abs/1711.07971
18
+ Code reference: https://github.com/AlexHex7/Non-local_pytorch
19
+
20
+ Args:
21
+ in_channels (int): Channels of the input feature map.
22
+ reduction (int): Channel reduction ratio. Default: 2.
23
+ use_scale (bool): Whether to scale pairwise_weight by
24
+ `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
25
+ Default: True.
26
+ conv_cfg (None | dict): The config dict for convolution layers.
27
+ If not specified, it will use `nn.Conv2d` for convolution layers.
28
+ Default: None.
29
+ norm_cfg (None | dict): The config dict for normalization layers.
30
+ Default: None. (This parameter is only applicable to conv_out.)
31
+ mode (str): Options are `gaussian`, `concatenation`,
32
+ `embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
33
+ """
34
+
35
+ def __init__(self,
36
+ in_channels,
37
+ reduction=2,
38
+ use_scale=True,
39
+ conv_cfg=None,
40
+ norm_cfg=None,
41
+ mode='embedded_gaussian',
42
+ **kwargs):
43
+ super(_NonLocalNd, self).__init__()
44
+ self.in_channels = in_channels
45
+ self.reduction = reduction
46
+ self.use_scale = use_scale
47
+ self.inter_channels = max(in_channels // reduction, 1)
48
+ self.mode = mode
49
+
50
+ if mode not in [
51
+ 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
52
+ ]:
53
+ raise ValueError("Mode should be in 'gaussian', 'concatenation', "
54
+ f"'embedded_gaussian' or 'dot_product', but got "
55
+ f'{mode} instead.')
56
+
57
+ # g, theta, phi are defaulted as `nn.ConvNd`.
58
+ # Here we use ConvModule for potential usage.
59
+ self.g = ConvModule(
60
+ self.in_channels,
61
+ self.inter_channels,
62
+ kernel_size=1,
63
+ conv_cfg=conv_cfg,
64
+ act_cfg=None)
65
+ self.conv_out = ConvModule(
66
+ self.inter_channels,
67
+ self.in_channels,
68
+ kernel_size=1,
69
+ conv_cfg=conv_cfg,
70
+ norm_cfg=norm_cfg,
71
+ act_cfg=None)
72
+
73
+ if self.mode != 'gaussian':
74
+ self.theta = ConvModule(
75
+ self.in_channels,
76
+ self.inter_channels,
77
+ kernel_size=1,
78
+ conv_cfg=conv_cfg,
79
+ act_cfg=None)
80
+ self.phi = ConvModule(
81
+ self.in_channels,
82
+ self.inter_channels,
83
+ kernel_size=1,
84
+ conv_cfg=conv_cfg,
85
+ act_cfg=None)
86
+
87
+ if self.mode == 'concatenation':
88
+ self.concat_project = ConvModule(
89
+ self.inter_channels * 2,
90
+ 1,
91
+ kernel_size=1,
92
+ stride=1,
93
+ padding=0,
94
+ bias=False,
95
+ act_cfg=dict(type='ReLU'))
96
+
97
+ self.init_weights(**kwargs)
98
+
99
+ def init_weights(self, std=0.01, zeros_init=True):
100
+ if self.mode != 'gaussian':
101
+ for m in [self.g, self.theta, self.phi]:
102
+ normal_init(m.conv, std=std)
103
+ else:
104
+ normal_init(self.g.conv, std=std)
105
+ if zeros_init:
106
+ if self.conv_out.norm_cfg is None:
107
+ constant_init(self.conv_out.conv, 0)
108
+ else:
109
+ constant_init(self.conv_out.norm, 0)
110
+ else:
111
+ if self.conv_out.norm_cfg is None:
112
+ normal_init(self.conv_out.conv, std=std)
113
+ else:
114
+ normal_init(self.conv_out.norm, std=std)
115
+
116
+ def gaussian(self, theta_x, phi_x):
117
+ # NonLocal1d pairwise_weight: [N, H, H]
118
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
119
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
120
+ pairwise_weight = torch.matmul(theta_x, phi_x)
121
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
122
+ return pairwise_weight
123
+
124
+ def embedded_gaussian(self, theta_x, phi_x):
125
+ # NonLocal1d pairwise_weight: [N, H, H]
126
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
127
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
128
+ pairwise_weight = torch.matmul(theta_x, phi_x)
129
+ if self.use_scale:
130
+ # theta_x.shape[-1] is `self.inter_channels`
131
+ pairwise_weight /= theta_x.shape[-1]**0.5
132
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
133
+ return pairwise_weight
134
+
135
+ def dot_product(self, theta_x, phi_x):
136
+ # NonLocal1d pairwise_weight: [N, H, H]
137
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
138
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
139
+ pairwise_weight = torch.matmul(theta_x, phi_x)
140
+ pairwise_weight /= pairwise_weight.shape[-1]
141
+ return pairwise_weight
142
+
143
+ def concatenation(self, theta_x, phi_x):
144
+ # NonLocal1d pairwise_weight: [N, H, H]
145
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
146
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
147
+ h = theta_x.size(2)
148
+ w = phi_x.size(3)
149
+ theta_x = theta_x.repeat(1, 1, 1, w)
150
+ phi_x = phi_x.repeat(1, 1, h, 1)
151
+
152
+ concat_feature = torch.cat([theta_x, phi_x], dim=1)
153
+ pairwise_weight = self.concat_project(concat_feature)
154
+ n, _, h, w = pairwise_weight.size()
155
+ pairwise_weight = pairwise_weight.view(n, h, w)
156
+ pairwise_weight /= pairwise_weight.shape[-1]
157
+
158
+ return pairwise_weight
159
+
160
+ def forward(self, x):
161
+ # Assume `reduction = 1`, then `inter_channels = C`
162
+ # or `inter_channels = C` when `mode="gaussian"`
163
+
164
+ # NonLocal1d x: [N, C, H]
165
+ # NonLocal2d x: [N, C, H, W]
166
+ # NonLocal3d x: [N, C, T, H, W]
167
+ n = x.size(0)
168
+
169
+ # NonLocal1d g_x: [N, H, C]
170
+ # NonLocal2d g_x: [N, HxW, C]
171
+ # NonLocal3d g_x: [N, TxHxW, C]
172
+ g_x = self.g(x).view(n, self.inter_channels, -1)
173
+ g_x = g_x.permute(0, 2, 1)
174
+
175
+ # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
176
+ # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
177
+ # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
178
+ if self.mode == 'gaussian':
179
+ theta_x = x.view(n, self.in_channels, -1)
180
+ theta_x = theta_x.permute(0, 2, 1)
181
+ if self.sub_sample:
182
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
183
+ else:
184
+ phi_x = x.view(n, self.in_channels, -1)
185
+ elif self.mode == 'concatenation':
186
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
187
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
188
+ else:
189
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
190
+ theta_x = theta_x.permute(0, 2, 1)
191
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
192
+
193
+ pairwise_func = getattr(self, self.mode)
194
+ # NonLocal1d pairwise_weight: [N, H, H]
195
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
196
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
197
+ pairwise_weight = pairwise_func(theta_x, phi_x)
198
+
199
+ # NonLocal1d y: [N, H, C]
200
+ # NonLocal2d y: [N, HxW, C]
201
+ # NonLocal3d y: [N, TxHxW, C]
202
+ y = torch.matmul(pairwise_weight, g_x)
203
+ # NonLocal1d y: [N, C, H]
204
+ # NonLocal2d y: [N, C, H, W]
205
+ # NonLocal3d y: [N, C, T, H, W]
206
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
207
+ *x.size()[2:])
208
+
209
+ output = x + self.conv_out(y)
210
+
211
+ return output
212
+
213
+
214
+ class NonLocal1d(_NonLocalNd):
215
+ """1D Non-local module.
216
+
217
+ Args:
218
+ in_channels (int): Same as `NonLocalND`.
219
+ sub_sample (bool): Whether to apply max pooling after pairwise
220
+ function (Note that the `sub_sample` is applied on spatial only).
221
+ Default: False.
222
+ conv_cfg (None | dict): Same as `NonLocalND`.
223
+ Default: dict(type='Conv1d').
224
+ """
225
+
226
+ def __init__(self,
227
+ in_channels,
228
+ sub_sample=False,
229
+ conv_cfg=dict(type='Conv1d'),
230
+ **kwargs):
231
+ super(NonLocal1d, self).__init__(
232
+ in_channels, conv_cfg=conv_cfg, **kwargs)
233
+
234
+ self.sub_sample = sub_sample
235
+
236
+ if sub_sample:
237
+ max_pool_layer = nn.MaxPool1d(kernel_size=2)
238
+ self.g = nn.Sequential(self.g, max_pool_layer)
239
+ if self.mode != 'gaussian':
240
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
241
+ else:
242
+ self.phi = max_pool_layer
243
+
244
+
245
+ @PLUGIN_LAYERS.register_module()
246
+ class NonLocal2d(_NonLocalNd):
247
+ """2D Non-local module.
248
+
249
+ Args:
250
+ in_channels (int): Same as `NonLocalND`.
251
+ sub_sample (bool): Whether to apply max pooling after pairwise
252
+ function (Note that the `sub_sample` is applied on spatial only).
253
+ Default: False.
254
+ conv_cfg (None | dict): Same as `NonLocalND`.
255
+ Default: dict(type='Conv2d').
256
+ """
257
+
258
+ _abbr_ = 'nonlocal_block'
259
+
260
+ def __init__(self,
261
+ in_channels,
262
+ sub_sample=False,
263
+ conv_cfg=dict(type='Conv2d'),
264
+ **kwargs):
265
+ super(NonLocal2d, self).__init__(
266
+ in_channels, conv_cfg=conv_cfg, **kwargs)
267
+
268
+ self.sub_sample = sub_sample
269
+
270
+ if sub_sample:
271
+ max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
272
+ self.g = nn.Sequential(self.g, max_pool_layer)
273
+ if self.mode != 'gaussian':
274
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
275
+ else:
276
+ self.phi = max_pool_layer
277
+
278
+
279
+ class NonLocal3d(_NonLocalNd):
280
+ """3D Non-local module.
281
+
282
+ Args:
283
+ in_channels (int): Same as `NonLocalND`.
284
+ sub_sample (bool): Whether to apply max pooling after pairwise
285
+ function (Note that the `sub_sample` is applied on spatial only).
286
+ Default: False.
287
+ conv_cfg (None | dict): Same as `NonLocalND`.
288
+ Default: dict(type='Conv3d').
289
+ """
290
+
291
+ def __init__(self,
292
+ in_channels,
293
+ sub_sample=False,
294
+ conv_cfg=dict(type='Conv3d'),
295
+ **kwargs):
296
+ super(NonLocal3d, self).__init__(
297
+ in_channels, conv_cfg=conv_cfg, **kwargs)
298
+ self.sub_sample = sub_sample
299
+
300
+ if sub_sample:
301
+ max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
302
+ self.g = nn.Sequential(self.g, max_pool_layer)
303
+ if self.mode != 'gaussian':
304
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
305
+ else:
306
+ self.phi = max_pool_layer
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/norm.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import inspect
3
+
4
+ import torch.nn as nn
5
+
6
+ from annotator.mmpkg.mmcv.utils import is_tuple_of
7
+ from annotator.mmpkg.mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
8
+ from .registry import NORM_LAYERS
9
+
10
+ NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
11
+ NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
12
+ NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
13
+ NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
14
+ NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
15
+ NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
16
+ NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
17
+ NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
18
+ NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
19
+ NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
20
+ NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
21
+
22
+
23
+ def infer_abbr(class_type):
24
+ """Infer abbreviation from the class name.
25
+
26
+ When we build a norm layer with `build_norm_layer()`, we want to preserve
27
+ the norm type in variable names, e.g, self.bn1, self.gn. This method will
28
+ infer the abbreviation to map class types to abbreviations.
29
+
30
+ Rule 1: If the class has the property "_abbr_", return the property.
31
+ Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
32
+ InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
33
+ "in" respectively.
34
+ Rule 3: If the class name contains "batch", "group", "layer" or "instance",
35
+ the abbreviation of this layer will be "bn", "gn", "ln" and "in"
36
+ respectively.
37
+ Rule 4: Otherwise, the abbreviation falls back to "norm".
38
+
39
+ Args:
40
+ class_type (type): The norm layer type.
41
+
42
+ Returns:
43
+ str: The inferred abbreviation.
44
+ """
45
+ if not inspect.isclass(class_type):
46
+ raise TypeError(
47
+ f'class_type must be a type, but got {type(class_type)}')
48
+ if hasattr(class_type, '_abbr_'):
49
+ return class_type._abbr_
50
+ if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
51
+ return 'in'
52
+ elif issubclass(class_type, _BatchNorm):
53
+ return 'bn'
54
+ elif issubclass(class_type, nn.GroupNorm):
55
+ return 'gn'
56
+ elif issubclass(class_type, nn.LayerNorm):
57
+ return 'ln'
58
+ else:
59
+ class_name = class_type.__name__.lower()
60
+ if 'batch' in class_name:
61
+ return 'bn'
62
+ elif 'group' in class_name:
63
+ return 'gn'
64
+ elif 'layer' in class_name:
65
+ return 'ln'
66
+ elif 'instance' in class_name:
67
+ return 'in'
68
+ else:
69
+ return 'norm_layer'
70
+
71
+
72
+ def build_norm_layer(cfg, num_features, postfix=''):
73
+ """Build normalization layer.
74
+
75
+ Args:
76
+ cfg (dict): The norm layer config, which should contain:
77
+
78
+ - type (str): Layer type.
79
+ - layer args: Args needed to instantiate a norm layer.
80
+ - requires_grad (bool, optional): Whether stop gradient updates.
81
+ num_features (int): Number of input channels.
82
+ postfix (int | str): The postfix to be appended into norm abbreviation
83
+ to create named layer.
84
+
85
+ Returns:
86
+ (str, nn.Module): The first element is the layer name consisting of
87
+ abbreviation and postfix, e.g., bn1, gn. The second element is the
88
+ created norm layer.
89
+ """
90
+ if not isinstance(cfg, dict):
91
+ raise TypeError('cfg must be a dict')
92
+ if 'type' not in cfg:
93
+ raise KeyError('the cfg dict must contain the key "type"')
94
+ cfg_ = cfg.copy()
95
+
96
+ layer_type = cfg_.pop('type')
97
+ if layer_type not in NORM_LAYERS:
98
+ raise KeyError(f'Unrecognized norm type {layer_type}')
99
+
100
+ norm_layer = NORM_LAYERS.get(layer_type)
101
+ abbr = infer_abbr(norm_layer)
102
+
103
+ assert isinstance(postfix, (int, str))
104
+ name = abbr + str(postfix)
105
+
106
+ requires_grad = cfg_.pop('requires_grad', True)
107
+ cfg_.setdefault('eps', 1e-5)
108
+ if layer_type != 'GN':
109
+ layer = norm_layer(num_features, **cfg_)
110
+ if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
111
+ layer._specify_ddp_gpu_num(1)
112
+ else:
113
+ assert 'num_groups' in cfg_
114
+ layer = norm_layer(num_channels=num_features, **cfg_)
115
+
116
+ for param in layer.parameters():
117
+ param.requires_grad = requires_grad
118
+
119
+ return name, layer
120
+
121
+
122
+ def is_norm(layer, exclude=None):
123
+ """Check if a layer is a normalization layer.
124
+
125
+ Args:
126
+ layer (nn.Module): The layer to be checked.
127
+ exclude (type | tuple[type]): Types to be excluded.
128
+
129
+ Returns:
130
+ bool: Whether the layer is a norm layer.
131
+ """
132
+ if exclude is not None:
133
+ if not isinstance(exclude, tuple):
134
+ exclude = (exclude, )
135
+ if not is_tuple_of(exclude, type):
136
+ raise TypeError(
137
+ f'"exclude" must be either None or type or a tuple of types, '
138
+ f'but got {type(exclude)}: {exclude}')
139
+
140
+ if exclude and isinstance(layer, exclude):
141
+ return False
142
+
143
+ all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
144
+ return isinstance(layer, all_norm_bases)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/padding.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+
4
+ from .registry import PADDING_LAYERS
5
+
6
+ PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
7
+ PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
8
+ PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
9
+
10
+
11
+ def build_padding_layer(cfg, *args, **kwargs):
12
+ """Build padding layer.
13
+
14
+ Args:
15
+ cfg (None or dict): The padding layer config, which should contain:
16
+ - type (str): Layer type.
17
+ - layer args: Args needed to instantiate a padding layer.
18
+
19
+ Returns:
20
+ nn.Module: Created padding layer.
21
+ """
22
+ if not isinstance(cfg, dict):
23
+ raise TypeError('cfg must be a dict')
24
+ if 'type' not in cfg:
25
+ raise KeyError('the cfg dict must contain the key "type"')
26
+
27
+ cfg_ = cfg.copy()
28
+ padding_type = cfg_.pop('type')
29
+ if padding_type not in PADDING_LAYERS:
30
+ raise KeyError(f'Unrecognized padding type {padding_type}.')
31
+ else:
32
+ padding_layer = PADDING_LAYERS.get(padding_type)
33
+
34
+ layer = padding_layer(*args, **kwargs, **cfg_)
35
+
36
+ return layer
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/plugin.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import platform
3
+
4
+ from .registry import PLUGIN_LAYERS
5
+
6
+ if platform.system() == 'Windows':
7
+ import regex as re
8
+ else:
9
+ import re
10
+
11
+
12
+ def infer_abbr(class_type):
13
+ """Infer abbreviation from the class name.
14
+
15
+ This method will infer the abbreviation to map class types to
16
+ abbreviations.
17
+
18
+ Rule 1: If the class has the property "abbr", return the property.
19
+ Rule 2: Otherwise, the abbreviation falls back to snake case of class
20
+ name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
21
+
22
+ Args:
23
+ class_type (type): The norm layer type.
24
+
25
+ Returns:
26
+ str: The inferred abbreviation.
27
+ """
28
+
29
+ def camel2snack(word):
30
+ """Convert camel case word into snack case.
31
+
32
+ Modified from `inflection lib
33
+ <https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
34
+
35
+ Example::
36
+
37
+ >>> camel2snack("FancyBlock")
38
+ 'fancy_block'
39
+ """
40
+
41
+ word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
42
+ word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
43
+ word = word.replace('-', '_')
44
+ return word.lower()
45
+
46
+ if not inspect.isclass(class_type):
47
+ raise TypeError(
48
+ f'class_type must be a type, but got {type(class_type)}')
49
+ if hasattr(class_type, '_abbr_'):
50
+ return class_type._abbr_
51
+ else:
52
+ return camel2snack(class_type.__name__)
53
+
54
+
55
+ def build_plugin_layer(cfg, postfix='', **kwargs):
56
+ """Build plugin layer.
57
+
58
+ Args:
59
+ cfg (None or dict): cfg should contain:
60
+ type (str): identify plugin layer type.
61
+ layer args: args needed to instantiate a plugin layer.
62
+ postfix (int, str): appended into norm abbreviation to
63
+ create named layer. Default: ''.
64
+
65
+ Returns:
66
+ tuple[str, nn.Module]:
67
+ name (str): abbreviation + postfix
68
+ layer (nn.Module): created plugin layer
69
+ """
70
+ if not isinstance(cfg, dict):
71
+ raise TypeError('cfg must be a dict')
72
+ if 'type' not in cfg:
73
+ raise KeyError('the cfg dict must contain the key "type"')
74
+ cfg_ = cfg.copy()
75
+
76
+ layer_type = cfg_.pop('type')
77
+ if layer_type not in PLUGIN_LAYERS:
78
+ raise KeyError(f'Unrecognized plugin type {layer_type}')
79
+
80
+ plugin_layer = PLUGIN_LAYERS.get(layer_type)
81
+ abbr = infer_abbr(plugin_layer)
82
+
83
+ assert isinstance(postfix, (int, str))
84
+ name = abbr + str(postfix)
85
+
86
+ layer = plugin_layer(**kwargs, **cfg_)
87
+
88
+ return name, layer
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/registry.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from annotator.mmpkg.mmcv.utils import Registry
3
+
4
+ CONV_LAYERS = Registry('conv layer')
5
+ NORM_LAYERS = Registry('norm layer')
6
+ ACTIVATION_LAYERS = Registry('activation layer')
7
+ PADDING_LAYERS = Registry('padding layer')
8
+ UPSAMPLE_LAYERS = Registry('upsample layer')
9
+ PLUGIN_LAYERS = Registry('plugin layer')
10
+
11
+ DROPOUT_LAYERS = Registry('drop out layers')
12
+ POSITIONAL_ENCODING = Registry('position encoding')
13
+ ATTENTION = Registry('attention')
14
+ FEEDFORWARD_NETWORK = Registry('feed-forward Network')
15
+ TRANSFORMER_LAYER = Registry('transformerLayer')
16
+ TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/scale.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Scale(nn.Module):
7
+ """A learnable scale parameter.
8
+
9
+ This layer scales the input by a learnable factor. It multiplies a
10
+ learnable scale parameter of shape (1,) with input of any shape.
11
+
12
+ Args:
13
+ scale (float): Initial value of scale factor. Default: 1.0
14
+ """
15
+
16
+ def __init__(self, scale=1.0):
17
+ super(Scale, self).__init__()
18
+ self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
19
+
20
+ def forward(self, x):
21
+ return x * self.scale
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/swish.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .registry import ACTIVATION_LAYERS
6
+
7
+
8
+ @ACTIVATION_LAYERS.register_module()
9
+ class Swish(nn.Module):
10
+ """Swish Module.
11
+
12
+ This module applies the swish function:
13
+
14
+ .. math::
15
+ Swish(x) = x * Sigmoid(x)
16
+
17
+ Returns:
18
+ Tensor: The output tensor.
19
+ """
20
+
21
+ def __init__(self):
22
+ super(Swish, self).__init__()
23
+
24
+ def forward(self, x):
25
+ return x * torch.sigmoid(x)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/transformer.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import warnings
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from annotator.mmpkg.mmcv import ConfigDict, deprecated_api_warning
9
+ from annotator.mmpkg.mmcv.cnn import Linear, build_activation_layer, build_norm_layer
10
+ from annotator.mmpkg.mmcv.runner.base_module import BaseModule, ModuleList, Sequential
11
+ from annotator.mmpkg.mmcv.utils import build_from_cfg
12
+ from .drop import build_dropout
13
+ from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
14
+ TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
15
+
16
+ # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
17
+ try:
18
+ from annotator.mmpkg.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
19
+ warnings.warn(
20
+ ImportWarning(
21
+ '``MultiScaleDeformableAttention`` has been moved to '
22
+ '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
23
+ '``from annotator.mmpkg.mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
24
+ 'to ``from annotator.mmpkg.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
25
+ ))
26
+
27
+ except ImportError:
28
+ warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
29
+ '``mmcv.ops.multi_scale_deform_attn``, '
30
+ 'You should install ``mmcv-full`` if you need this module. ')
31
+
32
+
33
+ def build_positional_encoding(cfg, default_args=None):
34
+ """Builder for Position Encoding."""
35
+ return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
36
+
37
+
38
+ def build_attention(cfg, default_args=None):
39
+ """Builder for attention."""
40
+ return build_from_cfg(cfg, ATTENTION, default_args)
41
+
42
+
43
+ def build_feedforward_network(cfg, default_args=None):
44
+ """Builder for feed-forward network (FFN)."""
45
+ return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
46
+
47
+
48
+ def build_transformer_layer(cfg, default_args=None):
49
+ """Builder for transformer layer."""
50
+ return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
51
+
52
+
53
+ def build_transformer_layer_sequence(cfg, default_args=None):
54
+ """Builder for transformer encoder and transformer decoder."""
55
+ return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
56
+
57
+
58
+ @ATTENTION.register_module()
59
+ class MultiheadAttention(BaseModule):
60
+ """A wrapper for ``torch.nn.MultiheadAttention``.
61
+
62
+ This module implements MultiheadAttention with identity connection,
63
+ and positional encoding is also passed as input.
64
+
65
+ Args:
66
+ embed_dims (int): The embedding dimension.
67
+ num_heads (int): Parallel attention heads.
68
+ attn_drop (float): A Dropout layer on attn_output_weights.
69
+ Default: 0.0.
70
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
71
+ Default: 0.0.
72
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
73
+ when adding the shortcut.
74
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
75
+ Default: None.
76
+ batch_first (bool): When it is True, Key, Query and Value are shape of
77
+ (batch, n, embed_dim), otherwise (n, batch, embed_dim).
78
+ Default to False.
79
+ """
80
+
81
+ def __init__(self,
82
+ embed_dims,
83
+ num_heads,
84
+ attn_drop=0.,
85
+ proj_drop=0.,
86
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
87
+ init_cfg=None,
88
+ batch_first=False,
89
+ **kwargs):
90
+ super(MultiheadAttention, self).__init__(init_cfg)
91
+ if 'dropout' in kwargs:
92
+ warnings.warn('The arguments `dropout` in MultiheadAttention '
93
+ 'has been deprecated, now you can separately '
94
+ 'set `attn_drop`(float), proj_drop(float), '
95
+ 'and `dropout_layer`(dict) ')
96
+ attn_drop = kwargs['dropout']
97
+ dropout_layer['drop_prob'] = kwargs.pop('dropout')
98
+
99
+ self.embed_dims = embed_dims
100
+ self.num_heads = num_heads
101
+ self.batch_first = batch_first
102
+
103
+ self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
104
+ **kwargs)
105
+
106
+ self.proj_drop = nn.Dropout(proj_drop)
107
+ self.dropout_layer = build_dropout(
108
+ dropout_layer) if dropout_layer else nn.Identity()
109
+
110
+ @deprecated_api_warning({'residual': 'identity'},
111
+ cls_name='MultiheadAttention')
112
+ def forward(self,
113
+ query,
114
+ key=None,
115
+ value=None,
116
+ identity=None,
117
+ query_pos=None,
118
+ key_pos=None,
119
+ attn_mask=None,
120
+ key_padding_mask=None,
121
+ **kwargs):
122
+ """Forward function for `MultiheadAttention`.
123
+
124
+ **kwargs allow passing a more general data flow when combining
125
+ with other operations in `transformerlayer`.
126
+
127
+ Args:
128
+ query (Tensor): The input query with shape [num_queries, bs,
129
+ embed_dims] if self.batch_first is False, else
130
+ [bs, num_queries embed_dims].
131
+ key (Tensor): The key tensor with shape [num_keys, bs,
132
+ embed_dims] if self.batch_first is False, else
133
+ [bs, num_keys, embed_dims] .
134
+ If None, the ``query`` will be used. Defaults to None.
135
+ value (Tensor): The value tensor with same shape as `key`.
136
+ Same in `nn.MultiheadAttention.forward`. Defaults to None.
137
+ If None, the `key` will be used.
138
+ identity (Tensor): This tensor, with the same shape as x,
139
+ will be used for the identity link.
140
+ If None, `x` will be used. Defaults to None.
141
+ query_pos (Tensor): The positional encoding for query, with
142
+ the same shape as `x`. If not None, it will
143
+ be added to `x` before forward function. Defaults to None.
144
+ key_pos (Tensor): The positional encoding for `key`, with the
145
+ same shape as `key`. Defaults to None. If not None, it will
146
+ be added to `key` before forward function. If None, and
147
+ `query_pos` has the same shape as `key`, then `query_pos`
148
+ will be used for `key_pos`. Defaults to None.
149
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
150
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
151
+ Defaults to None.
152
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
153
+ Defaults to None.
154
+
155
+ Returns:
156
+ Tensor: forwarded results with shape
157
+ [num_queries, bs, embed_dims]
158
+ if self.batch_first is False, else
159
+ [bs, num_queries embed_dims].
160
+ """
161
+
162
+ if key is None:
163
+ key = query
164
+ if value is None:
165
+ value = key
166
+ if identity is None:
167
+ identity = query
168
+ if key_pos is None:
169
+ if query_pos is not None:
170
+ # use query_pos if key_pos is not available
171
+ if query_pos.shape == key.shape:
172
+ key_pos = query_pos
173
+ else:
174
+ warnings.warn(f'position encoding of key is'
175
+ f'missing in {self.__class__.__name__}.')
176
+ if query_pos is not None:
177
+ query = query + query_pos
178
+ if key_pos is not None:
179
+ key = key + key_pos
180
+
181
+ # Because the dataflow('key', 'query', 'value') of
182
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
183
+ # embed_dims), We should adjust the shape of dataflow from
184
+ # batch_first (batch, num_query, embed_dims) to num_query_first
185
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
186
+ # from num_query_first to batch_first.
187
+ if self.batch_first:
188
+ query = query.transpose(0, 1)
189
+ key = key.transpose(0, 1)
190
+ value = value.transpose(0, 1)
191
+
192
+ out = self.attn(
193
+ query=query,
194
+ key=key,
195
+ value=value,
196
+ attn_mask=attn_mask,
197
+ key_padding_mask=key_padding_mask)[0]
198
+
199
+ if self.batch_first:
200
+ out = out.transpose(0, 1)
201
+
202
+ return identity + self.dropout_layer(self.proj_drop(out))
203
+
204
+
205
+ @FEEDFORWARD_NETWORK.register_module()
206
+ class FFN(BaseModule):
207
+ """Implements feed-forward networks (FFNs) with identity connection.
208
+
209
+ Args:
210
+ embed_dims (int): The feature dimension. Same as
211
+ `MultiheadAttention`. Defaults: 256.
212
+ feedforward_channels (int): The hidden dimension of FFNs.
213
+ Defaults: 1024.
214
+ num_fcs (int, optional): The number of fully-connected layers in
215
+ FFNs. Default: 2.
216
+ act_cfg (dict, optional): The activation config for FFNs.
217
+ Default: dict(type='ReLU')
218
+ ffn_drop (float, optional): Probability of an element to be
219
+ zeroed in FFN. Default 0.0.
220
+ add_identity (bool, optional): Whether to add the
221
+ identity connection. Default: `True`.
222
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
223
+ when adding the shortcut.
224
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
225
+ Default: None.
226
+ """
227
+
228
+ @deprecated_api_warning(
229
+ {
230
+ 'dropout': 'ffn_drop',
231
+ 'add_residual': 'add_identity'
232
+ },
233
+ cls_name='FFN')
234
+ def __init__(self,
235
+ embed_dims=256,
236
+ feedforward_channels=1024,
237
+ num_fcs=2,
238
+ act_cfg=dict(type='ReLU', inplace=True),
239
+ ffn_drop=0.,
240
+ dropout_layer=None,
241
+ add_identity=True,
242
+ init_cfg=None,
243
+ **kwargs):
244
+ super(FFN, self).__init__(init_cfg)
245
+ assert num_fcs >= 2, 'num_fcs should be no less ' \
246
+ f'than 2. got {num_fcs}.'
247
+ self.embed_dims = embed_dims
248
+ self.feedforward_channels = feedforward_channels
249
+ self.num_fcs = num_fcs
250
+ self.act_cfg = act_cfg
251
+ self.activate = build_activation_layer(act_cfg)
252
+
253
+ layers = []
254
+ in_channels = embed_dims
255
+ for _ in range(num_fcs - 1):
256
+ layers.append(
257
+ Sequential(
258
+ Linear(in_channels, feedforward_channels), self.activate,
259
+ nn.Dropout(ffn_drop)))
260
+ in_channels = feedforward_channels
261
+ layers.append(Linear(feedforward_channels, embed_dims))
262
+ layers.append(nn.Dropout(ffn_drop))
263
+ self.layers = Sequential(*layers)
264
+ self.dropout_layer = build_dropout(
265
+ dropout_layer) if dropout_layer else torch.nn.Identity()
266
+ self.add_identity = add_identity
267
+
268
+ @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
269
+ def forward(self, x, identity=None):
270
+ """Forward function for `FFN`.
271
+
272
+ The function would add x to the output tensor if residue is None.
273
+ """
274
+ out = self.layers(x)
275
+ if not self.add_identity:
276
+ return self.dropout_layer(out)
277
+ if identity is None:
278
+ identity = x
279
+ return identity + self.dropout_layer(out)
280
+
281
+
282
+ @TRANSFORMER_LAYER.register_module()
283
+ class BaseTransformerLayer(BaseModule):
284
+ """Base `TransformerLayer` for vision transformer.
285
+
286
+ It can be built from `mmcv.ConfigDict` and support more flexible
287
+ customization, for example, using any number of `FFN or LN ` and
288
+ use different kinds of `attention` by specifying a list of `ConfigDict`
289
+ named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
290
+ when you specifying `norm` as the first element of `operation_order`.
291
+ More details about the `prenorm`: `On Layer Normalization in the
292
+ Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
293
+
294
+ Args:
295
+ attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
296
+ Configs for `self_attention` or `cross_attention` modules,
297
+ The order of the configs in the list should be consistent with
298
+ corresponding attentions in operation_order.
299
+ If it is a dict, all of the attention modules in operation_order
300
+ will be built with this config. Default: None.
301
+ ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
302
+ Configs for FFN, The order of the configs in the list should be
303
+ consistent with corresponding ffn in operation_order.
304
+ If it is a dict, all of the attention modules in operation_order
305
+ will be built with this config.
306
+ operation_order (tuple[str]): The execution order of operation
307
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
308
+ Support `prenorm` when you specifying first element as `norm`.
309
+ Default:None.
310
+ norm_cfg (dict): Config dict for normalization layer.
311
+ Default: dict(type='LN').
312
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
313
+ Default: None.
314
+ batch_first (bool): Key, Query and Value are shape
315
+ of (batch, n, embed_dim)
316
+ or (n, batch, embed_dim). Default to False.
317
+ """
318
+
319
+ def __init__(self,
320
+ attn_cfgs=None,
321
+ ffn_cfgs=dict(
322
+ type='FFN',
323
+ embed_dims=256,
324
+ feedforward_channels=1024,
325
+ num_fcs=2,
326
+ ffn_drop=0.,
327
+ act_cfg=dict(type='ReLU', inplace=True),
328
+ ),
329
+ operation_order=None,
330
+ norm_cfg=dict(type='LN'),
331
+ init_cfg=None,
332
+ batch_first=False,
333
+ **kwargs):
334
+
335
+ deprecated_args = dict(
336
+ feedforward_channels='feedforward_channels',
337
+ ffn_dropout='ffn_drop',
338
+ ffn_num_fcs='num_fcs')
339
+ for ori_name, new_name in deprecated_args.items():
340
+ if ori_name in kwargs:
341
+ warnings.warn(
342
+ f'The arguments `{ori_name}` in BaseTransformerLayer '
343
+ f'has been deprecated, now you should set `{new_name}` '
344
+ f'and other FFN related arguments '
345
+ f'to a dict named `ffn_cfgs`. ')
346
+ ffn_cfgs[new_name] = kwargs[ori_name]
347
+
348
+ super(BaseTransformerLayer, self).__init__(init_cfg)
349
+
350
+ self.batch_first = batch_first
351
+
352
+ assert set(operation_order) & set(
353
+ ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
354
+ set(operation_order), f'The operation_order of' \
355
+ f' {self.__class__.__name__} should ' \
356
+ f'contains all four operation type ' \
357
+ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
358
+
359
+ num_attn = operation_order.count('self_attn') + operation_order.count(
360
+ 'cross_attn')
361
+ if isinstance(attn_cfgs, dict):
362
+ attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
363
+ else:
364
+ assert num_attn == len(attn_cfgs), f'The length ' \
365
+ f'of attn_cfg {num_attn} is ' \
366
+ f'not consistent with the number of attention' \
367
+ f'in operation_order {operation_order}.'
368
+
369
+ self.num_attn = num_attn
370
+ self.operation_order = operation_order
371
+ self.norm_cfg = norm_cfg
372
+ self.pre_norm = operation_order[0] == 'norm'
373
+ self.attentions = ModuleList()
374
+
375
+ index = 0
376
+ for operation_name in operation_order:
377
+ if operation_name in ['self_attn', 'cross_attn']:
378
+ if 'batch_first' in attn_cfgs[index]:
379
+ assert self.batch_first == attn_cfgs[index]['batch_first']
380
+ else:
381
+ attn_cfgs[index]['batch_first'] = self.batch_first
382
+ attention = build_attention(attn_cfgs[index])
383
+ # Some custom attentions used as `self_attn`
384
+ # or `cross_attn` can have different behavior.
385
+ attention.operation_name = operation_name
386
+ self.attentions.append(attention)
387
+ index += 1
388
+
389
+ self.embed_dims = self.attentions[0].embed_dims
390
+
391
+ self.ffns = ModuleList()
392
+ num_ffns = operation_order.count('ffn')
393
+ if isinstance(ffn_cfgs, dict):
394
+ ffn_cfgs = ConfigDict(ffn_cfgs)
395
+ if isinstance(ffn_cfgs, dict):
396
+ ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
397
+ assert len(ffn_cfgs) == num_ffns
398
+ for ffn_index in range(num_ffns):
399
+ if 'embed_dims' not in ffn_cfgs[ffn_index]:
400
+ ffn_cfgs['embed_dims'] = self.embed_dims
401
+ else:
402
+ assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
403
+ self.ffns.append(
404
+ build_feedforward_network(ffn_cfgs[ffn_index],
405
+ dict(type='FFN')))
406
+
407
+ self.norms = ModuleList()
408
+ num_norms = operation_order.count('norm')
409
+ for _ in range(num_norms):
410
+ self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
411
+
412
+ def forward(self,
413
+ query,
414
+ key=None,
415
+ value=None,
416
+ query_pos=None,
417
+ key_pos=None,
418
+ attn_masks=None,
419
+ query_key_padding_mask=None,
420
+ key_padding_mask=None,
421
+ **kwargs):
422
+ """Forward function for `TransformerDecoderLayer`.
423
+
424
+ **kwargs contains some specific arguments of attentions.
425
+
426
+ Args:
427
+ query (Tensor): The input query with shape
428
+ [num_queries, bs, embed_dims] if
429
+ self.batch_first is False, else
430
+ [bs, num_queries embed_dims].
431
+ key (Tensor): The key tensor with shape [num_keys, bs,
432
+ embed_dims] if self.batch_first is False, else
433
+ [bs, num_keys, embed_dims] .
434
+ value (Tensor): The value tensor with same shape as `key`.
435
+ query_pos (Tensor): The positional encoding for `query`.
436
+ Default: None.
437
+ key_pos (Tensor): The positional encoding for `key`.
438
+ Default: None.
439
+ attn_masks (List[Tensor] | None): 2D Tensor used in
440
+ calculation of corresponding attention. The length of
441
+ it should equal to the number of `attention` in
442
+ `operation_order`. Default: None.
443
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
444
+ shape [bs, num_queries]. Only used in `self_attn` layer.
445
+ Defaults to None.
446
+ key_padding_mask (Tensor): ByteTensor for `query`, with
447
+ shape [bs, num_keys]. Default: None.
448
+
449
+ Returns:
450
+ Tensor: forwarded results with shape [num_queries, bs, embed_dims].
451
+ """
452
+
453
+ norm_index = 0
454
+ attn_index = 0
455
+ ffn_index = 0
456
+ identity = query
457
+ if attn_masks is None:
458
+ attn_masks = [None for _ in range(self.num_attn)]
459
+ elif isinstance(attn_masks, torch.Tensor):
460
+ attn_masks = [
461
+ copy.deepcopy(attn_masks) for _ in range(self.num_attn)
462
+ ]
463
+ warnings.warn(f'Use same attn_mask in all attentions in '
464
+ f'{self.__class__.__name__} ')
465
+ else:
466
+ assert len(attn_masks) == self.num_attn, f'The length of ' \
467
+ f'attn_masks {len(attn_masks)} must be equal ' \
468
+ f'to the number of attention in ' \
469
+ f'operation_order {self.num_attn}'
470
+
471
+ for layer in self.operation_order:
472
+ if layer == 'self_attn':
473
+ temp_key = temp_value = query
474
+ query = self.attentions[attn_index](
475
+ query,
476
+ temp_key,
477
+ temp_value,
478
+ identity if self.pre_norm else None,
479
+ query_pos=query_pos,
480
+ key_pos=query_pos,
481
+ attn_mask=attn_masks[attn_index],
482
+ key_padding_mask=query_key_padding_mask,
483
+ **kwargs)
484
+ attn_index += 1
485
+ identity = query
486
+
487
+ elif layer == 'norm':
488
+ query = self.norms[norm_index](query)
489
+ norm_index += 1
490
+
491
+ elif layer == 'cross_attn':
492
+ query = self.attentions[attn_index](
493
+ query,
494
+ key,
495
+ value,
496
+ identity if self.pre_norm else None,
497
+ query_pos=query_pos,
498
+ key_pos=key_pos,
499
+ attn_mask=attn_masks[attn_index],
500
+ key_padding_mask=key_padding_mask,
501
+ **kwargs)
502
+ attn_index += 1
503
+ identity = query
504
+
505
+ elif layer == 'ffn':
506
+ query = self.ffns[ffn_index](
507
+ query, identity if self.pre_norm else None)
508
+ ffn_index += 1
509
+
510
+ return query
511
+
512
+
513
+ @TRANSFORMER_LAYER_SEQUENCE.register_module()
514
+ class TransformerLayerSequence(BaseModule):
515
+ """Base class for TransformerEncoder and TransformerDecoder in vision
516
+ transformer.
517
+
518
+ As base-class of Encoder and Decoder in vision transformer.
519
+ Support customization such as specifying different kind
520
+ of `transformer_layer` in `transformer_coder`.
521
+
522
+ Args:
523
+ transformerlayer (list[obj:`mmcv.ConfigDict`] |
524
+ obj:`mmcv.ConfigDict`): Config of transformerlayer
525
+ in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
526
+ it would be repeated `num_layer` times to a
527
+ list[`mmcv.ConfigDict`]. Default: None.
528
+ num_layers (int): The number of `TransformerLayer`. Default: None.
529
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
530
+ Default: None.
531
+ """
532
+
533
+ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
534
+ super(TransformerLayerSequence, self).__init__(init_cfg)
535
+ if isinstance(transformerlayers, dict):
536
+ transformerlayers = [
537
+ copy.deepcopy(transformerlayers) for _ in range(num_layers)
538
+ ]
539
+ else:
540
+ assert isinstance(transformerlayers, list) and \
541
+ len(transformerlayers) == num_layers
542
+ self.num_layers = num_layers
543
+ self.layers = ModuleList()
544
+ for i in range(num_layers):
545
+ self.layers.append(build_transformer_layer(transformerlayers[i]))
546
+ self.embed_dims = self.layers[0].embed_dims
547
+ self.pre_norm = self.layers[0].pre_norm
548
+
549
+ def forward(self,
550
+ query,
551
+ key,
552
+ value,
553
+ query_pos=None,
554
+ key_pos=None,
555
+ attn_masks=None,
556
+ query_key_padding_mask=None,
557
+ key_padding_mask=None,
558
+ **kwargs):
559
+ """Forward function for `TransformerCoder`.
560
+
561
+ Args:
562
+ query (Tensor): Input query with shape
563
+ `(num_queries, bs, embed_dims)`.
564
+ key (Tensor): The key tensor with shape
565
+ `(num_keys, bs, embed_dims)`.
566
+ value (Tensor): The value tensor with shape
567
+ `(num_keys, bs, embed_dims)`.
568
+ query_pos (Tensor): The positional encoding for `query`.
569
+ Default: None.
570
+ key_pos (Tensor): The positional encoding for `key`.
571
+ Default: None.
572
+ attn_masks (List[Tensor], optional): Each element is 2D Tensor
573
+ which is used in calculation of corresponding attention in
574
+ operation_order. Default: None.
575
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
576
+ shape [bs, num_queries]. Only used in self-attention
577
+ Default: None.
578
+ key_padding_mask (Tensor): ByteTensor for `query`, with
579
+ shape [bs, num_keys]. Default: None.
580
+
581
+ Returns:
582
+ Tensor: results with shape [num_queries, bs, embed_dims].
583
+ """
584
+ for layer in self.layers:
585
+ query = layer(
586
+ query,
587
+ key,
588
+ value,
589
+ query_pos=query_pos,
590
+ key_pos=key_pos,
591
+ attn_masks=attn_masks,
592
+ query_key_padding_mask=query_key_padding_mask,
593
+ key_padding_mask=key_padding_mask,
594
+ **kwargs)
595
+ return query
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/upsample.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from ..utils import xavier_init
6
+ from .registry import UPSAMPLE_LAYERS
7
+
8
+ UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
9
+ UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
10
+
11
+
12
+ @UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
13
+ class PixelShufflePack(nn.Module):
14
+ """Pixel Shuffle upsample layer.
15
+
16
+ This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
17
+ achieve a simple upsampling with pixel shuffle.
18
+
19
+ Args:
20
+ in_channels (int): Number of input channels.
21
+ out_channels (int): Number of output channels.
22
+ scale_factor (int): Upsample ratio.
23
+ upsample_kernel (int): Kernel size of the conv layer to expand the
24
+ channels.
25
+ """
26
+
27
+ def __init__(self, in_channels, out_channels, scale_factor,
28
+ upsample_kernel):
29
+ super(PixelShufflePack, self).__init__()
30
+ self.in_channels = in_channels
31
+ self.out_channels = out_channels
32
+ self.scale_factor = scale_factor
33
+ self.upsample_kernel = upsample_kernel
34
+ self.upsample_conv = nn.Conv2d(
35
+ self.in_channels,
36
+ self.out_channels * scale_factor * scale_factor,
37
+ self.upsample_kernel,
38
+ padding=(self.upsample_kernel - 1) // 2)
39
+ self.init_weights()
40
+
41
+ def init_weights(self):
42
+ xavier_init(self.upsample_conv, distribution='uniform')
43
+
44
+ def forward(self, x):
45
+ x = self.upsample_conv(x)
46
+ x = F.pixel_shuffle(x, self.scale_factor)
47
+ return x
48
+
49
+
50
+ def build_upsample_layer(cfg, *args, **kwargs):
51
+ """Build upsample layer.
52
+
53
+ Args:
54
+ cfg (dict): The upsample layer config, which should contain:
55
+
56
+ - type (str): Layer type.
57
+ - scale_factor (int): Upsample ratio, which is not applicable to
58
+ deconv.
59
+ - layer args: Args needed to instantiate a upsample layer.
60
+ args (argument list): Arguments passed to the ``__init__``
61
+ method of the corresponding conv layer.
62
+ kwargs (keyword arguments): Keyword arguments passed to the
63
+ ``__init__`` method of the corresponding conv layer.
64
+
65
+ Returns:
66
+ nn.Module: Created upsample layer.
67
+ """
68
+ if not isinstance(cfg, dict):
69
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
70
+ if 'type' not in cfg:
71
+ raise KeyError(
72
+ f'the cfg dict must contain the key "type", but got {cfg}')
73
+ cfg_ = cfg.copy()
74
+
75
+ layer_type = cfg_.pop('type')
76
+ if layer_type not in UPSAMPLE_LAYERS:
77
+ raise KeyError(f'Unrecognized upsample type {layer_type}')
78
+ else:
79
+ upsample = UPSAMPLE_LAYERS.get(layer_type)
80
+
81
+ if upsample is nn.Upsample:
82
+ cfg_['mode'] = layer_type
83
+ layer = upsample(*args, **kwargs, **cfg_)
84
+ return layer
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/bricks/wrappers.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501
3
+
4
+ Wrap some nn modules to support empty tensor input. Currently, these wrappers
5
+ are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
6
+ heads are trained on only positive RoIs.
7
+ """
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn.modules.utils import _pair, _triple
13
+
14
+ from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
15
+
16
+ if torch.__version__ == 'parrots':
17
+ TORCH_VERSION = torch.__version__
18
+ else:
19
+ # torch.__version__ could be 1.3.1+cu92, we only need the first two
20
+ # for comparison
21
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
22
+
23
+
24
+ def obsolete_torch_version(torch_version, version_threshold):
25
+ return torch_version == 'parrots' or torch_version <= version_threshold
26
+
27
+
28
+ class NewEmptyTensorOp(torch.autograd.Function):
29
+
30
+ @staticmethod
31
+ def forward(ctx, x, new_shape):
32
+ ctx.shape = x.shape
33
+ return x.new_empty(new_shape)
34
+
35
+ @staticmethod
36
+ def backward(ctx, grad):
37
+ shape = ctx.shape
38
+ return NewEmptyTensorOp.apply(grad, shape), None
39
+
40
+
41
+ @CONV_LAYERS.register_module('Conv', force=True)
42
+ class Conv2d(nn.Conv2d):
43
+
44
+ def forward(self, x):
45
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
46
+ out_shape = [x.shape[0], self.out_channels]
47
+ for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
48
+ self.padding, self.stride, self.dilation):
49
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
50
+ out_shape.append(o)
51
+ empty = NewEmptyTensorOp.apply(x, out_shape)
52
+ if self.training:
53
+ # produce dummy gradient to avoid DDP warning.
54
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
55
+ return empty + dummy
56
+ else:
57
+ return empty
58
+
59
+ return super().forward(x)
60
+
61
+
62
+ @CONV_LAYERS.register_module('Conv3d', force=True)
63
+ class Conv3d(nn.Conv3d):
64
+
65
+ def forward(self, x):
66
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
67
+ out_shape = [x.shape[0], self.out_channels]
68
+ for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
69
+ self.padding, self.stride, self.dilation):
70
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
71
+ out_shape.append(o)
72
+ empty = NewEmptyTensorOp.apply(x, out_shape)
73
+ if self.training:
74
+ # produce dummy gradient to avoid DDP warning.
75
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
76
+ return empty + dummy
77
+ else:
78
+ return empty
79
+
80
+ return super().forward(x)
81
+
82
+
83
+ @CONV_LAYERS.register_module()
84
+ @CONV_LAYERS.register_module('deconv')
85
+ @UPSAMPLE_LAYERS.register_module('deconv', force=True)
86
+ class ConvTranspose2d(nn.ConvTranspose2d):
87
+
88
+ def forward(self, x):
89
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
90
+ out_shape = [x.shape[0], self.out_channels]
91
+ for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
92
+ self.padding, self.stride,
93
+ self.dilation, self.output_padding):
94
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
95
+ empty = NewEmptyTensorOp.apply(x, out_shape)
96
+ if self.training:
97
+ # produce dummy gradient to avoid DDP warning.
98
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
99
+ return empty + dummy
100
+ else:
101
+ return empty
102
+
103
+ return super().forward(x)
104
+
105
+
106
+ @CONV_LAYERS.register_module()
107
+ @CONV_LAYERS.register_module('deconv3d')
108
+ @UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
109
+ class ConvTranspose3d(nn.ConvTranspose3d):
110
+
111
+ def forward(self, x):
112
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
113
+ out_shape = [x.shape[0], self.out_channels]
114
+ for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
115
+ self.padding, self.stride,
116
+ self.dilation, self.output_padding):
117
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
118
+ empty = NewEmptyTensorOp.apply(x, out_shape)
119
+ if self.training:
120
+ # produce dummy gradient to avoid DDP warning.
121
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
122
+ return empty + dummy
123
+ else:
124
+ return empty
125
+
126
+ return super().forward(x)
127
+
128
+
129
+ class MaxPool2d(nn.MaxPool2d):
130
+
131
+ def forward(self, x):
132
+ # PyTorch 1.9 does not support empty tensor inference yet
133
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
134
+ out_shape = list(x.shape[:2])
135
+ for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
136
+ _pair(self.padding), _pair(self.stride),
137
+ _pair(self.dilation)):
138
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
139
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
140
+ out_shape.append(o)
141
+ empty = NewEmptyTensorOp.apply(x, out_shape)
142
+ return empty
143
+
144
+ return super().forward(x)
145
+
146
+
147
+ class MaxPool3d(nn.MaxPool3d):
148
+
149
+ def forward(self, x):
150
+ # PyTorch 1.9 does not support empty tensor inference yet
151
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
152
+ out_shape = list(x.shape[:2])
153
+ for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
154
+ _triple(self.padding),
155
+ _triple(self.stride),
156
+ _triple(self.dilation)):
157
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
158
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
159
+ out_shape.append(o)
160
+ empty = NewEmptyTensorOp.apply(x, out_shape)
161
+ return empty
162
+
163
+ return super().forward(x)
164
+
165
+
166
+ class Linear(torch.nn.Linear):
167
+
168
+ def forward(self, x):
169
+ # empty tensor forward of Linear layer is supported in Pytorch 1.6
170
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
171
+ out_shape = [x.shape[0], self.out_features]
172
+ empty = NewEmptyTensorOp.apply(x, out_shape)
173
+ if self.training:
174
+ # produce dummy gradient to avoid DDP warning.
175
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
176
+ return empty + dummy
177
+ else:
178
+ return empty
179
+
180
+ return super().forward(x)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/builder.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from ..runner import Sequential
3
+ from ..utils import Registry, build_from_cfg
4
+
5
+
6
+ def build_model_from_cfg(cfg, registry, default_args=None):
7
+ """Build a PyTorch model from config dict(s). Different from
8
+ ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
9
+
10
+ Args:
11
+ cfg (dict, list[dict]): The config of modules, is is either a config
12
+ dict or a list of config dicts. If cfg is a list, a
13
+ the built modules will be wrapped with ``nn.Sequential``.
14
+ registry (:obj:`Registry`): A registry the module belongs to.
15
+ default_args (dict, optional): Default arguments to build the module.
16
+ Defaults to None.
17
+
18
+ Returns:
19
+ nn.Module: A built nn module.
20
+ """
21
+ if isinstance(cfg, list):
22
+ modules = [
23
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
24
+ ]
25
+ return Sequential(*modules)
26
+ else:
27
+ return build_from_cfg(cfg, registry, default_args)
28
+
29
+
30
+ MODELS = Registry('model', build_func=build_model_from_cfg)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/resnet.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+
4
+ import torch.nn as nn
5
+ import torch.utils.checkpoint as cp
6
+
7
+ from .utils import constant_init, kaiming_init
8
+
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1, dilation=1):
11
+ """3x3 convolution with padding."""
12
+ return nn.Conv2d(
13
+ in_planes,
14
+ out_planes,
15
+ kernel_size=3,
16
+ stride=stride,
17
+ padding=dilation,
18
+ dilation=dilation,
19
+ bias=False)
20
+
21
+
22
+ class BasicBlock(nn.Module):
23
+ expansion = 1
24
+
25
+ def __init__(self,
26
+ inplanes,
27
+ planes,
28
+ stride=1,
29
+ dilation=1,
30
+ downsample=None,
31
+ style='pytorch',
32
+ with_cp=False):
33
+ super(BasicBlock, self).__init__()
34
+ assert style in ['pytorch', 'caffe']
35
+ self.conv1 = conv3x3(inplanes, planes, stride, dilation)
36
+ self.bn1 = nn.BatchNorm2d(planes)
37
+ self.relu = nn.ReLU(inplace=True)
38
+ self.conv2 = conv3x3(planes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes)
40
+ self.downsample = downsample
41
+ self.stride = stride
42
+ self.dilation = dilation
43
+ assert not with_cp
44
+
45
+ def forward(self, x):
46
+ residual = x
47
+
48
+ out = self.conv1(x)
49
+ out = self.bn1(out)
50
+ out = self.relu(out)
51
+
52
+ out = self.conv2(out)
53
+ out = self.bn2(out)
54
+
55
+ if self.downsample is not None:
56
+ residual = self.downsample(x)
57
+
58
+ out += residual
59
+ out = self.relu(out)
60
+
61
+ return out
62
+
63
+
64
+ class Bottleneck(nn.Module):
65
+ expansion = 4
66
+
67
+ def __init__(self,
68
+ inplanes,
69
+ planes,
70
+ stride=1,
71
+ dilation=1,
72
+ downsample=None,
73
+ style='pytorch',
74
+ with_cp=False):
75
+ """Bottleneck block.
76
+
77
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
78
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
79
+ """
80
+ super(Bottleneck, self).__init__()
81
+ assert style in ['pytorch', 'caffe']
82
+ if style == 'pytorch':
83
+ conv1_stride = 1
84
+ conv2_stride = stride
85
+ else:
86
+ conv1_stride = stride
87
+ conv2_stride = 1
88
+ self.conv1 = nn.Conv2d(
89
+ inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
90
+ self.conv2 = nn.Conv2d(
91
+ planes,
92
+ planes,
93
+ kernel_size=3,
94
+ stride=conv2_stride,
95
+ padding=dilation,
96
+ dilation=dilation,
97
+ bias=False)
98
+
99
+ self.bn1 = nn.BatchNorm2d(planes)
100
+ self.bn2 = nn.BatchNorm2d(planes)
101
+ self.conv3 = nn.Conv2d(
102
+ planes, planes * self.expansion, kernel_size=1, bias=False)
103
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
104
+ self.relu = nn.ReLU(inplace=True)
105
+ self.downsample = downsample
106
+ self.stride = stride
107
+ self.dilation = dilation
108
+ self.with_cp = with_cp
109
+
110
+ def forward(self, x):
111
+
112
+ def _inner_forward(x):
113
+ residual = x
114
+
115
+ out = self.conv1(x)
116
+ out = self.bn1(out)
117
+ out = self.relu(out)
118
+
119
+ out = self.conv2(out)
120
+ out = self.bn2(out)
121
+ out = self.relu(out)
122
+
123
+ out = self.conv3(out)
124
+ out = self.bn3(out)
125
+
126
+ if self.downsample is not None:
127
+ residual = self.downsample(x)
128
+
129
+ out += residual
130
+
131
+ return out
132
+
133
+ if self.with_cp and x.requires_grad:
134
+ out = cp.checkpoint(_inner_forward, x)
135
+ else:
136
+ out = _inner_forward(x)
137
+
138
+ out = self.relu(out)
139
+
140
+ return out
141
+
142
+
143
+ def make_res_layer(block,
144
+ inplanes,
145
+ planes,
146
+ blocks,
147
+ stride=1,
148
+ dilation=1,
149
+ style='pytorch',
150
+ with_cp=False):
151
+ downsample = None
152
+ if stride != 1 or inplanes != planes * block.expansion:
153
+ downsample = nn.Sequential(
154
+ nn.Conv2d(
155
+ inplanes,
156
+ planes * block.expansion,
157
+ kernel_size=1,
158
+ stride=stride,
159
+ bias=False),
160
+ nn.BatchNorm2d(planes * block.expansion),
161
+ )
162
+
163
+ layers = []
164
+ layers.append(
165
+ block(
166
+ inplanes,
167
+ planes,
168
+ stride,
169
+ dilation,
170
+ downsample,
171
+ style=style,
172
+ with_cp=with_cp))
173
+ inplanes = planes * block.expansion
174
+ for _ in range(1, blocks):
175
+ layers.append(
176
+ block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
177
+
178
+ return nn.Sequential(*layers)
179
+
180
+
181
+ class ResNet(nn.Module):
182
+ """ResNet backbone.
183
+
184
+ Args:
185
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
186
+ num_stages (int): Resnet stages, normally 4.
187
+ strides (Sequence[int]): Strides of the first block of each stage.
188
+ dilations (Sequence[int]): Dilation of each stage.
189
+ out_indices (Sequence[int]): Output from which stages.
190
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
191
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
192
+ the first 1x1 conv layer.
193
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
194
+ not freezing any parameters.
195
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
196
+ running stats (mean and var).
197
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
198
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
199
+ memory while slowing down the training speed.
200
+ """
201
+
202
+ arch_settings = {
203
+ 18: (BasicBlock, (2, 2, 2, 2)),
204
+ 34: (BasicBlock, (3, 4, 6, 3)),
205
+ 50: (Bottleneck, (3, 4, 6, 3)),
206
+ 101: (Bottleneck, (3, 4, 23, 3)),
207
+ 152: (Bottleneck, (3, 8, 36, 3))
208
+ }
209
+
210
+ def __init__(self,
211
+ depth,
212
+ num_stages=4,
213
+ strides=(1, 2, 2, 2),
214
+ dilations=(1, 1, 1, 1),
215
+ out_indices=(0, 1, 2, 3),
216
+ style='pytorch',
217
+ frozen_stages=-1,
218
+ bn_eval=True,
219
+ bn_frozen=False,
220
+ with_cp=False):
221
+ super(ResNet, self).__init__()
222
+ if depth not in self.arch_settings:
223
+ raise KeyError(f'invalid depth {depth} for resnet')
224
+ assert num_stages >= 1 and num_stages <= 4
225
+ block, stage_blocks = self.arch_settings[depth]
226
+ stage_blocks = stage_blocks[:num_stages]
227
+ assert len(strides) == len(dilations) == num_stages
228
+ assert max(out_indices) < num_stages
229
+
230
+ self.out_indices = out_indices
231
+ self.style = style
232
+ self.frozen_stages = frozen_stages
233
+ self.bn_eval = bn_eval
234
+ self.bn_frozen = bn_frozen
235
+ self.with_cp = with_cp
236
+
237
+ self.inplanes = 64
238
+ self.conv1 = nn.Conv2d(
239
+ 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
240
+ self.bn1 = nn.BatchNorm2d(64)
241
+ self.relu = nn.ReLU(inplace=True)
242
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
243
+
244
+ self.res_layers = []
245
+ for i, num_blocks in enumerate(stage_blocks):
246
+ stride = strides[i]
247
+ dilation = dilations[i]
248
+ planes = 64 * 2**i
249
+ res_layer = make_res_layer(
250
+ block,
251
+ self.inplanes,
252
+ planes,
253
+ num_blocks,
254
+ stride=stride,
255
+ dilation=dilation,
256
+ style=self.style,
257
+ with_cp=with_cp)
258
+ self.inplanes = planes * block.expansion
259
+ layer_name = f'layer{i + 1}'
260
+ self.add_module(layer_name, res_layer)
261
+ self.res_layers.append(layer_name)
262
+
263
+ self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
264
+
265
+ def init_weights(self, pretrained=None):
266
+ if isinstance(pretrained, str):
267
+ logger = logging.getLogger()
268
+ from ..runner import load_checkpoint
269
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
270
+ elif pretrained is None:
271
+ for m in self.modules():
272
+ if isinstance(m, nn.Conv2d):
273
+ kaiming_init(m)
274
+ elif isinstance(m, nn.BatchNorm2d):
275
+ constant_init(m, 1)
276
+ else:
277
+ raise TypeError('pretrained must be a str or None')
278
+
279
+ def forward(self, x):
280
+ x = self.conv1(x)
281
+ x = self.bn1(x)
282
+ x = self.relu(x)
283
+ x = self.maxpool(x)
284
+ outs = []
285
+ for i, layer_name in enumerate(self.res_layers):
286
+ res_layer = getattr(self, layer_name)
287
+ x = res_layer(x)
288
+ if i in self.out_indices:
289
+ outs.append(x)
290
+ if len(outs) == 1:
291
+ return outs[0]
292
+ else:
293
+ return tuple(outs)
294
+
295
+ def train(self, mode=True):
296
+ super(ResNet, self).train(mode)
297
+ if self.bn_eval:
298
+ for m in self.modules():
299
+ if isinstance(m, nn.BatchNorm2d):
300
+ m.eval()
301
+ if self.bn_frozen:
302
+ for params in m.parameters():
303
+ params.requires_grad = False
304
+ if mode and self.frozen_stages >= 0:
305
+ for param in self.conv1.parameters():
306
+ param.requires_grad = False
307
+ for param in self.bn1.parameters():
308
+ param.requires_grad = False
309
+ self.bn1.eval()
310
+ self.bn1.weight.requires_grad = False
311
+ self.bn1.bias.requires_grad = False
312
+ for i in range(1, self.frozen_stages + 1):
313
+ mod = getattr(self, f'layer{i}')
314
+ mod.eval()
315
+ for param in mod.parameters():
316
+ param.requires_grad = False
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .flops_counter import get_model_complexity_info
3
+ from .fuse_conv_bn import fuse_conv_bn
4
+ from .sync_bn import revert_sync_batchnorm
5
+ from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
6
+ KaimingInit, NormalInit, PretrainedInit,
7
+ TruncNormalInit, UniformInit, XavierInit,
8
+ bias_init_with_prob, caffe2_xavier_init,
9
+ constant_init, initialize, kaiming_init, normal_init,
10
+ trunc_normal_init, uniform_init, xavier_init)
11
+
12
+ __all__ = [
13
+ 'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
14
+ 'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
15
+ 'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
16
+ 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
17
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
18
+ 'Caffe2XavierInit', 'revert_sync_batchnorm'
19
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/flops_counter.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from flops-counter.pytorch by Vladislav Sovrasov
2
+ # original repo: https://github.com/sovrasov/flops-counter.pytorch
3
+
4
+ # MIT License
5
+
6
+ # Copyright (c) 2018 Vladislav Sovrasov
7
+
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+
15
+ # The above copyright notice and this permission notice shall be included in
16
+ # all copies or substantial portions of the Software.
17
+
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+ import sys
27
+ from functools import partial
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+
33
+ import annotator.mmpkg.mmcv as mmcv
34
+
35
+
36
+ def get_model_complexity_info(model,
37
+ input_shape,
38
+ print_per_layer_stat=True,
39
+ as_strings=True,
40
+ input_constructor=None,
41
+ flush=False,
42
+ ost=sys.stdout):
43
+ """Get complexity information of a model.
44
+
45
+ This method can calculate FLOPs and parameter counts of a model with
46
+ corresponding input shape. It can also print complexity information for
47
+ each layer in a model.
48
+
49
+ Supported layers are listed as below:
50
+ - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
51
+ - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
52
+ ``nn.ReLU6``.
53
+ - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
54
+ ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
55
+ ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
56
+ ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
57
+ ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
58
+ - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
59
+ ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
60
+ ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
61
+ - Linear: ``nn.Linear``.
62
+ - Deconvolution: ``nn.ConvTranspose2d``.
63
+ - Upsample: ``nn.Upsample``.
64
+
65
+ Args:
66
+ model (nn.Module): The model for complexity calculation.
67
+ input_shape (tuple): Input shape used for calculation.
68
+ print_per_layer_stat (bool): Whether to print complexity information
69
+ for each layer in a model. Default: True.
70
+ as_strings (bool): Output FLOPs and params counts in a string form.
71
+ Default: True.
72
+ input_constructor (None | callable): If specified, it takes a callable
73
+ method that generates input. otherwise, it will generate a random
74
+ tensor with input shape to calculate FLOPs. Default: None.
75
+ flush (bool): same as that in :func:`print`. Default: False.
76
+ ost (stream): same as ``file`` param in :func:`print`.
77
+ Default: sys.stdout.
78
+
79
+ Returns:
80
+ tuple[float | str]: If ``as_strings`` is set to True, it will return
81
+ FLOPs and parameter counts in a string format. otherwise, it will
82
+ return those in a float number format.
83
+ """
84
+ assert type(input_shape) is tuple
85
+ assert len(input_shape) >= 1
86
+ assert isinstance(model, nn.Module)
87
+ flops_model = add_flops_counting_methods(model)
88
+ flops_model.eval()
89
+ flops_model.start_flops_count()
90
+ if input_constructor:
91
+ input = input_constructor(input_shape)
92
+ _ = flops_model(**input)
93
+ else:
94
+ try:
95
+ batch = torch.ones(()).new_empty(
96
+ (1, *input_shape),
97
+ dtype=next(flops_model.parameters()).dtype,
98
+ device=next(flops_model.parameters()).device)
99
+ except StopIteration:
100
+ # Avoid StopIteration for models which have no parameters,
101
+ # like `nn.Relu()`, `nn.AvgPool2d`, etc.
102
+ batch = torch.ones(()).new_empty((1, *input_shape))
103
+
104
+ _ = flops_model(batch)
105
+
106
+ flops_count, params_count = flops_model.compute_average_flops_cost()
107
+ if print_per_layer_stat:
108
+ print_model_with_flops(
109
+ flops_model, flops_count, params_count, ost=ost, flush=flush)
110
+ flops_model.stop_flops_count()
111
+
112
+ if as_strings:
113
+ return flops_to_string(flops_count), params_to_string(params_count)
114
+
115
+ return flops_count, params_count
116
+
117
+
118
+ def flops_to_string(flops, units='GFLOPs', precision=2):
119
+ """Convert FLOPs number into a string.
120
+
121
+ Note that Here we take a multiply-add counts as one FLOP.
122
+
123
+ Args:
124
+ flops (float): FLOPs number to be converted.
125
+ units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
126
+ 'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
127
+ choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
128
+ precision (int): Digit number after the decimal point. Default: 2.
129
+
130
+ Returns:
131
+ str: The converted FLOPs number with units.
132
+
133
+ Examples:
134
+ >>> flops_to_string(1e9)
135
+ '1.0 GFLOPs'
136
+ >>> flops_to_string(2e5, 'MFLOPs')
137
+ '0.2 MFLOPs'
138
+ >>> flops_to_string(3e-9, None)
139
+ '3e-09 FLOPs'
140
+ """
141
+ if units is None:
142
+ if flops // 10**9 > 0:
143
+ return str(round(flops / 10.**9, precision)) + ' GFLOPs'
144
+ elif flops // 10**6 > 0:
145
+ return str(round(flops / 10.**6, precision)) + ' MFLOPs'
146
+ elif flops // 10**3 > 0:
147
+ return str(round(flops / 10.**3, precision)) + ' KFLOPs'
148
+ else:
149
+ return str(flops) + ' FLOPs'
150
+ else:
151
+ if units == 'GFLOPs':
152
+ return str(round(flops / 10.**9, precision)) + ' ' + units
153
+ elif units == 'MFLOPs':
154
+ return str(round(flops / 10.**6, precision)) + ' ' + units
155
+ elif units == 'KFLOPs':
156
+ return str(round(flops / 10.**3, precision)) + ' ' + units
157
+ else:
158
+ return str(flops) + ' FLOPs'
159
+
160
+
161
+ def params_to_string(num_params, units=None, precision=2):
162
+ """Convert parameter number into a string.
163
+
164
+ Args:
165
+ num_params (float): Parameter number to be converted.
166
+ units (str | None): Converted FLOPs units. Options are None, 'M',
167
+ 'K' and ''. If set to None, it will automatically choose the most
168
+ suitable unit for Parameter number. Default: None.
169
+ precision (int): Digit number after the decimal point. Default: 2.
170
+
171
+ Returns:
172
+ str: The converted parameter number with units.
173
+
174
+ Examples:
175
+ >>> params_to_string(1e9)
176
+ '1000.0 M'
177
+ >>> params_to_string(2e5)
178
+ '200.0 k'
179
+ >>> params_to_string(3e-9)
180
+ '3e-09'
181
+ """
182
+ if units is None:
183
+ if num_params // 10**6 > 0:
184
+ return str(round(num_params / 10**6, precision)) + ' M'
185
+ elif num_params // 10**3:
186
+ return str(round(num_params / 10**3, precision)) + ' k'
187
+ else:
188
+ return str(num_params)
189
+ else:
190
+ if units == 'M':
191
+ return str(round(num_params / 10.**6, precision)) + ' ' + units
192
+ elif units == 'K':
193
+ return str(round(num_params / 10.**3, precision)) + ' ' + units
194
+ else:
195
+ return str(num_params)
196
+
197
+
198
+ def print_model_with_flops(model,
199
+ total_flops,
200
+ total_params,
201
+ units='GFLOPs',
202
+ precision=3,
203
+ ost=sys.stdout,
204
+ flush=False):
205
+ """Print a model with FLOPs for each layer.
206
+
207
+ Args:
208
+ model (nn.Module): The model to be printed.
209
+ total_flops (float): Total FLOPs of the model.
210
+ total_params (float): Total parameter counts of the model.
211
+ units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
212
+ precision (int): Digit number after the decimal point. Default: 3.
213
+ ost (stream): same as `file` param in :func:`print`.
214
+ Default: sys.stdout.
215
+ flush (bool): same as that in :func:`print`. Default: False.
216
+
217
+ Example:
218
+ >>> class ExampleModel(nn.Module):
219
+
220
+ >>> def __init__(self):
221
+ >>> super().__init__()
222
+ >>> self.conv1 = nn.Conv2d(3, 8, 3)
223
+ >>> self.conv2 = nn.Conv2d(8, 256, 3)
224
+ >>> self.conv3 = nn.Conv2d(256, 8, 3)
225
+ >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
226
+ >>> self.flatten = nn.Flatten()
227
+ >>> self.fc = nn.Linear(8, 1)
228
+
229
+ >>> def forward(self, x):
230
+ >>> x = self.conv1(x)
231
+ >>> x = self.conv2(x)
232
+ >>> x = self.conv3(x)
233
+ >>> x = self.avg_pool(x)
234
+ >>> x = self.flatten(x)
235
+ >>> x = self.fc(x)
236
+ >>> return x
237
+
238
+ >>> model = ExampleModel()
239
+ >>> x = (3, 16, 16)
240
+ to print the complexity information state for each layer, you can use
241
+ >>> get_model_complexity_info(model, x)
242
+ or directly use
243
+ >>> print_model_with_flops(model, 4579784.0, 37361)
244
+ ExampleModel(
245
+ 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
246
+ (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
247
+ (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
248
+ (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
249
+ (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
250
+ (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
251
+ (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
252
+ )
253
+ """
254
+
255
+ def accumulate_params(self):
256
+ if is_supported_instance(self):
257
+ return self.__params__
258
+ else:
259
+ sum = 0
260
+ for m in self.children():
261
+ sum += m.accumulate_params()
262
+ return sum
263
+
264
+ def accumulate_flops(self):
265
+ if is_supported_instance(self):
266
+ return self.__flops__ / model.__batch_counter__
267
+ else:
268
+ sum = 0
269
+ for m in self.children():
270
+ sum += m.accumulate_flops()
271
+ return sum
272
+
273
+ def flops_repr(self):
274
+ accumulated_num_params = self.accumulate_params()
275
+ accumulated_flops_cost = self.accumulate_flops()
276
+ return ', '.join([
277
+ params_to_string(
278
+ accumulated_num_params, units='M', precision=precision),
279
+ '{:.3%} Params'.format(accumulated_num_params / total_params),
280
+ flops_to_string(
281
+ accumulated_flops_cost, units=units, precision=precision),
282
+ '{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
283
+ self.original_extra_repr()
284
+ ])
285
+
286
+ def add_extra_repr(m):
287
+ m.accumulate_flops = accumulate_flops.__get__(m)
288
+ m.accumulate_params = accumulate_params.__get__(m)
289
+ flops_extra_repr = flops_repr.__get__(m)
290
+ if m.extra_repr != flops_extra_repr:
291
+ m.original_extra_repr = m.extra_repr
292
+ m.extra_repr = flops_extra_repr
293
+ assert m.extra_repr != m.original_extra_repr
294
+
295
+ def del_extra_repr(m):
296
+ if hasattr(m, 'original_extra_repr'):
297
+ m.extra_repr = m.original_extra_repr
298
+ del m.original_extra_repr
299
+ if hasattr(m, 'accumulate_flops'):
300
+ del m.accumulate_flops
301
+
302
+ model.apply(add_extra_repr)
303
+ print(model, file=ost, flush=flush)
304
+ model.apply(del_extra_repr)
305
+
306
+
307
+ def get_model_parameters_number(model):
308
+ """Calculate parameter number of a model.
309
+
310
+ Args:
311
+ model (nn.module): The model for parameter number calculation.
312
+
313
+ Returns:
314
+ float: Parameter number of the model.
315
+ """
316
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
317
+ return num_params
318
+
319
+
320
+ def add_flops_counting_methods(net_main_module):
321
+ # adding additional methods to the existing module object,
322
+ # this is done this way so that each function has access to self object
323
+ net_main_module.start_flops_count = start_flops_count.__get__(
324
+ net_main_module)
325
+ net_main_module.stop_flops_count = stop_flops_count.__get__(
326
+ net_main_module)
327
+ net_main_module.reset_flops_count = reset_flops_count.__get__(
328
+ net_main_module)
329
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( # noqa: E501
330
+ net_main_module)
331
+
332
+ net_main_module.reset_flops_count()
333
+
334
+ return net_main_module
335
+
336
+
337
+ def compute_average_flops_cost(self):
338
+ """Compute average FLOPs cost.
339
+
340
+ A method to compute average FLOPs cost, which will be available after
341
+ `add_flops_counting_methods()` is called on a desired net object.
342
+
343
+ Returns:
344
+ float: Current mean flops consumption per image.
345
+ """
346
+ batches_count = self.__batch_counter__
347
+ flops_sum = 0
348
+ for module in self.modules():
349
+ if is_supported_instance(module):
350
+ flops_sum += module.__flops__
351
+ params_sum = get_model_parameters_number(self)
352
+ return flops_sum / batches_count, params_sum
353
+
354
+
355
+ def start_flops_count(self):
356
+ """Activate the computation of mean flops consumption per image.
357
+
358
+ A method to activate the computation of mean flops consumption per image.
359
+ which will be available after ``add_flops_counting_methods()`` is called on
360
+ a desired net object. It should be called before running the network.
361
+ """
362
+ add_batch_counter_hook_function(self)
363
+
364
+ def add_flops_counter_hook_function(module):
365
+ if is_supported_instance(module):
366
+ if hasattr(module, '__flops_handle__'):
367
+ return
368
+
369
+ else:
370
+ handle = module.register_forward_hook(
371
+ get_modules_mapping()[type(module)])
372
+
373
+ module.__flops_handle__ = handle
374
+
375
+ self.apply(partial(add_flops_counter_hook_function))
376
+
377
+
378
+ def stop_flops_count(self):
379
+ """Stop computing the mean flops consumption per image.
380
+
381
+ A method to stop computing the mean flops consumption per image, which will
382
+ be available after ``add_flops_counting_methods()`` is called on a desired
383
+ net object. It can be called to pause the computation whenever.
384
+ """
385
+ remove_batch_counter_hook_function(self)
386
+ self.apply(remove_flops_counter_hook_function)
387
+
388
+
389
+ def reset_flops_count(self):
390
+ """Reset statistics computed so far.
391
+
392
+ A method to Reset computed statistics, which will be available after
393
+ `add_flops_counting_methods()` is called on a desired net object.
394
+ """
395
+ add_batch_counter_variables_or_reset(self)
396
+ self.apply(add_flops_counter_variable_or_reset)
397
+
398
+
399
+ # ---- Internal functions
400
+ def empty_flops_counter_hook(module, input, output):
401
+ module.__flops__ += 0
402
+
403
+
404
+ def upsample_flops_counter_hook(module, input, output):
405
+ output_size = output[0]
406
+ batch_size = output_size.shape[0]
407
+ output_elements_count = batch_size
408
+ for val in output_size.shape[1:]:
409
+ output_elements_count *= val
410
+ module.__flops__ += int(output_elements_count)
411
+
412
+
413
+ def relu_flops_counter_hook(module, input, output):
414
+ active_elements_count = output.numel()
415
+ module.__flops__ += int(active_elements_count)
416
+
417
+
418
+ def linear_flops_counter_hook(module, input, output):
419
+ input = input[0]
420
+ output_last_dim = output.shape[
421
+ -1] # pytorch checks dimensions, so here we don't care much
422
+ module.__flops__ += int(np.prod(input.shape) * output_last_dim)
423
+
424
+
425
+ def pool_flops_counter_hook(module, input, output):
426
+ input = input[0]
427
+ module.__flops__ += int(np.prod(input.shape))
428
+
429
+
430
+ def norm_flops_counter_hook(module, input, output):
431
+ input = input[0]
432
+
433
+ batch_flops = np.prod(input.shape)
434
+ if (getattr(module, 'affine', False)
435
+ or getattr(module, 'elementwise_affine', False)):
436
+ batch_flops *= 2
437
+ module.__flops__ += int(batch_flops)
438
+
439
+
440
+ def deconv_flops_counter_hook(conv_module, input, output):
441
+ # Can have multiple inputs, getting the first one
442
+ input = input[0]
443
+
444
+ batch_size = input.shape[0]
445
+ input_height, input_width = input.shape[2:]
446
+
447
+ kernel_height, kernel_width = conv_module.kernel_size
448
+ in_channels = conv_module.in_channels
449
+ out_channels = conv_module.out_channels
450
+ groups = conv_module.groups
451
+
452
+ filters_per_channel = out_channels // groups
453
+ conv_per_position_flops = (
454
+ kernel_height * kernel_width * in_channels * filters_per_channel)
455
+
456
+ active_elements_count = batch_size * input_height * input_width
457
+ overall_conv_flops = conv_per_position_flops * active_elements_count
458
+ bias_flops = 0
459
+ if conv_module.bias is not None:
460
+ output_height, output_width = output.shape[2:]
461
+ bias_flops = out_channels * batch_size * output_height * output_height
462
+ overall_flops = overall_conv_flops + bias_flops
463
+
464
+ conv_module.__flops__ += int(overall_flops)
465
+
466
+
467
+ def conv_flops_counter_hook(conv_module, input, output):
468
+ # Can have multiple inputs, getting the first one
469
+ input = input[0]
470
+
471
+ batch_size = input.shape[0]
472
+ output_dims = list(output.shape[2:])
473
+
474
+ kernel_dims = list(conv_module.kernel_size)
475
+ in_channels = conv_module.in_channels
476
+ out_channels = conv_module.out_channels
477
+ groups = conv_module.groups
478
+
479
+ filters_per_channel = out_channels // groups
480
+ conv_per_position_flops = int(
481
+ np.prod(kernel_dims)) * in_channels * filters_per_channel
482
+
483
+ active_elements_count = batch_size * int(np.prod(output_dims))
484
+
485
+ overall_conv_flops = conv_per_position_flops * active_elements_count
486
+
487
+ bias_flops = 0
488
+
489
+ if conv_module.bias is not None:
490
+
491
+ bias_flops = out_channels * active_elements_count
492
+
493
+ overall_flops = overall_conv_flops + bias_flops
494
+
495
+ conv_module.__flops__ += int(overall_flops)
496
+
497
+
498
+ def batch_counter_hook(module, input, output):
499
+ batch_size = 1
500
+ if len(input) > 0:
501
+ # Can have multiple inputs, getting the first one
502
+ input = input[0]
503
+ batch_size = len(input)
504
+ else:
505
+ pass
506
+ print('Warning! No positional inputs found for a module, '
507
+ 'assuming batch size is 1.')
508
+ module.__batch_counter__ += batch_size
509
+
510
+
511
+ def add_batch_counter_variables_or_reset(module):
512
+
513
+ module.__batch_counter__ = 0
514
+
515
+
516
+ def add_batch_counter_hook_function(module):
517
+ if hasattr(module, '__batch_counter_handle__'):
518
+ return
519
+
520
+ handle = module.register_forward_hook(batch_counter_hook)
521
+ module.__batch_counter_handle__ = handle
522
+
523
+
524
+ def remove_batch_counter_hook_function(module):
525
+ if hasattr(module, '__batch_counter_handle__'):
526
+ module.__batch_counter_handle__.remove()
527
+ del module.__batch_counter_handle__
528
+
529
+
530
+ def add_flops_counter_variable_or_reset(module):
531
+ if is_supported_instance(module):
532
+ if hasattr(module, '__flops__') or hasattr(module, '__params__'):
533
+ print('Warning: variables __flops__ or __params__ are already '
534
+ 'defined for the module' + type(module).__name__ +
535
+ ' ptflops can affect your code!')
536
+ module.__flops__ = 0
537
+ module.__params__ = get_model_parameters_number(module)
538
+
539
+
540
+ def is_supported_instance(module):
541
+ if type(module) in get_modules_mapping():
542
+ return True
543
+ return False
544
+
545
+
546
+ def remove_flops_counter_hook_function(module):
547
+ if is_supported_instance(module):
548
+ if hasattr(module, '__flops_handle__'):
549
+ module.__flops_handle__.remove()
550
+ del module.__flops_handle__
551
+
552
+
553
+ def get_modules_mapping():
554
+ return {
555
+ # convolutions
556
+ nn.Conv1d: conv_flops_counter_hook,
557
+ nn.Conv2d: conv_flops_counter_hook,
558
+ mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
559
+ nn.Conv3d: conv_flops_counter_hook,
560
+ mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
561
+ # activations
562
+ nn.ReLU: relu_flops_counter_hook,
563
+ nn.PReLU: relu_flops_counter_hook,
564
+ nn.ELU: relu_flops_counter_hook,
565
+ nn.LeakyReLU: relu_flops_counter_hook,
566
+ nn.ReLU6: relu_flops_counter_hook,
567
+ # poolings
568
+ nn.MaxPool1d: pool_flops_counter_hook,
569
+ nn.AvgPool1d: pool_flops_counter_hook,
570
+ nn.AvgPool2d: pool_flops_counter_hook,
571
+ nn.MaxPool2d: pool_flops_counter_hook,
572
+ mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
573
+ nn.MaxPool3d: pool_flops_counter_hook,
574
+ mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
575
+ nn.AvgPool3d: pool_flops_counter_hook,
576
+ nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
577
+ nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
578
+ nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
579
+ nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
580
+ nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
581
+ nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
582
+ # normalizations
583
+ nn.BatchNorm1d: norm_flops_counter_hook,
584
+ nn.BatchNorm2d: norm_flops_counter_hook,
585
+ nn.BatchNorm3d: norm_flops_counter_hook,
586
+ nn.GroupNorm: norm_flops_counter_hook,
587
+ nn.InstanceNorm1d: norm_flops_counter_hook,
588
+ nn.InstanceNorm2d: norm_flops_counter_hook,
589
+ nn.InstanceNorm3d: norm_flops_counter_hook,
590
+ nn.LayerNorm: norm_flops_counter_hook,
591
+ # FC
592
+ nn.Linear: linear_flops_counter_hook,
593
+ mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
594
+ # Upscale
595
+ nn.Upsample: upsample_flops_counter_hook,
596
+ # Deconvolution
597
+ nn.ConvTranspose2d: deconv_flops_counter_hook,
598
+ mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
599
+ }
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/fuse_conv_bn.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ def _fuse_conv_bn(conv, bn):
7
+ """Fuse conv and bn into one module.
8
+
9
+ Args:
10
+ conv (nn.Module): Conv to be fused.
11
+ bn (nn.Module): BN to be fused.
12
+
13
+ Returns:
14
+ nn.Module: Fused module.
15
+ """
16
+ conv_w = conv.weight
17
+ conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
18
+ bn.running_mean)
19
+
20
+ factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
21
+ conv.weight = nn.Parameter(conv_w *
22
+ factor.reshape([conv.out_channels, 1, 1, 1]))
23
+ conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
24
+ return conv
25
+
26
+
27
+ def fuse_conv_bn(module):
28
+ """Recursively fuse conv and bn in a module.
29
+
30
+ During inference, the functionary of batch norm layers is turned off
31
+ but only the mean and var alone channels are used, which exposes the
32
+ chance to fuse it with the preceding conv layers to save computations and
33
+ simplify network structures.
34
+
35
+ Args:
36
+ module (nn.Module): Module to be fused.
37
+
38
+ Returns:
39
+ nn.Module: Fused module.
40
+ """
41
+ last_conv = None
42
+ last_conv_name = None
43
+
44
+ for name, child in module.named_children():
45
+ if isinstance(child,
46
+ (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
47
+ if last_conv is None: # only fuse BN that is after Conv
48
+ continue
49
+ fused_conv = _fuse_conv_bn(last_conv, child)
50
+ module._modules[last_conv_name] = fused_conv
51
+ # To reduce changes, set BN as Identity instead of deleting it.
52
+ module._modules[name] = nn.Identity()
53
+ last_conv = None
54
+ elif isinstance(child, nn.Conv2d):
55
+ last_conv = child
56
+ last_conv_name = name
57
+ else:
58
+ fuse_conv_bn(child)
59
+ return module
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/sync_bn.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import annotator.mmpkg.mmcv as mmcv
4
+
5
+
6
+ class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
7
+ """A general BatchNorm layer without input dimension check.
8
+
9
+ Reproduced from @kapily's work:
10
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
11
+ The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
12
+ is `_check_input_dim` that is designed for tensor sanity checks.
13
+ The check has been bypassed in this class for the convenience of converting
14
+ SyncBatchNorm.
15
+ """
16
+
17
+ def _check_input_dim(self, input):
18
+ return
19
+
20
+
21
+ def revert_sync_batchnorm(module):
22
+ """Helper function to convert all `SyncBatchNorm` (SyncBN) and
23
+ `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
24
+ `BatchNormXd` layers.
25
+
26
+ Adapted from @kapily's work:
27
+ (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
28
+
29
+ Args:
30
+ module (nn.Module): The module containing `SyncBatchNorm` layers.
31
+
32
+ Returns:
33
+ module_output: The converted module with `BatchNormXd` layers.
34
+ """
35
+ module_output = module
36
+ module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
37
+ if hasattr(mmcv, 'ops'):
38
+ module_checklist.append(mmcv.ops.SyncBatchNorm)
39
+ if isinstance(module, tuple(module_checklist)):
40
+ module_output = _BatchNormXd(module.num_features, module.eps,
41
+ module.momentum, module.affine,
42
+ module.track_running_stats)
43
+ if module.affine:
44
+ # no_grad() may not be needed here but
45
+ # just to be consistent with `convert_sync_batchnorm()`
46
+ with torch.no_grad():
47
+ module_output.weight = module.weight
48
+ module_output.bias = module.bias
49
+ module_output.running_mean = module.running_mean
50
+ module_output.running_var = module.running_var
51
+ module_output.num_batches_tracked = module.num_batches_tracked
52
+ module_output.training = module.training
53
+ # qconfig exists in quantized models
54
+ if hasattr(module, 'qconfig'):
55
+ module_output.qconfig = module.qconfig
56
+ for name, child in module.named_children():
57
+ module_output.add_module(name, revert_sync_batchnorm(child))
58
+ del module
59
+ return module_output
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/utils/weight_init.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import math
4
+ import warnings
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+
11
+ from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg, get_logger, print_log
12
+
13
+ INITIALIZERS = Registry('initializer')
14
+
15
+
16
+ def update_init_info(module, init_info):
17
+ """Update the `_params_init_info` in the module if the value of parameters
18
+ are changed.
19
+
20
+ Args:
21
+ module (obj:`nn.Module`): The module of PyTorch with a user-defined
22
+ attribute `_params_init_info` which records the initialization
23
+ information.
24
+ init_info (str): The string that describes the initialization.
25
+ """
26
+ assert hasattr(
27
+ module,
28
+ '_params_init_info'), f'Can not find `_params_init_info` in {module}'
29
+ for name, param in module.named_parameters():
30
+
31
+ assert param in module._params_init_info, (
32
+ f'Find a new :obj:`Parameter` '
33
+ f'named `{name}` during executing the '
34
+ f'`init_weights` of '
35
+ f'`{module.__class__.__name__}`. '
36
+ f'Please do not add or '
37
+ f'replace parameters during executing '
38
+ f'the `init_weights`. ')
39
+
40
+ # The parameter has been changed during executing the
41
+ # `init_weights` of module
42
+ mean_value = param.data.mean()
43
+ if module._params_init_info[param]['tmp_mean_value'] != mean_value:
44
+ module._params_init_info[param]['init_info'] = init_info
45
+ module._params_init_info[param]['tmp_mean_value'] = mean_value
46
+
47
+
48
+ def constant_init(module, val, bias=0):
49
+ if hasattr(module, 'weight') and module.weight is not None:
50
+ nn.init.constant_(module.weight, val)
51
+ if hasattr(module, 'bias') and module.bias is not None:
52
+ nn.init.constant_(module.bias, bias)
53
+
54
+
55
+ def xavier_init(module, gain=1, bias=0, distribution='normal'):
56
+ assert distribution in ['uniform', 'normal']
57
+ if hasattr(module, 'weight') and module.weight is not None:
58
+ if distribution == 'uniform':
59
+ nn.init.xavier_uniform_(module.weight, gain=gain)
60
+ else:
61
+ nn.init.xavier_normal_(module.weight, gain=gain)
62
+ if hasattr(module, 'bias') and module.bias is not None:
63
+ nn.init.constant_(module.bias, bias)
64
+
65
+
66
+ def normal_init(module, mean=0, std=1, bias=0):
67
+ if hasattr(module, 'weight') and module.weight is not None:
68
+ nn.init.normal_(module.weight, mean, std)
69
+ if hasattr(module, 'bias') and module.bias is not None:
70
+ nn.init.constant_(module.bias, bias)
71
+
72
+
73
+ def trunc_normal_init(module: nn.Module,
74
+ mean: float = 0,
75
+ std: float = 1,
76
+ a: float = -2,
77
+ b: float = 2,
78
+ bias: float = 0) -> None:
79
+ if hasattr(module, 'weight') and module.weight is not None:
80
+ trunc_normal_(module.weight, mean, std, a, b) # type: ignore
81
+ if hasattr(module, 'bias') and module.bias is not None:
82
+ nn.init.constant_(module.bias, bias) # type: ignore
83
+
84
+
85
+ def uniform_init(module, a=0, b=1, bias=0):
86
+ if hasattr(module, 'weight') and module.weight is not None:
87
+ nn.init.uniform_(module.weight, a, b)
88
+ if hasattr(module, 'bias') and module.bias is not None:
89
+ nn.init.constant_(module.bias, bias)
90
+
91
+
92
+ def kaiming_init(module,
93
+ a=0,
94
+ mode='fan_out',
95
+ nonlinearity='relu',
96
+ bias=0,
97
+ distribution='normal'):
98
+ assert distribution in ['uniform', 'normal']
99
+ if hasattr(module, 'weight') and module.weight is not None:
100
+ if distribution == 'uniform':
101
+ nn.init.kaiming_uniform_(
102
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
103
+ else:
104
+ nn.init.kaiming_normal_(
105
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
106
+ if hasattr(module, 'bias') and module.bias is not None:
107
+ nn.init.constant_(module.bias, bias)
108
+
109
+
110
+ def caffe2_xavier_init(module, bias=0):
111
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
112
+ # Acknowledgment to FAIR's internal code
113
+ kaiming_init(
114
+ module,
115
+ a=1,
116
+ mode='fan_in',
117
+ nonlinearity='leaky_relu',
118
+ bias=bias,
119
+ distribution='uniform')
120
+
121
+
122
+ def bias_init_with_prob(prior_prob):
123
+ """initialize conv/fc bias value according to a given probability value."""
124
+ bias_init = float(-np.log((1 - prior_prob) / prior_prob))
125
+ return bias_init
126
+
127
+
128
+ def _get_bases_name(m):
129
+ return [b.__name__ for b in m.__class__.__bases__]
130
+
131
+
132
+ class BaseInit(object):
133
+
134
+ def __init__(self, *, bias=0, bias_prob=None, layer=None):
135
+ self.wholemodule = False
136
+ if not isinstance(bias, (int, float)):
137
+ raise TypeError(f'bias must be a number, but got a {type(bias)}')
138
+
139
+ if bias_prob is not None:
140
+ if not isinstance(bias_prob, float):
141
+ raise TypeError(f'bias_prob type must be float, \
142
+ but got {type(bias_prob)}')
143
+
144
+ if layer is not None:
145
+ if not isinstance(layer, (str, list)):
146
+ raise TypeError(f'layer must be a str or a list of str, \
147
+ but got a {type(layer)}')
148
+ else:
149
+ layer = []
150
+
151
+ if bias_prob is not None:
152
+ self.bias = bias_init_with_prob(bias_prob)
153
+ else:
154
+ self.bias = bias
155
+ self.layer = [layer] if isinstance(layer, str) else layer
156
+
157
+ def _get_init_info(self):
158
+ info = f'{self.__class__.__name__}, bias={self.bias}'
159
+ return info
160
+
161
+
162
+ @INITIALIZERS.register_module(name='Constant')
163
+ class ConstantInit(BaseInit):
164
+ """Initialize module parameters with constant values.
165
+
166
+ Args:
167
+ val (int | float): the value to fill the weights in the module with
168
+ bias (int | float): the value to fill the bias. Defaults to 0.
169
+ bias_prob (float, optional): the probability for bias initialization.
170
+ Defaults to None.
171
+ layer (str | list[str], optional): the layer will be initialized.
172
+ Defaults to None.
173
+ """
174
+
175
+ def __init__(self, val, **kwargs):
176
+ super().__init__(**kwargs)
177
+ self.val = val
178
+
179
+ def __call__(self, module):
180
+
181
+ def init(m):
182
+ if self.wholemodule:
183
+ constant_init(m, self.val, self.bias)
184
+ else:
185
+ layername = m.__class__.__name__
186
+ basesname = _get_bases_name(m)
187
+ if len(set(self.layer) & set([layername] + basesname)):
188
+ constant_init(m, self.val, self.bias)
189
+
190
+ module.apply(init)
191
+ if hasattr(module, '_params_init_info'):
192
+ update_init_info(module, init_info=self._get_init_info())
193
+
194
+ def _get_init_info(self):
195
+ info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
196
+ return info
197
+
198
+
199
+ @INITIALIZERS.register_module(name='Xavier')
200
+ class XavierInit(BaseInit):
201
+ r"""Initialize module parameters with values according to the method
202
+ described in `Understanding the difficulty of training deep feedforward
203
+ neural networks - Glorot, X. & Bengio, Y. (2010).
204
+ <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
205
+
206
+ Args:
207
+ gain (int | float): an optional scaling factor. Defaults to 1.
208
+ bias (int | float): the value to fill the bias. Defaults to 0.
209
+ bias_prob (float, optional): the probability for bias initialization.
210
+ Defaults to None.
211
+ distribution (str): distribution either be ``'normal'``
212
+ or ``'uniform'``. Defaults to ``'normal'``.
213
+ layer (str | list[str], optional): the layer will be initialized.
214
+ Defaults to None.
215
+ """
216
+
217
+ def __init__(self, gain=1, distribution='normal', **kwargs):
218
+ super().__init__(**kwargs)
219
+ self.gain = gain
220
+ self.distribution = distribution
221
+
222
+ def __call__(self, module):
223
+
224
+ def init(m):
225
+ if self.wholemodule:
226
+ xavier_init(m, self.gain, self.bias, self.distribution)
227
+ else:
228
+ layername = m.__class__.__name__
229
+ basesname = _get_bases_name(m)
230
+ if len(set(self.layer) & set([layername] + basesname)):
231
+ xavier_init(m, self.gain, self.bias, self.distribution)
232
+
233
+ module.apply(init)
234
+ if hasattr(module, '_params_init_info'):
235
+ update_init_info(module, init_info=self._get_init_info())
236
+
237
+ def _get_init_info(self):
238
+ info = f'{self.__class__.__name__}: gain={self.gain}, ' \
239
+ f'distribution={self.distribution}, bias={self.bias}'
240
+ return info
241
+
242
+
243
+ @INITIALIZERS.register_module(name='Normal')
244
+ class NormalInit(BaseInit):
245
+ r"""Initialize module parameters with the values drawn from the normal
246
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
247
+
248
+ Args:
249
+ mean (int | float):the mean of the normal distribution. Defaults to 0.
250
+ std (int | float): the standard deviation of the normal distribution.
251
+ Defaults to 1.
252
+ bias (int | float): the value to fill the bias. Defaults to 0.
253
+ bias_prob (float, optional): the probability for bias initialization.
254
+ Defaults to None.
255
+ layer (str | list[str], optional): the layer will be initialized.
256
+ Defaults to None.
257
+
258
+ """
259
+
260
+ def __init__(self, mean=0, std=1, **kwargs):
261
+ super().__init__(**kwargs)
262
+ self.mean = mean
263
+ self.std = std
264
+
265
+ def __call__(self, module):
266
+
267
+ def init(m):
268
+ if self.wholemodule:
269
+ normal_init(m, self.mean, self.std, self.bias)
270
+ else:
271
+ layername = m.__class__.__name__
272
+ basesname = _get_bases_name(m)
273
+ if len(set(self.layer) & set([layername] + basesname)):
274
+ normal_init(m, self.mean, self.std, self.bias)
275
+
276
+ module.apply(init)
277
+ if hasattr(module, '_params_init_info'):
278
+ update_init_info(module, init_info=self._get_init_info())
279
+
280
+ def _get_init_info(self):
281
+ info = f'{self.__class__.__name__}: mean={self.mean},' \
282
+ f' std={self.std}, bias={self.bias}'
283
+ return info
284
+
285
+
286
+ @INITIALIZERS.register_module(name='TruncNormal')
287
+ class TruncNormalInit(BaseInit):
288
+ r"""Initialize module parameters with the values drawn from the normal
289
+ distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
290
+ outside :math:`[a, b]`.
291
+
292
+ Args:
293
+ mean (float): the mean of the normal distribution. Defaults to 0.
294
+ std (float): the standard deviation of the normal distribution.
295
+ Defaults to 1.
296
+ a (float): The minimum cutoff value.
297
+ b ( float): The maximum cutoff value.
298
+ bias (float): the value to fill the bias. Defaults to 0.
299
+ bias_prob (float, optional): the probability for bias initialization.
300
+ Defaults to None.
301
+ layer (str | list[str], optional): the layer will be initialized.
302
+ Defaults to None.
303
+
304
+ """
305
+
306
+ def __init__(self,
307
+ mean: float = 0,
308
+ std: float = 1,
309
+ a: float = -2,
310
+ b: float = 2,
311
+ **kwargs) -> None:
312
+ super().__init__(**kwargs)
313
+ self.mean = mean
314
+ self.std = std
315
+ self.a = a
316
+ self.b = b
317
+
318
+ def __call__(self, module: nn.Module) -> None:
319
+
320
+ def init(m):
321
+ if self.wholemodule:
322
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
323
+ self.bias)
324
+ else:
325
+ layername = m.__class__.__name__
326
+ basesname = _get_bases_name(m)
327
+ if len(set(self.layer) & set([layername] + basesname)):
328
+ trunc_normal_init(m, self.mean, self.std, self.a, self.b,
329
+ self.bias)
330
+
331
+ module.apply(init)
332
+ if hasattr(module, '_params_init_info'):
333
+ update_init_info(module, init_info=self._get_init_info())
334
+
335
+ def _get_init_info(self):
336
+ info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
337
+ f' mean={self.mean}, std={self.std}, bias={self.bias}'
338
+ return info
339
+
340
+
341
+ @INITIALIZERS.register_module(name='Uniform')
342
+ class UniformInit(BaseInit):
343
+ r"""Initialize module parameters with values drawn from the uniform
344
+ distribution :math:`\mathcal{U}(a, b)`.
345
+
346
+ Args:
347
+ a (int | float): the lower bound of the uniform distribution.
348
+ Defaults to 0.
349
+ b (int | float): the upper bound of the uniform distribution.
350
+ Defaults to 1.
351
+ bias (int | float): the value to fill the bias. Defaults to 0.
352
+ bias_prob (float, optional): the probability for bias initialization.
353
+ Defaults to None.
354
+ layer (str | list[str], optional): the layer will be initialized.
355
+ Defaults to None.
356
+ """
357
+
358
+ def __init__(self, a=0, b=1, **kwargs):
359
+ super().__init__(**kwargs)
360
+ self.a = a
361
+ self.b = b
362
+
363
+ def __call__(self, module):
364
+
365
+ def init(m):
366
+ if self.wholemodule:
367
+ uniform_init(m, self.a, self.b, self.bias)
368
+ else:
369
+ layername = m.__class__.__name__
370
+ basesname = _get_bases_name(m)
371
+ if len(set(self.layer) & set([layername] + basesname)):
372
+ uniform_init(m, self.a, self.b, self.bias)
373
+
374
+ module.apply(init)
375
+ if hasattr(module, '_params_init_info'):
376
+ update_init_info(module, init_info=self._get_init_info())
377
+
378
+ def _get_init_info(self):
379
+ info = f'{self.__class__.__name__}: a={self.a},' \
380
+ f' b={self.b}, bias={self.bias}'
381
+ return info
382
+
383
+
384
+ @INITIALIZERS.register_module(name='Kaiming')
385
+ class KaimingInit(BaseInit):
386
+ r"""Initialize module parameters with the values according to the method
387
+ described in `Delving deep into rectifiers: Surpassing human-level
388
+ performance on ImageNet classification - He, K. et al. (2015).
389
+ <https://www.cv-foundation.org/openaccess/content_iccv_2015/
390
+ papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
391
+
392
+ Args:
393
+ a (int | float): the negative slope of the rectifier used after this
394
+ layer (only used with ``'leaky_relu'``). Defaults to 0.
395
+ mode (str): either ``'fan_in'`` or ``'fan_out'``. Choosing
396
+ ``'fan_in'`` preserves the magnitude of the variance of the weights
397
+ in the forward pass. Choosing ``'fan_out'`` preserves the
398
+ magnitudes in the backwards pass. Defaults to ``'fan_out'``.
399
+ nonlinearity (str): the non-linear function (`nn.functional` name),
400
+ recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
401
+ Defaults to 'relu'.
402
+ bias (int | float): the value to fill the bias. Defaults to 0.
403
+ bias_prob (float, optional): the probability for bias initialization.
404
+ Defaults to None.
405
+ distribution (str): distribution either be ``'normal'`` or
406
+ ``'uniform'``. Defaults to ``'normal'``.
407
+ layer (str | list[str], optional): the layer will be initialized.
408
+ Defaults to None.
409
+ """
410
+
411
+ def __init__(self,
412
+ a=0,
413
+ mode='fan_out',
414
+ nonlinearity='relu',
415
+ distribution='normal',
416
+ **kwargs):
417
+ super().__init__(**kwargs)
418
+ self.a = a
419
+ self.mode = mode
420
+ self.nonlinearity = nonlinearity
421
+ self.distribution = distribution
422
+
423
+ def __call__(self, module):
424
+
425
+ def init(m):
426
+ if self.wholemodule:
427
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
428
+ self.bias, self.distribution)
429
+ else:
430
+ layername = m.__class__.__name__
431
+ basesname = _get_bases_name(m)
432
+ if len(set(self.layer) & set([layername] + basesname)):
433
+ kaiming_init(m, self.a, self.mode, self.nonlinearity,
434
+ self.bias, self.distribution)
435
+
436
+ module.apply(init)
437
+ if hasattr(module, '_params_init_info'):
438
+ update_init_info(module, init_info=self._get_init_info())
439
+
440
+ def _get_init_info(self):
441
+ info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
442
+ f'nonlinearity={self.nonlinearity}, ' \
443
+ f'distribution ={self.distribution}, bias={self.bias}'
444
+ return info
445
+
446
+
447
+ @INITIALIZERS.register_module(name='Caffe2Xavier')
448
+ class Caffe2XavierInit(KaimingInit):
449
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
450
+ # Acknowledgment to FAIR's internal code
451
+ def __init__(self, **kwargs):
452
+ super().__init__(
453
+ a=1,
454
+ mode='fan_in',
455
+ nonlinearity='leaky_relu',
456
+ distribution='uniform',
457
+ **kwargs)
458
+
459
+ def __call__(self, module):
460
+ super().__call__(module)
461
+
462
+
463
+ @INITIALIZERS.register_module(name='Pretrained')
464
+ class PretrainedInit(object):
465
+ """Initialize module by loading a pretrained model.
466
+
467
+ Args:
468
+ checkpoint (str): the checkpoint file of the pretrained model should
469
+ be load.
470
+ prefix (str, optional): the prefix of a sub-module in the pretrained
471
+ model. it is for loading a part of the pretrained model to
472
+ initialize. For example, if we would like to only load the
473
+ backbone of a detector model, we can set ``prefix='backbone.'``.
474
+ Defaults to None.
475
+ map_location (str): map tensors into proper locations.
476
+ """
477
+
478
+ def __init__(self, checkpoint, prefix=None, map_location=None):
479
+ self.checkpoint = checkpoint
480
+ self.prefix = prefix
481
+ self.map_location = map_location
482
+
483
+ def __call__(self, module):
484
+ from annotator.mmpkg.mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
485
+ load_state_dict)
486
+ logger = get_logger('mmcv')
487
+ if self.prefix is None:
488
+ print_log(f'load model from: {self.checkpoint}', logger=logger)
489
+ load_checkpoint(
490
+ module,
491
+ self.checkpoint,
492
+ map_location=self.map_location,
493
+ strict=False,
494
+ logger=logger)
495
+ else:
496
+ print_log(
497
+ f'load {self.prefix} in model from: {self.checkpoint}',
498
+ logger=logger)
499
+ state_dict = _load_checkpoint_with_prefix(
500
+ self.prefix, self.checkpoint, map_location=self.map_location)
501
+ load_state_dict(module, state_dict, strict=False, logger=logger)
502
+
503
+ if hasattr(module, '_params_init_info'):
504
+ update_init_info(module, init_info=self._get_init_info())
505
+
506
+ def _get_init_info(self):
507
+ info = f'{self.__class__.__name__}: load from {self.checkpoint}'
508
+ return info
509
+
510
+
511
+ def _initialize(module, cfg, wholemodule=False):
512
+ func = build_from_cfg(cfg, INITIALIZERS)
513
+ # wholemodule flag is for override mode, there is no layer key in override
514
+ # and initializer will give init values for the whole module with the name
515
+ # in override.
516
+ func.wholemodule = wholemodule
517
+ func(module)
518
+
519
+
520
+ def _initialize_override(module, override, cfg):
521
+ if not isinstance(override, (dict, list)):
522
+ raise TypeError(f'override must be a dict or a list of dict, \
523
+ but got {type(override)}')
524
+
525
+ override = [override] if isinstance(override, dict) else override
526
+
527
+ for override_ in override:
528
+
529
+ cp_override = copy.deepcopy(override_)
530
+ name = cp_override.pop('name', None)
531
+ if name is None:
532
+ raise ValueError('`override` must contain the key "name",'
533
+ f'but got {cp_override}')
534
+ # if override only has name key, it means use args in init_cfg
535
+ if not cp_override:
536
+ cp_override.update(cfg)
537
+ # if override has name key and other args except type key, it will
538
+ # raise error
539
+ elif 'type' not in cp_override.keys():
540
+ raise ValueError(
541
+ f'`override` need "type" key, but got {cp_override}')
542
+
543
+ if hasattr(module, name):
544
+ _initialize(getattr(module, name), cp_override, wholemodule=True)
545
+ else:
546
+ raise RuntimeError(f'module did not have attribute {name}, '
547
+ f'but init_cfg is {cp_override}.')
548
+
549
+
550
+ def initialize(module, init_cfg):
551
+ """Initialize a module.
552
+
553
+ Args:
554
+ module (``torch.nn.Module``): the module will be initialized.
555
+ init_cfg (dict | list[dict]): initialization configuration dict to
556
+ define initializer. OpenMMLab has implemented 6 initializers
557
+ including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
558
+ ``Kaiming``, and ``Pretrained``.
559
+ Example:
560
+ >>> module = nn.Linear(2, 3, bias=True)
561
+ >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
562
+ >>> initialize(module, init_cfg)
563
+
564
+ >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
565
+ >>> # define key ``'layer'`` for initializing layer with different
566
+ >>> # configuration
567
+ >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
568
+ dict(type='Constant', layer='Linear', val=2)]
569
+ >>> initialize(module, init_cfg)
570
+
571
+ >>> # define key``'override'`` to initialize some specific part in
572
+ >>> # module
573
+ >>> class FooNet(nn.Module):
574
+ >>> def __init__(self):
575
+ >>> super().__init__()
576
+ >>> self.feat = nn.Conv2d(3, 16, 3)
577
+ >>> self.reg = nn.Conv2d(16, 10, 3)
578
+ >>> self.cls = nn.Conv2d(16, 5, 3)
579
+ >>> model = FooNet()
580
+ >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
581
+ >>> override=dict(type='Constant', name='reg', val=3, bias=4))
582
+ >>> initialize(model, init_cfg)
583
+
584
+ >>> model = ResNet(depth=50)
585
+ >>> # Initialize weights with the pretrained model.
586
+ >>> init_cfg = dict(type='Pretrained',
587
+ checkpoint='torchvision://resnet50')
588
+ >>> initialize(model, init_cfg)
589
+
590
+ >>> # Initialize weights of a sub-module with the specific part of
591
+ >>> # a pretrained model by using "prefix".
592
+ >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
593
+ >>> 'retinanet_r50_fpn_1x_coco/'\
594
+ >>> 'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
595
+ >>> init_cfg = dict(type='Pretrained',
596
+ checkpoint=url, prefix='backbone.')
597
+ """
598
+ if not isinstance(init_cfg, (dict, list)):
599
+ raise TypeError(f'init_cfg must be a dict or a list of dict, \
600
+ but got {type(init_cfg)}')
601
+
602
+ if isinstance(init_cfg, dict):
603
+ init_cfg = [init_cfg]
604
+
605
+ for cfg in init_cfg:
606
+ # should deeply copy the original config because cfg may be used by
607
+ # other modules, e.g., one init_cfg shared by multiple bottleneck
608
+ # blocks, the expected cfg will be changed after pop and will change
609
+ # the initialization behavior of other modules
610
+ cp_cfg = copy.deepcopy(cfg)
611
+ override = cp_cfg.pop('override', None)
612
+ _initialize(module, cp_cfg)
613
+
614
+ if override is not None:
615
+ cp_cfg.pop('layer', None)
616
+ _initialize_override(module, override, cp_cfg)
617
+ else:
618
+ # All attributes in module have same initialization.
619
+ pass
620
+
621
+
622
+ def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
623
+ b: float) -> Tensor:
624
+ # Method based on
625
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
626
+ # Modified from
627
+ # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
628
+ def norm_cdf(x):
629
+ # Computes standard normal cumulative distribution function
630
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
631
+
632
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
633
+ warnings.warn(
634
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
635
+ 'The distribution of values may be incorrect.',
636
+ stacklevel=2)
637
+
638
+ with torch.no_grad():
639
+ # Values are generated by using a truncated uniform distribution and
640
+ # then using the inverse CDF for the normal distribution.
641
+ # Get upper and lower cdf values
642
+ lower = norm_cdf((a - mean) / std)
643
+ upper = norm_cdf((b - mean) / std)
644
+
645
+ # Uniformly fill tensor with values from [lower, upper], then translate
646
+ # to [2lower-1, 2upper-1].
647
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
648
+
649
+ # Use inverse cdf transform for normal distribution to get truncated
650
+ # standard normal
651
+ tensor.erfinv_()
652
+
653
+ # Transform to proper mean, std
654
+ tensor.mul_(std * math.sqrt(2.))
655
+ tensor.add_(mean)
656
+
657
+ # Clamp to ensure it's in the proper range
658
+ tensor.clamp_(min=a, max=b)
659
+ return tensor
660
+
661
+
662
+ def trunc_normal_(tensor: Tensor,
663
+ mean: float = 0.,
664
+ std: float = 1.,
665
+ a: float = -2.,
666
+ b: float = 2.) -> Tensor:
667
+ r"""Fills the input Tensor with values drawn from a truncated
668
+ normal distribution. The values are effectively drawn from the
669
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
670
+ with values outside :math:`[a, b]` redrawn until they are within
671
+ the bounds. The method used for generating the random values works
672
+ best when :math:`a \leq \text{mean} \leq b`.
673
+
674
+ Modified from
675
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
676
+
677
+ Args:
678
+ tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
679
+ mean (float): the mean of the normal distribution.
680
+ std (float): the standard deviation of the normal distribution.
681
+ a (float): the minimum cutoff value.
682
+ b (float): the maximum cutoff value.
683
+ """
684
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/cnn/vgg.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+
4
+ import torch.nn as nn
5
+
6
+ from .utils import constant_init, kaiming_init, normal_init
7
+
8
+
9
+ def conv3x3(in_planes, out_planes, dilation=1):
10
+ """3x3 convolution with padding."""
11
+ return nn.Conv2d(
12
+ in_planes,
13
+ out_planes,
14
+ kernel_size=3,
15
+ padding=dilation,
16
+ dilation=dilation)
17
+
18
+
19
+ def make_vgg_layer(inplanes,
20
+ planes,
21
+ num_blocks,
22
+ dilation=1,
23
+ with_bn=False,
24
+ ceil_mode=False):
25
+ layers = []
26
+ for _ in range(num_blocks):
27
+ layers.append(conv3x3(inplanes, planes, dilation))
28
+ if with_bn:
29
+ layers.append(nn.BatchNorm2d(planes))
30
+ layers.append(nn.ReLU(inplace=True))
31
+ inplanes = planes
32
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
33
+
34
+ return layers
35
+
36
+
37
+ class VGG(nn.Module):
38
+ """VGG backbone.
39
+
40
+ Args:
41
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
42
+ with_bn (bool): Use BatchNorm or not.
43
+ num_classes (int): number of classes for classification.
44
+ num_stages (int): VGG stages, normally 5.
45
+ dilations (Sequence[int]): Dilation of each stage.
46
+ out_indices (Sequence[int]): Output from which stages.
47
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
48
+ not freezing any parameters.
49
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
50
+ running stats (mean and var).
51
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
52
+ """
53
+
54
+ arch_settings = {
55
+ 11: (1, 1, 2, 2, 2),
56
+ 13: (2, 2, 2, 2, 2),
57
+ 16: (2, 2, 3, 3, 3),
58
+ 19: (2, 2, 4, 4, 4)
59
+ }
60
+
61
+ def __init__(self,
62
+ depth,
63
+ with_bn=False,
64
+ num_classes=-1,
65
+ num_stages=5,
66
+ dilations=(1, 1, 1, 1, 1),
67
+ out_indices=(0, 1, 2, 3, 4),
68
+ frozen_stages=-1,
69
+ bn_eval=True,
70
+ bn_frozen=False,
71
+ ceil_mode=False,
72
+ with_last_pool=True):
73
+ super(VGG, self).__init__()
74
+ if depth not in self.arch_settings:
75
+ raise KeyError(f'invalid depth {depth} for vgg')
76
+ assert num_stages >= 1 and num_stages <= 5
77
+ stage_blocks = self.arch_settings[depth]
78
+ self.stage_blocks = stage_blocks[:num_stages]
79
+ assert len(dilations) == num_stages
80
+ assert max(out_indices) <= num_stages
81
+
82
+ self.num_classes = num_classes
83
+ self.out_indices = out_indices
84
+ self.frozen_stages = frozen_stages
85
+ self.bn_eval = bn_eval
86
+ self.bn_frozen = bn_frozen
87
+
88
+ self.inplanes = 3
89
+ start_idx = 0
90
+ vgg_layers = []
91
+ self.range_sub_modules = []
92
+ for i, num_blocks in enumerate(self.stage_blocks):
93
+ num_modules = num_blocks * (2 + with_bn) + 1
94
+ end_idx = start_idx + num_modules
95
+ dilation = dilations[i]
96
+ planes = 64 * 2**i if i < 4 else 512
97
+ vgg_layer = make_vgg_layer(
98
+ self.inplanes,
99
+ planes,
100
+ num_blocks,
101
+ dilation=dilation,
102
+ with_bn=with_bn,
103
+ ceil_mode=ceil_mode)
104
+ vgg_layers.extend(vgg_layer)
105
+ self.inplanes = planes
106
+ self.range_sub_modules.append([start_idx, end_idx])
107
+ start_idx = end_idx
108
+ if not with_last_pool:
109
+ vgg_layers.pop(-1)
110
+ self.range_sub_modules[-1][1] -= 1
111
+ self.module_name = 'features'
112
+ self.add_module(self.module_name, nn.Sequential(*vgg_layers))
113
+
114
+ if self.num_classes > 0:
115
+ self.classifier = nn.Sequential(
116
+ nn.Linear(512 * 7 * 7, 4096),
117
+ nn.ReLU(True),
118
+ nn.Dropout(),
119
+ nn.Linear(4096, 4096),
120
+ nn.ReLU(True),
121
+ nn.Dropout(),
122
+ nn.Linear(4096, num_classes),
123
+ )
124
+
125
+ def init_weights(self, pretrained=None):
126
+ if isinstance(pretrained, str):
127
+ logger = logging.getLogger()
128
+ from ..runner import load_checkpoint
129
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
130
+ elif pretrained is None:
131
+ for m in self.modules():
132
+ if isinstance(m, nn.Conv2d):
133
+ kaiming_init(m)
134
+ elif isinstance(m, nn.BatchNorm2d):
135
+ constant_init(m, 1)
136
+ elif isinstance(m, nn.Linear):
137
+ normal_init(m, std=0.01)
138
+ else:
139
+ raise TypeError('pretrained must be a str or None')
140
+
141
+ def forward(self, x):
142
+ outs = []
143
+ vgg_layers = getattr(self, self.module_name)
144
+ for i in range(len(self.stage_blocks)):
145
+ for j in range(*self.range_sub_modules[i]):
146
+ vgg_layer = vgg_layers[j]
147
+ x = vgg_layer(x)
148
+ if i in self.out_indices:
149
+ outs.append(x)
150
+ if self.num_classes > 0:
151
+ x = x.view(x.size(0), -1)
152
+ x = self.classifier(x)
153
+ outs.append(x)
154
+ if len(outs) == 1:
155
+ return outs[0]
156
+ else:
157
+ return tuple(outs)
158
+
159
+ def train(self, mode=True):
160
+ super(VGG, self).train(mode)
161
+ if self.bn_eval:
162
+ for m in self.modules():
163
+ if isinstance(m, nn.BatchNorm2d):
164
+ m.eval()
165
+ if self.bn_frozen:
166
+ for params in m.parameters():
167
+ params.requires_grad = False
168
+ vgg_layers = getattr(self, self.module_name)
169
+ if mode and self.frozen_stages >= 0:
170
+ for i in range(self.frozen_stages):
171
+ for j in range(*self.range_sub_modules[i]):
172
+ mod = vgg_layers[j]
173
+ mod.eval()
174
+ for param in mod.parameters():
175
+ param.requires_grad = False
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/engine/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
3
+ single_gpu_test)
4
+
5
+ __all__ = [
6
+ 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
7
+ 'single_gpu_test'
8
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/engine/test.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ import pickle
4
+ import shutil
5
+ import tempfile
6
+ import time
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+
11
+ import annotator.mmpkg.mmcv as mmcv
12
+ from annotator.mmpkg.mmcv.runner import get_dist_info
13
+
14
+
15
+ def single_gpu_test(model, data_loader):
16
+ """Test model with a single gpu.
17
+
18
+ This method tests model with a single gpu and displays test progress bar.
19
+
20
+ Args:
21
+ model (nn.Module): Model to be tested.
22
+ data_loader (nn.Dataloader): Pytorch data loader.
23
+
24
+ Returns:
25
+ list: The prediction results.
26
+ """
27
+ model.eval()
28
+ results = []
29
+ dataset = data_loader.dataset
30
+ prog_bar = mmcv.ProgressBar(len(dataset))
31
+ for data in data_loader:
32
+ with torch.no_grad():
33
+ result = model(return_loss=False, **data)
34
+ results.extend(result)
35
+
36
+ # Assume result has the same length of batch_size
37
+ # refer to https://github.com/open-mmlab/mmcv/issues/985
38
+ batch_size = len(result)
39
+ for _ in range(batch_size):
40
+ prog_bar.update()
41
+ return results
42
+
43
+
44
+ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
45
+ """Test model with multiple gpus.
46
+
47
+ This method tests model with multiple gpus and collects the results
48
+ under two different modes: gpu and cpu modes. By setting
49
+ ``gpu_collect=True``, it encodes results to gpu tensors and use gpu
50
+ communication for results collection. On cpu mode it saves the results on
51
+ different gpus to ``tmpdir`` and collects them by the rank 0 worker.
52
+
53
+ Args:
54
+ model (nn.Module): Model to be tested.
55
+ data_loader (nn.Dataloader): Pytorch data loader.
56
+ tmpdir (str): Path of directory to save the temporary results from
57
+ different gpus under cpu mode.
58
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
59
+
60
+ Returns:
61
+ list: The prediction results.
62
+ """
63
+ model.eval()
64
+ results = []
65
+ dataset = data_loader.dataset
66
+ rank, world_size = get_dist_info()
67
+ if rank == 0:
68
+ prog_bar = mmcv.ProgressBar(len(dataset))
69
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
70
+ for i, data in enumerate(data_loader):
71
+ with torch.no_grad():
72
+ result = model(return_loss=False, **data)
73
+ results.extend(result)
74
+
75
+ if rank == 0:
76
+ batch_size = len(result)
77
+ batch_size_all = batch_size * world_size
78
+ if batch_size_all + prog_bar.completed > len(dataset):
79
+ batch_size_all = len(dataset) - prog_bar.completed
80
+ for _ in range(batch_size_all):
81
+ prog_bar.update()
82
+
83
+ # collect results from all ranks
84
+ if gpu_collect:
85
+ results = collect_results_gpu(results, len(dataset))
86
+ else:
87
+ results = collect_results_cpu(results, len(dataset), tmpdir)
88
+ return results
89
+
90
+
91
+ def collect_results_cpu(result_part, size, tmpdir=None):
92
+ """Collect results under cpu mode.
93
+
94
+ On cpu mode, this function will save the results on different gpus to
95
+ ``tmpdir`` and collect them by the rank 0 worker.
96
+
97
+ Args:
98
+ result_part (list): Result list containing result parts
99
+ to be collected.
100
+ size (int): Size of the results, commonly equal to length of
101
+ the results.
102
+ tmpdir (str | None): temporal directory for collected results to
103
+ store. If set to None, it will create a random temporal directory
104
+ for it.
105
+
106
+ Returns:
107
+ list: The collected results.
108
+ """
109
+ rank, world_size = get_dist_info()
110
+ # create a tmp dir if it is not specified
111
+ if tmpdir is None:
112
+ MAX_LEN = 512
113
+ # 32 is whitespace
114
+ dir_tensor = torch.full((MAX_LEN, ),
115
+ 32,
116
+ dtype=torch.uint8,
117
+ device='cuda')
118
+ if rank == 0:
119
+ mmcv.mkdir_or_exist('.dist_test')
120
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
121
+ tmpdir = torch.tensor(
122
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
123
+ dir_tensor[:len(tmpdir)] = tmpdir
124
+ dist.broadcast(dir_tensor, 0)
125
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
126
+ else:
127
+ mmcv.mkdir_or_exist(tmpdir)
128
+ # dump the part result to the dir
129
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
130
+ dist.barrier()
131
+ # collect all parts
132
+ if rank != 0:
133
+ return None
134
+ else:
135
+ # load results of all parts from tmp dir
136
+ part_list = []
137
+ for i in range(world_size):
138
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
139
+ part_result = mmcv.load(part_file)
140
+ # When data is severely insufficient, an empty part_result
141
+ # on a certain gpu could makes the overall outputs empty.
142
+ if part_result:
143
+ part_list.append(part_result)
144
+ # sort the results
145
+ ordered_results = []
146
+ for res in zip(*part_list):
147
+ ordered_results.extend(list(res))
148
+ # the dataloader may pad some samples
149
+ ordered_results = ordered_results[:size]
150
+ # remove tmp dir
151
+ shutil.rmtree(tmpdir)
152
+ return ordered_results
153
+
154
+
155
+ def collect_results_gpu(result_part, size):
156
+ """Collect results under gpu mode.
157
+
158
+ On gpu mode, this function will encode results to gpu tensors and use gpu
159
+ communication for results collection.
160
+
161
+ Args:
162
+ result_part (list): Result list containing result parts
163
+ to be collected.
164
+ size (int): Size of the results, commonly equal to length of
165
+ the results.
166
+
167
+ Returns:
168
+ list: The collected results.
169
+ """
170
+ rank, world_size = get_dist_info()
171
+ # dump result part to tensor with pickle
172
+ part_tensor = torch.tensor(
173
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
174
+ # gather all result part tensor shape
175
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
176
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
177
+ dist.all_gather(shape_list, shape_tensor)
178
+ # padding result part tensor to max length
179
+ shape_max = torch.tensor(shape_list).max()
180
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
181
+ part_send[:shape_tensor[0]] = part_tensor
182
+ part_recv_list = [
183
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
184
+ ]
185
+ # gather all result part
186
+ dist.all_gather(part_recv_list, part_send)
187
+
188
+ if rank == 0:
189
+ part_list = []
190
+ for recv, shape in zip(part_recv_list, shape_list):
191
+ part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
192
+ # When data is severely insufficient, an empty part_result
193
+ # on a certain gpu could makes the overall outputs empty.
194
+ if part_result:
195
+ part_list.append(part_result)
196
+ # sort the results
197
+ ordered_results = []
198
+ for res in zip(*part_list):
199
+ ordered_results.extend(list(res))
200
+ # the dataloader may pad some samples
201
+ ordered_results = ordered_results[:size]
202
+ return ordered_results
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .file_client import BaseStorageBackend, FileClient
3
+ from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
4
+ from .io import dump, load, register_handler
5
+ from .parse import dict_from_file, list_from_file
6
+
7
+ __all__ = [
8
+ 'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
9
+ 'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
10
+ 'list_from_file', 'dict_from_file'
11
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/file_client.py ADDED
@@ -0,0 +1,1148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import inspect
3
+ import os
4
+ import os.path as osp
5
+ import re
6
+ import tempfile
7
+ import warnings
8
+ from abc import ABCMeta, abstractmethod
9
+ from contextlib import contextmanager
10
+ from pathlib import Path
11
+ from typing import Iterable, Iterator, Optional, Tuple, Union
12
+ from urllib.request import urlopen
13
+
14
+ import annotator.mmpkg.mmcv as mmcv
15
+ from annotator.mmpkg.mmcv.utils.misc import has_method
16
+ from annotator.mmpkg.mmcv.utils.path import is_filepath
17
+
18
+
19
+ class BaseStorageBackend(metaclass=ABCMeta):
20
+ """Abstract class of storage backends.
21
+
22
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
23
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
24
+ as texts.
25
+ """
26
+
27
+ # a flag to indicate whether the backend can create a symlink for a file
28
+ _allow_symlink = False
29
+
30
+ @property
31
+ def name(self):
32
+ return self.__class__.__name__
33
+
34
+ @property
35
+ def allow_symlink(self):
36
+ return self._allow_symlink
37
+
38
+ @abstractmethod
39
+ def get(self, filepath):
40
+ pass
41
+
42
+ @abstractmethod
43
+ def get_text(self, filepath):
44
+ pass
45
+
46
+
47
+ class CephBackend(BaseStorageBackend):
48
+ """Ceph storage backend (for internal use).
49
+
50
+ Args:
51
+ path_mapping (dict|None): path mapping dict from local path to Petrel
52
+ path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
53
+ will be replaced by ``dst``. Default: None.
54
+
55
+ .. warning::
56
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
57
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
58
+ """
59
+
60
+ def __init__(self, path_mapping=None):
61
+ try:
62
+ import ceph
63
+ except ImportError:
64
+ raise ImportError('Please install ceph to enable CephBackend.')
65
+
66
+ warnings.warn(
67
+ 'CephBackend will be deprecated, please use PetrelBackend instead')
68
+ self._client = ceph.S3Client()
69
+ assert isinstance(path_mapping, dict) or path_mapping is None
70
+ self.path_mapping = path_mapping
71
+
72
+ def get(self, filepath):
73
+ filepath = str(filepath)
74
+ if self.path_mapping is not None:
75
+ for k, v in self.path_mapping.items():
76
+ filepath = filepath.replace(k, v)
77
+ value = self._client.Get(filepath)
78
+ value_buf = memoryview(value)
79
+ return value_buf
80
+
81
+ def get_text(self, filepath, encoding=None):
82
+ raise NotImplementedError
83
+
84
+
85
+ class PetrelBackend(BaseStorageBackend):
86
+ """Petrel storage backend (for internal use).
87
+
88
+ PetrelBackend supports reading and writing data to multiple clusters.
89
+ If the file path contains the cluster name, PetrelBackend will read data
90
+ from specified cluster or write data to it. Otherwise, PetrelBackend will
91
+ access the default cluster.
92
+
93
+ Args:
94
+ path_mapping (dict, optional): Path mapping dict from local path to
95
+ Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
96
+ ``filepath`` will be replaced by ``dst``. Default: None.
97
+ enable_mc (bool, optional): Whether to enable memcached support.
98
+ Default: True.
99
+
100
+ Examples:
101
+ >>> filepath1 = 's3://path/of/file'
102
+ >>> filepath2 = 'cluster-name:s3://path/of/file'
103
+ >>> client = PetrelBackend()
104
+ >>> client.get(filepath1) # get data from default cluster
105
+ >>> client.get(filepath2) # get data from 'cluster-name' cluster
106
+ """
107
+
108
+ def __init__(self,
109
+ path_mapping: Optional[dict] = None,
110
+ enable_mc: bool = True):
111
+ try:
112
+ from petrel_client import client
113
+ except ImportError:
114
+ raise ImportError('Please install petrel_client to enable '
115
+ 'PetrelBackend.')
116
+
117
+ self._client = client.Client(enable_mc=enable_mc)
118
+ assert isinstance(path_mapping, dict) or path_mapping is None
119
+ self.path_mapping = path_mapping
120
+
121
+ def _map_path(self, filepath: Union[str, Path]) -> str:
122
+ """Map ``filepath`` to a string path whose prefix will be replaced by
123
+ :attr:`self.path_mapping`.
124
+
125
+ Args:
126
+ filepath (str): Path to be mapped.
127
+ """
128
+ filepath = str(filepath)
129
+ if self.path_mapping is not None:
130
+ for k, v in self.path_mapping.items():
131
+ filepath = filepath.replace(k, v)
132
+ return filepath
133
+
134
+ def _format_path(self, filepath: str) -> str:
135
+ """Convert a ``filepath`` to standard format of petrel oss.
136
+
137
+ If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
138
+ environment, the ``filepath`` will be the format of
139
+ 's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
140
+ above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
141
+
142
+ Args:
143
+ filepath (str): Path to be formatted.
144
+ """
145
+ return re.sub(r'\\+', '/', filepath)
146
+
147
+ def get(self, filepath: Union[str, Path]) -> memoryview:
148
+ """Read data from a given ``filepath`` with 'rb' mode.
149
+
150
+ Args:
151
+ filepath (str or Path): Path to read data.
152
+
153
+ Returns:
154
+ memoryview: A memory view of expected bytes object to avoid
155
+ copying. The memoryview object can be converted to bytes by
156
+ ``value_buf.tobytes()``.
157
+ """
158
+ filepath = self._map_path(filepath)
159
+ filepath = self._format_path(filepath)
160
+ value = self._client.Get(filepath)
161
+ value_buf = memoryview(value)
162
+ return value_buf
163
+
164
+ def get_text(self,
165
+ filepath: Union[str, Path],
166
+ encoding: str = 'utf-8') -> str:
167
+ """Read data from a given ``filepath`` with 'r' mode.
168
+
169
+ Args:
170
+ filepath (str or Path): Path to read data.
171
+ encoding (str): The encoding format used to open the ``filepath``.
172
+ Default: 'utf-8'.
173
+
174
+ Returns:
175
+ str: Expected text reading from ``filepath``.
176
+ """
177
+ return str(self.get(filepath), encoding=encoding)
178
+
179
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
180
+ """Save data to a given ``filepath``.
181
+
182
+ Args:
183
+ obj (bytes): Data to be saved.
184
+ filepath (str or Path): Path to write data.
185
+ """
186
+ filepath = self._map_path(filepath)
187
+ filepath = self._format_path(filepath)
188
+ self._client.put(filepath, obj)
189
+
190
+ def put_text(self,
191
+ obj: str,
192
+ filepath: Union[str, Path],
193
+ encoding: str = 'utf-8') -> None:
194
+ """Save data to a given ``filepath``.
195
+
196
+ Args:
197
+ obj (str): Data to be written.
198
+ filepath (str or Path): Path to write data.
199
+ encoding (str): The encoding format used to encode the ``obj``.
200
+ Default: 'utf-8'.
201
+ """
202
+ self.put(bytes(obj, encoding=encoding), filepath)
203
+
204
+ def remove(self, filepath: Union[str, Path]) -> None:
205
+ """Remove a file.
206
+
207
+ Args:
208
+ filepath (str or Path): Path to be removed.
209
+ """
210
+ if not has_method(self._client, 'delete'):
211
+ raise NotImplementedError(
212
+ ('Current version of Petrel Python SDK has not supported '
213
+ 'the `delete` method, please use a higher version or dev'
214
+ ' branch instead.'))
215
+
216
+ filepath = self._map_path(filepath)
217
+ filepath = self._format_path(filepath)
218
+ self._client.delete(filepath)
219
+
220
+ def exists(self, filepath: Union[str, Path]) -> bool:
221
+ """Check whether a file path exists.
222
+
223
+ Args:
224
+ filepath (str or Path): Path to be checked whether exists.
225
+
226
+ Returns:
227
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
228
+ """
229
+ if not (has_method(self._client, 'contains')
230
+ and has_method(self._client, 'isdir')):
231
+ raise NotImplementedError(
232
+ ('Current version of Petrel Python SDK has not supported '
233
+ 'the `contains` and `isdir` methods, please use a higher'
234
+ 'version or dev branch instead.'))
235
+
236
+ filepath = self._map_path(filepath)
237
+ filepath = self._format_path(filepath)
238
+ return self._client.contains(filepath) or self._client.isdir(filepath)
239
+
240
+ def isdir(self, filepath: Union[str, Path]) -> bool:
241
+ """Check whether a file path is a directory.
242
+
243
+ Args:
244
+ filepath (str or Path): Path to be checked whether it is a
245
+ directory.
246
+
247
+ Returns:
248
+ bool: Return ``True`` if ``filepath`` points to a directory,
249
+ ``False`` otherwise.
250
+ """
251
+ if not has_method(self._client, 'isdir'):
252
+ raise NotImplementedError(
253
+ ('Current version of Petrel Python SDK has not supported '
254
+ 'the `isdir` method, please use a higher version or dev'
255
+ ' branch instead.'))
256
+
257
+ filepath = self._map_path(filepath)
258
+ filepath = self._format_path(filepath)
259
+ return self._client.isdir(filepath)
260
+
261
+ def isfile(self, filepath: Union[str, Path]) -> bool:
262
+ """Check whether a file path is a file.
263
+
264
+ Args:
265
+ filepath (str or Path): Path to be checked whether it is a file.
266
+
267
+ Returns:
268
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
269
+ otherwise.
270
+ """
271
+ if not has_method(self._client, 'contains'):
272
+ raise NotImplementedError(
273
+ ('Current version of Petrel Python SDK has not supported '
274
+ 'the `contains` method, please use a higher version or '
275
+ 'dev branch instead.'))
276
+
277
+ filepath = self._map_path(filepath)
278
+ filepath = self._format_path(filepath)
279
+ return self._client.contains(filepath)
280
+
281
+ def join_path(self, filepath: Union[str, Path],
282
+ *filepaths: Union[str, Path]) -> str:
283
+ """Concatenate all file paths.
284
+
285
+ Args:
286
+ filepath (str or Path): Path to be concatenated.
287
+
288
+ Returns:
289
+ str: The result after concatenation.
290
+ """
291
+ filepath = self._format_path(self._map_path(filepath))
292
+ if filepath.endswith('/'):
293
+ filepath = filepath[:-1]
294
+ formatted_paths = [filepath]
295
+ for path in filepaths:
296
+ formatted_paths.append(self._format_path(self._map_path(path)))
297
+ return '/'.join(formatted_paths)
298
+
299
+ @contextmanager
300
+ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
301
+ """Download a file from ``filepath`` and return a temporary path.
302
+
303
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
304
+ can be called with ``with`` statement, and when exists from the
305
+ ``with`` statement, the temporary path will be released.
306
+
307
+ Args:
308
+ filepath (str | Path): Download a file from ``filepath``.
309
+
310
+ Examples:
311
+ >>> client = PetrelBackend()
312
+ >>> # After existing from the ``with`` clause,
313
+ >>> # the path will be removed
314
+ >>> with client.get_local_path('s3://path/of/your/file') as path:
315
+ ... # do something here
316
+
317
+ Yields:
318
+ Iterable[str]: Only yield one temporary path.
319
+ """
320
+ filepath = self._map_path(filepath)
321
+ filepath = self._format_path(filepath)
322
+ assert self.isfile(filepath)
323
+ try:
324
+ f = tempfile.NamedTemporaryFile(delete=False)
325
+ f.write(self.get(filepath))
326
+ f.close()
327
+ yield f.name
328
+ finally:
329
+ os.remove(f.name)
330
+
331
+ def list_dir_or_file(self,
332
+ dir_path: Union[str, Path],
333
+ list_dir: bool = True,
334
+ list_file: bool = True,
335
+ suffix: Optional[Union[str, Tuple[str]]] = None,
336
+ recursive: bool = False) -> Iterator[str]:
337
+ """Scan a directory to find the interested directories or files in
338
+ arbitrary order.
339
+
340
+ Note:
341
+ Petrel has no concept of directories but it simulates the directory
342
+ hierarchy in the filesystem through public prefixes. In addition,
343
+ if the returned path ends with '/', it means the path is a public
344
+ prefix which is a logical directory.
345
+
346
+ Note:
347
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
348
+ In addition, the returned path of directory will not contains the
349
+ suffix '/' which is consistent with other backends.
350
+
351
+ Args:
352
+ dir_path (str | Path): Path of the directory.
353
+ list_dir (bool): List the directories. Default: True.
354
+ list_file (bool): List the path of files. Default: True.
355
+ suffix (str or tuple[str], optional): File suffix
356
+ that we are interested in. Default: None.
357
+ recursive (bool): If set to True, recursively scan the
358
+ directory. Default: False.
359
+
360
+ Yields:
361
+ Iterable[str]: A relative path to ``dir_path``.
362
+ """
363
+ if not has_method(self._client, 'list'):
364
+ raise NotImplementedError(
365
+ ('Current version of Petrel Python SDK has not supported '
366
+ 'the `list` method, please use a higher version or dev'
367
+ ' branch instead.'))
368
+
369
+ dir_path = self._map_path(dir_path)
370
+ dir_path = self._format_path(dir_path)
371
+ if list_dir and suffix is not None:
372
+ raise TypeError(
373
+ '`list_dir` should be False when `suffix` is not None')
374
+
375
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
376
+ raise TypeError('`suffix` must be a string or tuple of strings')
377
+
378
+ # Petrel's simulated directory hierarchy assumes that directory paths
379
+ # should end with `/`
380
+ if not dir_path.endswith('/'):
381
+ dir_path += '/'
382
+
383
+ root = dir_path
384
+
385
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
386
+ recursive):
387
+ for path in self._client.list(dir_path):
388
+ # the `self.isdir` is not used here to determine whether path
389
+ # is a directory, because `self.isdir` relies on
390
+ # `self._client.list`
391
+ if path.endswith('/'): # a directory path
392
+ next_dir_path = self.join_path(dir_path, path)
393
+ if list_dir:
394
+ # get the relative path and exclude the last
395
+ # character '/'
396
+ rel_dir = next_dir_path[len(root):-1]
397
+ yield rel_dir
398
+ if recursive:
399
+ yield from _list_dir_or_file(next_dir_path, list_dir,
400
+ list_file, suffix,
401
+ recursive)
402
+ else: # a file path
403
+ absolute_path = self.join_path(dir_path, path)
404
+ rel_path = absolute_path[len(root):]
405
+ if (suffix is None
406
+ or rel_path.endswith(suffix)) and list_file:
407
+ yield rel_path
408
+
409
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
410
+ recursive)
411
+
412
+
413
+ class MemcachedBackend(BaseStorageBackend):
414
+ """Memcached storage backend.
415
+
416
+ Attributes:
417
+ server_list_cfg (str): Config file for memcached server list.
418
+ client_cfg (str): Config file for memcached client.
419
+ sys_path (str | None): Additional path to be appended to `sys.path`.
420
+ Default: None.
421
+ """
422
+
423
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
424
+ if sys_path is not None:
425
+ import sys
426
+ sys.path.append(sys_path)
427
+ try:
428
+ import mc
429
+ except ImportError:
430
+ raise ImportError(
431
+ 'Please install memcached to enable MemcachedBackend.')
432
+
433
+ self.server_list_cfg = server_list_cfg
434
+ self.client_cfg = client_cfg
435
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
436
+ self.client_cfg)
437
+ # mc.pyvector servers as a point which points to a memory cache
438
+ self._mc_buffer = mc.pyvector()
439
+
440
+ def get(self, filepath):
441
+ filepath = str(filepath)
442
+ import mc
443
+ self._client.Get(filepath, self._mc_buffer)
444
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
445
+ return value_buf
446
+
447
+ def get_text(self, filepath, encoding=None):
448
+ raise NotImplementedError
449
+
450
+
451
+ class LmdbBackend(BaseStorageBackend):
452
+ """Lmdb storage backend.
453
+
454
+ Args:
455
+ db_path (str): Lmdb database path.
456
+ readonly (bool, optional): Lmdb environment parameter. If True,
457
+ disallow any write operations. Default: True.
458
+ lock (bool, optional): Lmdb environment parameter. If False, when
459
+ concurrent access occurs, do not lock the database. Default: False.
460
+ readahead (bool, optional): Lmdb environment parameter. If False,
461
+ disable the OS filesystem readahead mechanism, which may improve
462
+ random read performance when a database is larger than RAM.
463
+ Default: False.
464
+
465
+ Attributes:
466
+ db_path (str): Lmdb database path.
467
+ """
468
+
469
+ def __init__(self,
470
+ db_path,
471
+ readonly=True,
472
+ lock=False,
473
+ readahead=False,
474
+ **kwargs):
475
+ try:
476
+ import lmdb
477
+ except ImportError:
478
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
479
+
480
+ self.db_path = str(db_path)
481
+ self._client = lmdb.open(
482
+ self.db_path,
483
+ readonly=readonly,
484
+ lock=lock,
485
+ readahead=readahead,
486
+ **kwargs)
487
+
488
+ def get(self, filepath):
489
+ """Get values according to the filepath.
490
+
491
+ Args:
492
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
493
+ """
494
+ filepath = str(filepath)
495
+ with self._client.begin(write=False) as txn:
496
+ value_buf = txn.get(filepath.encode('ascii'))
497
+ return value_buf
498
+
499
+ def get_text(self, filepath, encoding=None):
500
+ raise NotImplementedError
501
+
502
+
503
+ class HardDiskBackend(BaseStorageBackend):
504
+ """Raw hard disks storage backend."""
505
+
506
+ _allow_symlink = True
507
+
508
+ def get(self, filepath: Union[str, Path]) -> bytes:
509
+ """Read data from a given ``filepath`` with 'rb' mode.
510
+
511
+ Args:
512
+ filepath (str or Path): Path to read data.
513
+
514
+ Returns:
515
+ bytes: Expected bytes object.
516
+ """
517
+ with open(filepath, 'rb') as f:
518
+ value_buf = f.read()
519
+ return value_buf
520
+
521
+ def get_text(self,
522
+ filepath: Union[str, Path],
523
+ encoding: str = 'utf-8') -> str:
524
+ """Read data from a given ``filepath`` with 'r' mode.
525
+
526
+ Args:
527
+ filepath (str or Path): Path to read data.
528
+ encoding (str): The encoding format used to open the ``filepath``.
529
+ Default: 'utf-8'.
530
+
531
+ Returns:
532
+ str: Expected text reading from ``filepath``.
533
+ """
534
+ with open(filepath, 'r', encoding=encoding) as f:
535
+ value_buf = f.read()
536
+ return value_buf
537
+
538
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
539
+ """Write data to a given ``filepath`` with 'wb' mode.
540
+
541
+ Note:
542
+ ``put`` will create a directory if the directory of ``filepath``
543
+ does not exist.
544
+
545
+ Args:
546
+ obj (bytes): Data to be written.
547
+ filepath (str or Path): Path to write data.
548
+ """
549
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
550
+ with open(filepath, 'wb') as f:
551
+ f.write(obj)
552
+
553
+ def put_text(self,
554
+ obj: str,
555
+ filepath: Union[str, Path],
556
+ encoding: str = 'utf-8') -> None:
557
+ """Write data to a given ``filepath`` with 'w' mode.
558
+
559
+ Note:
560
+ ``put_text`` will create a directory if the directory of
561
+ ``filepath`` does not exist.
562
+
563
+ Args:
564
+ obj (str): Data to be written.
565
+ filepath (str or Path): Path to write data.
566
+ encoding (str): The encoding format used to open the ``filepath``.
567
+ Default: 'utf-8'.
568
+ """
569
+ mmcv.mkdir_or_exist(osp.dirname(filepath))
570
+ with open(filepath, 'w', encoding=encoding) as f:
571
+ f.write(obj)
572
+
573
+ def remove(self, filepath: Union[str, Path]) -> None:
574
+ """Remove a file.
575
+
576
+ Args:
577
+ filepath (str or Path): Path to be removed.
578
+ """
579
+ os.remove(filepath)
580
+
581
+ def exists(self, filepath: Union[str, Path]) -> bool:
582
+ """Check whether a file path exists.
583
+
584
+ Args:
585
+ filepath (str or Path): Path to be checked whether exists.
586
+
587
+ Returns:
588
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
589
+ """
590
+ return osp.exists(filepath)
591
+
592
+ def isdir(self, filepath: Union[str, Path]) -> bool:
593
+ """Check whether a file path is a directory.
594
+
595
+ Args:
596
+ filepath (str or Path): Path to be checked whether it is a
597
+ directory.
598
+
599
+ Returns:
600
+ bool: Return ``True`` if ``filepath`` points to a directory,
601
+ ``False`` otherwise.
602
+ """
603
+ return osp.isdir(filepath)
604
+
605
+ def isfile(self, filepath: Union[str, Path]) -> bool:
606
+ """Check whether a file path is a file.
607
+
608
+ Args:
609
+ filepath (str or Path): Path to be checked whether it is a file.
610
+
611
+ Returns:
612
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
613
+ otherwise.
614
+ """
615
+ return osp.isfile(filepath)
616
+
617
+ def join_path(self, filepath: Union[str, Path],
618
+ *filepaths: Union[str, Path]) -> str:
619
+ """Concatenate all file paths.
620
+
621
+ Join one or more filepath components intelligently. The return value
622
+ is the concatenation of filepath and any members of *filepaths.
623
+
624
+ Args:
625
+ filepath (str or Path): Path to be concatenated.
626
+
627
+ Returns:
628
+ str: The result of concatenation.
629
+ """
630
+ return osp.join(filepath, *filepaths)
631
+
632
+ @contextmanager
633
+ def get_local_path(
634
+ self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]:
635
+ """Only for unified API and do nothing."""
636
+ yield filepath
637
+
638
+ def list_dir_or_file(self,
639
+ dir_path: Union[str, Path],
640
+ list_dir: bool = True,
641
+ list_file: bool = True,
642
+ suffix: Optional[Union[str, Tuple[str]]] = None,
643
+ recursive: bool = False) -> Iterator[str]:
644
+ """Scan a directory to find the interested directories or files in
645
+ arbitrary order.
646
+
647
+ Note:
648
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
649
+
650
+ Args:
651
+ dir_path (str | Path): Path of the directory.
652
+ list_dir (bool): List the directories. Default: True.
653
+ list_file (bool): List the path of files. Default: True.
654
+ suffix (str or tuple[str], optional): File suffix
655
+ that we are interested in. Default: None.
656
+ recursive (bool): If set to True, recursively scan the
657
+ directory. Default: False.
658
+
659
+ Yields:
660
+ Iterable[str]: A relative path to ``dir_path``.
661
+ """
662
+ if list_dir and suffix is not None:
663
+ raise TypeError('`suffix` should be None when `list_dir` is True')
664
+
665
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
666
+ raise TypeError('`suffix` must be a string or tuple of strings')
667
+
668
+ root = dir_path
669
+
670
+ def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
671
+ recursive):
672
+ for entry in os.scandir(dir_path):
673
+ if not entry.name.startswith('.') and entry.is_file():
674
+ rel_path = osp.relpath(entry.path, root)
675
+ if (suffix is None
676
+ or rel_path.endswith(suffix)) and list_file:
677
+ yield rel_path
678
+ elif osp.isdir(entry.path):
679
+ if list_dir:
680
+ rel_dir = osp.relpath(entry.path, root)
681
+ yield rel_dir
682
+ if recursive:
683
+ yield from _list_dir_or_file(entry.path, list_dir,
684
+ list_file, suffix,
685
+ recursive)
686
+
687
+ return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
688
+ recursive)
689
+
690
+
691
+ class HTTPBackend(BaseStorageBackend):
692
+ """HTTP and HTTPS storage bachend."""
693
+
694
+ def get(self, filepath):
695
+ value_buf = urlopen(filepath).read()
696
+ return value_buf
697
+
698
+ def get_text(self, filepath, encoding='utf-8'):
699
+ value_buf = urlopen(filepath).read()
700
+ return value_buf.decode(encoding)
701
+
702
+ @contextmanager
703
+ def get_local_path(self, filepath: str) -> Iterable[str]:
704
+ """Download a file from ``filepath``.
705
+
706
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
707
+ can be called with ``with`` statement, and when exists from the
708
+ ``with`` statement, the temporary path will be released.
709
+
710
+ Args:
711
+ filepath (str): Download a file from ``filepath``.
712
+
713
+ Examples:
714
+ >>> client = HTTPBackend()
715
+ >>> # After existing from the ``with`` clause,
716
+ >>> # the path will be removed
717
+ >>> with client.get_local_path('http://path/of/your/file') as path:
718
+ ... # do something here
719
+ """
720
+ try:
721
+ f = tempfile.NamedTemporaryFile(delete=False)
722
+ f.write(self.get(filepath))
723
+ f.close()
724
+ yield f.name
725
+ finally:
726
+ os.remove(f.name)
727
+
728
+
729
+ class FileClient:
730
+ """A general file client to access files in different backends.
731
+
732
+ The client loads a file or text in a specified backend from its path
733
+ and returns it as a binary or text file. There are two ways to choose a
734
+ backend, the name of backend and the prefix of path. Although both of them
735
+ can be used to choose a storage backend, ``backend`` has a higher priority
736
+ that is if they are all set, the storage backend will be chosen by the
737
+ backend argument. If they are all `None`, the disk backend will be chosen.
738
+ Note that It can also register other backend accessor with a given name,
739
+ prefixes, and backend class. In addition, We use the singleton pattern to
740
+ avoid repeated object creation. If the arguments are the same, the same
741
+ object will be returned.
742
+
743
+ Args:
744
+ backend (str, optional): The storage backend type. Options are "disk",
745
+ "ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
746
+ prefix (str, optional): The prefix of the registered storage backend.
747
+ Options are "s3", "http", "https". Default: None.
748
+
749
+ Examples:
750
+ >>> # only set backend
751
+ >>> file_client = FileClient(backend='petrel')
752
+ >>> # only set prefix
753
+ >>> file_client = FileClient(prefix='s3')
754
+ >>> # set both backend and prefix but use backend to choose client
755
+ >>> file_client = FileClient(backend='petrel', prefix='s3')
756
+ >>> # if the arguments are the same, the same object is returned
757
+ >>> file_client1 = FileClient(backend='petrel')
758
+ >>> file_client1 is file_client
759
+ True
760
+
761
+ Attributes:
762
+ client (:obj:`BaseStorageBackend`): The backend object.
763
+ """
764
+
765
+ _backends = {
766
+ 'disk': HardDiskBackend,
767
+ 'ceph': CephBackend,
768
+ 'memcached': MemcachedBackend,
769
+ 'lmdb': LmdbBackend,
770
+ 'petrel': PetrelBackend,
771
+ 'http': HTTPBackend,
772
+ }
773
+ # This collection is used to record the overridden backends, and when a
774
+ # backend appears in the collection, the singleton pattern is disabled for
775
+ # that backend, because if the singleton pattern is used, then the object
776
+ # returned will be the backend before overwriting
777
+ _overridden_backends = set()
778
+ _prefix_to_backends = {
779
+ 's3': PetrelBackend,
780
+ 'http': HTTPBackend,
781
+ 'https': HTTPBackend,
782
+ }
783
+ _overridden_prefixes = set()
784
+
785
+ _instances = {}
786
+
787
+ def __new__(cls, backend=None, prefix=None, **kwargs):
788
+ if backend is None and prefix is None:
789
+ backend = 'disk'
790
+ if backend is not None and backend not in cls._backends:
791
+ raise ValueError(
792
+ f'Backend {backend} is not supported. Currently supported ones'
793
+ f' are {list(cls._backends.keys())}')
794
+ if prefix is not None and prefix not in cls._prefix_to_backends:
795
+ raise ValueError(
796
+ f'prefix {prefix} is not supported. Currently supported ones '
797
+ f'are {list(cls._prefix_to_backends.keys())}')
798
+
799
+ # concatenate the arguments to a unique key for determining whether
800
+ # objects with the same arguments were created
801
+ arg_key = f'{backend}:{prefix}'
802
+ for key, value in kwargs.items():
803
+ arg_key += f':{key}:{value}'
804
+
805
+ # if a backend was overridden, it will create a new object
806
+ if (arg_key in cls._instances
807
+ and backend not in cls._overridden_backends
808
+ and prefix not in cls._overridden_prefixes):
809
+ _instance = cls._instances[arg_key]
810
+ else:
811
+ # create a new object and put it to _instance
812
+ _instance = super().__new__(cls)
813
+ if backend is not None:
814
+ _instance.client = cls._backends[backend](**kwargs)
815
+ else:
816
+ _instance.client = cls._prefix_to_backends[prefix](**kwargs)
817
+
818
+ cls._instances[arg_key] = _instance
819
+
820
+ return _instance
821
+
822
+ @property
823
+ def name(self):
824
+ return self.client.name
825
+
826
+ @property
827
+ def allow_symlink(self):
828
+ return self.client.allow_symlink
829
+
830
+ @staticmethod
831
+ def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
832
+ """Parse the prefix of a uri.
833
+
834
+ Args:
835
+ uri (str | Path): Uri to be parsed that contains the file prefix.
836
+
837
+ Examples:
838
+ >>> FileClient.parse_uri_prefix('s3://path/of/your/file')
839
+ 's3'
840
+
841
+ Returns:
842
+ str | None: Return the prefix of uri if the uri contains '://'
843
+ else ``None``.
844
+ """
845
+ assert is_filepath(uri)
846
+ uri = str(uri)
847
+ if '://' not in uri:
848
+ return None
849
+ else:
850
+ prefix, _ = uri.split('://')
851
+ # In the case of PetrelBackend, the prefix may contains the cluster
852
+ # name like clusterName:s3
853
+ if ':' in prefix:
854
+ _, prefix = prefix.split(':')
855
+ return prefix
856
+
857
+ @classmethod
858
+ def infer_client(cls,
859
+ file_client_args: Optional[dict] = None,
860
+ uri: Optional[Union[str, Path]] = None) -> 'FileClient':
861
+ """Infer a suitable file client based on the URI and arguments.
862
+
863
+ Args:
864
+ file_client_args (dict, optional): Arguments to instantiate a
865
+ FileClient. Default: None.
866
+ uri (str | Path, optional): Uri to be parsed that contains the file
867
+ prefix. Default: None.
868
+
869
+ Examples:
870
+ >>> uri = 's3://path/of/your/file'
871
+ >>> file_client = FileClient.infer_client(uri=uri)
872
+ >>> file_client_args = {'backend': 'petrel'}
873
+ >>> file_client = FileClient.infer_client(file_client_args)
874
+
875
+ Returns:
876
+ FileClient: Instantiated FileClient object.
877
+ """
878
+ assert file_client_args is not None or uri is not None
879
+ if file_client_args is None:
880
+ file_prefix = cls.parse_uri_prefix(uri) # type: ignore
881
+ return cls(prefix=file_prefix)
882
+ else:
883
+ return cls(**file_client_args)
884
+
885
+ @classmethod
886
+ def _register_backend(cls, name, backend, force=False, prefixes=None):
887
+ if not isinstance(name, str):
888
+ raise TypeError('the backend name should be a string, '
889
+ f'but got {type(name)}')
890
+ if not inspect.isclass(backend):
891
+ raise TypeError(
892
+ f'backend should be a class but got {type(backend)}')
893
+ if not issubclass(backend, BaseStorageBackend):
894
+ raise TypeError(
895
+ f'backend {backend} is not a subclass of BaseStorageBackend')
896
+ if not force and name in cls._backends:
897
+ raise KeyError(
898
+ f'{name} is already registered as a storage backend, '
899
+ 'add "force=True" if you want to override it')
900
+
901
+ if name in cls._backends and force:
902
+ cls._overridden_backends.add(name)
903
+ cls._backends[name] = backend
904
+
905
+ if prefixes is not None:
906
+ if isinstance(prefixes, str):
907
+ prefixes = [prefixes]
908
+ else:
909
+ assert isinstance(prefixes, (list, tuple))
910
+ for prefix in prefixes:
911
+ if prefix not in cls._prefix_to_backends:
912
+ cls._prefix_to_backends[prefix] = backend
913
+ elif (prefix in cls._prefix_to_backends) and force:
914
+ cls._overridden_prefixes.add(prefix)
915
+ cls._prefix_to_backends[prefix] = backend
916
+ else:
917
+ raise KeyError(
918
+ f'{prefix} is already registered as a storage backend,'
919
+ ' add "force=True" if you want to override it')
920
+
921
+ @classmethod
922
+ def register_backend(cls, name, backend=None, force=False, prefixes=None):
923
+ """Register a backend to FileClient.
924
+
925
+ This method can be used as a normal class method or a decorator.
926
+
927
+ .. code-block:: python
928
+
929
+ class NewBackend(BaseStorageBackend):
930
+
931
+ def get(self, filepath):
932
+ return filepath
933
+
934
+ def get_text(self, filepath):
935
+ return filepath
936
+
937
+ FileClient.register_backend('new', NewBackend)
938
+
939
+ or
940
+
941
+ .. code-block:: python
942
+
943
+ @FileClient.register_backend('new')
944
+ class NewBackend(BaseStorageBackend):
945
+
946
+ def get(self, filepath):
947
+ return filepath
948
+
949
+ def get_text(self, filepath):
950
+ return filepath
951
+
952
+ Args:
953
+ name (str): The name of the registered backend.
954
+ backend (class, optional): The backend class to be registered,
955
+ which must be a subclass of :class:`BaseStorageBackend`.
956
+ When this method is used as a decorator, backend is None.
957
+ Defaults to None.
958
+ force (bool, optional): Whether to override the backend if the name
959
+ has already been registered. Defaults to False.
960
+ prefixes (str or list[str] or tuple[str], optional): The prefixes
961
+ of the registered storage backend. Default: None.
962
+ `New in version 1.3.15.`
963
+ """
964
+ if backend is not None:
965
+ cls._register_backend(
966
+ name, backend, force=force, prefixes=prefixes)
967
+ return
968
+
969
+ def _register(backend_cls):
970
+ cls._register_backend(
971
+ name, backend_cls, force=force, prefixes=prefixes)
972
+ return backend_cls
973
+
974
+ return _register
975
+
976
+ def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
977
+ """Read data from a given ``filepath`` with 'rb' mode.
978
+
979
+ Note:
980
+ There are two types of return values for ``get``, one is ``bytes``
981
+ and the other is ``memoryview``. The advantage of using memoryview
982
+ is that you can avoid copying, and if you want to convert it to
983
+ ``bytes``, you can use ``.tobytes()``.
984
+
985
+ Args:
986
+ filepath (str or Path): Path to read data.
987
+
988
+ Returns:
989
+ bytes | memoryview: Expected bytes object or a memory view of the
990
+ bytes object.
991
+ """
992
+ return self.client.get(filepath)
993
+
994
+ def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str:
995
+ """Read data from a given ``filepath`` with 'r' mode.
996
+
997
+ Args:
998
+ filepath (str or Path): Path to read data.
999
+ encoding (str): The encoding format used to open the ``filepath``.
1000
+ Default: 'utf-8'.
1001
+
1002
+ Returns:
1003
+ str: Expected text reading from ``filepath``.
1004
+ """
1005
+ return self.client.get_text(filepath, encoding)
1006
+
1007
+ def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
1008
+ """Write data to a given ``filepath`` with 'wb' mode.
1009
+
1010
+ Note:
1011
+ ``put`` should create a directory if the directory of ``filepath``
1012
+ does not exist.
1013
+
1014
+ Args:
1015
+ obj (bytes): Data to be written.
1016
+ filepath (str or Path): Path to write data.
1017
+ """
1018
+ self.client.put(obj, filepath)
1019
+
1020
+ def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
1021
+ """Write data to a given ``filepath`` with 'w' mode.
1022
+
1023
+ Note:
1024
+ ``put_text`` should create a directory if the directory of
1025
+ ``filepath`` does not exist.
1026
+
1027
+ Args:
1028
+ obj (str): Data to be written.
1029
+ filepath (str or Path): Path to write data.
1030
+ encoding (str, optional): The encoding format used to open the
1031
+ `filepath`. Default: 'utf-8'.
1032
+ """
1033
+ self.client.put_text(obj, filepath)
1034
+
1035
+ def remove(self, filepath: Union[str, Path]) -> None:
1036
+ """Remove a file.
1037
+
1038
+ Args:
1039
+ filepath (str, Path): Path to be removed.
1040
+ """
1041
+ self.client.remove(filepath)
1042
+
1043
+ def exists(self, filepath: Union[str, Path]) -> bool:
1044
+ """Check whether a file path exists.
1045
+
1046
+ Args:
1047
+ filepath (str or Path): Path to be checked whether exists.
1048
+
1049
+ Returns:
1050
+ bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
1051
+ """
1052
+ return self.client.exists(filepath)
1053
+
1054
+ def isdir(self, filepath: Union[str, Path]) -> bool:
1055
+ """Check whether a file path is a directory.
1056
+
1057
+ Args:
1058
+ filepath (str or Path): Path to be checked whether it is a
1059
+ directory.
1060
+
1061
+ Returns:
1062
+ bool: Return ``True`` if ``filepath`` points to a directory,
1063
+ ``False`` otherwise.
1064
+ """
1065
+ return self.client.isdir(filepath)
1066
+
1067
+ def isfile(self, filepath: Union[str, Path]) -> bool:
1068
+ """Check whether a file path is a file.
1069
+
1070
+ Args:
1071
+ filepath (str or Path): Path to be checked whether it is a file.
1072
+
1073
+ Returns:
1074
+ bool: Return ``True`` if ``filepath`` points to a file, ``False``
1075
+ otherwise.
1076
+ """
1077
+ return self.client.isfile(filepath)
1078
+
1079
+ def join_path(self, filepath: Union[str, Path],
1080
+ *filepaths: Union[str, Path]) -> str:
1081
+ """Concatenate all file paths.
1082
+
1083
+ Join one or more filepath components intelligently. The return value
1084
+ is the concatenation of filepath and any members of *filepaths.
1085
+
1086
+ Args:
1087
+ filepath (str or Path): Path to be concatenated.
1088
+
1089
+ Returns:
1090
+ str: The result of concatenation.
1091
+ """
1092
+ return self.client.join_path(filepath, *filepaths)
1093
+
1094
+ @contextmanager
1095
+ def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
1096
+ """Download data from ``filepath`` and write the data to local path.
1097
+
1098
+ ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
1099
+ can be called with ``with`` statement, and when exists from the
1100
+ ``with`` statement, the temporary path will be released.
1101
+
1102
+ Note:
1103
+ If the ``filepath`` is a local path, just return itself.
1104
+
1105
+ .. warning::
1106
+ ``get_local_path`` is an experimental interface that may change in
1107
+ the future.
1108
+
1109
+ Args:
1110
+ filepath (str or Path): Path to be read data.
1111
+
1112
+ Examples:
1113
+ >>> file_client = FileClient(prefix='s3')
1114
+ >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
1115
+ ... # do something here
1116
+
1117
+ Yields:
1118
+ Iterable[str]: Only yield one path.
1119
+ """
1120
+ with self.client.get_local_path(str(filepath)) as local_path:
1121
+ yield local_path
1122
+
1123
+ def list_dir_or_file(self,
1124
+ dir_path: Union[str, Path],
1125
+ list_dir: bool = True,
1126
+ list_file: bool = True,
1127
+ suffix: Optional[Union[str, Tuple[str]]] = None,
1128
+ recursive: bool = False) -> Iterator[str]:
1129
+ """Scan a directory to find the interested directories or files in
1130
+ arbitrary order.
1131
+
1132
+ Note:
1133
+ :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
1134
+
1135
+ Args:
1136
+ dir_path (str | Path): Path of the directory.
1137
+ list_dir (bool): List the directories. Default: True.
1138
+ list_file (bool): List the path of files. Default: True.
1139
+ suffix (str or tuple[str], optional): File suffix
1140
+ that we are interested in. Default: None.
1141
+ recursive (bool): If set to True, recursively scan the
1142
+ directory. Default: False.
1143
+
1144
+ Yields:
1145
+ Iterable[str]: A relative path to ``dir_path``.
1146
+ """
1147
+ yield from self.client.list_dir_or_file(dir_path, list_dir, list_file,
1148
+ suffix, recursive)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .base import BaseFileHandler
3
+ from .json_handler import JsonHandler
4
+ from .pickle_handler import PickleHandler
5
+ from .yaml_handler import YamlHandler
6
+
7
+ __all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/base.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from abc import ABCMeta, abstractmethod
3
+
4
+
5
+ class BaseFileHandler(metaclass=ABCMeta):
6
+ # `str_like` is a flag to indicate whether the type of file object is
7
+ # str-like object or bytes-like object. Pickle only processes bytes-like
8
+ # objects but json only processes str-like object. If it is str-like
9
+ # object, `StringIO` will be used to process the buffer.
10
+ str_like = True
11
+
12
+ @abstractmethod
13
+ def load_from_fileobj(self, file, **kwargs):
14
+ pass
15
+
16
+ @abstractmethod
17
+ def dump_to_fileobj(self, obj, file, **kwargs):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def dump_to_str(self, obj, **kwargs):
22
+ pass
23
+
24
+ def load_from_path(self, filepath, mode='r', **kwargs):
25
+ with open(filepath, mode) as f:
26
+ return self.load_from_fileobj(f, **kwargs)
27
+
28
+ def dump_to_path(self, obj, filepath, mode='w', **kwargs):
29
+ with open(filepath, mode) as f:
30
+ self.dump_to_fileobj(obj, f, **kwargs)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/json_handler.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import json
3
+
4
+ import numpy as np
5
+
6
+ from .base import BaseFileHandler
7
+
8
+
9
+ def set_default(obj):
10
+ """Set default json values for non-serializable values.
11
+
12
+ It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
13
+ It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
14
+ etc.) into plain numbers of plain python built-in types.
15
+ """
16
+ if isinstance(obj, (set, range)):
17
+ return list(obj)
18
+ elif isinstance(obj, np.ndarray):
19
+ return obj.tolist()
20
+ elif isinstance(obj, np.generic):
21
+ return obj.item()
22
+ raise TypeError(f'{type(obj)} is unsupported for json dump')
23
+
24
+
25
+ class JsonHandler(BaseFileHandler):
26
+
27
+ def load_from_fileobj(self, file):
28
+ return json.load(file)
29
+
30
+ def dump_to_fileobj(self, obj, file, **kwargs):
31
+ kwargs.setdefault('default', set_default)
32
+ json.dump(obj, file, **kwargs)
33
+
34
+ def dump_to_str(self, obj, **kwargs):
35
+ kwargs.setdefault('default', set_default)
36
+ return json.dumps(obj, **kwargs)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/pickle_handler.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import pickle
3
+
4
+ from .base import BaseFileHandler
5
+
6
+
7
+ class PickleHandler(BaseFileHandler):
8
+
9
+ str_like = False
10
+
11
+ def load_from_fileobj(self, file, **kwargs):
12
+ return pickle.load(file, **kwargs)
13
+
14
+ def load_from_path(self, filepath, **kwargs):
15
+ return super(PickleHandler, self).load_from_path(
16
+ filepath, mode='rb', **kwargs)
17
+
18
+ def dump_to_str(self, obj, **kwargs):
19
+ kwargs.setdefault('protocol', 2)
20
+ return pickle.dumps(obj, **kwargs)
21
+
22
+ def dump_to_fileobj(self, obj, file, **kwargs):
23
+ kwargs.setdefault('protocol', 2)
24
+ pickle.dump(obj, file, **kwargs)
25
+
26
+ def dump_to_path(self, obj, filepath, **kwargs):
27
+ super(PickleHandler, self).dump_to_path(
28
+ obj, filepath, mode='wb', **kwargs)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/handlers/yaml_handler.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import yaml
3
+
4
+ try:
5
+ from yaml import CLoader as Loader, CDumper as Dumper
6
+ except ImportError:
7
+ from yaml import Loader, Dumper
8
+
9
+ from .base import BaseFileHandler # isort:skip
10
+
11
+
12
+ class YamlHandler(BaseFileHandler):
13
+
14
+ def load_from_fileobj(self, file, **kwargs):
15
+ kwargs.setdefault('Loader', Loader)
16
+ return yaml.load(file, **kwargs)
17
+
18
+ def dump_to_fileobj(self, obj, file, **kwargs):
19
+ kwargs.setdefault('Dumper', Dumper)
20
+ yaml.dump(obj, file, **kwargs)
21
+
22
+ def dump_to_str(self, obj, **kwargs):
23
+ kwargs.setdefault('Dumper', Dumper)
24
+ return yaml.dump(obj, **kwargs)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/io.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from io import BytesIO, StringIO
3
+ from pathlib import Path
4
+
5
+ from ..utils import is_list_of, is_str
6
+ from .file_client import FileClient
7
+ from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
8
+
9
+ file_handlers = {
10
+ 'json': JsonHandler(),
11
+ 'yaml': YamlHandler(),
12
+ 'yml': YamlHandler(),
13
+ 'pickle': PickleHandler(),
14
+ 'pkl': PickleHandler()
15
+ }
16
+
17
+
18
+ def load(file, file_format=None, file_client_args=None, **kwargs):
19
+ """Load data from json/yaml/pickle files.
20
+
21
+ This method provides a unified api for loading data from serialized files.
22
+
23
+ Note:
24
+ In v1.3.16 and later, ``load`` supports loading data from serialized
25
+ files those can be storaged in different backends.
26
+
27
+ Args:
28
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
29
+ object.
30
+ file_format (str, optional): If not specified, the file format will be
31
+ inferred from the file extension, otherwise use the specified one.
32
+ Currently supported formats include "json", "yaml/yml" and
33
+ "pickle/pkl".
34
+ file_client_args (dict, optional): Arguments to instantiate a
35
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
36
+ Default: None.
37
+
38
+ Examples:
39
+ >>> load('/path/of/your/file') # file is storaged in disk
40
+ >>> load('https://path/of/your/file') # file is storaged in Internet
41
+ >>> load('s3://path/of/your/file') # file is storaged in petrel
42
+
43
+ Returns:
44
+ The content from the file.
45
+ """
46
+ if isinstance(file, Path):
47
+ file = str(file)
48
+ if file_format is None and is_str(file):
49
+ file_format = file.split('.')[-1]
50
+ if file_format not in file_handlers:
51
+ raise TypeError(f'Unsupported format: {file_format}')
52
+
53
+ handler = file_handlers[file_format]
54
+ if is_str(file):
55
+ file_client = FileClient.infer_client(file_client_args, file)
56
+ if handler.str_like:
57
+ with StringIO(file_client.get_text(file)) as f:
58
+ obj = handler.load_from_fileobj(f, **kwargs)
59
+ else:
60
+ with BytesIO(file_client.get(file)) as f:
61
+ obj = handler.load_from_fileobj(f, **kwargs)
62
+ elif hasattr(file, 'read'):
63
+ obj = handler.load_from_fileobj(file, **kwargs)
64
+ else:
65
+ raise TypeError('"file" must be a filepath str or a file-object')
66
+ return obj
67
+
68
+
69
+ def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
70
+ """Dump data to json/yaml/pickle strings or files.
71
+
72
+ This method provides a unified api for dumping data as strings or to files,
73
+ and also supports custom arguments for each file format.
74
+
75
+ Note:
76
+ In v1.3.16 and later, ``dump`` supports dumping data as strings or to
77
+ files which is saved to different backends.
78
+
79
+ Args:
80
+ obj (any): The python object to be dumped.
81
+ file (str or :obj:`Path` or file-like object, optional): If not
82
+ specified, then the object is dumped to a str, otherwise to a file
83
+ specified by the filename or file-like object.
84
+ file_format (str, optional): Same as :func:`load`.
85
+ file_client_args (dict, optional): Arguments to instantiate a
86
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
87
+ Default: None.
88
+
89
+ Examples:
90
+ >>> dump('hello world', '/path/of/your/file') # disk
91
+ >>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
92
+
93
+ Returns:
94
+ bool: True for success, False otherwise.
95
+ """
96
+ if isinstance(file, Path):
97
+ file = str(file)
98
+ if file_format is None:
99
+ if is_str(file):
100
+ file_format = file.split('.')[-1]
101
+ elif file is None:
102
+ raise ValueError(
103
+ 'file_format must be specified since file is None')
104
+ if file_format not in file_handlers:
105
+ raise TypeError(f'Unsupported format: {file_format}')
106
+
107
+ handler = file_handlers[file_format]
108
+ if file is None:
109
+ return handler.dump_to_str(obj, **kwargs)
110
+ elif is_str(file):
111
+ file_client = FileClient.infer_client(file_client_args, file)
112
+ if handler.str_like:
113
+ with StringIO() as f:
114
+ handler.dump_to_fileobj(obj, f, **kwargs)
115
+ file_client.put_text(f.getvalue(), file)
116
+ else:
117
+ with BytesIO() as f:
118
+ handler.dump_to_fileobj(obj, f, **kwargs)
119
+ file_client.put(f.getvalue(), file)
120
+ elif hasattr(file, 'write'):
121
+ handler.dump_to_fileobj(obj, file, **kwargs)
122
+ else:
123
+ raise TypeError('"file" must be a filename str or a file-object')
124
+
125
+
126
+ def _register_handler(handler, file_formats):
127
+ """Register a handler for some file extensions.
128
+
129
+ Args:
130
+ handler (:obj:`BaseFileHandler`): Handler to be registered.
131
+ file_formats (str or list[str]): File formats to be handled by this
132
+ handler.
133
+ """
134
+ if not isinstance(handler, BaseFileHandler):
135
+ raise TypeError(
136
+ f'handler must be a child of BaseFileHandler, not {type(handler)}')
137
+ if isinstance(file_formats, str):
138
+ file_formats = [file_formats]
139
+ if not is_list_of(file_formats, str):
140
+ raise TypeError('file_formats must be a str or a list of str')
141
+ for ext in file_formats:
142
+ file_handlers[ext] = handler
143
+
144
+
145
+ def register_handler(file_formats, **kwargs):
146
+
147
+ def wrap(cls):
148
+ _register_handler(cls(**kwargs), file_formats)
149
+ return cls
150
+
151
+ return wrap
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/fileio/parse.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+
3
+ from io import StringIO
4
+
5
+ from .file_client import FileClient
6
+
7
+
8
+ def list_from_file(filename,
9
+ prefix='',
10
+ offset=0,
11
+ max_num=0,
12
+ encoding='utf-8',
13
+ file_client_args=None):
14
+ """Load a text file and parse the content as a list of strings.
15
+
16
+ Note:
17
+ In v1.3.16 and later, ``list_from_file`` supports loading a text file
18
+ which can be storaged in different backends and parsing the content as
19
+ a list for strings.
20
+
21
+ Args:
22
+ filename (str): Filename.
23
+ prefix (str): The prefix to be inserted to the beginning of each item.
24
+ offset (int): The offset of lines.
25
+ max_num (int): The maximum number of lines to be read,
26
+ zeros and negatives mean no limitation.
27
+ encoding (str): Encoding used to open the file. Default utf-8.
28
+ file_client_args (dict, optional): Arguments to instantiate a
29
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
30
+ Default: None.
31
+
32
+ Examples:
33
+ >>> list_from_file('/path/of/your/file') # disk
34
+ ['hello', 'world']
35
+ >>> list_from_file('s3://path/of/your/file') # ceph or petrel
36
+ ['hello', 'world']
37
+
38
+ Returns:
39
+ list[str]: A list of strings.
40
+ """
41
+ cnt = 0
42
+ item_list = []
43
+ file_client = FileClient.infer_client(file_client_args, filename)
44
+ with StringIO(file_client.get_text(filename, encoding)) as f:
45
+ for _ in range(offset):
46
+ f.readline()
47
+ for line in f:
48
+ if 0 < max_num <= cnt:
49
+ break
50
+ item_list.append(prefix + line.rstrip('\n\r'))
51
+ cnt += 1
52
+ return item_list
53
+
54
+
55
+ def dict_from_file(filename,
56
+ key_type=str,
57
+ encoding='utf-8',
58
+ file_client_args=None):
59
+ """Load a text file and parse the content as a dict.
60
+
61
+ Each line of the text file will be two or more columns split by
62
+ whitespaces or tabs. The first column will be parsed as dict keys, and
63
+ the following columns will be parsed as dict values.
64
+
65
+ Note:
66
+ In v1.3.16 and later, ``dict_from_file`` supports loading a text file
67
+ which can be storaged in different backends and parsing the content as
68
+ a dict.
69
+
70
+ Args:
71
+ filename(str): Filename.
72
+ key_type(type): Type of the dict keys. str is user by default and
73
+ type conversion will be performed if specified.
74
+ encoding (str): Encoding used to open the file. Default utf-8.
75
+ file_client_args (dict, optional): Arguments to instantiate a
76
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
77
+ Default: None.
78
+
79
+ Examples:
80
+ >>> dict_from_file('/path/of/your/file') # disk
81
+ {'key1': 'value1', 'key2': 'value2'}
82
+ >>> dict_from_file('s3://path/of/your/file') # ceph or petrel
83
+ {'key1': 'value1', 'key2': 'value2'}
84
+
85
+ Returns:
86
+ dict: The parsed contents.
87
+ """
88
+ mapping = {}
89
+ file_client = FileClient.infer_client(file_client_args, filename)
90
+ with StringIO(file_client.get_text(filename, encoding)) as f:
91
+ for line in f:
92
+ items = line.rstrip('\n').split()
93
+ assert len(items) >= 2
94
+ key = key_type(items[0])
95
+ val = items[1:] if len(items) > 2 else items[1]
96
+ mapping[key] = val
97
+ return mapping
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr,
3
+ gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert,
4
+ rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
5
+ from .geometric import (cutout, imcrop, imflip, imflip_, impad,
6
+ impad_to_multiple, imrescale, imresize, imresize_like,
7
+ imresize_to_multiple, imrotate, imshear, imtranslate,
8
+ rescale_size)
9
+ from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
10
+ from .misc import tensor2imgs
11
+ from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
12
+ adjust_lighting, adjust_sharpness, auto_contrast,
13
+ clahe, imdenormalize, imequalize, iminvert,
14
+ imnormalize, imnormalize_, lut_transform, posterize,
15
+ solarize)
16
+
17
+ __all__ = [
18
+ 'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
19
+ 'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
20
+ 'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
21
+ 'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
22
+ 'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
23
+ 'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
24
+ 'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
25
+ 'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
26
+ 'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
27
+ 'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting'
28
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/colorspace.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import cv2
3
+ import numpy as np
4
+
5
+
6
+ def imconvert(img, src, dst):
7
+ """Convert an image from the src colorspace to dst colorspace.
8
+
9
+ Args:
10
+ img (ndarray): The input image.
11
+ src (str): The source colorspace, e.g., 'rgb', 'hsv'.
12
+ dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
13
+
14
+ Returns:
15
+ ndarray: The converted image.
16
+ """
17
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
18
+ out_img = cv2.cvtColor(img, code)
19
+ return out_img
20
+
21
+
22
+ def bgr2gray(img, keepdim=False):
23
+ """Convert a BGR image to grayscale image.
24
+
25
+ Args:
26
+ img (ndarray): The input image.
27
+ keepdim (bool): If False (by default), then return the grayscale image
28
+ with 2 dims, otherwise 3 dims.
29
+
30
+ Returns:
31
+ ndarray: The converted grayscale image.
32
+ """
33
+ out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
34
+ if keepdim:
35
+ out_img = out_img[..., None]
36
+ return out_img
37
+
38
+
39
+ def rgb2gray(img, keepdim=False):
40
+ """Convert a RGB image to grayscale image.
41
+
42
+ Args:
43
+ img (ndarray): The input image.
44
+ keepdim (bool): If False (by default), then return the grayscale image
45
+ with 2 dims, otherwise 3 dims.
46
+
47
+ Returns:
48
+ ndarray: The converted grayscale image.
49
+ """
50
+ out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
51
+ if keepdim:
52
+ out_img = out_img[..., None]
53
+ return out_img
54
+
55
+
56
+ def gray2bgr(img):
57
+ """Convert a grayscale image to BGR image.
58
+
59
+ Args:
60
+ img (ndarray): The input image.
61
+
62
+ Returns:
63
+ ndarray: The converted BGR image.
64
+ """
65
+ img = img[..., None] if img.ndim == 2 else img
66
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
67
+ return out_img
68
+
69
+
70
+ def gray2rgb(img):
71
+ """Convert a grayscale image to RGB image.
72
+
73
+ Args:
74
+ img (ndarray): The input image.
75
+
76
+ Returns:
77
+ ndarray: The converted RGB image.
78
+ """
79
+ img = img[..., None] if img.ndim == 2 else img
80
+ out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
81
+ return out_img
82
+
83
+
84
+ def _convert_input_type_range(img):
85
+ """Convert the type and range of the input image.
86
+
87
+ It converts the input image to np.float32 type and range of [0, 1].
88
+ It is mainly used for pre-processing the input image in colorspace
89
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
90
+
91
+ Args:
92
+ img (ndarray): The input image. It accepts:
93
+ 1. np.uint8 type with range [0, 255];
94
+ 2. np.float32 type with range [0, 1].
95
+
96
+ Returns:
97
+ (ndarray): The converted image with type of np.float32 and range of
98
+ [0, 1].
99
+ """
100
+ img_type = img.dtype
101
+ img = img.astype(np.float32)
102
+ if img_type == np.float32:
103
+ pass
104
+ elif img_type == np.uint8:
105
+ img /= 255.
106
+ else:
107
+ raise TypeError('The img type should be np.float32 or np.uint8, '
108
+ f'but got {img_type}')
109
+ return img
110
+
111
+
112
+ def _convert_output_type_range(img, dst_type):
113
+ """Convert the type and range of the image according to dst_type.
114
+
115
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
116
+ images will be converted to np.uint8 type with range [0, 255]. If
117
+ `dst_type` is np.float32, it converts the image to np.float32 type with
118
+ range [0, 1].
119
+ It is mainly used for post-processing images in colorspace conversion
120
+ functions such as rgb2ycbcr and ycbcr2rgb.
121
+
122
+ Args:
123
+ img (ndarray): The image to be converted with np.float32 type and
124
+ range [0, 255].
125
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
126
+ converts the image to np.uint8 type with range [0, 255]. If
127
+ dst_type is np.float32, it converts the image to np.float32 type
128
+ with range [0, 1].
129
+
130
+ Returns:
131
+ (ndarray): The converted image with desired type and range.
132
+ """
133
+ if dst_type not in (np.uint8, np.float32):
134
+ raise TypeError('The dst_type should be np.float32 or np.uint8, '
135
+ f'but got {dst_type}')
136
+ if dst_type == np.uint8:
137
+ img = img.round()
138
+ else:
139
+ img /= 255.
140
+ return img.astype(dst_type)
141
+
142
+
143
+ def rgb2ycbcr(img, y_only=False):
144
+ """Convert a RGB image to YCbCr image.
145
+
146
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
147
+ It implements the ITU-R BT.601 conversion for standard-definition
148
+ television. See more details in
149
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
150
+
151
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
152
+ In OpenCV, it implements a JPEG conversion. See more details in
153
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
154
+
155
+ Args:
156
+ img (ndarray): The input image. It accepts:
157
+ 1. np.uint8 type with range [0, 255];
158
+ 2. np.float32 type with range [0, 1].
159
+ y_only (bool): Whether to only return Y channel. Default: False.
160
+
161
+ Returns:
162
+ ndarray: The converted YCbCr image. The output image has the same type
163
+ and range as input image.
164
+ """
165
+ img_type = img.dtype
166
+ img = _convert_input_type_range(img)
167
+ if y_only:
168
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
169
+ else:
170
+ out_img = np.matmul(
171
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
172
+ [24.966, 112.0, -18.214]]) + [16, 128, 128]
173
+ out_img = _convert_output_type_range(out_img, img_type)
174
+ return out_img
175
+
176
+
177
+ def bgr2ycbcr(img, y_only=False):
178
+ """Convert a BGR image to YCbCr image.
179
+
180
+ The bgr version of rgb2ycbcr.
181
+ It implements the ITU-R BT.601 conversion for standard-definition
182
+ television. See more details in
183
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
184
+
185
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
186
+ In OpenCV, it implements a JPEG conversion. See more details in
187
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
188
+
189
+ Args:
190
+ img (ndarray): The input image. It accepts:
191
+ 1. np.uint8 type with range [0, 255];
192
+ 2. np.float32 type with range [0, 1].
193
+ y_only (bool): Whether to only return Y channel. Default: False.
194
+
195
+ Returns:
196
+ ndarray: The converted YCbCr image. The output image has the same type
197
+ and range as input image.
198
+ """
199
+ img_type = img.dtype
200
+ img = _convert_input_type_range(img)
201
+ if y_only:
202
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
203
+ else:
204
+ out_img = np.matmul(
205
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
206
+ [65.481, -37.797, 112.0]]) + [16, 128, 128]
207
+ out_img = _convert_output_type_range(out_img, img_type)
208
+ return out_img
209
+
210
+
211
+ def ycbcr2rgb(img):
212
+ """Convert a YCbCr image to RGB image.
213
+
214
+ This function produces the same results as Matlab's ycbcr2rgb function.
215
+ It implements the ITU-R BT.601 conversion for standard-definition
216
+ television. See more details in
217
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
218
+
219
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
220
+ In OpenCV, it implements a JPEG conversion. See more details in
221
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
222
+
223
+ Args:
224
+ img (ndarray): The input image. It accepts:
225
+ 1. np.uint8 type with range [0, 255];
226
+ 2. np.float32 type with range [0, 1].
227
+
228
+ Returns:
229
+ ndarray: The converted RGB image. The output image has the same type
230
+ and range as input image.
231
+ """
232
+ img_type = img.dtype
233
+ img = _convert_input_type_range(img) * 255
234
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
235
+ [0, -0.00153632, 0.00791071],
236
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [
237
+ -222.921, 135.576, -276.836
238
+ ]
239
+ out_img = _convert_output_type_range(out_img, img_type)
240
+ return out_img
241
+
242
+
243
+ def ycbcr2bgr(img):
244
+ """Convert a YCbCr image to BGR image.
245
+
246
+ The bgr version of ycbcr2rgb.
247
+ It implements the ITU-R BT.601 conversion for standard-definition
248
+ television. See more details in
249
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
250
+
251
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
252
+ In OpenCV, it implements a JPEG conversion. See more details in
253
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
254
+
255
+ Args:
256
+ img (ndarray): The input image. It accepts:
257
+ 1. np.uint8 type with range [0, 255];
258
+ 2. np.float32 type with range [0, 1].
259
+
260
+ Returns:
261
+ ndarray: The converted BGR image. The output image has the same type
262
+ and range as input image.
263
+ """
264
+ img_type = img.dtype
265
+ img = _convert_input_type_range(img) * 255
266
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
267
+ [0.00791071, -0.00153632, 0],
268
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [
269
+ -276.836, 135.576, -222.921
270
+ ]
271
+ out_img = _convert_output_type_range(out_img, img_type)
272
+ return out_img
273
+
274
+
275
+ def convert_color_factory(src, dst):
276
+
277
+ code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
278
+
279
+ def convert_color(img):
280
+ out_img = cv2.cvtColor(img, code)
281
+ return out_img
282
+
283
+ convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
284
+ image.
285
+
286
+ Args:
287
+ img (ndarray or str): The input image.
288
+
289
+ Returns:
290
+ ndarray: The converted {dst.upper()} image.
291
+ """
292
+
293
+ return convert_color
294
+
295
+
296
+ bgr2rgb = convert_color_factory('bgr', 'rgb')
297
+
298
+ rgb2bgr = convert_color_factory('rgb', 'bgr')
299
+
300
+ bgr2hsv = convert_color_factory('bgr', 'hsv')
301
+
302
+ hsv2bgr = convert_color_factory('hsv', 'bgr')
303
+
304
+ bgr2hls = convert_color_factory('bgr', 'hls')
305
+
306
+ hls2bgr = convert_color_factory('hls', 'bgr')
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/geometric.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import numbers
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from ..utils import to_2tuple
8
+ from .io import imread_backend
9
+
10
+ try:
11
+ from PIL import Image
12
+ except ImportError:
13
+ Image = None
14
+
15
+
16
+ def _scale_size(size, scale):
17
+ """Rescale a size by a ratio.
18
+
19
+ Args:
20
+ size (tuple[int]): (w, h).
21
+ scale (float | tuple(float)): Scaling factor.
22
+
23
+ Returns:
24
+ tuple[int]: scaled size.
25
+ """
26
+ if isinstance(scale, (float, int)):
27
+ scale = (scale, scale)
28
+ w, h = size
29
+ return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
30
+
31
+
32
+ cv2_interp_codes = {
33
+ 'nearest': cv2.INTER_NEAREST,
34
+ 'bilinear': cv2.INTER_LINEAR,
35
+ 'bicubic': cv2.INTER_CUBIC,
36
+ 'area': cv2.INTER_AREA,
37
+ 'lanczos': cv2.INTER_LANCZOS4
38
+ }
39
+
40
+ if Image is not None:
41
+ pillow_interp_codes = {
42
+ 'nearest': Image.NEAREST,
43
+ 'bilinear': Image.BILINEAR,
44
+ 'bicubic': Image.BICUBIC,
45
+ 'box': Image.BOX,
46
+ 'lanczos': Image.LANCZOS,
47
+ 'hamming': Image.HAMMING
48
+ }
49
+
50
+
51
+ def imresize(img,
52
+ size,
53
+ return_scale=False,
54
+ interpolation='bilinear',
55
+ out=None,
56
+ backend=None):
57
+ """Resize image to a given size.
58
+
59
+ Args:
60
+ img (ndarray): The input image.
61
+ size (tuple[int]): Target size (w, h).
62
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
63
+ interpolation (str): Interpolation method, accepted values are
64
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
65
+ backend, "nearest", "bilinear" for 'pillow' backend.
66
+ out (ndarray): The output destination.
67
+ backend (str | None): The image resize backend type. Options are `cv2`,
68
+ `pillow`, `None`. If backend is None, the global imread_backend
69
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
70
+
71
+ Returns:
72
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
73
+ `resized_img`.
74
+ """
75
+ h, w = img.shape[:2]
76
+ if backend is None:
77
+ backend = imread_backend
78
+ if backend not in ['cv2', 'pillow']:
79
+ raise ValueError(f'backend: {backend} is not supported for resize.'
80
+ f"Supported backends are 'cv2', 'pillow'")
81
+
82
+ if backend == 'pillow':
83
+ assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
84
+ pil_image = Image.fromarray(img)
85
+ pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
86
+ resized_img = np.array(pil_image)
87
+ else:
88
+ resized_img = cv2.resize(
89
+ img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
90
+ if not return_scale:
91
+ return resized_img
92
+ else:
93
+ w_scale = size[0] / w
94
+ h_scale = size[1] / h
95
+ return resized_img, w_scale, h_scale
96
+
97
+
98
+ def imresize_to_multiple(img,
99
+ divisor,
100
+ size=None,
101
+ scale_factor=None,
102
+ keep_ratio=False,
103
+ return_scale=False,
104
+ interpolation='bilinear',
105
+ out=None,
106
+ backend=None):
107
+ """Resize image according to a given size or scale factor and then rounds
108
+ up the the resized or rescaled image size to the nearest value that can be
109
+ divided by the divisor.
110
+
111
+ Args:
112
+ img (ndarray): The input image.
113
+ divisor (int | tuple): Resized image size will be a multiple of
114
+ divisor. If divisor is a tuple, divisor should be
115
+ (w_divisor, h_divisor).
116
+ size (None | int | tuple[int]): Target size (w, h). Default: None.
117
+ scale_factor (None | float | tuple[float]): Multiplier for spatial
118
+ size. Should match input size if it is a tuple and the 2D style is
119
+ (w_scale_factor, h_scale_factor). Default: None.
120
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
121
+ image. Default: False.
122
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
123
+ interpolation (str): Interpolation method, accepted values are
124
+ "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
125
+ backend, "nearest", "bilinear" for 'pillow' backend.
126
+ out (ndarray): The output destination.
127
+ backend (str | None): The image resize backend type. Options are `cv2`,
128
+ `pillow`, `None`. If backend is None, the global imread_backend
129
+ specified by ``mmcv.use_backend()`` will be used. Default: None.
130
+
131
+ Returns:
132
+ tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
133
+ `resized_img`.
134
+ """
135
+ h, w = img.shape[:2]
136
+ if size is not None and scale_factor is not None:
137
+ raise ValueError('only one of size or scale_factor should be defined')
138
+ elif size is None and scale_factor is None:
139
+ raise ValueError('one of size or scale_factor should be defined')
140
+ elif size is not None:
141
+ size = to_2tuple(size)
142
+ if keep_ratio:
143
+ size = rescale_size((w, h), size, return_scale=False)
144
+ else:
145
+ size = _scale_size((w, h), scale_factor)
146
+
147
+ divisor = to_2tuple(divisor)
148
+ size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
149
+ resized_img, w_scale, h_scale = imresize(
150
+ img,
151
+ size,
152
+ return_scale=True,
153
+ interpolation=interpolation,
154
+ out=out,
155
+ backend=backend)
156
+ if return_scale:
157
+ return resized_img, w_scale, h_scale
158
+ else:
159
+ return resized_img
160
+
161
+
162
+ def imresize_like(img,
163
+ dst_img,
164
+ return_scale=False,
165
+ interpolation='bilinear',
166
+ backend=None):
167
+ """Resize image to the same size of a given image.
168
+
169
+ Args:
170
+ img (ndarray): The input image.
171
+ dst_img (ndarray): The target image.
172
+ return_scale (bool): Whether to return `w_scale` and `h_scale`.
173
+ interpolation (str): Same as :func:`resize`.
174
+ backend (str | None): Same as :func:`resize`.
175
+
176
+ Returns:
177
+ tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
178
+ `resized_img`.
179
+ """
180
+ h, w = dst_img.shape[:2]
181
+ return imresize(img, (w, h), return_scale, interpolation, backend=backend)
182
+
183
+
184
+ def rescale_size(old_size, scale, return_scale=False):
185
+ """Calculate the new size to be rescaled to.
186
+
187
+ Args:
188
+ old_size (tuple[int]): The old size (w, h) of image.
189
+ scale (float | tuple[int]): The scaling factor or maximum size.
190
+ If it is a float number, then the image will be rescaled by this
191
+ factor, else if it is a tuple of 2 integers, then the image will
192
+ be rescaled as large as possible within the scale.
193
+ return_scale (bool): Whether to return the scaling factor besides the
194
+ rescaled image size.
195
+
196
+ Returns:
197
+ tuple[int]: The new rescaled image size.
198
+ """
199
+ w, h = old_size
200
+ if isinstance(scale, (float, int)):
201
+ if scale <= 0:
202
+ raise ValueError(f'Invalid scale {scale}, must be positive.')
203
+ scale_factor = scale
204
+ elif isinstance(scale, tuple):
205
+ max_long_edge = max(scale)
206
+ max_short_edge = min(scale)
207
+ scale_factor = min(max_long_edge / max(h, w),
208
+ max_short_edge / min(h, w))
209
+ else:
210
+ raise TypeError(
211
+ f'Scale must be a number or tuple of int, but got {type(scale)}')
212
+
213
+ new_size = _scale_size((w, h), scale_factor)
214
+
215
+ if return_scale:
216
+ return new_size, scale_factor
217
+ else:
218
+ return new_size
219
+
220
+
221
+ def imrescale(img,
222
+ scale,
223
+ return_scale=False,
224
+ interpolation='bilinear',
225
+ backend=None):
226
+ """Resize image while keeping the aspect ratio.
227
+
228
+ Args:
229
+ img (ndarray): The input image.
230
+ scale (float | tuple[int]): The scaling factor or maximum size.
231
+ If it is a float number, then the image will be rescaled by this
232
+ factor, else if it is a tuple of 2 integers, then the image will
233
+ be rescaled as large as possible within the scale.
234
+ return_scale (bool): Whether to return the scaling factor besides the
235
+ rescaled image.
236
+ interpolation (str): Same as :func:`resize`.
237
+ backend (str | None): Same as :func:`resize`.
238
+
239
+ Returns:
240
+ ndarray: The rescaled image.
241
+ """
242
+ h, w = img.shape[:2]
243
+ new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
244
+ rescaled_img = imresize(
245
+ img, new_size, interpolation=interpolation, backend=backend)
246
+ if return_scale:
247
+ return rescaled_img, scale_factor
248
+ else:
249
+ return rescaled_img
250
+
251
+
252
+ def imflip(img, direction='horizontal'):
253
+ """Flip an image horizontally or vertically.
254
+
255
+ Args:
256
+ img (ndarray): Image to be flipped.
257
+ direction (str): The flip direction, either "horizontal" or
258
+ "vertical" or "diagonal".
259
+
260
+ Returns:
261
+ ndarray: The flipped image.
262
+ """
263
+ assert direction in ['horizontal', 'vertical', 'diagonal']
264
+ if direction == 'horizontal':
265
+ return np.flip(img, axis=1)
266
+ elif direction == 'vertical':
267
+ return np.flip(img, axis=0)
268
+ else:
269
+ return np.flip(img, axis=(0, 1))
270
+
271
+
272
+ def imflip_(img, direction='horizontal'):
273
+ """Inplace flip an image horizontally or vertically.
274
+
275
+ Args:
276
+ img (ndarray): Image to be flipped.
277
+ direction (str): The flip direction, either "horizontal" or
278
+ "vertical" or "diagonal".
279
+
280
+ Returns:
281
+ ndarray: The flipped image (inplace).
282
+ """
283
+ assert direction in ['horizontal', 'vertical', 'diagonal']
284
+ if direction == 'horizontal':
285
+ return cv2.flip(img, 1, img)
286
+ elif direction == 'vertical':
287
+ return cv2.flip(img, 0, img)
288
+ else:
289
+ return cv2.flip(img, -1, img)
290
+
291
+
292
+ def imrotate(img,
293
+ angle,
294
+ center=None,
295
+ scale=1.0,
296
+ border_value=0,
297
+ interpolation='bilinear',
298
+ auto_bound=False):
299
+ """Rotate an image.
300
+
301
+ Args:
302
+ img (ndarray): Image to be rotated.
303
+ angle (float): Rotation angle in degrees, positive values mean
304
+ clockwise rotation.
305
+ center (tuple[float], optional): Center point (w, h) of the rotation in
306
+ the source image. If not specified, the center of the image will be
307
+ used.
308
+ scale (float): Isotropic scale factor.
309
+ border_value (int): Border value.
310
+ interpolation (str): Same as :func:`resize`.
311
+ auto_bound (bool): Whether to adjust the image size to cover the whole
312
+ rotated image.
313
+
314
+ Returns:
315
+ ndarray: The rotated image.
316
+ """
317
+ if center is not None and auto_bound:
318
+ raise ValueError('`auto_bound` conflicts with `center`')
319
+ h, w = img.shape[:2]
320
+ if center is None:
321
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
322
+ assert isinstance(center, tuple)
323
+
324
+ matrix = cv2.getRotationMatrix2D(center, -angle, scale)
325
+ if auto_bound:
326
+ cos = np.abs(matrix[0, 0])
327
+ sin = np.abs(matrix[0, 1])
328
+ new_w = h * sin + w * cos
329
+ new_h = h * cos + w * sin
330
+ matrix[0, 2] += (new_w - w) * 0.5
331
+ matrix[1, 2] += (new_h - h) * 0.5
332
+ w = int(np.round(new_w))
333
+ h = int(np.round(new_h))
334
+ rotated = cv2.warpAffine(
335
+ img,
336
+ matrix, (w, h),
337
+ flags=cv2_interp_codes[interpolation],
338
+ borderValue=border_value)
339
+ return rotated
340
+
341
+
342
+ def bbox_clip(bboxes, img_shape):
343
+ """Clip bboxes to fit the image shape.
344
+
345
+ Args:
346
+ bboxes (ndarray): Shape (..., 4*k)
347
+ img_shape (tuple[int]): (height, width) of the image.
348
+
349
+ Returns:
350
+ ndarray: Clipped bboxes.
351
+ """
352
+ assert bboxes.shape[-1] % 4 == 0
353
+ cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype)
354
+ cmin[0::2] = img_shape[1] - 1
355
+ cmin[1::2] = img_shape[0] - 1
356
+ clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0)
357
+ return clipped_bboxes
358
+
359
+
360
+ def bbox_scaling(bboxes, scale, clip_shape=None):
361
+ """Scaling bboxes w.r.t the box center.
362
+
363
+ Args:
364
+ bboxes (ndarray): Shape(..., 4).
365
+ scale (float): Scaling factor.
366
+ clip_shape (tuple[int], optional): If specified, bboxes that exceed the
367
+ boundary will be clipped according to the given shape (h, w).
368
+
369
+ Returns:
370
+ ndarray: Scaled bboxes.
371
+ """
372
+ if float(scale) == 1.0:
373
+ scaled_bboxes = bboxes.copy()
374
+ else:
375
+ w = bboxes[..., 2] - bboxes[..., 0] + 1
376
+ h = bboxes[..., 3] - bboxes[..., 1] + 1
377
+ dw = (w * (scale - 1)) * 0.5
378
+ dh = (h * (scale - 1)) * 0.5
379
+ scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
380
+ if clip_shape is not None:
381
+ return bbox_clip(scaled_bboxes, clip_shape)
382
+ else:
383
+ return scaled_bboxes
384
+
385
+
386
+ def imcrop(img, bboxes, scale=1.0, pad_fill=None):
387
+ """Crop image patches.
388
+
389
+ 3 steps: scale the bboxes -> clip bboxes -> crop and pad.
390
+
391
+ Args:
392
+ img (ndarray): Image to be cropped.
393
+ bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
394
+ scale (float, optional): Scale ratio of bboxes, the default value
395
+ 1.0 means no padding.
396
+ pad_fill (Number | list[Number]): Value to be filled for padding.
397
+ Default: None, which means no padding.
398
+
399
+ Returns:
400
+ list[ndarray] | ndarray: The cropped image patches.
401
+ """
402
+ chn = 1 if img.ndim == 2 else img.shape[2]
403
+ if pad_fill is not None:
404
+ if isinstance(pad_fill, (int, float)):
405
+ pad_fill = [pad_fill for _ in range(chn)]
406
+ assert len(pad_fill) == chn
407
+
408
+ _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes
409
+ scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32)
410
+ clipped_bbox = bbox_clip(scaled_bboxes, img.shape)
411
+
412
+ patches = []
413
+ for i in range(clipped_bbox.shape[0]):
414
+ x1, y1, x2, y2 = tuple(clipped_bbox[i, :])
415
+ if pad_fill is None:
416
+ patch = img[y1:y2 + 1, x1:x2 + 1, ...]
417
+ else:
418
+ _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
419
+ if chn == 1:
420
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
421
+ else:
422
+ patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
423
+ patch = np.array(
424
+ pad_fill, dtype=img.dtype) * np.ones(
425
+ patch_shape, dtype=img.dtype)
426
+ x_start = 0 if _x1 >= 0 else -_x1
427
+ y_start = 0 if _y1 >= 0 else -_y1
428
+ w = x2 - x1 + 1
429
+ h = y2 - y1 + 1
430
+ patch[y_start:y_start + h, x_start:x_start + w,
431
+ ...] = img[y1:y1 + h, x1:x1 + w, ...]
432
+ patches.append(patch)
433
+
434
+ if bboxes.ndim == 1:
435
+ return patches[0]
436
+ else:
437
+ return patches
438
+
439
+
440
+ def impad(img,
441
+ *,
442
+ shape=None,
443
+ padding=None,
444
+ pad_val=0,
445
+ padding_mode='constant'):
446
+ """Pad the given image to a certain shape or pad on all sides with
447
+ specified padding mode and padding value.
448
+
449
+ Args:
450
+ img (ndarray): Image to be padded.
451
+ shape (tuple[int]): Expected padding shape (h, w). Default: None.
452
+ padding (int or tuple[int]): Padding on each border. If a single int is
453
+ provided this is used to pad all borders. If tuple of length 2 is
454
+ provided this is the padding on left/right and top/bottom
455
+ respectively. If a tuple of length 4 is provided this is the
456
+ padding for the left, top, right and bottom borders respectively.
457
+ Default: None. Note that `shape` and `padding` can not be both
458
+ set.
459
+ pad_val (Number | Sequence[Number]): Values to be filled in padding
460
+ areas when padding_mode is 'constant'. Default: 0.
461
+ padding_mode (str): Type of padding. Should be: constant, edge,
462
+ reflect or symmetric. Default: constant.
463
+
464
+ - constant: pads with a constant value, this value is specified
465
+ with pad_val.
466
+ - edge: pads with the last value at the edge of the image.
467
+ - reflect: pads with reflection of image without repeating the
468
+ last value on the edge. For example, padding [1, 2, 3, 4]
469
+ with 2 elements on both sides in reflect mode will result
470
+ in [3, 2, 1, 2, 3, 4, 3, 2].
471
+ - symmetric: pads with reflection of image repeating the last
472
+ value on the edge. For example, padding [1, 2, 3, 4] with
473
+ 2 elements on both sides in symmetric mode will result in
474
+ [2, 1, 1, 2, 3, 4, 4, 3]
475
+
476
+ Returns:
477
+ ndarray: The padded image.
478
+ """
479
+
480
+ assert (shape is not None) ^ (padding is not None)
481
+ if shape is not None:
482
+ padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0])
483
+
484
+ # check pad_val
485
+ if isinstance(pad_val, tuple):
486
+ assert len(pad_val) == img.shape[-1]
487
+ elif not isinstance(pad_val, numbers.Number):
488
+ raise TypeError('pad_val must be a int or a tuple. '
489
+ f'But received {type(pad_val)}')
490
+
491
+ # check padding
492
+ if isinstance(padding, tuple) and len(padding) in [2, 4]:
493
+ if len(padding) == 2:
494
+ padding = (padding[0], padding[1], padding[0], padding[1])
495
+ elif isinstance(padding, numbers.Number):
496
+ padding = (padding, padding, padding, padding)
497
+ else:
498
+ raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
499
+ f'But received {padding}')
500
+
501
+ # check padding mode
502
+ assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
503
+
504
+ border_type = {
505
+ 'constant': cv2.BORDER_CONSTANT,
506
+ 'edge': cv2.BORDER_REPLICATE,
507
+ 'reflect': cv2.BORDER_REFLECT_101,
508
+ 'symmetric': cv2.BORDER_REFLECT
509
+ }
510
+ img = cv2.copyMakeBorder(
511
+ img,
512
+ padding[1],
513
+ padding[3],
514
+ padding[0],
515
+ padding[2],
516
+ border_type[padding_mode],
517
+ value=pad_val)
518
+
519
+ return img
520
+
521
+
522
+ def impad_to_multiple(img, divisor, pad_val=0):
523
+ """Pad an image to ensure each edge to be multiple to some number.
524
+
525
+ Args:
526
+ img (ndarray): Image to be padded.
527
+ divisor (int): Padded image edges will be multiple to divisor.
528
+ pad_val (Number | Sequence[Number]): Same as :func:`impad`.
529
+
530
+ Returns:
531
+ ndarray: The padded image.
532
+ """
533
+ pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
534
+ pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
535
+ return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
536
+
537
+
538
+ def cutout(img, shape, pad_val=0):
539
+ """Randomly cut out a rectangle from the original img.
540
+
541
+ Args:
542
+ img (ndarray): Image to be cutout.
543
+ shape (int | tuple[int]): Expected cutout shape (h, w). If given as a
544
+ int, the value will be used for both h and w.
545
+ pad_val (int | float | tuple[int | float]): Values to be filled in the
546
+ cut area. Defaults to 0.
547
+
548
+ Returns:
549
+ ndarray: The cutout image.
550
+ """
551
+
552
+ channels = 1 if img.ndim == 2 else img.shape[2]
553
+ if isinstance(shape, int):
554
+ cut_h, cut_w = shape, shape
555
+ else:
556
+ assert isinstance(shape, tuple) and len(shape) == 2, \
557
+ f'shape must be a int or a tuple with length 2, but got type ' \
558
+ f'{type(shape)} instead.'
559
+ cut_h, cut_w = shape
560
+ if isinstance(pad_val, (int, float)):
561
+ pad_val = tuple([pad_val] * channels)
562
+ elif isinstance(pad_val, tuple):
563
+ assert len(pad_val) == channels, \
564
+ 'Expected the num of elements in tuple equals the channels' \
565
+ 'of input image. Found {} vs {}'.format(
566
+ len(pad_val), channels)
567
+ else:
568
+ raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`')
569
+
570
+ img_h, img_w = img.shape[:2]
571
+ y0 = np.random.uniform(img_h)
572
+ x0 = np.random.uniform(img_w)
573
+
574
+ y1 = int(max(0, y0 - cut_h / 2.))
575
+ x1 = int(max(0, x0 - cut_w / 2.))
576
+ y2 = min(img_h, y1 + cut_h)
577
+ x2 = min(img_w, x1 + cut_w)
578
+
579
+ if img.ndim == 2:
580
+ patch_shape = (y2 - y1, x2 - x1)
581
+ else:
582
+ patch_shape = (y2 - y1, x2 - x1, channels)
583
+
584
+ img_cutout = img.copy()
585
+ patch = np.array(
586
+ pad_val, dtype=img.dtype) * np.ones(
587
+ patch_shape, dtype=img.dtype)
588
+ img_cutout[y1:y2, x1:x2, ...] = patch
589
+
590
+ return img_cutout
591
+
592
+
593
+ def _get_shear_matrix(magnitude, direction='horizontal'):
594
+ """Generate the shear matrix for transformation.
595
+
596
+ Args:
597
+ magnitude (int | float): The magnitude used for shear.
598
+ direction (str): The flip direction, either "horizontal"
599
+ or "vertical".
600
+
601
+ Returns:
602
+ ndarray: The shear matrix with dtype float32.
603
+ """
604
+ if direction == 'horizontal':
605
+ shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]])
606
+ elif direction == 'vertical':
607
+ shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]])
608
+ return shear_matrix
609
+
610
+
611
+ def imshear(img,
612
+ magnitude,
613
+ direction='horizontal',
614
+ border_value=0,
615
+ interpolation='bilinear'):
616
+ """Shear an image.
617
+
618
+ Args:
619
+ img (ndarray): Image to be sheared with format (h, w)
620
+ or (h, w, c).
621
+ magnitude (int | float): The magnitude used for shear.
622
+ direction (str): The flip direction, either "horizontal"
623
+ or "vertical".
624
+ border_value (int | tuple[int]): Value used in case of a
625
+ constant border.
626
+ interpolation (str): Same as :func:`resize`.
627
+
628
+ Returns:
629
+ ndarray: The sheared image.
630
+ """
631
+ assert direction in ['horizontal',
632
+ 'vertical'], f'Invalid direction: {direction}'
633
+ height, width = img.shape[:2]
634
+ if img.ndim == 2:
635
+ channels = 1
636
+ elif img.ndim == 3:
637
+ channels = img.shape[-1]
638
+ if isinstance(border_value, int):
639
+ border_value = tuple([border_value] * channels)
640
+ elif isinstance(border_value, tuple):
641
+ assert len(border_value) == channels, \
642
+ 'Expected the num of elements in tuple equals the channels' \
643
+ 'of input image. Found {} vs {}'.format(
644
+ len(border_value), channels)
645
+ else:
646
+ raise ValueError(
647
+ f'Invalid type {type(border_value)} for `border_value`')
648
+ shear_matrix = _get_shear_matrix(magnitude, direction)
649
+ sheared = cv2.warpAffine(
650
+ img,
651
+ shear_matrix,
652
+ (width, height),
653
+ # Note case when the number elements in `border_value`
654
+ # greater than 3 (e.g. shearing masks whose channels large
655
+ # than 3) will raise TypeError in `cv2.warpAffine`.
656
+ # Here simply slice the first 3 values in `border_value`.
657
+ borderValue=border_value[:3],
658
+ flags=cv2_interp_codes[interpolation])
659
+ return sheared
660
+
661
+
662
+ def _get_translate_matrix(offset, direction='horizontal'):
663
+ """Generate the translate matrix.
664
+
665
+ Args:
666
+ offset (int | float): The offset used for translate.
667
+ direction (str): The translate direction, either
668
+ "horizontal" or "vertical".
669
+
670
+ Returns:
671
+ ndarray: The translate matrix with dtype float32.
672
+ """
673
+ if direction == 'horizontal':
674
+ translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]])
675
+ elif direction == 'vertical':
676
+ translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]])
677
+ return translate_matrix
678
+
679
+
680
+ def imtranslate(img,
681
+ offset,
682
+ direction='horizontal',
683
+ border_value=0,
684
+ interpolation='bilinear'):
685
+ """Translate an image.
686
+
687
+ Args:
688
+ img (ndarray): Image to be translated with format
689
+ (h, w) or (h, w, c).
690
+ offset (int | float): The offset used for translate.
691
+ direction (str): The translate direction, either "horizontal"
692
+ or "vertical".
693
+ border_value (int | tuple[int]): Value used in case of a
694
+ constant border.
695
+ interpolation (str): Same as :func:`resize`.
696
+
697
+ Returns:
698
+ ndarray: The translated image.
699
+ """
700
+ assert direction in ['horizontal',
701
+ 'vertical'], f'Invalid direction: {direction}'
702
+ height, width = img.shape[:2]
703
+ if img.ndim == 2:
704
+ channels = 1
705
+ elif img.ndim == 3:
706
+ channels = img.shape[-1]
707
+ if isinstance(border_value, int):
708
+ border_value = tuple([border_value] * channels)
709
+ elif isinstance(border_value, tuple):
710
+ assert len(border_value) == channels, \
711
+ 'Expected the num of elements in tuple equals the channels' \
712
+ 'of input image. Found {} vs {}'.format(
713
+ len(border_value), channels)
714
+ else:
715
+ raise ValueError(
716
+ f'Invalid type {type(border_value)} for `border_value`.')
717
+ translate_matrix = _get_translate_matrix(offset, direction)
718
+ translated = cv2.warpAffine(
719
+ img,
720
+ translate_matrix,
721
+ (width, height),
722
+ # Note case when the number elements in `border_value`
723
+ # greater than 3 (e.g. translating masks whose channels
724
+ # large than 3) will raise TypeError in `cv2.warpAffine`.
725
+ # Here simply slice the first 3 values in `border_value`.
726
+ borderValue=border_value[:3],
727
+ flags=cv2_interp_codes[interpolation])
728
+ return translated
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/io.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import io
3
+ import os.path as osp
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
9
+ IMREAD_UNCHANGED)
10
+
11
+ from annotator.mmpkg.mmcv.utils import check_file_exist, is_str, mkdir_or_exist
12
+
13
+ try:
14
+ from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
15
+ except ImportError:
16
+ TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None
17
+
18
+ try:
19
+ from PIL import Image, ImageOps
20
+ except ImportError:
21
+ Image = None
22
+
23
+ try:
24
+ import tifffile
25
+ except ImportError:
26
+ tifffile = None
27
+
28
+ jpeg = None
29
+ supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile']
30
+
31
+ imread_flags = {
32
+ 'color': IMREAD_COLOR,
33
+ 'grayscale': IMREAD_GRAYSCALE,
34
+ 'unchanged': IMREAD_UNCHANGED,
35
+ 'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR,
36
+ 'grayscale_ignore_orientation':
37
+ IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE
38
+ }
39
+
40
+ imread_backend = 'cv2'
41
+
42
+
43
+ def use_backend(backend):
44
+ """Select a backend for image decoding.
45
+
46
+ Args:
47
+ backend (str): The image decoding backend type. Options are `cv2`,
48
+ `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG)
49
+ and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg`
50
+ file format.
51
+ """
52
+ assert backend in supported_backends
53
+ global imread_backend
54
+ imread_backend = backend
55
+ if imread_backend == 'turbojpeg':
56
+ if TurboJPEG is None:
57
+ raise ImportError('`PyTurboJPEG` is not installed')
58
+ global jpeg
59
+ if jpeg is None:
60
+ jpeg = TurboJPEG()
61
+ elif imread_backend == 'pillow':
62
+ if Image is None:
63
+ raise ImportError('`Pillow` is not installed')
64
+ elif imread_backend == 'tifffile':
65
+ if tifffile is None:
66
+ raise ImportError('`tifffile` is not installed')
67
+
68
+
69
+ def _jpegflag(flag='color', channel_order='bgr'):
70
+ channel_order = channel_order.lower()
71
+ if channel_order not in ['rgb', 'bgr']:
72
+ raise ValueError('channel order must be either "rgb" or "bgr"')
73
+
74
+ if flag == 'color':
75
+ if channel_order == 'bgr':
76
+ return TJPF_BGR
77
+ elif channel_order == 'rgb':
78
+ return TJCS_RGB
79
+ elif flag == 'grayscale':
80
+ return TJPF_GRAY
81
+ else:
82
+ raise ValueError('flag must be "color" or "grayscale"')
83
+
84
+
85
+ def _pillow2array(img, flag='color', channel_order='bgr'):
86
+ """Convert a pillow image to numpy array.
87
+
88
+ Args:
89
+ img (:obj:`PIL.Image.Image`): The image loaded using PIL
90
+ flag (str): Flags specifying the color type of a loaded image,
91
+ candidates are 'color', 'grayscale' and 'unchanged'.
92
+ Default to 'color'.
93
+ channel_order (str): The channel order of the output image array,
94
+ candidates are 'bgr' and 'rgb'. Default to 'bgr'.
95
+
96
+ Returns:
97
+ np.ndarray: The converted numpy array
98
+ """
99
+ channel_order = channel_order.lower()
100
+ if channel_order not in ['rgb', 'bgr']:
101
+ raise ValueError('channel order must be either "rgb" or "bgr"')
102
+
103
+ if flag == 'unchanged':
104
+ array = np.array(img)
105
+ if array.ndim >= 3 and array.shape[2] >= 3: # color image
106
+ array[:, :, :3] = array[:, :, (2, 1, 0)] # RGB to BGR
107
+ else:
108
+ # Handle exif orientation tag
109
+ if flag in ['color', 'grayscale']:
110
+ img = ImageOps.exif_transpose(img)
111
+ # If the image mode is not 'RGB', convert it to 'RGB' first.
112
+ if img.mode != 'RGB':
113
+ if img.mode != 'LA':
114
+ # Most formats except 'LA' can be directly converted to RGB
115
+ img = img.convert('RGB')
116
+ else:
117
+ # When the mode is 'LA', the default conversion will fill in
118
+ # the canvas with black, which sometimes shadows black objects
119
+ # in the foreground.
120
+ #
121
+ # Therefore, a random color (124, 117, 104) is used for canvas
122
+ img_rgba = img.convert('RGBA')
123
+ img = Image.new('RGB', img_rgba.size, (124, 117, 104))
124
+ img.paste(img_rgba, mask=img_rgba.split()[3]) # 3 is alpha
125
+ if flag in ['color', 'color_ignore_orientation']:
126
+ array = np.array(img)
127
+ if channel_order != 'rgb':
128
+ array = array[:, :, ::-1] # RGB to BGR
129
+ elif flag in ['grayscale', 'grayscale_ignore_orientation']:
130
+ img = img.convert('L')
131
+ array = np.array(img)
132
+ else:
133
+ raise ValueError(
134
+ 'flag must be "color", "grayscale", "unchanged", '
135
+ f'"color_ignore_orientation" or "grayscale_ignore_orientation"'
136
+ f' but got {flag}')
137
+ return array
138
+
139
+
140
+ def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
141
+ """Read an image.
142
+
143
+ Args:
144
+ img_or_path (ndarray or str or Path): Either a numpy array or str or
145
+ pathlib.Path. If it is a numpy array (loaded image), then
146
+ it will be returned as is.
147
+ flag (str): Flags specifying the color type of a loaded image,
148
+ candidates are `color`, `grayscale`, `unchanged`,
149
+ `color_ignore_orientation` and `grayscale_ignore_orientation`.
150
+ By default, `cv2` and `pillow` backend would rotate the image
151
+ according to its EXIF info unless called with `unchanged` or
152
+ `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend
153
+ always ignore image's EXIF info regardless of the flag.
154
+ The `turbojpeg` backend only supports `color` and `grayscale`.
155
+ channel_order (str): Order of channel, candidates are `bgr` and `rgb`.
156
+ backend (str | None): The image decoding backend type. Options are
157
+ `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
158
+ If backend is None, the global imread_backend specified by
159
+ ``mmcv.use_backend()`` will be used. Default: None.
160
+
161
+ Returns:
162
+ ndarray: Loaded image array.
163
+ """
164
+
165
+ if backend is None:
166
+ backend = imread_backend
167
+ if backend not in supported_backends:
168
+ raise ValueError(f'backend: {backend} is not supported. Supported '
169
+ "backends are 'cv2', 'turbojpeg', 'pillow'")
170
+ if isinstance(img_or_path, Path):
171
+ img_or_path = str(img_or_path)
172
+
173
+ if isinstance(img_or_path, np.ndarray):
174
+ return img_or_path
175
+ elif is_str(img_or_path):
176
+ check_file_exist(img_or_path,
177
+ f'img file does not exist: {img_or_path}')
178
+ if backend == 'turbojpeg':
179
+ with open(img_or_path, 'rb') as in_file:
180
+ img = jpeg.decode(in_file.read(),
181
+ _jpegflag(flag, channel_order))
182
+ if img.shape[-1] == 1:
183
+ img = img[:, :, 0]
184
+ return img
185
+ elif backend == 'pillow':
186
+ img = Image.open(img_or_path)
187
+ img = _pillow2array(img, flag, channel_order)
188
+ return img
189
+ elif backend == 'tifffile':
190
+ img = tifffile.imread(img_or_path)
191
+ return img
192
+ else:
193
+ flag = imread_flags[flag] if is_str(flag) else flag
194
+ img = cv2.imread(img_or_path, flag)
195
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
196
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
197
+ return img
198
+ else:
199
+ raise TypeError('"img" must be a numpy array or a str or '
200
+ 'a pathlib.Path object')
201
+
202
+
203
+ def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
204
+ """Read an image from bytes.
205
+
206
+ Args:
207
+ content (bytes): Image bytes got from files or other streams.
208
+ flag (str): Same as :func:`imread`.
209
+ backend (str | None): The image decoding backend type. Options are
210
+ `cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
211
+ global imread_backend specified by ``mmcv.use_backend()`` will be
212
+ used. Default: None.
213
+
214
+ Returns:
215
+ ndarray: Loaded image array.
216
+ """
217
+
218
+ if backend is None:
219
+ backend = imread_backend
220
+ if backend not in supported_backends:
221
+ raise ValueError(f'backend: {backend} is not supported. Supported '
222
+ "backends are 'cv2', 'turbojpeg', 'pillow'")
223
+ if backend == 'turbojpeg':
224
+ img = jpeg.decode(content, _jpegflag(flag, channel_order))
225
+ if img.shape[-1] == 1:
226
+ img = img[:, :, 0]
227
+ return img
228
+ elif backend == 'pillow':
229
+ buff = io.BytesIO(content)
230
+ img = Image.open(buff)
231
+ img = _pillow2array(img, flag, channel_order)
232
+ return img
233
+ else:
234
+ img_np = np.frombuffer(content, np.uint8)
235
+ flag = imread_flags[flag] if is_str(flag) else flag
236
+ img = cv2.imdecode(img_np, flag)
237
+ if flag == IMREAD_COLOR and channel_order == 'rgb':
238
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
239
+ return img
240
+
241
+
242
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
243
+ """Write image to file.
244
+
245
+ Args:
246
+ img (ndarray): Image array to be written.
247
+ file_path (str): Image file path.
248
+ params (None or list): Same as opencv :func:`imwrite` interface.
249
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
250
+ whether to create it automatically.
251
+
252
+ Returns:
253
+ bool: Successful or not.
254
+ """
255
+ if auto_mkdir:
256
+ dir_name = osp.abspath(osp.dirname(file_path))
257
+ mkdir_or_exist(dir_name)
258
+ return cv2.imwrite(file_path, img, params)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/misc.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import numpy as np
3
+
4
+ import annotator.mmpkg.mmcv as mmcv
5
+
6
+ try:
7
+ import torch
8
+ except ImportError:
9
+ torch = None
10
+
11
+
12
+ def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
13
+ """Convert tensor to 3-channel images.
14
+
15
+ Args:
16
+ tensor (torch.Tensor): Tensor that contains multiple images, shape (
17
+ N, C, H, W).
18
+ mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
19
+ std (tuple[float], optional): Standard deviation of images.
20
+ Defaults to (1, 1, 1).
21
+ to_rgb (bool, optional): Whether the tensor was converted to RGB
22
+ format in the first place. If so, convert it back to BGR.
23
+ Defaults to True.
24
+
25
+ Returns:
26
+ list[np.ndarray]: A list that contains multiple images.
27
+ """
28
+
29
+ if torch is None:
30
+ raise RuntimeError('pytorch is not installed')
31
+ assert torch.is_tensor(tensor) and tensor.ndim == 4
32
+ assert len(mean) == 3
33
+ assert len(std) == 3
34
+
35
+ num_imgs = tensor.size(0)
36
+ mean = np.array(mean, dtype=np.float32)
37
+ std = np.array(std, dtype=np.float32)
38
+ imgs = []
39
+ for img_id in range(num_imgs):
40
+ img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
41
+ img = mmcv.imdenormalize(
42
+ img, mean, std, to_bgr=to_rgb).astype(np.uint8)
43
+ imgs.append(np.ascontiguousarray(img))
44
+ return imgs
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/image/photometric.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import cv2
3
+ import numpy as np
4
+
5
+ from ..utils import is_tuple_of
6
+ from .colorspace import bgr2gray, gray2bgr
7
+
8
+
9
+ def imnormalize(img, mean, std, to_rgb=True):
10
+ """Normalize an image with mean and std.
11
+
12
+ Args:
13
+ img (ndarray): Image to be normalized.
14
+ mean (ndarray): The mean to be used for normalize.
15
+ std (ndarray): The std to be used for normalize.
16
+ to_rgb (bool): Whether to convert to rgb.
17
+
18
+ Returns:
19
+ ndarray: The normalized image.
20
+ """
21
+ img = img.copy().astype(np.float32)
22
+ return imnormalize_(img, mean, std, to_rgb)
23
+
24
+
25
+ def imnormalize_(img, mean, std, to_rgb=True):
26
+ """Inplace normalize an image with mean and std.
27
+
28
+ Args:
29
+ img (ndarray): Image to be normalized.
30
+ mean (ndarray): The mean to be used for normalize.
31
+ std (ndarray): The std to be used for normalize.
32
+ to_rgb (bool): Whether to convert to rgb.
33
+
34
+ Returns:
35
+ ndarray: The normalized image.
36
+ """
37
+ # cv2 inplace normalization does not accept uint8
38
+ assert img.dtype != np.uint8
39
+ mean = np.float64(mean.reshape(1, -1))
40
+ stdinv = 1 / np.float64(std.reshape(1, -1))
41
+ if to_rgb:
42
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
43
+ cv2.subtract(img, mean, img) # inplace
44
+ cv2.multiply(img, stdinv, img) # inplace
45
+ return img
46
+
47
+
48
+ def imdenormalize(img, mean, std, to_bgr=True):
49
+ assert img.dtype != np.uint8
50
+ mean = mean.reshape(1, -1).astype(np.float64)
51
+ std = std.reshape(1, -1).astype(np.float64)
52
+ img = cv2.multiply(img, std) # make a copy
53
+ cv2.add(img, mean, img) # inplace
54
+ if to_bgr:
55
+ cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img) # inplace
56
+ return img
57
+
58
+
59
+ def iminvert(img):
60
+ """Invert (negate) an image.
61
+
62
+ Args:
63
+ img (ndarray): Image to be inverted.
64
+
65
+ Returns:
66
+ ndarray: The inverted image.
67
+ """
68
+ return np.full_like(img, 255) - img
69
+
70
+
71
+ def solarize(img, thr=128):
72
+ """Solarize an image (invert all pixel values above a threshold)
73
+
74
+ Args:
75
+ img (ndarray): Image to be solarized.
76
+ thr (int): Threshold for solarizing (0 - 255).
77
+
78
+ Returns:
79
+ ndarray: The solarized image.
80
+ """
81
+ img = np.where(img < thr, img, 255 - img)
82
+ return img
83
+
84
+
85
+ def posterize(img, bits):
86
+ """Posterize an image (reduce the number of bits for each color channel)
87
+
88
+ Args:
89
+ img (ndarray): Image to be posterized.
90
+ bits (int): Number of bits (1 to 8) to use for posterizing.
91
+
92
+ Returns:
93
+ ndarray: The posterized image.
94
+ """
95
+ shift = 8 - bits
96
+ img = np.left_shift(np.right_shift(img, shift), shift)
97
+ return img
98
+
99
+
100
+ def adjust_color(img, alpha=1, beta=None, gamma=0):
101
+ r"""It blends the source image and its gray image:
102
+
103
+ .. math::
104
+ output = img * alpha + gray\_img * beta + gamma
105
+
106
+ Args:
107
+ img (ndarray): The input source image.
108
+ alpha (int | float): Weight for the source image. Default 1.
109
+ beta (int | float): Weight for the converted gray image.
110
+ If None, it's assigned the value (1 - `alpha`).
111
+ gamma (int | float): Scalar added to each sum.
112
+ Same as :func:`cv2.addWeighted`. Default 0.
113
+
114
+ Returns:
115
+ ndarray: Colored image which has the same size and dtype as input.
116
+ """
117
+ gray_img = bgr2gray(img)
118
+ gray_img = np.tile(gray_img[..., None], [1, 1, 3])
119
+ if beta is None:
120
+ beta = 1 - alpha
121
+ colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
122
+ if not colored_img.dtype == np.uint8:
123
+ # Note when the dtype of `img` is not the default `np.uint8`
124
+ # (e.g. np.float32), the value in `colored_img` got from cv2
125
+ # is not guaranteed to be in range [0, 255], so here clip
126
+ # is needed.
127
+ colored_img = np.clip(colored_img, 0, 255)
128
+ return colored_img
129
+
130
+
131
+ def imequalize(img):
132
+ """Equalize the image histogram.
133
+
134
+ This function applies a non-linear mapping to the input image,
135
+ in order to create a uniform distribution of grayscale values
136
+ in the output image.
137
+
138
+ Args:
139
+ img (ndarray): Image to be equalized.
140
+
141
+ Returns:
142
+ ndarray: The equalized image.
143
+ """
144
+
145
+ def _scale_channel(im, c):
146
+ """Scale the data in the corresponding channel."""
147
+ im = im[:, :, c]
148
+ # Compute the histogram of the image channel.
149
+ histo = np.histogram(im, 256, (0, 255))[0]
150
+ # For computing the step, filter out the nonzeros.
151
+ nonzero_histo = histo[histo > 0]
152
+ step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255
153
+ if not step:
154
+ lut = np.array(range(256))
155
+ else:
156
+ # Compute the cumulative sum, shifted by step // 2
157
+ # and then normalized by step.
158
+ lut = (np.cumsum(histo) + (step // 2)) // step
159
+ # Shift lut, prepending with 0.
160
+ lut = np.concatenate([[0], lut[:-1]], 0)
161
+ # handle potential integer overflow
162
+ lut[lut > 255] = 255
163
+ # If step is zero, return the original image.
164
+ # Otherwise, index from lut.
165
+ return np.where(np.equal(step, 0), im, lut[im])
166
+
167
+ # Scales each channel independently and then stacks
168
+ # the result.
169
+ s1 = _scale_channel(img, 0)
170
+ s2 = _scale_channel(img, 1)
171
+ s3 = _scale_channel(img, 2)
172
+ equalized_img = np.stack([s1, s2, s3], axis=-1)
173
+ return equalized_img.astype(img.dtype)
174
+
175
+
176
+ def adjust_brightness(img, factor=1.):
177
+ """Adjust image brightness.
178
+
179
+ This function controls the brightness of an image. An
180
+ enhancement factor of 0.0 gives a black image.
181
+ A factor of 1.0 gives the original image. This function
182
+ blends the source image and the degenerated black image:
183
+
184
+ .. math::
185
+ output = img * factor + degenerated * (1 - factor)
186
+
187
+ Args:
188
+ img (ndarray): Image to be brightened.
189
+ factor (float): A value controls the enhancement.
190
+ Factor 1.0 returns the original image, lower
191
+ factors mean less color (brightness, contrast,
192
+ etc), and higher values more. Default 1.
193
+
194
+ Returns:
195
+ ndarray: The brightened image.
196
+ """
197
+ degenerated = np.zeros_like(img)
198
+ # Note manually convert the dtype to np.float32, to
199
+ # achieve as close results as PIL.ImageEnhance.Brightness.
200
+ # Set beta=1-factor, and gamma=0
201
+ brightened_img = cv2.addWeighted(
202
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
203
+ 1 - factor, 0)
204
+ brightened_img = np.clip(brightened_img, 0, 255)
205
+ return brightened_img.astype(img.dtype)
206
+
207
+
208
+ def adjust_contrast(img, factor=1.):
209
+ """Adjust image contrast.
210
+
211
+ This function controls the contrast of an image. An
212
+ enhancement factor of 0.0 gives a solid grey
213
+ image. A factor of 1.0 gives the original image. It
214
+ blends the source image and the degenerated mean image:
215
+
216
+ .. math::
217
+ output = img * factor + degenerated * (1 - factor)
218
+
219
+ Args:
220
+ img (ndarray): Image to be contrasted. BGR order.
221
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
222
+
223
+ Returns:
224
+ ndarray: The contrasted image.
225
+ """
226
+ gray_img = bgr2gray(img)
227
+ hist = np.histogram(gray_img, 256, (0, 255))[0]
228
+ mean = round(np.sum(gray_img) / np.sum(hist))
229
+ degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
230
+ degenerated = gray2bgr(degenerated)
231
+ contrasted_img = cv2.addWeighted(
232
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
233
+ 1 - factor, 0)
234
+ contrasted_img = np.clip(contrasted_img, 0, 255)
235
+ return contrasted_img.astype(img.dtype)
236
+
237
+
238
+ def auto_contrast(img, cutoff=0):
239
+ """Auto adjust image contrast.
240
+
241
+ This function maximize (normalize) image contrast by first removing cutoff
242
+ percent of the lightest and darkest pixels from the histogram and remapping
243
+ the image so that the darkest pixel becomes black (0), and the lightest
244
+ becomes white (255).
245
+
246
+ Args:
247
+ img (ndarray): Image to be contrasted. BGR order.
248
+ cutoff (int | float | tuple): The cutoff percent of the lightest and
249
+ darkest pixels to be removed. If given as tuple, it shall be
250
+ (low, high). Otherwise, the single value will be used for both.
251
+ Defaults to 0.
252
+
253
+ Returns:
254
+ ndarray: The contrasted image.
255
+ """
256
+
257
+ def _auto_contrast_channel(im, c, cutoff):
258
+ im = im[:, :, c]
259
+ # Compute the histogram of the image channel.
260
+ histo = np.histogram(im, 256, (0, 255))[0]
261
+ # Remove cut-off percent pixels from histo
262
+ histo_sum = np.cumsum(histo)
263
+ cut_low = histo_sum[-1] * cutoff[0] // 100
264
+ cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100
265
+ histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low
266
+ histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0)
267
+
268
+ # Compute mapping
269
+ low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1]
270
+ # If all the values have been cut off, return the origin img
271
+ if low >= high:
272
+ return im
273
+ scale = 255.0 / (high - low)
274
+ offset = -low * scale
275
+ lut = np.array(range(256))
276
+ lut = lut * scale + offset
277
+ lut = np.clip(lut, 0, 255)
278
+ return lut[im]
279
+
280
+ if isinstance(cutoff, (int, float)):
281
+ cutoff = (cutoff, cutoff)
282
+ else:
283
+ assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \
284
+ f'float or tuple, but got {type(cutoff)} instead.'
285
+ # Auto adjusts contrast for each channel independently and then stacks
286
+ # the result.
287
+ s1 = _auto_contrast_channel(img, 0, cutoff)
288
+ s2 = _auto_contrast_channel(img, 1, cutoff)
289
+ s3 = _auto_contrast_channel(img, 2, cutoff)
290
+ contrasted_img = np.stack([s1, s2, s3], axis=-1)
291
+ return contrasted_img.astype(img.dtype)
292
+
293
+
294
+ def adjust_sharpness(img, factor=1., kernel=None):
295
+ """Adjust image sharpness.
296
+
297
+ This function controls the sharpness of an image. An
298
+ enhancement factor of 0.0 gives a blurred image. A
299
+ factor of 1.0 gives the original image. And a factor
300
+ of 2.0 gives a sharpened image. It blends the source
301
+ image and the degenerated mean image:
302
+
303
+ .. math::
304
+ output = img * factor + degenerated * (1 - factor)
305
+
306
+ Args:
307
+ img (ndarray): Image to be sharpened. BGR order.
308
+ factor (float): Same as :func:`mmcv.adjust_brightness`.
309
+ kernel (np.ndarray, optional): Filter kernel to be applied on the img
310
+ to obtain the degenerated img. Defaults to None.
311
+
312
+ Note:
313
+ No value sanity check is enforced on the kernel set by users. So with
314
+ an inappropriate kernel, the ``adjust_sharpness`` may fail to perform
315
+ the function its name indicates but end up performing whatever
316
+ transform determined by the kernel.
317
+
318
+ Returns:
319
+ ndarray: The sharpened image.
320
+ """
321
+
322
+ if kernel is None:
323
+ # adopted from PIL.ImageFilter.SMOOTH
324
+ kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13
325
+ assert isinstance(kernel, np.ndarray), \
326
+ f'kernel must be of type np.ndarray, but got {type(kernel)} instead.'
327
+ assert kernel.ndim == 2, \
328
+ f'kernel must have a dimension of 2, but got {kernel.ndim} instead.'
329
+
330
+ degenerated = cv2.filter2D(img, -1, kernel)
331
+ sharpened_img = cv2.addWeighted(
332
+ img.astype(np.float32), factor, degenerated.astype(np.float32),
333
+ 1 - factor, 0)
334
+ sharpened_img = np.clip(sharpened_img, 0, 255)
335
+ return sharpened_img.astype(img.dtype)
336
+
337
+
338
+ def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True):
339
+ """AlexNet-style PCA jitter.
340
+
341
+ This data augmentation is proposed in `ImageNet Classification with Deep
342
+ Convolutional Neural Networks
343
+ <https://dl.acm.org/doi/pdf/10.1145/3065386>`_.
344
+
345
+ Args:
346
+ img (ndarray): Image to be adjusted lighting. BGR order.
347
+ eigval (ndarray): the eigenvalue of the convariance matrix of pixel
348
+ values, respectively.
349
+ eigvec (ndarray): the eigenvector of the convariance matrix of pixel
350
+ values, respectively.
351
+ alphastd (float): The standard deviation for distribution of alpha.
352
+ Defaults to 0.1
353
+ to_rgb (bool): Whether to convert img to rgb.
354
+
355
+ Returns:
356
+ ndarray: The adjusted image.
357
+ """
358
+ assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \
359
+ f'eigval and eigvec should both be of type np.ndarray, got ' \
360
+ f'{type(eigval)} and {type(eigvec)} instead.'
361
+
362
+ assert eigval.ndim == 1 and eigvec.ndim == 2
363
+ assert eigvec.shape == (3, eigval.shape[0])
364
+ n_eigval = eigval.shape[0]
365
+ assert isinstance(alphastd, float), 'alphastd should be of type float, ' \
366
+ f'got {type(alphastd)} instead.'
367
+
368
+ img = img.copy().astype(np.float32)
369
+ if to_rgb:
370
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
371
+
372
+ alpha = np.random.normal(0, alphastd, n_eigval)
373
+ alter = eigvec \
374
+ * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \
375
+ * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval))
376
+ alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape)
377
+ img_adjusted = img + alter
378
+ return img_adjusted
379
+
380
+
381
+ def lut_transform(img, lut_table):
382
+ """Transform array by look-up table.
383
+
384
+ The function lut_transform fills the output array with values from the
385
+ look-up table. Indices of the entries are taken from the input array.
386
+
387
+ Args:
388
+ img (ndarray): Image to be transformed.
389
+ lut_table (ndarray): look-up table of 256 elements; in case of
390
+ multi-channel input array, the table should either have a single
391
+ channel (in this case the same table is used for all channels) or
392
+ the same number of channels as in the input array.
393
+
394
+ Returns:
395
+ ndarray: The transformed image.
396
+ """
397
+ assert isinstance(img, np.ndarray)
398
+ assert 0 <= np.min(img) and np.max(img) <= 255
399
+ assert isinstance(lut_table, np.ndarray)
400
+ assert lut_table.shape == (256, )
401
+
402
+ return cv2.LUT(np.array(img, dtype=np.uint8), lut_table)
403
+
404
+
405
+ def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
406
+ """Use CLAHE method to process the image.
407
+
408
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
409
+ Graphics Gems, 1994:474-485.` for more information.
410
+
411
+ Args:
412
+ img (ndarray): Image to be processed.
413
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
414
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
415
+ Input image will be divided into equally sized rectangular tiles.
416
+ It defines the number of tiles in row and column. Default: (8, 8).
417
+
418
+ Returns:
419
+ ndarray: The processed image.
420
+ """
421
+ assert isinstance(img, np.ndarray)
422
+ assert img.ndim == 2
423
+ assert isinstance(clip_limit, (float, int))
424
+ assert is_tuple_of(tile_grid_size, int)
425
+ assert len(tile_grid_size) == 2
426
+
427
+ clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
428
+ return clahe.apply(np.array(img, dtype=np.uint8))