toto10 commited on
Commit
5f4ea79
1 Parent(s): 373d463

4f3c990a87ec971910ff61a029b3ca21870c19ebc5abc37dc65d3e23339b5592

Browse files
Files changed (50) hide show
  1. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/ext_loader.py +71 -0
  2. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/logging.py +110 -0
  3. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/misc.py +377 -0
  4. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_jit.py +41 -0
  5. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_wrapper.py +107 -0
  6. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/path.py +101 -0
  7. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/progressbar.py +208 -0
  8. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/registry.py +315 -0
  9. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/testing.py +140 -0
  10. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/timer.py +118 -0
  11. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/trace.py +23 -0
  12. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/version_utils.py +90 -0
  13. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/version.py +35 -0
  14. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/__init__.py +11 -0
  15. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/io.py +318 -0
  16. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/optflow.py +254 -0
  17. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/processing.py +160 -0
  18. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/__init__.py +9 -0
  19. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/color.py +51 -0
  20. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/image.py +152 -0
  21. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/optflow.py +112 -0
  22. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/__init__.py +9 -0
  23. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/inference.py +138 -0
  24. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/test.py +238 -0
  25. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/train.py +116 -0
  26. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/__init__.py +3 -0
  27. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/__init__.py +8 -0
  28. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/class_names.py +152 -0
  29. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/eval_hooks.py +109 -0
  30. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/metrics.py +326 -0
  31. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/__init__.py +4 -0
  32. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/builder.py +8 -0
  33. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/__init__.py +4 -0
  34. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/base_pixel_sampler.py +12 -0
  35. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/ohem_pixel_sampler.py +76 -0
  36. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/__init__.py +3 -0
  37. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/misc.py +17 -0
  38. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/__init__.py +19 -0
  39. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/ade.py +84 -0
  40. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/builder.py +169 -0
  41. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/chase_db1.py +27 -0
  42. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/cityscapes.py +217 -0
  43. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/custom.py +403 -0
  44. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/dataset_wrappers.py +50 -0
  45. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/drive.py +27 -0
  46. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/hrf.py +27 -0
  47. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pascal_context.py +103 -0
  48. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/__init__.py +16 -0
  49. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/compose.py +51 -0
  50. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/formating.py +288 -0
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/ext_loader.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import importlib
3
+ import os
4
+ import pkgutil
5
+ import warnings
6
+ from collections import namedtuple
7
+
8
+ import torch
9
+
10
+ if torch.__version__ != 'parrots':
11
+
12
+ def load_ext(name, funcs):
13
+ ext = importlib.import_module('mmcv.' + name)
14
+ for fun in funcs:
15
+ assert hasattr(ext, fun), f'{fun} miss in module {name}'
16
+ return ext
17
+ else:
18
+ from parrots import extension
19
+ from parrots.base import ParrotsException
20
+
21
+ has_return_value_ops = [
22
+ 'nms',
23
+ 'softnms',
24
+ 'nms_match',
25
+ 'nms_rotated',
26
+ 'top_pool_forward',
27
+ 'top_pool_backward',
28
+ 'bottom_pool_forward',
29
+ 'bottom_pool_backward',
30
+ 'left_pool_forward',
31
+ 'left_pool_backward',
32
+ 'right_pool_forward',
33
+ 'right_pool_backward',
34
+ 'fused_bias_leakyrelu',
35
+ 'upfirdn2d',
36
+ 'ms_deform_attn_forward',
37
+ 'pixel_group',
38
+ 'contour_expand',
39
+ ]
40
+
41
+ def get_fake_func(name, e):
42
+
43
+ def fake_func(*args, **kwargs):
44
+ warnings.warn(f'{name} is not supported in parrots now')
45
+ raise e
46
+
47
+ return fake_func
48
+
49
+ def load_ext(name, funcs):
50
+ ExtModule = namedtuple('ExtModule', funcs)
51
+ ext_list = []
52
+ lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
53
+ for fun in funcs:
54
+ try:
55
+ ext_fun = extension.load(fun, name, lib_dir=lib_root)
56
+ except ParrotsException as e:
57
+ if 'No element registered' not in e.message:
58
+ warnings.warn(e.message)
59
+ ext_fun = get_fake_func(fun, e)
60
+ ext_list.append(ext_fun)
61
+ else:
62
+ if fun in has_return_value_ops:
63
+ ext_list.append(ext_fun.op)
64
+ else:
65
+ ext_list.append(ext_fun.op_)
66
+ return ExtModule(*ext_list)
67
+
68
+
69
+ def check_ops_exist():
70
+ ext_loader = pkgutil.find_loader('mmcv._ext')
71
+ return ext_loader is not None
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/logging.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+
4
+ import torch.distributed as dist
5
+
6
+ logger_initialized = {}
7
+
8
+
9
+ def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
10
+ """Initialize and get a logger by name.
11
+
12
+ If the logger has not been initialized, this method will initialize the
13
+ logger by adding one or two handlers, otherwise the initialized logger will
14
+ be directly returned. During initialization, a StreamHandler will always be
15
+ added. If `log_file` is specified and the process rank is 0, a FileHandler
16
+ will also be added.
17
+
18
+ Args:
19
+ name (str): Logger name.
20
+ log_file (str | None): The log filename. If specified, a FileHandler
21
+ will be added to the logger.
22
+ log_level (int): The logger level. Note that only the process of
23
+ rank 0 is affected, and other processes will set the level to
24
+ "Error" thus be silent most of the time.
25
+ file_mode (str): The file mode used in opening log file.
26
+ Defaults to 'w'.
27
+
28
+ Returns:
29
+ logging.Logger: The expected logger.
30
+ """
31
+ logger = logging.getLogger(name)
32
+ if name in logger_initialized:
33
+ return logger
34
+ # handle hierarchical names
35
+ # e.g., logger "a" is initialized, then logger "a.b" will skip the
36
+ # initialization since it is a child of "a".
37
+ for logger_name in logger_initialized:
38
+ if name.startswith(logger_name):
39
+ return logger
40
+
41
+ # handle duplicate logs to the console
42
+ # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
43
+ # to the root logger. As logger.propagate is True by default, this root
44
+ # level handler causes logging messages from rank>0 processes to
45
+ # unexpectedly show up on the console, creating much unwanted clutter.
46
+ # To fix this issue, we set the root logger's StreamHandler, if any, to log
47
+ # at the ERROR level.
48
+ for handler in logger.root.handlers:
49
+ if type(handler) is logging.StreamHandler:
50
+ handler.setLevel(logging.ERROR)
51
+
52
+ stream_handler = logging.StreamHandler()
53
+ handlers = [stream_handler]
54
+
55
+ if dist.is_available() and dist.is_initialized():
56
+ rank = dist.get_rank()
57
+ else:
58
+ rank = 0
59
+
60
+ # only rank 0 will add a FileHandler
61
+ if rank == 0 and log_file is not None:
62
+ # Here, the default behaviour of the official logger is 'a'. Thus, we
63
+ # provide an interface to change the file mode to the default
64
+ # behaviour.
65
+ file_handler = logging.FileHandler(log_file, file_mode)
66
+ handlers.append(file_handler)
67
+
68
+ formatter = logging.Formatter(
69
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
70
+ for handler in handlers:
71
+ handler.setFormatter(formatter)
72
+ handler.setLevel(log_level)
73
+ logger.addHandler(handler)
74
+
75
+ if rank == 0:
76
+ logger.setLevel(log_level)
77
+ else:
78
+ logger.setLevel(logging.ERROR)
79
+
80
+ logger_initialized[name] = True
81
+
82
+ return logger
83
+
84
+
85
+ def print_log(msg, logger=None, level=logging.INFO):
86
+ """Print a log message.
87
+
88
+ Args:
89
+ msg (str): The message to be logged.
90
+ logger (logging.Logger | str | None): The logger to be used.
91
+ Some special loggers are:
92
+ - "silent": no message will be printed.
93
+ - other str: the logger obtained with `get_root_logger(logger)`.
94
+ - None: The `print()` method will be used to print log messages.
95
+ level (int): Logging level. Only available when `logger` is a Logger
96
+ object or "root".
97
+ """
98
+ if logger is None:
99
+ print(msg)
100
+ elif isinstance(logger, logging.Logger):
101
+ logger.log(level, msg)
102
+ elif logger == 'silent':
103
+ pass
104
+ elif isinstance(logger, str):
105
+ _logger = get_logger(logger)
106
+ _logger.log(level, msg)
107
+ else:
108
+ raise TypeError(
109
+ 'logger should be either a logging.Logger object, str, '
110
+ f'"silent" or None, but got {type(logger)}')
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/misc.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import collections.abc
3
+ import functools
4
+ import itertools
5
+ import subprocess
6
+ import warnings
7
+ from collections import abc
8
+ from importlib import import_module
9
+ from inspect import getfullargspec
10
+ from itertools import repeat
11
+
12
+
13
+ # From PyTorch internals
14
+ def _ntuple(n):
15
+
16
+ def parse(x):
17
+ if isinstance(x, collections.abc.Iterable):
18
+ return x
19
+ return tuple(repeat(x, n))
20
+
21
+ return parse
22
+
23
+
24
+ to_1tuple = _ntuple(1)
25
+ to_2tuple = _ntuple(2)
26
+ to_3tuple = _ntuple(3)
27
+ to_4tuple = _ntuple(4)
28
+ to_ntuple = _ntuple
29
+
30
+
31
+ def is_str(x):
32
+ """Whether the input is an string instance.
33
+
34
+ Note: This method is deprecated since python 2 is no longer supported.
35
+ """
36
+ return isinstance(x, str)
37
+
38
+
39
+ def import_modules_from_strings(imports, allow_failed_imports=False):
40
+ """Import modules from the given list of strings.
41
+
42
+ Args:
43
+ imports (list | str | None): The given module names to be imported.
44
+ allow_failed_imports (bool): If True, the failed imports will return
45
+ None. Otherwise, an ImportError is raise. Default: False.
46
+
47
+ Returns:
48
+ list[module] | module | None: The imported modules.
49
+
50
+ Examples:
51
+ >>> osp, sys = import_modules_from_strings(
52
+ ... ['os.path', 'sys'])
53
+ >>> import os.path as osp_
54
+ >>> import sys as sys_
55
+ >>> assert osp == osp_
56
+ >>> assert sys == sys_
57
+ """
58
+ if not imports:
59
+ return
60
+ single_import = False
61
+ if isinstance(imports, str):
62
+ single_import = True
63
+ imports = [imports]
64
+ if not isinstance(imports, list):
65
+ raise TypeError(
66
+ f'custom_imports must be a list but got type {type(imports)}')
67
+ imported = []
68
+ for imp in imports:
69
+ if not isinstance(imp, str):
70
+ raise TypeError(
71
+ f'{imp} is of type {type(imp)} and cannot be imported.')
72
+ try:
73
+ imported_tmp = import_module(imp)
74
+ except ImportError:
75
+ if allow_failed_imports:
76
+ warnings.warn(f'{imp} failed to import and is ignored.',
77
+ UserWarning)
78
+ imported_tmp = None
79
+ else:
80
+ raise ImportError
81
+ imported.append(imported_tmp)
82
+ if single_import:
83
+ imported = imported[0]
84
+ return imported
85
+
86
+
87
+ def iter_cast(inputs, dst_type, return_type=None):
88
+ """Cast elements of an iterable object into some type.
89
+
90
+ Args:
91
+ inputs (Iterable): The input object.
92
+ dst_type (type): Destination type.
93
+ return_type (type, optional): If specified, the output object will be
94
+ converted to this type, otherwise an iterator.
95
+
96
+ Returns:
97
+ iterator or specified type: The converted object.
98
+ """
99
+ if not isinstance(inputs, abc.Iterable):
100
+ raise TypeError('inputs must be an iterable object')
101
+ if not isinstance(dst_type, type):
102
+ raise TypeError('"dst_type" must be a valid type')
103
+
104
+ out_iterable = map(dst_type, inputs)
105
+
106
+ if return_type is None:
107
+ return out_iterable
108
+ else:
109
+ return return_type(out_iterable)
110
+
111
+
112
+ def list_cast(inputs, dst_type):
113
+ """Cast elements of an iterable object into a list of some type.
114
+
115
+ A partial method of :func:`iter_cast`.
116
+ """
117
+ return iter_cast(inputs, dst_type, return_type=list)
118
+
119
+
120
+ def tuple_cast(inputs, dst_type):
121
+ """Cast elements of an iterable object into a tuple of some type.
122
+
123
+ A partial method of :func:`iter_cast`.
124
+ """
125
+ return iter_cast(inputs, dst_type, return_type=tuple)
126
+
127
+
128
+ def is_seq_of(seq, expected_type, seq_type=None):
129
+ """Check whether it is a sequence of some type.
130
+
131
+ Args:
132
+ seq (Sequence): The sequence to be checked.
133
+ expected_type (type): Expected type of sequence items.
134
+ seq_type (type, optional): Expected sequence type.
135
+
136
+ Returns:
137
+ bool: Whether the sequence is valid.
138
+ """
139
+ if seq_type is None:
140
+ exp_seq_type = abc.Sequence
141
+ else:
142
+ assert isinstance(seq_type, type)
143
+ exp_seq_type = seq_type
144
+ if not isinstance(seq, exp_seq_type):
145
+ return False
146
+ for item in seq:
147
+ if not isinstance(item, expected_type):
148
+ return False
149
+ return True
150
+
151
+
152
+ def is_list_of(seq, expected_type):
153
+ """Check whether it is a list of some type.
154
+
155
+ A partial method of :func:`is_seq_of`.
156
+ """
157
+ return is_seq_of(seq, expected_type, seq_type=list)
158
+
159
+
160
+ def is_tuple_of(seq, expected_type):
161
+ """Check whether it is a tuple of some type.
162
+
163
+ A partial method of :func:`is_seq_of`.
164
+ """
165
+ return is_seq_of(seq, expected_type, seq_type=tuple)
166
+
167
+
168
+ def slice_list(in_list, lens):
169
+ """Slice a list into several sub lists by a list of given length.
170
+
171
+ Args:
172
+ in_list (list): The list to be sliced.
173
+ lens(int or list): The expected length of each out list.
174
+
175
+ Returns:
176
+ list: A list of sliced list.
177
+ """
178
+ if isinstance(lens, int):
179
+ assert len(in_list) % lens == 0
180
+ lens = [lens] * int(len(in_list) / lens)
181
+ if not isinstance(lens, list):
182
+ raise TypeError('"indices" must be an integer or a list of integers')
183
+ elif sum(lens) != len(in_list):
184
+ raise ValueError('sum of lens and list length does not '
185
+ f'match: {sum(lens)} != {len(in_list)}')
186
+ out_list = []
187
+ idx = 0
188
+ for i in range(len(lens)):
189
+ out_list.append(in_list[idx:idx + lens[i]])
190
+ idx += lens[i]
191
+ return out_list
192
+
193
+
194
+ def concat_list(in_list):
195
+ """Concatenate a list of list into a single list.
196
+
197
+ Args:
198
+ in_list (list): The list of list to be merged.
199
+
200
+ Returns:
201
+ list: The concatenated flat list.
202
+ """
203
+ return list(itertools.chain(*in_list))
204
+
205
+
206
+ def check_prerequisites(
207
+ prerequisites,
208
+ checker,
209
+ msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
210
+ 'found, please install them first.'): # yapf: disable
211
+ """A decorator factory to check if prerequisites are satisfied.
212
+
213
+ Args:
214
+ prerequisites (str of list[str]): Prerequisites to be checked.
215
+ checker (callable): The checker method that returns True if a
216
+ prerequisite is meet, False otherwise.
217
+ msg_tmpl (str): The message template with two variables.
218
+
219
+ Returns:
220
+ decorator: A specific decorator.
221
+ """
222
+
223
+ def wrap(func):
224
+
225
+ @functools.wraps(func)
226
+ def wrapped_func(*args, **kwargs):
227
+ requirements = [prerequisites] if isinstance(
228
+ prerequisites, str) else prerequisites
229
+ missing = []
230
+ for item in requirements:
231
+ if not checker(item):
232
+ missing.append(item)
233
+ if missing:
234
+ print(msg_tmpl.format(', '.join(missing), func.__name__))
235
+ raise RuntimeError('Prerequisites not meet.')
236
+ else:
237
+ return func(*args, **kwargs)
238
+
239
+ return wrapped_func
240
+
241
+ return wrap
242
+
243
+
244
+ def _check_py_package(package):
245
+ try:
246
+ import_module(package)
247
+ except ImportError:
248
+ return False
249
+ else:
250
+ return True
251
+
252
+
253
+ def _check_executable(cmd):
254
+ if subprocess.call(f'which {cmd}', shell=True) != 0:
255
+ return False
256
+ else:
257
+ return True
258
+
259
+
260
+ def requires_package(prerequisites):
261
+ """A decorator to check if some python packages are installed.
262
+
263
+ Example:
264
+ >>> @requires_package('numpy')
265
+ >>> func(arg1, args):
266
+ >>> return numpy.zeros(1)
267
+ array([0.])
268
+ >>> @requires_package(['numpy', 'non_package'])
269
+ >>> func(arg1, args):
270
+ >>> return numpy.zeros(1)
271
+ ImportError
272
+ """
273
+ return check_prerequisites(prerequisites, checker=_check_py_package)
274
+
275
+
276
+ def requires_executable(prerequisites):
277
+ """A decorator to check if some executable files are installed.
278
+
279
+ Example:
280
+ >>> @requires_executable('ffmpeg')
281
+ >>> func(arg1, args):
282
+ >>> print(1)
283
+ 1
284
+ """
285
+ return check_prerequisites(prerequisites, checker=_check_executable)
286
+
287
+
288
+ def deprecated_api_warning(name_dict, cls_name=None):
289
+ """A decorator to check if some arguments are deprecate and try to replace
290
+ deprecate src_arg_name to dst_arg_name.
291
+
292
+ Args:
293
+ name_dict(dict):
294
+ key (str): Deprecate argument names.
295
+ val (str): Expected argument names.
296
+
297
+ Returns:
298
+ func: New function.
299
+ """
300
+
301
+ def api_warning_wrapper(old_func):
302
+
303
+ @functools.wraps(old_func)
304
+ def new_func(*args, **kwargs):
305
+ # get the arg spec of the decorated method
306
+ args_info = getfullargspec(old_func)
307
+ # get name of the function
308
+ func_name = old_func.__name__
309
+ if cls_name is not None:
310
+ func_name = f'{cls_name}.{func_name}'
311
+ if args:
312
+ arg_names = args_info.args[:len(args)]
313
+ for src_arg_name, dst_arg_name in name_dict.items():
314
+ if src_arg_name in arg_names:
315
+ warnings.warn(
316
+ f'"{src_arg_name}" is deprecated in '
317
+ f'`{func_name}`, please use "{dst_arg_name}" '
318
+ 'instead')
319
+ arg_names[arg_names.index(src_arg_name)] = dst_arg_name
320
+ if kwargs:
321
+ for src_arg_name, dst_arg_name in name_dict.items():
322
+ if src_arg_name in kwargs:
323
+
324
+ assert dst_arg_name not in kwargs, (
325
+ f'The expected behavior is to replace '
326
+ f'the deprecated key `{src_arg_name}` to '
327
+ f'new key `{dst_arg_name}`, but got them '
328
+ f'in the arguments at the same time, which '
329
+ f'is confusing. `{src_arg_name} will be '
330
+ f'deprecated in the future, please '
331
+ f'use `{dst_arg_name}` instead.')
332
+
333
+ warnings.warn(
334
+ f'"{src_arg_name}" is deprecated in '
335
+ f'`{func_name}`, please use "{dst_arg_name}" '
336
+ 'instead')
337
+ kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
338
+
339
+ # apply converted arguments to the decorated method
340
+ output = old_func(*args, **kwargs)
341
+ return output
342
+
343
+ return new_func
344
+
345
+ return api_warning_wrapper
346
+
347
+
348
+ def is_method_overridden(method, base_class, derived_class):
349
+ """Check if a method of base class is overridden in derived class.
350
+
351
+ Args:
352
+ method (str): the method name to check.
353
+ base_class (type): the class of the base class.
354
+ derived_class (type | Any): the class or instance of the derived class.
355
+ """
356
+ assert isinstance(base_class, type), \
357
+ "base_class doesn't accept instance, Please pass class instead."
358
+
359
+ if not isinstance(derived_class, type):
360
+ derived_class = derived_class.__class__
361
+
362
+ base_method = getattr(base_class, method)
363
+ derived_method = getattr(derived_class, method)
364
+ return derived_method != base_method
365
+
366
+
367
+ def has_method(obj: object, method: str) -> bool:
368
+ """Check whether the object has a method.
369
+
370
+ Args:
371
+ method (str): The method name to check.
372
+ obj (object): The object to check.
373
+
374
+ Returns:
375
+ bool: True if the object has the method else False.
376
+ """
377
+ return hasattr(obj, method) and callable(getattr(obj, method))
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_jit.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+
4
+ from .parrots_wrapper import TORCH_VERSION
5
+
6
+ parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
7
+
8
+ if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON':
9
+ from parrots.jit import pat as jit
10
+ else:
11
+
12
+ def jit(func=None,
13
+ check_input=None,
14
+ full_shape=True,
15
+ derivate=False,
16
+ coderize=False,
17
+ optimize=False):
18
+
19
+ def wrapper(func):
20
+
21
+ def wrapper_inner(*args, **kargs):
22
+ return func(*args, **kargs)
23
+
24
+ return wrapper_inner
25
+
26
+ if func is None:
27
+ return wrapper
28
+ else:
29
+ return func
30
+
31
+
32
+ if TORCH_VERSION == 'parrots':
33
+ from parrots.utils.tester import skip_no_elena
34
+ else:
35
+
36
+ def skip_no_elena(func):
37
+
38
+ def wrapper(*args, **kargs):
39
+ return func(*args, **kargs)
40
+
41
+ return wrapper
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_wrapper.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ TORCH_VERSION = torch.__version__
7
+
8
+
9
+ def is_rocm_pytorch() -> bool:
10
+ is_rocm = False
11
+ if TORCH_VERSION != 'parrots':
12
+ try:
13
+ from torch.utils.cpp_extension import ROCM_HOME
14
+ is_rocm = True if ((torch.version.hip is not None) and
15
+ (ROCM_HOME is not None)) else False
16
+ except ImportError:
17
+ pass
18
+ return is_rocm
19
+
20
+
21
+ def _get_cuda_home():
22
+ if TORCH_VERSION == 'parrots':
23
+ from parrots.utils.build_extension import CUDA_HOME
24
+ else:
25
+ if is_rocm_pytorch():
26
+ from torch.utils.cpp_extension import ROCM_HOME
27
+ CUDA_HOME = ROCM_HOME
28
+ else:
29
+ from torch.utils.cpp_extension import CUDA_HOME
30
+ return CUDA_HOME
31
+
32
+
33
+ def get_build_config():
34
+ if TORCH_VERSION == 'parrots':
35
+ from parrots.config import get_build_info
36
+ return get_build_info()
37
+ else:
38
+ return torch.__config__.show()
39
+
40
+
41
+ def _get_conv():
42
+ if TORCH_VERSION == 'parrots':
43
+ from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
44
+ else:
45
+ from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
46
+ return _ConvNd, _ConvTransposeMixin
47
+
48
+
49
+ def _get_dataloader():
50
+ if TORCH_VERSION == 'parrots':
51
+ from torch.utils.data import DataLoader, PoolDataLoader
52
+ else:
53
+ from torch.utils.data import DataLoader
54
+ PoolDataLoader = DataLoader
55
+ return DataLoader, PoolDataLoader
56
+
57
+
58
+ def _get_extension():
59
+ if TORCH_VERSION == 'parrots':
60
+ from parrots.utils.build_extension import BuildExtension, Extension
61
+ CppExtension = partial(Extension, cuda=False)
62
+ CUDAExtension = partial(Extension, cuda=True)
63
+ else:
64
+ from torch.utils.cpp_extension import (BuildExtension, CppExtension,
65
+ CUDAExtension)
66
+ return BuildExtension, CppExtension, CUDAExtension
67
+
68
+
69
+ def _get_pool():
70
+ if TORCH_VERSION == 'parrots':
71
+ from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
72
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
73
+ _MaxPoolNd)
74
+ else:
75
+ from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
76
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
77
+ _MaxPoolNd)
78
+ return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
79
+
80
+
81
+ def _get_norm():
82
+ if TORCH_VERSION == 'parrots':
83
+ from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
84
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
85
+ else:
86
+ from torch.nn.modules.instancenorm import _InstanceNorm
87
+ from torch.nn.modules.batchnorm import _BatchNorm
88
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm
89
+ return _BatchNorm, _InstanceNorm, SyncBatchNorm_
90
+
91
+
92
+ _ConvNd, _ConvTransposeMixin = _get_conv()
93
+ DataLoader, PoolDataLoader = _get_dataloader()
94
+ BuildExtension, CppExtension, CUDAExtension = _get_extension()
95
+ _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
96
+ _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
97
+
98
+
99
+ class SyncBatchNorm(SyncBatchNorm_):
100
+
101
+ def _check_input_dim(self, input):
102
+ if TORCH_VERSION == 'parrots':
103
+ if input.dim() < 2:
104
+ raise ValueError(
105
+ f'expected at least 2D input (got {input.dim()}D input)')
106
+ else:
107
+ super()._check_input_dim(input)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/path.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ import os.path as osp
4
+ from pathlib import Path
5
+
6
+ from .misc import is_str
7
+
8
+
9
+ def is_filepath(x):
10
+ return is_str(x) or isinstance(x, Path)
11
+
12
+
13
+ def fopen(filepath, *args, **kwargs):
14
+ if is_str(filepath):
15
+ return open(filepath, *args, **kwargs)
16
+ elif isinstance(filepath, Path):
17
+ return filepath.open(*args, **kwargs)
18
+ raise ValueError('`filepath` should be a string or a Path')
19
+
20
+
21
+ def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
22
+ if not osp.isfile(filename):
23
+ raise FileNotFoundError(msg_tmpl.format(filename))
24
+
25
+
26
+ def mkdir_or_exist(dir_name, mode=0o777):
27
+ if dir_name == '':
28
+ return
29
+ dir_name = osp.expanduser(dir_name)
30
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
31
+
32
+
33
+ def symlink(src, dst, overwrite=True, **kwargs):
34
+ if os.path.lexists(dst) and overwrite:
35
+ os.remove(dst)
36
+ os.symlink(src, dst, **kwargs)
37
+
38
+
39
+ def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
40
+ """Scan a directory to find the interested files.
41
+
42
+ Args:
43
+ dir_path (str | obj:`Path`): Path of the directory.
44
+ suffix (str | tuple(str), optional): File suffix that we are
45
+ interested in. Default: None.
46
+ recursive (bool, optional): If set to True, recursively scan the
47
+ directory. Default: False.
48
+ case_sensitive (bool, optional) : If set to False, ignore the case of
49
+ suffix. Default: True.
50
+
51
+ Returns:
52
+ A generator for all the interested files with relative paths.
53
+ """
54
+ if isinstance(dir_path, (str, Path)):
55
+ dir_path = str(dir_path)
56
+ else:
57
+ raise TypeError('"dir_path" must be a string or Path object')
58
+
59
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
60
+ raise TypeError('"suffix" must be a string or tuple of strings')
61
+
62
+ if suffix is not None and not case_sensitive:
63
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(
64
+ item.lower() for item in suffix)
65
+
66
+ root = dir_path
67
+
68
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
69
+ for entry in os.scandir(dir_path):
70
+ if not entry.name.startswith('.') and entry.is_file():
71
+ rel_path = osp.relpath(entry.path, root)
72
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
73
+ if suffix is None or _rel_path.endswith(suffix):
74
+ yield rel_path
75
+ elif recursive and os.path.isdir(entry.path):
76
+ # scan recursively if entry.path is a directory
77
+ yield from _scandir(entry.path, suffix, recursive,
78
+ case_sensitive)
79
+
80
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
81
+
82
+
83
+ def find_vcs_root(path, markers=('.git', )):
84
+ """Finds the root directory (including itself) of specified markers.
85
+
86
+ Args:
87
+ path (str): Path of directory or file.
88
+ markers (list[str], optional): List of file or directory names.
89
+
90
+ Returns:
91
+ The directory contained one of the markers or None if not found.
92
+ """
93
+ if osp.isfile(path):
94
+ path = osp.dirname(path)
95
+
96
+ prev, cur = None, osp.abspath(osp.expanduser(path))
97
+ while cur != prev:
98
+ if any(osp.exists(osp.join(cur, marker)) for marker in markers):
99
+ return cur
100
+ prev, cur = cur, osp.split(cur)[0]
101
+ return None
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/progressbar.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import sys
3
+ from collections.abc import Iterable
4
+ from multiprocessing import Pool
5
+ from shutil import get_terminal_size
6
+
7
+ from .timer import Timer
8
+
9
+
10
+ class ProgressBar:
11
+ """A progress bar which can print the progress."""
12
+
13
+ def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
14
+ self.task_num = task_num
15
+ self.bar_width = bar_width
16
+ self.completed = 0
17
+ self.file = file
18
+ if start:
19
+ self.start()
20
+
21
+ @property
22
+ def terminal_width(self):
23
+ width, _ = get_terminal_size()
24
+ return width
25
+
26
+ def start(self):
27
+ if self.task_num > 0:
28
+ self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
29
+ 'elapsed: 0s, ETA:')
30
+ else:
31
+ self.file.write('completed: 0, elapsed: 0s')
32
+ self.file.flush()
33
+ self.timer = Timer()
34
+
35
+ def update(self, num_tasks=1):
36
+ assert num_tasks > 0
37
+ self.completed += num_tasks
38
+ elapsed = self.timer.since_start()
39
+ if elapsed > 0:
40
+ fps = self.completed / elapsed
41
+ else:
42
+ fps = float('inf')
43
+ if self.task_num > 0:
44
+ percentage = self.completed / float(self.task_num)
45
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
46
+ msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
47
+ f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
48
+ f'ETA: {eta:5}s'
49
+
50
+ bar_width = min(self.bar_width,
51
+ int(self.terminal_width - len(msg)) + 2,
52
+ int(self.terminal_width * 0.6))
53
+ bar_width = max(2, bar_width)
54
+ mark_width = int(bar_width * percentage)
55
+ bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
56
+ self.file.write(msg.format(bar_chars))
57
+ else:
58
+ self.file.write(
59
+ f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
60
+ f' {fps:.1f} tasks/s')
61
+ self.file.flush()
62
+
63
+
64
+ def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
65
+ """Track the progress of tasks execution with a progress bar.
66
+
67
+ Tasks are done with a simple for-loop.
68
+
69
+ Args:
70
+ func (callable): The function to be applied to each task.
71
+ tasks (list or tuple[Iterable, int]): A list of tasks or
72
+ (tasks, total num).
73
+ bar_width (int): Width of progress bar.
74
+
75
+ Returns:
76
+ list: The task results.
77
+ """
78
+ if isinstance(tasks, tuple):
79
+ assert len(tasks) == 2
80
+ assert isinstance(tasks[0], Iterable)
81
+ assert isinstance(tasks[1], int)
82
+ task_num = tasks[1]
83
+ tasks = tasks[0]
84
+ elif isinstance(tasks, Iterable):
85
+ task_num = len(tasks)
86
+ else:
87
+ raise TypeError(
88
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
89
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
90
+ results = []
91
+ for task in tasks:
92
+ results.append(func(task, **kwargs))
93
+ prog_bar.update()
94
+ prog_bar.file.write('\n')
95
+ return results
96
+
97
+
98
+ def init_pool(process_num, initializer=None, initargs=None):
99
+ if initializer is None:
100
+ return Pool(process_num)
101
+ elif initargs is None:
102
+ return Pool(process_num, initializer)
103
+ else:
104
+ if not isinstance(initargs, tuple):
105
+ raise TypeError('"initargs" must be a tuple')
106
+ return Pool(process_num, initializer, initargs)
107
+
108
+
109
+ def track_parallel_progress(func,
110
+ tasks,
111
+ nproc,
112
+ initializer=None,
113
+ initargs=None,
114
+ bar_width=50,
115
+ chunksize=1,
116
+ skip_first=False,
117
+ keep_order=True,
118
+ file=sys.stdout):
119
+ """Track the progress of parallel task execution with a progress bar.
120
+
121
+ The built-in :mod:`multiprocessing` module is used for process pools and
122
+ tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
123
+
124
+ Args:
125
+ func (callable): The function to be applied to each task.
126
+ tasks (list or tuple[Iterable, int]): A list of tasks or
127
+ (tasks, total num).
128
+ nproc (int): Process (worker) number.
129
+ initializer (None or callable): Refer to :class:`multiprocessing.Pool`
130
+ for details.
131
+ initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
132
+ details.
133
+ chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
134
+ bar_width (int): Width of progress bar.
135
+ skip_first (bool): Whether to skip the first sample for each worker
136
+ when estimating fps, since the initialization step may takes
137
+ longer.
138
+ keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
139
+ :func:`Pool.imap_unordered` is used.
140
+
141
+ Returns:
142
+ list: The task results.
143
+ """
144
+ if isinstance(tasks, tuple):
145
+ assert len(tasks) == 2
146
+ assert isinstance(tasks[0], Iterable)
147
+ assert isinstance(tasks[1], int)
148
+ task_num = tasks[1]
149
+ tasks = tasks[0]
150
+ elif isinstance(tasks, Iterable):
151
+ task_num = len(tasks)
152
+ else:
153
+ raise TypeError(
154
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
155
+ pool = init_pool(nproc, initializer, initargs)
156
+ start = not skip_first
157
+ task_num -= nproc * chunksize * int(skip_first)
158
+ prog_bar = ProgressBar(task_num, bar_width, start, file=file)
159
+ results = []
160
+ if keep_order:
161
+ gen = pool.imap(func, tasks, chunksize)
162
+ else:
163
+ gen = pool.imap_unordered(func, tasks, chunksize)
164
+ for result in gen:
165
+ results.append(result)
166
+ if skip_first:
167
+ if len(results) < nproc * chunksize:
168
+ continue
169
+ elif len(results) == nproc * chunksize:
170
+ prog_bar.start()
171
+ continue
172
+ prog_bar.update()
173
+ prog_bar.file.write('\n')
174
+ pool.close()
175
+ pool.join()
176
+ return results
177
+
178
+
179
+ def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
180
+ """Track the progress of tasks iteration or enumeration with a progress
181
+ bar.
182
+
183
+ Tasks are yielded with a simple for-loop.
184
+
185
+ Args:
186
+ tasks (list or tuple[Iterable, int]): A list of tasks or
187
+ (tasks, total num).
188
+ bar_width (int): Width of progress bar.
189
+
190
+ Yields:
191
+ list: The task results.
192
+ """
193
+ if isinstance(tasks, tuple):
194
+ assert len(tasks) == 2
195
+ assert isinstance(tasks[0], Iterable)
196
+ assert isinstance(tasks[1], int)
197
+ task_num = tasks[1]
198
+ tasks = tasks[0]
199
+ elif isinstance(tasks, Iterable):
200
+ task_num = len(tasks)
201
+ else:
202
+ raise TypeError(
203
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
204
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
205
+ for task in tasks:
206
+ yield task
207
+ prog_bar.update()
208
+ prog_bar.file.write('\n')
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/registry.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import inspect
3
+ import warnings
4
+ from functools import partial
5
+
6
+ from .misc import is_seq_of
7
+
8
+
9
+ def build_from_cfg(cfg, registry, default_args=None):
10
+ """Build a module from config dict.
11
+
12
+ Args:
13
+ cfg (dict): Config dict. It should at least contain the key "type".
14
+ registry (:obj:`Registry`): The registry to search the type from.
15
+ default_args (dict, optional): Default initialization arguments.
16
+
17
+ Returns:
18
+ object: The constructed object.
19
+ """
20
+ if not isinstance(cfg, dict):
21
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
22
+ if 'type' not in cfg:
23
+ if default_args is None or 'type' not in default_args:
24
+ raise KeyError(
25
+ '`cfg` or `default_args` must contain the key "type", '
26
+ f'but got {cfg}\n{default_args}')
27
+ if not isinstance(registry, Registry):
28
+ raise TypeError('registry must be an mmcv.Registry object, '
29
+ f'but got {type(registry)}')
30
+ if not (isinstance(default_args, dict) or default_args is None):
31
+ raise TypeError('default_args must be a dict or None, '
32
+ f'but got {type(default_args)}')
33
+
34
+ args = cfg.copy()
35
+
36
+ if default_args is not None:
37
+ for name, value in default_args.items():
38
+ args.setdefault(name, value)
39
+
40
+ obj_type = args.pop('type')
41
+ if isinstance(obj_type, str):
42
+ obj_cls = registry.get(obj_type)
43
+ if obj_cls is None:
44
+ raise KeyError(
45
+ f'{obj_type} is not in the {registry.name} registry')
46
+ elif inspect.isclass(obj_type):
47
+ obj_cls = obj_type
48
+ else:
49
+ raise TypeError(
50
+ f'type must be a str or valid type, but got {type(obj_type)}')
51
+ try:
52
+ return obj_cls(**args)
53
+ except Exception as e:
54
+ # Normal TypeError does not print class name.
55
+ raise type(e)(f'{obj_cls.__name__}: {e}')
56
+
57
+
58
+ class Registry:
59
+ """A registry to map strings to classes.
60
+
61
+ Registered object could be built from registry.
62
+ Example:
63
+ >>> MODELS = Registry('models')
64
+ >>> @MODELS.register_module()
65
+ >>> class ResNet:
66
+ >>> pass
67
+ >>> resnet = MODELS.build(dict(type='ResNet'))
68
+
69
+ Please refer to
70
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
71
+ advanced usage.
72
+
73
+ Args:
74
+ name (str): Registry name.
75
+ build_func(func, optional): Build function to construct instance from
76
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
77
+ ``build_func`` is specified. If ``parent`` is specified and
78
+ ``build_func`` is not given, ``build_func`` will be inherited
79
+ from ``parent``. Default: None.
80
+ parent (Registry, optional): Parent registry. The class registered in
81
+ children registry could be built from parent. Default: None.
82
+ scope (str, optional): The scope of registry. It is the key to search
83
+ for children registry. If not specified, scope will be the name of
84
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
85
+ Default: None.
86
+ """
87
+
88
+ def __init__(self, name, build_func=None, parent=None, scope=None):
89
+ self._name = name
90
+ self._module_dict = dict()
91
+ self._children = dict()
92
+ self._scope = self.infer_scope() if scope is None else scope
93
+
94
+ # self.build_func will be set with the following priority:
95
+ # 1. build_func
96
+ # 2. parent.build_func
97
+ # 3. build_from_cfg
98
+ if build_func is None:
99
+ if parent is not None:
100
+ self.build_func = parent.build_func
101
+ else:
102
+ self.build_func = build_from_cfg
103
+ else:
104
+ self.build_func = build_func
105
+ if parent is not None:
106
+ assert isinstance(parent, Registry)
107
+ parent._add_children(self)
108
+ self.parent = parent
109
+ else:
110
+ self.parent = None
111
+
112
+ def __len__(self):
113
+ return len(self._module_dict)
114
+
115
+ def __contains__(self, key):
116
+ return self.get(key) is not None
117
+
118
+ def __repr__(self):
119
+ format_str = self.__class__.__name__ + \
120
+ f'(name={self._name}, ' \
121
+ f'items={self._module_dict})'
122
+ return format_str
123
+
124
+ @staticmethod
125
+ def infer_scope():
126
+ """Infer the scope of registry.
127
+
128
+ The name of the package where registry is defined will be returned.
129
+
130
+ Example:
131
+ # in mmdet/models/backbone/resnet.py
132
+ >>> MODELS = Registry('models')
133
+ >>> @MODELS.register_module()
134
+ >>> class ResNet:
135
+ >>> pass
136
+ The scope of ``ResNet`` will be ``mmdet``.
137
+
138
+
139
+ Returns:
140
+ scope (str): The inferred scope name.
141
+ """
142
+ # inspect.stack() trace where this function is called, the index-2
143
+ # indicates the frame where `infer_scope()` is called
144
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
145
+ split_filename = filename.split('.')
146
+ return split_filename[0]
147
+
148
+ @staticmethod
149
+ def split_scope_key(key):
150
+ """Split scope and key.
151
+
152
+ The first scope will be split from key.
153
+
154
+ Examples:
155
+ >>> Registry.split_scope_key('mmdet.ResNet')
156
+ 'mmdet', 'ResNet'
157
+ >>> Registry.split_scope_key('ResNet')
158
+ None, 'ResNet'
159
+
160
+ Return:
161
+ scope (str, None): The first scope.
162
+ key (str): The remaining key.
163
+ """
164
+ split_index = key.find('.')
165
+ if split_index != -1:
166
+ return key[:split_index], key[split_index + 1:]
167
+ else:
168
+ return None, key
169
+
170
+ @property
171
+ def name(self):
172
+ return self._name
173
+
174
+ @property
175
+ def scope(self):
176
+ return self._scope
177
+
178
+ @property
179
+ def module_dict(self):
180
+ return self._module_dict
181
+
182
+ @property
183
+ def children(self):
184
+ return self._children
185
+
186
+ def get(self, key):
187
+ """Get the registry record.
188
+
189
+ Args:
190
+ key (str): The class name in string format.
191
+
192
+ Returns:
193
+ class: The corresponding class.
194
+ """
195
+ scope, real_key = self.split_scope_key(key)
196
+ if scope is None or scope == self._scope:
197
+ # get from self
198
+ if real_key in self._module_dict:
199
+ return self._module_dict[real_key]
200
+ else:
201
+ # get from self._children
202
+ if scope in self._children:
203
+ return self._children[scope].get(real_key)
204
+ else:
205
+ # goto root
206
+ parent = self.parent
207
+ while parent.parent is not None:
208
+ parent = parent.parent
209
+ return parent.get(key)
210
+
211
+ def build(self, *args, **kwargs):
212
+ return self.build_func(*args, **kwargs, registry=self)
213
+
214
+ def _add_children(self, registry):
215
+ """Add children for a registry.
216
+
217
+ The ``registry`` will be added as children based on its scope.
218
+ The parent registry could build objects from children registry.
219
+
220
+ Example:
221
+ >>> models = Registry('models')
222
+ >>> mmdet_models = Registry('models', parent=models)
223
+ >>> @mmdet_models.register_module()
224
+ >>> class ResNet:
225
+ >>> pass
226
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
227
+ """
228
+
229
+ assert isinstance(registry, Registry)
230
+ assert registry.scope is not None
231
+ assert registry.scope not in self.children, \
232
+ f'scope {registry.scope} exists in {self.name} registry'
233
+ self.children[registry.scope] = registry
234
+
235
+ def _register_module(self, module_class, module_name=None, force=False):
236
+ if not inspect.isclass(module_class):
237
+ raise TypeError('module must be a class, '
238
+ f'but got {type(module_class)}')
239
+
240
+ if module_name is None:
241
+ module_name = module_class.__name__
242
+ if isinstance(module_name, str):
243
+ module_name = [module_name]
244
+ for name in module_name:
245
+ if not force and name in self._module_dict:
246
+ raise KeyError(f'{name} is already registered '
247
+ f'in {self.name}')
248
+ self._module_dict[name] = module_class
249
+
250
+ def deprecated_register_module(self, cls=None, force=False):
251
+ warnings.warn(
252
+ 'The old API of register_module(module, force=False) '
253
+ 'is deprecated and will be removed, please use the new API '
254
+ 'register_module(name=None, force=False, module=None) instead.')
255
+ if cls is None:
256
+ return partial(self.deprecated_register_module, force=force)
257
+ self._register_module(cls, force=force)
258
+ return cls
259
+
260
+ def register_module(self, name=None, force=False, module=None):
261
+ """Register a module.
262
+
263
+ A record will be added to `self._module_dict`, whose key is the class
264
+ name or the specified name, and value is the class itself.
265
+ It can be used as a decorator or a normal function.
266
+
267
+ Example:
268
+ >>> backbones = Registry('backbone')
269
+ >>> @backbones.register_module()
270
+ >>> class ResNet:
271
+ >>> pass
272
+
273
+ >>> backbones = Registry('backbone')
274
+ >>> @backbones.register_module(name='mnet')
275
+ >>> class MobileNet:
276
+ >>> pass
277
+
278
+ >>> backbones = Registry('backbone')
279
+ >>> class ResNet:
280
+ >>> pass
281
+ >>> backbones.register_module(ResNet)
282
+
283
+ Args:
284
+ name (str | None): The module name to be registered. If not
285
+ specified, the class name will be used.
286
+ force (bool, optional): Whether to override an existing class with
287
+ the same name. Default: False.
288
+ module (type): Module class to be registered.
289
+ """
290
+ if not isinstance(force, bool):
291
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
292
+ # NOTE: This is a walkaround to be compatible with the old api,
293
+ # while it may introduce unexpected bugs.
294
+ if isinstance(name, type):
295
+ return self.deprecated_register_module(name, force=force)
296
+
297
+ # raise the error ahead of time
298
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
299
+ raise TypeError(
300
+ 'name must be either of None, an instance of str or a sequence'
301
+ f' of str, but got {type(name)}')
302
+
303
+ # use it as a normal method: x.register_module(module=SomeClass)
304
+ if module is not None:
305
+ self._register_module(
306
+ module_class=module, module_name=name, force=force)
307
+ return module
308
+
309
+ # use it as a decorator: @x.register_module()
310
+ def _register(cls):
311
+ self._register_module(
312
+ module_class=cls, module_name=name, force=force)
313
+ return cls
314
+
315
+ return _register
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/testing.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab.
2
+ import sys
3
+ from collections.abc import Iterable
4
+ from runpy import run_path
5
+ from shlex import split
6
+ from typing import Any, Dict, List
7
+ from unittest.mock import patch
8
+
9
+
10
+ def check_python_script(cmd):
11
+ """Run the python cmd script with `__main__`. The difference between
12
+ `os.system` is that, this function exectues code in the current process, so
13
+ that it can be tracked by coverage tools. Currently it supports two forms:
14
+
15
+ - ./tests/data/scripts/hello.py zz
16
+ - python tests/data/scripts/hello.py zz
17
+ """
18
+ args = split(cmd)
19
+ if args[0] == 'python':
20
+ args = args[1:]
21
+ with patch.object(sys, 'argv', args):
22
+ run_path(args[0], run_name='__main__')
23
+
24
+
25
+ def _any(judge_result):
26
+ """Since built-in ``any`` works only when the element of iterable is not
27
+ iterable, implement the function."""
28
+ if not isinstance(judge_result, Iterable):
29
+ return judge_result
30
+
31
+ try:
32
+ for element in judge_result:
33
+ if _any(element):
34
+ return True
35
+ except TypeError:
36
+ # Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
37
+ if judge_result:
38
+ return True
39
+ return False
40
+
41
+
42
+ def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
43
+ expected_subset: Dict[Any, Any]) -> bool:
44
+ """Check if the dict_obj contains the expected_subset.
45
+
46
+ Args:
47
+ dict_obj (Dict[Any, Any]): Dict object to be checked.
48
+ expected_subset (Dict[Any, Any]): Subset expected to be contained in
49
+ dict_obj.
50
+
51
+ Returns:
52
+ bool: Whether the dict_obj contains the expected_subset.
53
+ """
54
+
55
+ for key, value in expected_subset.items():
56
+ if key not in dict_obj.keys() or _any(dict_obj[key] != value):
57
+ return False
58
+ return True
59
+
60
+
61
+ def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
62
+ """Check if attribute of class object is correct.
63
+
64
+ Args:
65
+ obj (object): Class object to be checked.
66
+ expected_attrs (Dict[str, Any]): Dict of the expected attrs.
67
+
68
+ Returns:
69
+ bool: Whether the attribute of class object is correct.
70
+ """
71
+ for attr, value in expected_attrs.items():
72
+ if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
73
+ return False
74
+ return True
75
+
76
+
77
+ def assert_dict_has_keys(obj: Dict[str, Any],
78
+ expected_keys: List[str]) -> bool:
79
+ """Check if the obj has all the expected_keys.
80
+
81
+ Args:
82
+ obj (Dict[str, Any]): Object to be checked.
83
+ expected_keys (List[str]): Keys expected to contained in the keys of
84
+ the obj.
85
+
86
+ Returns:
87
+ bool: Whether the obj has the expected keys.
88
+ """
89
+ return set(expected_keys).issubset(set(obj.keys()))
90
+
91
+
92
+ def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
93
+ """Check if target_keys is equal to result_keys.
94
+
95
+ Args:
96
+ result_keys (List[str]): Result keys to be checked.
97
+ target_keys (List[str]): Target keys to be checked.
98
+
99
+ Returns:
100
+ bool: Whether target_keys is equal to result_keys.
101
+ """
102
+ return set(result_keys) == set(target_keys)
103
+
104
+
105
+ def assert_is_norm_layer(module) -> bool:
106
+ """Check if the module is a norm layer.
107
+
108
+ Args:
109
+ module (nn.Module): The module to be checked.
110
+
111
+ Returns:
112
+ bool: Whether the module is a norm layer.
113
+ """
114
+ from .parrots_wrapper import _BatchNorm, _InstanceNorm
115
+ from torch.nn import GroupNorm, LayerNorm
116
+ norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
117
+ return isinstance(module, norm_layer_candidates)
118
+
119
+
120
+ def assert_params_all_zeros(module) -> bool:
121
+ """Check if the parameters of the module is all zeros.
122
+
123
+ Args:
124
+ module (nn.Module): The module to be checked.
125
+
126
+ Returns:
127
+ bool: Whether the parameters of the module is all zeros.
128
+ """
129
+ weight_data = module.weight.data
130
+ is_weight_zero = weight_data.allclose(
131
+ weight_data.new_zeros(weight_data.size()))
132
+
133
+ if hasattr(module, 'bias') and module.bias is not None:
134
+ bias_data = module.bias.data
135
+ is_bias_zero = bias_data.allclose(
136
+ bias_data.new_zeros(bias_data.size()))
137
+ else:
138
+ is_bias_zero = True
139
+
140
+ return is_weight_zero and is_bias_zero
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/timer.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from time import time
3
+
4
+
5
+ class TimerError(Exception):
6
+
7
+ def __init__(self, message):
8
+ self.message = message
9
+ super(TimerError, self).__init__(message)
10
+
11
+
12
+ class Timer:
13
+ """A flexible Timer class.
14
+
15
+ :Example:
16
+
17
+ >>> import time
18
+ >>> import annotator.mmpkg.mmcv as mmcv
19
+ >>> with mmcv.Timer():
20
+ >>> # simulate a code block that will run for 1s
21
+ >>> time.sleep(1)
22
+ 1.000
23
+ >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
24
+ >>> # simulate a code block that will run for 1s
25
+ >>> time.sleep(1)
26
+ it takes 1.0 seconds
27
+ >>> timer = mmcv.Timer()
28
+ >>> time.sleep(0.5)
29
+ >>> print(timer.since_start())
30
+ 0.500
31
+ >>> time.sleep(0.5)
32
+ >>> print(timer.since_last_check())
33
+ 0.500
34
+ >>> print(timer.since_start())
35
+ 1.000
36
+ """
37
+
38
+ def __init__(self, start=True, print_tmpl=None):
39
+ self._is_running = False
40
+ self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
41
+ if start:
42
+ self.start()
43
+
44
+ @property
45
+ def is_running(self):
46
+ """bool: indicate whether the timer is running"""
47
+ return self._is_running
48
+
49
+ def __enter__(self):
50
+ self.start()
51
+ return self
52
+
53
+ def __exit__(self, type, value, traceback):
54
+ print(self.print_tmpl.format(self.since_last_check()))
55
+ self._is_running = False
56
+
57
+ def start(self):
58
+ """Start the timer."""
59
+ if not self._is_running:
60
+ self._t_start = time()
61
+ self._is_running = True
62
+ self._t_last = time()
63
+
64
+ def since_start(self):
65
+ """Total time since the timer is started.
66
+
67
+ Returns (float): Time in seconds.
68
+ """
69
+ if not self._is_running:
70
+ raise TimerError('timer is not running')
71
+ self._t_last = time()
72
+ return self._t_last - self._t_start
73
+
74
+ def since_last_check(self):
75
+ """Time since the last checking.
76
+
77
+ Either :func:`since_start` or :func:`since_last_check` is a checking
78
+ operation.
79
+
80
+ Returns (float): Time in seconds.
81
+ """
82
+ if not self._is_running:
83
+ raise TimerError('timer is not running')
84
+ dur = time() - self._t_last
85
+ self._t_last = time()
86
+ return dur
87
+
88
+
89
+ _g_timers = {} # global timers
90
+
91
+
92
+ def check_time(timer_id):
93
+ """Add check points in a single line.
94
+
95
+ This method is suitable for running a task on a list of items. A timer will
96
+ be registered when the method is called for the first time.
97
+
98
+ :Example:
99
+
100
+ >>> import time
101
+ >>> import annotator.mmpkg.mmcv as mmcv
102
+ >>> for i in range(1, 6):
103
+ >>> # simulate a code block
104
+ >>> time.sleep(i)
105
+ >>> mmcv.check_time('task1')
106
+ 2.000
107
+ 3.000
108
+ 4.000
109
+ 5.000
110
+
111
+ Args:
112
+ timer_id (str): Timer identifier.
113
+ """
114
+ if timer_id not in _g_timers:
115
+ _g_timers[timer_id] = Timer()
116
+ return 0
117
+ else:
118
+ return _g_timers[timer_id].since_last_check()
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/trace.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+
5
+ from annotator.mmpkg.mmcv.utils import digit_version
6
+
7
+
8
+ def is_jit_tracing() -> bool:
9
+ if (torch.__version__ != 'parrots'
10
+ and digit_version(torch.__version__) >= digit_version('1.6.0')):
11
+ on_trace = torch.jit.is_tracing()
12
+ # In PyTorch 1.6, torch.jit.is_tracing has a bug.
13
+ # Refers to https://github.com/pytorch/pytorch/issues/42448
14
+ if isinstance(on_trace, bool):
15
+ return on_trace
16
+ else:
17
+ return torch._C._is_tracing()
18
+ else:
19
+ warnings.warn(
20
+ 'torch.jit.is_tracing is only supported after v1.6.0. '
21
+ 'Therefore is_tracing returns False automatically. Please '
22
+ 'set on_trace manually if you are using trace.', UserWarning)
23
+ return False
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/version_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ import subprocess
4
+ import warnings
5
+
6
+ from packaging.version import parse
7
+
8
+
9
+ def digit_version(version_str: str, length: int = 4):
10
+ """Convert a version string into a tuple of integers.
11
+
12
+ This method is usually used for comparing two versions. For pre-release
13
+ versions: alpha < beta < rc.
14
+
15
+ Args:
16
+ version_str (str): The version string.
17
+ length (int): The maximum number of version levels. Default: 4.
18
+
19
+ Returns:
20
+ tuple[int]: The version info in digits (integers).
21
+ """
22
+ assert 'parrots' not in version_str
23
+ version = parse(version_str)
24
+ assert version.release, f'failed to parse version {version_str}'
25
+ release = list(version.release)
26
+ release = release[:length]
27
+ if len(release) < length:
28
+ release = release + [0] * (length - len(release))
29
+ if version.is_prerelease:
30
+ mapping = {'a': -3, 'b': -2, 'rc': -1}
31
+ val = -4
32
+ # version.pre can be None
33
+ if version.pre:
34
+ if version.pre[0] not in mapping:
35
+ warnings.warn(f'unknown prerelease version {version.pre[0]}, '
36
+ 'version checking may go wrong')
37
+ else:
38
+ val = mapping[version.pre[0]]
39
+ release.extend([val, version.pre[-1]])
40
+ else:
41
+ release.extend([val, 0])
42
+
43
+ elif version.is_postrelease:
44
+ release.extend([1, version.post])
45
+ else:
46
+ release.extend([0, 0])
47
+ return tuple(release)
48
+
49
+
50
+ def _minimal_ext_cmd(cmd):
51
+ # construct minimal environment
52
+ env = {}
53
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
54
+ v = os.environ.get(k)
55
+ if v is not None:
56
+ env[k] = v
57
+ # LANGUAGE is used on win32
58
+ env['LANGUAGE'] = 'C'
59
+ env['LANG'] = 'C'
60
+ env['LC_ALL'] = 'C'
61
+ out = subprocess.Popen(
62
+ cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
63
+ return out
64
+
65
+
66
+ def get_git_hash(fallback='unknown', digits=None):
67
+ """Get the git hash of the current repo.
68
+
69
+ Args:
70
+ fallback (str, optional): The fallback string when git hash is
71
+ unavailable. Defaults to 'unknown'.
72
+ digits (int, optional): kept digits of the hash. Defaults to None,
73
+ meaning all digits are kept.
74
+
75
+ Returns:
76
+ str: Git commit hash.
77
+ """
78
+
79
+ if digits is not None and not isinstance(digits, int):
80
+ raise TypeError('digits must be None or an integer')
81
+
82
+ try:
83
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
84
+ sha = out.strip().decode('ascii')
85
+ if digits is not None:
86
+ sha = sha[:digits]
87
+ except OSError:
88
+ sha = fallback
89
+
90
+ return sha
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/version.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ __version__ = '1.3.17'
3
+
4
+
5
+ def parse_version_info(version_str: str, length: int = 4) -> tuple:
6
+ """Parse a version string into a tuple.
7
+
8
+ Args:
9
+ version_str (str): The version string.
10
+ length (int): The maximum number of version levels. Default: 4.
11
+
12
+ Returns:
13
+ tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
14
+ (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
15
+ (2, 0, 0, 0, 'rc', 1) (when length is set to 4).
16
+ """
17
+ from packaging.version import parse
18
+ version = parse(version_str)
19
+ assert version.release, f'failed to parse version {version_str}'
20
+ release = list(version.release)
21
+ release = release[:length]
22
+ if len(release) < length:
23
+ release = release + [0] * (length - len(release))
24
+ if version.is_prerelease:
25
+ release.extend(list(version.pre))
26
+ elif version.is_postrelease:
27
+ release.extend(list(version.post))
28
+ else:
29
+ release.extend([0, 0])
30
+ return tuple(release)
31
+
32
+
33
+ version_info = tuple(int(x) for x in __version__.split('.')[:3])
34
+
35
+ __all__ = ['__version__', 'version_info', 'parse_version_info']
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .io import Cache, VideoReader, frames2video
3
+ from .optflow import (dequantize_flow, flow_from_bytes, flow_warp, flowread,
4
+ flowwrite, quantize_flow, sparse_flow_from_bytes)
5
+ from .processing import concat_video, convert_video, cut_video, resize_video
6
+
7
+ __all__ = [
8
+ 'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
9
+ 'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow',
10
+ 'dequantize_flow', 'flow_warp', 'flow_from_bytes', 'sparse_flow_from_bytes'
11
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/io.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ from collections import OrderedDict
4
+
5
+ import cv2
6
+ from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
7
+ CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
8
+ CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
9
+
10
+ from annotator.mmpkg.mmcv.utils import (check_file_exist, mkdir_or_exist, scandir,
11
+ track_progress)
12
+
13
+
14
+ class Cache:
15
+
16
+ def __init__(self, capacity):
17
+ self._cache = OrderedDict()
18
+ self._capacity = int(capacity)
19
+ if capacity <= 0:
20
+ raise ValueError('capacity must be a positive integer')
21
+
22
+ @property
23
+ def capacity(self):
24
+ return self._capacity
25
+
26
+ @property
27
+ def size(self):
28
+ return len(self._cache)
29
+
30
+ def put(self, key, val):
31
+ if key in self._cache:
32
+ return
33
+ if len(self._cache) >= self.capacity:
34
+ self._cache.popitem(last=False)
35
+ self._cache[key] = val
36
+
37
+ def get(self, key, default=None):
38
+ val = self._cache[key] if key in self._cache else default
39
+ return val
40
+
41
+
42
+ class VideoReader:
43
+ """Video class with similar usage to a list object.
44
+
45
+ This video warpper class provides convenient apis to access frames.
46
+ There exists an issue of OpenCV's VideoCapture class that jumping to a
47
+ certain frame may be inaccurate. It is fixed in this class by checking
48
+ the position after jumping each time.
49
+ Cache is used when decoding videos. So if the same frame is visited for
50
+ the second time, there is no need to decode again if it is stored in the
51
+ cache.
52
+
53
+ :Example:
54
+
55
+ >>> import annotator.mmpkg.mmcv as mmcv
56
+ >>> v = mmcv.VideoReader('sample.mp4')
57
+ >>> len(v) # get the total frame number with `len()`
58
+ 120
59
+ >>> for img in v: # v is iterable
60
+ >>> mmcv.imshow(img)
61
+ >>> v[5] # get the 6th frame
62
+ """
63
+
64
+ def __init__(self, filename, cache_capacity=10):
65
+ # Check whether the video path is a url
66
+ if not filename.startswith(('https://', 'http://')):
67
+ check_file_exist(filename, 'Video file not found: ' + filename)
68
+ self._vcap = cv2.VideoCapture(filename)
69
+ assert cache_capacity > 0
70
+ self._cache = Cache(cache_capacity)
71
+ self._position = 0
72
+ # get basic info
73
+ self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
74
+ self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
75
+ self._fps = self._vcap.get(CAP_PROP_FPS)
76
+ self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
77
+ self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
78
+
79
+ @property
80
+ def vcap(self):
81
+ """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
82
+ return self._vcap
83
+
84
+ @property
85
+ def opened(self):
86
+ """bool: Indicate whether the video is opened."""
87
+ return self._vcap.isOpened()
88
+
89
+ @property
90
+ def width(self):
91
+ """int: Width of video frames."""
92
+ return self._width
93
+
94
+ @property
95
+ def height(self):
96
+ """int: Height of video frames."""
97
+ return self._height
98
+
99
+ @property
100
+ def resolution(self):
101
+ """tuple: Video resolution (width, height)."""
102
+ return (self._width, self._height)
103
+
104
+ @property
105
+ def fps(self):
106
+ """float: FPS of the video."""
107
+ return self._fps
108
+
109
+ @property
110
+ def frame_cnt(self):
111
+ """int: Total frames of the video."""
112
+ return self._frame_cnt
113
+
114
+ @property
115
+ def fourcc(self):
116
+ """str: "Four character code" of the video."""
117
+ return self._fourcc
118
+
119
+ @property
120
+ def position(self):
121
+ """int: Current cursor position, indicating frame decoded."""
122
+ return self._position
123
+
124
+ def _get_real_position(self):
125
+ return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
126
+
127
+ def _set_real_position(self, frame_id):
128
+ self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
129
+ pos = self._get_real_position()
130
+ for _ in range(frame_id - pos):
131
+ self._vcap.read()
132
+ self._position = frame_id
133
+
134
+ def read(self):
135
+ """Read the next frame.
136
+
137
+ If the next frame have been decoded before and in the cache, then
138
+ return it directly, otherwise decode, cache and return it.
139
+
140
+ Returns:
141
+ ndarray or None: Return the frame if successful, otherwise None.
142
+ """
143
+ # pos = self._position
144
+ if self._cache:
145
+ img = self._cache.get(self._position)
146
+ if img is not None:
147
+ ret = True
148
+ else:
149
+ if self._position != self._get_real_position():
150
+ self._set_real_position(self._position)
151
+ ret, img = self._vcap.read()
152
+ if ret:
153
+ self._cache.put(self._position, img)
154
+ else:
155
+ ret, img = self._vcap.read()
156
+ if ret:
157
+ self._position += 1
158
+ return img
159
+
160
+ def get_frame(self, frame_id):
161
+ """Get frame by index.
162
+
163
+ Args:
164
+ frame_id (int): Index of the expected frame, 0-based.
165
+
166
+ Returns:
167
+ ndarray or None: Return the frame if successful, otherwise None.
168
+ """
169
+ if frame_id < 0 or frame_id >= self._frame_cnt:
170
+ raise IndexError(
171
+ f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
172
+ if frame_id == self._position:
173
+ return self.read()
174
+ if self._cache:
175
+ img = self._cache.get(frame_id)
176
+ if img is not None:
177
+ self._position = frame_id + 1
178
+ return img
179
+ self._set_real_position(frame_id)
180
+ ret, img = self._vcap.read()
181
+ if ret:
182
+ if self._cache:
183
+ self._cache.put(self._position, img)
184
+ self._position += 1
185
+ return img
186
+
187
+ def current_frame(self):
188
+ """Get the current frame (frame that is just visited).
189
+
190
+ Returns:
191
+ ndarray or None: If the video is fresh, return None, otherwise
192
+ return the frame.
193
+ """
194
+ if self._position == 0:
195
+ return None
196
+ return self._cache.get(self._position - 1)
197
+
198
+ def cvt2frames(self,
199
+ frame_dir,
200
+ file_start=0,
201
+ filename_tmpl='{:06d}.jpg',
202
+ start=0,
203
+ max_num=0,
204
+ show_progress=True):
205
+ """Convert a video to frame images.
206
+
207
+ Args:
208
+ frame_dir (str): Output directory to store all the frame images.
209
+ file_start (int): Filenames will start from the specified number.
210
+ filename_tmpl (str): Filename template with the index as the
211
+ placeholder.
212
+ start (int): The starting frame index.
213
+ max_num (int): Maximum number of frames to be written.
214
+ show_progress (bool): Whether to show a progress bar.
215
+ """
216
+ mkdir_or_exist(frame_dir)
217
+ if max_num == 0:
218
+ task_num = self.frame_cnt - start
219
+ else:
220
+ task_num = min(self.frame_cnt - start, max_num)
221
+ if task_num <= 0:
222
+ raise ValueError('start must be less than total frame number')
223
+ if start > 0:
224
+ self._set_real_position(start)
225
+
226
+ def write_frame(file_idx):
227
+ img = self.read()
228
+ if img is None:
229
+ return
230
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
231
+ cv2.imwrite(filename, img)
232
+
233
+ if show_progress:
234
+ track_progress(write_frame, range(file_start,
235
+ file_start + task_num))
236
+ else:
237
+ for i in range(task_num):
238
+ write_frame(file_start + i)
239
+
240
+ def __len__(self):
241
+ return self.frame_cnt
242
+
243
+ def __getitem__(self, index):
244
+ if isinstance(index, slice):
245
+ return [
246
+ self.get_frame(i)
247
+ for i in range(*index.indices(self.frame_cnt))
248
+ ]
249
+ # support negative indexing
250
+ if index < 0:
251
+ index += self.frame_cnt
252
+ if index < 0:
253
+ raise IndexError('index out of range')
254
+ return self.get_frame(index)
255
+
256
+ def __iter__(self):
257
+ self._set_real_position(0)
258
+ return self
259
+
260
+ def __next__(self):
261
+ img = self.read()
262
+ if img is not None:
263
+ return img
264
+ else:
265
+ raise StopIteration
266
+
267
+ next = __next__
268
+
269
+ def __enter__(self):
270
+ return self
271
+
272
+ def __exit__(self, exc_type, exc_value, traceback):
273
+ self._vcap.release()
274
+
275
+
276
+ def frames2video(frame_dir,
277
+ video_file,
278
+ fps=30,
279
+ fourcc='XVID',
280
+ filename_tmpl='{:06d}.jpg',
281
+ start=0,
282
+ end=0,
283
+ show_progress=True):
284
+ """Read the frame images from a directory and join them as a video.
285
+
286
+ Args:
287
+ frame_dir (str): The directory containing video frames.
288
+ video_file (str): Output filename.
289
+ fps (float): FPS of the output video.
290
+ fourcc (str): Fourcc of the output video, this should be compatible
291
+ with the output file type.
292
+ filename_tmpl (str): Filename template with the index as the variable.
293
+ start (int): Starting frame index.
294
+ end (int): Ending frame index.
295
+ show_progress (bool): Whether to show a progress bar.
296
+ """
297
+ if end == 0:
298
+ ext = filename_tmpl.split('.')[-1]
299
+ end = len([name for name in scandir(frame_dir, ext)])
300
+ first_file = osp.join(frame_dir, filename_tmpl.format(start))
301
+ check_file_exist(first_file, 'The start frame not found: ' + first_file)
302
+ img = cv2.imread(first_file)
303
+ height, width = img.shape[:2]
304
+ resolution = (width, height)
305
+ vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
306
+ resolution)
307
+
308
+ def write_frame(file_idx):
309
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
310
+ img = cv2.imread(filename)
311
+ vwriter.write(img)
312
+
313
+ if show_progress:
314
+ track_progress(write_frame, range(start, end))
315
+ else:
316
+ for i in range(start, end):
317
+ write_frame(i)
318
+ vwriter.release()
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/optflow.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from annotator.mmpkg.mmcv.arraymisc import dequantize, quantize
8
+ from annotator.mmpkg.mmcv.image import imread, imwrite
9
+ from annotator.mmpkg.mmcv.utils import is_str
10
+
11
+
12
+ def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
13
+ """Read an optical flow map.
14
+
15
+ Args:
16
+ flow_or_path (ndarray or str): A flow map or filepath.
17
+ quantize (bool): whether to read quantized pair, if set to True,
18
+ remaining args will be passed to :func:`dequantize_flow`.
19
+ concat_axis (int): The axis that dx and dy are concatenated,
20
+ can be either 0 or 1. Ignored if quantize is False.
21
+
22
+ Returns:
23
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
24
+ """
25
+ if isinstance(flow_or_path, np.ndarray):
26
+ if (flow_or_path.ndim != 3) or (flow_or_path.shape[-1] != 2):
27
+ raise ValueError(f'Invalid flow with shape {flow_or_path.shape}')
28
+ return flow_or_path
29
+ elif not is_str(flow_or_path):
30
+ raise TypeError(f'"flow_or_path" must be a filename or numpy array, '
31
+ f'not {type(flow_or_path)}')
32
+
33
+ if not quantize:
34
+ with open(flow_or_path, 'rb') as f:
35
+ try:
36
+ header = f.read(4).decode('utf-8')
37
+ except Exception:
38
+ raise IOError(f'Invalid flow file: {flow_or_path}')
39
+ else:
40
+ if header != 'PIEH':
41
+ raise IOError(f'Invalid flow file: {flow_or_path}, '
42
+ 'header does not contain PIEH')
43
+
44
+ w = np.fromfile(f, np.int32, 1).squeeze()
45
+ h = np.fromfile(f, np.int32, 1).squeeze()
46
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
47
+ else:
48
+ assert concat_axis in [0, 1]
49
+ cat_flow = imread(flow_or_path, flag='unchanged')
50
+ if cat_flow.ndim != 2:
51
+ raise IOError(
52
+ f'{flow_or_path} is not a valid quantized flow file, '
53
+ f'its dimension is {cat_flow.ndim}.')
54
+ assert cat_flow.shape[concat_axis] % 2 == 0
55
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
56
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
57
+
58
+ return flow.astype(np.float32)
59
+
60
+
61
+ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
62
+ """Write optical flow to file.
63
+
64
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
65
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
66
+ will be concatenated horizontally into a single image if quantize is True.)
67
+
68
+ Args:
69
+ flow (ndarray): (h, w, 2) array of optical flow.
70
+ filename (str): Output filepath.
71
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
72
+ images. If set to True, remaining args will be passed to
73
+ :func:`quantize_flow`.
74
+ concat_axis (int): The axis that dx and dy are concatenated,
75
+ can be either 0 or 1. Ignored if quantize is False.
76
+ """
77
+ if not quantize:
78
+ with open(filename, 'wb') as f:
79
+ f.write('PIEH'.encode('utf-8'))
80
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
81
+ flow = flow.astype(np.float32)
82
+ flow.tofile(f)
83
+ f.flush()
84
+ else:
85
+ assert concat_axis in [0, 1]
86
+ dx, dy = quantize_flow(flow, *args, **kwargs)
87
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
88
+ imwrite(dxdy, filename)
89
+
90
+
91
+ def quantize_flow(flow, max_val=0.02, norm=True):
92
+ """Quantize flow to [0, 255].
93
+
94
+ After this step, the size of flow will be much smaller, and can be
95
+ dumped as jpeg images.
96
+
97
+ Args:
98
+ flow (ndarray): (h, w, 2) array of optical flow.
99
+ max_val (float): Maximum value of flow, values beyond
100
+ [-max_val, max_val] will be truncated.
101
+ norm (bool): Whether to divide flow values by image width/height.
102
+
103
+ Returns:
104
+ tuple[ndarray]: Quantized dx and dy.
105
+ """
106
+ h, w, _ = flow.shape
107
+ dx = flow[..., 0]
108
+ dy = flow[..., 1]
109
+ if norm:
110
+ dx = dx / w # avoid inplace operations
111
+ dy = dy / h
112
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
113
+ flow_comps = [
114
+ quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
115
+ ]
116
+ return tuple(flow_comps)
117
+
118
+
119
+ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
120
+ """Recover from quantized flow.
121
+
122
+ Args:
123
+ dx (ndarray): Quantized dx.
124
+ dy (ndarray): Quantized dy.
125
+ max_val (float): Maximum value used when quantizing.
126
+ denorm (bool): Whether to multiply flow values with width/height.
127
+
128
+ Returns:
129
+ ndarray: Dequantized flow.
130
+ """
131
+ assert dx.shape == dy.shape
132
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
133
+
134
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
135
+
136
+ if denorm:
137
+ dx *= dx.shape[1]
138
+ dy *= dx.shape[0]
139
+ flow = np.dstack((dx, dy))
140
+ return flow
141
+
142
+
143
+ def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
144
+ """Use flow to warp img.
145
+
146
+ Args:
147
+ img (ndarray, float or uint8): Image to be warped.
148
+ flow (ndarray, float): Optical Flow.
149
+ filling_value (int): The missing pixels will be set with filling_value.
150
+ interpolate_mode (str): bilinear -> Bilinear Interpolation;
151
+ nearest -> Nearest Neighbor.
152
+
153
+ Returns:
154
+ ndarray: Warped image with the same shape of img
155
+ """
156
+ warnings.warn('This function is just for prototyping and cannot '
157
+ 'guarantee the computational efficiency.')
158
+ assert flow.ndim == 3, 'Flow must be in 3D arrays.'
159
+ height = flow.shape[0]
160
+ width = flow.shape[1]
161
+ channels = img.shape[2]
162
+
163
+ output = np.ones(
164
+ (height, width, channels), dtype=img.dtype) * filling_value
165
+
166
+ grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2)
167
+ dx = grid[:, :, 0] + flow[:, :, 1]
168
+ dy = grid[:, :, 1] + flow[:, :, 0]
169
+ sx = np.floor(dx).astype(int)
170
+ sy = np.floor(dy).astype(int)
171
+ valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1)
172
+
173
+ if interpolate_mode == 'nearest':
174
+ output[valid, :] = img[dx[valid].round().astype(int),
175
+ dy[valid].round().astype(int), :]
176
+ elif interpolate_mode == 'bilinear':
177
+ # dirty walkround for integer positions
178
+ eps_ = 1e-6
179
+ dx, dy = dx + eps_, dy + eps_
180
+ left_top_ = img[np.floor(dx[valid]).astype(int),
181
+ np.floor(dy[valid]).astype(int), :] * (
182
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
183
+ np.ceil(dy[valid]) - dy[valid])[:, None]
184
+ left_down_ = img[np.ceil(dx[valid]).astype(int),
185
+ np.floor(dy[valid]).astype(int), :] * (
186
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
187
+ np.ceil(dy[valid]) - dy[valid])[:, None]
188
+ right_top_ = img[np.floor(dx[valid]).astype(int),
189
+ np.ceil(dy[valid]).astype(int), :] * (
190
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
191
+ dy[valid] - np.floor(dy[valid]))[:, None]
192
+ right_down_ = img[np.ceil(dx[valid]).astype(int),
193
+ np.ceil(dy[valid]).astype(int), :] * (
194
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
195
+ dy[valid] - np.floor(dy[valid]))[:, None]
196
+ output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_
197
+ else:
198
+ raise NotImplementedError(
199
+ 'We only support interpolation modes of nearest and bilinear, '
200
+ f'but got {interpolate_mode}.')
201
+ return output.astype(img.dtype)
202
+
203
+
204
+ def flow_from_bytes(content):
205
+ """Read dense optical flow from bytes.
206
+
207
+ .. note::
208
+ This load optical flow function works for FlyingChairs, FlyingThings3D,
209
+ Sintel, FlyingChairsOcc datasets, but cannot load the data from
210
+ ChairsSDHom.
211
+
212
+ Args:
213
+ content (bytes): Optical flow bytes got from files or other streams.
214
+
215
+ Returns:
216
+ ndarray: Loaded optical flow with the shape (H, W, 2).
217
+ """
218
+
219
+ # header in first 4 bytes
220
+ header = content[:4]
221
+ if header.decode('utf-8') != 'PIEH':
222
+ raise Exception('Flow file header does not contain PIEH')
223
+ # width in second 4 bytes
224
+ width = np.frombuffer(content[4:], np.int32, 1).squeeze()
225
+ # height in third 4 bytes
226
+ height = np.frombuffer(content[8:], np.int32, 1).squeeze()
227
+ # after first 12 bytes, all bytes are flow
228
+ flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape(
229
+ (height, width, 2))
230
+
231
+ return flow
232
+
233
+
234
+ def sparse_flow_from_bytes(content):
235
+ """Read the optical flow in KITTI datasets from bytes.
236
+
237
+ This function is modified from RAFT load the `KITTI datasets
238
+ <https://github.com/princeton-vl/RAFT/blob/224320502d66c356d88e6c712f38129e60661e80/core/utils/frame_utils.py#L102>`_.
239
+
240
+ Args:
241
+ content (bytes): Optical flow bytes got from files or other streams.
242
+
243
+ Returns:
244
+ Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2)
245
+ and flow valid mask with the shape (H, W).
246
+ """ # nopa
247
+
248
+ content = np.frombuffer(content, np.uint8)
249
+ flow = cv2.imdecode(content, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
250
+ flow = flow[:, :, ::-1].astype(np.float32)
251
+ # flow shape (H, W, 2) valid shape (H, W)
252
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
253
+ flow = (flow - 2**15) / 64.0
254
+ return flow, valid
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/processing.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ import os.path as osp
4
+ import subprocess
5
+ import tempfile
6
+
7
+ from annotator.mmpkg.mmcv.utils import requires_executable
8
+
9
+
10
+ @requires_executable('ffmpeg')
11
+ def convert_video(in_file,
12
+ out_file,
13
+ print_cmd=False,
14
+ pre_options='',
15
+ **kwargs):
16
+ """Convert a video with ffmpeg.
17
+
18
+ This provides a general api to ffmpeg, the executed command is::
19
+
20
+ `ffmpeg -y <pre_options> -i <in_file> <options> <out_file>`
21
+
22
+ Options(kwargs) are mapped to ffmpeg commands with the following rules:
23
+
24
+ - key=val: "-key val"
25
+ - key=True: "-key"
26
+ - key=False: ""
27
+
28
+ Args:
29
+ in_file (str): Input video filename.
30
+ out_file (str): Output video filename.
31
+ pre_options (str): Options appears before "-i <in_file>".
32
+ print_cmd (bool): Whether to print the final ffmpeg command.
33
+ """
34
+ options = []
35
+ for k, v in kwargs.items():
36
+ if isinstance(v, bool):
37
+ if v:
38
+ options.append(f'-{k}')
39
+ elif k == 'log_level':
40
+ assert v in [
41
+ 'quiet', 'panic', 'fatal', 'error', 'warning', 'info',
42
+ 'verbose', 'debug', 'trace'
43
+ ]
44
+ options.append(f'-loglevel {v}')
45
+ else:
46
+ options.append(f'-{k} {v}')
47
+ cmd = f'ffmpeg -y {pre_options} -i {in_file} {" ".join(options)} ' \
48
+ f'{out_file}'
49
+ if print_cmd:
50
+ print(cmd)
51
+ subprocess.call(cmd, shell=True)
52
+
53
+
54
+ @requires_executable('ffmpeg')
55
+ def resize_video(in_file,
56
+ out_file,
57
+ size=None,
58
+ ratio=None,
59
+ keep_ar=False,
60
+ log_level='info',
61
+ print_cmd=False):
62
+ """Resize a video.
63
+
64
+ Args:
65
+ in_file (str): Input video filename.
66
+ out_file (str): Output video filename.
67
+ size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1).
68
+ ratio (tuple or float): Expected resize ratio, (2, 0.5) means
69
+ (w*2, h*0.5).
70
+ keep_ar (bool): Whether to keep original aspect ratio.
71
+ log_level (str): Logging level of ffmpeg.
72
+ print_cmd (bool): Whether to print the final ffmpeg command.
73
+ """
74
+ if size is None and ratio is None:
75
+ raise ValueError('expected size or ratio must be specified')
76
+ if size is not None and ratio is not None:
77
+ raise ValueError('size and ratio cannot be specified at the same time')
78
+ options = {'log_level': log_level}
79
+ if size:
80
+ if not keep_ar:
81
+ options['vf'] = f'scale={size[0]}:{size[1]}'
82
+ else:
83
+ options['vf'] = f'scale=w={size[0]}:h={size[1]}:' \
84
+ 'force_original_aspect_ratio=decrease'
85
+ else:
86
+ if not isinstance(ratio, tuple):
87
+ ratio = (ratio, ratio)
88
+ options['vf'] = f'scale="trunc(iw*{ratio[0]}):trunc(ih*{ratio[1]})"'
89
+ convert_video(in_file, out_file, print_cmd, **options)
90
+
91
+
92
+ @requires_executable('ffmpeg')
93
+ def cut_video(in_file,
94
+ out_file,
95
+ start=None,
96
+ end=None,
97
+ vcodec=None,
98
+ acodec=None,
99
+ log_level='info',
100
+ print_cmd=False):
101
+ """Cut a clip from a video.
102
+
103
+ Args:
104
+ in_file (str): Input video filename.
105
+ out_file (str): Output video filename.
106
+ start (None or float): Start time (in seconds).
107
+ end (None or float): End time (in seconds).
108
+ vcodec (None or str): Output video codec, None for unchanged.
109
+ acodec (None or str): Output audio codec, None for unchanged.
110
+ log_level (str): Logging level of ffmpeg.
111
+ print_cmd (bool): Whether to print the final ffmpeg command.
112
+ """
113
+ options = {'log_level': log_level}
114
+ if vcodec is None:
115
+ options['vcodec'] = 'copy'
116
+ if acodec is None:
117
+ options['acodec'] = 'copy'
118
+ if start:
119
+ options['ss'] = start
120
+ else:
121
+ start = 0
122
+ if end:
123
+ options['t'] = end - start
124
+ convert_video(in_file, out_file, print_cmd, **options)
125
+
126
+
127
+ @requires_executable('ffmpeg')
128
+ def concat_video(video_list,
129
+ out_file,
130
+ vcodec=None,
131
+ acodec=None,
132
+ log_level='info',
133
+ print_cmd=False):
134
+ """Concatenate multiple videos into a single one.
135
+
136
+ Args:
137
+ video_list (list): A list of video filenames
138
+ out_file (str): Output video filename
139
+ vcodec (None or str): Output video codec, None for unchanged
140
+ acodec (None or str): Output audio codec, None for unchanged
141
+ log_level (str): Logging level of ffmpeg.
142
+ print_cmd (bool): Whether to print the final ffmpeg command.
143
+ """
144
+ tmp_filehandler, tmp_filename = tempfile.mkstemp(suffix='.txt', text=True)
145
+ with open(tmp_filename, 'w') as f:
146
+ for filename in video_list:
147
+ f.write(f'file {osp.abspath(filename)}\n')
148
+ options = {'log_level': log_level}
149
+ if vcodec is None:
150
+ options['vcodec'] = 'copy'
151
+ if acodec is None:
152
+ options['acodec'] = 'copy'
153
+ convert_video(
154
+ tmp_filename,
155
+ out_file,
156
+ print_cmd,
157
+ pre_options='-f concat -safe 0',
158
+ **options)
159
+ os.close(tmp_filehandler)
160
+ os.remove(tmp_filename)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .color import Color, color_val
3
+ from .image import imshow, imshow_bboxes, imshow_det_bboxes
4
+ from .optflow import flow2rgb, flowshow, make_color_wheel
5
+
6
+ __all__ = [
7
+ 'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes',
8
+ 'flowshow', 'flow2rgb', 'make_color_wheel'
9
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/color.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from enum import Enum
3
+
4
+ import numpy as np
5
+
6
+ from annotator.mmpkg.mmcv.utils import is_str
7
+
8
+
9
+ class Color(Enum):
10
+ """An enum that defines common colors.
11
+
12
+ Contains red, green, blue, cyan, yellow, magenta, white and black.
13
+ """
14
+ red = (0, 0, 255)
15
+ green = (0, 255, 0)
16
+ blue = (255, 0, 0)
17
+ cyan = (255, 255, 0)
18
+ yellow = (0, 255, 255)
19
+ magenta = (255, 0, 255)
20
+ white = (255, 255, 255)
21
+ black = (0, 0, 0)
22
+
23
+
24
+ def color_val(color):
25
+ """Convert various input to color tuples.
26
+
27
+ Args:
28
+ color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
29
+
30
+ Returns:
31
+ tuple[int]: A tuple of 3 integers indicating BGR channels.
32
+ """
33
+ if is_str(color):
34
+ return Color[color].value
35
+ elif isinstance(color, Color):
36
+ return color.value
37
+ elif isinstance(color, tuple):
38
+ assert len(color) == 3
39
+ for channel in color:
40
+ assert 0 <= channel <= 255
41
+ return color
42
+ elif isinstance(color, int):
43
+ assert 0 <= color <= 255
44
+ return color, color, color
45
+ elif isinstance(color, np.ndarray):
46
+ assert color.ndim == 1 and color.size == 3
47
+ assert np.all((color >= 0) & (color <= 255))
48
+ color = color.astype(np.uint8)
49
+ return tuple(color)
50
+ else:
51
+ raise TypeError(f'Invalid type for color: {type(color)}')
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/image.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import cv2
3
+ import numpy as np
4
+
5
+ from annotator.mmpkg.mmcv.image import imread, imwrite
6
+ from .color import color_val
7
+
8
+
9
+ def imshow(img, win_name='', wait_time=0):
10
+ """Show an image.
11
+
12
+ Args:
13
+ img (str or ndarray): The image to be displayed.
14
+ win_name (str): The window name.
15
+ wait_time (int): Value of waitKey param.
16
+ """
17
+ cv2.imshow(win_name, imread(img))
18
+ if wait_time == 0: # prevent from hanging if windows was closed
19
+ while True:
20
+ ret = cv2.waitKey(1)
21
+
22
+ closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
23
+ # if user closed window or if some key pressed
24
+ if closed or ret != -1:
25
+ break
26
+ else:
27
+ ret = cv2.waitKey(wait_time)
28
+
29
+
30
+ def imshow_bboxes(img,
31
+ bboxes,
32
+ colors='green',
33
+ top_k=-1,
34
+ thickness=1,
35
+ show=True,
36
+ win_name='',
37
+ wait_time=0,
38
+ out_file=None):
39
+ """Draw bboxes on an image.
40
+
41
+ Args:
42
+ img (str or ndarray): The image to be displayed.
43
+ bboxes (list or ndarray): A list of ndarray of shape (k, 4).
44
+ colors (list[str or tuple or Color]): A list of colors.
45
+ top_k (int): Plot the first k bboxes only if set positive.
46
+ thickness (int): Thickness of lines.
47
+ show (bool): Whether to show the image.
48
+ win_name (str): The window name.
49
+ wait_time (int): Value of waitKey param.
50
+ out_file (str, optional): The filename to write the image.
51
+
52
+ Returns:
53
+ ndarray: The image with bboxes drawn on it.
54
+ """
55
+ img = imread(img)
56
+ img = np.ascontiguousarray(img)
57
+
58
+ if isinstance(bboxes, np.ndarray):
59
+ bboxes = [bboxes]
60
+ if not isinstance(colors, list):
61
+ colors = [colors for _ in range(len(bboxes))]
62
+ colors = [color_val(c) for c in colors]
63
+ assert len(bboxes) == len(colors)
64
+
65
+ for i, _bboxes in enumerate(bboxes):
66
+ _bboxes = _bboxes.astype(np.int32)
67
+ if top_k <= 0:
68
+ _top_k = _bboxes.shape[0]
69
+ else:
70
+ _top_k = min(top_k, _bboxes.shape[0])
71
+ for j in range(_top_k):
72
+ left_top = (_bboxes[j, 0], _bboxes[j, 1])
73
+ right_bottom = (_bboxes[j, 2], _bboxes[j, 3])
74
+ cv2.rectangle(
75
+ img, left_top, right_bottom, colors[i], thickness=thickness)
76
+
77
+ if show:
78
+ imshow(img, win_name, wait_time)
79
+ if out_file is not None:
80
+ imwrite(img, out_file)
81
+ return img
82
+
83
+
84
+ def imshow_det_bboxes(img,
85
+ bboxes,
86
+ labels,
87
+ class_names=None,
88
+ score_thr=0,
89
+ bbox_color='green',
90
+ text_color='green',
91
+ thickness=1,
92
+ font_scale=0.5,
93
+ show=True,
94
+ win_name='',
95
+ wait_time=0,
96
+ out_file=None):
97
+ """Draw bboxes and class labels (with scores) on an image.
98
+
99
+ Args:
100
+ img (str or ndarray): The image to be displayed.
101
+ bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
102
+ (n, 5).
103
+ labels (ndarray): Labels of bboxes.
104
+ class_names (list[str]): Names of each classes.
105
+ score_thr (float): Minimum score of bboxes to be shown.
106
+ bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
107
+ text_color (str or tuple or :obj:`Color`): Color of texts.
108
+ thickness (int): Thickness of lines.
109
+ font_scale (float): Font scales of texts.
110
+ show (bool): Whether to show the image.
111
+ win_name (str): The window name.
112
+ wait_time (int): Value of waitKey param.
113
+ out_file (str or None): The filename to write the image.
114
+
115
+ Returns:
116
+ ndarray: The image with bboxes drawn on it.
117
+ """
118
+ assert bboxes.ndim == 2
119
+ assert labels.ndim == 1
120
+ assert bboxes.shape[0] == labels.shape[0]
121
+ assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5
122
+ img = imread(img)
123
+ img = np.ascontiguousarray(img)
124
+
125
+ if score_thr > 0:
126
+ assert bboxes.shape[1] == 5
127
+ scores = bboxes[:, -1]
128
+ inds = scores > score_thr
129
+ bboxes = bboxes[inds, :]
130
+ labels = labels[inds]
131
+
132
+ bbox_color = color_val(bbox_color)
133
+ text_color = color_val(text_color)
134
+
135
+ for bbox, label in zip(bboxes, labels):
136
+ bbox_int = bbox.astype(np.int32)
137
+ left_top = (bbox_int[0], bbox_int[1])
138
+ right_bottom = (bbox_int[2], bbox_int[3])
139
+ cv2.rectangle(
140
+ img, left_top, right_bottom, bbox_color, thickness=thickness)
141
+ label_text = class_names[
142
+ label] if class_names is not None else f'cls {label}'
143
+ if len(bbox) > 4:
144
+ label_text += f'|{bbox[-1]:.02f}'
145
+ cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 2),
146
+ cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
147
+
148
+ if show:
149
+ imshow(img, win_name, wait_time)
150
+ if out_file is not None:
151
+ imwrite(img, out_file)
152
+ return img
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/optflow.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from __future__ import division
3
+
4
+ import numpy as np
5
+
6
+ from annotator.mmpkg.mmcv.image import rgb2bgr
7
+ from annotator.mmpkg.mmcv.video import flowread
8
+ from .image import imshow
9
+
10
+
11
+ def flowshow(flow, win_name='', wait_time=0):
12
+ """Show optical flow.
13
+
14
+ Args:
15
+ flow (ndarray or str): The optical flow to be displayed.
16
+ win_name (str): The window name.
17
+ wait_time (int): Value of waitKey param.
18
+ """
19
+ flow = flowread(flow)
20
+ flow_img = flow2rgb(flow)
21
+ imshow(rgb2bgr(flow_img), win_name, wait_time)
22
+
23
+
24
+ def flow2rgb(flow, color_wheel=None, unknown_thr=1e6):
25
+ """Convert flow map to RGB image.
26
+
27
+ Args:
28
+ flow (ndarray): Array of optical flow.
29
+ color_wheel (ndarray or None): Color wheel used to map flow field to
30
+ RGB colorspace. Default color wheel will be used if not specified.
31
+ unknown_thr (str): Values above this threshold will be marked as
32
+ unknown and thus ignored.
33
+
34
+ Returns:
35
+ ndarray: RGB image that can be visualized.
36
+ """
37
+ assert flow.ndim == 3 and flow.shape[-1] == 2
38
+ if color_wheel is None:
39
+ color_wheel = make_color_wheel()
40
+ assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3
41
+ num_bins = color_wheel.shape[0]
42
+
43
+ dx = flow[:, :, 0].copy()
44
+ dy = flow[:, :, 1].copy()
45
+
46
+ ignore_inds = (
47
+ np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) |
48
+ (np.abs(dy) > unknown_thr))
49
+ dx[ignore_inds] = 0
50
+ dy[ignore_inds] = 0
51
+
52
+ rad = np.sqrt(dx**2 + dy**2)
53
+ if np.any(rad > np.finfo(float).eps):
54
+ max_rad = np.max(rad)
55
+ dx /= max_rad
56
+ dy /= max_rad
57
+
58
+ rad = np.sqrt(dx**2 + dy**2)
59
+ angle = np.arctan2(-dy, -dx) / np.pi
60
+
61
+ bin_real = (angle + 1) / 2 * (num_bins - 1)
62
+ bin_left = np.floor(bin_real).astype(int)
63
+ bin_right = (bin_left + 1) % num_bins
64
+ w = (bin_real - bin_left.astype(np.float32))[..., None]
65
+ flow_img = (1 -
66
+ w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :]
67
+ small_ind = rad <= 1
68
+ flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind])
69
+ flow_img[np.logical_not(small_ind)] *= 0.75
70
+
71
+ flow_img[ignore_inds, :] = 0
72
+
73
+ return flow_img
74
+
75
+
76
+ def make_color_wheel(bins=None):
77
+ """Build a color wheel.
78
+
79
+ Args:
80
+ bins(list or tuple, optional): Specify the number of bins for each
81
+ color range, corresponding to six ranges: red -> yellow,
82
+ yellow -> green, green -> cyan, cyan -> blue, blue -> magenta,
83
+ magenta -> red. [15, 6, 4, 11, 13, 6] is used for default
84
+ (see Middlebury).
85
+
86
+ Returns:
87
+ ndarray: Color wheel of shape (total_bins, 3).
88
+ """
89
+ if bins is None:
90
+ bins = [15, 6, 4, 11, 13, 6]
91
+ assert len(bins) == 6
92
+
93
+ RY, YG, GC, CB, BM, MR = tuple(bins)
94
+
95
+ ry = [1, np.arange(RY) / RY, 0]
96
+ yg = [1 - np.arange(YG) / YG, 1, 0]
97
+ gc = [0, 1, np.arange(GC) / GC]
98
+ cb = [0, 1 - np.arange(CB) / CB, 1]
99
+ bm = [np.arange(BM) / BM, 0, 1]
100
+ mr = [1, 0, 1 - np.arange(MR) / MR]
101
+
102
+ num_bins = RY + YG + GC + CB + BM + MR
103
+
104
+ color_wheel = np.zeros((3, num_bins), dtype=np.float32)
105
+
106
+ col = 0
107
+ for i, color in enumerate([ry, yg, gc, cb, bm, mr]):
108
+ for j in range(3):
109
+ color_wheel[j, col:col + bins[i]] = color[j]
110
+ col += bins[i]
111
+
112
+ return color_wheel.T
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .inference import inference_segmentor, init_segmentor, show_result_pyplot
2
+ from .test import multi_gpu_test, single_gpu_test
3
+ from .train import get_root_logger, set_random_seed, train_segmentor
4
+
5
+ __all__ = [
6
+ 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
7
+ 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
8
+ 'show_result_pyplot'
9
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/inference.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import annotator.mmpkg.mmcv as mmcv
3
+ import torch
4
+ from annotator.mmpkg.mmcv.parallel import collate, scatter
5
+ from annotator.mmpkg.mmcv.runner import load_checkpoint
6
+
7
+ from annotator.mmpkg.mmseg.datasets.pipelines import Compose
8
+ from annotator.mmpkg.mmseg.models import build_segmentor
9
+ from modules import devices
10
+
11
+
12
+ def init_segmentor(config, checkpoint=None, device=devices.get_device_for("controlnet")):
13
+ """Initialize a segmentor from config file.
14
+
15
+ Args:
16
+ config (str or :obj:`mmcv.Config`): Config file path or the config
17
+ object.
18
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
19
+ will not load any weights.
20
+ device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
21
+ Use 'cpu' for loading model on CPU.
22
+ Returns:
23
+ nn.Module: The constructed segmentor.
24
+ """
25
+ if isinstance(config, str):
26
+ config = mmcv.Config.fromfile(config)
27
+ elif not isinstance(config, mmcv.Config):
28
+ raise TypeError('config must be a filename or Config object, '
29
+ 'but got {}'.format(type(config)))
30
+ config.model.pretrained = None
31
+ config.model.train_cfg = None
32
+ model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
33
+ if checkpoint is not None:
34
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
35
+ model.CLASSES = checkpoint['meta']['CLASSES']
36
+ model.PALETTE = checkpoint['meta']['PALETTE']
37
+ model.cfg = config # save the config in the model for convenience
38
+ model.to(device)
39
+ model.eval()
40
+ return model
41
+
42
+
43
+ class LoadImage:
44
+ """A simple pipeline to load image."""
45
+
46
+ def __call__(self, results):
47
+ """Call function to load images into results.
48
+
49
+ Args:
50
+ results (dict): A result dict contains the file name
51
+ of the image to be read.
52
+
53
+ Returns:
54
+ dict: ``results`` will be returned containing loaded image.
55
+ """
56
+
57
+ if isinstance(results['img'], str):
58
+ results['filename'] = results['img']
59
+ results['ori_filename'] = results['img']
60
+ else:
61
+ results['filename'] = None
62
+ results['ori_filename'] = None
63
+ img = mmcv.imread(results['img'])
64
+ results['img'] = img
65
+ results['img_shape'] = img.shape
66
+ results['ori_shape'] = img.shape
67
+ return results
68
+
69
+
70
+ def inference_segmentor(model, img):
71
+ """Inference image(s) with the segmentor.
72
+
73
+ Args:
74
+ model (nn.Module): The loaded segmentor.
75
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
76
+ images.
77
+
78
+ Returns:
79
+ (list[Tensor]): The segmentation result.
80
+ """
81
+ cfg = model.cfg
82
+ device = next(model.parameters()).device # model device
83
+ # build the data pipeline
84
+ test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
85
+ test_pipeline = Compose(test_pipeline)
86
+ # prepare data
87
+ data = dict(img=img)
88
+ data = test_pipeline(data)
89
+ data = collate([data], samples_per_gpu=1)
90
+ if next(model.parameters()).is_cuda:
91
+ # scatter to specified GPU
92
+ data = scatter(data, [device])[0]
93
+ else:
94
+ data['img'][0] = data['img'][0].to(devices.get_device_for("controlnet"))
95
+ data['img_metas'] = [i.data[0] for i in data['img_metas']]
96
+
97
+ # forward the model
98
+ with torch.no_grad():
99
+ result = model(return_loss=False, rescale=True, **data)
100
+ return result
101
+
102
+
103
+ def show_result_pyplot(model,
104
+ img,
105
+ result,
106
+ palette=None,
107
+ fig_size=(15, 10),
108
+ opacity=0.5,
109
+ title='',
110
+ block=True):
111
+ """Visualize the segmentation results on the image.
112
+
113
+ Args:
114
+ model (nn.Module): The loaded segmentor.
115
+ img (str or np.ndarray): Image filename or loaded image.
116
+ result (list): The segmentation result.
117
+ palette (list[list[int]]] | None): The palette of segmentation
118
+ map. If None is given, random palette will be generated.
119
+ Default: None
120
+ fig_size (tuple): Figure size of the pyplot figure.
121
+ opacity(float): Opacity of painted segmentation map.
122
+ Default 0.5.
123
+ Must be in (0, 1] range.
124
+ title (str): The title of pyplot figure.
125
+ Default is ''.
126
+ block (bool): Whether to block the pyplot figure.
127
+ Default is True.
128
+ """
129
+ if hasattr(model, 'module'):
130
+ model = model.module
131
+ img = model.show_result(
132
+ img, result, palette=palette, show=False, opacity=opacity)
133
+ # plt.figure(figsize=fig_size)
134
+ # plt.imshow(mmcv.bgr2rgb(img))
135
+ # plt.title(title)
136
+ # plt.tight_layout()
137
+ # plt.show(block=block)
138
+ return mmcv.bgr2rgb(img)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/test.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import pickle
3
+ import shutil
4
+ import tempfile
5
+
6
+ import annotator.mmpkg.mmcv as mmcv
7
+ import numpy as np
8
+ import torch
9
+ import torch.distributed as dist
10
+ from annotator.mmpkg.mmcv.image import tensor2imgs
11
+ from annotator.mmpkg.mmcv.runner import get_dist_info
12
+
13
+
14
+ def np2tmp(array, temp_file_name=None):
15
+ """Save ndarray to local numpy file.
16
+
17
+ Args:
18
+ array (ndarray): Ndarray to save.
19
+ temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
20
+ function will generate a file name with tempfile.NamedTemporaryFile
21
+ to save ndarray. Default: None.
22
+
23
+ Returns:
24
+ str: The numpy file name.
25
+ """
26
+
27
+ if temp_file_name is None:
28
+ temp_file_name = tempfile.NamedTemporaryFile(
29
+ suffix='.npy', delete=False).name
30
+ np.save(temp_file_name, array)
31
+ return temp_file_name
32
+
33
+
34
+ def single_gpu_test(model,
35
+ data_loader,
36
+ show=False,
37
+ out_dir=None,
38
+ efficient_test=False,
39
+ opacity=0.5):
40
+ """Test with single GPU.
41
+
42
+ Args:
43
+ model (nn.Module): Model to be tested.
44
+ data_loader (utils.data.Dataloader): Pytorch data loader.
45
+ show (bool): Whether show results during inference. Default: False.
46
+ out_dir (str, optional): If specified, the results will be dumped into
47
+ the directory to save output results.
48
+ efficient_test (bool): Whether save the results as local numpy files to
49
+ save CPU memory during evaluation. Default: False.
50
+ opacity(float): Opacity of painted segmentation map.
51
+ Default 0.5.
52
+ Must be in (0, 1] range.
53
+ Returns:
54
+ list: The prediction results.
55
+ """
56
+
57
+ model.eval()
58
+ results = []
59
+ dataset = data_loader.dataset
60
+ prog_bar = mmcv.ProgressBar(len(dataset))
61
+ for i, data in enumerate(data_loader):
62
+ with torch.no_grad():
63
+ result = model(return_loss=False, **data)
64
+
65
+ if show or out_dir:
66
+ img_tensor = data['img'][0]
67
+ img_metas = data['img_metas'][0].data[0]
68
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
69
+ assert len(imgs) == len(img_metas)
70
+
71
+ for img, img_meta in zip(imgs, img_metas):
72
+ h, w, _ = img_meta['img_shape']
73
+ img_show = img[:h, :w, :]
74
+
75
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
76
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
77
+
78
+ if out_dir:
79
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
80
+ else:
81
+ out_file = None
82
+
83
+ model.module.show_result(
84
+ img_show,
85
+ result,
86
+ palette=dataset.PALETTE,
87
+ show=show,
88
+ out_file=out_file,
89
+ opacity=opacity)
90
+
91
+ if isinstance(result, list):
92
+ if efficient_test:
93
+ result = [np2tmp(_) for _ in result]
94
+ results.extend(result)
95
+ else:
96
+ if efficient_test:
97
+ result = np2tmp(result)
98
+ results.append(result)
99
+
100
+ batch_size = len(result)
101
+ for _ in range(batch_size):
102
+ prog_bar.update()
103
+ return results
104
+
105
+
106
+ def multi_gpu_test(model,
107
+ data_loader,
108
+ tmpdir=None,
109
+ gpu_collect=False,
110
+ efficient_test=False):
111
+ """Test model with multiple gpus.
112
+
113
+ This method tests model with multiple gpus and collects the results
114
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
115
+ it encodes results to gpu tensors and use gpu communication for results
116
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
117
+ and collects them by the rank 0 worker.
118
+
119
+ Args:
120
+ model (nn.Module): Model to be tested.
121
+ data_loader (utils.data.Dataloader): Pytorch data loader.
122
+ tmpdir (str): Path of directory to save the temporary results from
123
+ different gpus under cpu mode.
124
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
125
+ efficient_test (bool): Whether save the results as local numpy files to
126
+ save CPU memory during evaluation. Default: False.
127
+
128
+ Returns:
129
+ list: The prediction results.
130
+ """
131
+
132
+ model.eval()
133
+ results = []
134
+ dataset = data_loader.dataset
135
+ rank, world_size = get_dist_info()
136
+ if rank == 0:
137
+ prog_bar = mmcv.ProgressBar(len(dataset))
138
+ for i, data in enumerate(data_loader):
139
+ with torch.no_grad():
140
+ result = model(return_loss=False, rescale=True, **data)
141
+
142
+ if isinstance(result, list):
143
+ if efficient_test:
144
+ result = [np2tmp(_) for _ in result]
145
+ results.extend(result)
146
+ else:
147
+ if efficient_test:
148
+ result = np2tmp(result)
149
+ results.append(result)
150
+
151
+ if rank == 0:
152
+ batch_size = data['img'][0].size(0)
153
+ for _ in range(batch_size * world_size):
154
+ prog_bar.update()
155
+
156
+ # collect results from all ranks
157
+ if gpu_collect:
158
+ results = collect_results_gpu(results, len(dataset))
159
+ else:
160
+ results = collect_results_cpu(results, len(dataset), tmpdir)
161
+ return results
162
+
163
+
164
+ def collect_results_cpu(result_part, size, tmpdir=None):
165
+ """Collect results with CPU."""
166
+ rank, world_size = get_dist_info()
167
+ # create a tmp dir if it is not specified
168
+ if tmpdir is None:
169
+ MAX_LEN = 512
170
+ # 32 is whitespace
171
+ dir_tensor = torch.full((MAX_LEN, ),
172
+ 32,
173
+ dtype=torch.uint8,
174
+ device='cuda')
175
+ if rank == 0:
176
+ tmpdir = tempfile.mkdtemp()
177
+ tmpdir = torch.tensor(
178
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
179
+ dir_tensor[:len(tmpdir)] = tmpdir
180
+ dist.broadcast(dir_tensor, 0)
181
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
182
+ else:
183
+ mmcv.mkdir_or_exist(tmpdir)
184
+ # dump the part result to the dir
185
+ mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
186
+ dist.barrier()
187
+ # collect all parts
188
+ if rank != 0:
189
+ return None
190
+ else:
191
+ # load results of all parts from tmp dir
192
+ part_list = []
193
+ for i in range(world_size):
194
+ part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
195
+ part_list.append(mmcv.load(part_file))
196
+ # sort the results
197
+ ordered_results = []
198
+ for res in zip(*part_list):
199
+ ordered_results.extend(list(res))
200
+ # the dataloader may pad some samples
201
+ ordered_results = ordered_results[:size]
202
+ # remove tmp dir
203
+ shutil.rmtree(tmpdir)
204
+ return ordered_results
205
+
206
+
207
+ def collect_results_gpu(result_part, size):
208
+ """Collect results with GPU."""
209
+ rank, world_size = get_dist_info()
210
+ # dump result part to tensor with pickle
211
+ part_tensor = torch.tensor(
212
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
213
+ # gather all result part tensor shape
214
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
215
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
216
+ dist.all_gather(shape_list, shape_tensor)
217
+ # padding result part tensor to max length
218
+ shape_max = torch.tensor(shape_list).max()
219
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
220
+ part_send[:shape_tensor[0]] = part_tensor
221
+ part_recv_list = [
222
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
223
+ ]
224
+ # gather all result part
225
+ dist.all_gather(part_recv_list, part_send)
226
+
227
+ if rank == 0:
228
+ part_list = []
229
+ for recv, shape in zip(part_recv_list, shape_list):
230
+ part_list.append(
231
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
232
+ # sort the results
233
+ ordered_results = []
234
+ for res in zip(*part_list):
235
+ ordered_results.extend(list(res))
236
+ # the dataloader may pad some samples
237
+ ordered_results = ordered_results[:size]
238
+ return ordered_results
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/train.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+
4
+ import numpy as np
5
+ import torch
6
+ from annotator.mmpkg.mmcv.parallel import MMDataParallel, MMDistributedDataParallel
7
+ from annotator.mmpkg.mmcv.runner import build_optimizer, build_runner
8
+
9
+ from annotator.mmpkg.mmseg.core import DistEvalHook, EvalHook
10
+ from annotator.mmpkg.mmseg.datasets import build_dataloader, build_dataset
11
+ from annotator.mmpkg.mmseg.utils import get_root_logger
12
+
13
+
14
+ def set_random_seed(seed, deterministic=False):
15
+ """Set random seed.
16
+
17
+ Args:
18
+ seed (int): Seed to be used.
19
+ deterministic (bool): Whether to set the deterministic option for
20
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
21
+ to True and `torch.backends.cudnn.benchmark` to False.
22
+ Default: False.
23
+ """
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed_all(seed)
28
+ if deterministic:
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+
32
+
33
+ def train_segmentor(model,
34
+ dataset,
35
+ cfg,
36
+ distributed=False,
37
+ validate=False,
38
+ timestamp=None,
39
+ meta=None):
40
+ """Launch segmentor training."""
41
+ logger = get_root_logger(cfg.log_level)
42
+
43
+ # prepare data loaders
44
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
45
+ data_loaders = [
46
+ build_dataloader(
47
+ ds,
48
+ cfg.data.samples_per_gpu,
49
+ cfg.data.workers_per_gpu,
50
+ # cfg.gpus will be ignored if distributed
51
+ len(cfg.gpu_ids),
52
+ dist=distributed,
53
+ seed=cfg.seed,
54
+ drop_last=True) for ds in dataset
55
+ ]
56
+
57
+ # put model on gpus
58
+ if distributed:
59
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
60
+ # Sets the `find_unused_parameters` parameter in
61
+ # torch.nn.parallel.DistributedDataParallel
62
+ model = MMDistributedDataParallel(
63
+ model.cuda(),
64
+ device_ids=[torch.cuda.current_device()],
65
+ broadcast_buffers=False,
66
+ find_unused_parameters=find_unused_parameters)
67
+ else:
68
+ model = MMDataParallel(
69
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
70
+
71
+ # build runner
72
+ optimizer = build_optimizer(model, cfg.optimizer)
73
+
74
+ if cfg.get('runner') is None:
75
+ cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
76
+ warnings.warn(
77
+ 'config is now expected to have a `runner` section, '
78
+ 'please set `runner` in your config.', UserWarning)
79
+
80
+ runner = build_runner(
81
+ cfg.runner,
82
+ default_args=dict(
83
+ model=model,
84
+ batch_processor=None,
85
+ optimizer=optimizer,
86
+ work_dir=cfg.work_dir,
87
+ logger=logger,
88
+ meta=meta))
89
+
90
+ # register hooks
91
+ runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
92
+ cfg.checkpoint_config, cfg.log_config,
93
+ cfg.get('momentum_config', None))
94
+
95
+ # an ugly walkaround to make the .log and .log.json filenames the same
96
+ runner.timestamp = timestamp
97
+
98
+ # register eval hooks
99
+ if validate:
100
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
101
+ val_dataloader = build_dataloader(
102
+ val_dataset,
103
+ samples_per_gpu=1,
104
+ workers_per_gpu=cfg.data.workers_per_gpu,
105
+ dist=distributed,
106
+ shuffle=False)
107
+ eval_cfg = cfg.get('evaluation', {})
108
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
109
+ eval_hook = DistEvalHook if distributed else EvalHook
110
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW')
111
+
112
+ if cfg.resume_from:
113
+ runner.resume(cfg.resume_from)
114
+ elif cfg.load_from:
115
+ runner.load_checkpoint(cfg.load_from)
116
+ runner.run(data_loaders, cfg.workflow)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .evaluation import * # noqa: F401, F403
2
+ from .seg import * # noqa: F401, F403
3
+ from .utils import * # noqa: F401, F403
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .class_names import get_classes, get_palette
2
+ from .eval_hooks import DistEvalHook, EvalHook
3
+ from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
4
+
5
+ __all__ = [
6
+ 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
7
+ 'eval_metrics', 'get_classes', 'get_palette'
8
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/class_names.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import annotator.mmpkg.mmcv as mmcv
2
+
3
+
4
+ def cityscapes_classes():
5
+ """Cityscapes class names for external use."""
6
+ return [
7
+ 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
8
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
9
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
10
+ 'bicycle'
11
+ ]
12
+
13
+
14
+ def ade_classes():
15
+ """ADE20K class names for external use."""
16
+ return [
17
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
18
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
19
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
20
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
21
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
22
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
23
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
24
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
25
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
26
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
27
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
28
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
29
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
30
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
31
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
32
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
33
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
34
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
35
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
36
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
37
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
38
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
39
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
40
+ 'clock', 'flag'
41
+ ]
42
+
43
+
44
+ def voc_classes():
45
+ """Pascal VOC class names for external use."""
46
+ return [
47
+ 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
48
+ 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
49
+ 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
50
+ 'tvmonitor'
51
+ ]
52
+
53
+
54
+ def cityscapes_palette():
55
+ """Cityscapes palette for external use."""
56
+ return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
57
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
58
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
59
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
60
+ [0, 0, 230], [119, 11, 32]]
61
+
62
+
63
+ def ade_palette():
64
+ """ADE20K palette for external use."""
65
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
66
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
67
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
68
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
69
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
70
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
71
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
72
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
73
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
74
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
75
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
76
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
77
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
78
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
79
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
80
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
81
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
82
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
83
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
84
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
85
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
86
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
87
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
88
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
89
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
90
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
91
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
92
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
93
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
94
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
95
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
96
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
97
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
98
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
99
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
100
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
101
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
102
+ [102, 255, 0], [92, 0, 255]]
103
+
104
+
105
+ def voc_palette():
106
+ """Pascal VOC palette for external use."""
107
+ return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
108
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
109
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
110
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
111
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
112
+
113
+
114
+ dataset_aliases = {
115
+ 'cityscapes': ['cityscapes'],
116
+ 'ade': ['ade', 'ade20k'],
117
+ 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
118
+ }
119
+
120
+
121
+ def get_classes(dataset):
122
+ """Get class names of a dataset."""
123
+ alias2name = {}
124
+ for name, aliases in dataset_aliases.items():
125
+ for alias in aliases:
126
+ alias2name[alias] = name
127
+
128
+ if mmcv.is_str(dataset):
129
+ if dataset in alias2name:
130
+ labels = eval(alias2name[dataset] + '_classes()')
131
+ else:
132
+ raise ValueError(f'Unrecognized dataset: {dataset}')
133
+ else:
134
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
135
+ return labels
136
+
137
+
138
+ def get_palette(dataset):
139
+ """Get class palette (RGB) of a dataset."""
140
+ alias2name = {}
141
+ for name, aliases in dataset_aliases.items():
142
+ for alias in aliases:
143
+ alias2name[alias] = name
144
+
145
+ if mmcv.is_str(dataset):
146
+ if dataset in alias2name:
147
+ labels = eval(alias2name[dataset] + '_palette()')
148
+ else:
149
+ raise ValueError(f'Unrecognized dataset: {dataset}')
150
+ else:
151
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
152
+ return labels
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/eval_hooks.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from annotator.mmpkg.mmcv.runner import DistEvalHook as _DistEvalHook
4
+ from annotator.mmpkg.mmcv.runner import EvalHook as _EvalHook
5
+
6
+
7
+ class EvalHook(_EvalHook):
8
+ """Single GPU EvalHook, with efficient test support.
9
+
10
+ Args:
11
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
12
+ If set to True, it will perform by epoch. Otherwise, by iteration.
13
+ Default: False.
14
+ efficient_test (bool): Whether save the results as local numpy files to
15
+ save CPU memory during evaluation. Default: False.
16
+ Returns:
17
+ list: The prediction results.
18
+ """
19
+
20
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
21
+
22
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
23
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
24
+ self.efficient_test = efficient_test
25
+
26
+ def after_train_iter(self, runner):
27
+ """After train epoch hook.
28
+
29
+ Override default ``single_gpu_test``.
30
+ """
31
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
32
+ return
33
+ from annotator.mmpkg.mmseg.apis import single_gpu_test
34
+ runner.log_buffer.clear()
35
+ results = single_gpu_test(
36
+ runner.model,
37
+ self.dataloader,
38
+ show=False,
39
+ efficient_test=self.efficient_test)
40
+ self.evaluate(runner, results)
41
+
42
+ def after_train_epoch(self, runner):
43
+ """After train epoch hook.
44
+
45
+ Override default ``single_gpu_test``.
46
+ """
47
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
48
+ return
49
+ from annotator.mmpkg.mmseg.apis import single_gpu_test
50
+ runner.log_buffer.clear()
51
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
52
+ self.evaluate(runner, results)
53
+
54
+
55
+ class DistEvalHook(_DistEvalHook):
56
+ """Distributed EvalHook, with efficient test support.
57
+
58
+ Args:
59
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
60
+ If set to True, it will perform by epoch. Otherwise, by iteration.
61
+ Default: False.
62
+ efficient_test (bool): Whether save the results as local numpy files to
63
+ save CPU memory during evaluation. Default: False.
64
+ Returns:
65
+ list: The prediction results.
66
+ """
67
+
68
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
69
+
70
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
71
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
72
+ self.efficient_test = efficient_test
73
+
74
+ def after_train_iter(self, runner):
75
+ """After train epoch hook.
76
+
77
+ Override default ``multi_gpu_test``.
78
+ """
79
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
80
+ return
81
+ from annotator.mmpkg.mmseg.apis import multi_gpu_test
82
+ runner.log_buffer.clear()
83
+ results = multi_gpu_test(
84
+ runner.model,
85
+ self.dataloader,
86
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
87
+ gpu_collect=self.gpu_collect,
88
+ efficient_test=self.efficient_test)
89
+ if runner.rank == 0:
90
+ print('\n')
91
+ self.evaluate(runner, results)
92
+
93
+ def after_train_epoch(self, runner):
94
+ """After train epoch hook.
95
+
96
+ Override default ``multi_gpu_test``.
97
+ """
98
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
99
+ return
100
+ from annotator.mmpkg.mmseg.apis import multi_gpu_test
101
+ runner.log_buffer.clear()
102
+ results = multi_gpu_test(
103
+ runner.model,
104
+ self.dataloader,
105
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
106
+ gpu_collect=self.gpu_collect)
107
+ if runner.rank == 0:
108
+ print('\n')
109
+ self.evaluate(runner, results)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/metrics.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import annotator.mmpkg.mmcv as mmcv
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def f_score(precision, recall, beta=1):
9
+ """calcuate the f-score value.
10
+
11
+ Args:
12
+ precision (float | torch.Tensor): The precision value.
13
+ recall (float | torch.Tensor): The recall value.
14
+ beta (int): Determines the weight of recall in the combined score.
15
+ Default: False.
16
+
17
+ Returns:
18
+ [torch.tensor]: The f-score value.
19
+ """
20
+ score = (1 + beta**2) * (precision * recall) / (
21
+ (beta**2 * precision) + recall)
22
+ return score
23
+
24
+
25
+ def intersect_and_union(pred_label,
26
+ label,
27
+ num_classes,
28
+ ignore_index,
29
+ label_map=dict(),
30
+ reduce_zero_label=False):
31
+ """Calculate intersection and Union.
32
+
33
+ Args:
34
+ pred_label (ndarray | str): Prediction segmentation map
35
+ or predict result filename.
36
+ label (ndarray | str): Ground truth segmentation map
37
+ or label filename.
38
+ num_classes (int): Number of categories.
39
+ ignore_index (int): Index that will be ignored in evaluation.
40
+ label_map (dict): Mapping old labels to new labels. The parameter will
41
+ work only when label is str. Default: dict().
42
+ reduce_zero_label (bool): Whether ignore zero label. The parameter will
43
+ work only when label is str. Default: False.
44
+
45
+ Returns:
46
+ torch.Tensor: The intersection of prediction and ground truth
47
+ histogram on all classes.
48
+ torch.Tensor: The union of prediction and ground truth histogram on
49
+ all classes.
50
+ torch.Tensor: The prediction histogram on all classes.
51
+ torch.Tensor: The ground truth histogram on all classes.
52
+ """
53
+
54
+ if isinstance(pred_label, str):
55
+ pred_label = torch.from_numpy(np.load(pred_label))
56
+ else:
57
+ pred_label = torch.from_numpy((pred_label))
58
+
59
+ if isinstance(label, str):
60
+ label = torch.from_numpy(
61
+ mmcv.imread(label, flag='unchanged', backend='pillow'))
62
+ else:
63
+ label = torch.from_numpy(label)
64
+
65
+ if label_map is not None:
66
+ for old_id, new_id in label_map.items():
67
+ label[label == old_id] = new_id
68
+ if reduce_zero_label:
69
+ label[label == 0] = 255
70
+ label = label - 1
71
+ label[label == 254] = 255
72
+
73
+ mask = (label != ignore_index)
74
+ pred_label = pred_label[mask]
75
+ label = label[mask]
76
+
77
+ intersect = pred_label[pred_label == label]
78
+ area_intersect = torch.histc(
79
+ intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
80
+ area_pred_label = torch.histc(
81
+ pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
82
+ area_label = torch.histc(
83
+ label.float(), bins=(num_classes), min=0, max=num_classes - 1)
84
+ area_union = area_pred_label + area_label - area_intersect
85
+ return area_intersect, area_union, area_pred_label, area_label
86
+
87
+
88
+ def total_intersect_and_union(results,
89
+ gt_seg_maps,
90
+ num_classes,
91
+ ignore_index,
92
+ label_map=dict(),
93
+ reduce_zero_label=False):
94
+ """Calculate Total Intersection and Union.
95
+
96
+ Args:
97
+ results (list[ndarray] | list[str]): List of prediction segmentation
98
+ maps or list of prediction result filenames.
99
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
100
+ segmentation maps or list of label filenames.
101
+ num_classes (int): Number of categories.
102
+ ignore_index (int): Index that will be ignored in evaluation.
103
+ label_map (dict): Mapping old labels to new labels. Default: dict().
104
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
105
+
106
+ Returns:
107
+ ndarray: The intersection of prediction and ground truth histogram
108
+ on all classes.
109
+ ndarray: The union of prediction and ground truth histogram on all
110
+ classes.
111
+ ndarray: The prediction histogram on all classes.
112
+ ndarray: The ground truth histogram on all classes.
113
+ """
114
+ num_imgs = len(results)
115
+ assert len(gt_seg_maps) == num_imgs
116
+ total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
117
+ total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
118
+ total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
119
+ total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
120
+ for i in range(num_imgs):
121
+ area_intersect, area_union, area_pred_label, area_label = \
122
+ intersect_and_union(
123
+ results[i], gt_seg_maps[i], num_classes, ignore_index,
124
+ label_map, reduce_zero_label)
125
+ total_area_intersect += area_intersect
126
+ total_area_union += area_union
127
+ total_area_pred_label += area_pred_label
128
+ total_area_label += area_label
129
+ return total_area_intersect, total_area_union, total_area_pred_label, \
130
+ total_area_label
131
+
132
+
133
+ def mean_iou(results,
134
+ gt_seg_maps,
135
+ num_classes,
136
+ ignore_index,
137
+ nan_to_num=None,
138
+ label_map=dict(),
139
+ reduce_zero_label=False):
140
+ """Calculate Mean Intersection and Union (mIoU)
141
+
142
+ Args:
143
+ results (list[ndarray] | list[str]): List of prediction segmentation
144
+ maps or list of prediction result filenames.
145
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
146
+ segmentation maps or list of label filenames.
147
+ num_classes (int): Number of categories.
148
+ ignore_index (int): Index that will be ignored in evaluation.
149
+ nan_to_num (int, optional): If specified, NaN values will be replaced
150
+ by the numbers defined by the user. Default: None.
151
+ label_map (dict): Mapping old labels to new labels. Default: dict().
152
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
153
+
154
+ Returns:
155
+ dict[str, float | ndarray]:
156
+ <aAcc> float: Overall accuracy on all images.
157
+ <Acc> ndarray: Per category accuracy, shape (num_classes, ).
158
+ <IoU> ndarray: Per category IoU, shape (num_classes, ).
159
+ """
160
+ iou_result = eval_metrics(
161
+ results=results,
162
+ gt_seg_maps=gt_seg_maps,
163
+ num_classes=num_classes,
164
+ ignore_index=ignore_index,
165
+ metrics=['mIoU'],
166
+ nan_to_num=nan_to_num,
167
+ label_map=label_map,
168
+ reduce_zero_label=reduce_zero_label)
169
+ return iou_result
170
+
171
+
172
+ def mean_dice(results,
173
+ gt_seg_maps,
174
+ num_classes,
175
+ ignore_index,
176
+ nan_to_num=None,
177
+ label_map=dict(),
178
+ reduce_zero_label=False):
179
+ """Calculate Mean Dice (mDice)
180
+
181
+ Args:
182
+ results (list[ndarray] | list[str]): List of prediction segmentation
183
+ maps or list of prediction result filenames.
184
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
185
+ segmentation maps or list of label filenames.
186
+ num_classes (int): Number of categories.
187
+ ignore_index (int): Index that will be ignored in evaluation.
188
+ nan_to_num (int, optional): If specified, NaN values will be replaced
189
+ by the numbers defined by the user. Default: None.
190
+ label_map (dict): Mapping old labels to new labels. Default: dict().
191
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
192
+
193
+ Returns:
194
+ dict[str, float | ndarray]: Default metrics.
195
+ <aAcc> float: Overall accuracy on all images.
196
+ <Acc> ndarray: Per category accuracy, shape (num_classes, ).
197
+ <Dice> ndarray: Per category dice, shape (num_classes, ).
198
+ """
199
+
200
+ dice_result = eval_metrics(
201
+ results=results,
202
+ gt_seg_maps=gt_seg_maps,
203
+ num_classes=num_classes,
204
+ ignore_index=ignore_index,
205
+ metrics=['mDice'],
206
+ nan_to_num=nan_to_num,
207
+ label_map=label_map,
208
+ reduce_zero_label=reduce_zero_label)
209
+ return dice_result
210
+
211
+
212
+ def mean_fscore(results,
213
+ gt_seg_maps,
214
+ num_classes,
215
+ ignore_index,
216
+ nan_to_num=None,
217
+ label_map=dict(),
218
+ reduce_zero_label=False,
219
+ beta=1):
220
+ """Calculate Mean Intersection and Union (mIoU)
221
+
222
+ Args:
223
+ results (list[ndarray] | list[str]): List of prediction segmentation
224
+ maps or list of prediction result filenames.
225
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
226
+ segmentation maps or list of label filenames.
227
+ num_classes (int): Number of categories.
228
+ ignore_index (int): Index that will be ignored in evaluation.
229
+ nan_to_num (int, optional): If specified, NaN values will be replaced
230
+ by the numbers defined by the user. Default: None.
231
+ label_map (dict): Mapping old labels to new labels. Default: dict().
232
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
233
+ beta (int): Determines the weight of recall in the combined score.
234
+ Default: False.
235
+
236
+
237
+ Returns:
238
+ dict[str, float | ndarray]: Default metrics.
239
+ <aAcc> float: Overall accuracy on all images.
240
+ <Fscore> ndarray: Per category recall, shape (num_classes, ).
241
+ <Precision> ndarray: Per category precision, shape (num_classes, ).
242
+ <Recall> ndarray: Per category f-score, shape (num_classes, ).
243
+ """
244
+ fscore_result = eval_metrics(
245
+ results=results,
246
+ gt_seg_maps=gt_seg_maps,
247
+ num_classes=num_classes,
248
+ ignore_index=ignore_index,
249
+ metrics=['mFscore'],
250
+ nan_to_num=nan_to_num,
251
+ label_map=label_map,
252
+ reduce_zero_label=reduce_zero_label,
253
+ beta=beta)
254
+ return fscore_result
255
+
256
+
257
+ def eval_metrics(results,
258
+ gt_seg_maps,
259
+ num_classes,
260
+ ignore_index,
261
+ metrics=['mIoU'],
262
+ nan_to_num=None,
263
+ label_map=dict(),
264
+ reduce_zero_label=False,
265
+ beta=1):
266
+ """Calculate evaluation metrics
267
+ Args:
268
+ results (list[ndarray] | list[str]): List of prediction segmentation
269
+ maps or list of prediction result filenames.
270
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
271
+ segmentation maps or list of label filenames.
272
+ num_classes (int): Number of categories.
273
+ ignore_index (int): Index that will be ignored in evaluation.
274
+ metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
275
+ nan_to_num (int, optional): If specified, NaN values will be replaced
276
+ by the numbers defined by the user. Default: None.
277
+ label_map (dict): Mapping old labels to new labels. Default: dict().
278
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
279
+ Returns:
280
+ float: Overall accuracy on all images.
281
+ ndarray: Per category accuracy, shape (num_classes, ).
282
+ ndarray: Per category evaluation metrics, shape (num_classes, ).
283
+ """
284
+ if isinstance(metrics, str):
285
+ metrics = [metrics]
286
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
287
+ if not set(metrics).issubset(set(allowed_metrics)):
288
+ raise KeyError('metrics {} is not supported'.format(metrics))
289
+
290
+ total_area_intersect, total_area_union, total_area_pred_label, \
291
+ total_area_label = total_intersect_and_union(
292
+ results, gt_seg_maps, num_classes, ignore_index, label_map,
293
+ reduce_zero_label)
294
+ all_acc = total_area_intersect.sum() / total_area_label.sum()
295
+ ret_metrics = OrderedDict({'aAcc': all_acc})
296
+ for metric in metrics:
297
+ if metric == 'mIoU':
298
+ iou = total_area_intersect / total_area_union
299
+ acc = total_area_intersect / total_area_label
300
+ ret_metrics['IoU'] = iou
301
+ ret_metrics['Acc'] = acc
302
+ elif metric == 'mDice':
303
+ dice = 2 * total_area_intersect / (
304
+ total_area_pred_label + total_area_label)
305
+ acc = total_area_intersect / total_area_label
306
+ ret_metrics['Dice'] = dice
307
+ ret_metrics['Acc'] = acc
308
+ elif metric == 'mFscore':
309
+ precision = total_area_intersect / total_area_pred_label
310
+ recall = total_area_intersect / total_area_label
311
+ f_value = torch.tensor(
312
+ [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
313
+ ret_metrics['Fscore'] = f_value
314
+ ret_metrics['Precision'] = precision
315
+ ret_metrics['Recall'] = recall
316
+
317
+ ret_metrics = {
318
+ metric: value.numpy()
319
+ for metric, value in ret_metrics.items()
320
+ }
321
+ if nan_to_num is not None:
322
+ ret_metrics = OrderedDict({
323
+ metric: np.nan_to_num(metric_value, nan=nan_to_num)
324
+ for metric, metric_value in ret_metrics.items()
325
+ })
326
+ return ret_metrics
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .builder import build_pixel_sampler
2
+ from .sampler import BasePixelSampler, OHEMPixelSampler
3
+
4
+ __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/builder.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg
2
+
3
+ PIXEL_SAMPLERS = Registry('pixel sampler')
4
+
5
+
6
+ def build_pixel_sampler(cfg, **default_args):
7
+ """Build pixel sampler for segmentation map."""
8
+ return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base_pixel_sampler import BasePixelSampler
2
+ from .ohem_pixel_sampler import OHEMPixelSampler
3
+
4
+ __all__ = ['BasePixelSampler', 'OHEMPixelSampler']
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/base_pixel_sampler.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+
4
+ class BasePixelSampler(metaclass=ABCMeta):
5
+ """Base class of pixel sampler."""
6
+
7
+ def __init__(self, **kwargs):
8
+ pass
9
+
10
+ @abstractmethod
11
+ def sample(self, seg_logit, seg_label):
12
+ """Placeholder for sample function."""
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/ohem_pixel_sampler.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from ..builder import PIXEL_SAMPLERS
5
+ from .base_pixel_sampler import BasePixelSampler
6
+
7
+
8
+ @PIXEL_SAMPLERS.register_module()
9
+ class OHEMPixelSampler(BasePixelSampler):
10
+ """Online Hard Example Mining Sampler for segmentation.
11
+
12
+ Args:
13
+ context (nn.Module): The context of sampler, subclass of
14
+ :obj:`BaseDecodeHead`.
15
+ thresh (float, optional): The threshold for hard example selection.
16
+ Below which, are prediction with low confidence. If not
17
+ specified, the hard examples will be pixels of top ``min_kept``
18
+ loss. Default: None.
19
+ min_kept (int, optional): The minimum number of predictions to keep.
20
+ Default: 100000.
21
+ """
22
+
23
+ def __init__(self, context, thresh=None, min_kept=100000):
24
+ super(OHEMPixelSampler, self).__init__()
25
+ self.context = context
26
+ assert min_kept > 1
27
+ self.thresh = thresh
28
+ self.min_kept = min_kept
29
+
30
+ def sample(self, seg_logit, seg_label):
31
+ """Sample pixels that have high loss or with low prediction confidence.
32
+
33
+ Args:
34
+ seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
35
+ seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
36
+
37
+ Returns:
38
+ torch.Tensor: segmentation weight, shape (N, H, W)
39
+ """
40
+ with torch.no_grad():
41
+ assert seg_logit.shape[2:] == seg_label.shape[2:]
42
+ assert seg_label.shape[1] == 1
43
+ seg_label = seg_label.squeeze(1).long()
44
+ batch_kept = self.min_kept * seg_label.size(0)
45
+ valid_mask = seg_label != self.context.ignore_index
46
+ seg_weight = seg_logit.new_zeros(size=seg_label.size())
47
+ valid_seg_weight = seg_weight[valid_mask]
48
+ if self.thresh is not None:
49
+ seg_prob = F.softmax(seg_logit, dim=1)
50
+
51
+ tmp_seg_label = seg_label.clone().unsqueeze(1)
52
+ tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
53
+ seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
54
+ sort_prob, sort_indices = seg_prob[valid_mask].sort()
55
+
56
+ if sort_prob.numel() > 0:
57
+ min_threshold = sort_prob[min(batch_kept,
58
+ sort_prob.numel() - 1)]
59
+ else:
60
+ min_threshold = 0.0
61
+ threshold = max(min_threshold, self.thresh)
62
+ valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
63
+ else:
64
+ losses = self.context.loss_decode(
65
+ seg_logit,
66
+ seg_label,
67
+ weight=None,
68
+ ignore_index=self.context.ignore_index,
69
+ reduction_override='none')
70
+ # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
71
+ _, sort_indices = losses[valid_mask].sort(descending=True)
72
+ valid_seg_weight[sort_indices[:batch_kept]] = 1.
73
+
74
+ seg_weight[valid_mask] = valid_seg_weight
75
+
76
+ return seg_weight
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .misc import add_prefix
2
+
3
+ __all__ = ['add_prefix']
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/misc.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def add_prefix(inputs, prefix):
2
+ """Add prefix for dict.
3
+
4
+ Args:
5
+ inputs (dict): The input dict with str keys.
6
+ prefix (str): The prefix to add.
7
+
8
+ Returns:
9
+
10
+ dict: The dict with keys updated with ``prefix``.
11
+ """
12
+
13
+ outputs = dict()
14
+ for name, value in inputs.items():
15
+ outputs[f'{prefix}.{name}'] = value
16
+
17
+ return outputs
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ade import ADE20KDataset
2
+ from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
3
+ from .chase_db1 import ChaseDB1Dataset
4
+ from .cityscapes import CityscapesDataset
5
+ from .custom import CustomDataset
6
+ from .dataset_wrappers import ConcatDataset, RepeatDataset
7
+ from .drive import DRIVEDataset
8
+ from .hrf import HRFDataset
9
+ from .pascal_context import PascalContextDataset, PascalContextDataset59
10
+ from .stare import STAREDataset
11
+ from .voc import PascalVOCDataset
12
+
13
+ __all__ = [
14
+ 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
15
+ 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
16
+ 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
17
+ 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
18
+ 'STAREDataset'
19
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/ade.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .builder import DATASETS
2
+ from .custom import CustomDataset
3
+
4
+
5
+ @DATASETS.register_module()
6
+ class ADE20KDataset(CustomDataset):
7
+ """ADE20K dataset.
8
+
9
+ In segmentation map annotation for ADE20K, 0 stands for background, which
10
+ is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
11
+ The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
12
+ '.png'.
13
+ """
14
+ CLASSES = (
15
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
16
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
17
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
18
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
19
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
20
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
21
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
22
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
23
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
24
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
25
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
26
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
27
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
28
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
29
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
30
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
31
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
32
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
33
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
34
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
35
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
36
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
37
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
38
+ 'clock', 'flag')
39
+
40
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
41
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
42
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
43
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
44
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
45
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
46
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
47
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
48
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
49
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
50
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
51
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
52
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
53
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
54
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
55
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
56
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
57
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
58
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
59
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
60
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
61
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
62
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
63
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
64
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
65
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
66
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
67
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
68
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
69
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
70
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
71
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
72
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
73
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
74
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
75
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
76
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
77
+ [102, 255, 0], [92, 0, 255]]
78
+
79
+ def __init__(self, **kwargs):
80
+ super(ADE20KDataset, self).__init__(
81
+ img_suffix='.jpg',
82
+ seg_map_suffix='.png',
83
+ reduce_zero_label=True,
84
+ **kwargs)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/builder.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import platform
3
+ import random
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ from annotator.mmpkg.mmcv.parallel import collate
8
+ from annotator.mmpkg.mmcv.runner import get_dist_info
9
+ from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg
10
+ from annotator.mmpkg.mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
11
+ from torch.utils.data import DistributedSampler
12
+
13
+ if platform.system() != 'Windows':
14
+ # https://github.com/pytorch/pytorch/issues/973
15
+ import resource
16
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
17
+ hard_limit = rlimit[1]
18
+ soft_limit = min(4096, hard_limit)
19
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
20
+
21
+ DATASETS = Registry('dataset')
22
+ PIPELINES = Registry('pipeline')
23
+
24
+
25
+ def _concat_dataset(cfg, default_args=None):
26
+ """Build :obj:`ConcatDataset by."""
27
+ from .dataset_wrappers import ConcatDataset
28
+ img_dir = cfg['img_dir']
29
+ ann_dir = cfg.get('ann_dir', None)
30
+ split = cfg.get('split', None)
31
+ num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
32
+ if ann_dir is not None:
33
+ num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
34
+ else:
35
+ num_ann_dir = 0
36
+ if split is not None:
37
+ num_split = len(split) if isinstance(split, (list, tuple)) else 1
38
+ else:
39
+ num_split = 0
40
+ if num_img_dir > 1:
41
+ assert num_img_dir == num_ann_dir or num_ann_dir == 0
42
+ assert num_img_dir == num_split or num_split == 0
43
+ else:
44
+ assert num_split == num_ann_dir or num_ann_dir <= 1
45
+ num_dset = max(num_split, num_img_dir)
46
+
47
+ datasets = []
48
+ for i in range(num_dset):
49
+ data_cfg = copy.deepcopy(cfg)
50
+ if isinstance(img_dir, (list, tuple)):
51
+ data_cfg['img_dir'] = img_dir[i]
52
+ if isinstance(ann_dir, (list, tuple)):
53
+ data_cfg['ann_dir'] = ann_dir[i]
54
+ if isinstance(split, (list, tuple)):
55
+ data_cfg['split'] = split[i]
56
+ datasets.append(build_dataset(data_cfg, default_args))
57
+
58
+ return ConcatDataset(datasets)
59
+
60
+
61
+ def build_dataset(cfg, default_args=None):
62
+ """Build datasets."""
63
+ from .dataset_wrappers import ConcatDataset, RepeatDataset
64
+ if isinstance(cfg, (list, tuple)):
65
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
66
+ elif cfg['type'] == 'RepeatDataset':
67
+ dataset = RepeatDataset(
68
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
69
+ elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
70
+ cfg.get('split', None), (list, tuple)):
71
+ dataset = _concat_dataset(cfg, default_args)
72
+ else:
73
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
74
+
75
+ return dataset
76
+
77
+
78
+ def build_dataloader(dataset,
79
+ samples_per_gpu,
80
+ workers_per_gpu,
81
+ num_gpus=1,
82
+ dist=True,
83
+ shuffle=True,
84
+ seed=None,
85
+ drop_last=False,
86
+ pin_memory=True,
87
+ dataloader_type='PoolDataLoader',
88
+ **kwargs):
89
+ """Build PyTorch DataLoader.
90
+
91
+ In distributed training, each GPU/process has a dataloader.
92
+ In non-distributed training, there is only one dataloader for all GPUs.
93
+
94
+ Args:
95
+ dataset (Dataset): A PyTorch dataset.
96
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
97
+ batch size of each GPU.
98
+ workers_per_gpu (int): How many subprocesses to use for data loading
99
+ for each GPU.
100
+ num_gpus (int): Number of GPUs. Only used in non-distributed training.
101
+ dist (bool): Distributed training/test or not. Default: True.
102
+ shuffle (bool): Whether to shuffle the data at every epoch.
103
+ Default: True.
104
+ seed (int | None): Seed to be used. Default: None.
105
+ drop_last (bool): Whether to drop the last incomplete batch in epoch.
106
+ Default: False
107
+ pin_memory (bool): Whether to use pin_memory in DataLoader.
108
+ Default: True
109
+ dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
110
+ kwargs: any keyword argument to be used to initialize DataLoader
111
+
112
+ Returns:
113
+ DataLoader: A PyTorch dataloader.
114
+ """
115
+ rank, world_size = get_dist_info()
116
+ if dist:
117
+ sampler = DistributedSampler(
118
+ dataset, world_size, rank, shuffle=shuffle)
119
+ shuffle = False
120
+ batch_size = samples_per_gpu
121
+ num_workers = workers_per_gpu
122
+ else:
123
+ sampler = None
124
+ batch_size = num_gpus * samples_per_gpu
125
+ num_workers = num_gpus * workers_per_gpu
126
+
127
+ init_fn = partial(
128
+ worker_init_fn, num_workers=num_workers, rank=rank,
129
+ seed=seed) if seed is not None else None
130
+
131
+ assert dataloader_type in (
132
+ 'DataLoader',
133
+ 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
134
+
135
+ if dataloader_type == 'PoolDataLoader':
136
+ dataloader = PoolDataLoader
137
+ elif dataloader_type == 'DataLoader':
138
+ dataloader = DataLoader
139
+
140
+ data_loader = dataloader(
141
+ dataset,
142
+ batch_size=batch_size,
143
+ sampler=sampler,
144
+ num_workers=num_workers,
145
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
146
+ pin_memory=pin_memory,
147
+ shuffle=shuffle,
148
+ worker_init_fn=init_fn,
149
+ drop_last=drop_last,
150
+ **kwargs)
151
+
152
+ return data_loader
153
+
154
+
155
+ def worker_init_fn(worker_id, num_workers, rank, seed):
156
+ """Worker init func for dataloader.
157
+
158
+ The seed of each worker equals to num_worker * rank + worker_id + user_seed
159
+
160
+ Args:
161
+ worker_id (int): Worker id.
162
+ num_workers (int): Number of workers.
163
+ rank (int): The rank of current process.
164
+ seed (int): The random seed to use.
165
+ """
166
+
167
+ worker_seed = num_workers * rank + worker_id + seed
168
+ np.random.seed(worker_seed)
169
+ random.seed(worker_seed)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/chase_db1.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from .builder import DATASETS
4
+ from .custom import CustomDataset
5
+
6
+
7
+ @DATASETS.register_module()
8
+ class ChaseDB1Dataset(CustomDataset):
9
+ """Chase_db1 dataset.
10
+
11
+ In segmentation map annotation for Chase_db1, 0 stands for background,
12
+ which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
13
+ The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
14
+ '_1stHO.png'.
15
+ """
16
+
17
+ CLASSES = ('background', 'vessel')
18
+
19
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
20
+
21
+ def __init__(self, **kwargs):
22
+ super(ChaseDB1Dataset, self).__init__(
23
+ img_suffix='.png',
24
+ seg_map_suffix='_1stHO.png',
25
+ reduce_zero_label=False,
26
+ **kwargs)
27
+ assert osp.exists(self.img_dir)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/cityscapes.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import tempfile
3
+
4
+ import annotator.mmpkg.mmcv as mmcv
5
+ import numpy as np
6
+ from annotator.mmpkg.mmcv.utils import print_log
7
+ from PIL import Image
8
+
9
+ from .builder import DATASETS
10
+ from .custom import CustomDataset
11
+
12
+
13
+ @DATASETS.register_module()
14
+ class CityscapesDataset(CustomDataset):
15
+ """Cityscapes dataset.
16
+
17
+ The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
18
+ fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
19
+ """
20
+
21
+ CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
22
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
23
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
24
+ 'bicycle')
25
+
26
+ PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
27
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
28
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
29
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
30
+ [0, 80, 100], [0, 0, 230], [119, 11, 32]]
31
+
32
+ def __init__(self, **kwargs):
33
+ super(CityscapesDataset, self).__init__(
34
+ img_suffix='_leftImg8bit.png',
35
+ seg_map_suffix='_gtFine_labelTrainIds.png',
36
+ **kwargs)
37
+
38
+ @staticmethod
39
+ def _convert_to_label_id(result):
40
+ """Convert trainId to id for cityscapes."""
41
+ if isinstance(result, str):
42
+ result = np.load(result)
43
+ import cityscapesscripts.helpers.labels as CSLabels
44
+ result_copy = result.copy()
45
+ for trainId, label in CSLabels.trainId2label.items():
46
+ result_copy[result == trainId] = label.id
47
+
48
+ return result_copy
49
+
50
+ def results2img(self, results, imgfile_prefix, to_label_id):
51
+ """Write the segmentation results to images.
52
+
53
+ Args:
54
+ results (list[list | tuple | ndarray]): Testing results of the
55
+ dataset.
56
+ imgfile_prefix (str): The filename prefix of the png files.
57
+ If the prefix is "somepath/xxx",
58
+ the png files will be named "somepath/xxx.png".
59
+ to_label_id (bool): whether convert output to label_id for
60
+ submission
61
+
62
+ Returns:
63
+ list[str: str]: result txt files which contains corresponding
64
+ semantic segmentation images.
65
+ """
66
+ mmcv.mkdir_or_exist(imgfile_prefix)
67
+ result_files = []
68
+ prog_bar = mmcv.ProgressBar(len(self))
69
+ for idx in range(len(self)):
70
+ result = results[idx]
71
+ if to_label_id:
72
+ result = self._convert_to_label_id(result)
73
+ filename = self.img_infos[idx]['filename']
74
+ basename = osp.splitext(osp.basename(filename))[0]
75
+
76
+ png_filename = osp.join(imgfile_prefix, f'{basename}.png')
77
+
78
+ output = Image.fromarray(result.astype(np.uint8)).convert('P')
79
+ import cityscapesscripts.helpers.labels as CSLabels
80
+ palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
81
+ for label_id, label in CSLabels.id2label.items():
82
+ palette[label_id] = label.color
83
+
84
+ output.putpalette(palette)
85
+ output.save(png_filename)
86
+ result_files.append(png_filename)
87
+ prog_bar.update()
88
+
89
+ return result_files
90
+
91
+ def format_results(self, results, imgfile_prefix=None, to_label_id=True):
92
+ """Format the results into dir (standard format for Cityscapes
93
+ evaluation).
94
+
95
+ Args:
96
+ results (list): Testing results of the dataset.
97
+ imgfile_prefix (str | None): The prefix of images files. It
98
+ includes the file path and the prefix of filename, e.g.,
99
+ "a/b/prefix". If not specified, a temp file will be created.
100
+ Default: None.
101
+ to_label_id (bool): whether convert output to label_id for
102
+ submission. Default: False
103
+
104
+ Returns:
105
+ tuple: (result_files, tmp_dir), result_files is a list containing
106
+ the image paths, tmp_dir is the temporal directory created
107
+ for saving json/png files when img_prefix is not specified.
108
+ """
109
+
110
+ assert isinstance(results, list), 'results must be a list'
111
+ assert len(results) == len(self), (
112
+ 'The length of results is not equal to the dataset len: '
113
+ f'{len(results)} != {len(self)}')
114
+
115
+ if imgfile_prefix is None:
116
+ tmp_dir = tempfile.TemporaryDirectory()
117
+ imgfile_prefix = tmp_dir.name
118
+ else:
119
+ tmp_dir = None
120
+ result_files = self.results2img(results, imgfile_prefix, to_label_id)
121
+
122
+ return result_files, tmp_dir
123
+
124
+ def evaluate(self,
125
+ results,
126
+ metric='mIoU',
127
+ logger=None,
128
+ imgfile_prefix=None,
129
+ efficient_test=False):
130
+ """Evaluation in Cityscapes/default protocol.
131
+
132
+ Args:
133
+ results (list): Testing results of the dataset.
134
+ metric (str | list[str]): Metrics to be evaluated.
135
+ logger (logging.Logger | None | str): Logger used for printing
136
+ related information during evaluation. Default: None.
137
+ imgfile_prefix (str | None): The prefix of output image file,
138
+ for cityscapes evaluation only. It includes the file path and
139
+ the prefix of filename, e.g., "a/b/prefix".
140
+ If results are evaluated with cityscapes protocol, it would be
141
+ the prefix of output png files. The output files would be
142
+ png images under folder "a/b/prefix/xxx.png", where "xxx" is
143
+ the image name of cityscapes. If not specified, a temp file
144
+ will be created for evaluation.
145
+ Default: None.
146
+
147
+ Returns:
148
+ dict[str, float]: Cityscapes/default metrics.
149
+ """
150
+
151
+ eval_results = dict()
152
+ metrics = metric.copy() if isinstance(metric, list) else [metric]
153
+ if 'cityscapes' in metrics:
154
+ eval_results.update(
155
+ self._evaluate_cityscapes(results, logger, imgfile_prefix))
156
+ metrics.remove('cityscapes')
157
+ if len(metrics) > 0:
158
+ eval_results.update(
159
+ super(CityscapesDataset,
160
+ self).evaluate(results, metrics, logger, efficient_test))
161
+
162
+ return eval_results
163
+
164
+ def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
165
+ """Evaluation in Cityscapes protocol.
166
+
167
+ Args:
168
+ results (list): Testing results of the dataset.
169
+ logger (logging.Logger | str | None): Logger used for printing
170
+ related information during evaluation. Default: None.
171
+ imgfile_prefix (str | None): The prefix of output image file
172
+
173
+ Returns:
174
+ dict[str: float]: Cityscapes evaluation results.
175
+ """
176
+ try:
177
+ import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
178
+ except ImportError:
179
+ raise ImportError('Please run "pip install cityscapesscripts" to '
180
+ 'install cityscapesscripts first.')
181
+ msg = 'Evaluating in Cityscapes style'
182
+ if logger is None:
183
+ msg = '\n' + msg
184
+ print_log(msg, logger=logger)
185
+
186
+ result_files, tmp_dir = self.format_results(results, imgfile_prefix)
187
+
188
+ if tmp_dir is None:
189
+ result_dir = imgfile_prefix
190
+ else:
191
+ result_dir = tmp_dir.name
192
+
193
+ eval_results = dict()
194
+ print_log(f'Evaluating results under {result_dir} ...', logger=logger)
195
+
196
+ CSEval.args.evalInstLevelScore = True
197
+ CSEval.args.predictionPath = osp.abspath(result_dir)
198
+ CSEval.args.evalPixelAccuracy = True
199
+ CSEval.args.JSONOutput = False
200
+
201
+ seg_map_list = []
202
+ pred_list = []
203
+
204
+ # when evaluating with official cityscapesscripts,
205
+ # **_gtFine_labelIds.png is used
206
+ for seg_map in mmcv.scandir(
207
+ self.ann_dir, 'gtFine_labelIds.png', recursive=True):
208
+ seg_map_list.append(osp.join(self.ann_dir, seg_map))
209
+ pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
210
+
211
+ eval_results.update(
212
+ CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
213
+
214
+ if tmp_dir is not None:
215
+ tmp_dir.cleanup()
216
+
217
+ return eval_results
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/custom.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ from collections import OrderedDict
4
+ from functools import reduce
5
+
6
+ import annotator.mmpkg.mmcv as mmcv
7
+ import numpy as np
8
+ from annotator.mmpkg.mmcv.utils import print_log
9
+ from torch.utils.data import Dataset
10
+
11
+ from annotator.mmpkg.mmseg.core import eval_metrics
12
+ from annotator.mmpkg.mmseg.utils import get_root_logger
13
+ from .builder import DATASETS
14
+ from .pipelines import Compose
15
+
16
+
17
+ @DATASETS.register_module()
18
+ class CustomDataset(Dataset):
19
+ """Custom dataset for semantic segmentation. An example of file structure
20
+ is as followed.
21
+
22
+ .. code-block:: none
23
+
24
+ ├── data
25
+ │ ├── my_dataset
26
+ │ │ ├── img_dir
27
+ │ │ │ ├── train
28
+ │ │ │ │ ├── xxx{img_suffix}
29
+ │ │ │ │ ├── yyy{img_suffix}
30
+ │ │ │ │ ├── zzz{img_suffix}
31
+ │ │ │ ├── val
32
+ │ │ ├── ann_dir
33
+ │ │ │ ├── train
34
+ │ │ │ │ ├── xxx{seg_map_suffix}
35
+ │ │ │ │ ├── yyy{seg_map_suffix}
36
+ │ │ │ │ ├── zzz{seg_map_suffix}
37
+ │ │ │ ├── val
38
+
39
+ The img/gt_semantic_seg pair of CustomDataset should be of the same
40
+ except suffix. A valid img/gt_semantic_seg filename pair should be like
41
+ ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
42
+ in the suffix). If split is given, then ``xxx`` is specified in txt file.
43
+ Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
44
+ Please refer to ``docs/tutorials/new_dataset.md`` for more details.
45
+
46
+
47
+ Args:
48
+ pipeline (list[dict]): Processing pipeline
49
+ img_dir (str): Path to image directory
50
+ img_suffix (str): Suffix of images. Default: '.jpg'
51
+ ann_dir (str, optional): Path to annotation directory. Default: None
52
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
53
+ split (str, optional): Split txt file. If split is specified, only
54
+ file with suffix in the splits will be loaded. Otherwise, all
55
+ images in img_dir/ann_dir will be loaded. Default: None
56
+ data_root (str, optional): Data root for img_dir/ann_dir. Default:
57
+ None.
58
+ test_mode (bool): If test_mode=True, gt wouldn't be loaded.
59
+ ignore_index (int): The label index to be ignored. Default: 255
60
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
61
+ Default: False
62
+ classes (str | Sequence[str], optional): Specify classes to load.
63
+ If is None, ``cls.CLASSES`` will be used. Default: None.
64
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
65
+ The palette of segmentation map. If None is given, and
66
+ self.PALETTE is None, random palette will be generated.
67
+ Default: None
68
+ """
69
+
70
+ CLASSES = None
71
+
72
+ PALETTE = None
73
+
74
+ def __init__(self,
75
+ pipeline,
76
+ img_dir,
77
+ img_suffix='.jpg',
78
+ ann_dir=None,
79
+ seg_map_suffix='.png',
80
+ split=None,
81
+ data_root=None,
82
+ test_mode=False,
83
+ ignore_index=255,
84
+ reduce_zero_label=False,
85
+ classes=None,
86
+ palette=None):
87
+ self.pipeline = Compose(pipeline)
88
+ self.img_dir = img_dir
89
+ self.img_suffix = img_suffix
90
+ self.ann_dir = ann_dir
91
+ self.seg_map_suffix = seg_map_suffix
92
+ self.split = split
93
+ self.data_root = data_root
94
+ self.test_mode = test_mode
95
+ self.ignore_index = ignore_index
96
+ self.reduce_zero_label = reduce_zero_label
97
+ self.label_map = None
98
+ self.CLASSES, self.PALETTE = self.get_classes_and_palette(
99
+ classes, palette)
100
+
101
+ # join paths if data_root is specified
102
+ if self.data_root is not None:
103
+ if not osp.isabs(self.img_dir):
104
+ self.img_dir = osp.join(self.data_root, self.img_dir)
105
+ if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
106
+ self.ann_dir = osp.join(self.data_root, self.ann_dir)
107
+ if not (self.split is None or osp.isabs(self.split)):
108
+ self.split = osp.join(self.data_root, self.split)
109
+
110
+ # load annotations
111
+ self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
112
+ self.ann_dir,
113
+ self.seg_map_suffix, self.split)
114
+
115
+ def __len__(self):
116
+ """Total number of samples of data."""
117
+ return len(self.img_infos)
118
+
119
+ def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
120
+ split):
121
+ """Load annotation from directory.
122
+
123
+ Args:
124
+ img_dir (str): Path to image directory
125
+ img_suffix (str): Suffix of images.
126
+ ann_dir (str|None): Path to annotation directory.
127
+ seg_map_suffix (str|None): Suffix of segmentation maps.
128
+ split (str|None): Split txt file. If split is specified, only file
129
+ with suffix in the splits will be loaded. Otherwise, all images
130
+ in img_dir/ann_dir will be loaded. Default: None
131
+
132
+ Returns:
133
+ list[dict]: All image info of dataset.
134
+ """
135
+
136
+ img_infos = []
137
+ if split is not None:
138
+ with open(split) as f:
139
+ for line in f:
140
+ img_name = line.strip()
141
+ img_info = dict(filename=img_name + img_suffix)
142
+ if ann_dir is not None:
143
+ seg_map = img_name + seg_map_suffix
144
+ img_info['ann'] = dict(seg_map=seg_map)
145
+ img_infos.append(img_info)
146
+ else:
147
+ for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
148
+ img_info = dict(filename=img)
149
+ if ann_dir is not None:
150
+ seg_map = img.replace(img_suffix, seg_map_suffix)
151
+ img_info['ann'] = dict(seg_map=seg_map)
152
+ img_infos.append(img_info)
153
+
154
+ print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
155
+ return img_infos
156
+
157
+ def get_ann_info(self, idx):
158
+ """Get annotation by index.
159
+
160
+ Args:
161
+ idx (int): Index of data.
162
+
163
+ Returns:
164
+ dict: Annotation info of specified index.
165
+ """
166
+
167
+ return self.img_infos[idx]['ann']
168
+
169
+ def pre_pipeline(self, results):
170
+ """Prepare results dict for pipeline."""
171
+ results['seg_fields'] = []
172
+ results['img_prefix'] = self.img_dir
173
+ results['seg_prefix'] = self.ann_dir
174
+ if self.custom_classes:
175
+ results['label_map'] = self.label_map
176
+
177
+ def __getitem__(self, idx):
178
+ """Get training/test data after pipeline.
179
+
180
+ Args:
181
+ idx (int): Index of data.
182
+
183
+ Returns:
184
+ dict: Training/test data (with annotation if `test_mode` is set
185
+ False).
186
+ """
187
+
188
+ if self.test_mode:
189
+ return self.prepare_test_img(idx)
190
+ else:
191
+ return self.prepare_train_img(idx)
192
+
193
+ def prepare_train_img(self, idx):
194
+ """Get training data and annotations after pipeline.
195
+
196
+ Args:
197
+ idx (int): Index of data.
198
+
199
+ Returns:
200
+ dict: Training data and annotation after pipeline with new keys
201
+ introduced by pipeline.
202
+ """
203
+
204
+ img_info = self.img_infos[idx]
205
+ ann_info = self.get_ann_info(idx)
206
+ results = dict(img_info=img_info, ann_info=ann_info)
207
+ self.pre_pipeline(results)
208
+ return self.pipeline(results)
209
+
210
+ def prepare_test_img(self, idx):
211
+ """Get testing data after pipeline.
212
+
213
+ Args:
214
+ idx (int): Index of data.
215
+
216
+ Returns:
217
+ dict: Testing data after pipeline with new keys introduced by
218
+ pipeline.
219
+ """
220
+
221
+ img_info = self.img_infos[idx]
222
+ results = dict(img_info=img_info)
223
+ self.pre_pipeline(results)
224
+ return self.pipeline(results)
225
+
226
+ def format_results(self, results, **kwargs):
227
+ """Place holder to format result to dataset specific output."""
228
+
229
+ def get_gt_seg_maps(self, efficient_test=False):
230
+ """Get ground truth segmentation maps for evaluation."""
231
+ gt_seg_maps = []
232
+ for img_info in self.img_infos:
233
+ seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
234
+ if efficient_test:
235
+ gt_seg_map = seg_map
236
+ else:
237
+ gt_seg_map = mmcv.imread(
238
+ seg_map, flag='unchanged', backend='pillow')
239
+ gt_seg_maps.append(gt_seg_map)
240
+ return gt_seg_maps
241
+
242
+ def get_classes_and_palette(self, classes=None, palette=None):
243
+ """Get class names of current dataset.
244
+
245
+ Args:
246
+ classes (Sequence[str] | str | None): If classes is None, use
247
+ default CLASSES defined by builtin dataset. If classes is a
248
+ string, take it as a file name. The file contains the name of
249
+ classes where each line contains one class name. If classes is
250
+ a tuple or list, override the CLASSES defined by the dataset.
251
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
252
+ The palette of segmentation map. If None is given, random
253
+ palette will be generated. Default: None
254
+ """
255
+ if classes is None:
256
+ self.custom_classes = False
257
+ return self.CLASSES, self.PALETTE
258
+
259
+ self.custom_classes = True
260
+ if isinstance(classes, str):
261
+ # take it as a file path
262
+ class_names = mmcv.list_from_file(classes)
263
+ elif isinstance(classes, (tuple, list)):
264
+ class_names = classes
265
+ else:
266
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
267
+
268
+ if self.CLASSES:
269
+ if not set(classes).issubset(self.CLASSES):
270
+ raise ValueError('classes is not a subset of CLASSES.')
271
+
272
+ # dictionary, its keys are the old label ids and its values
273
+ # are the new label ids.
274
+ # used for changing pixel labels in load_annotations.
275
+ self.label_map = {}
276
+ for i, c in enumerate(self.CLASSES):
277
+ if c not in class_names:
278
+ self.label_map[i] = -1
279
+ else:
280
+ self.label_map[i] = classes.index(c)
281
+
282
+ palette = self.get_palette_for_custom_classes(class_names, palette)
283
+
284
+ return class_names, palette
285
+
286
+ def get_palette_for_custom_classes(self, class_names, palette=None):
287
+
288
+ if self.label_map is not None:
289
+ # return subset of palette
290
+ palette = []
291
+ for old_id, new_id in sorted(
292
+ self.label_map.items(), key=lambda x: x[1]):
293
+ if new_id != -1:
294
+ palette.append(self.PALETTE[old_id])
295
+ palette = type(self.PALETTE)(palette)
296
+
297
+ elif palette is None:
298
+ if self.PALETTE is None:
299
+ palette = np.random.randint(0, 255, size=(len(class_names), 3))
300
+ else:
301
+ palette = self.PALETTE
302
+
303
+ return palette
304
+
305
+ def evaluate(self,
306
+ results,
307
+ metric='mIoU',
308
+ logger=None,
309
+ efficient_test=False,
310
+ **kwargs):
311
+ """Evaluate the dataset.
312
+
313
+ Args:
314
+ results (list): Testing results of the dataset.
315
+ metric (str | list[str]): Metrics to be evaluated. 'mIoU',
316
+ 'mDice' and 'mFscore' are supported.
317
+ logger (logging.Logger | None | str): Logger used for printing
318
+ related information during evaluation. Default: None.
319
+
320
+ Returns:
321
+ dict[str, float]: Default metrics.
322
+ """
323
+
324
+ if isinstance(metric, str):
325
+ metric = [metric]
326
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
327
+ if not set(metric).issubset(set(allowed_metrics)):
328
+ raise KeyError('metric {} is not supported'.format(metric))
329
+ eval_results = {}
330
+ gt_seg_maps = self.get_gt_seg_maps(efficient_test)
331
+ if self.CLASSES is None:
332
+ num_classes = len(
333
+ reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
334
+ else:
335
+ num_classes = len(self.CLASSES)
336
+ ret_metrics = eval_metrics(
337
+ results,
338
+ gt_seg_maps,
339
+ num_classes,
340
+ self.ignore_index,
341
+ metric,
342
+ label_map=self.label_map,
343
+ reduce_zero_label=self.reduce_zero_label)
344
+
345
+ if self.CLASSES is None:
346
+ class_names = tuple(range(num_classes))
347
+ else:
348
+ class_names = self.CLASSES
349
+
350
+ # summary table
351
+ ret_metrics_summary = OrderedDict({
352
+ ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
353
+ for ret_metric, ret_metric_value in ret_metrics.items()
354
+ })
355
+
356
+ # each class table
357
+ ret_metrics.pop('aAcc', None)
358
+ ret_metrics_class = OrderedDict({
359
+ ret_metric: np.round(ret_metric_value * 100, 2)
360
+ for ret_metric, ret_metric_value in ret_metrics.items()
361
+ })
362
+ ret_metrics_class.update({'Class': class_names})
363
+ ret_metrics_class.move_to_end('Class', last=False)
364
+
365
+ try:
366
+ from prettytable import PrettyTable
367
+ # for logger
368
+ class_table_data = PrettyTable()
369
+ for key, val in ret_metrics_class.items():
370
+ class_table_data.add_column(key, val)
371
+
372
+ summary_table_data = PrettyTable()
373
+ for key, val in ret_metrics_summary.items():
374
+ if key == 'aAcc':
375
+ summary_table_data.add_column(key, [val])
376
+ else:
377
+ summary_table_data.add_column('m' + key, [val])
378
+
379
+ print_log('per class results:', logger)
380
+ print_log('\n' + class_table_data.get_string(), logger=logger)
381
+ print_log('Summary:', logger)
382
+ print_log('\n' + summary_table_data.get_string(), logger=logger)
383
+ except ImportError: # prettytable is not installed
384
+ pass
385
+
386
+ # each metric dict
387
+ for key, value in ret_metrics_summary.items():
388
+ if key == 'aAcc':
389
+ eval_results[key] = value / 100.0
390
+ else:
391
+ eval_results['m' + key] = value / 100.0
392
+
393
+ ret_metrics_class.pop('Class', None)
394
+ for key, value in ret_metrics_class.items():
395
+ eval_results.update({
396
+ key + '.' + str(name): value[idx] / 100.0
397
+ for idx, name in enumerate(class_names)
398
+ })
399
+
400
+ if mmcv.is_list_of(results, str):
401
+ for file_name in results:
402
+ os.remove(file_name)
403
+ return eval_results
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/dataset_wrappers.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
2
+
3
+ from .builder import DATASETS
4
+
5
+
6
+ @DATASETS.register_module()
7
+ class ConcatDataset(_ConcatDataset):
8
+ """A wrapper of concatenated dataset.
9
+
10
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
11
+ concat the group flag for image aspect ratio.
12
+
13
+ Args:
14
+ datasets (list[:obj:`Dataset`]): A list of datasets.
15
+ """
16
+
17
+ def __init__(self, datasets):
18
+ super(ConcatDataset, self).__init__(datasets)
19
+ self.CLASSES = datasets[0].CLASSES
20
+ self.PALETTE = datasets[0].PALETTE
21
+
22
+
23
+ @DATASETS.register_module()
24
+ class RepeatDataset(object):
25
+ """A wrapper of repeated dataset.
26
+
27
+ The length of repeated dataset will be `times` larger than the original
28
+ dataset. This is useful when the data loading time is long but the dataset
29
+ is small. Using RepeatDataset can reduce the data loading time between
30
+ epochs.
31
+
32
+ Args:
33
+ dataset (:obj:`Dataset`): The dataset to be repeated.
34
+ times (int): Repeat times.
35
+ """
36
+
37
+ def __init__(self, dataset, times):
38
+ self.dataset = dataset
39
+ self.times = times
40
+ self.CLASSES = dataset.CLASSES
41
+ self.PALETTE = dataset.PALETTE
42
+ self._ori_len = len(self.dataset)
43
+
44
+ def __getitem__(self, idx):
45
+ """Get item from original dataset."""
46
+ return self.dataset[idx % self._ori_len]
47
+
48
+ def __len__(self):
49
+ """The length is multiplied by ``times``"""
50
+ return self.times * self._ori_len
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/drive.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from .builder import DATASETS
4
+ from .custom import CustomDataset
5
+
6
+
7
+ @DATASETS.register_module()
8
+ class DRIVEDataset(CustomDataset):
9
+ """DRIVE dataset.
10
+
11
+ In segmentation map annotation for DRIVE, 0 stands for background, which is
12
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
13
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
14
+ '_manual1.png'.
15
+ """
16
+
17
+ CLASSES = ('background', 'vessel')
18
+
19
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
20
+
21
+ def __init__(self, **kwargs):
22
+ super(DRIVEDataset, self).__init__(
23
+ img_suffix='.png',
24
+ seg_map_suffix='_manual1.png',
25
+ reduce_zero_label=False,
26
+ **kwargs)
27
+ assert osp.exists(self.img_dir)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/hrf.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from .builder import DATASETS
4
+ from .custom import CustomDataset
5
+
6
+
7
+ @DATASETS.register_module()
8
+ class HRFDataset(CustomDataset):
9
+ """HRF dataset.
10
+
11
+ In segmentation map annotation for HRF, 0 stands for background, which is
12
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
13
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
14
+ '.png'.
15
+ """
16
+
17
+ CLASSES = ('background', 'vessel')
18
+
19
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
20
+
21
+ def __init__(self, **kwargs):
22
+ super(HRFDataset, self).__init__(
23
+ img_suffix='.png',
24
+ seg_map_suffix='.png',
25
+ reduce_zero_label=False,
26
+ **kwargs)
27
+ assert osp.exists(self.img_dir)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pascal_context.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from .builder import DATASETS
4
+ from .custom import CustomDataset
5
+
6
+
7
+ @DATASETS.register_module()
8
+ class PascalContextDataset(CustomDataset):
9
+ """PascalContext dataset.
10
+
11
+ In segmentation map annotation for PascalContext, 0 stands for background,
12
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
13
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
14
+ fixed to '.png'.
15
+
16
+ Args:
17
+ split (str): Split txt file for PascalContext.
18
+ """
19
+
20
+ CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
21
+ 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
22
+ 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
23
+ 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
24
+ 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
25
+ 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
26
+ 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
27
+ 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
28
+ 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
29
+ 'window', 'wood')
30
+
31
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
32
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
33
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
34
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
35
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
36
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
37
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
38
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
39
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
40
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
41
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
42
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
43
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
44
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
45
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
46
+
47
+ def __init__(self, split, **kwargs):
48
+ super(PascalContextDataset, self).__init__(
49
+ img_suffix='.jpg',
50
+ seg_map_suffix='.png',
51
+ split=split,
52
+ reduce_zero_label=False,
53
+ **kwargs)
54
+ assert osp.exists(self.img_dir) and self.split is not None
55
+
56
+
57
+ @DATASETS.register_module()
58
+ class PascalContextDataset59(CustomDataset):
59
+ """PascalContext dataset.
60
+
61
+ In segmentation map annotation for PascalContext, 0 stands for background,
62
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
63
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
64
+ fixed to '.png'.
65
+
66
+ Args:
67
+ split (str): Split txt file for PascalContext.
68
+ """
69
+
70
+ CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
71
+ 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
72
+ 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
73
+ 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
74
+ 'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
75
+ 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
76
+ 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
77
+ 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
78
+ 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
79
+
80
+ PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
81
+ [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
82
+ [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
83
+ [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
84
+ [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
85
+ [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
86
+ [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
87
+ [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
88
+ [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
89
+ [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
90
+ [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
91
+ [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
92
+ [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
93
+ [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
94
+ [0, 235, 255], [0, 173, 255], [31, 0, 255]]
95
+
96
+ def __init__(self, split, **kwargs):
97
+ super(PascalContextDataset59, self).__init__(
98
+ img_suffix='.jpg',
99
+ seg_map_suffix='.png',
100
+ split=split,
101
+ reduce_zero_label=True,
102
+ **kwargs)
103
+ assert osp.exists(self.img_dir) and self.split is not None
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .compose import Compose
2
+ from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
3
+ Transpose, to_tensor)
4
+ from .loading import LoadAnnotations, LoadImageFromFile
5
+ from .test_time_aug import MultiScaleFlipAug
6
+ from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
7
+ PhotoMetricDistortion, RandomCrop, RandomFlip,
8
+ RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
9
+
10
+ __all__ = [
11
+ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
12
+ 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
13
+ 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
14
+ 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
15
+ 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
16
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/compose.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+
3
+ from annotator.mmpkg.mmcv.utils import build_from_cfg
4
+
5
+ from ..builder import PIPELINES
6
+
7
+
8
+ @PIPELINES.register_module()
9
+ class Compose(object):
10
+ """Compose multiple transforms sequentially.
11
+
12
+ Args:
13
+ transforms (Sequence[dict | callable]): Sequence of transform object or
14
+ config dict to be composed.
15
+ """
16
+
17
+ def __init__(self, transforms):
18
+ assert isinstance(transforms, collections.abc.Sequence)
19
+ self.transforms = []
20
+ for transform in transforms:
21
+ if isinstance(transform, dict):
22
+ transform = build_from_cfg(transform, PIPELINES)
23
+ self.transforms.append(transform)
24
+ elif callable(transform):
25
+ self.transforms.append(transform)
26
+ else:
27
+ raise TypeError('transform must be callable or a dict')
28
+
29
+ def __call__(self, data):
30
+ """Call function to apply transforms sequentially.
31
+
32
+ Args:
33
+ data (dict): A result dict contains the data to transform.
34
+
35
+ Returns:
36
+ dict: Transformed data.
37
+ """
38
+
39
+ for t in self.transforms:
40
+ data = t(data)
41
+ if data is None:
42
+ return None
43
+ return data
44
+
45
+ def __repr__(self):
46
+ format_string = self.__class__.__name__ + '('
47
+ for t in self.transforms:
48
+ format_string += '\n'
49
+ format_string += f' {t}'
50
+ format_string += '\n)'
51
+ return format_string
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/formating.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import annotator.mmpkg.mmcv as mmcv
4
+ import numpy as np
5
+ import torch
6
+ from annotator.mmpkg.mmcv.parallel import DataContainer as DC
7
+
8
+ from ..builder import PIPELINES
9
+
10
+
11
+ def to_tensor(data):
12
+ """Convert objects of various python types to :obj:`torch.Tensor`.
13
+
14
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
15
+ :class:`Sequence`, :class:`int` and :class:`float`.
16
+
17
+ Args:
18
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
19
+ be converted.
20
+ """
21
+
22
+ if isinstance(data, torch.Tensor):
23
+ return data
24
+ elif isinstance(data, np.ndarray):
25
+ return torch.from_numpy(data)
26
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
27
+ return torch.tensor(data)
28
+ elif isinstance(data, int):
29
+ return torch.LongTensor([data])
30
+ elif isinstance(data, float):
31
+ return torch.FloatTensor([data])
32
+ else:
33
+ raise TypeError(f'type {type(data)} cannot be converted to tensor.')
34
+
35
+
36
+ @PIPELINES.register_module()
37
+ class ToTensor(object):
38
+ """Convert some results to :obj:`torch.Tensor` by given keys.
39
+
40
+ Args:
41
+ keys (Sequence[str]): Keys that need to be converted to Tensor.
42
+ """
43
+
44
+ def __init__(self, keys):
45
+ self.keys = keys
46
+
47
+ def __call__(self, results):
48
+ """Call function to convert data in results to :obj:`torch.Tensor`.
49
+
50
+ Args:
51
+ results (dict): Result dict contains the data to convert.
52
+
53
+ Returns:
54
+ dict: The result dict contains the data converted
55
+ to :obj:`torch.Tensor`.
56
+ """
57
+
58
+ for key in self.keys:
59
+ results[key] = to_tensor(results[key])
60
+ return results
61
+
62
+ def __repr__(self):
63
+ return self.__class__.__name__ + f'(keys={self.keys})'
64
+
65
+
66
+ @PIPELINES.register_module()
67
+ class ImageToTensor(object):
68
+ """Convert image to :obj:`torch.Tensor` by given keys.
69
+
70
+ The dimension order of input image is (H, W, C). The pipeline will convert
71
+ it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
72
+ (1, H, W).
73
+
74
+ Args:
75
+ keys (Sequence[str]): Key of images to be converted to Tensor.
76
+ """
77
+
78
+ def __init__(self, keys):
79
+ self.keys = keys
80
+
81
+ def __call__(self, results):
82
+ """Call function to convert image in results to :obj:`torch.Tensor` and
83
+ transpose the channel order.
84
+
85
+ Args:
86
+ results (dict): Result dict contains the image data to convert.
87
+
88
+ Returns:
89
+ dict: The result dict contains the image converted
90
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
91
+ """
92
+
93
+ for key in self.keys:
94
+ img = results[key]
95
+ if len(img.shape) < 3:
96
+ img = np.expand_dims(img, -1)
97
+ results[key] = to_tensor(img.transpose(2, 0, 1))
98
+ return results
99
+
100
+ def __repr__(self):
101
+ return self.__class__.__name__ + f'(keys={self.keys})'
102
+
103
+
104
+ @PIPELINES.register_module()
105
+ class Transpose(object):
106
+ """Transpose some results by given keys.
107
+
108
+ Args:
109
+ keys (Sequence[str]): Keys of results to be transposed.
110
+ order (Sequence[int]): Order of transpose.
111
+ """
112
+
113
+ def __init__(self, keys, order):
114
+ self.keys = keys
115
+ self.order = order
116
+
117
+ def __call__(self, results):
118
+ """Call function to convert image in results to :obj:`torch.Tensor` and
119
+ transpose the channel order.
120
+
121
+ Args:
122
+ results (dict): Result dict contains the image data to convert.
123
+
124
+ Returns:
125
+ dict: The result dict contains the image converted
126
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
127
+ """
128
+
129
+ for key in self.keys:
130
+ results[key] = results[key].transpose(self.order)
131
+ return results
132
+
133
+ def __repr__(self):
134
+ return self.__class__.__name__ + \
135
+ f'(keys={self.keys}, order={self.order})'
136
+
137
+
138
+ @PIPELINES.register_module()
139
+ class ToDataContainer(object):
140
+ """Convert results to :obj:`mmcv.DataContainer` by given fields.
141
+
142
+ Args:
143
+ fields (Sequence[dict]): Each field is a dict like
144
+ ``dict(key='xxx', **kwargs)``. The ``key`` in result will
145
+ be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
146
+ Default: ``(dict(key='img', stack=True),
147
+ dict(key='gt_semantic_seg'))``.
148
+ """
149
+
150
+ def __init__(self,
151
+ fields=(dict(key='img',
152
+ stack=True), dict(key='gt_semantic_seg'))):
153
+ self.fields = fields
154
+
155
+ def __call__(self, results):
156
+ """Call function to convert data in results to
157
+ :obj:`mmcv.DataContainer`.
158
+
159
+ Args:
160
+ results (dict): Result dict contains the data to convert.
161
+
162
+ Returns:
163
+ dict: The result dict contains the data converted to
164
+ :obj:`mmcv.DataContainer`.
165
+ """
166
+
167
+ for field in self.fields:
168
+ field = field.copy()
169
+ key = field.pop('key')
170
+ results[key] = DC(results[key], **field)
171
+ return results
172
+
173
+ def __repr__(self):
174
+ return self.__class__.__name__ + f'(fields={self.fields})'
175
+
176
+
177
+ @PIPELINES.register_module()
178
+ class DefaultFormatBundle(object):
179
+ """Default formatting bundle.
180
+
181
+ It simplifies the pipeline of formatting common fields, including "img"
182
+ and "gt_semantic_seg". These fields are formatted as follows.
183
+
184
+ - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
185
+ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
186
+ (3)to DataContainer (stack=True)
187
+ """
188
+
189
+ def __call__(self, results):
190
+ """Call function to transform and format common fields in results.
191
+
192
+ Args:
193
+ results (dict): Result dict contains the data to convert.
194
+
195
+ Returns:
196
+ dict: The result dict contains the data that is formatted with
197
+ default bundle.
198
+ """
199
+
200
+ if 'img' in results:
201
+ img = results['img']
202
+ if len(img.shape) < 3:
203
+ img = np.expand_dims(img, -1)
204
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
205
+ results['img'] = DC(to_tensor(img), stack=True)
206
+ if 'gt_semantic_seg' in results:
207
+ # convert to long
208
+ results['gt_semantic_seg'] = DC(
209
+ to_tensor(results['gt_semantic_seg'][None,
210
+ ...].astype(np.int64)),
211
+ stack=True)
212
+ return results
213
+
214
+ def __repr__(self):
215
+ return self.__class__.__name__
216
+
217
+
218
+ @PIPELINES.register_module()
219
+ class Collect(object):
220
+ """Collect data from the loader relevant to the specific task.
221
+
222
+ This is usually the last stage of the data loader pipeline. Typically keys
223
+ is set to some subset of "img", "gt_semantic_seg".
224
+
225
+ The "img_meta" item is always populated. The contents of the "img_meta"
226
+ dictionary depends on "meta_keys". By default this includes:
227
+
228
+ - "img_shape": shape of the image input to the network as a tuple
229
+ (h, w, c). Note that images may be zero padded on the bottom/right
230
+ if the batch tensor is larger than this shape.
231
+
232
+ - "scale_factor": a float indicating the preprocessing scale
233
+
234
+ - "flip": a boolean indicating if image flip transform was used
235
+
236
+ - "filename": path to the image file
237
+
238
+ - "ori_shape": original shape of the image as a tuple (h, w, c)
239
+
240
+ - "pad_shape": image shape after padding
241
+
242
+ - "img_norm_cfg": a dict of normalization information:
243
+ - mean - per channel mean subtraction
244
+ - std - per channel std divisor
245
+ - to_rgb - bool indicating if bgr was converted to rgb
246
+
247
+ Args:
248
+ keys (Sequence[str]): Keys of results to be collected in ``data``.
249
+ meta_keys (Sequence[str], optional): Meta keys to be converted to
250
+ ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
251
+ Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
252
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
253
+ 'img_norm_cfg')``
254
+ """
255
+
256
+ def __init__(self,
257
+ keys,
258
+ meta_keys=('filename', 'ori_filename', 'ori_shape',
259
+ 'img_shape', 'pad_shape', 'scale_factor', 'flip',
260
+ 'flip_direction', 'img_norm_cfg')):
261
+ self.keys = keys
262
+ self.meta_keys = meta_keys
263
+
264
+ def __call__(self, results):
265
+ """Call function to collect keys in results. The keys in ``meta_keys``
266
+ will be converted to :obj:mmcv.DataContainer.
267
+
268
+ Args:
269
+ results (dict): Result dict contains the data to collect.
270
+
271
+ Returns:
272
+ dict: The result dict contains the following keys
273
+ - keys in``self.keys``
274
+ - ``img_metas``
275
+ """
276
+
277
+ data = {}
278
+ img_meta = {}
279
+ for key in self.meta_keys:
280
+ img_meta[key] = results[key]
281
+ data['img_metas'] = DC(img_meta, cpu_only=True)
282
+ for key in self.keys:
283
+ data[key] = results[key]
284
+ return data
285
+
286
+ def __repr__(self):
287
+ return self.__class__.__name__ + \
288
+ f'(keys={self.keys}, meta_keys={self.meta_keys})'