4f3c990a87ec971910ff61a029b3ca21870c19ebc5abc37dc65d3e23339b5592
Browse files- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/ext_loader.py +71 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/logging.py +110 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/misc.py +377 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_jit.py +41 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_wrapper.py +107 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/path.py +101 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/progressbar.py +208 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/registry.py +315 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/testing.py +140 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/timer.py +118 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/trace.py +23 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/version_utils.py +90 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/version.py +35 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/__init__.py +11 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/io.py +318 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/optflow.py +254 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/processing.py +160 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/__init__.py +9 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/color.py +51 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/image.py +152 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/optflow.py +112 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/__init__.py +9 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/inference.py +138 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/test.py +238 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/train.py +116 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/__init__.py +3 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/__init__.py +8 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/class_names.py +152 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/eval_hooks.py +109 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/metrics.py +326 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/__init__.py +4 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/builder.py +8 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/__init__.py +4 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/base_pixel_sampler.py +12 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/ohem_pixel_sampler.py +76 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/__init__.py +3 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/misc.py +17 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/__init__.py +19 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/ade.py +84 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/builder.py +169 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/chase_db1.py +27 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/cityscapes.py +217 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/custom.py +403 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/dataset_wrappers.py +50 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/drive.py +27 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/hrf.py +27 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pascal_context.py +103 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/__init__.py +16 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/compose.py +51 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/formating.py +288 -0
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/ext_loader.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import importlib
|
3 |
+
import os
|
4 |
+
import pkgutil
|
5 |
+
import warnings
|
6 |
+
from collections import namedtuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
if torch.__version__ != 'parrots':
|
11 |
+
|
12 |
+
def load_ext(name, funcs):
|
13 |
+
ext = importlib.import_module('mmcv.' + name)
|
14 |
+
for fun in funcs:
|
15 |
+
assert hasattr(ext, fun), f'{fun} miss in module {name}'
|
16 |
+
return ext
|
17 |
+
else:
|
18 |
+
from parrots import extension
|
19 |
+
from parrots.base import ParrotsException
|
20 |
+
|
21 |
+
has_return_value_ops = [
|
22 |
+
'nms',
|
23 |
+
'softnms',
|
24 |
+
'nms_match',
|
25 |
+
'nms_rotated',
|
26 |
+
'top_pool_forward',
|
27 |
+
'top_pool_backward',
|
28 |
+
'bottom_pool_forward',
|
29 |
+
'bottom_pool_backward',
|
30 |
+
'left_pool_forward',
|
31 |
+
'left_pool_backward',
|
32 |
+
'right_pool_forward',
|
33 |
+
'right_pool_backward',
|
34 |
+
'fused_bias_leakyrelu',
|
35 |
+
'upfirdn2d',
|
36 |
+
'ms_deform_attn_forward',
|
37 |
+
'pixel_group',
|
38 |
+
'contour_expand',
|
39 |
+
]
|
40 |
+
|
41 |
+
def get_fake_func(name, e):
|
42 |
+
|
43 |
+
def fake_func(*args, **kwargs):
|
44 |
+
warnings.warn(f'{name} is not supported in parrots now')
|
45 |
+
raise e
|
46 |
+
|
47 |
+
return fake_func
|
48 |
+
|
49 |
+
def load_ext(name, funcs):
|
50 |
+
ExtModule = namedtuple('ExtModule', funcs)
|
51 |
+
ext_list = []
|
52 |
+
lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
53 |
+
for fun in funcs:
|
54 |
+
try:
|
55 |
+
ext_fun = extension.load(fun, name, lib_dir=lib_root)
|
56 |
+
except ParrotsException as e:
|
57 |
+
if 'No element registered' not in e.message:
|
58 |
+
warnings.warn(e.message)
|
59 |
+
ext_fun = get_fake_func(fun, e)
|
60 |
+
ext_list.append(ext_fun)
|
61 |
+
else:
|
62 |
+
if fun in has_return_value_ops:
|
63 |
+
ext_list.append(ext_fun.op)
|
64 |
+
else:
|
65 |
+
ext_list.append(ext_fun.op_)
|
66 |
+
return ExtModule(*ext_list)
|
67 |
+
|
68 |
+
|
69 |
+
def check_ops_exist():
|
70 |
+
ext_loader = pkgutil.find_loader('mmcv._ext')
|
71 |
+
return ext_loader is not None
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/logging.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
logger_initialized = {}
|
7 |
+
|
8 |
+
|
9 |
+
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
|
10 |
+
"""Initialize and get a logger by name.
|
11 |
+
|
12 |
+
If the logger has not been initialized, this method will initialize the
|
13 |
+
logger by adding one or two handlers, otherwise the initialized logger will
|
14 |
+
be directly returned. During initialization, a StreamHandler will always be
|
15 |
+
added. If `log_file` is specified and the process rank is 0, a FileHandler
|
16 |
+
will also be added.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
name (str): Logger name.
|
20 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
21 |
+
will be added to the logger.
|
22 |
+
log_level (int): The logger level. Note that only the process of
|
23 |
+
rank 0 is affected, and other processes will set the level to
|
24 |
+
"Error" thus be silent most of the time.
|
25 |
+
file_mode (str): The file mode used in opening log file.
|
26 |
+
Defaults to 'w'.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
logging.Logger: The expected logger.
|
30 |
+
"""
|
31 |
+
logger = logging.getLogger(name)
|
32 |
+
if name in logger_initialized:
|
33 |
+
return logger
|
34 |
+
# handle hierarchical names
|
35 |
+
# e.g., logger "a" is initialized, then logger "a.b" will skip the
|
36 |
+
# initialization since it is a child of "a".
|
37 |
+
for logger_name in logger_initialized:
|
38 |
+
if name.startswith(logger_name):
|
39 |
+
return logger
|
40 |
+
|
41 |
+
# handle duplicate logs to the console
|
42 |
+
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
|
43 |
+
# to the root logger. As logger.propagate is True by default, this root
|
44 |
+
# level handler causes logging messages from rank>0 processes to
|
45 |
+
# unexpectedly show up on the console, creating much unwanted clutter.
|
46 |
+
# To fix this issue, we set the root logger's StreamHandler, if any, to log
|
47 |
+
# at the ERROR level.
|
48 |
+
for handler in logger.root.handlers:
|
49 |
+
if type(handler) is logging.StreamHandler:
|
50 |
+
handler.setLevel(logging.ERROR)
|
51 |
+
|
52 |
+
stream_handler = logging.StreamHandler()
|
53 |
+
handlers = [stream_handler]
|
54 |
+
|
55 |
+
if dist.is_available() and dist.is_initialized():
|
56 |
+
rank = dist.get_rank()
|
57 |
+
else:
|
58 |
+
rank = 0
|
59 |
+
|
60 |
+
# only rank 0 will add a FileHandler
|
61 |
+
if rank == 0 and log_file is not None:
|
62 |
+
# Here, the default behaviour of the official logger is 'a'. Thus, we
|
63 |
+
# provide an interface to change the file mode to the default
|
64 |
+
# behaviour.
|
65 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
66 |
+
handlers.append(file_handler)
|
67 |
+
|
68 |
+
formatter = logging.Formatter(
|
69 |
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
70 |
+
for handler in handlers:
|
71 |
+
handler.setFormatter(formatter)
|
72 |
+
handler.setLevel(log_level)
|
73 |
+
logger.addHandler(handler)
|
74 |
+
|
75 |
+
if rank == 0:
|
76 |
+
logger.setLevel(log_level)
|
77 |
+
else:
|
78 |
+
logger.setLevel(logging.ERROR)
|
79 |
+
|
80 |
+
logger_initialized[name] = True
|
81 |
+
|
82 |
+
return logger
|
83 |
+
|
84 |
+
|
85 |
+
def print_log(msg, logger=None, level=logging.INFO):
|
86 |
+
"""Print a log message.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
msg (str): The message to be logged.
|
90 |
+
logger (logging.Logger | str | None): The logger to be used.
|
91 |
+
Some special loggers are:
|
92 |
+
- "silent": no message will be printed.
|
93 |
+
- other str: the logger obtained with `get_root_logger(logger)`.
|
94 |
+
- None: The `print()` method will be used to print log messages.
|
95 |
+
level (int): Logging level. Only available when `logger` is a Logger
|
96 |
+
object or "root".
|
97 |
+
"""
|
98 |
+
if logger is None:
|
99 |
+
print(msg)
|
100 |
+
elif isinstance(logger, logging.Logger):
|
101 |
+
logger.log(level, msg)
|
102 |
+
elif logger == 'silent':
|
103 |
+
pass
|
104 |
+
elif isinstance(logger, str):
|
105 |
+
_logger = get_logger(logger)
|
106 |
+
_logger.log(level, msg)
|
107 |
+
else:
|
108 |
+
raise TypeError(
|
109 |
+
'logger should be either a logging.Logger object, str, '
|
110 |
+
f'"silent" or None, but got {type(logger)}')
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/misc.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import collections.abc
|
3 |
+
import functools
|
4 |
+
import itertools
|
5 |
+
import subprocess
|
6 |
+
import warnings
|
7 |
+
from collections import abc
|
8 |
+
from importlib import import_module
|
9 |
+
from inspect import getfullargspec
|
10 |
+
from itertools import repeat
|
11 |
+
|
12 |
+
|
13 |
+
# From PyTorch internals
|
14 |
+
def _ntuple(n):
|
15 |
+
|
16 |
+
def parse(x):
|
17 |
+
if isinstance(x, collections.abc.Iterable):
|
18 |
+
return x
|
19 |
+
return tuple(repeat(x, n))
|
20 |
+
|
21 |
+
return parse
|
22 |
+
|
23 |
+
|
24 |
+
to_1tuple = _ntuple(1)
|
25 |
+
to_2tuple = _ntuple(2)
|
26 |
+
to_3tuple = _ntuple(3)
|
27 |
+
to_4tuple = _ntuple(4)
|
28 |
+
to_ntuple = _ntuple
|
29 |
+
|
30 |
+
|
31 |
+
def is_str(x):
|
32 |
+
"""Whether the input is an string instance.
|
33 |
+
|
34 |
+
Note: This method is deprecated since python 2 is no longer supported.
|
35 |
+
"""
|
36 |
+
return isinstance(x, str)
|
37 |
+
|
38 |
+
|
39 |
+
def import_modules_from_strings(imports, allow_failed_imports=False):
|
40 |
+
"""Import modules from the given list of strings.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
imports (list | str | None): The given module names to be imported.
|
44 |
+
allow_failed_imports (bool): If True, the failed imports will return
|
45 |
+
None. Otherwise, an ImportError is raise. Default: False.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
list[module] | module | None: The imported modules.
|
49 |
+
|
50 |
+
Examples:
|
51 |
+
>>> osp, sys = import_modules_from_strings(
|
52 |
+
... ['os.path', 'sys'])
|
53 |
+
>>> import os.path as osp_
|
54 |
+
>>> import sys as sys_
|
55 |
+
>>> assert osp == osp_
|
56 |
+
>>> assert sys == sys_
|
57 |
+
"""
|
58 |
+
if not imports:
|
59 |
+
return
|
60 |
+
single_import = False
|
61 |
+
if isinstance(imports, str):
|
62 |
+
single_import = True
|
63 |
+
imports = [imports]
|
64 |
+
if not isinstance(imports, list):
|
65 |
+
raise TypeError(
|
66 |
+
f'custom_imports must be a list but got type {type(imports)}')
|
67 |
+
imported = []
|
68 |
+
for imp in imports:
|
69 |
+
if not isinstance(imp, str):
|
70 |
+
raise TypeError(
|
71 |
+
f'{imp} is of type {type(imp)} and cannot be imported.')
|
72 |
+
try:
|
73 |
+
imported_tmp = import_module(imp)
|
74 |
+
except ImportError:
|
75 |
+
if allow_failed_imports:
|
76 |
+
warnings.warn(f'{imp} failed to import and is ignored.',
|
77 |
+
UserWarning)
|
78 |
+
imported_tmp = None
|
79 |
+
else:
|
80 |
+
raise ImportError
|
81 |
+
imported.append(imported_tmp)
|
82 |
+
if single_import:
|
83 |
+
imported = imported[0]
|
84 |
+
return imported
|
85 |
+
|
86 |
+
|
87 |
+
def iter_cast(inputs, dst_type, return_type=None):
|
88 |
+
"""Cast elements of an iterable object into some type.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
inputs (Iterable): The input object.
|
92 |
+
dst_type (type): Destination type.
|
93 |
+
return_type (type, optional): If specified, the output object will be
|
94 |
+
converted to this type, otherwise an iterator.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
iterator or specified type: The converted object.
|
98 |
+
"""
|
99 |
+
if not isinstance(inputs, abc.Iterable):
|
100 |
+
raise TypeError('inputs must be an iterable object')
|
101 |
+
if not isinstance(dst_type, type):
|
102 |
+
raise TypeError('"dst_type" must be a valid type')
|
103 |
+
|
104 |
+
out_iterable = map(dst_type, inputs)
|
105 |
+
|
106 |
+
if return_type is None:
|
107 |
+
return out_iterable
|
108 |
+
else:
|
109 |
+
return return_type(out_iterable)
|
110 |
+
|
111 |
+
|
112 |
+
def list_cast(inputs, dst_type):
|
113 |
+
"""Cast elements of an iterable object into a list of some type.
|
114 |
+
|
115 |
+
A partial method of :func:`iter_cast`.
|
116 |
+
"""
|
117 |
+
return iter_cast(inputs, dst_type, return_type=list)
|
118 |
+
|
119 |
+
|
120 |
+
def tuple_cast(inputs, dst_type):
|
121 |
+
"""Cast elements of an iterable object into a tuple of some type.
|
122 |
+
|
123 |
+
A partial method of :func:`iter_cast`.
|
124 |
+
"""
|
125 |
+
return iter_cast(inputs, dst_type, return_type=tuple)
|
126 |
+
|
127 |
+
|
128 |
+
def is_seq_of(seq, expected_type, seq_type=None):
|
129 |
+
"""Check whether it is a sequence of some type.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
seq (Sequence): The sequence to be checked.
|
133 |
+
expected_type (type): Expected type of sequence items.
|
134 |
+
seq_type (type, optional): Expected sequence type.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
bool: Whether the sequence is valid.
|
138 |
+
"""
|
139 |
+
if seq_type is None:
|
140 |
+
exp_seq_type = abc.Sequence
|
141 |
+
else:
|
142 |
+
assert isinstance(seq_type, type)
|
143 |
+
exp_seq_type = seq_type
|
144 |
+
if not isinstance(seq, exp_seq_type):
|
145 |
+
return False
|
146 |
+
for item in seq:
|
147 |
+
if not isinstance(item, expected_type):
|
148 |
+
return False
|
149 |
+
return True
|
150 |
+
|
151 |
+
|
152 |
+
def is_list_of(seq, expected_type):
|
153 |
+
"""Check whether it is a list of some type.
|
154 |
+
|
155 |
+
A partial method of :func:`is_seq_of`.
|
156 |
+
"""
|
157 |
+
return is_seq_of(seq, expected_type, seq_type=list)
|
158 |
+
|
159 |
+
|
160 |
+
def is_tuple_of(seq, expected_type):
|
161 |
+
"""Check whether it is a tuple of some type.
|
162 |
+
|
163 |
+
A partial method of :func:`is_seq_of`.
|
164 |
+
"""
|
165 |
+
return is_seq_of(seq, expected_type, seq_type=tuple)
|
166 |
+
|
167 |
+
|
168 |
+
def slice_list(in_list, lens):
|
169 |
+
"""Slice a list into several sub lists by a list of given length.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
in_list (list): The list to be sliced.
|
173 |
+
lens(int or list): The expected length of each out list.
|
174 |
+
|
175 |
+
Returns:
|
176 |
+
list: A list of sliced list.
|
177 |
+
"""
|
178 |
+
if isinstance(lens, int):
|
179 |
+
assert len(in_list) % lens == 0
|
180 |
+
lens = [lens] * int(len(in_list) / lens)
|
181 |
+
if not isinstance(lens, list):
|
182 |
+
raise TypeError('"indices" must be an integer or a list of integers')
|
183 |
+
elif sum(lens) != len(in_list):
|
184 |
+
raise ValueError('sum of lens and list length does not '
|
185 |
+
f'match: {sum(lens)} != {len(in_list)}')
|
186 |
+
out_list = []
|
187 |
+
idx = 0
|
188 |
+
for i in range(len(lens)):
|
189 |
+
out_list.append(in_list[idx:idx + lens[i]])
|
190 |
+
idx += lens[i]
|
191 |
+
return out_list
|
192 |
+
|
193 |
+
|
194 |
+
def concat_list(in_list):
|
195 |
+
"""Concatenate a list of list into a single list.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
in_list (list): The list of list to be merged.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
list: The concatenated flat list.
|
202 |
+
"""
|
203 |
+
return list(itertools.chain(*in_list))
|
204 |
+
|
205 |
+
|
206 |
+
def check_prerequisites(
|
207 |
+
prerequisites,
|
208 |
+
checker,
|
209 |
+
msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
|
210 |
+
'found, please install them first.'): # yapf: disable
|
211 |
+
"""A decorator factory to check if prerequisites are satisfied.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
prerequisites (str of list[str]): Prerequisites to be checked.
|
215 |
+
checker (callable): The checker method that returns True if a
|
216 |
+
prerequisite is meet, False otherwise.
|
217 |
+
msg_tmpl (str): The message template with two variables.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
decorator: A specific decorator.
|
221 |
+
"""
|
222 |
+
|
223 |
+
def wrap(func):
|
224 |
+
|
225 |
+
@functools.wraps(func)
|
226 |
+
def wrapped_func(*args, **kwargs):
|
227 |
+
requirements = [prerequisites] if isinstance(
|
228 |
+
prerequisites, str) else prerequisites
|
229 |
+
missing = []
|
230 |
+
for item in requirements:
|
231 |
+
if not checker(item):
|
232 |
+
missing.append(item)
|
233 |
+
if missing:
|
234 |
+
print(msg_tmpl.format(', '.join(missing), func.__name__))
|
235 |
+
raise RuntimeError('Prerequisites not meet.')
|
236 |
+
else:
|
237 |
+
return func(*args, **kwargs)
|
238 |
+
|
239 |
+
return wrapped_func
|
240 |
+
|
241 |
+
return wrap
|
242 |
+
|
243 |
+
|
244 |
+
def _check_py_package(package):
|
245 |
+
try:
|
246 |
+
import_module(package)
|
247 |
+
except ImportError:
|
248 |
+
return False
|
249 |
+
else:
|
250 |
+
return True
|
251 |
+
|
252 |
+
|
253 |
+
def _check_executable(cmd):
|
254 |
+
if subprocess.call(f'which {cmd}', shell=True) != 0:
|
255 |
+
return False
|
256 |
+
else:
|
257 |
+
return True
|
258 |
+
|
259 |
+
|
260 |
+
def requires_package(prerequisites):
|
261 |
+
"""A decorator to check if some python packages are installed.
|
262 |
+
|
263 |
+
Example:
|
264 |
+
>>> @requires_package('numpy')
|
265 |
+
>>> func(arg1, args):
|
266 |
+
>>> return numpy.zeros(1)
|
267 |
+
array([0.])
|
268 |
+
>>> @requires_package(['numpy', 'non_package'])
|
269 |
+
>>> func(arg1, args):
|
270 |
+
>>> return numpy.zeros(1)
|
271 |
+
ImportError
|
272 |
+
"""
|
273 |
+
return check_prerequisites(prerequisites, checker=_check_py_package)
|
274 |
+
|
275 |
+
|
276 |
+
def requires_executable(prerequisites):
|
277 |
+
"""A decorator to check if some executable files are installed.
|
278 |
+
|
279 |
+
Example:
|
280 |
+
>>> @requires_executable('ffmpeg')
|
281 |
+
>>> func(arg1, args):
|
282 |
+
>>> print(1)
|
283 |
+
1
|
284 |
+
"""
|
285 |
+
return check_prerequisites(prerequisites, checker=_check_executable)
|
286 |
+
|
287 |
+
|
288 |
+
def deprecated_api_warning(name_dict, cls_name=None):
|
289 |
+
"""A decorator to check if some arguments are deprecate and try to replace
|
290 |
+
deprecate src_arg_name to dst_arg_name.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
name_dict(dict):
|
294 |
+
key (str): Deprecate argument names.
|
295 |
+
val (str): Expected argument names.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
func: New function.
|
299 |
+
"""
|
300 |
+
|
301 |
+
def api_warning_wrapper(old_func):
|
302 |
+
|
303 |
+
@functools.wraps(old_func)
|
304 |
+
def new_func(*args, **kwargs):
|
305 |
+
# get the arg spec of the decorated method
|
306 |
+
args_info = getfullargspec(old_func)
|
307 |
+
# get name of the function
|
308 |
+
func_name = old_func.__name__
|
309 |
+
if cls_name is not None:
|
310 |
+
func_name = f'{cls_name}.{func_name}'
|
311 |
+
if args:
|
312 |
+
arg_names = args_info.args[:len(args)]
|
313 |
+
for src_arg_name, dst_arg_name in name_dict.items():
|
314 |
+
if src_arg_name in arg_names:
|
315 |
+
warnings.warn(
|
316 |
+
f'"{src_arg_name}" is deprecated in '
|
317 |
+
f'`{func_name}`, please use "{dst_arg_name}" '
|
318 |
+
'instead')
|
319 |
+
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
|
320 |
+
if kwargs:
|
321 |
+
for src_arg_name, dst_arg_name in name_dict.items():
|
322 |
+
if src_arg_name in kwargs:
|
323 |
+
|
324 |
+
assert dst_arg_name not in kwargs, (
|
325 |
+
f'The expected behavior is to replace '
|
326 |
+
f'the deprecated key `{src_arg_name}` to '
|
327 |
+
f'new key `{dst_arg_name}`, but got them '
|
328 |
+
f'in the arguments at the same time, which '
|
329 |
+
f'is confusing. `{src_arg_name} will be '
|
330 |
+
f'deprecated in the future, please '
|
331 |
+
f'use `{dst_arg_name}` instead.')
|
332 |
+
|
333 |
+
warnings.warn(
|
334 |
+
f'"{src_arg_name}" is deprecated in '
|
335 |
+
f'`{func_name}`, please use "{dst_arg_name}" '
|
336 |
+
'instead')
|
337 |
+
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
|
338 |
+
|
339 |
+
# apply converted arguments to the decorated method
|
340 |
+
output = old_func(*args, **kwargs)
|
341 |
+
return output
|
342 |
+
|
343 |
+
return new_func
|
344 |
+
|
345 |
+
return api_warning_wrapper
|
346 |
+
|
347 |
+
|
348 |
+
def is_method_overridden(method, base_class, derived_class):
|
349 |
+
"""Check if a method of base class is overridden in derived class.
|
350 |
+
|
351 |
+
Args:
|
352 |
+
method (str): the method name to check.
|
353 |
+
base_class (type): the class of the base class.
|
354 |
+
derived_class (type | Any): the class or instance of the derived class.
|
355 |
+
"""
|
356 |
+
assert isinstance(base_class, type), \
|
357 |
+
"base_class doesn't accept instance, Please pass class instead."
|
358 |
+
|
359 |
+
if not isinstance(derived_class, type):
|
360 |
+
derived_class = derived_class.__class__
|
361 |
+
|
362 |
+
base_method = getattr(base_class, method)
|
363 |
+
derived_method = getattr(derived_class, method)
|
364 |
+
return derived_method != base_method
|
365 |
+
|
366 |
+
|
367 |
+
def has_method(obj: object, method: str) -> bool:
|
368 |
+
"""Check whether the object has a method.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
method (str): The method name to check.
|
372 |
+
obj (object): The object to check.
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
bool: True if the object has the method else False.
|
376 |
+
"""
|
377 |
+
return hasattr(obj, method) and callable(getattr(obj, method))
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_jit.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os
|
3 |
+
|
4 |
+
from .parrots_wrapper import TORCH_VERSION
|
5 |
+
|
6 |
+
parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
|
7 |
+
|
8 |
+
if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON':
|
9 |
+
from parrots.jit import pat as jit
|
10 |
+
else:
|
11 |
+
|
12 |
+
def jit(func=None,
|
13 |
+
check_input=None,
|
14 |
+
full_shape=True,
|
15 |
+
derivate=False,
|
16 |
+
coderize=False,
|
17 |
+
optimize=False):
|
18 |
+
|
19 |
+
def wrapper(func):
|
20 |
+
|
21 |
+
def wrapper_inner(*args, **kargs):
|
22 |
+
return func(*args, **kargs)
|
23 |
+
|
24 |
+
return wrapper_inner
|
25 |
+
|
26 |
+
if func is None:
|
27 |
+
return wrapper
|
28 |
+
else:
|
29 |
+
return func
|
30 |
+
|
31 |
+
|
32 |
+
if TORCH_VERSION == 'parrots':
|
33 |
+
from parrots.utils.tester import skip_no_elena
|
34 |
+
else:
|
35 |
+
|
36 |
+
def skip_no_elena(func):
|
37 |
+
|
38 |
+
def wrapper(*args, **kargs):
|
39 |
+
return func(*args, **kargs)
|
40 |
+
|
41 |
+
return wrapper
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_wrapper.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
TORCH_VERSION = torch.__version__
|
7 |
+
|
8 |
+
|
9 |
+
def is_rocm_pytorch() -> bool:
|
10 |
+
is_rocm = False
|
11 |
+
if TORCH_VERSION != 'parrots':
|
12 |
+
try:
|
13 |
+
from torch.utils.cpp_extension import ROCM_HOME
|
14 |
+
is_rocm = True if ((torch.version.hip is not None) and
|
15 |
+
(ROCM_HOME is not None)) else False
|
16 |
+
except ImportError:
|
17 |
+
pass
|
18 |
+
return is_rocm
|
19 |
+
|
20 |
+
|
21 |
+
def _get_cuda_home():
|
22 |
+
if TORCH_VERSION == 'parrots':
|
23 |
+
from parrots.utils.build_extension import CUDA_HOME
|
24 |
+
else:
|
25 |
+
if is_rocm_pytorch():
|
26 |
+
from torch.utils.cpp_extension import ROCM_HOME
|
27 |
+
CUDA_HOME = ROCM_HOME
|
28 |
+
else:
|
29 |
+
from torch.utils.cpp_extension import CUDA_HOME
|
30 |
+
return CUDA_HOME
|
31 |
+
|
32 |
+
|
33 |
+
def get_build_config():
|
34 |
+
if TORCH_VERSION == 'parrots':
|
35 |
+
from parrots.config import get_build_info
|
36 |
+
return get_build_info()
|
37 |
+
else:
|
38 |
+
return torch.__config__.show()
|
39 |
+
|
40 |
+
|
41 |
+
def _get_conv():
|
42 |
+
if TORCH_VERSION == 'parrots':
|
43 |
+
from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
|
44 |
+
else:
|
45 |
+
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
|
46 |
+
return _ConvNd, _ConvTransposeMixin
|
47 |
+
|
48 |
+
|
49 |
+
def _get_dataloader():
|
50 |
+
if TORCH_VERSION == 'parrots':
|
51 |
+
from torch.utils.data import DataLoader, PoolDataLoader
|
52 |
+
else:
|
53 |
+
from torch.utils.data import DataLoader
|
54 |
+
PoolDataLoader = DataLoader
|
55 |
+
return DataLoader, PoolDataLoader
|
56 |
+
|
57 |
+
|
58 |
+
def _get_extension():
|
59 |
+
if TORCH_VERSION == 'parrots':
|
60 |
+
from parrots.utils.build_extension import BuildExtension, Extension
|
61 |
+
CppExtension = partial(Extension, cuda=False)
|
62 |
+
CUDAExtension = partial(Extension, cuda=True)
|
63 |
+
else:
|
64 |
+
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
|
65 |
+
CUDAExtension)
|
66 |
+
return BuildExtension, CppExtension, CUDAExtension
|
67 |
+
|
68 |
+
|
69 |
+
def _get_pool():
|
70 |
+
if TORCH_VERSION == 'parrots':
|
71 |
+
from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
|
72 |
+
_AdaptiveMaxPoolNd, _AvgPoolNd,
|
73 |
+
_MaxPoolNd)
|
74 |
+
else:
|
75 |
+
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
|
76 |
+
_AdaptiveMaxPoolNd, _AvgPoolNd,
|
77 |
+
_MaxPoolNd)
|
78 |
+
return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
|
79 |
+
|
80 |
+
|
81 |
+
def _get_norm():
|
82 |
+
if TORCH_VERSION == 'parrots':
|
83 |
+
from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
|
84 |
+
SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
|
85 |
+
else:
|
86 |
+
from torch.nn.modules.instancenorm import _InstanceNorm
|
87 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
88 |
+
SyncBatchNorm_ = torch.nn.SyncBatchNorm
|
89 |
+
return _BatchNorm, _InstanceNorm, SyncBatchNorm_
|
90 |
+
|
91 |
+
|
92 |
+
_ConvNd, _ConvTransposeMixin = _get_conv()
|
93 |
+
DataLoader, PoolDataLoader = _get_dataloader()
|
94 |
+
BuildExtension, CppExtension, CUDAExtension = _get_extension()
|
95 |
+
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
|
96 |
+
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
|
97 |
+
|
98 |
+
|
99 |
+
class SyncBatchNorm(SyncBatchNorm_):
|
100 |
+
|
101 |
+
def _check_input_dim(self, input):
|
102 |
+
if TORCH_VERSION == 'parrots':
|
103 |
+
if input.dim() < 2:
|
104 |
+
raise ValueError(
|
105 |
+
f'expected at least 2D input (got {input.dim()}D input)')
|
106 |
+
else:
|
107 |
+
super()._check_input_dim(input)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/path.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from .misc import is_str
|
7 |
+
|
8 |
+
|
9 |
+
def is_filepath(x):
|
10 |
+
return is_str(x) or isinstance(x, Path)
|
11 |
+
|
12 |
+
|
13 |
+
def fopen(filepath, *args, **kwargs):
|
14 |
+
if is_str(filepath):
|
15 |
+
return open(filepath, *args, **kwargs)
|
16 |
+
elif isinstance(filepath, Path):
|
17 |
+
return filepath.open(*args, **kwargs)
|
18 |
+
raise ValueError('`filepath` should be a string or a Path')
|
19 |
+
|
20 |
+
|
21 |
+
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
|
22 |
+
if not osp.isfile(filename):
|
23 |
+
raise FileNotFoundError(msg_tmpl.format(filename))
|
24 |
+
|
25 |
+
|
26 |
+
def mkdir_or_exist(dir_name, mode=0o777):
|
27 |
+
if dir_name == '':
|
28 |
+
return
|
29 |
+
dir_name = osp.expanduser(dir_name)
|
30 |
+
os.makedirs(dir_name, mode=mode, exist_ok=True)
|
31 |
+
|
32 |
+
|
33 |
+
def symlink(src, dst, overwrite=True, **kwargs):
|
34 |
+
if os.path.lexists(dst) and overwrite:
|
35 |
+
os.remove(dst)
|
36 |
+
os.symlink(src, dst, **kwargs)
|
37 |
+
|
38 |
+
|
39 |
+
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
|
40 |
+
"""Scan a directory to find the interested files.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
dir_path (str | obj:`Path`): Path of the directory.
|
44 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
45 |
+
interested in. Default: None.
|
46 |
+
recursive (bool, optional): If set to True, recursively scan the
|
47 |
+
directory. Default: False.
|
48 |
+
case_sensitive (bool, optional) : If set to False, ignore the case of
|
49 |
+
suffix. Default: True.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
A generator for all the interested files with relative paths.
|
53 |
+
"""
|
54 |
+
if isinstance(dir_path, (str, Path)):
|
55 |
+
dir_path = str(dir_path)
|
56 |
+
else:
|
57 |
+
raise TypeError('"dir_path" must be a string or Path object')
|
58 |
+
|
59 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
60 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
61 |
+
|
62 |
+
if suffix is not None and not case_sensitive:
|
63 |
+
suffix = suffix.lower() if isinstance(suffix, str) else tuple(
|
64 |
+
item.lower() for item in suffix)
|
65 |
+
|
66 |
+
root = dir_path
|
67 |
+
|
68 |
+
def _scandir(dir_path, suffix, recursive, case_sensitive):
|
69 |
+
for entry in os.scandir(dir_path):
|
70 |
+
if not entry.name.startswith('.') and entry.is_file():
|
71 |
+
rel_path = osp.relpath(entry.path, root)
|
72 |
+
_rel_path = rel_path if case_sensitive else rel_path.lower()
|
73 |
+
if suffix is None or _rel_path.endswith(suffix):
|
74 |
+
yield rel_path
|
75 |
+
elif recursive and os.path.isdir(entry.path):
|
76 |
+
# scan recursively if entry.path is a directory
|
77 |
+
yield from _scandir(entry.path, suffix, recursive,
|
78 |
+
case_sensitive)
|
79 |
+
|
80 |
+
return _scandir(dir_path, suffix, recursive, case_sensitive)
|
81 |
+
|
82 |
+
|
83 |
+
def find_vcs_root(path, markers=('.git', )):
|
84 |
+
"""Finds the root directory (including itself) of specified markers.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
path (str): Path of directory or file.
|
88 |
+
markers (list[str], optional): List of file or directory names.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
The directory contained one of the markers or None if not found.
|
92 |
+
"""
|
93 |
+
if osp.isfile(path):
|
94 |
+
path = osp.dirname(path)
|
95 |
+
|
96 |
+
prev, cur = None, osp.abspath(osp.expanduser(path))
|
97 |
+
while cur != prev:
|
98 |
+
if any(osp.exists(osp.join(cur, marker)) for marker in markers):
|
99 |
+
return cur
|
100 |
+
prev, cur = cur, osp.split(cur)[0]
|
101 |
+
return None
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/progressbar.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import sys
|
3 |
+
from collections.abc import Iterable
|
4 |
+
from multiprocessing import Pool
|
5 |
+
from shutil import get_terminal_size
|
6 |
+
|
7 |
+
from .timer import Timer
|
8 |
+
|
9 |
+
|
10 |
+
class ProgressBar:
|
11 |
+
"""A progress bar which can print the progress."""
|
12 |
+
|
13 |
+
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
|
14 |
+
self.task_num = task_num
|
15 |
+
self.bar_width = bar_width
|
16 |
+
self.completed = 0
|
17 |
+
self.file = file
|
18 |
+
if start:
|
19 |
+
self.start()
|
20 |
+
|
21 |
+
@property
|
22 |
+
def terminal_width(self):
|
23 |
+
width, _ = get_terminal_size()
|
24 |
+
return width
|
25 |
+
|
26 |
+
def start(self):
|
27 |
+
if self.task_num > 0:
|
28 |
+
self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
|
29 |
+
'elapsed: 0s, ETA:')
|
30 |
+
else:
|
31 |
+
self.file.write('completed: 0, elapsed: 0s')
|
32 |
+
self.file.flush()
|
33 |
+
self.timer = Timer()
|
34 |
+
|
35 |
+
def update(self, num_tasks=1):
|
36 |
+
assert num_tasks > 0
|
37 |
+
self.completed += num_tasks
|
38 |
+
elapsed = self.timer.since_start()
|
39 |
+
if elapsed > 0:
|
40 |
+
fps = self.completed / elapsed
|
41 |
+
else:
|
42 |
+
fps = float('inf')
|
43 |
+
if self.task_num > 0:
|
44 |
+
percentage = self.completed / float(self.task_num)
|
45 |
+
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
|
46 |
+
msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
|
47 |
+
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
|
48 |
+
f'ETA: {eta:5}s'
|
49 |
+
|
50 |
+
bar_width = min(self.bar_width,
|
51 |
+
int(self.terminal_width - len(msg)) + 2,
|
52 |
+
int(self.terminal_width * 0.6))
|
53 |
+
bar_width = max(2, bar_width)
|
54 |
+
mark_width = int(bar_width * percentage)
|
55 |
+
bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
|
56 |
+
self.file.write(msg.format(bar_chars))
|
57 |
+
else:
|
58 |
+
self.file.write(
|
59 |
+
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
|
60 |
+
f' {fps:.1f} tasks/s')
|
61 |
+
self.file.flush()
|
62 |
+
|
63 |
+
|
64 |
+
def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
|
65 |
+
"""Track the progress of tasks execution with a progress bar.
|
66 |
+
|
67 |
+
Tasks are done with a simple for-loop.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
func (callable): The function to be applied to each task.
|
71 |
+
tasks (list or tuple[Iterable, int]): A list of tasks or
|
72 |
+
(tasks, total num).
|
73 |
+
bar_width (int): Width of progress bar.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
list: The task results.
|
77 |
+
"""
|
78 |
+
if isinstance(tasks, tuple):
|
79 |
+
assert len(tasks) == 2
|
80 |
+
assert isinstance(tasks[0], Iterable)
|
81 |
+
assert isinstance(tasks[1], int)
|
82 |
+
task_num = tasks[1]
|
83 |
+
tasks = tasks[0]
|
84 |
+
elif isinstance(tasks, Iterable):
|
85 |
+
task_num = len(tasks)
|
86 |
+
else:
|
87 |
+
raise TypeError(
|
88 |
+
'"tasks" must be an iterable object or a (iterator, int) tuple')
|
89 |
+
prog_bar = ProgressBar(task_num, bar_width, file=file)
|
90 |
+
results = []
|
91 |
+
for task in tasks:
|
92 |
+
results.append(func(task, **kwargs))
|
93 |
+
prog_bar.update()
|
94 |
+
prog_bar.file.write('\n')
|
95 |
+
return results
|
96 |
+
|
97 |
+
|
98 |
+
def init_pool(process_num, initializer=None, initargs=None):
|
99 |
+
if initializer is None:
|
100 |
+
return Pool(process_num)
|
101 |
+
elif initargs is None:
|
102 |
+
return Pool(process_num, initializer)
|
103 |
+
else:
|
104 |
+
if not isinstance(initargs, tuple):
|
105 |
+
raise TypeError('"initargs" must be a tuple')
|
106 |
+
return Pool(process_num, initializer, initargs)
|
107 |
+
|
108 |
+
|
109 |
+
def track_parallel_progress(func,
|
110 |
+
tasks,
|
111 |
+
nproc,
|
112 |
+
initializer=None,
|
113 |
+
initargs=None,
|
114 |
+
bar_width=50,
|
115 |
+
chunksize=1,
|
116 |
+
skip_first=False,
|
117 |
+
keep_order=True,
|
118 |
+
file=sys.stdout):
|
119 |
+
"""Track the progress of parallel task execution with a progress bar.
|
120 |
+
|
121 |
+
The built-in :mod:`multiprocessing` module is used for process pools and
|
122 |
+
tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
func (callable): The function to be applied to each task.
|
126 |
+
tasks (list or tuple[Iterable, int]): A list of tasks or
|
127 |
+
(tasks, total num).
|
128 |
+
nproc (int): Process (worker) number.
|
129 |
+
initializer (None or callable): Refer to :class:`multiprocessing.Pool`
|
130 |
+
for details.
|
131 |
+
initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
|
132 |
+
details.
|
133 |
+
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
|
134 |
+
bar_width (int): Width of progress bar.
|
135 |
+
skip_first (bool): Whether to skip the first sample for each worker
|
136 |
+
when estimating fps, since the initialization step may takes
|
137 |
+
longer.
|
138 |
+
keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
|
139 |
+
:func:`Pool.imap_unordered` is used.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
list: The task results.
|
143 |
+
"""
|
144 |
+
if isinstance(tasks, tuple):
|
145 |
+
assert len(tasks) == 2
|
146 |
+
assert isinstance(tasks[0], Iterable)
|
147 |
+
assert isinstance(tasks[1], int)
|
148 |
+
task_num = tasks[1]
|
149 |
+
tasks = tasks[0]
|
150 |
+
elif isinstance(tasks, Iterable):
|
151 |
+
task_num = len(tasks)
|
152 |
+
else:
|
153 |
+
raise TypeError(
|
154 |
+
'"tasks" must be an iterable object or a (iterator, int) tuple')
|
155 |
+
pool = init_pool(nproc, initializer, initargs)
|
156 |
+
start = not skip_first
|
157 |
+
task_num -= nproc * chunksize * int(skip_first)
|
158 |
+
prog_bar = ProgressBar(task_num, bar_width, start, file=file)
|
159 |
+
results = []
|
160 |
+
if keep_order:
|
161 |
+
gen = pool.imap(func, tasks, chunksize)
|
162 |
+
else:
|
163 |
+
gen = pool.imap_unordered(func, tasks, chunksize)
|
164 |
+
for result in gen:
|
165 |
+
results.append(result)
|
166 |
+
if skip_first:
|
167 |
+
if len(results) < nproc * chunksize:
|
168 |
+
continue
|
169 |
+
elif len(results) == nproc * chunksize:
|
170 |
+
prog_bar.start()
|
171 |
+
continue
|
172 |
+
prog_bar.update()
|
173 |
+
prog_bar.file.write('\n')
|
174 |
+
pool.close()
|
175 |
+
pool.join()
|
176 |
+
return results
|
177 |
+
|
178 |
+
|
179 |
+
def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
|
180 |
+
"""Track the progress of tasks iteration or enumeration with a progress
|
181 |
+
bar.
|
182 |
+
|
183 |
+
Tasks are yielded with a simple for-loop.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
tasks (list or tuple[Iterable, int]): A list of tasks or
|
187 |
+
(tasks, total num).
|
188 |
+
bar_width (int): Width of progress bar.
|
189 |
+
|
190 |
+
Yields:
|
191 |
+
list: The task results.
|
192 |
+
"""
|
193 |
+
if isinstance(tasks, tuple):
|
194 |
+
assert len(tasks) == 2
|
195 |
+
assert isinstance(tasks[0], Iterable)
|
196 |
+
assert isinstance(tasks[1], int)
|
197 |
+
task_num = tasks[1]
|
198 |
+
tasks = tasks[0]
|
199 |
+
elif isinstance(tasks, Iterable):
|
200 |
+
task_num = len(tasks)
|
201 |
+
else:
|
202 |
+
raise TypeError(
|
203 |
+
'"tasks" must be an iterable object or a (iterator, int) tuple')
|
204 |
+
prog_bar = ProgressBar(task_num, bar_width, file=file)
|
205 |
+
for task in tasks:
|
206 |
+
yield task
|
207 |
+
prog_bar.update()
|
208 |
+
prog_bar.file.write('\n')
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/registry.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import inspect
|
3 |
+
import warnings
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
from .misc import is_seq_of
|
7 |
+
|
8 |
+
|
9 |
+
def build_from_cfg(cfg, registry, default_args=None):
|
10 |
+
"""Build a module from config dict.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
cfg (dict): Config dict. It should at least contain the key "type".
|
14 |
+
registry (:obj:`Registry`): The registry to search the type from.
|
15 |
+
default_args (dict, optional): Default initialization arguments.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
object: The constructed object.
|
19 |
+
"""
|
20 |
+
if not isinstance(cfg, dict):
|
21 |
+
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
22 |
+
if 'type' not in cfg:
|
23 |
+
if default_args is None or 'type' not in default_args:
|
24 |
+
raise KeyError(
|
25 |
+
'`cfg` or `default_args` must contain the key "type", '
|
26 |
+
f'but got {cfg}\n{default_args}')
|
27 |
+
if not isinstance(registry, Registry):
|
28 |
+
raise TypeError('registry must be an mmcv.Registry object, '
|
29 |
+
f'but got {type(registry)}')
|
30 |
+
if not (isinstance(default_args, dict) or default_args is None):
|
31 |
+
raise TypeError('default_args must be a dict or None, '
|
32 |
+
f'but got {type(default_args)}')
|
33 |
+
|
34 |
+
args = cfg.copy()
|
35 |
+
|
36 |
+
if default_args is not None:
|
37 |
+
for name, value in default_args.items():
|
38 |
+
args.setdefault(name, value)
|
39 |
+
|
40 |
+
obj_type = args.pop('type')
|
41 |
+
if isinstance(obj_type, str):
|
42 |
+
obj_cls = registry.get(obj_type)
|
43 |
+
if obj_cls is None:
|
44 |
+
raise KeyError(
|
45 |
+
f'{obj_type} is not in the {registry.name} registry')
|
46 |
+
elif inspect.isclass(obj_type):
|
47 |
+
obj_cls = obj_type
|
48 |
+
else:
|
49 |
+
raise TypeError(
|
50 |
+
f'type must be a str or valid type, but got {type(obj_type)}')
|
51 |
+
try:
|
52 |
+
return obj_cls(**args)
|
53 |
+
except Exception as e:
|
54 |
+
# Normal TypeError does not print class name.
|
55 |
+
raise type(e)(f'{obj_cls.__name__}: {e}')
|
56 |
+
|
57 |
+
|
58 |
+
class Registry:
|
59 |
+
"""A registry to map strings to classes.
|
60 |
+
|
61 |
+
Registered object could be built from registry.
|
62 |
+
Example:
|
63 |
+
>>> MODELS = Registry('models')
|
64 |
+
>>> @MODELS.register_module()
|
65 |
+
>>> class ResNet:
|
66 |
+
>>> pass
|
67 |
+
>>> resnet = MODELS.build(dict(type='ResNet'))
|
68 |
+
|
69 |
+
Please refer to
|
70 |
+
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
71 |
+
advanced usage.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
name (str): Registry name.
|
75 |
+
build_func(func, optional): Build function to construct instance from
|
76 |
+
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
77 |
+
``build_func`` is specified. If ``parent`` is specified and
|
78 |
+
``build_func`` is not given, ``build_func`` will be inherited
|
79 |
+
from ``parent``. Default: None.
|
80 |
+
parent (Registry, optional): Parent registry. The class registered in
|
81 |
+
children registry could be built from parent. Default: None.
|
82 |
+
scope (str, optional): The scope of registry. It is the key to search
|
83 |
+
for children registry. If not specified, scope will be the name of
|
84 |
+
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
85 |
+
Default: None.
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self, name, build_func=None, parent=None, scope=None):
|
89 |
+
self._name = name
|
90 |
+
self._module_dict = dict()
|
91 |
+
self._children = dict()
|
92 |
+
self._scope = self.infer_scope() if scope is None else scope
|
93 |
+
|
94 |
+
# self.build_func will be set with the following priority:
|
95 |
+
# 1. build_func
|
96 |
+
# 2. parent.build_func
|
97 |
+
# 3. build_from_cfg
|
98 |
+
if build_func is None:
|
99 |
+
if parent is not None:
|
100 |
+
self.build_func = parent.build_func
|
101 |
+
else:
|
102 |
+
self.build_func = build_from_cfg
|
103 |
+
else:
|
104 |
+
self.build_func = build_func
|
105 |
+
if parent is not None:
|
106 |
+
assert isinstance(parent, Registry)
|
107 |
+
parent._add_children(self)
|
108 |
+
self.parent = parent
|
109 |
+
else:
|
110 |
+
self.parent = None
|
111 |
+
|
112 |
+
def __len__(self):
|
113 |
+
return len(self._module_dict)
|
114 |
+
|
115 |
+
def __contains__(self, key):
|
116 |
+
return self.get(key) is not None
|
117 |
+
|
118 |
+
def __repr__(self):
|
119 |
+
format_str = self.__class__.__name__ + \
|
120 |
+
f'(name={self._name}, ' \
|
121 |
+
f'items={self._module_dict})'
|
122 |
+
return format_str
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def infer_scope():
|
126 |
+
"""Infer the scope of registry.
|
127 |
+
|
128 |
+
The name of the package where registry is defined will be returned.
|
129 |
+
|
130 |
+
Example:
|
131 |
+
# in mmdet/models/backbone/resnet.py
|
132 |
+
>>> MODELS = Registry('models')
|
133 |
+
>>> @MODELS.register_module()
|
134 |
+
>>> class ResNet:
|
135 |
+
>>> pass
|
136 |
+
The scope of ``ResNet`` will be ``mmdet``.
|
137 |
+
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
scope (str): The inferred scope name.
|
141 |
+
"""
|
142 |
+
# inspect.stack() trace where this function is called, the index-2
|
143 |
+
# indicates the frame where `infer_scope()` is called
|
144 |
+
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
|
145 |
+
split_filename = filename.split('.')
|
146 |
+
return split_filename[0]
|
147 |
+
|
148 |
+
@staticmethod
|
149 |
+
def split_scope_key(key):
|
150 |
+
"""Split scope and key.
|
151 |
+
|
152 |
+
The first scope will be split from key.
|
153 |
+
|
154 |
+
Examples:
|
155 |
+
>>> Registry.split_scope_key('mmdet.ResNet')
|
156 |
+
'mmdet', 'ResNet'
|
157 |
+
>>> Registry.split_scope_key('ResNet')
|
158 |
+
None, 'ResNet'
|
159 |
+
|
160 |
+
Return:
|
161 |
+
scope (str, None): The first scope.
|
162 |
+
key (str): The remaining key.
|
163 |
+
"""
|
164 |
+
split_index = key.find('.')
|
165 |
+
if split_index != -1:
|
166 |
+
return key[:split_index], key[split_index + 1:]
|
167 |
+
else:
|
168 |
+
return None, key
|
169 |
+
|
170 |
+
@property
|
171 |
+
def name(self):
|
172 |
+
return self._name
|
173 |
+
|
174 |
+
@property
|
175 |
+
def scope(self):
|
176 |
+
return self._scope
|
177 |
+
|
178 |
+
@property
|
179 |
+
def module_dict(self):
|
180 |
+
return self._module_dict
|
181 |
+
|
182 |
+
@property
|
183 |
+
def children(self):
|
184 |
+
return self._children
|
185 |
+
|
186 |
+
def get(self, key):
|
187 |
+
"""Get the registry record.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
key (str): The class name in string format.
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
class: The corresponding class.
|
194 |
+
"""
|
195 |
+
scope, real_key = self.split_scope_key(key)
|
196 |
+
if scope is None or scope == self._scope:
|
197 |
+
# get from self
|
198 |
+
if real_key in self._module_dict:
|
199 |
+
return self._module_dict[real_key]
|
200 |
+
else:
|
201 |
+
# get from self._children
|
202 |
+
if scope in self._children:
|
203 |
+
return self._children[scope].get(real_key)
|
204 |
+
else:
|
205 |
+
# goto root
|
206 |
+
parent = self.parent
|
207 |
+
while parent.parent is not None:
|
208 |
+
parent = parent.parent
|
209 |
+
return parent.get(key)
|
210 |
+
|
211 |
+
def build(self, *args, **kwargs):
|
212 |
+
return self.build_func(*args, **kwargs, registry=self)
|
213 |
+
|
214 |
+
def _add_children(self, registry):
|
215 |
+
"""Add children for a registry.
|
216 |
+
|
217 |
+
The ``registry`` will be added as children based on its scope.
|
218 |
+
The parent registry could build objects from children registry.
|
219 |
+
|
220 |
+
Example:
|
221 |
+
>>> models = Registry('models')
|
222 |
+
>>> mmdet_models = Registry('models', parent=models)
|
223 |
+
>>> @mmdet_models.register_module()
|
224 |
+
>>> class ResNet:
|
225 |
+
>>> pass
|
226 |
+
>>> resnet = models.build(dict(type='mmdet.ResNet'))
|
227 |
+
"""
|
228 |
+
|
229 |
+
assert isinstance(registry, Registry)
|
230 |
+
assert registry.scope is not None
|
231 |
+
assert registry.scope not in self.children, \
|
232 |
+
f'scope {registry.scope} exists in {self.name} registry'
|
233 |
+
self.children[registry.scope] = registry
|
234 |
+
|
235 |
+
def _register_module(self, module_class, module_name=None, force=False):
|
236 |
+
if not inspect.isclass(module_class):
|
237 |
+
raise TypeError('module must be a class, '
|
238 |
+
f'but got {type(module_class)}')
|
239 |
+
|
240 |
+
if module_name is None:
|
241 |
+
module_name = module_class.__name__
|
242 |
+
if isinstance(module_name, str):
|
243 |
+
module_name = [module_name]
|
244 |
+
for name in module_name:
|
245 |
+
if not force and name in self._module_dict:
|
246 |
+
raise KeyError(f'{name} is already registered '
|
247 |
+
f'in {self.name}')
|
248 |
+
self._module_dict[name] = module_class
|
249 |
+
|
250 |
+
def deprecated_register_module(self, cls=None, force=False):
|
251 |
+
warnings.warn(
|
252 |
+
'The old API of register_module(module, force=False) '
|
253 |
+
'is deprecated and will be removed, please use the new API '
|
254 |
+
'register_module(name=None, force=False, module=None) instead.')
|
255 |
+
if cls is None:
|
256 |
+
return partial(self.deprecated_register_module, force=force)
|
257 |
+
self._register_module(cls, force=force)
|
258 |
+
return cls
|
259 |
+
|
260 |
+
def register_module(self, name=None, force=False, module=None):
|
261 |
+
"""Register a module.
|
262 |
+
|
263 |
+
A record will be added to `self._module_dict`, whose key is the class
|
264 |
+
name or the specified name, and value is the class itself.
|
265 |
+
It can be used as a decorator or a normal function.
|
266 |
+
|
267 |
+
Example:
|
268 |
+
>>> backbones = Registry('backbone')
|
269 |
+
>>> @backbones.register_module()
|
270 |
+
>>> class ResNet:
|
271 |
+
>>> pass
|
272 |
+
|
273 |
+
>>> backbones = Registry('backbone')
|
274 |
+
>>> @backbones.register_module(name='mnet')
|
275 |
+
>>> class MobileNet:
|
276 |
+
>>> pass
|
277 |
+
|
278 |
+
>>> backbones = Registry('backbone')
|
279 |
+
>>> class ResNet:
|
280 |
+
>>> pass
|
281 |
+
>>> backbones.register_module(ResNet)
|
282 |
+
|
283 |
+
Args:
|
284 |
+
name (str | None): The module name to be registered. If not
|
285 |
+
specified, the class name will be used.
|
286 |
+
force (bool, optional): Whether to override an existing class with
|
287 |
+
the same name. Default: False.
|
288 |
+
module (type): Module class to be registered.
|
289 |
+
"""
|
290 |
+
if not isinstance(force, bool):
|
291 |
+
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
292 |
+
# NOTE: This is a walkaround to be compatible with the old api,
|
293 |
+
# while it may introduce unexpected bugs.
|
294 |
+
if isinstance(name, type):
|
295 |
+
return self.deprecated_register_module(name, force=force)
|
296 |
+
|
297 |
+
# raise the error ahead of time
|
298 |
+
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
|
299 |
+
raise TypeError(
|
300 |
+
'name must be either of None, an instance of str or a sequence'
|
301 |
+
f' of str, but got {type(name)}')
|
302 |
+
|
303 |
+
# use it as a normal method: x.register_module(module=SomeClass)
|
304 |
+
if module is not None:
|
305 |
+
self._register_module(
|
306 |
+
module_class=module, module_name=name, force=force)
|
307 |
+
return module
|
308 |
+
|
309 |
+
# use it as a decorator: @x.register_module()
|
310 |
+
def _register(cls):
|
311 |
+
self._register_module(
|
312 |
+
module_class=cls, module_name=name, force=force)
|
313 |
+
return cls
|
314 |
+
|
315 |
+
return _register
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/testing.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Open-MMLab.
|
2 |
+
import sys
|
3 |
+
from collections.abc import Iterable
|
4 |
+
from runpy import run_path
|
5 |
+
from shlex import split
|
6 |
+
from typing import Any, Dict, List
|
7 |
+
from unittest.mock import patch
|
8 |
+
|
9 |
+
|
10 |
+
def check_python_script(cmd):
|
11 |
+
"""Run the python cmd script with `__main__`. The difference between
|
12 |
+
`os.system` is that, this function exectues code in the current process, so
|
13 |
+
that it can be tracked by coverage tools. Currently it supports two forms:
|
14 |
+
|
15 |
+
- ./tests/data/scripts/hello.py zz
|
16 |
+
- python tests/data/scripts/hello.py zz
|
17 |
+
"""
|
18 |
+
args = split(cmd)
|
19 |
+
if args[0] == 'python':
|
20 |
+
args = args[1:]
|
21 |
+
with patch.object(sys, 'argv', args):
|
22 |
+
run_path(args[0], run_name='__main__')
|
23 |
+
|
24 |
+
|
25 |
+
def _any(judge_result):
|
26 |
+
"""Since built-in ``any`` works only when the element of iterable is not
|
27 |
+
iterable, implement the function."""
|
28 |
+
if not isinstance(judge_result, Iterable):
|
29 |
+
return judge_result
|
30 |
+
|
31 |
+
try:
|
32 |
+
for element in judge_result:
|
33 |
+
if _any(element):
|
34 |
+
return True
|
35 |
+
except TypeError:
|
36 |
+
# Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
|
37 |
+
if judge_result:
|
38 |
+
return True
|
39 |
+
return False
|
40 |
+
|
41 |
+
|
42 |
+
def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
|
43 |
+
expected_subset: Dict[Any, Any]) -> bool:
|
44 |
+
"""Check if the dict_obj contains the expected_subset.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
dict_obj (Dict[Any, Any]): Dict object to be checked.
|
48 |
+
expected_subset (Dict[Any, Any]): Subset expected to be contained in
|
49 |
+
dict_obj.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
bool: Whether the dict_obj contains the expected_subset.
|
53 |
+
"""
|
54 |
+
|
55 |
+
for key, value in expected_subset.items():
|
56 |
+
if key not in dict_obj.keys() or _any(dict_obj[key] != value):
|
57 |
+
return False
|
58 |
+
return True
|
59 |
+
|
60 |
+
|
61 |
+
def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
|
62 |
+
"""Check if attribute of class object is correct.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
obj (object): Class object to be checked.
|
66 |
+
expected_attrs (Dict[str, Any]): Dict of the expected attrs.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
bool: Whether the attribute of class object is correct.
|
70 |
+
"""
|
71 |
+
for attr, value in expected_attrs.items():
|
72 |
+
if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
|
73 |
+
return False
|
74 |
+
return True
|
75 |
+
|
76 |
+
|
77 |
+
def assert_dict_has_keys(obj: Dict[str, Any],
|
78 |
+
expected_keys: List[str]) -> bool:
|
79 |
+
"""Check if the obj has all the expected_keys.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
obj (Dict[str, Any]): Object to be checked.
|
83 |
+
expected_keys (List[str]): Keys expected to contained in the keys of
|
84 |
+
the obj.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
bool: Whether the obj has the expected keys.
|
88 |
+
"""
|
89 |
+
return set(expected_keys).issubset(set(obj.keys()))
|
90 |
+
|
91 |
+
|
92 |
+
def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
|
93 |
+
"""Check if target_keys is equal to result_keys.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
result_keys (List[str]): Result keys to be checked.
|
97 |
+
target_keys (List[str]): Target keys to be checked.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
bool: Whether target_keys is equal to result_keys.
|
101 |
+
"""
|
102 |
+
return set(result_keys) == set(target_keys)
|
103 |
+
|
104 |
+
|
105 |
+
def assert_is_norm_layer(module) -> bool:
|
106 |
+
"""Check if the module is a norm layer.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
module (nn.Module): The module to be checked.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
bool: Whether the module is a norm layer.
|
113 |
+
"""
|
114 |
+
from .parrots_wrapper import _BatchNorm, _InstanceNorm
|
115 |
+
from torch.nn import GroupNorm, LayerNorm
|
116 |
+
norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
|
117 |
+
return isinstance(module, norm_layer_candidates)
|
118 |
+
|
119 |
+
|
120 |
+
def assert_params_all_zeros(module) -> bool:
|
121 |
+
"""Check if the parameters of the module is all zeros.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
module (nn.Module): The module to be checked.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
bool: Whether the parameters of the module is all zeros.
|
128 |
+
"""
|
129 |
+
weight_data = module.weight.data
|
130 |
+
is_weight_zero = weight_data.allclose(
|
131 |
+
weight_data.new_zeros(weight_data.size()))
|
132 |
+
|
133 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
134 |
+
bias_data = module.bias.data
|
135 |
+
is_bias_zero = bias_data.allclose(
|
136 |
+
bias_data.new_zeros(bias_data.size()))
|
137 |
+
else:
|
138 |
+
is_bias_zero = True
|
139 |
+
|
140 |
+
return is_weight_zero and is_bias_zero
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/timer.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from time import time
|
3 |
+
|
4 |
+
|
5 |
+
class TimerError(Exception):
|
6 |
+
|
7 |
+
def __init__(self, message):
|
8 |
+
self.message = message
|
9 |
+
super(TimerError, self).__init__(message)
|
10 |
+
|
11 |
+
|
12 |
+
class Timer:
|
13 |
+
"""A flexible Timer class.
|
14 |
+
|
15 |
+
:Example:
|
16 |
+
|
17 |
+
>>> import time
|
18 |
+
>>> import annotator.mmpkg.mmcv as mmcv
|
19 |
+
>>> with mmcv.Timer():
|
20 |
+
>>> # simulate a code block that will run for 1s
|
21 |
+
>>> time.sleep(1)
|
22 |
+
1.000
|
23 |
+
>>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
|
24 |
+
>>> # simulate a code block that will run for 1s
|
25 |
+
>>> time.sleep(1)
|
26 |
+
it takes 1.0 seconds
|
27 |
+
>>> timer = mmcv.Timer()
|
28 |
+
>>> time.sleep(0.5)
|
29 |
+
>>> print(timer.since_start())
|
30 |
+
0.500
|
31 |
+
>>> time.sleep(0.5)
|
32 |
+
>>> print(timer.since_last_check())
|
33 |
+
0.500
|
34 |
+
>>> print(timer.since_start())
|
35 |
+
1.000
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, start=True, print_tmpl=None):
|
39 |
+
self._is_running = False
|
40 |
+
self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
|
41 |
+
if start:
|
42 |
+
self.start()
|
43 |
+
|
44 |
+
@property
|
45 |
+
def is_running(self):
|
46 |
+
"""bool: indicate whether the timer is running"""
|
47 |
+
return self._is_running
|
48 |
+
|
49 |
+
def __enter__(self):
|
50 |
+
self.start()
|
51 |
+
return self
|
52 |
+
|
53 |
+
def __exit__(self, type, value, traceback):
|
54 |
+
print(self.print_tmpl.format(self.since_last_check()))
|
55 |
+
self._is_running = False
|
56 |
+
|
57 |
+
def start(self):
|
58 |
+
"""Start the timer."""
|
59 |
+
if not self._is_running:
|
60 |
+
self._t_start = time()
|
61 |
+
self._is_running = True
|
62 |
+
self._t_last = time()
|
63 |
+
|
64 |
+
def since_start(self):
|
65 |
+
"""Total time since the timer is started.
|
66 |
+
|
67 |
+
Returns (float): Time in seconds.
|
68 |
+
"""
|
69 |
+
if not self._is_running:
|
70 |
+
raise TimerError('timer is not running')
|
71 |
+
self._t_last = time()
|
72 |
+
return self._t_last - self._t_start
|
73 |
+
|
74 |
+
def since_last_check(self):
|
75 |
+
"""Time since the last checking.
|
76 |
+
|
77 |
+
Either :func:`since_start` or :func:`since_last_check` is a checking
|
78 |
+
operation.
|
79 |
+
|
80 |
+
Returns (float): Time in seconds.
|
81 |
+
"""
|
82 |
+
if not self._is_running:
|
83 |
+
raise TimerError('timer is not running')
|
84 |
+
dur = time() - self._t_last
|
85 |
+
self._t_last = time()
|
86 |
+
return dur
|
87 |
+
|
88 |
+
|
89 |
+
_g_timers = {} # global timers
|
90 |
+
|
91 |
+
|
92 |
+
def check_time(timer_id):
|
93 |
+
"""Add check points in a single line.
|
94 |
+
|
95 |
+
This method is suitable for running a task on a list of items. A timer will
|
96 |
+
be registered when the method is called for the first time.
|
97 |
+
|
98 |
+
:Example:
|
99 |
+
|
100 |
+
>>> import time
|
101 |
+
>>> import annotator.mmpkg.mmcv as mmcv
|
102 |
+
>>> for i in range(1, 6):
|
103 |
+
>>> # simulate a code block
|
104 |
+
>>> time.sleep(i)
|
105 |
+
>>> mmcv.check_time('task1')
|
106 |
+
2.000
|
107 |
+
3.000
|
108 |
+
4.000
|
109 |
+
5.000
|
110 |
+
|
111 |
+
Args:
|
112 |
+
timer_id (str): Timer identifier.
|
113 |
+
"""
|
114 |
+
if timer_id not in _g_timers:
|
115 |
+
_g_timers[timer_id] = Timer()
|
116 |
+
return 0
|
117 |
+
else:
|
118 |
+
return _g_timers[timer_id].since_last_check()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/trace.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from annotator.mmpkg.mmcv.utils import digit_version
|
6 |
+
|
7 |
+
|
8 |
+
def is_jit_tracing() -> bool:
|
9 |
+
if (torch.__version__ != 'parrots'
|
10 |
+
and digit_version(torch.__version__) >= digit_version('1.6.0')):
|
11 |
+
on_trace = torch.jit.is_tracing()
|
12 |
+
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
|
13 |
+
# Refers to https://github.com/pytorch/pytorch/issues/42448
|
14 |
+
if isinstance(on_trace, bool):
|
15 |
+
return on_trace
|
16 |
+
else:
|
17 |
+
return torch._C._is_tracing()
|
18 |
+
else:
|
19 |
+
warnings.warn(
|
20 |
+
'torch.jit.is_tracing is only supported after v1.6.0. '
|
21 |
+
'Therefore is_tracing returns False automatically. Please '
|
22 |
+
'set on_trace manually if you are using trace.', UserWarning)
|
23 |
+
return False
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/version_utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
from packaging.version import parse
|
7 |
+
|
8 |
+
|
9 |
+
def digit_version(version_str: str, length: int = 4):
|
10 |
+
"""Convert a version string into a tuple of integers.
|
11 |
+
|
12 |
+
This method is usually used for comparing two versions. For pre-release
|
13 |
+
versions: alpha < beta < rc.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
version_str (str): The version string.
|
17 |
+
length (int): The maximum number of version levels. Default: 4.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
tuple[int]: The version info in digits (integers).
|
21 |
+
"""
|
22 |
+
assert 'parrots' not in version_str
|
23 |
+
version = parse(version_str)
|
24 |
+
assert version.release, f'failed to parse version {version_str}'
|
25 |
+
release = list(version.release)
|
26 |
+
release = release[:length]
|
27 |
+
if len(release) < length:
|
28 |
+
release = release + [0] * (length - len(release))
|
29 |
+
if version.is_prerelease:
|
30 |
+
mapping = {'a': -3, 'b': -2, 'rc': -1}
|
31 |
+
val = -4
|
32 |
+
# version.pre can be None
|
33 |
+
if version.pre:
|
34 |
+
if version.pre[0] not in mapping:
|
35 |
+
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
|
36 |
+
'version checking may go wrong')
|
37 |
+
else:
|
38 |
+
val = mapping[version.pre[0]]
|
39 |
+
release.extend([val, version.pre[-1]])
|
40 |
+
else:
|
41 |
+
release.extend([val, 0])
|
42 |
+
|
43 |
+
elif version.is_postrelease:
|
44 |
+
release.extend([1, version.post])
|
45 |
+
else:
|
46 |
+
release.extend([0, 0])
|
47 |
+
return tuple(release)
|
48 |
+
|
49 |
+
|
50 |
+
def _minimal_ext_cmd(cmd):
|
51 |
+
# construct minimal environment
|
52 |
+
env = {}
|
53 |
+
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
54 |
+
v = os.environ.get(k)
|
55 |
+
if v is not None:
|
56 |
+
env[k] = v
|
57 |
+
# LANGUAGE is used on win32
|
58 |
+
env['LANGUAGE'] = 'C'
|
59 |
+
env['LANG'] = 'C'
|
60 |
+
env['LC_ALL'] = 'C'
|
61 |
+
out = subprocess.Popen(
|
62 |
+
cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
|
63 |
+
return out
|
64 |
+
|
65 |
+
|
66 |
+
def get_git_hash(fallback='unknown', digits=None):
|
67 |
+
"""Get the git hash of the current repo.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
fallback (str, optional): The fallback string when git hash is
|
71 |
+
unavailable. Defaults to 'unknown'.
|
72 |
+
digits (int, optional): kept digits of the hash. Defaults to None,
|
73 |
+
meaning all digits are kept.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
str: Git commit hash.
|
77 |
+
"""
|
78 |
+
|
79 |
+
if digits is not None and not isinstance(digits, int):
|
80 |
+
raise TypeError('digits must be None or an integer')
|
81 |
+
|
82 |
+
try:
|
83 |
+
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
|
84 |
+
sha = out.strip().decode('ascii')
|
85 |
+
if digits is not None:
|
86 |
+
sha = sha[:digits]
|
87 |
+
except OSError:
|
88 |
+
sha = fallback
|
89 |
+
|
90 |
+
return sha
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/version.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
__version__ = '1.3.17'
|
3 |
+
|
4 |
+
|
5 |
+
def parse_version_info(version_str: str, length: int = 4) -> tuple:
|
6 |
+
"""Parse a version string into a tuple.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
version_str (str): The version string.
|
10 |
+
length (int): The maximum number of version levels. Default: 4.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
|
14 |
+
(1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
|
15 |
+
(2, 0, 0, 0, 'rc', 1) (when length is set to 4).
|
16 |
+
"""
|
17 |
+
from packaging.version import parse
|
18 |
+
version = parse(version_str)
|
19 |
+
assert version.release, f'failed to parse version {version_str}'
|
20 |
+
release = list(version.release)
|
21 |
+
release = release[:length]
|
22 |
+
if len(release) < length:
|
23 |
+
release = release + [0] * (length - len(release))
|
24 |
+
if version.is_prerelease:
|
25 |
+
release.extend(list(version.pre))
|
26 |
+
elif version.is_postrelease:
|
27 |
+
release.extend(list(version.post))
|
28 |
+
else:
|
29 |
+
release.extend([0, 0])
|
30 |
+
return tuple(release)
|
31 |
+
|
32 |
+
|
33 |
+
version_info = tuple(int(x) for x in __version__.split('.')[:3])
|
34 |
+
|
35 |
+
__all__ = ['__version__', 'version_info', 'parse_version_info']
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .io import Cache, VideoReader, frames2video
|
3 |
+
from .optflow import (dequantize_flow, flow_from_bytes, flow_warp, flowread,
|
4 |
+
flowwrite, quantize_flow, sparse_flow_from_bytes)
|
5 |
+
from .processing import concat_video, convert_video, cut_video, resize_video
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
|
9 |
+
'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow',
|
10 |
+
'dequantize_flow', 'flow_warp', 'flow_from_bytes', 'sparse_flow_from_bytes'
|
11 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/io.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
|
7 |
+
CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
|
8 |
+
CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
|
9 |
+
|
10 |
+
from annotator.mmpkg.mmcv.utils import (check_file_exist, mkdir_or_exist, scandir,
|
11 |
+
track_progress)
|
12 |
+
|
13 |
+
|
14 |
+
class Cache:
|
15 |
+
|
16 |
+
def __init__(self, capacity):
|
17 |
+
self._cache = OrderedDict()
|
18 |
+
self._capacity = int(capacity)
|
19 |
+
if capacity <= 0:
|
20 |
+
raise ValueError('capacity must be a positive integer')
|
21 |
+
|
22 |
+
@property
|
23 |
+
def capacity(self):
|
24 |
+
return self._capacity
|
25 |
+
|
26 |
+
@property
|
27 |
+
def size(self):
|
28 |
+
return len(self._cache)
|
29 |
+
|
30 |
+
def put(self, key, val):
|
31 |
+
if key in self._cache:
|
32 |
+
return
|
33 |
+
if len(self._cache) >= self.capacity:
|
34 |
+
self._cache.popitem(last=False)
|
35 |
+
self._cache[key] = val
|
36 |
+
|
37 |
+
def get(self, key, default=None):
|
38 |
+
val = self._cache[key] if key in self._cache else default
|
39 |
+
return val
|
40 |
+
|
41 |
+
|
42 |
+
class VideoReader:
|
43 |
+
"""Video class with similar usage to a list object.
|
44 |
+
|
45 |
+
This video warpper class provides convenient apis to access frames.
|
46 |
+
There exists an issue of OpenCV's VideoCapture class that jumping to a
|
47 |
+
certain frame may be inaccurate. It is fixed in this class by checking
|
48 |
+
the position after jumping each time.
|
49 |
+
Cache is used when decoding videos. So if the same frame is visited for
|
50 |
+
the second time, there is no need to decode again if it is stored in the
|
51 |
+
cache.
|
52 |
+
|
53 |
+
:Example:
|
54 |
+
|
55 |
+
>>> import annotator.mmpkg.mmcv as mmcv
|
56 |
+
>>> v = mmcv.VideoReader('sample.mp4')
|
57 |
+
>>> len(v) # get the total frame number with `len()`
|
58 |
+
120
|
59 |
+
>>> for img in v: # v is iterable
|
60 |
+
>>> mmcv.imshow(img)
|
61 |
+
>>> v[5] # get the 6th frame
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, filename, cache_capacity=10):
|
65 |
+
# Check whether the video path is a url
|
66 |
+
if not filename.startswith(('https://', 'http://')):
|
67 |
+
check_file_exist(filename, 'Video file not found: ' + filename)
|
68 |
+
self._vcap = cv2.VideoCapture(filename)
|
69 |
+
assert cache_capacity > 0
|
70 |
+
self._cache = Cache(cache_capacity)
|
71 |
+
self._position = 0
|
72 |
+
# get basic info
|
73 |
+
self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
|
74 |
+
self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
|
75 |
+
self._fps = self._vcap.get(CAP_PROP_FPS)
|
76 |
+
self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
|
77 |
+
self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
|
78 |
+
|
79 |
+
@property
|
80 |
+
def vcap(self):
|
81 |
+
""":obj:`cv2.VideoCapture`: The raw VideoCapture object."""
|
82 |
+
return self._vcap
|
83 |
+
|
84 |
+
@property
|
85 |
+
def opened(self):
|
86 |
+
"""bool: Indicate whether the video is opened."""
|
87 |
+
return self._vcap.isOpened()
|
88 |
+
|
89 |
+
@property
|
90 |
+
def width(self):
|
91 |
+
"""int: Width of video frames."""
|
92 |
+
return self._width
|
93 |
+
|
94 |
+
@property
|
95 |
+
def height(self):
|
96 |
+
"""int: Height of video frames."""
|
97 |
+
return self._height
|
98 |
+
|
99 |
+
@property
|
100 |
+
def resolution(self):
|
101 |
+
"""tuple: Video resolution (width, height)."""
|
102 |
+
return (self._width, self._height)
|
103 |
+
|
104 |
+
@property
|
105 |
+
def fps(self):
|
106 |
+
"""float: FPS of the video."""
|
107 |
+
return self._fps
|
108 |
+
|
109 |
+
@property
|
110 |
+
def frame_cnt(self):
|
111 |
+
"""int: Total frames of the video."""
|
112 |
+
return self._frame_cnt
|
113 |
+
|
114 |
+
@property
|
115 |
+
def fourcc(self):
|
116 |
+
"""str: "Four character code" of the video."""
|
117 |
+
return self._fourcc
|
118 |
+
|
119 |
+
@property
|
120 |
+
def position(self):
|
121 |
+
"""int: Current cursor position, indicating frame decoded."""
|
122 |
+
return self._position
|
123 |
+
|
124 |
+
def _get_real_position(self):
|
125 |
+
return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
|
126 |
+
|
127 |
+
def _set_real_position(self, frame_id):
|
128 |
+
self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
|
129 |
+
pos = self._get_real_position()
|
130 |
+
for _ in range(frame_id - pos):
|
131 |
+
self._vcap.read()
|
132 |
+
self._position = frame_id
|
133 |
+
|
134 |
+
def read(self):
|
135 |
+
"""Read the next frame.
|
136 |
+
|
137 |
+
If the next frame have been decoded before and in the cache, then
|
138 |
+
return it directly, otherwise decode, cache and return it.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
ndarray or None: Return the frame if successful, otherwise None.
|
142 |
+
"""
|
143 |
+
# pos = self._position
|
144 |
+
if self._cache:
|
145 |
+
img = self._cache.get(self._position)
|
146 |
+
if img is not None:
|
147 |
+
ret = True
|
148 |
+
else:
|
149 |
+
if self._position != self._get_real_position():
|
150 |
+
self._set_real_position(self._position)
|
151 |
+
ret, img = self._vcap.read()
|
152 |
+
if ret:
|
153 |
+
self._cache.put(self._position, img)
|
154 |
+
else:
|
155 |
+
ret, img = self._vcap.read()
|
156 |
+
if ret:
|
157 |
+
self._position += 1
|
158 |
+
return img
|
159 |
+
|
160 |
+
def get_frame(self, frame_id):
|
161 |
+
"""Get frame by index.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
frame_id (int): Index of the expected frame, 0-based.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
ndarray or None: Return the frame if successful, otherwise None.
|
168 |
+
"""
|
169 |
+
if frame_id < 0 or frame_id >= self._frame_cnt:
|
170 |
+
raise IndexError(
|
171 |
+
f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
|
172 |
+
if frame_id == self._position:
|
173 |
+
return self.read()
|
174 |
+
if self._cache:
|
175 |
+
img = self._cache.get(frame_id)
|
176 |
+
if img is not None:
|
177 |
+
self._position = frame_id + 1
|
178 |
+
return img
|
179 |
+
self._set_real_position(frame_id)
|
180 |
+
ret, img = self._vcap.read()
|
181 |
+
if ret:
|
182 |
+
if self._cache:
|
183 |
+
self._cache.put(self._position, img)
|
184 |
+
self._position += 1
|
185 |
+
return img
|
186 |
+
|
187 |
+
def current_frame(self):
|
188 |
+
"""Get the current frame (frame that is just visited).
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
ndarray or None: If the video is fresh, return None, otherwise
|
192 |
+
return the frame.
|
193 |
+
"""
|
194 |
+
if self._position == 0:
|
195 |
+
return None
|
196 |
+
return self._cache.get(self._position - 1)
|
197 |
+
|
198 |
+
def cvt2frames(self,
|
199 |
+
frame_dir,
|
200 |
+
file_start=0,
|
201 |
+
filename_tmpl='{:06d}.jpg',
|
202 |
+
start=0,
|
203 |
+
max_num=0,
|
204 |
+
show_progress=True):
|
205 |
+
"""Convert a video to frame images.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
frame_dir (str): Output directory to store all the frame images.
|
209 |
+
file_start (int): Filenames will start from the specified number.
|
210 |
+
filename_tmpl (str): Filename template with the index as the
|
211 |
+
placeholder.
|
212 |
+
start (int): The starting frame index.
|
213 |
+
max_num (int): Maximum number of frames to be written.
|
214 |
+
show_progress (bool): Whether to show a progress bar.
|
215 |
+
"""
|
216 |
+
mkdir_or_exist(frame_dir)
|
217 |
+
if max_num == 0:
|
218 |
+
task_num = self.frame_cnt - start
|
219 |
+
else:
|
220 |
+
task_num = min(self.frame_cnt - start, max_num)
|
221 |
+
if task_num <= 0:
|
222 |
+
raise ValueError('start must be less than total frame number')
|
223 |
+
if start > 0:
|
224 |
+
self._set_real_position(start)
|
225 |
+
|
226 |
+
def write_frame(file_idx):
|
227 |
+
img = self.read()
|
228 |
+
if img is None:
|
229 |
+
return
|
230 |
+
filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
|
231 |
+
cv2.imwrite(filename, img)
|
232 |
+
|
233 |
+
if show_progress:
|
234 |
+
track_progress(write_frame, range(file_start,
|
235 |
+
file_start + task_num))
|
236 |
+
else:
|
237 |
+
for i in range(task_num):
|
238 |
+
write_frame(file_start + i)
|
239 |
+
|
240 |
+
def __len__(self):
|
241 |
+
return self.frame_cnt
|
242 |
+
|
243 |
+
def __getitem__(self, index):
|
244 |
+
if isinstance(index, slice):
|
245 |
+
return [
|
246 |
+
self.get_frame(i)
|
247 |
+
for i in range(*index.indices(self.frame_cnt))
|
248 |
+
]
|
249 |
+
# support negative indexing
|
250 |
+
if index < 0:
|
251 |
+
index += self.frame_cnt
|
252 |
+
if index < 0:
|
253 |
+
raise IndexError('index out of range')
|
254 |
+
return self.get_frame(index)
|
255 |
+
|
256 |
+
def __iter__(self):
|
257 |
+
self._set_real_position(0)
|
258 |
+
return self
|
259 |
+
|
260 |
+
def __next__(self):
|
261 |
+
img = self.read()
|
262 |
+
if img is not None:
|
263 |
+
return img
|
264 |
+
else:
|
265 |
+
raise StopIteration
|
266 |
+
|
267 |
+
next = __next__
|
268 |
+
|
269 |
+
def __enter__(self):
|
270 |
+
return self
|
271 |
+
|
272 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
273 |
+
self._vcap.release()
|
274 |
+
|
275 |
+
|
276 |
+
def frames2video(frame_dir,
|
277 |
+
video_file,
|
278 |
+
fps=30,
|
279 |
+
fourcc='XVID',
|
280 |
+
filename_tmpl='{:06d}.jpg',
|
281 |
+
start=0,
|
282 |
+
end=0,
|
283 |
+
show_progress=True):
|
284 |
+
"""Read the frame images from a directory and join them as a video.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
frame_dir (str): The directory containing video frames.
|
288 |
+
video_file (str): Output filename.
|
289 |
+
fps (float): FPS of the output video.
|
290 |
+
fourcc (str): Fourcc of the output video, this should be compatible
|
291 |
+
with the output file type.
|
292 |
+
filename_tmpl (str): Filename template with the index as the variable.
|
293 |
+
start (int): Starting frame index.
|
294 |
+
end (int): Ending frame index.
|
295 |
+
show_progress (bool): Whether to show a progress bar.
|
296 |
+
"""
|
297 |
+
if end == 0:
|
298 |
+
ext = filename_tmpl.split('.')[-1]
|
299 |
+
end = len([name for name in scandir(frame_dir, ext)])
|
300 |
+
first_file = osp.join(frame_dir, filename_tmpl.format(start))
|
301 |
+
check_file_exist(first_file, 'The start frame not found: ' + first_file)
|
302 |
+
img = cv2.imread(first_file)
|
303 |
+
height, width = img.shape[:2]
|
304 |
+
resolution = (width, height)
|
305 |
+
vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
|
306 |
+
resolution)
|
307 |
+
|
308 |
+
def write_frame(file_idx):
|
309 |
+
filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
|
310 |
+
img = cv2.imread(filename)
|
311 |
+
vwriter.write(img)
|
312 |
+
|
313 |
+
if show_progress:
|
314 |
+
track_progress(write_frame, range(start, end))
|
315 |
+
else:
|
316 |
+
for i in range(start, end):
|
317 |
+
write_frame(i)
|
318 |
+
vwriter.release()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/optflow.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from annotator.mmpkg.mmcv.arraymisc import dequantize, quantize
|
8 |
+
from annotator.mmpkg.mmcv.image import imread, imwrite
|
9 |
+
from annotator.mmpkg.mmcv.utils import is_str
|
10 |
+
|
11 |
+
|
12 |
+
def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
|
13 |
+
"""Read an optical flow map.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
flow_or_path (ndarray or str): A flow map or filepath.
|
17 |
+
quantize (bool): whether to read quantized pair, if set to True,
|
18 |
+
remaining args will be passed to :func:`dequantize_flow`.
|
19 |
+
concat_axis (int): The axis that dx and dy are concatenated,
|
20 |
+
can be either 0 or 1. Ignored if quantize is False.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
ndarray: Optical flow represented as a (h, w, 2) numpy array
|
24 |
+
"""
|
25 |
+
if isinstance(flow_or_path, np.ndarray):
|
26 |
+
if (flow_or_path.ndim != 3) or (flow_or_path.shape[-1] != 2):
|
27 |
+
raise ValueError(f'Invalid flow with shape {flow_or_path.shape}')
|
28 |
+
return flow_or_path
|
29 |
+
elif not is_str(flow_or_path):
|
30 |
+
raise TypeError(f'"flow_or_path" must be a filename or numpy array, '
|
31 |
+
f'not {type(flow_or_path)}')
|
32 |
+
|
33 |
+
if not quantize:
|
34 |
+
with open(flow_or_path, 'rb') as f:
|
35 |
+
try:
|
36 |
+
header = f.read(4).decode('utf-8')
|
37 |
+
except Exception:
|
38 |
+
raise IOError(f'Invalid flow file: {flow_or_path}')
|
39 |
+
else:
|
40 |
+
if header != 'PIEH':
|
41 |
+
raise IOError(f'Invalid flow file: {flow_or_path}, '
|
42 |
+
'header does not contain PIEH')
|
43 |
+
|
44 |
+
w = np.fromfile(f, np.int32, 1).squeeze()
|
45 |
+
h = np.fromfile(f, np.int32, 1).squeeze()
|
46 |
+
flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
|
47 |
+
else:
|
48 |
+
assert concat_axis in [0, 1]
|
49 |
+
cat_flow = imread(flow_or_path, flag='unchanged')
|
50 |
+
if cat_flow.ndim != 2:
|
51 |
+
raise IOError(
|
52 |
+
f'{flow_or_path} is not a valid quantized flow file, '
|
53 |
+
f'its dimension is {cat_flow.ndim}.')
|
54 |
+
assert cat_flow.shape[concat_axis] % 2 == 0
|
55 |
+
dx, dy = np.split(cat_flow, 2, axis=concat_axis)
|
56 |
+
flow = dequantize_flow(dx, dy, *args, **kwargs)
|
57 |
+
|
58 |
+
return flow.astype(np.float32)
|
59 |
+
|
60 |
+
|
61 |
+
def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
|
62 |
+
"""Write optical flow to file.
|
63 |
+
|
64 |
+
If the flow is not quantized, it will be saved as a .flo file losslessly,
|
65 |
+
otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
|
66 |
+
will be concatenated horizontally into a single image if quantize is True.)
|
67 |
+
|
68 |
+
Args:
|
69 |
+
flow (ndarray): (h, w, 2) array of optical flow.
|
70 |
+
filename (str): Output filepath.
|
71 |
+
quantize (bool): Whether to quantize the flow and save it to 2 jpeg
|
72 |
+
images. If set to True, remaining args will be passed to
|
73 |
+
:func:`quantize_flow`.
|
74 |
+
concat_axis (int): The axis that dx and dy are concatenated,
|
75 |
+
can be either 0 or 1. Ignored if quantize is False.
|
76 |
+
"""
|
77 |
+
if not quantize:
|
78 |
+
with open(filename, 'wb') as f:
|
79 |
+
f.write('PIEH'.encode('utf-8'))
|
80 |
+
np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
|
81 |
+
flow = flow.astype(np.float32)
|
82 |
+
flow.tofile(f)
|
83 |
+
f.flush()
|
84 |
+
else:
|
85 |
+
assert concat_axis in [0, 1]
|
86 |
+
dx, dy = quantize_flow(flow, *args, **kwargs)
|
87 |
+
dxdy = np.concatenate((dx, dy), axis=concat_axis)
|
88 |
+
imwrite(dxdy, filename)
|
89 |
+
|
90 |
+
|
91 |
+
def quantize_flow(flow, max_val=0.02, norm=True):
|
92 |
+
"""Quantize flow to [0, 255].
|
93 |
+
|
94 |
+
After this step, the size of flow will be much smaller, and can be
|
95 |
+
dumped as jpeg images.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
flow (ndarray): (h, w, 2) array of optical flow.
|
99 |
+
max_val (float): Maximum value of flow, values beyond
|
100 |
+
[-max_val, max_val] will be truncated.
|
101 |
+
norm (bool): Whether to divide flow values by image width/height.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
tuple[ndarray]: Quantized dx and dy.
|
105 |
+
"""
|
106 |
+
h, w, _ = flow.shape
|
107 |
+
dx = flow[..., 0]
|
108 |
+
dy = flow[..., 1]
|
109 |
+
if norm:
|
110 |
+
dx = dx / w # avoid inplace operations
|
111 |
+
dy = dy / h
|
112 |
+
# use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
|
113 |
+
flow_comps = [
|
114 |
+
quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
|
115 |
+
]
|
116 |
+
return tuple(flow_comps)
|
117 |
+
|
118 |
+
|
119 |
+
def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
|
120 |
+
"""Recover from quantized flow.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
dx (ndarray): Quantized dx.
|
124 |
+
dy (ndarray): Quantized dy.
|
125 |
+
max_val (float): Maximum value used when quantizing.
|
126 |
+
denorm (bool): Whether to multiply flow values with width/height.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
ndarray: Dequantized flow.
|
130 |
+
"""
|
131 |
+
assert dx.shape == dy.shape
|
132 |
+
assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
|
133 |
+
|
134 |
+
dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
|
135 |
+
|
136 |
+
if denorm:
|
137 |
+
dx *= dx.shape[1]
|
138 |
+
dy *= dx.shape[0]
|
139 |
+
flow = np.dstack((dx, dy))
|
140 |
+
return flow
|
141 |
+
|
142 |
+
|
143 |
+
def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
|
144 |
+
"""Use flow to warp img.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
img (ndarray, float or uint8): Image to be warped.
|
148 |
+
flow (ndarray, float): Optical Flow.
|
149 |
+
filling_value (int): The missing pixels will be set with filling_value.
|
150 |
+
interpolate_mode (str): bilinear -> Bilinear Interpolation;
|
151 |
+
nearest -> Nearest Neighbor.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
ndarray: Warped image with the same shape of img
|
155 |
+
"""
|
156 |
+
warnings.warn('This function is just for prototyping and cannot '
|
157 |
+
'guarantee the computational efficiency.')
|
158 |
+
assert flow.ndim == 3, 'Flow must be in 3D arrays.'
|
159 |
+
height = flow.shape[0]
|
160 |
+
width = flow.shape[1]
|
161 |
+
channels = img.shape[2]
|
162 |
+
|
163 |
+
output = np.ones(
|
164 |
+
(height, width, channels), dtype=img.dtype) * filling_value
|
165 |
+
|
166 |
+
grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2)
|
167 |
+
dx = grid[:, :, 0] + flow[:, :, 1]
|
168 |
+
dy = grid[:, :, 1] + flow[:, :, 0]
|
169 |
+
sx = np.floor(dx).astype(int)
|
170 |
+
sy = np.floor(dy).astype(int)
|
171 |
+
valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1)
|
172 |
+
|
173 |
+
if interpolate_mode == 'nearest':
|
174 |
+
output[valid, :] = img[dx[valid].round().astype(int),
|
175 |
+
dy[valid].round().astype(int), :]
|
176 |
+
elif interpolate_mode == 'bilinear':
|
177 |
+
# dirty walkround for integer positions
|
178 |
+
eps_ = 1e-6
|
179 |
+
dx, dy = dx + eps_, dy + eps_
|
180 |
+
left_top_ = img[np.floor(dx[valid]).astype(int),
|
181 |
+
np.floor(dy[valid]).astype(int), :] * (
|
182 |
+
np.ceil(dx[valid]) - dx[valid])[:, None] * (
|
183 |
+
np.ceil(dy[valid]) - dy[valid])[:, None]
|
184 |
+
left_down_ = img[np.ceil(dx[valid]).astype(int),
|
185 |
+
np.floor(dy[valid]).astype(int), :] * (
|
186 |
+
dx[valid] - np.floor(dx[valid]))[:, None] * (
|
187 |
+
np.ceil(dy[valid]) - dy[valid])[:, None]
|
188 |
+
right_top_ = img[np.floor(dx[valid]).astype(int),
|
189 |
+
np.ceil(dy[valid]).astype(int), :] * (
|
190 |
+
np.ceil(dx[valid]) - dx[valid])[:, None] * (
|
191 |
+
dy[valid] - np.floor(dy[valid]))[:, None]
|
192 |
+
right_down_ = img[np.ceil(dx[valid]).astype(int),
|
193 |
+
np.ceil(dy[valid]).astype(int), :] * (
|
194 |
+
dx[valid] - np.floor(dx[valid]))[:, None] * (
|
195 |
+
dy[valid] - np.floor(dy[valid]))[:, None]
|
196 |
+
output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_
|
197 |
+
else:
|
198 |
+
raise NotImplementedError(
|
199 |
+
'We only support interpolation modes of nearest and bilinear, '
|
200 |
+
f'but got {interpolate_mode}.')
|
201 |
+
return output.astype(img.dtype)
|
202 |
+
|
203 |
+
|
204 |
+
def flow_from_bytes(content):
|
205 |
+
"""Read dense optical flow from bytes.
|
206 |
+
|
207 |
+
.. note::
|
208 |
+
This load optical flow function works for FlyingChairs, FlyingThings3D,
|
209 |
+
Sintel, FlyingChairsOcc datasets, but cannot load the data from
|
210 |
+
ChairsSDHom.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
content (bytes): Optical flow bytes got from files or other streams.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
ndarray: Loaded optical flow with the shape (H, W, 2).
|
217 |
+
"""
|
218 |
+
|
219 |
+
# header in first 4 bytes
|
220 |
+
header = content[:4]
|
221 |
+
if header.decode('utf-8') != 'PIEH':
|
222 |
+
raise Exception('Flow file header does not contain PIEH')
|
223 |
+
# width in second 4 bytes
|
224 |
+
width = np.frombuffer(content[4:], np.int32, 1).squeeze()
|
225 |
+
# height in third 4 bytes
|
226 |
+
height = np.frombuffer(content[8:], np.int32, 1).squeeze()
|
227 |
+
# after first 12 bytes, all bytes are flow
|
228 |
+
flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape(
|
229 |
+
(height, width, 2))
|
230 |
+
|
231 |
+
return flow
|
232 |
+
|
233 |
+
|
234 |
+
def sparse_flow_from_bytes(content):
|
235 |
+
"""Read the optical flow in KITTI datasets from bytes.
|
236 |
+
|
237 |
+
This function is modified from RAFT load the `KITTI datasets
|
238 |
+
<https://github.com/princeton-vl/RAFT/blob/224320502d66c356d88e6c712f38129e60661e80/core/utils/frame_utils.py#L102>`_.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
content (bytes): Optical flow bytes got from files or other streams.
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2)
|
245 |
+
and flow valid mask with the shape (H, W).
|
246 |
+
""" # nopa
|
247 |
+
|
248 |
+
content = np.frombuffer(content, np.uint8)
|
249 |
+
flow = cv2.imdecode(content, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
|
250 |
+
flow = flow[:, :, ::-1].astype(np.float32)
|
251 |
+
# flow shape (H, W, 2) valid shape (H, W)
|
252 |
+
flow, valid = flow[:, :, :2], flow[:, :, 2]
|
253 |
+
flow = (flow - 2**15) / 64.0
|
254 |
+
return flow, valid
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/processing.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import subprocess
|
5 |
+
import tempfile
|
6 |
+
|
7 |
+
from annotator.mmpkg.mmcv.utils import requires_executable
|
8 |
+
|
9 |
+
|
10 |
+
@requires_executable('ffmpeg')
|
11 |
+
def convert_video(in_file,
|
12 |
+
out_file,
|
13 |
+
print_cmd=False,
|
14 |
+
pre_options='',
|
15 |
+
**kwargs):
|
16 |
+
"""Convert a video with ffmpeg.
|
17 |
+
|
18 |
+
This provides a general api to ffmpeg, the executed command is::
|
19 |
+
|
20 |
+
`ffmpeg -y <pre_options> -i <in_file> <options> <out_file>`
|
21 |
+
|
22 |
+
Options(kwargs) are mapped to ffmpeg commands with the following rules:
|
23 |
+
|
24 |
+
- key=val: "-key val"
|
25 |
+
- key=True: "-key"
|
26 |
+
- key=False: ""
|
27 |
+
|
28 |
+
Args:
|
29 |
+
in_file (str): Input video filename.
|
30 |
+
out_file (str): Output video filename.
|
31 |
+
pre_options (str): Options appears before "-i <in_file>".
|
32 |
+
print_cmd (bool): Whether to print the final ffmpeg command.
|
33 |
+
"""
|
34 |
+
options = []
|
35 |
+
for k, v in kwargs.items():
|
36 |
+
if isinstance(v, bool):
|
37 |
+
if v:
|
38 |
+
options.append(f'-{k}')
|
39 |
+
elif k == 'log_level':
|
40 |
+
assert v in [
|
41 |
+
'quiet', 'panic', 'fatal', 'error', 'warning', 'info',
|
42 |
+
'verbose', 'debug', 'trace'
|
43 |
+
]
|
44 |
+
options.append(f'-loglevel {v}')
|
45 |
+
else:
|
46 |
+
options.append(f'-{k} {v}')
|
47 |
+
cmd = f'ffmpeg -y {pre_options} -i {in_file} {" ".join(options)} ' \
|
48 |
+
f'{out_file}'
|
49 |
+
if print_cmd:
|
50 |
+
print(cmd)
|
51 |
+
subprocess.call(cmd, shell=True)
|
52 |
+
|
53 |
+
|
54 |
+
@requires_executable('ffmpeg')
|
55 |
+
def resize_video(in_file,
|
56 |
+
out_file,
|
57 |
+
size=None,
|
58 |
+
ratio=None,
|
59 |
+
keep_ar=False,
|
60 |
+
log_level='info',
|
61 |
+
print_cmd=False):
|
62 |
+
"""Resize a video.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
in_file (str): Input video filename.
|
66 |
+
out_file (str): Output video filename.
|
67 |
+
size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1).
|
68 |
+
ratio (tuple or float): Expected resize ratio, (2, 0.5) means
|
69 |
+
(w*2, h*0.5).
|
70 |
+
keep_ar (bool): Whether to keep original aspect ratio.
|
71 |
+
log_level (str): Logging level of ffmpeg.
|
72 |
+
print_cmd (bool): Whether to print the final ffmpeg command.
|
73 |
+
"""
|
74 |
+
if size is None and ratio is None:
|
75 |
+
raise ValueError('expected size or ratio must be specified')
|
76 |
+
if size is not None and ratio is not None:
|
77 |
+
raise ValueError('size and ratio cannot be specified at the same time')
|
78 |
+
options = {'log_level': log_level}
|
79 |
+
if size:
|
80 |
+
if not keep_ar:
|
81 |
+
options['vf'] = f'scale={size[0]}:{size[1]}'
|
82 |
+
else:
|
83 |
+
options['vf'] = f'scale=w={size[0]}:h={size[1]}:' \
|
84 |
+
'force_original_aspect_ratio=decrease'
|
85 |
+
else:
|
86 |
+
if not isinstance(ratio, tuple):
|
87 |
+
ratio = (ratio, ratio)
|
88 |
+
options['vf'] = f'scale="trunc(iw*{ratio[0]}):trunc(ih*{ratio[1]})"'
|
89 |
+
convert_video(in_file, out_file, print_cmd, **options)
|
90 |
+
|
91 |
+
|
92 |
+
@requires_executable('ffmpeg')
|
93 |
+
def cut_video(in_file,
|
94 |
+
out_file,
|
95 |
+
start=None,
|
96 |
+
end=None,
|
97 |
+
vcodec=None,
|
98 |
+
acodec=None,
|
99 |
+
log_level='info',
|
100 |
+
print_cmd=False):
|
101 |
+
"""Cut a clip from a video.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
in_file (str): Input video filename.
|
105 |
+
out_file (str): Output video filename.
|
106 |
+
start (None or float): Start time (in seconds).
|
107 |
+
end (None or float): End time (in seconds).
|
108 |
+
vcodec (None or str): Output video codec, None for unchanged.
|
109 |
+
acodec (None or str): Output audio codec, None for unchanged.
|
110 |
+
log_level (str): Logging level of ffmpeg.
|
111 |
+
print_cmd (bool): Whether to print the final ffmpeg command.
|
112 |
+
"""
|
113 |
+
options = {'log_level': log_level}
|
114 |
+
if vcodec is None:
|
115 |
+
options['vcodec'] = 'copy'
|
116 |
+
if acodec is None:
|
117 |
+
options['acodec'] = 'copy'
|
118 |
+
if start:
|
119 |
+
options['ss'] = start
|
120 |
+
else:
|
121 |
+
start = 0
|
122 |
+
if end:
|
123 |
+
options['t'] = end - start
|
124 |
+
convert_video(in_file, out_file, print_cmd, **options)
|
125 |
+
|
126 |
+
|
127 |
+
@requires_executable('ffmpeg')
|
128 |
+
def concat_video(video_list,
|
129 |
+
out_file,
|
130 |
+
vcodec=None,
|
131 |
+
acodec=None,
|
132 |
+
log_level='info',
|
133 |
+
print_cmd=False):
|
134 |
+
"""Concatenate multiple videos into a single one.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
video_list (list): A list of video filenames
|
138 |
+
out_file (str): Output video filename
|
139 |
+
vcodec (None or str): Output video codec, None for unchanged
|
140 |
+
acodec (None or str): Output audio codec, None for unchanged
|
141 |
+
log_level (str): Logging level of ffmpeg.
|
142 |
+
print_cmd (bool): Whether to print the final ffmpeg command.
|
143 |
+
"""
|
144 |
+
tmp_filehandler, tmp_filename = tempfile.mkstemp(suffix='.txt', text=True)
|
145 |
+
with open(tmp_filename, 'w') as f:
|
146 |
+
for filename in video_list:
|
147 |
+
f.write(f'file {osp.abspath(filename)}\n')
|
148 |
+
options = {'log_level': log_level}
|
149 |
+
if vcodec is None:
|
150 |
+
options['vcodec'] = 'copy'
|
151 |
+
if acodec is None:
|
152 |
+
options['acodec'] = 'copy'
|
153 |
+
convert_video(
|
154 |
+
tmp_filename,
|
155 |
+
out_file,
|
156 |
+
print_cmd,
|
157 |
+
pre_options='-f concat -safe 0',
|
158 |
+
**options)
|
159 |
+
os.close(tmp_filehandler)
|
160 |
+
os.remove(tmp_filename)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .color import Color, color_val
|
3 |
+
from .image import imshow, imshow_bboxes, imshow_det_bboxes
|
4 |
+
from .optflow import flow2rgb, flowshow, make_color_wheel
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes',
|
8 |
+
'flowshow', 'flow2rgb', 'make_color_wheel'
|
9 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/color.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmcv.utils import is_str
|
7 |
+
|
8 |
+
|
9 |
+
class Color(Enum):
|
10 |
+
"""An enum that defines common colors.
|
11 |
+
|
12 |
+
Contains red, green, blue, cyan, yellow, magenta, white and black.
|
13 |
+
"""
|
14 |
+
red = (0, 0, 255)
|
15 |
+
green = (0, 255, 0)
|
16 |
+
blue = (255, 0, 0)
|
17 |
+
cyan = (255, 255, 0)
|
18 |
+
yellow = (0, 255, 255)
|
19 |
+
magenta = (255, 0, 255)
|
20 |
+
white = (255, 255, 255)
|
21 |
+
black = (0, 0, 0)
|
22 |
+
|
23 |
+
|
24 |
+
def color_val(color):
|
25 |
+
"""Convert various input to color tuples.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
tuple[int]: A tuple of 3 integers indicating BGR channels.
|
32 |
+
"""
|
33 |
+
if is_str(color):
|
34 |
+
return Color[color].value
|
35 |
+
elif isinstance(color, Color):
|
36 |
+
return color.value
|
37 |
+
elif isinstance(color, tuple):
|
38 |
+
assert len(color) == 3
|
39 |
+
for channel in color:
|
40 |
+
assert 0 <= channel <= 255
|
41 |
+
return color
|
42 |
+
elif isinstance(color, int):
|
43 |
+
assert 0 <= color <= 255
|
44 |
+
return color, color, color
|
45 |
+
elif isinstance(color, np.ndarray):
|
46 |
+
assert color.ndim == 1 and color.size == 3
|
47 |
+
assert np.all((color >= 0) & (color <= 255))
|
48 |
+
color = color.astype(np.uint8)
|
49 |
+
return tuple(color)
|
50 |
+
else:
|
51 |
+
raise TypeError(f'Invalid type for color: {type(color)}')
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/image.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from annotator.mmpkg.mmcv.image import imread, imwrite
|
6 |
+
from .color import color_val
|
7 |
+
|
8 |
+
|
9 |
+
def imshow(img, win_name='', wait_time=0):
|
10 |
+
"""Show an image.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
img (str or ndarray): The image to be displayed.
|
14 |
+
win_name (str): The window name.
|
15 |
+
wait_time (int): Value of waitKey param.
|
16 |
+
"""
|
17 |
+
cv2.imshow(win_name, imread(img))
|
18 |
+
if wait_time == 0: # prevent from hanging if windows was closed
|
19 |
+
while True:
|
20 |
+
ret = cv2.waitKey(1)
|
21 |
+
|
22 |
+
closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
|
23 |
+
# if user closed window or if some key pressed
|
24 |
+
if closed or ret != -1:
|
25 |
+
break
|
26 |
+
else:
|
27 |
+
ret = cv2.waitKey(wait_time)
|
28 |
+
|
29 |
+
|
30 |
+
def imshow_bboxes(img,
|
31 |
+
bboxes,
|
32 |
+
colors='green',
|
33 |
+
top_k=-1,
|
34 |
+
thickness=1,
|
35 |
+
show=True,
|
36 |
+
win_name='',
|
37 |
+
wait_time=0,
|
38 |
+
out_file=None):
|
39 |
+
"""Draw bboxes on an image.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
img (str or ndarray): The image to be displayed.
|
43 |
+
bboxes (list or ndarray): A list of ndarray of shape (k, 4).
|
44 |
+
colors (list[str or tuple or Color]): A list of colors.
|
45 |
+
top_k (int): Plot the first k bboxes only if set positive.
|
46 |
+
thickness (int): Thickness of lines.
|
47 |
+
show (bool): Whether to show the image.
|
48 |
+
win_name (str): The window name.
|
49 |
+
wait_time (int): Value of waitKey param.
|
50 |
+
out_file (str, optional): The filename to write the image.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
ndarray: The image with bboxes drawn on it.
|
54 |
+
"""
|
55 |
+
img = imread(img)
|
56 |
+
img = np.ascontiguousarray(img)
|
57 |
+
|
58 |
+
if isinstance(bboxes, np.ndarray):
|
59 |
+
bboxes = [bboxes]
|
60 |
+
if not isinstance(colors, list):
|
61 |
+
colors = [colors for _ in range(len(bboxes))]
|
62 |
+
colors = [color_val(c) for c in colors]
|
63 |
+
assert len(bboxes) == len(colors)
|
64 |
+
|
65 |
+
for i, _bboxes in enumerate(bboxes):
|
66 |
+
_bboxes = _bboxes.astype(np.int32)
|
67 |
+
if top_k <= 0:
|
68 |
+
_top_k = _bboxes.shape[0]
|
69 |
+
else:
|
70 |
+
_top_k = min(top_k, _bboxes.shape[0])
|
71 |
+
for j in range(_top_k):
|
72 |
+
left_top = (_bboxes[j, 0], _bboxes[j, 1])
|
73 |
+
right_bottom = (_bboxes[j, 2], _bboxes[j, 3])
|
74 |
+
cv2.rectangle(
|
75 |
+
img, left_top, right_bottom, colors[i], thickness=thickness)
|
76 |
+
|
77 |
+
if show:
|
78 |
+
imshow(img, win_name, wait_time)
|
79 |
+
if out_file is not None:
|
80 |
+
imwrite(img, out_file)
|
81 |
+
return img
|
82 |
+
|
83 |
+
|
84 |
+
def imshow_det_bboxes(img,
|
85 |
+
bboxes,
|
86 |
+
labels,
|
87 |
+
class_names=None,
|
88 |
+
score_thr=0,
|
89 |
+
bbox_color='green',
|
90 |
+
text_color='green',
|
91 |
+
thickness=1,
|
92 |
+
font_scale=0.5,
|
93 |
+
show=True,
|
94 |
+
win_name='',
|
95 |
+
wait_time=0,
|
96 |
+
out_file=None):
|
97 |
+
"""Draw bboxes and class labels (with scores) on an image.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
img (str or ndarray): The image to be displayed.
|
101 |
+
bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
|
102 |
+
(n, 5).
|
103 |
+
labels (ndarray): Labels of bboxes.
|
104 |
+
class_names (list[str]): Names of each classes.
|
105 |
+
score_thr (float): Minimum score of bboxes to be shown.
|
106 |
+
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
|
107 |
+
text_color (str or tuple or :obj:`Color`): Color of texts.
|
108 |
+
thickness (int): Thickness of lines.
|
109 |
+
font_scale (float): Font scales of texts.
|
110 |
+
show (bool): Whether to show the image.
|
111 |
+
win_name (str): The window name.
|
112 |
+
wait_time (int): Value of waitKey param.
|
113 |
+
out_file (str or None): The filename to write the image.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
ndarray: The image with bboxes drawn on it.
|
117 |
+
"""
|
118 |
+
assert bboxes.ndim == 2
|
119 |
+
assert labels.ndim == 1
|
120 |
+
assert bboxes.shape[0] == labels.shape[0]
|
121 |
+
assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5
|
122 |
+
img = imread(img)
|
123 |
+
img = np.ascontiguousarray(img)
|
124 |
+
|
125 |
+
if score_thr > 0:
|
126 |
+
assert bboxes.shape[1] == 5
|
127 |
+
scores = bboxes[:, -1]
|
128 |
+
inds = scores > score_thr
|
129 |
+
bboxes = bboxes[inds, :]
|
130 |
+
labels = labels[inds]
|
131 |
+
|
132 |
+
bbox_color = color_val(bbox_color)
|
133 |
+
text_color = color_val(text_color)
|
134 |
+
|
135 |
+
for bbox, label in zip(bboxes, labels):
|
136 |
+
bbox_int = bbox.astype(np.int32)
|
137 |
+
left_top = (bbox_int[0], bbox_int[1])
|
138 |
+
right_bottom = (bbox_int[2], bbox_int[3])
|
139 |
+
cv2.rectangle(
|
140 |
+
img, left_top, right_bottom, bbox_color, thickness=thickness)
|
141 |
+
label_text = class_names[
|
142 |
+
label] if class_names is not None else f'cls {label}'
|
143 |
+
if len(bbox) > 4:
|
144 |
+
label_text += f'|{bbox[-1]:.02f}'
|
145 |
+
cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 2),
|
146 |
+
cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
|
147 |
+
|
148 |
+
if show:
|
149 |
+
imshow(img, win_name, wait_time)
|
150 |
+
if out_file is not None:
|
151 |
+
imwrite(img, out_file)
|
152 |
+
return img
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/optflow.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from __future__ import division
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmcv.image import rgb2bgr
|
7 |
+
from annotator.mmpkg.mmcv.video import flowread
|
8 |
+
from .image import imshow
|
9 |
+
|
10 |
+
|
11 |
+
def flowshow(flow, win_name='', wait_time=0):
|
12 |
+
"""Show optical flow.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
flow (ndarray or str): The optical flow to be displayed.
|
16 |
+
win_name (str): The window name.
|
17 |
+
wait_time (int): Value of waitKey param.
|
18 |
+
"""
|
19 |
+
flow = flowread(flow)
|
20 |
+
flow_img = flow2rgb(flow)
|
21 |
+
imshow(rgb2bgr(flow_img), win_name, wait_time)
|
22 |
+
|
23 |
+
|
24 |
+
def flow2rgb(flow, color_wheel=None, unknown_thr=1e6):
|
25 |
+
"""Convert flow map to RGB image.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
flow (ndarray): Array of optical flow.
|
29 |
+
color_wheel (ndarray or None): Color wheel used to map flow field to
|
30 |
+
RGB colorspace. Default color wheel will be used if not specified.
|
31 |
+
unknown_thr (str): Values above this threshold will be marked as
|
32 |
+
unknown and thus ignored.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
ndarray: RGB image that can be visualized.
|
36 |
+
"""
|
37 |
+
assert flow.ndim == 3 and flow.shape[-1] == 2
|
38 |
+
if color_wheel is None:
|
39 |
+
color_wheel = make_color_wheel()
|
40 |
+
assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3
|
41 |
+
num_bins = color_wheel.shape[0]
|
42 |
+
|
43 |
+
dx = flow[:, :, 0].copy()
|
44 |
+
dy = flow[:, :, 1].copy()
|
45 |
+
|
46 |
+
ignore_inds = (
|
47 |
+
np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) |
|
48 |
+
(np.abs(dy) > unknown_thr))
|
49 |
+
dx[ignore_inds] = 0
|
50 |
+
dy[ignore_inds] = 0
|
51 |
+
|
52 |
+
rad = np.sqrt(dx**2 + dy**2)
|
53 |
+
if np.any(rad > np.finfo(float).eps):
|
54 |
+
max_rad = np.max(rad)
|
55 |
+
dx /= max_rad
|
56 |
+
dy /= max_rad
|
57 |
+
|
58 |
+
rad = np.sqrt(dx**2 + dy**2)
|
59 |
+
angle = np.arctan2(-dy, -dx) / np.pi
|
60 |
+
|
61 |
+
bin_real = (angle + 1) / 2 * (num_bins - 1)
|
62 |
+
bin_left = np.floor(bin_real).astype(int)
|
63 |
+
bin_right = (bin_left + 1) % num_bins
|
64 |
+
w = (bin_real - bin_left.astype(np.float32))[..., None]
|
65 |
+
flow_img = (1 -
|
66 |
+
w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :]
|
67 |
+
small_ind = rad <= 1
|
68 |
+
flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind])
|
69 |
+
flow_img[np.logical_not(small_ind)] *= 0.75
|
70 |
+
|
71 |
+
flow_img[ignore_inds, :] = 0
|
72 |
+
|
73 |
+
return flow_img
|
74 |
+
|
75 |
+
|
76 |
+
def make_color_wheel(bins=None):
|
77 |
+
"""Build a color wheel.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
bins(list or tuple, optional): Specify the number of bins for each
|
81 |
+
color range, corresponding to six ranges: red -> yellow,
|
82 |
+
yellow -> green, green -> cyan, cyan -> blue, blue -> magenta,
|
83 |
+
magenta -> red. [15, 6, 4, 11, 13, 6] is used for default
|
84 |
+
(see Middlebury).
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
ndarray: Color wheel of shape (total_bins, 3).
|
88 |
+
"""
|
89 |
+
if bins is None:
|
90 |
+
bins = [15, 6, 4, 11, 13, 6]
|
91 |
+
assert len(bins) == 6
|
92 |
+
|
93 |
+
RY, YG, GC, CB, BM, MR = tuple(bins)
|
94 |
+
|
95 |
+
ry = [1, np.arange(RY) / RY, 0]
|
96 |
+
yg = [1 - np.arange(YG) / YG, 1, 0]
|
97 |
+
gc = [0, 1, np.arange(GC) / GC]
|
98 |
+
cb = [0, 1 - np.arange(CB) / CB, 1]
|
99 |
+
bm = [np.arange(BM) / BM, 0, 1]
|
100 |
+
mr = [1, 0, 1 - np.arange(MR) / MR]
|
101 |
+
|
102 |
+
num_bins = RY + YG + GC + CB + BM + MR
|
103 |
+
|
104 |
+
color_wheel = np.zeros((3, num_bins), dtype=np.float32)
|
105 |
+
|
106 |
+
col = 0
|
107 |
+
for i, color in enumerate([ry, yg, gc, cb, bm, mr]):
|
108 |
+
for j in range(3):
|
109 |
+
color_wheel[j, col:col + bins[i]] = color[j]
|
110 |
+
col += bins[i]
|
111 |
+
|
112 |
+
return color_wheel.T
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .inference import inference_segmentor, init_segmentor, show_result_pyplot
|
2 |
+
from .test import multi_gpu_test, single_gpu_test
|
3 |
+
from .train import get_root_logger, set_random_seed, train_segmentor
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
|
7 |
+
'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
|
8 |
+
'show_result_pyplot'
|
9 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/inference.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import annotator.mmpkg.mmcv as mmcv
|
3 |
+
import torch
|
4 |
+
from annotator.mmpkg.mmcv.parallel import collate, scatter
|
5 |
+
from annotator.mmpkg.mmcv.runner import load_checkpoint
|
6 |
+
|
7 |
+
from annotator.mmpkg.mmseg.datasets.pipelines import Compose
|
8 |
+
from annotator.mmpkg.mmseg.models import build_segmentor
|
9 |
+
from modules import devices
|
10 |
+
|
11 |
+
|
12 |
+
def init_segmentor(config, checkpoint=None, device=devices.get_device_for("controlnet")):
|
13 |
+
"""Initialize a segmentor from config file.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
config (str or :obj:`mmcv.Config`): Config file path or the config
|
17 |
+
object.
|
18 |
+
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
19 |
+
will not load any weights.
|
20 |
+
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
|
21 |
+
Use 'cpu' for loading model on CPU.
|
22 |
+
Returns:
|
23 |
+
nn.Module: The constructed segmentor.
|
24 |
+
"""
|
25 |
+
if isinstance(config, str):
|
26 |
+
config = mmcv.Config.fromfile(config)
|
27 |
+
elif not isinstance(config, mmcv.Config):
|
28 |
+
raise TypeError('config must be a filename or Config object, '
|
29 |
+
'but got {}'.format(type(config)))
|
30 |
+
config.model.pretrained = None
|
31 |
+
config.model.train_cfg = None
|
32 |
+
model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
|
33 |
+
if checkpoint is not None:
|
34 |
+
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
35 |
+
model.CLASSES = checkpoint['meta']['CLASSES']
|
36 |
+
model.PALETTE = checkpoint['meta']['PALETTE']
|
37 |
+
model.cfg = config # save the config in the model for convenience
|
38 |
+
model.to(device)
|
39 |
+
model.eval()
|
40 |
+
return model
|
41 |
+
|
42 |
+
|
43 |
+
class LoadImage:
|
44 |
+
"""A simple pipeline to load image."""
|
45 |
+
|
46 |
+
def __call__(self, results):
|
47 |
+
"""Call function to load images into results.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
results (dict): A result dict contains the file name
|
51 |
+
of the image to be read.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
dict: ``results`` will be returned containing loaded image.
|
55 |
+
"""
|
56 |
+
|
57 |
+
if isinstance(results['img'], str):
|
58 |
+
results['filename'] = results['img']
|
59 |
+
results['ori_filename'] = results['img']
|
60 |
+
else:
|
61 |
+
results['filename'] = None
|
62 |
+
results['ori_filename'] = None
|
63 |
+
img = mmcv.imread(results['img'])
|
64 |
+
results['img'] = img
|
65 |
+
results['img_shape'] = img.shape
|
66 |
+
results['ori_shape'] = img.shape
|
67 |
+
return results
|
68 |
+
|
69 |
+
|
70 |
+
def inference_segmentor(model, img):
|
71 |
+
"""Inference image(s) with the segmentor.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
model (nn.Module): The loaded segmentor.
|
75 |
+
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
|
76 |
+
images.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
(list[Tensor]): The segmentation result.
|
80 |
+
"""
|
81 |
+
cfg = model.cfg
|
82 |
+
device = next(model.parameters()).device # model device
|
83 |
+
# build the data pipeline
|
84 |
+
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
85 |
+
test_pipeline = Compose(test_pipeline)
|
86 |
+
# prepare data
|
87 |
+
data = dict(img=img)
|
88 |
+
data = test_pipeline(data)
|
89 |
+
data = collate([data], samples_per_gpu=1)
|
90 |
+
if next(model.parameters()).is_cuda:
|
91 |
+
# scatter to specified GPU
|
92 |
+
data = scatter(data, [device])[0]
|
93 |
+
else:
|
94 |
+
data['img'][0] = data['img'][0].to(devices.get_device_for("controlnet"))
|
95 |
+
data['img_metas'] = [i.data[0] for i in data['img_metas']]
|
96 |
+
|
97 |
+
# forward the model
|
98 |
+
with torch.no_grad():
|
99 |
+
result = model(return_loss=False, rescale=True, **data)
|
100 |
+
return result
|
101 |
+
|
102 |
+
|
103 |
+
def show_result_pyplot(model,
|
104 |
+
img,
|
105 |
+
result,
|
106 |
+
palette=None,
|
107 |
+
fig_size=(15, 10),
|
108 |
+
opacity=0.5,
|
109 |
+
title='',
|
110 |
+
block=True):
|
111 |
+
"""Visualize the segmentation results on the image.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
model (nn.Module): The loaded segmentor.
|
115 |
+
img (str or np.ndarray): Image filename or loaded image.
|
116 |
+
result (list): The segmentation result.
|
117 |
+
palette (list[list[int]]] | None): The palette of segmentation
|
118 |
+
map. If None is given, random palette will be generated.
|
119 |
+
Default: None
|
120 |
+
fig_size (tuple): Figure size of the pyplot figure.
|
121 |
+
opacity(float): Opacity of painted segmentation map.
|
122 |
+
Default 0.5.
|
123 |
+
Must be in (0, 1] range.
|
124 |
+
title (str): The title of pyplot figure.
|
125 |
+
Default is ''.
|
126 |
+
block (bool): Whether to block the pyplot figure.
|
127 |
+
Default is True.
|
128 |
+
"""
|
129 |
+
if hasattr(model, 'module'):
|
130 |
+
model = model.module
|
131 |
+
img = model.show_result(
|
132 |
+
img, result, palette=palette, show=False, opacity=opacity)
|
133 |
+
# plt.figure(figsize=fig_size)
|
134 |
+
# plt.imshow(mmcv.bgr2rgb(img))
|
135 |
+
# plt.title(title)
|
136 |
+
# plt.tight_layout()
|
137 |
+
# plt.show(block=block)
|
138 |
+
return mmcv.bgr2rgb(img)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/test.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import pickle
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
|
6 |
+
import annotator.mmpkg.mmcv as mmcv
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from annotator.mmpkg.mmcv.image import tensor2imgs
|
11 |
+
from annotator.mmpkg.mmcv.runner import get_dist_info
|
12 |
+
|
13 |
+
|
14 |
+
def np2tmp(array, temp_file_name=None):
|
15 |
+
"""Save ndarray to local numpy file.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
array (ndarray): Ndarray to save.
|
19 |
+
temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
|
20 |
+
function will generate a file name with tempfile.NamedTemporaryFile
|
21 |
+
to save ndarray. Default: None.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
str: The numpy file name.
|
25 |
+
"""
|
26 |
+
|
27 |
+
if temp_file_name is None:
|
28 |
+
temp_file_name = tempfile.NamedTemporaryFile(
|
29 |
+
suffix='.npy', delete=False).name
|
30 |
+
np.save(temp_file_name, array)
|
31 |
+
return temp_file_name
|
32 |
+
|
33 |
+
|
34 |
+
def single_gpu_test(model,
|
35 |
+
data_loader,
|
36 |
+
show=False,
|
37 |
+
out_dir=None,
|
38 |
+
efficient_test=False,
|
39 |
+
opacity=0.5):
|
40 |
+
"""Test with single GPU.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
model (nn.Module): Model to be tested.
|
44 |
+
data_loader (utils.data.Dataloader): Pytorch data loader.
|
45 |
+
show (bool): Whether show results during inference. Default: False.
|
46 |
+
out_dir (str, optional): If specified, the results will be dumped into
|
47 |
+
the directory to save output results.
|
48 |
+
efficient_test (bool): Whether save the results as local numpy files to
|
49 |
+
save CPU memory during evaluation. Default: False.
|
50 |
+
opacity(float): Opacity of painted segmentation map.
|
51 |
+
Default 0.5.
|
52 |
+
Must be in (0, 1] range.
|
53 |
+
Returns:
|
54 |
+
list: The prediction results.
|
55 |
+
"""
|
56 |
+
|
57 |
+
model.eval()
|
58 |
+
results = []
|
59 |
+
dataset = data_loader.dataset
|
60 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
61 |
+
for i, data in enumerate(data_loader):
|
62 |
+
with torch.no_grad():
|
63 |
+
result = model(return_loss=False, **data)
|
64 |
+
|
65 |
+
if show or out_dir:
|
66 |
+
img_tensor = data['img'][0]
|
67 |
+
img_metas = data['img_metas'][0].data[0]
|
68 |
+
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
|
69 |
+
assert len(imgs) == len(img_metas)
|
70 |
+
|
71 |
+
for img, img_meta in zip(imgs, img_metas):
|
72 |
+
h, w, _ = img_meta['img_shape']
|
73 |
+
img_show = img[:h, :w, :]
|
74 |
+
|
75 |
+
ori_h, ori_w = img_meta['ori_shape'][:-1]
|
76 |
+
img_show = mmcv.imresize(img_show, (ori_w, ori_h))
|
77 |
+
|
78 |
+
if out_dir:
|
79 |
+
out_file = osp.join(out_dir, img_meta['ori_filename'])
|
80 |
+
else:
|
81 |
+
out_file = None
|
82 |
+
|
83 |
+
model.module.show_result(
|
84 |
+
img_show,
|
85 |
+
result,
|
86 |
+
palette=dataset.PALETTE,
|
87 |
+
show=show,
|
88 |
+
out_file=out_file,
|
89 |
+
opacity=opacity)
|
90 |
+
|
91 |
+
if isinstance(result, list):
|
92 |
+
if efficient_test:
|
93 |
+
result = [np2tmp(_) for _ in result]
|
94 |
+
results.extend(result)
|
95 |
+
else:
|
96 |
+
if efficient_test:
|
97 |
+
result = np2tmp(result)
|
98 |
+
results.append(result)
|
99 |
+
|
100 |
+
batch_size = len(result)
|
101 |
+
for _ in range(batch_size):
|
102 |
+
prog_bar.update()
|
103 |
+
return results
|
104 |
+
|
105 |
+
|
106 |
+
def multi_gpu_test(model,
|
107 |
+
data_loader,
|
108 |
+
tmpdir=None,
|
109 |
+
gpu_collect=False,
|
110 |
+
efficient_test=False):
|
111 |
+
"""Test model with multiple gpus.
|
112 |
+
|
113 |
+
This method tests model with multiple gpus and collects the results
|
114 |
+
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
|
115 |
+
it encodes results to gpu tensors and use gpu communication for results
|
116 |
+
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
|
117 |
+
and collects them by the rank 0 worker.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
model (nn.Module): Model to be tested.
|
121 |
+
data_loader (utils.data.Dataloader): Pytorch data loader.
|
122 |
+
tmpdir (str): Path of directory to save the temporary results from
|
123 |
+
different gpus under cpu mode.
|
124 |
+
gpu_collect (bool): Option to use either gpu or cpu to collect results.
|
125 |
+
efficient_test (bool): Whether save the results as local numpy files to
|
126 |
+
save CPU memory during evaluation. Default: False.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
list: The prediction results.
|
130 |
+
"""
|
131 |
+
|
132 |
+
model.eval()
|
133 |
+
results = []
|
134 |
+
dataset = data_loader.dataset
|
135 |
+
rank, world_size = get_dist_info()
|
136 |
+
if rank == 0:
|
137 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
138 |
+
for i, data in enumerate(data_loader):
|
139 |
+
with torch.no_grad():
|
140 |
+
result = model(return_loss=False, rescale=True, **data)
|
141 |
+
|
142 |
+
if isinstance(result, list):
|
143 |
+
if efficient_test:
|
144 |
+
result = [np2tmp(_) for _ in result]
|
145 |
+
results.extend(result)
|
146 |
+
else:
|
147 |
+
if efficient_test:
|
148 |
+
result = np2tmp(result)
|
149 |
+
results.append(result)
|
150 |
+
|
151 |
+
if rank == 0:
|
152 |
+
batch_size = data['img'][0].size(0)
|
153 |
+
for _ in range(batch_size * world_size):
|
154 |
+
prog_bar.update()
|
155 |
+
|
156 |
+
# collect results from all ranks
|
157 |
+
if gpu_collect:
|
158 |
+
results = collect_results_gpu(results, len(dataset))
|
159 |
+
else:
|
160 |
+
results = collect_results_cpu(results, len(dataset), tmpdir)
|
161 |
+
return results
|
162 |
+
|
163 |
+
|
164 |
+
def collect_results_cpu(result_part, size, tmpdir=None):
|
165 |
+
"""Collect results with CPU."""
|
166 |
+
rank, world_size = get_dist_info()
|
167 |
+
# create a tmp dir if it is not specified
|
168 |
+
if tmpdir is None:
|
169 |
+
MAX_LEN = 512
|
170 |
+
# 32 is whitespace
|
171 |
+
dir_tensor = torch.full((MAX_LEN, ),
|
172 |
+
32,
|
173 |
+
dtype=torch.uint8,
|
174 |
+
device='cuda')
|
175 |
+
if rank == 0:
|
176 |
+
tmpdir = tempfile.mkdtemp()
|
177 |
+
tmpdir = torch.tensor(
|
178 |
+
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
|
179 |
+
dir_tensor[:len(tmpdir)] = tmpdir
|
180 |
+
dist.broadcast(dir_tensor, 0)
|
181 |
+
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
|
182 |
+
else:
|
183 |
+
mmcv.mkdir_or_exist(tmpdir)
|
184 |
+
# dump the part result to the dir
|
185 |
+
mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
|
186 |
+
dist.barrier()
|
187 |
+
# collect all parts
|
188 |
+
if rank != 0:
|
189 |
+
return None
|
190 |
+
else:
|
191 |
+
# load results of all parts from tmp dir
|
192 |
+
part_list = []
|
193 |
+
for i in range(world_size):
|
194 |
+
part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
|
195 |
+
part_list.append(mmcv.load(part_file))
|
196 |
+
# sort the results
|
197 |
+
ordered_results = []
|
198 |
+
for res in zip(*part_list):
|
199 |
+
ordered_results.extend(list(res))
|
200 |
+
# the dataloader may pad some samples
|
201 |
+
ordered_results = ordered_results[:size]
|
202 |
+
# remove tmp dir
|
203 |
+
shutil.rmtree(tmpdir)
|
204 |
+
return ordered_results
|
205 |
+
|
206 |
+
|
207 |
+
def collect_results_gpu(result_part, size):
|
208 |
+
"""Collect results with GPU."""
|
209 |
+
rank, world_size = get_dist_info()
|
210 |
+
# dump result part to tensor with pickle
|
211 |
+
part_tensor = torch.tensor(
|
212 |
+
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
|
213 |
+
# gather all result part tensor shape
|
214 |
+
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
|
215 |
+
shape_list = [shape_tensor.clone() for _ in range(world_size)]
|
216 |
+
dist.all_gather(shape_list, shape_tensor)
|
217 |
+
# padding result part tensor to max length
|
218 |
+
shape_max = torch.tensor(shape_list).max()
|
219 |
+
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
|
220 |
+
part_send[:shape_tensor[0]] = part_tensor
|
221 |
+
part_recv_list = [
|
222 |
+
part_tensor.new_zeros(shape_max) for _ in range(world_size)
|
223 |
+
]
|
224 |
+
# gather all result part
|
225 |
+
dist.all_gather(part_recv_list, part_send)
|
226 |
+
|
227 |
+
if rank == 0:
|
228 |
+
part_list = []
|
229 |
+
for recv, shape in zip(part_recv_list, shape_list):
|
230 |
+
part_list.append(
|
231 |
+
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
|
232 |
+
# sort the results
|
233 |
+
ordered_results = []
|
234 |
+
for res in zip(*part_list):
|
235 |
+
ordered_results.extend(list(res))
|
236 |
+
# the dataloader may pad some samples
|
237 |
+
ordered_results = ordered_results[:size]
|
238 |
+
return ordered_results
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/train.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from annotator.mmpkg.mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
7 |
+
from annotator.mmpkg.mmcv.runner import build_optimizer, build_runner
|
8 |
+
|
9 |
+
from annotator.mmpkg.mmseg.core import DistEvalHook, EvalHook
|
10 |
+
from annotator.mmpkg.mmseg.datasets import build_dataloader, build_dataset
|
11 |
+
from annotator.mmpkg.mmseg.utils import get_root_logger
|
12 |
+
|
13 |
+
|
14 |
+
def set_random_seed(seed, deterministic=False):
|
15 |
+
"""Set random seed.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
seed (int): Seed to be used.
|
19 |
+
deterministic (bool): Whether to set the deterministic option for
|
20 |
+
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
21 |
+
to True and `torch.backends.cudnn.benchmark` to False.
|
22 |
+
Default: False.
|
23 |
+
"""
|
24 |
+
random.seed(seed)
|
25 |
+
np.random.seed(seed)
|
26 |
+
torch.manual_seed(seed)
|
27 |
+
torch.cuda.manual_seed_all(seed)
|
28 |
+
if deterministic:
|
29 |
+
torch.backends.cudnn.deterministic = True
|
30 |
+
torch.backends.cudnn.benchmark = False
|
31 |
+
|
32 |
+
|
33 |
+
def train_segmentor(model,
|
34 |
+
dataset,
|
35 |
+
cfg,
|
36 |
+
distributed=False,
|
37 |
+
validate=False,
|
38 |
+
timestamp=None,
|
39 |
+
meta=None):
|
40 |
+
"""Launch segmentor training."""
|
41 |
+
logger = get_root_logger(cfg.log_level)
|
42 |
+
|
43 |
+
# prepare data loaders
|
44 |
+
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
45 |
+
data_loaders = [
|
46 |
+
build_dataloader(
|
47 |
+
ds,
|
48 |
+
cfg.data.samples_per_gpu,
|
49 |
+
cfg.data.workers_per_gpu,
|
50 |
+
# cfg.gpus will be ignored if distributed
|
51 |
+
len(cfg.gpu_ids),
|
52 |
+
dist=distributed,
|
53 |
+
seed=cfg.seed,
|
54 |
+
drop_last=True) for ds in dataset
|
55 |
+
]
|
56 |
+
|
57 |
+
# put model on gpus
|
58 |
+
if distributed:
|
59 |
+
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
60 |
+
# Sets the `find_unused_parameters` parameter in
|
61 |
+
# torch.nn.parallel.DistributedDataParallel
|
62 |
+
model = MMDistributedDataParallel(
|
63 |
+
model.cuda(),
|
64 |
+
device_ids=[torch.cuda.current_device()],
|
65 |
+
broadcast_buffers=False,
|
66 |
+
find_unused_parameters=find_unused_parameters)
|
67 |
+
else:
|
68 |
+
model = MMDataParallel(
|
69 |
+
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
|
70 |
+
|
71 |
+
# build runner
|
72 |
+
optimizer = build_optimizer(model, cfg.optimizer)
|
73 |
+
|
74 |
+
if cfg.get('runner') is None:
|
75 |
+
cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
|
76 |
+
warnings.warn(
|
77 |
+
'config is now expected to have a `runner` section, '
|
78 |
+
'please set `runner` in your config.', UserWarning)
|
79 |
+
|
80 |
+
runner = build_runner(
|
81 |
+
cfg.runner,
|
82 |
+
default_args=dict(
|
83 |
+
model=model,
|
84 |
+
batch_processor=None,
|
85 |
+
optimizer=optimizer,
|
86 |
+
work_dir=cfg.work_dir,
|
87 |
+
logger=logger,
|
88 |
+
meta=meta))
|
89 |
+
|
90 |
+
# register hooks
|
91 |
+
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
|
92 |
+
cfg.checkpoint_config, cfg.log_config,
|
93 |
+
cfg.get('momentum_config', None))
|
94 |
+
|
95 |
+
# an ugly walkaround to make the .log and .log.json filenames the same
|
96 |
+
runner.timestamp = timestamp
|
97 |
+
|
98 |
+
# register eval hooks
|
99 |
+
if validate:
|
100 |
+
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
101 |
+
val_dataloader = build_dataloader(
|
102 |
+
val_dataset,
|
103 |
+
samples_per_gpu=1,
|
104 |
+
workers_per_gpu=cfg.data.workers_per_gpu,
|
105 |
+
dist=distributed,
|
106 |
+
shuffle=False)
|
107 |
+
eval_cfg = cfg.get('evaluation', {})
|
108 |
+
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
109 |
+
eval_hook = DistEvalHook if distributed else EvalHook
|
110 |
+
runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW')
|
111 |
+
|
112 |
+
if cfg.resume_from:
|
113 |
+
runner.resume(cfg.resume_from)
|
114 |
+
elif cfg.load_from:
|
115 |
+
runner.load_checkpoint(cfg.load_from)
|
116 |
+
runner.run(data_loaders, cfg.workflow)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .evaluation import * # noqa: F401, F403
|
2 |
+
from .seg import * # noqa: F401, F403
|
3 |
+
from .utils import * # noqa: F401, F403
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .class_names import get_classes, get_palette
|
2 |
+
from .eval_hooks import DistEvalHook, EvalHook
|
3 |
+
from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
|
7 |
+
'eval_metrics', 'get_classes', 'get_palette'
|
8 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/class_names.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import annotator.mmpkg.mmcv as mmcv
|
2 |
+
|
3 |
+
|
4 |
+
def cityscapes_classes():
|
5 |
+
"""Cityscapes class names for external use."""
|
6 |
+
return [
|
7 |
+
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
8 |
+
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
9 |
+
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
10 |
+
'bicycle'
|
11 |
+
]
|
12 |
+
|
13 |
+
|
14 |
+
def ade_classes():
|
15 |
+
"""ADE20K class names for external use."""
|
16 |
+
return [
|
17 |
+
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
|
18 |
+
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
|
19 |
+
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
|
20 |
+
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
|
21 |
+
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
|
22 |
+
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
23 |
+
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
24 |
+
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
|
25 |
+
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
|
26 |
+
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
|
27 |
+
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
|
28 |
+
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
|
29 |
+
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
30 |
+
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
|
31 |
+
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
|
32 |
+
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
|
33 |
+
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
|
34 |
+
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
|
35 |
+
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
|
36 |
+
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
|
37 |
+
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
|
38 |
+
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
|
39 |
+
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
40 |
+
'clock', 'flag'
|
41 |
+
]
|
42 |
+
|
43 |
+
|
44 |
+
def voc_classes():
|
45 |
+
"""Pascal VOC class names for external use."""
|
46 |
+
return [
|
47 |
+
'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
|
48 |
+
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
49 |
+
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
|
50 |
+
'tvmonitor'
|
51 |
+
]
|
52 |
+
|
53 |
+
|
54 |
+
def cityscapes_palette():
|
55 |
+
"""Cityscapes palette for external use."""
|
56 |
+
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
57 |
+
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
58 |
+
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
59 |
+
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
|
60 |
+
[0, 0, 230], [119, 11, 32]]
|
61 |
+
|
62 |
+
|
63 |
+
def ade_palette():
|
64 |
+
"""ADE20K palette for external use."""
|
65 |
+
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
66 |
+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
67 |
+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
68 |
+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
69 |
+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
70 |
+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
71 |
+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
72 |
+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
73 |
+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
74 |
+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
75 |
+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
76 |
+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
77 |
+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
78 |
+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
79 |
+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
80 |
+
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
81 |
+
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
82 |
+
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
83 |
+
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
84 |
+
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
85 |
+
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
86 |
+
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
87 |
+
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
88 |
+
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
89 |
+
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
90 |
+
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
91 |
+
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
92 |
+
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
93 |
+
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
94 |
+
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
95 |
+
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
96 |
+
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
97 |
+
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
98 |
+
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
99 |
+
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
100 |
+
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
101 |
+
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
102 |
+
[102, 255, 0], [92, 0, 255]]
|
103 |
+
|
104 |
+
|
105 |
+
def voc_palette():
|
106 |
+
"""Pascal VOC palette for external use."""
|
107 |
+
return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
|
108 |
+
[128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
|
109 |
+
[192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
|
110 |
+
[192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
|
111 |
+
[128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
|
112 |
+
|
113 |
+
|
114 |
+
dataset_aliases = {
|
115 |
+
'cityscapes': ['cityscapes'],
|
116 |
+
'ade': ['ade', 'ade20k'],
|
117 |
+
'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
|
118 |
+
}
|
119 |
+
|
120 |
+
|
121 |
+
def get_classes(dataset):
|
122 |
+
"""Get class names of a dataset."""
|
123 |
+
alias2name = {}
|
124 |
+
for name, aliases in dataset_aliases.items():
|
125 |
+
for alias in aliases:
|
126 |
+
alias2name[alias] = name
|
127 |
+
|
128 |
+
if mmcv.is_str(dataset):
|
129 |
+
if dataset in alias2name:
|
130 |
+
labels = eval(alias2name[dataset] + '_classes()')
|
131 |
+
else:
|
132 |
+
raise ValueError(f'Unrecognized dataset: {dataset}')
|
133 |
+
else:
|
134 |
+
raise TypeError(f'dataset must a str, but got {type(dataset)}')
|
135 |
+
return labels
|
136 |
+
|
137 |
+
|
138 |
+
def get_palette(dataset):
|
139 |
+
"""Get class palette (RGB) of a dataset."""
|
140 |
+
alias2name = {}
|
141 |
+
for name, aliases in dataset_aliases.items():
|
142 |
+
for alias in aliases:
|
143 |
+
alias2name[alias] = name
|
144 |
+
|
145 |
+
if mmcv.is_str(dataset):
|
146 |
+
if dataset in alias2name:
|
147 |
+
labels = eval(alias2name[dataset] + '_palette()')
|
148 |
+
else:
|
149 |
+
raise ValueError(f'Unrecognized dataset: {dataset}')
|
150 |
+
else:
|
151 |
+
raise TypeError(f'dataset must a str, but got {type(dataset)}')
|
152 |
+
return labels
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/eval_hooks.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
from annotator.mmpkg.mmcv.runner import DistEvalHook as _DistEvalHook
|
4 |
+
from annotator.mmpkg.mmcv.runner import EvalHook as _EvalHook
|
5 |
+
|
6 |
+
|
7 |
+
class EvalHook(_EvalHook):
|
8 |
+
"""Single GPU EvalHook, with efficient test support.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
|
12 |
+
If set to True, it will perform by epoch. Otherwise, by iteration.
|
13 |
+
Default: False.
|
14 |
+
efficient_test (bool): Whether save the results as local numpy files to
|
15 |
+
save CPU memory during evaluation. Default: False.
|
16 |
+
Returns:
|
17 |
+
list: The prediction results.
|
18 |
+
"""
|
19 |
+
|
20 |
+
greater_keys = ['mIoU', 'mAcc', 'aAcc']
|
21 |
+
|
22 |
+
def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
|
23 |
+
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
24 |
+
self.efficient_test = efficient_test
|
25 |
+
|
26 |
+
def after_train_iter(self, runner):
|
27 |
+
"""After train epoch hook.
|
28 |
+
|
29 |
+
Override default ``single_gpu_test``.
|
30 |
+
"""
|
31 |
+
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
32 |
+
return
|
33 |
+
from annotator.mmpkg.mmseg.apis import single_gpu_test
|
34 |
+
runner.log_buffer.clear()
|
35 |
+
results = single_gpu_test(
|
36 |
+
runner.model,
|
37 |
+
self.dataloader,
|
38 |
+
show=False,
|
39 |
+
efficient_test=self.efficient_test)
|
40 |
+
self.evaluate(runner, results)
|
41 |
+
|
42 |
+
def after_train_epoch(self, runner):
|
43 |
+
"""After train epoch hook.
|
44 |
+
|
45 |
+
Override default ``single_gpu_test``.
|
46 |
+
"""
|
47 |
+
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
48 |
+
return
|
49 |
+
from annotator.mmpkg.mmseg.apis import single_gpu_test
|
50 |
+
runner.log_buffer.clear()
|
51 |
+
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
52 |
+
self.evaluate(runner, results)
|
53 |
+
|
54 |
+
|
55 |
+
class DistEvalHook(_DistEvalHook):
|
56 |
+
"""Distributed EvalHook, with efficient test support.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
|
60 |
+
If set to True, it will perform by epoch. Otherwise, by iteration.
|
61 |
+
Default: False.
|
62 |
+
efficient_test (bool): Whether save the results as local numpy files to
|
63 |
+
save CPU memory during evaluation. Default: False.
|
64 |
+
Returns:
|
65 |
+
list: The prediction results.
|
66 |
+
"""
|
67 |
+
|
68 |
+
greater_keys = ['mIoU', 'mAcc', 'aAcc']
|
69 |
+
|
70 |
+
def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
|
71 |
+
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
72 |
+
self.efficient_test = efficient_test
|
73 |
+
|
74 |
+
def after_train_iter(self, runner):
|
75 |
+
"""After train epoch hook.
|
76 |
+
|
77 |
+
Override default ``multi_gpu_test``.
|
78 |
+
"""
|
79 |
+
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
80 |
+
return
|
81 |
+
from annotator.mmpkg.mmseg.apis import multi_gpu_test
|
82 |
+
runner.log_buffer.clear()
|
83 |
+
results = multi_gpu_test(
|
84 |
+
runner.model,
|
85 |
+
self.dataloader,
|
86 |
+
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
87 |
+
gpu_collect=self.gpu_collect,
|
88 |
+
efficient_test=self.efficient_test)
|
89 |
+
if runner.rank == 0:
|
90 |
+
print('\n')
|
91 |
+
self.evaluate(runner, results)
|
92 |
+
|
93 |
+
def after_train_epoch(self, runner):
|
94 |
+
"""After train epoch hook.
|
95 |
+
|
96 |
+
Override default ``multi_gpu_test``.
|
97 |
+
"""
|
98 |
+
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
99 |
+
return
|
100 |
+
from annotator.mmpkg.mmseg.apis import multi_gpu_test
|
101 |
+
runner.log_buffer.clear()
|
102 |
+
results = multi_gpu_test(
|
103 |
+
runner.model,
|
104 |
+
self.dataloader,
|
105 |
+
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
106 |
+
gpu_collect=self.gpu_collect)
|
107 |
+
if runner.rank == 0:
|
108 |
+
print('\n')
|
109 |
+
self.evaluate(runner, results)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/metrics.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import annotator.mmpkg.mmcv as mmcv
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def f_score(precision, recall, beta=1):
|
9 |
+
"""calcuate the f-score value.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
precision (float | torch.Tensor): The precision value.
|
13 |
+
recall (float | torch.Tensor): The recall value.
|
14 |
+
beta (int): Determines the weight of recall in the combined score.
|
15 |
+
Default: False.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
[torch.tensor]: The f-score value.
|
19 |
+
"""
|
20 |
+
score = (1 + beta**2) * (precision * recall) / (
|
21 |
+
(beta**2 * precision) + recall)
|
22 |
+
return score
|
23 |
+
|
24 |
+
|
25 |
+
def intersect_and_union(pred_label,
|
26 |
+
label,
|
27 |
+
num_classes,
|
28 |
+
ignore_index,
|
29 |
+
label_map=dict(),
|
30 |
+
reduce_zero_label=False):
|
31 |
+
"""Calculate intersection and Union.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
pred_label (ndarray | str): Prediction segmentation map
|
35 |
+
or predict result filename.
|
36 |
+
label (ndarray | str): Ground truth segmentation map
|
37 |
+
or label filename.
|
38 |
+
num_classes (int): Number of categories.
|
39 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
40 |
+
label_map (dict): Mapping old labels to new labels. The parameter will
|
41 |
+
work only when label is str. Default: dict().
|
42 |
+
reduce_zero_label (bool): Whether ignore zero label. The parameter will
|
43 |
+
work only when label is str. Default: False.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
torch.Tensor: The intersection of prediction and ground truth
|
47 |
+
histogram on all classes.
|
48 |
+
torch.Tensor: The union of prediction and ground truth histogram on
|
49 |
+
all classes.
|
50 |
+
torch.Tensor: The prediction histogram on all classes.
|
51 |
+
torch.Tensor: The ground truth histogram on all classes.
|
52 |
+
"""
|
53 |
+
|
54 |
+
if isinstance(pred_label, str):
|
55 |
+
pred_label = torch.from_numpy(np.load(pred_label))
|
56 |
+
else:
|
57 |
+
pred_label = torch.from_numpy((pred_label))
|
58 |
+
|
59 |
+
if isinstance(label, str):
|
60 |
+
label = torch.from_numpy(
|
61 |
+
mmcv.imread(label, flag='unchanged', backend='pillow'))
|
62 |
+
else:
|
63 |
+
label = torch.from_numpy(label)
|
64 |
+
|
65 |
+
if label_map is not None:
|
66 |
+
for old_id, new_id in label_map.items():
|
67 |
+
label[label == old_id] = new_id
|
68 |
+
if reduce_zero_label:
|
69 |
+
label[label == 0] = 255
|
70 |
+
label = label - 1
|
71 |
+
label[label == 254] = 255
|
72 |
+
|
73 |
+
mask = (label != ignore_index)
|
74 |
+
pred_label = pred_label[mask]
|
75 |
+
label = label[mask]
|
76 |
+
|
77 |
+
intersect = pred_label[pred_label == label]
|
78 |
+
area_intersect = torch.histc(
|
79 |
+
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
80 |
+
area_pred_label = torch.histc(
|
81 |
+
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
82 |
+
area_label = torch.histc(
|
83 |
+
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
|
84 |
+
area_union = area_pred_label + area_label - area_intersect
|
85 |
+
return area_intersect, area_union, area_pred_label, area_label
|
86 |
+
|
87 |
+
|
88 |
+
def total_intersect_and_union(results,
|
89 |
+
gt_seg_maps,
|
90 |
+
num_classes,
|
91 |
+
ignore_index,
|
92 |
+
label_map=dict(),
|
93 |
+
reduce_zero_label=False):
|
94 |
+
"""Calculate Total Intersection and Union.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
results (list[ndarray] | list[str]): List of prediction segmentation
|
98 |
+
maps or list of prediction result filenames.
|
99 |
+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
100 |
+
segmentation maps or list of label filenames.
|
101 |
+
num_classes (int): Number of categories.
|
102 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
103 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
104 |
+
reduce_zero_label (bool): Whether ignore zero label. Default: False.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
ndarray: The intersection of prediction and ground truth histogram
|
108 |
+
on all classes.
|
109 |
+
ndarray: The union of prediction and ground truth histogram on all
|
110 |
+
classes.
|
111 |
+
ndarray: The prediction histogram on all classes.
|
112 |
+
ndarray: The ground truth histogram on all classes.
|
113 |
+
"""
|
114 |
+
num_imgs = len(results)
|
115 |
+
assert len(gt_seg_maps) == num_imgs
|
116 |
+
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
|
117 |
+
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
|
118 |
+
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
119 |
+
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
|
120 |
+
for i in range(num_imgs):
|
121 |
+
area_intersect, area_union, area_pred_label, area_label = \
|
122 |
+
intersect_and_union(
|
123 |
+
results[i], gt_seg_maps[i], num_classes, ignore_index,
|
124 |
+
label_map, reduce_zero_label)
|
125 |
+
total_area_intersect += area_intersect
|
126 |
+
total_area_union += area_union
|
127 |
+
total_area_pred_label += area_pred_label
|
128 |
+
total_area_label += area_label
|
129 |
+
return total_area_intersect, total_area_union, total_area_pred_label, \
|
130 |
+
total_area_label
|
131 |
+
|
132 |
+
|
133 |
+
def mean_iou(results,
|
134 |
+
gt_seg_maps,
|
135 |
+
num_classes,
|
136 |
+
ignore_index,
|
137 |
+
nan_to_num=None,
|
138 |
+
label_map=dict(),
|
139 |
+
reduce_zero_label=False):
|
140 |
+
"""Calculate Mean Intersection and Union (mIoU)
|
141 |
+
|
142 |
+
Args:
|
143 |
+
results (list[ndarray] | list[str]): List of prediction segmentation
|
144 |
+
maps or list of prediction result filenames.
|
145 |
+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
146 |
+
segmentation maps or list of label filenames.
|
147 |
+
num_classes (int): Number of categories.
|
148 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
149 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
150 |
+
by the numbers defined by the user. Default: None.
|
151 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
152 |
+
reduce_zero_label (bool): Whether ignore zero label. Default: False.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
dict[str, float | ndarray]:
|
156 |
+
<aAcc> float: Overall accuracy on all images.
|
157 |
+
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
|
158 |
+
<IoU> ndarray: Per category IoU, shape (num_classes, ).
|
159 |
+
"""
|
160 |
+
iou_result = eval_metrics(
|
161 |
+
results=results,
|
162 |
+
gt_seg_maps=gt_seg_maps,
|
163 |
+
num_classes=num_classes,
|
164 |
+
ignore_index=ignore_index,
|
165 |
+
metrics=['mIoU'],
|
166 |
+
nan_to_num=nan_to_num,
|
167 |
+
label_map=label_map,
|
168 |
+
reduce_zero_label=reduce_zero_label)
|
169 |
+
return iou_result
|
170 |
+
|
171 |
+
|
172 |
+
def mean_dice(results,
|
173 |
+
gt_seg_maps,
|
174 |
+
num_classes,
|
175 |
+
ignore_index,
|
176 |
+
nan_to_num=None,
|
177 |
+
label_map=dict(),
|
178 |
+
reduce_zero_label=False):
|
179 |
+
"""Calculate Mean Dice (mDice)
|
180 |
+
|
181 |
+
Args:
|
182 |
+
results (list[ndarray] | list[str]): List of prediction segmentation
|
183 |
+
maps or list of prediction result filenames.
|
184 |
+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
185 |
+
segmentation maps or list of label filenames.
|
186 |
+
num_classes (int): Number of categories.
|
187 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
188 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
189 |
+
by the numbers defined by the user. Default: None.
|
190 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
191 |
+
reduce_zero_label (bool): Whether ignore zero label. Default: False.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
dict[str, float | ndarray]: Default metrics.
|
195 |
+
<aAcc> float: Overall accuracy on all images.
|
196 |
+
<Acc> ndarray: Per category accuracy, shape (num_classes, ).
|
197 |
+
<Dice> ndarray: Per category dice, shape (num_classes, ).
|
198 |
+
"""
|
199 |
+
|
200 |
+
dice_result = eval_metrics(
|
201 |
+
results=results,
|
202 |
+
gt_seg_maps=gt_seg_maps,
|
203 |
+
num_classes=num_classes,
|
204 |
+
ignore_index=ignore_index,
|
205 |
+
metrics=['mDice'],
|
206 |
+
nan_to_num=nan_to_num,
|
207 |
+
label_map=label_map,
|
208 |
+
reduce_zero_label=reduce_zero_label)
|
209 |
+
return dice_result
|
210 |
+
|
211 |
+
|
212 |
+
def mean_fscore(results,
|
213 |
+
gt_seg_maps,
|
214 |
+
num_classes,
|
215 |
+
ignore_index,
|
216 |
+
nan_to_num=None,
|
217 |
+
label_map=dict(),
|
218 |
+
reduce_zero_label=False,
|
219 |
+
beta=1):
|
220 |
+
"""Calculate Mean Intersection and Union (mIoU)
|
221 |
+
|
222 |
+
Args:
|
223 |
+
results (list[ndarray] | list[str]): List of prediction segmentation
|
224 |
+
maps or list of prediction result filenames.
|
225 |
+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
226 |
+
segmentation maps or list of label filenames.
|
227 |
+
num_classes (int): Number of categories.
|
228 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
229 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
230 |
+
by the numbers defined by the user. Default: None.
|
231 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
232 |
+
reduce_zero_label (bool): Whether ignore zero label. Default: False.
|
233 |
+
beta (int): Determines the weight of recall in the combined score.
|
234 |
+
Default: False.
|
235 |
+
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
dict[str, float | ndarray]: Default metrics.
|
239 |
+
<aAcc> float: Overall accuracy on all images.
|
240 |
+
<Fscore> ndarray: Per category recall, shape (num_classes, ).
|
241 |
+
<Precision> ndarray: Per category precision, shape (num_classes, ).
|
242 |
+
<Recall> ndarray: Per category f-score, shape (num_classes, ).
|
243 |
+
"""
|
244 |
+
fscore_result = eval_metrics(
|
245 |
+
results=results,
|
246 |
+
gt_seg_maps=gt_seg_maps,
|
247 |
+
num_classes=num_classes,
|
248 |
+
ignore_index=ignore_index,
|
249 |
+
metrics=['mFscore'],
|
250 |
+
nan_to_num=nan_to_num,
|
251 |
+
label_map=label_map,
|
252 |
+
reduce_zero_label=reduce_zero_label,
|
253 |
+
beta=beta)
|
254 |
+
return fscore_result
|
255 |
+
|
256 |
+
|
257 |
+
def eval_metrics(results,
|
258 |
+
gt_seg_maps,
|
259 |
+
num_classes,
|
260 |
+
ignore_index,
|
261 |
+
metrics=['mIoU'],
|
262 |
+
nan_to_num=None,
|
263 |
+
label_map=dict(),
|
264 |
+
reduce_zero_label=False,
|
265 |
+
beta=1):
|
266 |
+
"""Calculate evaluation metrics
|
267 |
+
Args:
|
268 |
+
results (list[ndarray] | list[str]): List of prediction segmentation
|
269 |
+
maps or list of prediction result filenames.
|
270 |
+
gt_seg_maps (list[ndarray] | list[str]): list of ground truth
|
271 |
+
segmentation maps or list of label filenames.
|
272 |
+
num_classes (int): Number of categories.
|
273 |
+
ignore_index (int): Index that will be ignored in evaluation.
|
274 |
+
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
|
275 |
+
nan_to_num (int, optional): If specified, NaN values will be replaced
|
276 |
+
by the numbers defined by the user. Default: None.
|
277 |
+
label_map (dict): Mapping old labels to new labels. Default: dict().
|
278 |
+
reduce_zero_label (bool): Whether ignore zero label. Default: False.
|
279 |
+
Returns:
|
280 |
+
float: Overall accuracy on all images.
|
281 |
+
ndarray: Per category accuracy, shape (num_classes, ).
|
282 |
+
ndarray: Per category evaluation metrics, shape (num_classes, ).
|
283 |
+
"""
|
284 |
+
if isinstance(metrics, str):
|
285 |
+
metrics = [metrics]
|
286 |
+
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
287 |
+
if not set(metrics).issubset(set(allowed_metrics)):
|
288 |
+
raise KeyError('metrics {} is not supported'.format(metrics))
|
289 |
+
|
290 |
+
total_area_intersect, total_area_union, total_area_pred_label, \
|
291 |
+
total_area_label = total_intersect_and_union(
|
292 |
+
results, gt_seg_maps, num_classes, ignore_index, label_map,
|
293 |
+
reduce_zero_label)
|
294 |
+
all_acc = total_area_intersect.sum() / total_area_label.sum()
|
295 |
+
ret_metrics = OrderedDict({'aAcc': all_acc})
|
296 |
+
for metric in metrics:
|
297 |
+
if metric == 'mIoU':
|
298 |
+
iou = total_area_intersect / total_area_union
|
299 |
+
acc = total_area_intersect / total_area_label
|
300 |
+
ret_metrics['IoU'] = iou
|
301 |
+
ret_metrics['Acc'] = acc
|
302 |
+
elif metric == 'mDice':
|
303 |
+
dice = 2 * total_area_intersect / (
|
304 |
+
total_area_pred_label + total_area_label)
|
305 |
+
acc = total_area_intersect / total_area_label
|
306 |
+
ret_metrics['Dice'] = dice
|
307 |
+
ret_metrics['Acc'] = acc
|
308 |
+
elif metric == 'mFscore':
|
309 |
+
precision = total_area_intersect / total_area_pred_label
|
310 |
+
recall = total_area_intersect / total_area_label
|
311 |
+
f_value = torch.tensor(
|
312 |
+
[f_score(x[0], x[1], beta) for x in zip(precision, recall)])
|
313 |
+
ret_metrics['Fscore'] = f_value
|
314 |
+
ret_metrics['Precision'] = precision
|
315 |
+
ret_metrics['Recall'] = recall
|
316 |
+
|
317 |
+
ret_metrics = {
|
318 |
+
metric: value.numpy()
|
319 |
+
for metric, value in ret_metrics.items()
|
320 |
+
}
|
321 |
+
if nan_to_num is not None:
|
322 |
+
ret_metrics = OrderedDict({
|
323 |
+
metric: np.nan_to_num(metric_value, nan=nan_to_num)
|
324 |
+
for metric, metric_value in ret_metrics.items()
|
325 |
+
})
|
326 |
+
return ret_metrics
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .builder import build_pixel_sampler
|
2 |
+
from .sampler import BasePixelSampler, OHEMPixelSampler
|
3 |
+
|
4 |
+
__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/builder.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg
|
2 |
+
|
3 |
+
PIXEL_SAMPLERS = Registry('pixel sampler')
|
4 |
+
|
5 |
+
|
6 |
+
def build_pixel_sampler(cfg, **default_args):
|
7 |
+
"""Build pixel sampler for segmentation map."""
|
8 |
+
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_pixel_sampler import BasePixelSampler
|
2 |
+
from .ohem_pixel_sampler import OHEMPixelSampler
|
3 |
+
|
4 |
+
__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/base_pixel_sampler.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class BasePixelSampler(metaclass=ABCMeta):
|
5 |
+
"""Base class of pixel sampler."""
|
6 |
+
|
7 |
+
def __init__(self, **kwargs):
|
8 |
+
pass
|
9 |
+
|
10 |
+
@abstractmethod
|
11 |
+
def sample(self, seg_logit, seg_label):
|
12 |
+
"""Placeholder for sample function."""
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/ohem_pixel_sampler.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from ..builder import PIXEL_SAMPLERS
|
5 |
+
from .base_pixel_sampler import BasePixelSampler
|
6 |
+
|
7 |
+
|
8 |
+
@PIXEL_SAMPLERS.register_module()
|
9 |
+
class OHEMPixelSampler(BasePixelSampler):
|
10 |
+
"""Online Hard Example Mining Sampler for segmentation.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
context (nn.Module): The context of sampler, subclass of
|
14 |
+
:obj:`BaseDecodeHead`.
|
15 |
+
thresh (float, optional): The threshold for hard example selection.
|
16 |
+
Below which, are prediction with low confidence. If not
|
17 |
+
specified, the hard examples will be pixels of top ``min_kept``
|
18 |
+
loss. Default: None.
|
19 |
+
min_kept (int, optional): The minimum number of predictions to keep.
|
20 |
+
Default: 100000.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, context, thresh=None, min_kept=100000):
|
24 |
+
super(OHEMPixelSampler, self).__init__()
|
25 |
+
self.context = context
|
26 |
+
assert min_kept > 1
|
27 |
+
self.thresh = thresh
|
28 |
+
self.min_kept = min_kept
|
29 |
+
|
30 |
+
def sample(self, seg_logit, seg_label):
|
31 |
+
"""Sample pixels that have high loss or with low prediction confidence.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
|
35 |
+
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
torch.Tensor: segmentation weight, shape (N, H, W)
|
39 |
+
"""
|
40 |
+
with torch.no_grad():
|
41 |
+
assert seg_logit.shape[2:] == seg_label.shape[2:]
|
42 |
+
assert seg_label.shape[1] == 1
|
43 |
+
seg_label = seg_label.squeeze(1).long()
|
44 |
+
batch_kept = self.min_kept * seg_label.size(0)
|
45 |
+
valid_mask = seg_label != self.context.ignore_index
|
46 |
+
seg_weight = seg_logit.new_zeros(size=seg_label.size())
|
47 |
+
valid_seg_weight = seg_weight[valid_mask]
|
48 |
+
if self.thresh is not None:
|
49 |
+
seg_prob = F.softmax(seg_logit, dim=1)
|
50 |
+
|
51 |
+
tmp_seg_label = seg_label.clone().unsqueeze(1)
|
52 |
+
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
|
53 |
+
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
|
54 |
+
sort_prob, sort_indices = seg_prob[valid_mask].sort()
|
55 |
+
|
56 |
+
if sort_prob.numel() > 0:
|
57 |
+
min_threshold = sort_prob[min(batch_kept,
|
58 |
+
sort_prob.numel() - 1)]
|
59 |
+
else:
|
60 |
+
min_threshold = 0.0
|
61 |
+
threshold = max(min_threshold, self.thresh)
|
62 |
+
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
|
63 |
+
else:
|
64 |
+
losses = self.context.loss_decode(
|
65 |
+
seg_logit,
|
66 |
+
seg_label,
|
67 |
+
weight=None,
|
68 |
+
ignore_index=self.context.ignore_index,
|
69 |
+
reduction_override='none')
|
70 |
+
# faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
|
71 |
+
_, sort_indices = losses[valid_mask].sort(descending=True)
|
72 |
+
valid_seg_weight[sort_indices[:batch_kept]] = 1.
|
73 |
+
|
74 |
+
seg_weight[valid_mask] = valid_seg_weight
|
75 |
+
|
76 |
+
return seg_weight
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .misc import add_prefix
|
2 |
+
|
3 |
+
__all__ = ['add_prefix']
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/misc.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def add_prefix(inputs, prefix):
|
2 |
+
"""Add prefix for dict.
|
3 |
+
|
4 |
+
Args:
|
5 |
+
inputs (dict): The input dict with str keys.
|
6 |
+
prefix (str): The prefix to add.
|
7 |
+
|
8 |
+
Returns:
|
9 |
+
|
10 |
+
dict: The dict with keys updated with ``prefix``.
|
11 |
+
"""
|
12 |
+
|
13 |
+
outputs = dict()
|
14 |
+
for name, value in inputs.items():
|
15 |
+
outputs[f'{prefix}.{name}'] = value
|
16 |
+
|
17 |
+
return outputs
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ade import ADE20KDataset
|
2 |
+
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
|
3 |
+
from .chase_db1 import ChaseDB1Dataset
|
4 |
+
from .cityscapes import CityscapesDataset
|
5 |
+
from .custom import CustomDataset
|
6 |
+
from .dataset_wrappers import ConcatDataset, RepeatDataset
|
7 |
+
from .drive import DRIVEDataset
|
8 |
+
from .hrf import HRFDataset
|
9 |
+
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
10 |
+
from .stare import STAREDataset
|
11 |
+
from .voc import PascalVOCDataset
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
|
15 |
+
'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
|
16 |
+
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
|
17 |
+
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
|
18 |
+
'STAREDataset'
|
19 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/ade.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .builder import DATASETS
|
2 |
+
from .custom import CustomDataset
|
3 |
+
|
4 |
+
|
5 |
+
@DATASETS.register_module()
|
6 |
+
class ADE20KDataset(CustomDataset):
|
7 |
+
"""ADE20K dataset.
|
8 |
+
|
9 |
+
In segmentation map annotation for ADE20K, 0 stands for background, which
|
10 |
+
is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
|
11 |
+
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
|
12 |
+
'.png'.
|
13 |
+
"""
|
14 |
+
CLASSES = (
|
15 |
+
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
|
16 |
+
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
|
17 |
+
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
|
18 |
+
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
|
19 |
+
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
|
20 |
+
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
21 |
+
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
22 |
+
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
|
23 |
+
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
|
24 |
+
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
|
25 |
+
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
|
26 |
+
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
|
27 |
+
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
28 |
+
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
|
29 |
+
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
|
30 |
+
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
|
31 |
+
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
|
32 |
+
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
|
33 |
+
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
|
34 |
+
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
|
35 |
+
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
|
36 |
+
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
|
37 |
+
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
38 |
+
'clock', 'flag')
|
39 |
+
|
40 |
+
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
41 |
+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
42 |
+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
43 |
+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
44 |
+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
45 |
+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
46 |
+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
47 |
+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
48 |
+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
49 |
+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
50 |
+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
51 |
+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
52 |
+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
53 |
+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
54 |
+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
55 |
+
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
56 |
+
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
57 |
+
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
58 |
+
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
59 |
+
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
60 |
+
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
61 |
+
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
62 |
+
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
63 |
+
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
64 |
+
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
65 |
+
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
66 |
+
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
67 |
+
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
68 |
+
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
69 |
+
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
70 |
+
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
71 |
+
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
72 |
+
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
73 |
+
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
74 |
+
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
75 |
+
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
76 |
+
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
77 |
+
[102, 255, 0], [92, 0, 255]]
|
78 |
+
|
79 |
+
def __init__(self, **kwargs):
|
80 |
+
super(ADE20KDataset, self).__init__(
|
81 |
+
img_suffix='.jpg',
|
82 |
+
seg_map_suffix='.png',
|
83 |
+
reduce_zero_label=True,
|
84 |
+
**kwargs)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/builder.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import platform
|
3 |
+
import random
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from annotator.mmpkg.mmcv.parallel import collate
|
8 |
+
from annotator.mmpkg.mmcv.runner import get_dist_info
|
9 |
+
from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg
|
10 |
+
from annotator.mmpkg.mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
|
11 |
+
from torch.utils.data import DistributedSampler
|
12 |
+
|
13 |
+
if platform.system() != 'Windows':
|
14 |
+
# https://github.com/pytorch/pytorch/issues/973
|
15 |
+
import resource
|
16 |
+
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
17 |
+
hard_limit = rlimit[1]
|
18 |
+
soft_limit = min(4096, hard_limit)
|
19 |
+
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
|
20 |
+
|
21 |
+
DATASETS = Registry('dataset')
|
22 |
+
PIPELINES = Registry('pipeline')
|
23 |
+
|
24 |
+
|
25 |
+
def _concat_dataset(cfg, default_args=None):
|
26 |
+
"""Build :obj:`ConcatDataset by."""
|
27 |
+
from .dataset_wrappers import ConcatDataset
|
28 |
+
img_dir = cfg['img_dir']
|
29 |
+
ann_dir = cfg.get('ann_dir', None)
|
30 |
+
split = cfg.get('split', None)
|
31 |
+
num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
|
32 |
+
if ann_dir is not None:
|
33 |
+
num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
|
34 |
+
else:
|
35 |
+
num_ann_dir = 0
|
36 |
+
if split is not None:
|
37 |
+
num_split = len(split) if isinstance(split, (list, tuple)) else 1
|
38 |
+
else:
|
39 |
+
num_split = 0
|
40 |
+
if num_img_dir > 1:
|
41 |
+
assert num_img_dir == num_ann_dir or num_ann_dir == 0
|
42 |
+
assert num_img_dir == num_split or num_split == 0
|
43 |
+
else:
|
44 |
+
assert num_split == num_ann_dir or num_ann_dir <= 1
|
45 |
+
num_dset = max(num_split, num_img_dir)
|
46 |
+
|
47 |
+
datasets = []
|
48 |
+
for i in range(num_dset):
|
49 |
+
data_cfg = copy.deepcopy(cfg)
|
50 |
+
if isinstance(img_dir, (list, tuple)):
|
51 |
+
data_cfg['img_dir'] = img_dir[i]
|
52 |
+
if isinstance(ann_dir, (list, tuple)):
|
53 |
+
data_cfg['ann_dir'] = ann_dir[i]
|
54 |
+
if isinstance(split, (list, tuple)):
|
55 |
+
data_cfg['split'] = split[i]
|
56 |
+
datasets.append(build_dataset(data_cfg, default_args))
|
57 |
+
|
58 |
+
return ConcatDataset(datasets)
|
59 |
+
|
60 |
+
|
61 |
+
def build_dataset(cfg, default_args=None):
|
62 |
+
"""Build datasets."""
|
63 |
+
from .dataset_wrappers import ConcatDataset, RepeatDataset
|
64 |
+
if isinstance(cfg, (list, tuple)):
|
65 |
+
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
|
66 |
+
elif cfg['type'] == 'RepeatDataset':
|
67 |
+
dataset = RepeatDataset(
|
68 |
+
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
69 |
+
elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
|
70 |
+
cfg.get('split', None), (list, tuple)):
|
71 |
+
dataset = _concat_dataset(cfg, default_args)
|
72 |
+
else:
|
73 |
+
dataset = build_from_cfg(cfg, DATASETS, default_args)
|
74 |
+
|
75 |
+
return dataset
|
76 |
+
|
77 |
+
|
78 |
+
def build_dataloader(dataset,
|
79 |
+
samples_per_gpu,
|
80 |
+
workers_per_gpu,
|
81 |
+
num_gpus=1,
|
82 |
+
dist=True,
|
83 |
+
shuffle=True,
|
84 |
+
seed=None,
|
85 |
+
drop_last=False,
|
86 |
+
pin_memory=True,
|
87 |
+
dataloader_type='PoolDataLoader',
|
88 |
+
**kwargs):
|
89 |
+
"""Build PyTorch DataLoader.
|
90 |
+
|
91 |
+
In distributed training, each GPU/process has a dataloader.
|
92 |
+
In non-distributed training, there is only one dataloader for all GPUs.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
dataset (Dataset): A PyTorch dataset.
|
96 |
+
samples_per_gpu (int): Number of training samples on each GPU, i.e.,
|
97 |
+
batch size of each GPU.
|
98 |
+
workers_per_gpu (int): How many subprocesses to use for data loading
|
99 |
+
for each GPU.
|
100 |
+
num_gpus (int): Number of GPUs. Only used in non-distributed training.
|
101 |
+
dist (bool): Distributed training/test or not. Default: True.
|
102 |
+
shuffle (bool): Whether to shuffle the data at every epoch.
|
103 |
+
Default: True.
|
104 |
+
seed (int | None): Seed to be used. Default: None.
|
105 |
+
drop_last (bool): Whether to drop the last incomplete batch in epoch.
|
106 |
+
Default: False
|
107 |
+
pin_memory (bool): Whether to use pin_memory in DataLoader.
|
108 |
+
Default: True
|
109 |
+
dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
|
110 |
+
kwargs: any keyword argument to be used to initialize DataLoader
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
DataLoader: A PyTorch dataloader.
|
114 |
+
"""
|
115 |
+
rank, world_size = get_dist_info()
|
116 |
+
if dist:
|
117 |
+
sampler = DistributedSampler(
|
118 |
+
dataset, world_size, rank, shuffle=shuffle)
|
119 |
+
shuffle = False
|
120 |
+
batch_size = samples_per_gpu
|
121 |
+
num_workers = workers_per_gpu
|
122 |
+
else:
|
123 |
+
sampler = None
|
124 |
+
batch_size = num_gpus * samples_per_gpu
|
125 |
+
num_workers = num_gpus * workers_per_gpu
|
126 |
+
|
127 |
+
init_fn = partial(
|
128 |
+
worker_init_fn, num_workers=num_workers, rank=rank,
|
129 |
+
seed=seed) if seed is not None else None
|
130 |
+
|
131 |
+
assert dataloader_type in (
|
132 |
+
'DataLoader',
|
133 |
+
'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
|
134 |
+
|
135 |
+
if dataloader_type == 'PoolDataLoader':
|
136 |
+
dataloader = PoolDataLoader
|
137 |
+
elif dataloader_type == 'DataLoader':
|
138 |
+
dataloader = DataLoader
|
139 |
+
|
140 |
+
data_loader = dataloader(
|
141 |
+
dataset,
|
142 |
+
batch_size=batch_size,
|
143 |
+
sampler=sampler,
|
144 |
+
num_workers=num_workers,
|
145 |
+
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
|
146 |
+
pin_memory=pin_memory,
|
147 |
+
shuffle=shuffle,
|
148 |
+
worker_init_fn=init_fn,
|
149 |
+
drop_last=drop_last,
|
150 |
+
**kwargs)
|
151 |
+
|
152 |
+
return data_loader
|
153 |
+
|
154 |
+
|
155 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
156 |
+
"""Worker init func for dataloader.
|
157 |
+
|
158 |
+
The seed of each worker equals to num_worker * rank + worker_id + user_seed
|
159 |
+
|
160 |
+
Args:
|
161 |
+
worker_id (int): Worker id.
|
162 |
+
num_workers (int): Number of workers.
|
163 |
+
rank (int): The rank of current process.
|
164 |
+
seed (int): The random seed to use.
|
165 |
+
"""
|
166 |
+
|
167 |
+
worker_seed = num_workers * rank + worker_id + seed
|
168 |
+
np.random.seed(worker_seed)
|
169 |
+
random.seed(worker_seed)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/chase_db1.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
from .builder import DATASETS
|
4 |
+
from .custom import CustomDataset
|
5 |
+
|
6 |
+
|
7 |
+
@DATASETS.register_module()
|
8 |
+
class ChaseDB1Dataset(CustomDataset):
|
9 |
+
"""Chase_db1 dataset.
|
10 |
+
|
11 |
+
In segmentation map annotation for Chase_db1, 0 stands for background,
|
12 |
+
which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
|
13 |
+
The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
14 |
+
'_1stHO.png'.
|
15 |
+
"""
|
16 |
+
|
17 |
+
CLASSES = ('background', 'vessel')
|
18 |
+
|
19 |
+
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
20 |
+
|
21 |
+
def __init__(self, **kwargs):
|
22 |
+
super(ChaseDB1Dataset, self).__init__(
|
23 |
+
img_suffix='.png',
|
24 |
+
seg_map_suffix='_1stHO.png',
|
25 |
+
reduce_zero_label=False,
|
26 |
+
**kwargs)
|
27 |
+
assert osp.exists(self.img_dir)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/cityscapes.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import tempfile
|
3 |
+
|
4 |
+
import annotator.mmpkg.mmcv as mmcv
|
5 |
+
import numpy as np
|
6 |
+
from annotator.mmpkg.mmcv.utils import print_log
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from .builder import DATASETS
|
10 |
+
from .custom import CustomDataset
|
11 |
+
|
12 |
+
|
13 |
+
@DATASETS.register_module()
|
14 |
+
class CityscapesDataset(CustomDataset):
|
15 |
+
"""Cityscapes dataset.
|
16 |
+
|
17 |
+
The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
|
18 |
+
fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
|
19 |
+
"""
|
20 |
+
|
21 |
+
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
|
22 |
+
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
|
23 |
+
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
|
24 |
+
'bicycle')
|
25 |
+
|
26 |
+
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
27 |
+
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
|
28 |
+
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
|
29 |
+
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
|
30 |
+
[0, 80, 100], [0, 0, 230], [119, 11, 32]]
|
31 |
+
|
32 |
+
def __init__(self, **kwargs):
|
33 |
+
super(CityscapesDataset, self).__init__(
|
34 |
+
img_suffix='_leftImg8bit.png',
|
35 |
+
seg_map_suffix='_gtFine_labelTrainIds.png',
|
36 |
+
**kwargs)
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def _convert_to_label_id(result):
|
40 |
+
"""Convert trainId to id for cityscapes."""
|
41 |
+
if isinstance(result, str):
|
42 |
+
result = np.load(result)
|
43 |
+
import cityscapesscripts.helpers.labels as CSLabels
|
44 |
+
result_copy = result.copy()
|
45 |
+
for trainId, label in CSLabels.trainId2label.items():
|
46 |
+
result_copy[result == trainId] = label.id
|
47 |
+
|
48 |
+
return result_copy
|
49 |
+
|
50 |
+
def results2img(self, results, imgfile_prefix, to_label_id):
|
51 |
+
"""Write the segmentation results to images.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
results (list[list | tuple | ndarray]): Testing results of the
|
55 |
+
dataset.
|
56 |
+
imgfile_prefix (str): The filename prefix of the png files.
|
57 |
+
If the prefix is "somepath/xxx",
|
58 |
+
the png files will be named "somepath/xxx.png".
|
59 |
+
to_label_id (bool): whether convert output to label_id for
|
60 |
+
submission
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
list[str: str]: result txt files which contains corresponding
|
64 |
+
semantic segmentation images.
|
65 |
+
"""
|
66 |
+
mmcv.mkdir_or_exist(imgfile_prefix)
|
67 |
+
result_files = []
|
68 |
+
prog_bar = mmcv.ProgressBar(len(self))
|
69 |
+
for idx in range(len(self)):
|
70 |
+
result = results[idx]
|
71 |
+
if to_label_id:
|
72 |
+
result = self._convert_to_label_id(result)
|
73 |
+
filename = self.img_infos[idx]['filename']
|
74 |
+
basename = osp.splitext(osp.basename(filename))[0]
|
75 |
+
|
76 |
+
png_filename = osp.join(imgfile_prefix, f'{basename}.png')
|
77 |
+
|
78 |
+
output = Image.fromarray(result.astype(np.uint8)).convert('P')
|
79 |
+
import cityscapesscripts.helpers.labels as CSLabels
|
80 |
+
palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
|
81 |
+
for label_id, label in CSLabels.id2label.items():
|
82 |
+
palette[label_id] = label.color
|
83 |
+
|
84 |
+
output.putpalette(palette)
|
85 |
+
output.save(png_filename)
|
86 |
+
result_files.append(png_filename)
|
87 |
+
prog_bar.update()
|
88 |
+
|
89 |
+
return result_files
|
90 |
+
|
91 |
+
def format_results(self, results, imgfile_prefix=None, to_label_id=True):
|
92 |
+
"""Format the results into dir (standard format for Cityscapes
|
93 |
+
evaluation).
|
94 |
+
|
95 |
+
Args:
|
96 |
+
results (list): Testing results of the dataset.
|
97 |
+
imgfile_prefix (str | None): The prefix of images files. It
|
98 |
+
includes the file path and the prefix of filename, e.g.,
|
99 |
+
"a/b/prefix". If not specified, a temp file will be created.
|
100 |
+
Default: None.
|
101 |
+
to_label_id (bool): whether convert output to label_id for
|
102 |
+
submission. Default: False
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
tuple: (result_files, tmp_dir), result_files is a list containing
|
106 |
+
the image paths, tmp_dir is the temporal directory created
|
107 |
+
for saving json/png files when img_prefix is not specified.
|
108 |
+
"""
|
109 |
+
|
110 |
+
assert isinstance(results, list), 'results must be a list'
|
111 |
+
assert len(results) == len(self), (
|
112 |
+
'The length of results is not equal to the dataset len: '
|
113 |
+
f'{len(results)} != {len(self)}')
|
114 |
+
|
115 |
+
if imgfile_prefix is None:
|
116 |
+
tmp_dir = tempfile.TemporaryDirectory()
|
117 |
+
imgfile_prefix = tmp_dir.name
|
118 |
+
else:
|
119 |
+
tmp_dir = None
|
120 |
+
result_files = self.results2img(results, imgfile_prefix, to_label_id)
|
121 |
+
|
122 |
+
return result_files, tmp_dir
|
123 |
+
|
124 |
+
def evaluate(self,
|
125 |
+
results,
|
126 |
+
metric='mIoU',
|
127 |
+
logger=None,
|
128 |
+
imgfile_prefix=None,
|
129 |
+
efficient_test=False):
|
130 |
+
"""Evaluation in Cityscapes/default protocol.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
results (list): Testing results of the dataset.
|
134 |
+
metric (str | list[str]): Metrics to be evaluated.
|
135 |
+
logger (logging.Logger | None | str): Logger used for printing
|
136 |
+
related information during evaluation. Default: None.
|
137 |
+
imgfile_prefix (str | None): The prefix of output image file,
|
138 |
+
for cityscapes evaluation only. It includes the file path and
|
139 |
+
the prefix of filename, e.g., "a/b/prefix".
|
140 |
+
If results are evaluated with cityscapes protocol, it would be
|
141 |
+
the prefix of output png files. The output files would be
|
142 |
+
png images under folder "a/b/prefix/xxx.png", where "xxx" is
|
143 |
+
the image name of cityscapes. If not specified, a temp file
|
144 |
+
will be created for evaluation.
|
145 |
+
Default: None.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
dict[str, float]: Cityscapes/default metrics.
|
149 |
+
"""
|
150 |
+
|
151 |
+
eval_results = dict()
|
152 |
+
metrics = metric.copy() if isinstance(metric, list) else [metric]
|
153 |
+
if 'cityscapes' in metrics:
|
154 |
+
eval_results.update(
|
155 |
+
self._evaluate_cityscapes(results, logger, imgfile_prefix))
|
156 |
+
metrics.remove('cityscapes')
|
157 |
+
if len(metrics) > 0:
|
158 |
+
eval_results.update(
|
159 |
+
super(CityscapesDataset,
|
160 |
+
self).evaluate(results, metrics, logger, efficient_test))
|
161 |
+
|
162 |
+
return eval_results
|
163 |
+
|
164 |
+
def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
|
165 |
+
"""Evaluation in Cityscapes protocol.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
results (list): Testing results of the dataset.
|
169 |
+
logger (logging.Logger | str | None): Logger used for printing
|
170 |
+
related information during evaluation. Default: None.
|
171 |
+
imgfile_prefix (str | None): The prefix of output image file
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
dict[str: float]: Cityscapes evaluation results.
|
175 |
+
"""
|
176 |
+
try:
|
177 |
+
import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
|
178 |
+
except ImportError:
|
179 |
+
raise ImportError('Please run "pip install cityscapesscripts" to '
|
180 |
+
'install cityscapesscripts first.')
|
181 |
+
msg = 'Evaluating in Cityscapes style'
|
182 |
+
if logger is None:
|
183 |
+
msg = '\n' + msg
|
184 |
+
print_log(msg, logger=logger)
|
185 |
+
|
186 |
+
result_files, tmp_dir = self.format_results(results, imgfile_prefix)
|
187 |
+
|
188 |
+
if tmp_dir is None:
|
189 |
+
result_dir = imgfile_prefix
|
190 |
+
else:
|
191 |
+
result_dir = tmp_dir.name
|
192 |
+
|
193 |
+
eval_results = dict()
|
194 |
+
print_log(f'Evaluating results under {result_dir} ...', logger=logger)
|
195 |
+
|
196 |
+
CSEval.args.evalInstLevelScore = True
|
197 |
+
CSEval.args.predictionPath = osp.abspath(result_dir)
|
198 |
+
CSEval.args.evalPixelAccuracy = True
|
199 |
+
CSEval.args.JSONOutput = False
|
200 |
+
|
201 |
+
seg_map_list = []
|
202 |
+
pred_list = []
|
203 |
+
|
204 |
+
# when evaluating with official cityscapesscripts,
|
205 |
+
# **_gtFine_labelIds.png is used
|
206 |
+
for seg_map in mmcv.scandir(
|
207 |
+
self.ann_dir, 'gtFine_labelIds.png', recursive=True):
|
208 |
+
seg_map_list.append(osp.join(self.ann_dir, seg_map))
|
209 |
+
pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
|
210 |
+
|
211 |
+
eval_results.update(
|
212 |
+
CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
|
213 |
+
|
214 |
+
if tmp_dir is not None:
|
215 |
+
tmp_dir.cleanup()
|
216 |
+
|
217 |
+
return eval_results
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/custom.py
ADDED
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
from collections import OrderedDict
|
4 |
+
from functools import reduce
|
5 |
+
|
6 |
+
import annotator.mmpkg.mmcv as mmcv
|
7 |
+
import numpy as np
|
8 |
+
from annotator.mmpkg.mmcv.utils import print_log
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
|
11 |
+
from annotator.mmpkg.mmseg.core import eval_metrics
|
12 |
+
from annotator.mmpkg.mmseg.utils import get_root_logger
|
13 |
+
from .builder import DATASETS
|
14 |
+
from .pipelines import Compose
|
15 |
+
|
16 |
+
|
17 |
+
@DATASETS.register_module()
|
18 |
+
class CustomDataset(Dataset):
|
19 |
+
"""Custom dataset for semantic segmentation. An example of file structure
|
20 |
+
is as followed.
|
21 |
+
|
22 |
+
.. code-block:: none
|
23 |
+
|
24 |
+
├── data
|
25 |
+
│ ├── my_dataset
|
26 |
+
│ │ ├── img_dir
|
27 |
+
│ │ │ ├── train
|
28 |
+
│ │ │ │ ├── xxx{img_suffix}
|
29 |
+
│ │ │ │ ├── yyy{img_suffix}
|
30 |
+
│ │ │ │ ├── zzz{img_suffix}
|
31 |
+
│ │ │ ├── val
|
32 |
+
│ │ ├── ann_dir
|
33 |
+
│ │ │ ├── train
|
34 |
+
│ │ │ │ ├── xxx{seg_map_suffix}
|
35 |
+
│ │ │ │ ├── yyy{seg_map_suffix}
|
36 |
+
│ │ │ │ ├── zzz{seg_map_suffix}
|
37 |
+
│ │ │ ├── val
|
38 |
+
|
39 |
+
The img/gt_semantic_seg pair of CustomDataset should be of the same
|
40 |
+
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
41 |
+
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
42 |
+
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
43 |
+
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
44 |
+
Please refer to ``docs/tutorials/new_dataset.md`` for more details.
|
45 |
+
|
46 |
+
|
47 |
+
Args:
|
48 |
+
pipeline (list[dict]): Processing pipeline
|
49 |
+
img_dir (str): Path to image directory
|
50 |
+
img_suffix (str): Suffix of images. Default: '.jpg'
|
51 |
+
ann_dir (str, optional): Path to annotation directory. Default: None
|
52 |
+
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
53 |
+
split (str, optional): Split txt file. If split is specified, only
|
54 |
+
file with suffix in the splits will be loaded. Otherwise, all
|
55 |
+
images in img_dir/ann_dir will be loaded. Default: None
|
56 |
+
data_root (str, optional): Data root for img_dir/ann_dir. Default:
|
57 |
+
None.
|
58 |
+
test_mode (bool): If test_mode=True, gt wouldn't be loaded.
|
59 |
+
ignore_index (int): The label index to be ignored. Default: 255
|
60 |
+
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
61 |
+
Default: False
|
62 |
+
classes (str | Sequence[str], optional): Specify classes to load.
|
63 |
+
If is None, ``cls.CLASSES`` will be used. Default: None.
|
64 |
+
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
65 |
+
The palette of segmentation map. If None is given, and
|
66 |
+
self.PALETTE is None, random palette will be generated.
|
67 |
+
Default: None
|
68 |
+
"""
|
69 |
+
|
70 |
+
CLASSES = None
|
71 |
+
|
72 |
+
PALETTE = None
|
73 |
+
|
74 |
+
def __init__(self,
|
75 |
+
pipeline,
|
76 |
+
img_dir,
|
77 |
+
img_suffix='.jpg',
|
78 |
+
ann_dir=None,
|
79 |
+
seg_map_suffix='.png',
|
80 |
+
split=None,
|
81 |
+
data_root=None,
|
82 |
+
test_mode=False,
|
83 |
+
ignore_index=255,
|
84 |
+
reduce_zero_label=False,
|
85 |
+
classes=None,
|
86 |
+
palette=None):
|
87 |
+
self.pipeline = Compose(pipeline)
|
88 |
+
self.img_dir = img_dir
|
89 |
+
self.img_suffix = img_suffix
|
90 |
+
self.ann_dir = ann_dir
|
91 |
+
self.seg_map_suffix = seg_map_suffix
|
92 |
+
self.split = split
|
93 |
+
self.data_root = data_root
|
94 |
+
self.test_mode = test_mode
|
95 |
+
self.ignore_index = ignore_index
|
96 |
+
self.reduce_zero_label = reduce_zero_label
|
97 |
+
self.label_map = None
|
98 |
+
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
|
99 |
+
classes, palette)
|
100 |
+
|
101 |
+
# join paths if data_root is specified
|
102 |
+
if self.data_root is not None:
|
103 |
+
if not osp.isabs(self.img_dir):
|
104 |
+
self.img_dir = osp.join(self.data_root, self.img_dir)
|
105 |
+
if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
|
106 |
+
self.ann_dir = osp.join(self.data_root, self.ann_dir)
|
107 |
+
if not (self.split is None or osp.isabs(self.split)):
|
108 |
+
self.split = osp.join(self.data_root, self.split)
|
109 |
+
|
110 |
+
# load annotations
|
111 |
+
self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
|
112 |
+
self.ann_dir,
|
113 |
+
self.seg_map_suffix, self.split)
|
114 |
+
|
115 |
+
def __len__(self):
|
116 |
+
"""Total number of samples of data."""
|
117 |
+
return len(self.img_infos)
|
118 |
+
|
119 |
+
def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
|
120 |
+
split):
|
121 |
+
"""Load annotation from directory.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
img_dir (str): Path to image directory
|
125 |
+
img_suffix (str): Suffix of images.
|
126 |
+
ann_dir (str|None): Path to annotation directory.
|
127 |
+
seg_map_suffix (str|None): Suffix of segmentation maps.
|
128 |
+
split (str|None): Split txt file. If split is specified, only file
|
129 |
+
with suffix in the splits will be loaded. Otherwise, all images
|
130 |
+
in img_dir/ann_dir will be loaded. Default: None
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
list[dict]: All image info of dataset.
|
134 |
+
"""
|
135 |
+
|
136 |
+
img_infos = []
|
137 |
+
if split is not None:
|
138 |
+
with open(split) as f:
|
139 |
+
for line in f:
|
140 |
+
img_name = line.strip()
|
141 |
+
img_info = dict(filename=img_name + img_suffix)
|
142 |
+
if ann_dir is not None:
|
143 |
+
seg_map = img_name + seg_map_suffix
|
144 |
+
img_info['ann'] = dict(seg_map=seg_map)
|
145 |
+
img_infos.append(img_info)
|
146 |
+
else:
|
147 |
+
for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
|
148 |
+
img_info = dict(filename=img)
|
149 |
+
if ann_dir is not None:
|
150 |
+
seg_map = img.replace(img_suffix, seg_map_suffix)
|
151 |
+
img_info['ann'] = dict(seg_map=seg_map)
|
152 |
+
img_infos.append(img_info)
|
153 |
+
|
154 |
+
print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
|
155 |
+
return img_infos
|
156 |
+
|
157 |
+
def get_ann_info(self, idx):
|
158 |
+
"""Get annotation by index.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
idx (int): Index of data.
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
dict: Annotation info of specified index.
|
165 |
+
"""
|
166 |
+
|
167 |
+
return self.img_infos[idx]['ann']
|
168 |
+
|
169 |
+
def pre_pipeline(self, results):
|
170 |
+
"""Prepare results dict for pipeline."""
|
171 |
+
results['seg_fields'] = []
|
172 |
+
results['img_prefix'] = self.img_dir
|
173 |
+
results['seg_prefix'] = self.ann_dir
|
174 |
+
if self.custom_classes:
|
175 |
+
results['label_map'] = self.label_map
|
176 |
+
|
177 |
+
def __getitem__(self, idx):
|
178 |
+
"""Get training/test data after pipeline.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
idx (int): Index of data.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
dict: Training/test data (with annotation if `test_mode` is set
|
185 |
+
False).
|
186 |
+
"""
|
187 |
+
|
188 |
+
if self.test_mode:
|
189 |
+
return self.prepare_test_img(idx)
|
190 |
+
else:
|
191 |
+
return self.prepare_train_img(idx)
|
192 |
+
|
193 |
+
def prepare_train_img(self, idx):
|
194 |
+
"""Get training data and annotations after pipeline.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
idx (int): Index of data.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
dict: Training data and annotation after pipeline with new keys
|
201 |
+
introduced by pipeline.
|
202 |
+
"""
|
203 |
+
|
204 |
+
img_info = self.img_infos[idx]
|
205 |
+
ann_info = self.get_ann_info(idx)
|
206 |
+
results = dict(img_info=img_info, ann_info=ann_info)
|
207 |
+
self.pre_pipeline(results)
|
208 |
+
return self.pipeline(results)
|
209 |
+
|
210 |
+
def prepare_test_img(self, idx):
|
211 |
+
"""Get testing data after pipeline.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
idx (int): Index of data.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
dict: Testing data after pipeline with new keys introduced by
|
218 |
+
pipeline.
|
219 |
+
"""
|
220 |
+
|
221 |
+
img_info = self.img_infos[idx]
|
222 |
+
results = dict(img_info=img_info)
|
223 |
+
self.pre_pipeline(results)
|
224 |
+
return self.pipeline(results)
|
225 |
+
|
226 |
+
def format_results(self, results, **kwargs):
|
227 |
+
"""Place holder to format result to dataset specific output."""
|
228 |
+
|
229 |
+
def get_gt_seg_maps(self, efficient_test=False):
|
230 |
+
"""Get ground truth segmentation maps for evaluation."""
|
231 |
+
gt_seg_maps = []
|
232 |
+
for img_info in self.img_infos:
|
233 |
+
seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
|
234 |
+
if efficient_test:
|
235 |
+
gt_seg_map = seg_map
|
236 |
+
else:
|
237 |
+
gt_seg_map = mmcv.imread(
|
238 |
+
seg_map, flag='unchanged', backend='pillow')
|
239 |
+
gt_seg_maps.append(gt_seg_map)
|
240 |
+
return gt_seg_maps
|
241 |
+
|
242 |
+
def get_classes_and_palette(self, classes=None, palette=None):
|
243 |
+
"""Get class names of current dataset.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
classes (Sequence[str] | str | None): If classes is None, use
|
247 |
+
default CLASSES defined by builtin dataset. If classes is a
|
248 |
+
string, take it as a file name. The file contains the name of
|
249 |
+
classes where each line contains one class name. If classes is
|
250 |
+
a tuple or list, override the CLASSES defined by the dataset.
|
251 |
+
palette (Sequence[Sequence[int]]] | np.ndarray | None):
|
252 |
+
The palette of segmentation map. If None is given, random
|
253 |
+
palette will be generated. Default: None
|
254 |
+
"""
|
255 |
+
if classes is None:
|
256 |
+
self.custom_classes = False
|
257 |
+
return self.CLASSES, self.PALETTE
|
258 |
+
|
259 |
+
self.custom_classes = True
|
260 |
+
if isinstance(classes, str):
|
261 |
+
# take it as a file path
|
262 |
+
class_names = mmcv.list_from_file(classes)
|
263 |
+
elif isinstance(classes, (tuple, list)):
|
264 |
+
class_names = classes
|
265 |
+
else:
|
266 |
+
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
267 |
+
|
268 |
+
if self.CLASSES:
|
269 |
+
if not set(classes).issubset(self.CLASSES):
|
270 |
+
raise ValueError('classes is not a subset of CLASSES.')
|
271 |
+
|
272 |
+
# dictionary, its keys are the old label ids and its values
|
273 |
+
# are the new label ids.
|
274 |
+
# used for changing pixel labels in load_annotations.
|
275 |
+
self.label_map = {}
|
276 |
+
for i, c in enumerate(self.CLASSES):
|
277 |
+
if c not in class_names:
|
278 |
+
self.label_map[i] = -1
|
279 |
+
else:
|
280 |
+
self.label_map[i] = classes.index(c)
|
281 |
+
|
282 |
+
palette = self.get_palette_for_custom_classes(class_names, palette)
|
283 |
+
|
284 |
+
return class_names, palette
|
285 |
+
|
286 |
+
def get_palette_for_custom_classes(self, class_names, palette=None):
|
287 |
+
|
288 |
+
if self.label_map is not None:
|
289 |
+
# return subset of palette
|
290 |
+
palette = []
|
291 |
+
for old_id, new_id in sorted(
|
292 |
+
self.label_map.items(), key=lambda x: x[1]):
|
293 |
+
if new_id != -1:
|
294 |
+
palette.append(self.PALETTE[old_id])
|
295 |
+
palette = type(self.PALETTE)(palette)
|
296 |
+
|
297 |
+
elif palette is None:
|
298 |
+
if self.PALETTE is None:
|
299 |
+
palette = np.random.randint(0, 255, size=(len(class_names), 3))
|
300 |
+
else:
|
301 |
+
palette = self.PALETTE
|
302 |
+
|
303 |
+
return palette
|
304 |
+
|
305 |
+
def evaluate(self,
|
306 |
+
results,
|
307 |
+
metric='mIoU',
|
308 |
+
logger=None,
|
309 |
+
efficient_test=False,
|
310 |
+
**kwargs):
|
311 |
+
"""Evaluate the dataset.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
results (list): Testing results of the dataset.
|
315 |
+
metric (str | list[str]): Metrics to be evaluated. 'mIoU',
|
316 |
+
'mDice' and 'mFscore' are supported.
|
317 |
+
logger (logging.Logger | None | str): Logger used for printing
|
318 |
+
related information during evaluation. Default: None.
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
dict[str, float]: Default metrics.
|
322 |
+
"""
|
323 |
+
|
324 |
+
if isinstance(metric, str):
|
325 |
+
metric = [metric]
|
326 |
+
allowed_metrics = ['mIoU', 'mDice', 'mFscore']
|
327 |
+
if not set(metric).issubset(set(allowed_metrics)):
|
328 |
+
raise KeyError('metric {} is not supported'.format(metric))
|
329 |
+
eval_results = {}
|
330 |
+
gt_seg_maps = self.get_gt_seg_maps(efficient_test)
|
331 |
+
if self.CLASSES is None:
|
332 |
+
num_classes = len(
|
333 |
+
reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
|
334 |
+
else:
|
335 |
+
num_classes = len(self.CLASSES)
|
336 |
+
ret_metrics = eval_metrics(
|
337 |
+
results,
|
338 |
+
gt_seg_maps,
|
339 |
+
num_classes,
|
340 |
+
self.ignore_index,
|
341 |
+
metric,
|
342 |
+
label_map=self.label_map,
|
343 |
+
reduce_zero_label=self.reduce_zero_label)
|
344 |
+
|
345 |
+
if self.CLASSES is None:
|
346 |
+
class_names = tuple(range(num_classes))
|
347 |
+
else:
|
348 |
+
class_names = self.CLASSES
|
349 |
+
|
350 |
+
# summary table
|
351 |
+
ret_metrics_summary = OrderedDict({
|
352 |
+
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
|
353 |
+
for ret_metric, ret_metric_value in ret_metrics.items()
|
354 |
+
})
|
355 |
+
|
356 |
+
# each class table
|
357 |
+
ret_metrics.pop('aAcc', None)
|
358 |
+
ret_metrics_class = OrderedDict({
|
359 |
+
ret_metric: np.round(ret_metric_value * 100, 2)
|
360 |
+
for ret_metric, ret_metric_value in ret_metrics.items()
|
361 |
+
})
|
362 |
+
ret_metrics_class.update({'Class': class_names})
|
363 |
+
ret_metrics_class.move_to_end('Class', last=False)
|
364 |
+
|
365 |
+
try:
|
366 |
+
from prettytable import PrettyTable
|
367 |
+
# for logger
|
368 |
+
class_table_data = PrettyTable()
|
369 |
+
for key, val in ret_metrics_class.items():
|
370 |
+
class_table_data.add_column(key, val)
|
371 |
+
|
372 |
+
summary_table_data = PrettyTable()
|
373 |
+
for key, val in ret_metrics_summary.items():
|
374 |
+
if key == 'aAcc':
|
375 |
+
summary_table_data.add_column(key, [val])
|
376 |
+
else:
|
377 |
+
summary_table_data.add_column('m' + key, [val])
|
378 |
+
|
379 |
+
print_log('per class results:', logger)
|
380 |
+
print_log('\n' + class_table_data.get_string(), logger=logger)
|
381 |
+
print_log('Summary:', logger)
|
382 |
+
print_log('\n' + summary_table_data.get_string(), logger=logger)
|
383 |
+
except ImportError: # prettytable is not installed
|
384 |
+
pass
|
385 |
+
|
386 |
+
# each metric dict
|
387 |
+
for key, value in ret_metrics_summary.items():
|
388 |
+
if key == 'aAcc':
|
389 |
+
eval_results[key] = value / 100.0
|
390 |
+
else:
|
391 |
+
eval_results['m' + key] = value / 100.0
|
392 |
+
|
393 |
+
ret_metrics_class.pop('Class', None)
|
394 |
+
for key, value in ret_metrics_class.items():
|
395 |
+
eval_results.update({
|
396 |
+
key + '.' + str(name): value[idx] / 100.0
|
397 |
+
for idx, name in enumerate(class_names)
|
398 |
+
})
|
399 |
+
|
400 |
+
if mmcv.is_list_of(results, str):
|
401 |
+
for file_name in results:
|
402 |
+
os.remove(file_name)
|
403 |
+
return eval_results
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/dataset_wrappers.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
2 |
+
|
3 |
+
from .builder import DATASETS
|
4 |
+
|
5 |
+
|
6 |
+
@DATASETS.register_module()
|
7 |
+
class ConcatDataset(_ConcatDataset):
|
8 |
+
"""A wrapper of concatenated dataset.
|
9 |
+
|
10 |
+
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
|
11 |
+
concat the group flag for image aspect ratio.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
datasets (list[:obj:`Dataset`]): A list of datasets.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, datasets):
|
18 |
+
super(ConcatDataset, self).__init__(datasets)
|
19 |
+
self.CLASSES = datasets[0].CLASSES
|
20 |
+
self.PALETTE = datasets[0].PALETTE
|
21 |
+
|
22 |
+
|
23 |
+
@DATASETS.register_module()
|
24 |
+
class RepeatDataset(object):
|
25 |
+
"""A wrapper of repeated dataset.
|
26 |
+
|
27 |
+
The length of repeated dataset will be `times` larger than the original
|
28 |
+
dataset. This is useful when the data loading time is long but the dataset
|
29 |
+
is small. Using RepeatDataset can reduce the data loading time between
|
30 |
+
epochs.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
dataset (:obj:`Dataset`): The dataset to be repeated.
|
34 |
+
times (int): Repeat times.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, dataset, times):
|
38 |
+
self.dataset = dataset
|
39 |
+
self.times = times
|
40 |
+
self.CLASSES = dataset.CLASSES
|
41 |
+
self.PALETTE = dataset.PALETTE
|
42 |
+
self._ori_len = len(self.dataset)
|
43 |
+
|
44 |
+
def __getitem__(self, idx):
|
45 |
+
"""Get item from original dataset."""
|
46 |
+
return self.dataset[idx % self._ori_len]
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
"""The length is multiplied by ``times``"""
|
50 |
+
return self.times * self._ori_len
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/drive.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
from .builder import DATASETS
|
4 |
+
from .custom import CustomDataset
|
5 |
+
|
6 |
+
|
7 |
+
@DATASETS.register_module()
|
8 |
+
class DRIVEDataset(CustomDataset):
|
9 |
+
"""DRIVE dataset.
|
10 |
+
|
11 |
+
In segmentation map annotation for DRIVE, 0 stands for background, which is
|
12 |
+
included in 2 categories. ``reduce_zero_label`` is fixed to False. The
|
13 |
+
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
14 |
+
'_manual1.png'.
|
15 |
+
"""
|
16 |
+
|
17 |
+
CLASSES = ('background', 'vessel')
|
18 |
+
|
19 |
+
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
20 |
+
|
21 |
+
def __init__(self, **kwargs):
|
22 |
+
super(DRIVEDataset, self).__init__(
|
23 |
+
img_suffix='.png',
|
24 |
+
seg_map_suffix='_manual1.png',
|
25 |
+
reduce_zero_label=False,
|
26 |
+
**kwargs)
|
27 |
+
assert osp.exists(self.img_dir)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/hrf.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
from .builder import DATASETS
|
4 |
+
from .custom import CustomDataset
|
5 |
+
|
6 |
+
|
7 |
+
@DATASETS.register_module()
|
8 |
+
class HRFDataset(CustomDataset):
|
9 |
+
"""HRF dataset.
|
10 |
+
|
11 |
+
In segmentation map annotation for HRF, 0 stands for background, which is
|
12 |
+
included in 2 categories. ``reduce_zero_label`` is fixed to False. The
|
13 |
+
``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
|
14 |
+
'.png'.
|
15 |
+
"""
|
16 |
+
|
17 |
+
CLASSES = ('background', 'vessel')
|
18 |
+
|
19 |
+
PALETTE = [[120, 120, 120], [6, 230, 230]]
|
20 |
+
|
21 |
+
def __init__(self, **kwargs):
|
22 |
+
super(HRFDataset, self).__init__(
|
23 |
+
img_suffix='.png',
|
24 |
+
seg_map_suffix='.png',
|
25 |
+
reduce_zero_label=False,
|
26 |
+
**kwargs)
|
27 |
+
assert osp.exists(self.img_dir)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pascal_context.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
from .builder import DATASETS
|
4 |
+
from .custom import CustomDataset
|
5 |
+
|
6 |
+
|
7 |
+
@DATASETS.register_module()
|
8 |
+
class PascalContextDataset(CustomDataset):
|
9 |
+
"""PascalContext dataset.
|
10 |
+
|
11 |
+
In segmentation map annotation for PascalContext, 0 stands for background,
|
12 |
+
which is included in 60 categories. ``reduce_zero_label`` is fixed to
|
13 |
+
False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
14 |
+
fixed to '.png'.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
split (str): Split txt file for PascalContext.
|
18 |
+
"""
|
19 |
+
|
20 |
+
CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
|
21 |
+
'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
|
22 |
+
'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
|
23 |
+
'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
|
24 |
+
'floor', 'flower', 'food', 'grass', 'ground', 'horse',
|
25 |
+
'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
|
26 |
+
'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
|
27 |
+
'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
|
28 |
+
'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
|
29 |
+
'window', 'wood')
|
30 |
+
|
31 |
+
PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
32 |
+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
33 |
+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
34 |
+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
35 |
+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
36 |
+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
37 |
+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
38 |
+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
39 |
+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
40 |
+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
41 |
+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
42 |
+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
43 |
+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
44 |
+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
45 |
+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
|
46 |
+
|
47 |
+
def __init__(self, split, **kwargs):
|
48 |
+
super(PascalContextDataset, self).__init__(
|
49 |
+
img_suffix='.jpg',
|
50 |
+
seg_map_suffix='.png',
|
51 |
+
split=split,
|
52 |
+
reduce_zero_label=False,
|
53 |
+
**kwargs)
|
54 |
+
assert osp.exists(self.img_dir) and self.split is not None
|
55 |
+
|
56 |
+
|
57 |
+
@DATASETS.register_module()
|
58 |
+
class PascalContextDataset59(CustomDataset):
|
59 |
+
"""PascalContext dataset.
|
60 |
+
|
61 |
+
In segmentation map annotation for PascalContext, 0 stands for background,
|
62 |
+
which is included in 60 categories. ``reduce_zero_label`` is fixed to
|
63 |
+
False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
64 |
+
fixed to '.png'.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
split (str): Split txt file for PascalContext.
|
68 |
+
"""
|
69 |
+
|
70 |
+
CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
|
71 |
+
'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
|
72 |
+
'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
|
73 |
+
'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
|
74 |
+
'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
|
75 |
+
'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
|
76 |
+
'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
|
77 |
+
'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
|
78 |
+
'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
|
79 |
+
|
80 |
+
PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
|
81 |
+
[120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
|
82 |
+
[4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
|
83 |
+
[120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
|
84 |
+
[204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
|
85 |
+
[61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
|
86 |
+
[255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
|
87 |
+
[112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
|
88 |
+
[10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
|
89 |
+
[102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
|
90 |
+
[0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
|
91 |
+
[235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
|
92 |
+
[250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
|
93 |
+
[255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
|
94 |
+
[0, 235, 255], [0, 173, 255], [31, 0, 255]]
|
95 |
+
|
96 |
+
def __init__(self, split, **kwargs):
|
97 |
+
super(PascalContextDataset59, self).__init__(
|
98 |
+
img_suffix='.jpg',
|
99 |
+
seg_map_suffix='.png',
|
100 |
+
split=split,
|
101 |
+
reduce_zero_label=True,
|
102 |
+
**kwargs)
|
103 |
+
assert osp.exists(self.img_dir) and self.split is not None
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .compose import Compose
|
2 |
+
from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
|
3 |
+
Transpose, to_tensor)
|
4 |
+
from .loading import LoadAnnotations, LoadImageFromFile
|
5 |
+
from .test_time_aug import MultiScaleFlipAug
|
6 |
+
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
|
7 |
+
PhotoMetricDistortion, RandomCrop, RandomFlip,
|
8 |
+
RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
|
12 |
+
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
|
13 |
+
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
|
14 |
+
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
15 |
+
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
|
16 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/compose.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
|
3 |
+
from annotator.mmpkg.mmcv.utils import build_from_cfg
|
4 |
+
|
5 |
+
from ..builder import PIPELINES
|
6 |
+
|
7 |
+
|
8 |
+
@PIPELINES.register_module()
|
9 |
+
class Compose(object):
|
10 |
+
"""Compose multiple transforms sequentially.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
transforms (Sequence[dict | callable]): Sequence of transform object or
|
14 |
+
config dict to be composed.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, transforms):
|
18 |
+
assert isinstance(transforms, collections.abc.Sequence)
|
19 |
+
self.transforms = []
|
20 |
+
for transform in transforms:
|
21 |
+
if isinstance(transform, dict):
|
22 |
+
transform = build_from_cfg(transform, PIPELINES)
|
23 |
+
self.transforms.append(transform)
|
24 |
+
elif callable(transform):
|
25 |
+
self.transforms.append(transform)
|
26 |
+
else:
|
27 |
+
raise TypeError('transform must be callable or a dict')
|
28 |
+
|
29 |
+
def __call__(self, data):
|
30 |
+
"""Call function to apply transforms sequentially.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
data (dict): A result dict contains the data to transform.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
dict: Transformed data.
|
37 |
+
"""
|
38 |
+
|
39 |
+
for t in self.transforms:
|
40 |
+
data = t(data)
|
41 |
+
if data is None:
|
42 |
+
return None
|
43 |
+
return data
|
44 |
+
|
45 |
+
def __repr__(self):
|
46 |
+
format_string = self.__class__.__name__ + '('
|
47 |
+
for t in self.transforms:
|
48 |
+
format_string += '\n'
|
49 |
+
format_string += f' {t}'
|
50 |
+
format_string += '\n)'
|
51 |
+
return format_string
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/formating.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Sequence
|
2 |
+
|
3 |
+
import annotator.mmpkg.mmcv as mmcv
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from annotator.mmpkg.mmcv.parallel import DataContainer as DC
|
7 |
+
|
8 |
+
from ..builder import PIPELINES
|
9 |
+
|
10 |
+
|
11 |
+
def to_tensor(data):
|
12 |
+
"""Convert objects of various python types to :obj:`torch.Tensor`.
|
13 |
+
|
14 |
+
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
|
15 |
+
:class:`Sequence`, :class:`int` and :class:`float`.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
|
19 |
+
be converted.
|
20 |
+
"""
|
21 |
+
|
22 |
+
if isinstance(data, torch.Tensor):
|
23 |
+
return data
|
24 |
+
elif isinstance(data, np.ndarray):
|
25 |
+
return torch.from_numpy(data)
|
26 |
+
elif isinstance(data, Sequence) and not mmcv.is_str(data):
|
27 |
+
return torch.tensor(data)
|
28 |
+
elif isinstance(data, int):
|
29 |
+
return torch.LongTensor([data])
|
30 |
+
elif isinstance(data, float):
|
31 |
+
return torch.FloatTensor([data])
|
32 |
+
else:
|
33 |
+
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
|
34 |
+
|
35 |
+
|
36 |
+
@PIPELINES.register_module()
|
37 |
+
class ToTensor(object):
|
38 |
+
"""Convert some results to :obj:`torch.Tensor` by given keys.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
keys (Sequence[str]): Keys that need to be converted to Tensor.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, keys):
|
45 |
+
self.keys = keys
|
46 |
+
|
47 |
+
def __call__(self, results):
|
48 |
+
"""Call function to convert data in results to :obj:`torch.Tensor`.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
results (dict): Result dict contains the data to convert.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
dict: The result dict contains the data converted
|
55 |
+
to :obj:`torch.Tensor`.
|
56 |
+
"""
|
57 |
+
|
58 |
+
for key in self.keys:
|
59 |
+
results[key] = to_tensor(results[key])
|
60 |
+
return results
|
61 |
+
|
62 |
+
def __repr__(self):
|
63 |
+
return self.__class__.__name__ + f'(keys={self.keys})'
|
64 |
+
|
65 |
+
|
66 |
+
@PIPELINES.register_module()
|
67 |
+
class ImageToTensor(object):
|
68 |
+
"""Convert image to :obj:`torch.Tensor` by given keys.
|
69 |
+
|
70 |
+
The dimension order of input image is (H, W, C). The pipeline will convert
|
71 |
+
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
|
72 |
+
(1, H, W).
|
73 |
+
|
74 |
+
Args:
|
75 |
+
keys (Sequence[str]): Key of images to be converted to Tensor.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, keys):
|
79 |
+
self.keys = keys
|
80 |
+
|
81 |
+
def __call__(self, results):
|
82 |
+
"""Call function to convert image in results to :obj:`torch.Tensor` and
|
83 |
+
transpose the channel order.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
results (dict): Result dict contains the image data to convert.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
dict: The result dict contains the image converted
|
90 |
+
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
|
91 |
+
"""
|
92 |
+
|
93 |
+
for key in self.keys:
|
94 |
+
img = results[key]
|
95 |
+
if len(img.shape) < 3:
|
96 |
+
img = np.expand_dims(img, -1)
|
97 |
+
results[key] = to_tensor(img.transpose(2, 0, 1))
|
98 |
+
return results
|
99 |
+
|
100 |
+
def __repr__(self):
|
101 |
+
return self.__class__.__name__ + f'(keys={self.keys})'
|
102 |
+
|
103 |
+
|
104 |
+
@PIPELINES.register_module()
|
105 |
+
class Transpose(object):
|
106 |
+
"""Transpose some results by given keys.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
keys (Sequence[str]): Keys of results to be transposed.
|
110 |
+
order (Sequence[int]): Order of transpose.
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(self, keys, order):
|
114 |
+
self.keys = keys
|
115 |
+
self.order = order
|
116 |
+
|
117 |
+
def __call__(self, results):
|
118 |
+
"""Call function to convert image in results to :obj:`torch.Tensor` and
|
119 |
+
transpose the channel order.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
results (dict): Result dict contains the image data to convert.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
dict: The result dict contains the image converted
|
126 |
+
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
|
127 |
+
"""
|
128 |
+
|
129 |
+
for key in self.keys:
|
130 |
+
results[key] = results[key].transpose(self.order)
|
131 |
+
return results
|
132 |
+
|
133 |
+
def __repr__(self):
|
134 |
+
return self.__class__.__name__ + \
|
135 |
+
f'(keys={self.keys}, order={self.order})'
|
136 |
+
|
137 |
+
|
138 |
+
@PIPELINES.register_module()
|
139 |
+
class ToDataContainer(object):
|
140 |
+
"""Convert results to :obj:`mmcv.DataContainer` by given fields.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
fields (Sequence[dict]): Each field is a dict like
|
144 |
+
``dict(key='xxx', **kwargs)``. The ``key`` in result will
|
145 |
+
be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
|
146 |
+
Default: ``(dict(key='img', stack=True),
|
147 |
+
dict(key='gt_semantic_seg'))``.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def __init__(self,
|
151 |
+
fields=(dict(key='img',
|
152 |
+
stack=True), dict(key='gt_semantic_seg'))):
|
153 |
+
self.fields = fields
|
154 |
+
|
155 |
+
def __call__(self, results):
|
156 |
+
"""Call function to convert data in results to
|
157 |
+
:obj:`mmcv.DataContainer`.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
results (dict): Result dict contains the data to convert.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
dict: The result dict contains the data converted to
|
164 |
+
:obj:`mmcv.DataContainer`.
|
165 |
+
"""
|
166 |
+
|
167 |
+
for field in self.fields:
|
168 |
+
field = field.copy()
|
169 |
+
key = field.pop('key')
|
170 |
+
results[key] = DC(results[key], **field)
|
171 |
+
return results
|
172 |
+
|
173 |
+
def __repr__(self):
|
174 |
+
return self.__class__.__name__ + f'(fields={self.fields})'
|
175 |
+
|
176 |
+
|
177 |
+
@PIPELINES.register_module()
|
178 |
+
class DefaultFormatBundle(object):
|
179 |
+
"""Default formatting bundle.
|
180 |
+
|
181 |
+
It simplifies the pipeline of formatting common fields, including "img"
|
182 |
+
and "gt_semantic_seg". These fields are formatted as follows.
|
183 |
+
|
184 |
+
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
|
185 |
+
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
|
186 |
+
(3)to DataContainer (stack=True)
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __call__(self, results):
|
190 |
+
"""Call function to transform and format common fields in results.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
results (dict): Result dict contains the data to convert.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
dict: The result dict contains the data that is formatted with
|
197 |
+
default bundle.
|
198 |
+
"""
|
199 |
+
|
200 |
+
if 'img' in results:
|
201 |
+
img = results['img']
|
202 |
+
if len(img.shape) < 3:
|
203 |
+
img = np.expand_dims(img, -1)
|
204 |
+
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
205 |
+
results['img'] = DC(to_tensor(img), stack=True)
|
206 |
+
if 'gt_semantic_seg' in results:
|
207 |
+
# convert to long
|
208 |
+
results['gt_semantic_seg'] = DC(
|
209 |
+
to_tensor(results['gt_semantic_seg'][None,
|
210 |
+
...].astype(np.int64)),
|
211 |
+
stack=True)
|
212 |
+
return results
|
213 |
+
|
214 |
+
def __repr__(self):
|
215 |
+
return self.__class__.__name__
|
216 |
+
|
217 |
+
|
218 |
+
@PIPELINES.register_module()
|
219 |
+
class Collect(object):
|
220 |
+
"""Collect data from the loader relevant to the specific task.
|
221 |
+
|
222 |
+
This is usually the last stage of the data loader pipeline. Typically keys
|
223 |
+
is set to some subset of "img", "gt_semantic_seg".
|
224 |
+
|
225 |
+
The "img_meta" item is always populated. The contents of the "img_meta"
|
226 |
+
dictionary depends on "meta_keys". By default this includes:
|
227 |
+
|
228 |
+
- "img_shape": shape of the image input to the network as a tuple
|
229 |
+
(h, w, c). Note that images may be zero padded on the bottom/right
|
230 |
+
if the batch tensor is larger than this shape.
|
231 |
+
|
232 |
+
- "scale_factor": a float indicating the preprocessing scale
|
233 |
+
|
234 |
+
- "flip": a boolean indicating if image flip transform was used
|
235 |
+
|
236 |
+
- "filename": path to the image file
|
237 |
+
|
238 |
+
- "ori_shape": original shape of the image as a tuple (h, w, c)
|
239 |
+
|
240 |
+
- "pad_shape": image shape after padding
|
241 |
+
|
242 |
+
- "img_norm_cfg": a dict of normalization information:
|
243 |
+
- mean - per channel mean subtraction
|
244 |
+
- std - per channel std divisor
|
245 |
+
- to_rgb - bool indicating if bgr was converted to rgb
|
246 |
+
|
247 |
+
Args:
|
248 |
+
keys (Sequence[str]): Keys of results to be collected in ``data``.
|
249 |
+
meta_keys (Sequence[str], optional): Meta keys to be converted to
|
250 |
+
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
|
251 |
+
Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
|
252 |
+
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
|
253 |
+
'img_norm_cfg')``
|
254 |
+
"""
|
255 |
+
|
256 |
+
def __init__(self,
|
257 |
+
keys,
|
258 |
+
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
259 |
+
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
260 |
+
'flip_direction', 'img_norm_cfg')):
|
261 |
+
self.keys = keys
|
262 |
+
self.meta_keys = meta_keys
|
263 |
+
|
264 |
+
def __call__(self, results):
|
265 |
+
"""Call function to collect keys in results. The keys in ``meta_keys``
|
266 |
+
will be converted to :obj:mmcv.DataContainer.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
results (dict): Result dict contains the data to collect.
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
dict: The result dict contains the following keys
|
273 |
+
- keys in``self.keys``
|
274 |
+
- ``img_metas``
|
275 |
+
"""
|
276 |
+
|
277 |
+
data = {}
|
278 |
+
img_meta = {}
|
279 |
+
for key in self.meta_keys:
|
280 |
+
img_meta[key] = results[key]
|
281 |
+
data['img_metas'] = DC(img_meta, cpu_only=True)
|
282 |
+
for key in self.keys:
|
283 |
+
data[key] = results[key]
|
284 |
+
return data
|
285 |
+
|
286 |
+
def __repr__(self):
|
287 |
+
return self.__class__.__name__ + \
|
288 |
+
f'(keys={self.keys}, meta_keys={self.meta_keys})'
|