toto10 commited on
Commit
ff00ee1
·
1 Parent(s): f07387d

3648a469f0fb07172ece85292a5ce8613ac37e3239706f93cac30eb840e698f2

Browse files
Files changed (50) hide show
  1. microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_jit.py +41 -0
  2. microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_wrapper.py +107 -0
  3. microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/path.py +101 -0
  4. microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/progressbar.py +208 -0
  5. microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/registry.py +315 -0
  6. microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/testing.py +140 -0
  7. microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/timer.py +118 -0
  8. microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/trace.py +23 -0
  9. microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/version_utils.py +90 -0
  10. microsoftexcel-controlnet/annotator/mmpkg/mmcv/version.py +35 -0
  11. microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/__init__.py +11 -0
  12. microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/io.py +318 -0
  13. microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/optflow.py +254 -0
  14. microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/processing.py +160 -0
  15. microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/__init__.py +9 -0
  16. microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/color.py +51 -0
  17. microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/image.py +152 -0
  18. microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/optflow.py +112 -0
  19. microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/__init__.py +9 -0
  20. microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/inference.py +138 -0
  21. microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/test.py +238 -0
  22. microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/train.py +116 -0
  23. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/__init__.py +3 -0
  24. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/__init__.py +8 -0
  25. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/class_names.py +152 -0
  26. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/eval_hooks.py +109 -0
  27. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/metrics.py +326 -0
  28. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/__init__.py +4 -0
  29. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/builder.py +8 -0
  30. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/__init__.py +4 -0
  31. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/base_pixel_sampler.py +12 -0
  32. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/ohem_pixel_sampler.py +76 -0
  33. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/__init__.py +3 -0
  34. microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/misc.py +17 -0
  35. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/__init__.py +19 -0
  36. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/ade.py +84 -0
  37. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/builder.py +169 -0
  38. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/chase_db1.py +27 -0
  39. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/cityscapes.py +217 -0
  40. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/custom.py +403 -0
  41. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/dataset_wrappers.py +50 -0
  42. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/drive.py +27 -0
  43. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/hrf.py +27 -0
  44. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pascal_context.py +103 -0
  45. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/__init__.py +16 -0
  46. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/compose.py +51 -0
  47. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/formating.py +288 -0
  48. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/loading.py +153 -0
  49. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/test_time_aug.py +133 -0
  50. microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/transforms.py +889 -0
microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_jit.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+
4
+ from .parrots_wrapper import TORCH_VERSION
5
+
6
+ parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
7
+
8
+ if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON':
9
+ from parrots.jit import pat as jit
10
+ else:
11
+
12
+ def jit(func=None,
13
+ check_input=None,
14
+ full_shape=True,
15
+ derivate=False,
16
+ coderize=False,
17
+ optimize=False):
18
+
19
+ def wrapper(func):
20
+
21
+ def wrapper_inner(*args, **kargs):
22
+ return func(*args, **kargs)
23
+
24
+ return wrapper_inner
25
+
26
+ if func is None:
27
+ return wrapper
28
+ else:
29
+ return func
30
+
31
+
32
+ if TORCH_VERSION == 'parrots':
33
+ from parrots.utils.tester import skip_no_elena
34
+ else:
35
+
36
+ def skip_no_elena(func):
37
+
38
+ def wrapper(*args, **kargs):
39
+ return func(*args, **kargs)
40
+
41
+ return wrapper
microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/parrots_wrapper.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from functools import partial
3
+
4
+ import torch
5
+
6
+ TORCH_VERSION = torch.__version__
7
+
8
+
9
+ def is_rocm_pytorch() -> bool:
10
+ is_rocm = False
11
+ if TORCH_VERSION != 'parrots':
12
+ try:
13
+ from torch.utils.cpp_extension import ROCM_HOME
14
+ is_rocm = True if ((torch.version.hip is not None) and
15
+ (ROCM_HOME is not None)) else False
16
+ except ImportError:
17
+ pass
18
+ return is_rocm
19
+
20
+
21
+ def _get_cuda_home():
22
+ if TORCH_VERSION == 'parrots':
23
+ from parrots.utils.build_extension import CUDA_HOME
24
+ else:
25
+ if is_rocm_pytorch():
26
+ from torch.utils.cpp_extension import ROCM_HOME
27
+ CUDA_HOME = ROCM_HOME
28
+ else:
29
+ from torch.utils.cpp_extension import CUDA_HOME
30
+ return CUDA_HOME
31
+
32
+
33
+ def get_build_config():
34
+ if TORCH_VERSION == 'parrots':
35
+ from parrots.config import get_build_info
36
+ return get_build_info()
37
+ else:
38
+ return torch.__config__.show()
39
+
40
+
41
+ def _get_conv():
42
+ if TORCH_VERSION == 'parrots':
43
+ from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
44
+ else:
45
+ from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
46
+ return _ConvNd, _ConvTransposeMixin
47
+
48
+
49
+ def _get_dataloader():
50
+ if TORCH_VERSION == 'parrots':
51
+ from torch.utils.data import DataLoader, PoolDataLoader
52
+ else:
53
+ from torch.utils.data import DataLoader
54
+ PoolDataLoader = DataLoader
55
+ return DataLoader, PoolDataLoader
56
+
57
+
58
+ def _get_extension():
59
+ if TORCH_VERSION == 'parrots':
60
+ from parrots.utils.build_extension import BuildExtension, Extension
61
+ CppExtension = partial(Extension, cuda=False)
62
+ CUDAExtension = partial(Extension, cuda=True)
63
+ else:
64
+ from torch.utils.cpp_extension import (BuildExtension, CppExtension,
65
+ CUDAExtension)
66
+ return BuildExtension, CppExtension, CUDAExtension
67
+
68
+
69
+ def _get_pool():
70
+ if TORCH_VERSION == 'parrots':
71
+ from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
72
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
73
+ _MaxPoolNd)
74
+ else:
75
+ from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
76
+ _AdaptiveMaxPoolNd, _AvgPoolNd,
77
+ _MaxPoolNd)
78
+ return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
79
+
80
+
81
+ def _get_norm():
82
+ if TORCH_VERSION == 'parrots':
83
+ from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
84
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
85
+ else:
86
+ from torch.nn.modules.instancenorm import _InstanceNorm
87
+ from torch.nn.modules.batchnorm import _BatchNorm
88
+ SyncBatchNorm_ = torch.nn.SyncBatchNorm
89
+ return _BatchNorm, _InstanceNorm, SyncBatchNorm_
90
+
91
+
92
+ _ConvNd, _ConvTransposeMixin = _get_conv()
93
+ DataLoader, PoolDataLoader = _get_dataloader()
94
+ BuildExtension, CppExtension, CUDAExtension = _get_extension()
95
+ _BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
96
+ _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
97
+
98
+
99
+ class SyncBatchNorm(SyncBatchNorm_):
100
+
101
+ def _check_input_dim(self, input):
102
+ if TORCH_VERSION == 'parrots':
103
+ if input.dim() < 2:
104
+ raise ValueError(
105
+ f'expected at least 2D input (got {input.dim()}D input)')
106
+ else:
107
+ super()._check_input_dim(input)
microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/path.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ import os.path as osp
4
+ from pathlib import Path
5
+
6
+ from .misc import is_str
7
+
8
+
9
+ def is_filepath(x):
10
+ return is_str(x) or isinstance(x, Path)
11
+
12
+
13
+ def fopen(filepath, *args, **kwargs):
14
+ if is_str(filepath):
15
+ return open(filepath, *args, **kwargs)
16
+ elif isinstance(filepath, Path):
17
+ return filepath.open(*args, **kwargs)
18
+ raise ValueError('`filepath` should be a string or a Path')
19
+
20
+
21
+ def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
22
+ if not osp.isfile(filename):
23
+ raise FileNotFoundError(msg_tmpl.format(filename))
24
+
25
+
26
+ def mkdir_or_exist(dir_name, mode=0o777):
27
+ if dir_name == '':
28
+ return
29
+ dir_name = osp.expanduser(dir_name)
30
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
31
+
32
+
33
+ def symlink(src, dst, overwrite=True, **kwargs):
34
+ if os.path.lexists(dst) and overwrite:
35
+ os.remove(dst)
36
+ os.symlink(src, dst, **kwargs)
37
+
38
+
39
+ def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
40
+ """Scan a directory to find the interested files.
41
+
42
+ Args:
43
+ dir_path (str | obj:`Path`): Path of the directory.
44
+ suffix (str | tuple(str), optional): File suffix that we are
45
+ interested in. Default: None.
46
+ recursive (bool, optional): If set to True, recursively scan the
47
+ directory. Default: False.
48
+ case_sensitive (bool, optional) : If set to False, ignore the case of
49
+ suffix. Default: True.
50
+
51
+ Returns:
52
+ A generator for all the interested files with relative paths.
53
+ """
54
+ if isinstance(dir_path, (str, Path)):
55
+ dir_path = str(dir_path)
56
+ else:
57
+ raise TypeError('"dir_path" must be a string or Path object')
58
+
59
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
60
+ raise TypeError('"suffix" must be a string or tuple of strings')
61
+
62
+ if suffix is not None and not case_sensitive:
63
+ suffix = suffix.lower() if isinstance(suffix, str) else tuple(
64
+ item.lower() for item in suffix)
65
+
66
+ root = dir_path
67
+
68
+ def _scandir(dir_path, suffix, recursive, case_sensitive):
69
+ for entry in os.scandir(dir_path):
70
+ if not entry.name.startswith('.') and entry.is_file():
71
+ rel_path = osp.relpath(entry.path, root)
72
+ _rel_path = rel_path if case_sensitive else rel_path.lower()
73
+ if suffix is None or _rel_path.endswith(suffix):
74
+ yield rel_path
75
+ elif recursive and os.path.isdir(entry.path):
76
+ # scan recursively if entry.path is a directory
77
+ yield from _scandir(entry.path, suffix, recursive,
78
+ case_sensitive)
79
+
80
+ return _scandir(dir_path, suffix, recursive, case_sensitive)
81
+
82
+
83
+ def find_vcs_root(path, markers=('.git', )):
84
+ """Finds the root directory (including itself) of specified markers.
85
+
86
+ Args:
87
+ path (str): Path of directory or file.
88
+ markers (list[str], optional): List of file or directory names.
89
+
90
+ Returns:
91
+ The directory contained one of the markers or None if not found.
92
+ """
93
+ if osp.isfile(path):
94
+ path = osp.dirname(path)
95
+
96
+ prev, cur = None, osp.abspath(osp.expanduser(path))
97
+ while cur != prev:
98
+ if any(osp.exists(osp.join(cur, marker)) for marker in markers):
99
+ return cur
100
+ prev, cur = cur, osp.split(cur)[0]
101
+ return None
microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/progressbar.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import sys
3
+ from collections.abc import Iterable
4
+ from multiprocessing import Pool
5
+ from shutil import get_terminal_size
6
+
7
+ from .timer import Timer
8
+
9
+
10
+ class ProgressBar:
11
+ """A progress bar which can print the progress."""
12
+
13
+ def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
14
+ self.task_num = task_num
15
+ self.bar_width = bar_width
16
+ self.completed = 0
17
+ self.file = file
18
+ if start:
19
+ self.start()
20
+
21
+ @property
22
+ def terminal_width(self):
23
+ width, _ = get_terminal_size()
24
+ return width
25
+
26
+ def start(self):
27
+ if self.task_num > 0:
28
+ self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
29
+ 'elapsed: 0s, ETA:')
30
+ else:
31
+ self.file.write('completed: 0, elapsed: 0s')
32
+ self.file.flush()
33
+ self.timer = Timer()
34
+
35
+ def update(self, num_tasks=1):
36
+ assert num_tasks > 0
37
+ self.completed += num_tasks
38
+ elapsed = self.timer.since_start()
39
+ if elapsed > 0:
40
+ fps = self.completed / elapsed
41
+ else:
42
+ fps = float('inf')
43
+ if self.task_num > 0:
44
+ percentage = self.completed / float(self.task_num)
45
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
46
+ msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
47
+ f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
48
+ f'ETA: {eta:5}s'
49
+
50
+ bar_width = min(self.bar_width,
51
+ int(self.terminal_width - len(msg)) + 2,
52
+ int(self.terminal_width * 0.6))
53
+ bar_width = max(2, bar_width)
54
+ mark_width = int(bar_width * percentage)
55
+ bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
56
+ self.file.write(msg.format(bar_chars))
57
+ else:
58
+ self.file.write(
59
+ f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
60
+ f' {fps:.1f} tasks/s')
61
+ self.file.flush()
62
+
63
+
64
+ def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
65
+ """Track the progress of tasks execution with a progress bar.
66
+
67
+ Tasks are done with a simple for-loop.
68
+
69
+ Args:
70
+ func (callable): The function to be applied to each task.
71
+ tasks (list or tuple[Iterable, int]): A list of tasks or
72
+ (tasks, total num).
73
+ bar_width (int): Width of progress bar.
74
+
75
+ Returns:
76
+ list: The task results.
77
+ """
78
+ if isinstance(tasks, tuple):
79
+ assert len(tasks) == 2
80
+ assert isinstance(tasks[0], Iterable)
81
+ assert isinstance(tasks[1], int)
82
+ task_num = tasks[1]
83
+ tasks = tasks[0]
84
+ elif isinstance(tasks, Iterable):
85
+ task_num = len(tasks)
86
+ else:
87
+ raise TypeError(
88
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
89
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
90
+ results = []
91
+ for task in tasks:
92
+ results.append(func(task, **kwargs))
93
+ prog_bar.update()
94
+ prog_bar.file.write('\n')
95
+ return results
96
+
97
+
98
+ def init_pool(process_num, initializer=None, initargs=None):
99
+ if initializer is None:
100
+ return Pool(process_num)
101
+ elif initargs is None:
102
+ return Pool(process_num, initializer)
103
+ else:
104
+ if not isinstance(initargs, tuple):
105
+ raise TypeError('"initargs" must be a tuple')
106
+ return Pool(process_num, initializer, initargs)
107
+
108
+
109
+ def track_parallel_progress(func,
110
+ tasks,
111
+ nproc,
112
+ initializer=None,
113
+ initargs=None,
114
+ bar_width=50,
115
+ chunksize=1,
116
+ skip_first=False,
117
+ keep_order=True,
118
+ file=sys.stdout):
119
+ """Track the progress of parallel task execution with a progress bar.
120
+
121
+ The built-in :mod:`multiprocessing` module is used for process pools and
122
+ tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
123
+
124
+ Args:
125
+ func (callable): The function to be applied to each task.
126
+ tasks (list or tuple[Iterable, int]): A list of tasks or
127
+ (tasks, total num).
128
+ nproc (int): Process (worker) number.
129
+ initializer (None or callable): Refer to :class:`multiprocessing.Pool`
130
+ for details.
131
+ initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
132
+ details.
133
+ chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
134
+ bar_width (int): Width of progress bar.
135
+ skip_first (bool): Whether to skip the first sample for each worker
136
+ when estimating fps, since the initialization step may takes
137
+ longer.
138
+ keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
139
+ :func:`Pool.imap_unordered` is used.
140
+
141
+ Returns:
142
+ list: The task results.
143
+ """
144
+ if isinstance(tasks, tuple):
145
+ assert len(tasks) == 2
146
+ assert isinstance(tasks[0], Iterable)
147
+ assert isinstance(tasks[1], int)
148
+ task_num = tasks[1]
149
+ tasks = tasks[0]
150
+ elif isinstance(tasks, Iterable):
151
+ task_num = len(tasks)
152
+ else:
153
+ raise TypeError(
154
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
155
+ pool = init_pool(nproc, initializer, initargs)
156
+ start = not skip_first
157
+ task_num -= nproc * chunksize * int(skip_first)
158
+ prog_bar = ProgressBar(task_num, bar_width, start, file=file)
159
+ results = []
160
+ if keep_order:
161
+ gen = pool.imap(func, tasks, chunksize)
162
+ else:
163
+ gen = pool.imap_unordered(func, tasks, chunksize)
164
+ for result in gen:
165
+ results.append(result)
166
+ if skip_first:
167
+ if len(results) < nproc * chunksize:
168
+ continue
169
+ elif len(results) == nproc * chunksize:
170
+ prog_bar.start()
171
+ continue
172
+ prog_bar.update()
173
+ prog_bar.file.write('\n')
174
+ pool.close()
175
+ pool.join()
176
+ return results
177
+
178
+
179
+ def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
180
+ """Track the progress of tasks iteration or enumeration with a progress
181
+ bar.
182
+
183
+ Tasks are yielded with a simple for-loop.
184
+
185
+ Args:
186
+ tasks (list or tuple[Iterable, int]): A list of tasks or
187
+ (tasks, total num).
188
+ bar_width (int): Width of progress bar.
189
+
190
+ Yields:
191
+ list: The task results.
192
+ """
193
+ if isinstance(tasks, tuple):
194
+ assert len(tasks) == 2
195
+ assert isinstance(tasks[0], Iterable)
196
+ assert isinstance(tasks[1], int)
197
+ task_num = tasks[1]
198
+ tasks = tasks[0]
199
+ elif isinstance(tasks, Iterable):
200
+ task_num = len(tasks)
201
+ else:
202
+ raise TypeError(
203
+ '"tasks" must be an iterable object or a (iterator, int) tuple')
204
+ prog_bar = ProgressBar(task_num, bar_width, file=file)
205
+ for task in tasks:
206
+ yield task
207
+ prog_bar.update()
208
+ prog_bar.file.write('\n')
microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/registry.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import inspect
3
+ import warnings
4
+ from functools import partial
5
+
6
+ from .misc import is_seq_of
7
+
8
+
9
+ def build_from_cfg(cfg, registry, default_args=None):
10
+ """Build a module from config dict.
11
+
12
+ Args:
13
+ cfg (dict): Config dict. It should at least contain the key "type".
14
+ registry (:obj:`Registry`): The registry to search the type from.
15
+ default_args (dict, optional): Default initialization arguments.
16
+
17
+ Returns:
18
+ object: The constructed object.
19
+ """
20
+ if not isinstance(cfg, dict):
21
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
22
+ if 'type' not in cfg:
23
+ if default_args is None or 'type' not in default_args:
24
+ raise KeyError(
25
+ '`cfg` or `default_args` must contain the key "type", '
26
+ f'but got {cfg}\n{default_args}')
27
+ if not isinstance(registry, Registry):
28
+ raise TypeError('registry must be an mmcv.Registry object, '
29
+ f'but got {type(registry)}')
30
+ if not (isinstance(default_args, dict) or default_args is None):
31
+ raise TypeError('default_args must be a dict or None, '
32
+ f'but got {type(default_args)}')
33
+
34
+ args = cfg.copy()
35
+
36
+ if default_args is not None:
37
+ for name, value in default_args.items():
38
+ args.setdefault(name, value)
39
+
40
+ obj_type = args.pop('type')
41
+ if isinstance(obj_type, str):
42
+ obj_cls = registry.get(obj_type)
43
+ if obj_cls is None:
44
+ raise KeyError(
45
+ f'{obj_type} is not in the {registry.name} registry')
46
+ elif inspect.isclass(obj_type):
47
+ obj_cls = obj_type
48
+ else:
49
+ raise TypeError(
50
+ f'type must be a str or valid type, but got {type(obj_type)}')
51
+ try:
52
+ return obj_cls(**args)
53
+ except Exception as e:
54
+ # Normal TypeError does not print class name.
55
+ raise type(e)(f'{obj_cls.__name__}: {e}')
56
+
57
+
58
+ class Registry:
59
+ """A registry to map strings to classes.
60
+
61
+ Registered object could be built from registry.
62
+ Example:
63
+ >>> MODELS = Registry('models')
64
+ >>> @MODELS.register_module()
65
+ >>> class ResNet:
66
+ >>> pass
67
+ >>> resnet = MODELS.build(dict(type='ResNet'))
68
+
69
+ Please refer to
70
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
71
+ advanced usage.
72
+
73
+ Args:
74
+ name (str): Registry name.
75
+ build_func(func, optional): Build function to construct instance from
76
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
77
+ ``build_func`` is specified. If ``parent`` is specified and
78
+ ``build_func`` is not given, ``build_func`` will be inherited
79
+ from ``parent``. Default: None.
80
+ parent (Registry, optional): Parent registry. The class registered in
81
+ children registry could be built from parent. Default: None.
82
+ scope (str, optional): The scope of registry. It is the key to search
83
+ for children registry. If not specified, scope will be the name of
84
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
85
+ Default: None.
86
+ """
87
+
88
+ def __init__(self, name, build_func=None, parent=None, scope=None):
89
+ self._name = name
90
+ self._module_dict = dict()
91
+ self._children = dict()
92
+ self._scope = self.infer_scope() if scope is None else scope
93
+
94
+ # self.build_func will be set with the following priority:
95
+ # 1. build_func
96
+ # 2. parent.build_func
97
+ # 3. build_from_cfg
98
+ if build_func is None:
99
+ if parent is not None:
100
+ self.build_func = parent.build_func
101
+ else:
102
+ self.build_func = build_from_cfg
103
+ else:
104
+ self.build_func = build_func
105
+ if parent is not None:
106
+ assert isinstance(parent, Registry)
107
+ parent._add_children(self)
108
+ self.parent = parent
109
+ else:
110
+ self.parent = None
111
+
112
+ def __len__(self):
113
+ return len(self._module_dict)
114
+
115
+ def __contains__(self, key):
116
+ return self.get(key) is not None
117
+
118
+ def __repr__(self):
119
+ format_str = self.__class__.__name__ + \
120
+ f'(name={self._name}, ' \
121
+ f'items={self._module_dict})'
122
+ return format_str
123
+
124
+ @staticmethod
125
+ def infer_scope():
126
+ """Infer the scope of registry.
127
+
128
+ The name of the package where registry is defined will be returned.
129
+
130
+ Example:
131
+ # in mmdet/models/backbone/resnet.py
132
+ >>> MODELS = Registry('models')
133
+ >>> @MODELS.register_module()
134
+ >>> class ResNet:
135
+ >>> pass
136
+ The scope of ``ResNet`` will be ``mmdet``.
137
+
138
+
139
+ Returns:
140
+ scope (str): The inferred scope name.
141
+ """
142
+ # inspect.stack() trace where this function is called, the index-2
143
+ # indicates the frame where `infer_scope()` is called
144
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
145
+ split_filename = filename.split('.')
146
+ return split_filename[0]
147
+
148
+ @staticmethod
149
+ def split_scope_key(key):
150
+ """Split scope and key.
151
+
152
+ The first scope will be split from key.
153
+
154
+ Examples:
155
+ >>> Registry.split_scope_key('mmdet.ResNet')
156
+ 'mmdet', 'ResNet'
157
+ >>> Registry.split_scope_key('ResNet')
158
+ None, 'ResNet'
159
+
160
+ Return:
161
+ scope (str, None): The first scope.
162
+ key (str): The remaining key.
163
+ """
164
+ split_index = key.find('.')
165
+ if split_index != -1:
166
+ return key[:split_index], key[split_index + 1:]
167
+ else:
168
+ return None, key
169
+
170
+ @property
171
+ def name(self):
172
+ return self._name
173
+
174
+ @property
175
+ def scope(self):
176
+ return self._scope
177
+
178
+ @property
179
+ def module_dict(self):
180
+ return self._module_dict
181
+
182
+ @property
183
+ def children(self):
184
+ return self._children
185
+
186
+ def get(self, key):
187
+ """Get the registry record.
188
+
189
+ Args:
190
+ key (str): The class name in string format.
191
+
192
+ Returns:
193
+ class: The corresponding class.
194
+ """
195
+ scope, real_key = self.split_scope_key(key)
196
+ if scope is None or scope == self._scope:
197
+ # get from self
198
+ if real_key in self._module_dict:
199
+ return self._module_dict[real_key]
200
+ else:
201
+ # get from self._children
202
+ if scope in self._children:
203
+ return self._children[scope].get(real_key)
204
+ else:
205
+ # goto root
206
+ parent = self.parent
207
+ while parent.parent is not None:
208
+ parent = parent.parent
209
+ return parent.get(key)
210
+
211
+ def build(self, *args, **kwargs):
212
+ return self.build_func(*args, **kwargs, registry=self)
213
+
214
+ def _add_children(self, registry):
215
+ """Add children for a registry.
216
+
217
+ The ``registry`` will be added as children based on its scope.
218
+ The parent registry could build objects from children registry.
219
+
220
+ Example:
221
+ >>> models = Registry('models')
222
+ >>> mmdet_models = Registry('models', parent=models)
223
+ >>> @mmdet_models.register_module()
224
+ >>> class ResNet:
225
+ >>> pass
226
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
227
+ """
228
+
229
+ assert isinstance(registry, Registry)
230
+ assert registry.scope is not None
231
+ assert registry.scope not in self.children, \
232
+ f'scope {registry.scope} exists in {self.name} registry'
233
+ self.children[registry.scope] = registry
234
+
235
+ def _register_module(self, module_class, module_name=None, force=False):
236
+ if not inspect.isclass(module_class):
237
+ raise TypeError('module must be a class, '
238
+ f'but got {type(module_class)}')
239
+
240
+ if module_name is None:
241
+ module_name = module_class.__name__
242
+ if isinstance(module_name, str):
243
+ module_name = [module_name]
244
+ for name in module_name:
245
+ if not force and name in self._module_dict:
246
+ raise KeyError(f'{name} is already registered '
247
+ f'in {self.name}')
248
+ self._module_dict[name] = module_class
249
+
250
+ def deprecated_register_module(self, cls=None, force=False):
251
+ warnings.warn(
252
+ 'The old API of register_module(module, force=False) '
253
+ 'is deprecated and will be removed, please use the new API '
254
+ 'register_module(name=None, force=False, module=None) instead.')
255
+ if cls is None:
256
+ return partial(self.deprecated_register_module, force=force)
257
+ self._register_module(cls, force=force)
258
+ return cls
259
+
260
+ def register_module(self, name=None, force=False, module=None):
261
+ """Register a module.
262
+
263
+ A record will be added to `self._module_dict`, whose key is the class
264
+ name or the specified name, and value is the class itself.
265
+ It can be used as a decorator or a normal function.
266
+
267
+ Example:
268
+ >>> backbones = Registry('backbone')
269
+ >>> @backbones.register_module()
270
+ >>> class ResNet:
271
+ >>> pass
272
+
273
+ >>> backbones = Registry('backbone')
274
+ >>> @backbones.register_module(name='mnet')
275
+ >>> class MobileNet:
276
+ >>> pass
277
+
278
+ >>> backbones = Registry('backbone')
279
+ >>> class ResNet:
280
+ >>> pass
281
+ >>> backbones.register_module(ResNet)
282
+
283
+ Args:
284
+ name (str | None): The module name to be registered. If not
285
+ specified, the class name will be used.
286
+ force (bool, optional): Whether to override an existing class with
287
+ the same name. Default: False.
288
+ module (type): Module class to be registered.
289
+ """
290
+ if not isinstance(force, bool):
291
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
292
+ # NOTE: This is a walkaround to be compatible with the old api,
293
+ # while it may introduce unexpected bugs.
294
+ if isinstance(name, type):
295
+ return self.deprecated_register_module(name, force=force)
296
+
297
+ # raise the error ahead of time
298
+ if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
299
+ raise TypeError(
300
+ 'name must be either of None, an instance of str or a sequence'
301
+ f' of str, but got {type(name)}')
302
+
303
+ # use it as a normal method: x.register_module(module=SomeClass)
304
+ if module is not None:
305
+ self._register_module(
306
+ module_class=module, module_name=name, force=force)
307
+ return module
308
+
309
+ # use it as a decorator: @x.register_module()
310
+ def _register(cls):
311
+ self._register_module(
312
+ module_class=cls, module_name=name, force=force)
313
+ return cls
314
+
315
+ return _register
microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/testing.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab.
2
+ import sys
3
+ from collections.abc import Iterable
4
+ from runpy import run_path
5
+ from shlex import split
6
+ from typing import Any, Dict, List
7
+ from unittest.mock import patch
8
+
9
+
10
+ def check_python_script(cmd):
11
+ """Run the python cmd script with `__main__`. The difference between
12
+ `os.system` is that, this function exectues code in the current process, so
13
+ that it can be tracked by coverage tools. Currently it supports two forms:
14
+
15
+ - ./tests/data/scripts/hello.py zz
16
+ - python tests/data/scripts/hello.py zz
17
+ """
18
+ args = split(cmd)
19
+ if args[0] == 'python':
20
+ args = args[1:]
21
+ with patch.object(sys, 'argv', args):
22
+ run_path(args[0], run_name='__main__')
23
+
24
+
25
+ def _any(judge_result):
26
+ """Since built-in ``any`` works only when the element of iterable is not
27
+ iterable, implement the function."""
28
+ if not isinstance(judge_result, Iterable):
29
+ return judge_result
30
+
31
+ try:
32
+ for element in judge_result:
33
+ if _any(element):
34
+ return True
35
+ except TypeError:
36
+ # Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
37
+ if judge_result:
38
+ return True
39
+ return False
40
+
41
+
42
+ def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
43
+ expected_subset: Dict[Any, Any]) -> bool:
44
+ """Check if the dict_obj contains the expected_subset.
45
+
46
+ Args:
47
+ dict_obj (Dict[Any, Any]): Dict object to be checked.
48
+ expected_subset (Dict[Any, Any]): Subset expected to be contained in
49
+ dict_obj.
50
+
51
+ Returns:
52
+ bool: Whether the dict_obj contains the expected_subset.
53
+ """
54
+
55
+ for key, value in expected_subset.items():
56
+ if key not in dict_obj.keys() or _any(dict_obj[key] != value):
57
+ return False
58
+ return True
59
+
60
+
61
+ def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
62
+ """Check if attribute of class object is correct.
63
+
64
+ Args:
65
+ obj (object): Class object to be checked.
66
+ expected_attrs (Dict[str, Any]): Dict of the expected attrs.
67
+
68
+ Returns:
69
+ bool: Whether the attribute of class object is correct.
70
+ """
71
+ for attr, value in expected_attrs.items():
72
+ if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
73
+ return False
74
+ return True
75
+
76
+
77
+ def assert_dict_has_keys(obj: Dict[str, Any],
78
+ expected_keys: List[str]) -> bool:
79
+ """Check if the obj has all the expected_keys.
80
+
81
+ Args:
82
+ obj (Dict[str, Any]): Object to be checked.
83
+ expected_keys (List[str]): Keys expected to contained in the keys of
84
+ the obj.
85
+
86
+ Returns:
87
+ bool: Whether the obj has the expected keys.
88
+ """
89
+ return set(expected_keys).issubset(set(obj.keys()))
90
+
91
+
92
+ def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
93
+ """Check if target_keys is equal to result_keys.
94
+
95
+ Args:
96
+ result_keys (List[str]): Result keys to be checked.
97
+ target_keys (List[str]): Target keys to be checked.
98
+
99
+ Returns:
100
+ bool: Whether target_keys is equal to result_keys.
101
+ """
102
+ return set(result_keys) == set(target_keys)
103
+
104
+
105
+ def assert_is_norm_layer(module) -> bool:
106
+ """Check if the module is a norm layer.
107
+
108
+ Args:
109
+ module (nn.Module): The module to be checked.
110
+
111
+ Returns:
112
+ bool: Whether the module is a norm layer.
113
+ """
114
+ from .parrots_wrapper import _BatchNorm, _InstanceNorm
115
+ from torch.nn import GroupNorm, LayerNorm
116
+ norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
117
+ return isinstance(module, norm_layer_candidates)
118
+
119
+
120
+ def assert_params_all_zeros(module) -> bool:
121
+ """Check if the parameters of the module is all zeros.
122
+
123
+ Args:
124
+ module (nn.Module): The module to be checked.
125
+
126
+ Returns:
127
+ bool: Whether the parameters of the module is all zeros.
128
+ """
129
+ weight_data = module.weight.data
130
+ is_weight_zero = weight_data.allclose(
131
+ weight_data.new_zeros(weight_data.size()))
132
+
133
+ if hasattr(module, 'bias') and module.bias is not None:
134
+ bias_data = module.bias.data
135
+ is_bias_zero = bias_data.allclose(
136
+ bias_data.new_zeros(bias_data.size()))
137
+ else:
138
+ is_bias_zero = True
139
+
140
+ return is_weight_zero and is_bias_zero
microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/timer.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from time import time
3
+
4
+
5
+ class TimerError(Exception):
6
+
7
+ def __init__(self, message):
8
+ self.message = message
9
+ super(TimerError, self).__init__(message)
10
+
11
+
12
+ class Timer:
13
+ """A flexible Timer class.
14
+
15
+ :Example:
16
+
17
+ >>> import time
18
+ >>> import annotator.mmpkg.mmcv as mmcv
19
+ >>> with mmcv.Timer():
20
+ >>> # simulate a code block that will run for 1s
21
+ >>> time.sleep(1)
22
+ 1.000
23
+ >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
24
+ >>> # simulate a code block that will run for 1s
25
+ >>> time.sleep(1)
26
+ it takes 1.0 seconds
27
+ >>> timer = mmcv.Timer()
28
+ >>> time.sleep(0.5)
29
+ >>> print(timer.since_start())
30
+ 0.500
31
+ >>> time.sleep(0.5)
32
+ >>> print(timer.since_last_check())
33
+ 0.500
34
+ >>> print(timer.since_start())
35
+ 1.000
36
+ """
37
+
38
+ def __init__(self, start=True, print_tmpl=None):
39
+ self._is_running = False
40
+ self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
41
+ if start:
42
+ self.start()
43
+
44
+ @property
45
+ def is_running(self):
46
+ """bool: indicate whether the timer is running"""
47
+ return self._is_running
48
+
49
+ def __enter__(self):
50
+ self.start()
51
+ return self
52
+
53
+ def __exit__(self, type, value, traceback):
54
+ print(self.print_tmpl.format(self.since_last_check()))
55
+ self._is_running = False
56
+
57
+ def start(self):
58
+ """Start the timer."""
59
+ if not self._is_running:
60
+ self._t_start = time()
61
+ self._is_running = True
62
+ self._t_last = time()
63
+
64
+ def since_start(self):
65
+ """Total time since the timer is started.
66
+
67
+ Returns (float): Time in seconds.
68
+ """
69
+ if not self._is_running:
70
+ raise TimerError('timer is not running')
71
+ self._t_last = time()
72
+ return self._t_last - self._t_start
73
+
74
+ def since_last_check(self):
75
+ """Time since the last checking.
76
+
77
+ Either :func:`since_start` or :func:`since_last_check` is a checking
78
+ operation.
79
+
80
+ Returns (float): Time in seconds.
81
+ """
82
+ if not self._is_running:
83
+ raise TimerError('timer is not running')
84
+ dur = time() - self._t_last
85
+ self._t_last = time()
86
+ return dur
87
+
88
+
89
+ _g_timers = {} # global timers
90
+
91
+
92
+ def check_time(timer_id):
93
+ """Add check points in a single line.
94
+
95
+ This method is suitable for running a task on a list of items. A timer will
96
+ be registered when the method is called for the first time.
97
+
98
+ :Example:
99
+
100
+ >>> import time
101
+ >>> import annotator.mmpkg.mmcv as mmcv
102
+ >>> for i in range(1, 6):
103
+ >>> # simulate a code block
104
+ >>> time.sleep(i)
105
+ >>> mmcv.check_time('task1')
106
+ 2.000
107
+ 3.000
108
+ 4.000
109
+ 5.000
110
+
111
+ Args:
112
+ timer_id (str): Timer identifier.
113
+ """
114
+ if timer_id not in _g_timers:
115
+ _g_timers[timer_id] = Timer()
116
+ return 0
117
+ else:
118
+ return _g_timers[timer_id].since_last_check()
microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/trace.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+
5
+ from annotator.mmpkg.mmcv.utils import digit_version
6
+
7
+
8
+ def is_jit_tracing() -> bool:
9
+ if (torch.__version__ != 'parrots'
10
+ and digit_version(torch.__version__) >= digit_version('1.6.0')):
11
+ on_trace = torch.jit.is_tracing()
12
+ # In PyTorch 1.6, torch.jit.is_tracing has a bug.
13
+ # Refers to https://github.com/pytorch/pytorch/issues/42448
14
+ if isinstance(on_trace, bool):
15
+ return on_trace
16
+ else:
17
+ return torch._C._is_tracing()
18
+ else:
19
+ warnings.warn(
20
+ 'torch.jit.is_tracing is only supported after v1.6.0. '
21
+ 'Therefore is_tracing returns False automatically. Please '
22
+ 'set on_trace manually if you are using trace.', UserWarning)
23
+ return False
microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/version_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ import subprocess
4
+ import warnings
5
+
6
+ from packaging.version import parse
7
+
8
+
9
+ def digit_version(version_str: str, length: int = 4):
10
+ """Convert a version string into a tuple of integers.
11
+
12
+ This method is usually used for comparing two versions. For pre-release
13
+ versions: alpha < beta < rc.
14
+
15
+ Args:
16
+ version_str (str): The version string.
17
+ length (int): The maximum number of version levels. Default: 4.
18
+
19
+ Returns:
20
+ tuple[int]: The version info in digits (integers).
21
+ """
22
+ assert 'parrots' not in version_str
23
+ version = parse(version_str)
24
+ assert version.release, f'failed to parse version {version_str}'
25
+ release = list(version.release)
26
+ release = release[:length]
27
+ if len(release) < length:
28
+ release = release + [0] * (length - len(release))
29
+ if version.is_prerelease:
30
+ mapping = {'a': -3, 'b': -2, 'rc': -1}
31
+ val = -4
32
+ # version.pre can be None
33
+ if version.pre:
34
+ if version.pre[0] not in mapping:
35
+ warnings.warn(f'unknown prerelease version {version.pre[0]}, '
36
+ 'version checking may go wrong')
37
+ else:
38
+ val = mapping[version.pre[0]]
39
+ release.extend([val, version.pre[-1]])
40
+ else:
41
+ release.extend([val, 0])
42
+
43
+ elif version.is_postrelease:
44
+ release.extend([1, version.post])
45
+ else:
46
+ release.extend([0, 0])
47
+ return tuple(release)
48
+
49
+
50
+ def _minimal_ext_cmd(cmd):
51
+ # construct minimal environment
52
+ env = {}
53
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
54
+ v = os.environ.get(k)
55
+ if v is not None:
56
+ env[k] = v
57
+ # LANGUAGE is used on win32
58
+ env['LANGUAGE'] = 'C'
59
+ env['LANG'] = 'C'
60
+ env['LC_ALL'] = 'C'
61
+ out = subprocess.Popen(
62
+ cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
63
+ return out
64
+
65
+
66
+ def get_git_hash(fallback='unknown', digits=None):
67
+ """Get the git hash of the current repo.
68
+
69
+ Args:
70
+ fallback (str, optional): The fallback string when git hash is
71
+ unavailable. Defaults to 'unknown'.
72
+ digits (int, optional): kept digits of the hash. Defaults to None,
73
+ meaning all digits are kept.
74
+
75
+ Returns:
76
+ str: Git commit hash.
77
+ """
78
+
79
+ if digits is not None and not isinstance(digits, int):
80
+ raise TypeError('digits must be None or an integer')
81
+
82
+ try:
83
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
84
+ sha = out.strip().decode('ascii')
85
+ if digits is not None:
86
+ sha = sha[:digits]
87
+ except OSError:
88
+ sha = fallback
89
+
90
+ return sha
microsoftexcel-controlnet/annotator/mmpkg/mmcv/version.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ __version__ = '1.3.17'
3
+
4
+
5
+ def parse_version_info(version_str: str, length: int = 4) -> tuple:
6
+ """Parse a version string into a tuple.
7
+
8
+ Args:
9
+ version_str (str): The version string.
10
+ length (int): The maximum number of version levels. Default: 4.
11
+
12
+ Returns:
13
+ tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
14
+ (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
15
+ (2, 0, 0, 0, 'rc', 1) (when length is set to 4).
16
+ """
17
+ from packaging.version import parse
18
+ version = parse(version_str)
19
+ assert version.release, f'failed to parse version {version_str}'
20
+ release = list(version.release)
21
+ release = release[:length]
22
+ if len(release) < length:
23
+ release = release + [0] * (length - len(release))
24
+ if version.is_prerelease:
25
+ release.extend(list(version.pre))
26
+ elif version.is_postrelease:
27
+ release.extend(list(version.post))
28
+ else:
29
+ release.extend([0, 0])
30
+ return tuple(release)
31
+
32
+
33
+ version_info = tuple(int(x) for x in __version__.split('.')[:3])
34
+
35
+ __all__ = ['__version__', 'version_info', 'parse_version_info']
microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .io import Cache, VideoReader, frames2video
3
+ from .optflow import (dequantize_flow, flow_from_bytes, flow_warp, flowread,
4
+ flowwrite, quantize_flow, sparse_flow_from_bytes)
5
+ from .processing import concat_video, convert_video, cut_video, resize_video
6
+
7
+ __all__ = [
8
+ 'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
9
+ 'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow',
10
+ 'dequantize_flow', 'flow_warp', 'flow_from_bytes', 'sparse_flow_from_bytes'
11
+ ]
microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/io.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ from collections import OrderedDict
4
+
5
+ import cv2
6
+ from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
7
+ CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
8
+ CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
9
+
10
+ from annotator.mmpkg.mmcv.utils import (check_file_exist, mkdir_or_exist, scandir,
11
+ track_progress)
12
+
13
+
14
+ class Cache:
15
+
16
+ def __init__(self, capacity):
17
+ self._cache = OrderedDict()
18
+ self._capacity = int(capacity)
19
+ if capacity <= 0:
20
+ raise ValueError('capacity must be a positive integer')
21
+
22
+ @property
23
+ def capacity(self):
24
+ return self._capacity
25
+
26
+ @property
27
+ def size(self):
28
+ return len(self._cache)
29
+
30
+ def put(self, key, val):
31
+ if key in self._cache:
32
+ return
33
+ if len(self._cache) >= self.capacity:
34
+ self._cache.popitem(last=False)
35
+ self._cache[key] = val
36
+
37
+ def get(self, key, default=None):
38
+ val = self._cache[key] if key in self._cache else default
39
+ return val
40
+
41
+
42
+ class VideoReader:
43
+ """Video class with similar usage to a list object.
44
+
45
+ This video warpper class provides convenient apis to access frames.
46
+ There exists an issue of OpenCV's VideoCapture class that jumping to a
47
+ certain frame may be inaccurate. It is fixed in this class by checking
48
+ the position after jumping each time.
49
+ Cache is used when decoding videos. So if the same frame is visited for
50
+ the second time, there is no need to decode again if it is stored in the
51
+ cache.
52
+
53
+ :Example:
54
+
55
+ >>> import annotator.mmpkg.mmcv as mmcv
56
+ >>> v = mmcv.VideoReader('sample.mp4')
57
+ >>> len(v) # get the total frame number with `len()`
58
+ 120
59
+ >>> for img in v: # v is iterable
60
+ >>> mmcv.imshow(img)
61
+ >>> v[5] # get the 6th frame
62
+ """
63
+
64
+ def __init__(self, filename, cache_capacity=10):
65
+ # Check whether the video path is a url
66
+ if not filename.startswith(('https://', 'http://')):
67
+ check_file_exist(filename, 'Video file not found: ' + filename)
68
+ self._vcap = cv2.VideoCapture(filename)
69
+ assert cache_capacity > 0
70
+ self._cache = Cache(cache_capacity)
71
+ self._position = 0
72
+ # get basic info
73
+ self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
74
+ self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
75
+ self._fps = self._vcap.get(CAP_PROP_FPS)
76
+ self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
77
+ self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
78
+
79
+ @property
80
+ def vcap(self):
81
+ """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
82
+ return self._vcap
83
+
84
+ @property
85
+ def opened(self):
86
+ """bool: Indicate whether the video is opened."""
87
+ return self._vcap.isOpened()
88
+
89
+ @property
90
+ def width(self):
91
+ """int: Width of video frames."""
92
+ return self._width
93
+
94
+ @property
95
+ def height(self):
96
+ """int: Height of video frames."""
97
+ return self._height
98
+
99
+ @property
100
+ def resolution(self):
101
+ """tuple: Video resolution (width, height)."""
102
+ return (self._width, self._height)
103
+
104
+ @property
105
+ def fps(self):
106
+ """float: FPS of the video."""
107
+ return self._fps
108
+
109
+ @property
110
+ def frame_cnt(self):
111
+ """int: Total frames of the video."""
112
+ return self._frame_cnt
113
+
114
+ @property
115
+ def fourcc(self):
116
+ """str: "Four character code" of the video."""
117
+ return self._fourcc
118
+
119
+ @property
120
+ def position(self):
121
+ """int: Current cursor position, indicating frame decoded."""
122
+ return self._position
123
+
124
+ def _get_real_position(self):
125
+ return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
126
+
127
+ def _set_real_position(self, frame_id):
128
+ self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
129
+ pos = self._get_real_position()
130
+ for _ in range(frame_id - pos):
131
+ self._vcap.read()
132
+ self._position = frame_id
133
+
134
+ def read(self):
135
+ """Read the next frame.
136
+
137
+ If the next frame have been decoded before and in the cache, then
138
+ return it directly, otherwise decode, cache and return it.
139
+
140
+ Returns:
141
+ ndarray or None: Return the frame if successful, otherwise None.
142
+ """
143
+ # pos = self._position
144
+ if self._cache:
145
+ img = self._cache.get(self._position)
146
+ if img is not None:
147
+ ret = True
148
+ else:
149
+ if self._position != self._get_real_position():
150
+ self._set_real_position(self._position)
151
+ ret, img = self._vcap.read()
152
+ if ret:
153
+ self._cache.put(self._position, img)
154
+ else:
155
+ ret, img = self._vcap.read()
156
+ if ret:
157
+ self._position += 1
158
+ return img
159
+
160
+ def get_frame(self, frame_id):
161
+ """Get frame by index.
162
+
163
+ Args:
164
+ frame_id (int): Index of the expected frame, 0-based.
165
+
166
+ Returns:
167
+ ndarray or None: Return the frame if successful, otherwise None.
168
+ """
169
+ if frame_id < 0 or frame_id >= self._frame_cnt:
170
+ raise IndexError(
171
+ f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
172
+ if frame_id == self._position:
173
+ return self.read()
174
+ if self._cache:
175
+ img = self._cache.get(frame_id)
176
+ if img is not None:
177
+ self._position = frame_id + 1
178
+ return img
179
+ self._set_real_position(frame_id)
180
+ ret, img = self._vcap.read()
181
+ if ret:
182
+ if self._cache:
183
+ self._cache.put(self._position, img)
184
+ self._position += 1
185
+ return img
186
+
187
+ def current_frame(self):
188
+ """Get the current frame (frame that is just visited).
189
+
190
+ Returns:
191
+ ndarray or None: If the video is fresh, return None, otherwise
192
+ return the frame.
193
+ """
194
+ if self._position == 0:
195
+ return None
196
+ return self._cache.get(self._position - 1)
197
+
198
+ def cvt2frames(self,
199
+ frame_dir,
200
+ file_start=0,
201
+ filename_tmpl='{:06d}.jpg',
202
+ start=0,
203
+ max_num=0,
204
+ show_progress=True):
205
+ """Convert a video to frame images.
206
+
207
+ Args:
208
+ frame_dir (str): Output directory to store all the frame images.
209
+ file_start (int): Filenames will start from the specified number.
210
+ filename_tmpl (str): Filename template with the index as the
211
+ placeholder.
212
+ start (int): The starting frame index.
213
+ max_num (int): Maximum number of frames to be written.
214
+ show_progress (bool): Whether to show a progress bar.
215
+ """
216
+ mkdir_or_exist(frame_dir)
217
+ if max_num == 0:
218
+ task_num = self.frame_cnt - start
219
+ else:
220
+ task_num = min(self.frame_cnt - start, max_num)
221
+ if task_num <= 0:
222
+ raise ValueError('start must be less than total frame number')
223
+ if start > 0:
224
+ self._set_real_position(start)
225
+
226
+ def write_frame(file_idx):
227
+ img = self.read()
228
+ if img is None:
229
+ return
230
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
231
+ cv2.imwrite(filename, img)
232
+
233
+ if show_progress:
234
+ track_progress(write_frame, range(file_start,
235
+ file_start + task_num))
236
+ else:
237
+ for i in range(task_num):
238
+ write_frame(file_start + i)
239
+
240
+ def __len__(self):
241
+ return self.frame_cnt
242
+
243
+ def __getitem__(self, index):
244
+ if isinstance(index, slice):
245
+ return [
246
+ self.get_frame(i)
247
+ for i in range(*index.indices(self.frame_cnt))
248
+ ]
249
+ # support negative indexing
250
+ if index < 0:
251
+ index += self.frame_cnt
252
+ if index < 0:
253
+ raise IndexError('index out of range')
254
+ return self.get_frame(index)
255
+
256
+ def __iter__(self):
257
+ self._set_real_position(0)
258
+ return self
259
+
260
+ def __next__(self):
261
+ img = self.read()
262
+ if img is not None:
263
+ return img
264
+ else:
265
+ raise StopIteration
266
+
267
+ next = __next__
268
+
269
+ def __enter__(self):
270
+ return self
271
+
272
+ def __exit__(self, exc_type, exc_value, traceback):
273
+ self._vcap.release()
274
+
275
+
276
+ def frames2video(frame_dir,
277
+ video_file,
278
+ fps=30,
279
+ fourcc='XVID',
280
+ filename_tmpl='{:06d}.jpg',
281
+ start=0,
282
+ end=0,
283
+ show_progress=True):
284
+ """Read the frame images from a directory and join them as a video.
285
+
286
+ Args:
287
+ frame_dir (str): The directory containing video frames.
288
+ video_file (str): Output filename.
289
+ fps (float): FPS of the output video.
290
+ fourcc (str): Fourcc of the output video, this should be compatible
291
+ with the output file type.
292
+ filename_tmpl (str): Filename template with the index as the variable.
293
+ start (int): Starting frame index.
294
+ end (int): Ending frame index.
295
+ show_progress (bool): Whether to show a progress bar.
296
+ """
297
+ if end == 0:
298
+ ext = filename_tmpl.split('.')[-1]
299
+ end = len([name for name in scandir(frame_dir, ext)])
300
+ first_file = osp.join(frame_dir, filename_tmpl.format(start))
301
+ check_file_exist(first_file, 'The start frame not found: ' + first_file)
302
+ img = cv2.imread(first_file)
303
+ height, width = img.shape[:2]
304
+ resolution = (width, height)
305
+ vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
306
+ resolution)
307
+
308
+ def write_frame(file_idx):
309
+ filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
310
+ img = cv2.imread(filename)
311
+ vwriter.write(img)
312
+
313
+ if show_progress:
314
+ track_progress(write_frame, range(start, end))
315
+ else:
316
+ for i in range(start, end):
317
+ write_frame(i)
318
+ vwriter.release()
microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/optflow.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from annotator.mmpkg.mmcv.arraymisc import dequantize, quantize
8
+ from annotator.mmpkg.mmcv.image import imread, imwrite
9
+ from annotator.mmpkg.mmcv.utils import is_str
10
+
11
+
12
+ def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
13
+ """Read an optical flow map.
14
+
15
+ Args:
16
+ flow_or_path (ndarray or str): A flow map or filepath.
17
+ quantize (bool): whether to read quantized pair, if set to True,
18
+ remaining args will be passed to :func:`dequantize_flow`.
19
+ concat_axis (int): The axis that dx and dy are concatenated,
20
+ can be either 0 or 1. Ignored if quantize is False.
21
+
22
+ Returns:
23
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
24
+ """
25
+ if isinstance(flow_or_path, np.ndarray):
26
+ if (flow_or_path.ndim != 3) or (flow_or_path.shape[-1] != 2):
27
+ raise ValueError(f'Invalid flow with shape {flow_or_path.shape}')
28
+ return flow_or_path
29
+ elif not is_str(flow_or_path):
30
+ raise TypeError(f'"flow_or_path" must be a filename or numpy array, '
31
+ f'not {type(flow_or_path)}')
32
+
33
+ if not quantize:
34
+ with open(flow_or_path, 'rb') as f:
35
+ try:
36
+ header = f.read(4).decode('utf-8')
37
+ except Exception:
38
+ raise IOError(f'Invalid flow file: {flow_or_path}')
39
+ else:
40
+ if header != 'PIEH':
41
+ raise IOError(f'Invalid flow file: {flow_or_path}, '
42
+ 'header does not contain PIEH')
43
+
44
+ w = np.fromfile(f, np.int32, 1).squeeze()
45
+ h = np.fromfile(f, np.int32, 1).squeeze()
46
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
47
+ else:
48
+ assert concat_axis in [0, 1]
49
+ cat_flow = imread(flow_or_path, flag='unchanged')
50
+ if cat_flow.ndim != 2:
51
+ raise IOError(
52
+ f'{flow_or_path} is not a valid quantized flow file, '
53
+ f'its dimension is {cat_flow.ndim}.')
54
+ assert cat_flow.shape[concat_axis] % 2 == 0
55
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
56
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
57
+
58
+ return flow.astype(np.float32)
59
+
60
+
61
+ def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
62
+ """Write optical flow to file.
63
+
64
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
65
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
66
+ will be concatenated horizontally into a single image if quantize is True.)
67
+
68
+ Args:
69
+ flow (ndarray): (h, w, 2) array of optical flow.
70
+ filename (str): Output filepath.
71
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
72
+ images. If set to True, remaining args will be passed to
73
+ :func:`quantize_flow`.
74
+ concat_axis (int): The axis that dx and dy are concatenated,
75
+ can be either 0 or 1. Ignored if quantize is False.
76
+ """
77
+ if not quantize:
78
+ with open(filename, 'wb') as f:
79
+ f.write('PIEH'.encode('utf-8'))
80
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
81
+ flow = flow.astype(np.float32)
82
+ flow.tofile(f)
83
+ f.flush()
84
+ else:
85
+ assert concat_axis in [0, 1]
86
+ dx, dy = quantize_flow(flow, *args, **kwargs)
87
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
88
+ imwrite(dxdy, filename)
89
+
90
+
91
+ def quantize_flow(flow, max_val=0.02, norm=True):
92
+ """Quantize flow to [0, 255].
93
+
94
+ After this step, the size of flow will be much smaller, and can be
95
+ dumped as jpeg images.
96
+
97
+ Args:
98
+ flow (ndarray): (h, w, 2) array of optical flow.
99
+ max_val (float): Maximum value of flow, values beyond
100
+ [-max_val, max_val] will be truncated.
101
+ norm (bool): Whether to divide flow values by image width/height.
102
+
103
+ Returns:
104
+ tuple[ndarray]: Quantized dx and dy.
105
+ """
106
+ h, w, _ = flow.shape
107
+ dx = flow[..., 0]
108
+ dy = flow[..., 1]
109
+ if norm:
110
+ dx = dx / w # avoid inplace operations
111
+ dy = dy / h
112
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
113
+ flow_comps = [
114
+ quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
115
+ ]
116
+ return tuple(flow_comps)
117
+
118
+
119
+ def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
120
+ """Recover from quantized flow.
121
+
122
+ Args:
123
+ dx (ndarray): Quantized dx.
124
+ dy (ndarray): Quantized dy.
125
+ max_val (float): Maximum value used when quantizing.
126
+ denorm (bool): Whether to multiply flow values with width/height.
127
+
128
+ Returns:
129
+ ndarray: Dequantized flow.
130
+ """
131
+ assert dx.shape == dy.shape
132
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
133
+
134
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
135
+
136
+ if denorm:
137
+ dx *= dx.shape[1]
138
+ dy *= dx.shape[0]
139
+ flow = np.dstack((dx, dy))
140
+ return flow
141
+
142
+
143
+ def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
144
+ """Use flow to warp img.
145
+
146
+ Args:
147
+ img (ndarray, float or uint8): Image to be warped.
148
+ flow (ndarray, float): Optical Flow.
149
+ filling_value (int): The missing pixels will be set with filling_value.
150
+ interpolate_mode (str): bilinear -> Bilinear Interpolation;
151
+ nearest -> Nearest Neighbor.
152
+
153
+ Returns:
154
+ ndarray: Warped image with the same shape of img
155
+ """
156
+ warnings.warn('This function is just for prototyping and cannot '
157
+ 'guarantee the computational efficiency.')
158
+ assert flow.ndim == 3, 'Flow must be in 3D arrays.'
159
+ height = flow.shape[0]
160
+ width = flow.shape[1]
161
+ channels = img.shape[2]
162
+
163
+ output = np.ones(
164
+ (height, width, channels), dtype=img.dtype) * filling_value
165
+
166
+ grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2)
167
+ dx = grid[:, :, 0] + flow[:, :, 1]
168
+ dy = grid[:, :, 1] + flow[:, :, 0]
169
+ sx = np.floor(dx).astype(int)
170
+ sy = np.floor(dy).astype(int)
171
+ valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1)
172
+
173
+ if interpolate_mode == 'nearest':
174
+ output[valid, :] = img[dx[valid].round().astype(int),
175
+ dy[valid].round().astype(int), :]
176
+ elif interpolate_mode == 'bilinear':
177
+ # dirty walkround for integer positions
178
+ eps_ = 1e-6
179
+ dx, dy = dx + eps_, dy + eps_
180
+ left_top_ = img[np.floor(dx[valid]).astype(int),
181
+ np.floor(dy[valid]).astype(int), :] * (
182
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
183
+ np.ceil(dy[valid]) - dy[valid])[:, None]
184
+ left_down_ = img[np.ceil(dx[valid]).astype(int),
185
+ np.floor(dy[valid]).astype(int), :] * (
186
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
187
+ np.ceil(dy[valid]) - dy[valid])[:, None]
188
+ right_top_ = img[np.floor(dx[valid]).astype(int),
189
+ np.ceil(dy[valid]).astype(int), :] * (
190
+ np.ceil(dx[valid]) - dx[valid])[:, None] * (
191
+ dy[valid] - np.floor(dy[valid]))[:, None]
192
+ right_down_ = img[np.ceil(dx[valid]).astype(int),
193
+ np.ceil(dy[valid]).astype(int), :] * (
194
+ dx[valid] - np.floor(dx[valid]))[:, None] * (
195
+ dy[valid] - np.floor(dy[valid]))[:, None]
196
+ output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_
197
+ else:
198
+ raise NotImplementedError(
199
+ 'We only support interpolation modes of nearest and bilinear, '
200
+ f'but got {interpolate_mode}.')
201
+ return output.astype(img.dtype)
202
+
203
+
204
+ def flow_from_bytes(content):
205
+ """Read dense optical flow from bytes.
206
+
207
+ .. note::
208
+ This load optical flow function works for FlyingChairs, FlyingThings3D,
209
+ Sintel, FlyingChairsOcc datasets, but cannot load the data from
210
+ ChairsSDHom.
211
+
212
+ Args:
213
+ content (bytes): Optical flow bytes got from files or other streams.
214
+
215
+ Returns:
216
+ ndarray: Loaded optical flow with the shape (H, W, 2).
217
+ """
218
+
219
+ # header in first 4 bytes
220
+ header = content[:4]
221
+ if header.decode('utf-8') != 'PIEH':
222
+ raise Exception('Flow file header does not contain PIEH')
223
+ # width in second 4 bytes
224
+ width = np.frombuffer(content[4:], np.int32, 1).squeeze()
225
+ # height in third 4 bytes
226
+ height = np.frombuffer(content[8:], np.int32, 1).squeeze()
227
+ # after first 12 bytes, all bytes are flow
228
+ flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape(
229
+ (height, width, 2))
230
+
231
+ return flow
232
+
233
+
234
+ def sparse_flow_from_bytes(content):
235
+ """Read the optical flow in KITTI datasets from bytes.
236
+
237
+ This function is modified from RAFT load the `KITTI datasets
238
+ <https://github.com/princeton-vl/RAFT/blob/224320502d66c356d88e6c712f38129e60661e80/core/utils/frame_utils.py#L102>`_.
239
+
240
+ Args:
241
+ content (bytes): Optical flow bytes got from files or other streams.
242
+
243
+ Returns:
244
+ Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2)
245
+ and flow valid mask with the shape (H, W).
246
+ """ # nopa
247
+
248
+ content = np.frombuffer(content, np.uint8)
249
+ flow = cv2.imdecode(content, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
250
+ flow = flow[:, :, ::-1].astype(np.float32)
251
+ # flow shape (H, W, 2) valid shape (H, W)
252
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
253
+ flow = (flow - 2**15) / 64.0
254
+ return flow, valid
microsoftexcel-controlnet/annotator/mmpkg/mmcv/video/processing.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ import os.path as osp
4
+ import subprocess
5
+ import tempfile
6
+
7
+ from annotator.mmpkg.mmcv.utils import requires_executable
8
+
9
+
10
+ @requires_executable('ffmpeg')
11
+ def convert_video(in_file,
12
+ out_file,
13
+ print_cmd=False,
14
+ pre_options='',
15
+ **kwargs):
16
+ """Convert a video with ffmpeg.
17
+
18
+ This provides a general api to ffmpeg, the executed command is::
19
+
20
+ `ffmpeg -y <pre_options> -i <in_file> <options> <out_file>`
21
+
22
+ Options(kwargs) are mapped to ffmpeg commands with the following rules:
23
+
24
+ - key=val: "-key val"
25
+ - key=True: "-key"
26
+ - key=False: ""
27
+
28
+ Args:
29
+ in_file (str): Input video filename.
30
+ out_file (str): Output video filename.
31
+ pre_options (str): Options appears before "-i <in_file>".
32
+ print_cmd (bool): Whether to print the final ffmpeg command.
33
+ """
34
+ options = []
35
+ for k, v in kwargs.items():
36
+ if isinstance(v, bool):
37
+ if v:
38
+ options.append(f'-{k}')
39
+ elif k == 'log_level':
40
+ assert v in [
41
+ 'quiet', 'panic', 'fatal', 'error', 'warning', 'info',
42
+ 'verbose', 'debug', 'trace'
43
+ ]
44
+ options.append(f'-loglevel {v}')
45
+ else:
46
+ options.append(f'-{k} {v}')
47
+ cmd = f'ffmpeg -y {pre_options} -i {in_file} {" ".join(options)} ' \
48
+ f'{out_file}'
49
+ if print_cmd:
50
+ print(cmd)
51
+ subprocess.call(cmd, shell=True)
52
+
53
+
54
+ @requires_executable('ffmpeg')
55
+ def resize_video(in_file,
56
+ out_file,
57
+ size=None,
58
+ ratio=None,
59
+ keep_ar=False,
60
+ log_level='info',
61
+ print_cmd=False):
62
+ """Resize a video.
63
+
64
+ Args:
65
+ in_file (str): Input video filename.
66
+ out_file (str): Output video filename.
67
+ size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1).
68
+ ratio (tuple or float): Expected resize ratio, (2, 0.5) means
69
+ (w*2, h*0.5).
70
+ keep_ar (bool): Whether to keep original aspect ratio.
71
+ log_level (str): Logging level of ffmpeg.
72
+ print_cmd (bool): Whether to print the final ffmpeg command.
73
+ """
74
+ if size is None and ratio is None:
75
+ raise ValueError('expected size or ratio must be specified')
76
+ if size is not None and ratio is not None:
77
+ raise ValueError('size and ratio cannot be specified at the same time')
78
+ options = {'log_level': log_level}
79
+ if size:
80
+ if not keep_ar:
81
+ options['vf'] = f'scale={size[0]}:{size[1]}'
82
+ else:
83
+ options['vf'] = f'scale=w={size[0]}:h={size[1]}:' \
84
+ 'force_original_aspect_ratio=decrease'
85
+ else:
86
+ if not isinstance(ratio, tuple):
87
+ ratio = (ratio, ratio)
88
+ options['vf'] = f'scale="trunc(iw*{ratio[0]}):trunc(ih*{ratio[1]})"'
89
+ convert_video(in_file, out_file, print_cmd, **options)
90
+
91
+
92
+ @requires_executable('ffmpeg')
93
+ def cut_video(in_file,
94
+ out_file,
95
+ start=None,
96
+ end=None,
97
+ vcodec=None,
98
+ acodec=None,
99
+ log_level='info',
100
+ print_cmd=False):
101
+ """Cut a clip from a video.
102
+
103
+ Args:
104
+ in_file (str): Input video filename.
105
+ out_file (str): Output video filename.
106
+ start (None or float): Start time (in seconds).
107
+ end (None or float): End time (in seconds).
108
+ vcodec (None or str): Output video codec, None for unchanged.
109
+ acodec (None or str): Output audio codec, None for unchanged.
110
+ log_level (str): Logging level of ffmpeg.
111
+ print_cmd (bool): Whether to print the final ffmpeg command.
112
+ """
113
+ options = {'log_level': log_level}
114
+ if vcodec is None:
115
+ options['vcodec'] = 'copy'
116
+ if acodec is None:
117
+ options['acodec'] = 'copy'
118
+ if start:
119
+ options['ss'] = start
120
+ else:
121
+ start = 0
122
+ if end:
123
+ options['t'] = end - start
124
+ convert_video(in_file, out_file, print_cmd, **options)
125
+
126
+
127
+ @requires_executable('ffmpeg')
128
+ def concat_video(video_list,
129
+ out_file,
130
+ vcodec=None,
131
+ acodec=None,
132
+ log_level='info',
133
+ print_cmd=False):
134
+ """Concatenate multiple videos into a single one.
135
+
136
+ Args:
137
+ video_list (list): A list of video filenames
138
+ out_file (str): Output video filename
139
+ vcodec (None or str): Output video codec, None for unchanged
140
+ acodec (None or str): Output audio codec, None for unchanged
141
+ log_level (str): Logging level of ffmpeg.
142
+ print_cmd (bool): Whether to print the final ffmpeg command.
143
+ """
144
+ tmp_filehandler, tmp_filename = tempfile.mkstemp(suffix='.txt', text=True)
145
+ with open(tmp_filename, 'w') as f:
146
+ for filename in video_list:
147
+ f.write(f'file {osp.abspath(filename)}\n')
148
+ options = {'log_level': log_level}
149
+ if vcodec is None:
150
+ options['vcodec'] = 'copy'
151
+ if acodec is None:
152
+ options['acodec'] = 'copy'
153
+ convert_video(
154
+ tmp_filename,
155
+ out_file,
156
+ print_cmd,
157
+ pre_options='-f concat -safe 0',
158
+ **options)
159
+ os.close(tmp_filehandler)
160
+ os.remove(tmp_filename)
microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .color import Color, color_val
3
+ from .image import imshow, imshow_bboxes, imshow_det_bboxes
4
+ from .optflow import flow2rgb, flowshow, make_color_wheel
5
+
6
+ __all__ = [
7
+ 'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes',
8
+ 'flowshow', 'flow2rgb', 'make_color_wheel'
9
+ ]
microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/color.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from enum import Enum
3
+
4
+ import numpy as np
5
+
6
+ from annotator.mmpkg.mmcv.utils import is_str
7
+
8
+
9
+ class Color(Enum):
10
+ """An enum that defines common colors.
11
+
12
+ Contains red, green, blue, cyan, yellow, magenta, white and black.
13
+ """
14
+ red = (0, 0, 255)
15
+ green = (0, 255, 0)
16
+ blue = (255, 0, 0)
17
+ cyan = (255, 255, 0)
18
+ yellow = (0, 255, 255)
19
+ magenta = (255, 0, 255)
20
+ white = (255, 255, 255)
21
+ black = (0, 0, 0)
22
+
23
+
24
+ def color_val(color):
25
+ """Convert various input to color tuples.
26
+
27
+ Args:
28
+ color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
29
+
30
+ Returns:
31
+ tuple[int]: A tuple of 3 integers indicating BGR channels.
32
+ """
33
+ if is_str(color):
34
+ return Color[color].value
35
+ elif isinstance(color, Color):
36
+ return color.value
37
+ elif isinstance(color, tuple):
38
+ assert len(color) == 3
39
+ for channel in color:
40
+ assert 0 <= channel <= 255
41
+ return color
42
+ elif isinstance(color, int):
43
+ assert 0 <= color <= 255
44
+ return color, color, color
45
+ elif isinstance(color, np.ndarray):
46
+ assert color.ndim == 1 and color.size == 3
47
+ assert np.all((color >= 0) & (color <= 255))
48
+ color = color.astype(np.uint8)
49
+ return tuple(color)
50
+ else:
51
+ raise TypeError(f'Invalid type for color: {type(color)}')
microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/image.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import cv2
3
+ import numpy as np
4
+
5
+ from annotator.mmpkg.mmcv.image import imread, imwrite
6
+ from .color import color_val
7
+
8
+
9
+ def imshow(img, win_name='', wait_time=0):
10
+ """Show an image.
11
+
12
+ Args:
13
+ img (str or ndarray): The image to be displayed.
14
+ win_name (str): The window name.
15
+ wait_time (int): Value of waitKey param.
16
+ """
17
+ cv2.imshow(win_name, imread(img))
18
+ if wait_time == 0: # prevent from hanging if windows was closed
19
+ while True:
20
+ ret = cv2.waitKey(1)
21
+
22
+ closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
23
+ # if user closed window or if some key pressed
24
+ if closed or ret != -1:
25
+ break
26
+ else:
27
+ ret = cv2.waitKey(wait_time)
28
+
29
+
30
+ def imshow_bboxes(img,
31
+ bboxes,
32
+ colors='green',
33
+ top_k=-1,
34
+ thickness=1,
35
+ show=True,
36
+ win_name='',
37
+ wait_time=0,
38
+ out_file=None):
39
+ """Draw bboxes on an image.
40
+
41
+ Args:
42
+ img (str or ndarray): The image to be displayed.
43
+ bboxes (list or ndarray): A list of ndarray of shape (k, 4).
44
+ colors (list[str or tuple or Color]): A list of colors.
45
+ top_k (int): Plot the first k bboxes only if set positive.
46
+ thickness (int): Thickness of lines.
47
+ show (bool): Whether to show the image.
48
+ win_name (str): The window name.
49
+ wait_time (int): Value of waitKey param.
50
+ out_file (str, optional): The filename to write the image.
51
+
52
+ Returns:
53
+ ndarray: The image with bboxes drawn on it.
54
+ """
55
+ img = imread(img)
56
+ img = np.ascontiguousarray(img)
57
+
58
+ if isinstance(bboxes, np.ndarray):
59
+ bboxes = [bboxes]
60
+ if not isinstance(colors, list):
61
+ colors = [colors for _ in range(len(bboxes))]
62
+ colors = [color_val(c) for c in colors]
63
+ assert len(bboxes) == len(colors)
64
+
65
+ for i, _bboxes in enumerate(bboxes):
66
+ _bboxes = _bboxes.astype(np.int32)
67
+ if top_k <= 0:
68
+ _top_k = _bboxes.shape[0]
69
+ else:
70
+ _top_k = min(top_k, _bboxes.shape[0])
71
+ for j in range(_top_k):
72
+ left_top = (_bboxes[j, 0], _bboxes[j, 1])
73
+ right_bottom = (_bboxes[j, 2], _bboxes[j, 3])
74
+ cv2.rectangle(
75
+ img, left_top, right_bottom, colors[i], thickness=thickness)
76
+
77
+ if show:
78
+ imshow(img, win_name, wait_time)
79
+ if out_file is not None:
80
+ imwrite(img, out_file)
81
+ return img
82
+
83
+
84
+ def imshow_det_bboxes(img,
85
+ bboxes,
86
+ labels,
87
+ class_names=None,
88
+ score_thr=0,
89
+ bbox_color='green',
90
+ text_color='green',
91
+ thickness=1,
92
+ font_scale=0.5,
93
+ show=True,
94
+ win_name='',
95
+ wait_time=0,
96
+ out_file=None):
97
+ """Draw bboxes and class labels (with scores) on an image.
98
+
99
+ Args:
100
+ img (str or ndarray): The image to be displayed.
101
+ bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
102
+ (n, 5).
103
+ labels (ndarray): Labels of bboxes.
104
+ class_names (list[str]): Names of each classes.
105
+ score_thr (float): Minimum score of bboxes to be shown.
106
+ bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
107
+ text_color (str or tuple or :obj:`Color`): Color of texts.
108
+ thickness (int): Thickness of lines.
109
+ font_scale (float): Font scales of texts.
110
+ show (bool): Whether to show the image.
111
+ win_name (str): The window name.
112
+ wait_time (int): Value of waitKey param.
113
+ out_file (str or None): The filename to write the image.
114
+
115
+ Returns:
116
+ ndarray: The image with bboxes drawn on it.
117
+ """
118
+ assert bboxes.ndim == 2
119
+ assert labels.ndim == 1
120
+ assert bboxes.shape[0] == labels.shape[0]
121
+ assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5
122
+ img = imread(img)
123
+ img = np.ascontiguousarray(img)
124
+
125
+ if score_thr > 0:
126
+ assert bboxes.shape[1] == 5
127
+ scores = bboxes[:, -1]
128
+ inds = scores > score_thr
129
+ bboxes = bboxes[inds, :]
130
+ labels = labels[inds]
131
+
132
+ bbox_color = color_val(bbox_color)
133
+ text_color = color_val(text_color)
134
+
135
+ for bbox, label in zip(bboxes, labels):
136
+ bbox_int = bbox.astype(np.int32)
137
+ left_top = (bbox_int[0], bbox_int[1])
138
+ right_bottom = (bbox_int[2], bbox_int[3])
139
+ cv2.rectangle(
140
+ img, left_top, right_bottom, bbox_color, thickness=thickness)
141
+ label_text = class_names[
142
+ label] if class_names is not None else f'cls {label}'
143
+ if len(bbox) > 4:
144
+ label_text += f'|{bbox[-1]:.02f}'
145
+ cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 2),
146
+ cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
147
+
148
+ if show:
149
+ imshow(img, win_name, wait_time)
150
+ if out_file is not None:
151
+ imwrite(img, out_file)
152
+ return img
microsoftexcel-controlnet/annotator/mmpkg/mmcv/visualization/optflow.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from __future__ import division
3
+
4
+ import numpy as np
5
+
6
+ from annotator.mmpkg.mmcv.image import rgb2bgr
7
+ from annotator.mmpkg.mmcv.video import flowread
8
+ from .image import imshow
9
+
10
+
11
+ def flowshow(flow, win_name='', wait_time=0):
12
+ """Show optical flow.
13
+
14
+ Args:
15
+ flow (ndarray or str): The optical flow to be displayed.
16
+ win_name (str): The window name.
17
+ wait_time (int): Value of waitKey param.
18
+ """
19
+ flow = flowread(flow)
20
+ flow_img = flow2rgb(flow)
21
+ imshow(rgb2bgr(flow_img), win_name, wait_time)
22
+
23
+
24
+ def flow2rgb(flow, color_wheel=None, unknown_thr=1e6):
25
+ """Convert flow map to RGB image.
26
+
27
+ Args:
28
+ flow (ndarray): Array of optical flow.
29
+ color_wheel (ndarray or None): Color wheel used to map flow field to
30
+ RGB colorspace. Default color wheel will be used if not specified.
31
+ unknown_thr (str): Values above this threshold will be marked as
32
+ unknown and thus ignored.
33
+
34
+ Returns:
35
+ ndarray: RGB image that can be visualized.
36
+ """
37
+ assert flow.ndim == 3 and flow.shape[-1] == 2
38
+ if color_wheel is None:
39
+ color_wheel = make_color_wheel()
40
+ assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3
41
+ num_bins = color_wheel.shape[0]
42
+
43
+ dx = flow[:, :, 0].copy()
44
+ dy = flow[:, :, 1].copy()
45
+
46
+ ignore_inds = (
47
+ np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) |
48
+ (np.abs(dy) > unknown_thr))
49
+ dx[ignore_inds] = 0
50
+ dy[ignore_inds] = 0
51
+
52
+ rad = np.sqrt(dx**2 + dy**2)
53
+ if np.any(rad > np.finfo(float).eps):
54
+ max_rad = np.max(rad)
55
+ dx /= max_rad
56
+ dy /= max_rad
57
+
58
+ rad = np.sqrt(dx**2 + dy**2)
59
+ angle = np.arctan2(-dy, -dx) / np.pi
60
+
61
+ bin_real = (angle + 1) / 2 * (num_bins - 1)
62
+ bin_left = np.floor(bin_real).astype(int)
63
+ bin_right = (bin_left + 1) % num_bins
64
+ w = (bin_real - bin_left.astype(np.float32))[..., None]
65
+ flow_img = (1 -
66
+ w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :]
67
+ small_ind = rad <= 1
68
+ flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind])
69
+ flow_img[np.logical_not(small_ind)] *= 0.75
70
+
71
+ flow_img[ignore_inds, :] = 0
72
+
73
+ return flow_img
74
+
75
+
76
+ def make_color_wheel(bins=None):
77
+ """Build a color wheel.
78
+
79
+ Args:
80
+ bins(list or tuple, optional): Specify the number of bins for each
81
+ color range, corresponding to six ranges: red -> yellow,
82
+ yellow -> green, green -> cyan, cyan -> blue, blue -> magenta,
83
+ magenta -> red. [15, 6, 4, 11, 13, 6] is used for default
84
+ (see Middlebury).
85
+
86
+ Returns:
87
+ ndarray: Color wheel of shape (total_bins, 3).
88
+ """
89
+ if bins is None:
90
+ bins = [15, 6, 4, 11, 13, 6]
91
+ assert len(bins) == 6
92
+
93
+ RY, YG, GC, CB, BM, MR = tuple(bins)
94
+
95
+ ry = [1, np.arange(RY) / RY, 0]
96
+ yg = [1 - np.arange(YG) / YG, 1, 0]
97
+ gc = [0, 1, np.arange(GC) / GC]
98
+ cb = [0, 1 - np.arange(CB) / CB, 1]
99
+ bm = [np.arange(BM) / BM, 0, 1]
100
+ mr = [1, 0, 1 - np.arange(MR) / MR]
101
+
102
+ num_bins = RY + YG + GC + CB + BM + MR
103
+
104
+ color_wheel = np.zeros((3, num_bins), dtype=np.float32)
105
+
106
+ col = 0
107
+ for i, color in enumerate([ry, yg, gc, cb, bm, mr]):
108
+ for j in range(3):
109
+ color_wheel[j, col:col + bins[i]] = color[j]
110
+ col += bins[i]
111
+
112
+ return color_wheel.T
microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .inference import inference_segmentor, init_segmentor, show_result_pyplot
2
+ from .test import multi_gpu_test, single_gpu_test
3
+ from .train import get_root_logger, set_random_seed, train_segmentor
4
+
5
+ __all__ = [
6
+ 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
7
+ 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
8
+ 'show_result_pyplot'
9
+ ]
microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/inference.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import annotator.mmpkg.mmcv as mmcv
3
+ import torch
4
+ from annotator.mmpkg.mmcv.parallel import collate, scatter
5
+ from annotator.mmpkg.mmcv.runner import load_checkpoint
6
+
7
+ from annotator.mmpkg.mmseg.datasets.pipelines import Compose
8
+ from annotator.mmpkg.mmseg.models import build_segmentor
9
+ from modules import devices
10
+
11
+
12
+ def init_segmentor(config, checkpoint=None, device=devices.get_device_for("controlnet")):
13
+ """Initialize a segmentor from config file.
14
+
15
+ Args:
16
+ config (str or :obj:`mmcv.Config`): Config file path or the config
17
+ object.
18
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
19
+ will not load any weights.
20
+ device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
21
+ Use 'cpu' for loading model on CPU.
22
+ Returns:
23
+ nn.Module: The constructed segmentor.
24
+ """
25
+ if isinstance(config, str):
26
+ config = mmcv.Config.fromfile(config)
27
+ elif not isinstance(config, mmcv.Config):
28
+ raise TypeError('config must be a filename or Config object, '
29
+ 'but got {}'.format(type(config)))
30
+ config.model.pretrained = None
31
+ config.model.train_cfg = None
32
+ model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
33
+ if checkpoint is not None:
34
+ checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
35
+ model.CLASSES = checkpoint['meta']['CLASSES']
36
+ model.PALETTE = checkpoint['meta']['PALETTE']
37
+ model.cfg = config # save the config in the model for convenience
38
+ model.to(device)
39
+ model.eval()
40
+ return model
41
+
42
+
43
+ class LoadImage:
44
+ """A simple pipeline to load image."""
45
+
46
+ def __call__(self, results):
47
+ """Call function to load images into results.
48
+
49
+ Args:
50
+ results (dict): A result dict contains the file name
51
+ of the image to be read.
52
+
53
+ Returns:
54
+ dict: ``results`` will be returned containing loaded image.
55
+ """
56
+
57
+ if isinstance(results['img'], str):
58
+ results['filename'] = results['img']
59
+ results['ori_filename'] = results['img']
60
+ else:
61
+ results['filename'] = None
62
+ results['ori_filename'] = None
63
+ img = mmcv.imread(results['img'])
64
+ results['img'] = img
65
+ results['img_shape'] = img.shape
66
+ results['ori_shape'] = img.shape
67
+ return results
68
+
69
+
70
+ def inference_segmentor(model, img):
71
+ """Inference image(s) with the segmentor.
72
+
73
+ Args:
74
+ model (nn.Module): The loaded segmentor.
75
+ imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
76
+ images.
77
+
78
+ Returns:
79
+ (list[Tensor]): The segmentation result.
80
+ """
81
+ cfg = model.cfg
82
+ device = next(model.parameters()).device # model device
83
+ # build the data pipeline
84
+ test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
85
+ test_pipeline = Compose(test_pipeline)
86
+ # prepare data
87
+ data = dict(img=img)
88
+ data = test_pipeline(data)
89
+ data = collate([data], samples_per_gpu=1)
90
+ if next(model.parameters()).is_cuda:
91
+ # scatter to specified GPU
92
+ data = scatter(data, [device])[0]
93
+ else:
94
+ data['img'][0] = data['img'][0].to(devices.get_device_for("controlnet"))
95
+ data['img_metas'] = [i.data[0] for i in data['img_metas']]
96
+
97
+ # forward the model
98
+ with torch.no_grad():
99
+ result = model(return_loss=False, rescale=True, **data)
100
+ return result
101
+
102
+
103
+ def show_result_pyplot(model,
104
+ img,
105
+ result,
106
+ palette=None,
107
+ fig_size=(15, 10),
108
+ opacity=0.5,
109
+ title='',
110
+ block=True):
111
+ """Visualize the segmentation results on the image.
112
+
113
+ Args:
114
+ model (nn.Module): The loaded segmentor.
115
+ img (str or np.ndarray): Image filename or loaded image.
116
+ result (list): The segmentation result.
117
+ palette (list[list[int]]] | None): The palette of segmentation
118
+ map. If None is given, random palette will be generated.
119
+ Default: None
120
+ fig_size (tuple): Figure size of the pyplot figure.
121
+ opacity(float): Opacity of painted segmentation map.
122
+ Default 0.5.
123
+ Must be in (0, 1] range.
124
+ title (str): The title of pyplot figure.
125
+ Default is ''.
126
+ block (bool): Whether to block the pyplot figure.
127
+ Default is True.
128
+ """
129
+ if hasattr(model, 'module'):
130
+ model = model.module
131
+ img = model.show_result(
132
+ img, result, palette=palette, show=False, opacity=opacity)
133
+ # plt.figure(figsize=fig_size)
134
+ # plt.imshow(mmcv.bgr2rgb(img))
135
+ # plt.title(title)
136
+ # plt.tight_layout()
137
+ # plt.show(block=block)
138
+ return mmcv.bgr2rgb(img)
microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/test.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import pickle
3
+ import shutil
4
+ import tempfile
5
+
6
+ import annotator.mmpkg.mmcv as mmcv
7
+ import numpy as np
8
+ import torch
9
+ import torch.distributed as dist
10
+ from annotator.mmpkg.mmcv.image import tensor2imgs
11
+ from annotator.mmpkg.mmcv.runner import get_dist_info
12
+
13
+
14
+ def np2tmp(array, temp_file_name=None):
15
+ """Save ndarray to local numpy file.
16
+
17
+ Args:
18
+ array (ndarray): Ndarray to save.
19
+ temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
20
+ function will generate a file name with tempfile.NamedTemporaryFile
21
+ to save ndarray. Default: None.
22
+
23
+ Returns:
24
+ str: The numpy file name.
25
+ """
26
+
27
+ if temp_file_name is None:
28
+ temp_file_name = tempfile.NamedTemporaryFile(
29
+ suffix='.npy', delete=False).name
30
+ np.save(temp_file_name, array)
31
+ return temp_file_name
32
+
33
+
34
+ def single_gpu_test(model,
35
+ data_loader,
36
+ show=False,
37
+ out_dir=None,
38
+ efficient_test=False,
39
+ opacity=0.5):
40
+ """Test with single GPU.
41
+
42
+ Args:
43
+ model (nn.Module): Model to be tested.
44
+ data_loader (utils.data.Dataloader): Pytorch data loader.
45
+ show (bool): Whether show results during inference. Default: False.
46
+ out_dir (str, optional): If specified, the results will be dumped into
47
+ the directory to save output results.
48
+ efficient_test (bool): Whether save the results as local numpy files to
49
+ save CPU memory during evaluation. Default: False.
50
+ opacity(float): Opacity of painted segmentation map.
51
+ Default 0.5.
52
+ Must be in (0, 1] range.
53
+ Returns:
54
+ list: The prediction results.
55
+ """
56
+
57
+ model.eval()
58
+ results = []
59
+ dataset = data_loader.dataset
60
+ prog_bar = mmcv.ProgressBar(len(dataset))
61
+ for i, data in enumerate(data_loader):
62
+ with torch.no_grad():
63
+ result = model(return_loss=False, **data)
64
+
65
+ if show or out_dir:
66
+ img_tensor = data['img'][0]
67
+ img_metas = data['img_metas'][0].data[0]
68
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
69
+ assert len(imgs) == len(img_metas)
70
+
71
+ for img, img_meta in zip(imgs, img_metas):
72
+ h, w, _ = img_meta['img_shape']
73
+ img_show = img[:h, :w, :]
74
+
75
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
76
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
77
+
78
+ if out_dir:
79
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
80
+ else:
81
+ out_file = None
82
+
83
+ model.module.show_result(
84
+ img_show,
85
+ result,
86
+ palette=dataset.PALETTE,
87
+ show=show,
88
+ out_file=out_file,
89
+ opacity=opacity)
90
+
91
+ if isinstance(result, list):
92
+ if efficient_test:
93
+ result = [np2tmp(_) for _ in result]
94
+ results.extend(result)
95
+ else:
96
+ if efficient_test:
97
+ result = np2tmp(result)
98
+ results.append(result)
99
+
100
+ batch_size = len(result)
101
+ for _ in range(batch_size):
102
+ prog_bar.update()
103
+ return results
104
+
105
+
106
+ def multi_gpu_test(model,
107
+ data_loader,
108
+ tmpdir=None,
109
+ gpu_collect=False,
110
+ efficient_test=False):
111
+ """Test model with multiple gpus.
112
+
113
+ This method tests model with multiple gpus and collects the results
114
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
115
+ it encodes results to gpu tensors and use gpu communication for results
116
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
117
+ and collects them by the rank 0 worker.
118
+
119
+ Args:
120
+ model (nn.Module): Model to be tested.
121
+ data_loader (utils.data.Dataloader): Pytorch data loader.
122
+ tmpdir (str): Path of directory to save the temporary results from
123
+ different gpus under cpu mode.
124
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
125
+ efficient_test (bool): Whether save the results as local numpy files to
126
+ save CPU memory during evaluation. Default: False.
127
+
128
+ Returns:
129
+ list: The prediction results.
130
+ """
131
+
132
+ model.eval()
133
+ results = []
134
+ dataset = data_loader.dataset
135
+ rank, world_size = get_dist_info()
136
+ if rank == 0:
137
+ prog_bar = mmcv.ProgressBar(len(dataset))
138
+ for i, data in enumerate(data_loader):
139
+ with torch.no_grad():
140
+ result = model(return_loss=False, rescale=True, **data)
141
+
142
+ if isinstance(result, list):
143
+ if efficient_test:
144
+ result = [np2tmp(_) for _ in result]
145
+ results.extend(result)
146
+ else:
147
+ if efficient_test:
148
+ result = np2tmp(result)
149
+ results.append(result)
150
+
151
+ if rank == 0:
152
+ batch_size = data['img'][0].size(0)
153
+ for _ in range(batch_size * world_size):
154
+ prog_bar.update()
155
+
156
+ # collect results from all ranks
157
+ if gpu_collect:
158
+ results = collect_results_gpu(results, len(dataset))
159
+ else:
160
+ results = collect_results_cpu(results, len(dataset), tmpdir)
161
+ return results
162
+
163
+
164
+ def collect_results_cpu(result_part, size, tmpdir=None):
165
+ """Collect results with CPU."""
166
+ rank, world_size = get_dist_info()
167
+ # create a tmp dir if it is not specified
168
+ if tmpdir is None:
169
+ MAX_LEN = 512
170
+ # 32 is whitespace
171
+ dir_tensor = torch.full((MAX_LEN, ),
172
+ 32,
173
+ dtype=torch.uint8,
174
+ device='cuda')
175
+ if rank == 0:
176
+ tmpdir = tempfile.mkdtemp()
177
+ tmpdir = torch.tensor(
178
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
179
+ dir_tensor[:len(tmpdir)] = tmpdir
180
+ dist.broadcast(dir_tensor, 0)
181
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
182
+ else:
183
+ mmcv.mkdir_or_exist(tmpdir)
184
+ # dump the part result to the dir
185
+ mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
186
+ dist.barrier()
187
+ # collect all parts
188
+ if rank != 0:
189
+ return None
190
+ else:
191
+ # load results of all parts from tmp dir
192
+ part_list = []
193
+ for i in range(world_size):
194
+ part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
195
+ part_list.append(mmcv.load(part_file))
196
+ # sort the results
197
+ ordered_results = []
198
+ for res in zip(*part_list):
199
+ ordered_results.extend(list(res))
200
+ # the dataloader may pad some samples
201
+ ordered_results = ordered_results[:size]
202
+ # remove tmp dir
203
+ shutil.rmtree(tmpdir)
204
+ return ordered_results
205
+
206
+
207
+ def collect_results_gpu(result_part, size):
208
+ """Collect results with GPU."""
209
+ rank, world_size = get_dist_info()
210
+ # dump result part to tensor with pickle
211
+ part_tensor = torch.tensor(
212
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
213
+ # gather all result part tensor shape
214
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
215
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
216
+ dist.all_gather(shape_list, shape_tensor)
217
+ # padding result part tensor to max length
218
+ shape_max = torch.tensor(shape_list).max()
219
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
220
+ part_send[:shape_tensor[0]] = part_tensor
221
+ part_recv_list = [
222
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
223
+ ]
224
+ # gather all result part
225
+ dist.all_gather(part_recv_list, part_send)
226
+
227
+ if rank == 0:
228
+ part_list = []
229
+ for recv, shape in zip(part_recv_list, shape_list):
230
+ part_list.append(
231
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
232
+ # sort the results
233
+ ordered_results = []
234
+ for res in zip(*part_list):
235
+ ordered_results.extend(list(res))
236
+ # the dataloader may pad some samples
237
+ ordered_results = ordered_results[:size]
238
+ return ordered_results
microsoftexcel-controlnet/annotator/mmpkg/mmseg/apis/train.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+
4
+ import numpy as np
5
+ import torch
6
+ from annotator.mmpkg.mmcv.parallel import MMDataParallel, MMDistributedDataParallel
7
+ from annotator.mmpkg.mmcv.runner import build_optimizer, build_runner
8
+
9
+ from annotator.mmpkg.mmseg.core import DistEvalHook, EvalHook
10
+ from annotator.mmpkg.mmseg.datasets import build_dataloader, build_dataset
11
+ from annotator.mmpkg.mmseg.utils import get_root_logger
12
+
13
+
14
+ def set_random_seed(seed, deterministic=False):
15
+ """Set random seed.
16
+
17
+ Args:
18
+ seed (int): Seed to be used.
19
+ deterministic (bool): Whether to set the deterministic option for
20
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
21
+ to True and `torch.backends.cudnn.benchmark` to False.
22
+ Default: False.
23
+ """
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed_all(seed)
28
+ if deterministic:
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+
32
+
33
+ def train_segmentor(model,
34
+ dataset,
35
+ cfg,
36
+ distributed=False,
37
+ validate=False,
38
+ timestamp=None,
39
+ meta=None):
40
+ """Launch segmentor training."""
41
+ logger = get_root_logger(cfg.log_level)
42
+
43
+ # prepare data loaders
44
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
45
+ data_loaders = [
46
+ build_dataloader(
47
+ ds,
48
+ cfg.data.samples_per_gpu,
49
+ cfg.data.workers_per_gpu,
50
+ # cfg.gpus will be ignored if distributed
51
+ len(cfg.gpu_ids),
52
+ dist=distributed,
53
+ seed=cfg.seed,
54
+ drop_last=True) for ds in dataset
55
+ ]
56
+
57
+ # put model on gpus
58
+ if distributed:
59
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
60
+ # Sets the `find_unused_parameters` parameter in
61
+ # torch.nn.parallel.DistributedDataParallel
62
+ model = MMDistributedDataParallel(
63
+ model.cuda(),
64
+ device_ids=[torch.cuda.current_device()],
65
+ broadcast_buffers=False,
66
+ find_unused_parameters=find_unused_parameters)
67
+ else:
68
+ model = MMDataParallel(
69
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
70
+
71
+ # build runner
72
+ optimizer = build_optimizer(model, cfg.optimizer)
73
+
74
+ if cfg.get('runner') is None:
75
+ cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
76
+ warnings.warn(
77
+ 'config is now expected to have a `runner` section, '
78
+ 'please set `runner` in your config.', UserWarning)
79
+
80
+ runner = build_runner(
81
+ cfg.runner,
82
+ default_args=dict(
83
+ model=model,
84
+ batch_processor=None,
85
+ optimizer=optimizer,
86
+ work_dir=cfg.work_dir,
87
+ logger=logger,
88
+ meta=meta))
89
+
90
+ # register hooks
91
+ runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
92
+ cfg.checkpoint_config, cfg.log_config,
93
+ cfg.get('momentum_config', None))
94
+
95
+ # an ugly walkaround to make the .log and .log.json filenames the same
96
+ runner.timestamp = timestamp
97
+
98
+ # register eval hooks
99
+ if validate:
100
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
101
+ val_dataloader = build_dataloader(
102
+ val_dataset,
103
+ samples_per_gpu=1,
104
+ workers_per_gpu=cfg.data.workers_per_gpu,
105
+ dist=distributed,
106
+ shuffle=False)
107
+ eval_cfg = cfg.get('evaluation', {})
108
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
109
+ eval_hook = DistEvalHook if distributed else EvalHook
110
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW')
111
+
112
+ if cfg.resume_from:
113
+ runner.resume(cfg.resume_from)
114
+ elif cfg.load_from:
115
+ runner.load_checkpoint(cfg.load_from)
116
+ runner.run(data_loaders, cfg.workflow)
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .evaluation import * # noqa: F401, F403
2
+ from .seg import * # noqa: F401, F403
3
+ from .utils import * # noqa: F401, F403
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .class_names import get_classes, get_palette
2
+ from .eval_hooks import DistEvalHook, EvalHook
3
+ from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
4
+
5
+ __all__ = [
6
+ 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
7
+ 'eval_metrics', 'get_classes', 'get_palette'
8
+ ]
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/class_names.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import annotator.mmpkg.mmcv as mmcv
2
+
3
+
4
+ def cityscapes_classes():
5
+ """Cityscapes class names for external use."""
6
+ return [
7
+ 'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
8
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
9
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
10
+ 'bicycle'
11
+ ]
12
+
13
+
14
+ def ade_classes():
15
+ """ADE20K class names for external use."""
16
+ return [
17
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
18
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
19
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
20
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
21
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
22
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
23
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
24
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
25
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
26
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
27
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
28
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
29
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
30
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
31
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
32
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
33
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
34
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
35
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
36
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
37
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
38
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
39
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
40
+ 'clock', 'flag'
41
+ ]
42
+
43
+
44
+ def voc_classes():
45
+ """Pascal VOC class names for external use."""
46
+ return [
47
+ 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
48
+ 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
49
+ 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
50
+ 'tvmonitor'
51
+ ]
52
+
53
+
54
+ def cityscapes_palette():
55
+ """Cityscapes palette for external use."""
56
+ return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
57
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
58
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
59
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
60
+ [0, 0, 230], [119, 11, 32]]
61
+
62
+
63
+ def ade_palette():
64
+ """ADE20K palette for external use."""
65
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
66
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
67
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
68
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
69
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
70
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
71
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
72
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
73
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
74
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
75
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
76
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
77
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
78
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
79
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
80
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
81
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
82
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
83
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
84
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
85
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
86
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
87
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
88
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
89
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
90
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
91
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
92
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
93
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
94
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
95
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
96
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
97
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
98
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
99
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
100
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
101
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
102
+ [102, 255, 0], [92, 0, 255]]
103
+
104
+
105
+ def voc_palette():
106
+ """Pascal VOC palette for external use."""
107
+ return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
108
+ [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
109
+ [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
110
+ [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
111
+ [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
112
+
113
+
114
+ dataset_aliases = {
115
+ 'cityscapes': ['cityscapes'],
116
+ 'ade': ['ade', 'ade20k'],
117
+ 'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
118
+ }
119
+
120
+
121
+ def get_classes(dataset):
122
+ """Get class names of a dataset."""
123
+ alias2name = {}
124
+ for name, aliases in dataset_aliases.items():
125
+ for alias in aliases:
126
+ alias2name[alias] = name
127
+
128
+ if mmcv.is_str(dataset):
129
+ if dataset in alias2name:
130
+ labels = eval(alias2name[dataset] + '_classes()')
131
+ else:
132
+ raise ValueError(f'Unrecognized dataset: {dataset}')
133
+ else:
134
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
135
+ return labels
136
+
137
+
138
+ def get_palette(dataset):
139
+ """Get class palette (RGB) of a dataset."""
140
+ alias2name = {}
141
+ for name, aliases in dataset_aliases.items():
142
+ for alias in aliases:
143
+ alias2name[alias] = name
144
+
145
+ if mmcv.is_str(dataset):
146
+ if dataset in alias2name:
147
+ labels = eval(alias2name[dataset] + '_palette()')
148
+ else:
149
+ raise ValueError(f'Unrecognized dataset: {dataset}')
150
+ else:
151
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
152
+ return labels
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/eval_hooks.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from annotator.mmpkg.mmcv.runner import DistEvalHook as _DistEvalHook
4
+ from annotator.mmpkg.mmcv.runner import EvalHook as _EvalHook
5
+
6
+
7
+ class EvalHook(_EvalHook):
8
+ """Single GPU EvalHook, with efficient test support.
9
+
10
+ Args:
11
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
12
+ If set to True, it will perform by epoch. Otherwise, by iteration.
13
+ Default: False.
14
+ efficient_test (bool): Whether save the results as local numpy files to
15
+ save CPU memory during evaluation. Default: False.
16
+ Returns:
17
+ list: The prediction results.
18
+ """
19
+
20
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
21
+
22
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
23
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
24
+ self.efficient_test = efficient_test
25
+
26
+ def after_train_iter(self, runner):
27
+ """After train epoch hook.
28
+
29
+ Override default ``single_gpu_test``.
30
+ """
31
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
32
+ return
33
+ from annotator.mmpkg.mmseg.apis import single_gpu_test
34
+ runner.log_buffer.clear()
35
+ results = single_gpu_test(
36
+ runner.model,
37
+ self.dataloader,
38
+ show=False,
39
+ efficient_test=self.efficient_test)
40
+ self.evaluate(runner, results)
41
+
42
+ def after_train_epoch(self, runner):
43
+ """After train epoch hook.
44
+
45
+ Override default ``single_gpu_test``.
46
+ """
47
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
48
+ return
49
+ from annotator.mmpkg.mmseg.apis import single_gpu_test
50
+ runner.log_buffer.clear()
51
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
52
+ self.evaluate(runner, results)
53
+
54
+
55
+ class DistEvalHook(_DistEvalHook):
56
+ """Distributed EvalHook, with efficient test support.
57
+
58
+ Args:
59
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
60
+ If set to True, it will perform by epoch. Otherwise, by iteration.
61
+ Default: False.
62
+ efficient_test (bool): Whether save the results as local numpy files to
63
+ save CPU memory during evaluation. Default: False.
64
+ Returns:
65
+ list: The prediction results.
66
+ """
67
+
68
+ greater_keys = ['mIoU', 'mAcc', 'aAcc']
69
+
70
+ def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
71
+ super().__init__(*args, by_epoch=by_epoch, **kwargs)
72
+ self.efficient_test = efficient_test
73
+
74
+ def after_train_iter(self, runner):
75
+ """After train epoch hook.
76
+
77
+ Override default ``multi_gpu_test``.
78
+ """
79
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
80
+ return
81
+ from annotator.mmpkg.mmseg.apis import multi_gpu_test
82
+ runner.log_buffer.clear()
83
+ results = multi_gpu_test(
84
+ runner.model,
85
+ self.dataloader,
86
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
87
+ gpu_collect=self.gpu_collect,
88
+ efficient_test=self.efficient_test)
89
+ if runner.rank == 0:
90
+ print('\n')
91
+ self.evaluate(runner, results)
92
+
93
+ def after_train_epoch(self, runner):
94
+ """After train epoch hook.
95
+
96
+ Override default ``multi_gpu_test``.
97
+ """
98
+ if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
99
+ return
100
+ from annotator.mmpkg.mmseg.apis import multi_gpu_test
101
+ runner.log_buffer.clear()
102
+ results = multi_gpu_test(
103
+ runner.model,
104
+ self.dataloader,
105
+ tmpdir=osp.join(runner.work_dir, '.eval_hook'),
106
+ gpu_collect=self.gpu_collect)
107
+ if runner.rank == 0:
108
+ print('\n')
109
+ self.evaluate(runner, results)
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/evaluation/metrics.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import annotator.mmpkg.mmcv as mmcv
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def f_score(precision, recall, beta=1):
9
+ """calcuate the f-score value.
10
+
11
+ Args:
12
+ precision (float | torch.Tensor): The precision value.
13
+ recall (float | torch.Tensor): The recall value.
14
+ beta (int): Determines the weight of recall in the combined score.
15
+ Default: False.
16
+
17
+ Returns:
18
+ [torch.tensor]: The f-score value.
19
+ """
20
+ score = (1 + beta**2) * (precision * recall) / (
21
+ (beta**2 * precision) + recall)
22
+ return score
23
+
24
+
25
+ def intersect_and_union(pred_label,
26
+ label,
27
+ num_classes,
28
+ ignore_index,
29
+ label_map=dict(),
30
+ reduce_zero_label=False):
31
+ """Calculate intersection and Union.
32
+
33
+ Args:
34
+ pred_label (ndarray | str): Prediction segmentation map
35
+ or predict result filename.
36
+ label (ndarray | str): Ground truth segmentation map
37
+ or label filename.
38
+ num_classes (int): Number of categories.
39
+ ignore_index (int): Index that will be ignored in evaluation.
40
+ label_map (dict): Mapping old labels to new labels. The parameter will
41
+ work only when label is str. Default: dict().
42
+ reduce_zero_label (bool): Whether ignore zero label. The parameter will
43
+ work only when label is str. Default: False.
44
+
45
+ Returns:
46
+ torch.Tensor: The intersection of prediction and ground truth
47
+ histogram on all classes.
48
+ torch.Tensor: The union of prediction and ground truth histogram on
49
+ all classes.
50
+ torch.Tensor: The prediction histogram on all classes.
51
+ torch.Tensor: The ground truth histogram on all classes.
52
+ """
53
+
54
+ if isinstance(pred_label, str):
55
+ pred_label = torch.from_numpy(np.load(pred_label))
56
+ else:
57
+ pred_label = torch.from_numpy((pred_label))
58
+
59
+ if isinstance(label, str):
60
+ label = torch.from_numpy(
61
+ mmcv.imread(label, flag='unchanged', backend='pillow'))
62
+ else:
63
+ label = torch.from_numpy(label)
64
+
65
+ if label_map is not None:
66
+ for old_id, new_id in label_map.items():
67
+ label[label == old_id] = new_id
68
+ if reduce_zero_label:
69
+ label[label == 0] = 255
70
+ label = label - 1
71
+ label[label == 254] = 255
72
+
73
+ mask = (label != ignore_index)
74
+ pred_label = pred_label[mask]
75
+ label = label[mask]
76
+
77
+ intersect = pred_label[pred_label == label]
78
+ area_intersect = torch.histc(
79
+ intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
80
+ area_pred_label = torch.histc(
81
+ pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
82
+ area_label = torch.histc(
83
+ label.float(), bins=(num_classes), min=0, max=num_classes - 1)
84
+ area_union = area_pred_label + area_label - area_intersect
85
+ return area_intersect, area_union, area_pred_label, area_label
86
+
87
+
88
+ def total_intersect_and_union(results,
89
+ gt_seg_maps,
90
+ num_classes,
91
+ ignore_index,
92
+ label_map=dict(),
93
+ reduce_zero_label=False):
94
+ """Calculate Total Intersection and Union.
95
+
96
+ Args:
97
+ results (list[ndarray] | list[str]): List of prediction segmentation
98
+ maps or list of prediction result filenames.
99
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
100
+ segmentation maps or list of label filenames.
101
+ num_classes (int): Number of categories.
102
+ ignore_index (int): Index that will be ignored in evaluation.
103
+ label_map (dict): Mapping old labels to new labels. Default: dict().
104
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
105
+
106
+ Returns:
107
+ ndarray: The intersection of prediction and ground truth histogram
108
+ on all classes.
109
+ ndarray: The union of prediction and ground truth histogram on all
110
+ classes.
111
+ ndarray: The prediction histogram on all classes.
112
+ ndarray: The ground truth histogram on all classes.
113
+ """
114
+ num_imgs = len(results)
115
+ assert len(gt_seg_maps) == num_imgs
116
+ total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
117
+ total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
118
+ total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
119
+ total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
120
+ for i in range(num_imgs):
121
+ area_intersect, area_union, area_pred_label, area_label = \
122
+ intersect_and_union(
123
+ results[i], gt_seg_maps[i], num_classes, ignore_index,
124
+ label_map, reduce_zero_label)
125
+ total_area_intersect += area_intersect
126
+ total_area_union += area_union
127
+ total_area_pred_label += area_pred_label
128
+ total_area_label += area_label
129
+ return total_area_intersect, total_area_union, total_area_pred_label, \
130
+ total_area_label
131
+
132
+
133
+ def mean_iou(results,
134
+ gt_seg_maps,
135
+ num_classes,
136
+ ignore_index,
137
+ nan_to_num=None,
138
+ label_map=dict(),
139
+ reduce_zero_label=False):
140
+ """Calculate Mean Intersection and Union (mIoU)
141
+
142
+ Args:
143
+ results (list[ndarray] | list[str]): List of prediction segmentation
144
+ maps or list of prediction result filenames.
145
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
146
+ segmentation maps or list of label filenames.
147
+ num_classes (int): Number of categories.
148
+ ignore_index (int): Index that will be ignored in evaluation.
149
+ nan_to_num (int, optional): If specified, NaN values will be replaced
150
+ by the numbers defined by the user. Default: None.
151
+ label_map (dict): Mapping old labels to new labels. Default: dict().
152
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
153
+
154
+ Returns:
155
+ dict[str, float | ndarray]:
156
+ <aAcc> float: Overall accuracy on all images.
157
+ <Acc> ndarray: Per category accuracy, shape (num_classes, ).
158
+ <IoU> ndarray: Per category IoU, shape (num_classes, ).
159
+ """
160
+ iou_result = eval_metrics(
161
+ results=results,
162
+ gt_seg_maps=gt_seg_maps,
163
+ num_classes=num_classes,
164
+ ignore_index=ignore_index,
165
+ metrics=['mIoU'],
166
+ nan_to_num=nan_to_num,
167
+ label_map=label_map,
168
+ reduce_zero_label=reduce_zero_label)
169
+ return iou_result
170
+
171
+
172
+ def mean_dice(results,
173
+ gt_seg_maps,
174
+ num_classes,
175
+ ignore_index,
176
+ nan_to_num=None,
177
+ label_map=dict(),
178
+ reduce_zero_label=False):
179
+ """Calculate Mean Dice (mDice)
180
+
181
+ Args:
182
+ results (list[ndarray] | list[str]): List of prediction segmentation
183
+ maps or list of prediction result filenames.
184
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
185
+ segmentation maps or list of label filenames.
186
+ num_classes (int): Number of categories.
187
+ ignore_index (int): Index that will be ignored in evaluation.
188
+ nan_to_num (int, optional): If specified, NaN values will be replaced
189
+ by the numbers defined by the user. Default: None.
190
+ label_map (dict): Mapping old labels to new labels. Default: dict().
191
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
192
+
193
+ Returns:
194
+ dict[str, float | ndarray]: Default metrics.
195
+ <aAcc> float: Overall accuracy on all images.
196
+ <Acc> ndarray: Per category accuracy, shape (num_classes, ).
197
+ <Dice> ndarray: Per category dice, shape (num_classes, ).
198
+ """
199
+
200
+ dice_result = eval_metrics(
201
+ results=results,
202
+ gt_seg_maps=gt_seg_maps,
203
+ num_classes=num_classes,
204
+ ignore_index=ignore_index,
205
+ metrics=['mDice'],
206
+ nan_to_num=nan_to_num,
207
+ label_map=label_map,
208
+ reduce_zero_label=reduce_zero_label)
209
+ return dice_result
210
+
211
+
212
+ def mean_fscore(results,
213
+ gt_seg_maps,
214
+ num_classes,
215
+ ignore_index,
216
+ nan_to_num=None,
217
+ label_map=dict(),
218
+ reduce_zero_label=False,
219
+ beta=1):
220
+ """Calculate Mean Intersection and Union (mIoU)
221
+
222
+ Args:
223
+ results (list[ndarray] | list[str]): List of prediction segmentation
224
+ maps or list of prediction result filenames.
225
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
226
+ segmentation maps or list of label filenames.
227
+ num_classes (int): Number of categories.
228
+ ignore_index (int): Index that will be ignored in evaluation.
229
+ nan_to_num (int, optional): If specified, NaN values will be replaced
230
+ by the numbers defined by the user. Default: None.
231
+ label_map (dict): Mapping old labels to new labels. Default: dict().
232
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
233
+ beta (int): Determines the weight of recall in the combined score.
234
+ Default: False.
235
+
236
+
237
+ Returns:
238
+ dict[str, float | ndarray]: Default metrics.
239
+ <aAcc> float: Overall accuracy on all images.
240
+ <Fscore> ndarray: Per category recall, shape (num_classes, ).
241
+ <Precision> ndarray: Per category precision, shape (num_classes, ).
242
+ <Recall> ndarray: Per category f-score, shape (num_classes, ).
243
+ """
244
+ fscore_result = eval_metrics(
245
+ results=results,
246
+ gt_seg_maps=gt_seg_maps,
247
+ num_classes=num_classes,
248
+ ignore_index=ignore_index,
249
+ metrics=['mFscore'],
250
+ nan_to_num=nan_to_num,
251
+ label_map=label_map,
252
+ reduce_zero_label=reduce_zero_label,
253
+ beta=beta)
254
+ return fscore_result
255
+
256
+
257
+ def eval_metrics(results,
258
+ gt_seg_maps,
259
+ num_classes,
260
+ ignore_index,
261
+ metrics=['mIoU'],
262
+ nan_to_num=None,
263
+ label_map=dict(),
264
+ reduce_zero_label=False,
265
+ beta=1):
266
+ """Calculate evaluation metrics
267
+ Args:
268
+ results (list[ndarray] | list[str]): List of prediction segmentation
269
+ maps or list of prediction result filenames.
270
+ gt_seg_maps (list[ndarray] | list[str]): list of ground truth
271
+ segmentation maps or list of label filenames.
272
+ num_classes (int): Number of categories.
273
+ ignore_index (int): Index that will be ignored in evaluation.
274
+ metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
275
+ nan_to_num (int, optional): If specified, NaN values will be replaced
276
+ by the numbers defined by the user. Default: None.
277
+ label_map (dict): Mapping old labels to new labels. Default: dict().
278
+ reduce_zero_label (bool): Whether ignore zero label. Default: False.
279
+ Returns:
280
+ float: Overall accuracy on all images.
281
+ ndarray: Per category accuracy, shape (num_classes, ).
282
+ ndarray: Per category evaluation metrics, shape (num_classes, ).
283
+ """
284
+ if isinstance(metrics, str):
285
+ metrics = [metrics]
286
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
287
+ if not set(metrics).issubset(set(allowed_metrics)):
288
+ raise KeyError('metrics {} is not supported'.format(metrics))
289
+
290
+ total_area_intersect, total_area_union, total_area_pred_label, \
291
+ total_area_label = total_intersect_and_union(
292
+ results, gt_seg_maps, num_classes, ignore_index, label_map,
293
+ reduce_zero_label)
294
+ all_acc = total_area_intersect.sum() / total_area_label.sum()
295
+ ret_metrics = OrderedDict({'aAcc': all_acc})
296
+ for metric in metrics:
297
+ if metric == 'mIoU':
298
+ iou = total_area_intersect / total_area_union
299
+ acc = total_area_intersect / total_area_label
300
+ ret_metrics['IoU'] = iou
301
+ ret_metrics['Acc'] = acc
302
+ elif metric == 'mDice':
303
+ dice = 2 * total_area_intersect / (
304
+ total_area_pred_label + total_area_label)
305
+ acc = total_area_intersect / total_area_label
306
+ ret_metrics['Dice'] = dice
307
+ ret_metrics['Acc'] = acc
308
+ elif metric == 'mFscore':
309
+ precision = total_area_intersect / total_area_pred_label
310
+ recall = total_area_intersect / total_area_label
311
+ f_value = torch.tensor(
312
+ [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
313
+ ret_metrics['Fscore'] = f_value
314
+ ret_metrics['Precision'] = precision
315
+ ret_metrics['Recall'] = recall
316
+
317
+ ret_metrics = {
318
+ metric: value.numpy()
319
+ for metric, value in ret_metrics.items()
320
+ }
321
+ if nan_to_num is not None:
322
+ ret_metrics = OrderedDict({
323
+ metric: np.nan_to_num(metric_value, nan=nan_to_num)
324
+ for metric, metric_value in ret_metrics.items()
325
+ })
326
+ return ret_metrics
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .builder import build_pixel_sampler
2
+ from .sampler import BasePixelSampler, OHEMPixelSampler
3
+
4
+ __all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/builder.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg
2
+
3
+ PIXEL_SAMPLERS = Registry('pixel sampler')
4
+
5
+
6
+ def build_pixel_sampler(cfg, **default_args):
7
+ """Build pixel sampler for segmentation map."""
8
+ return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base_pixel_sampler import BasePixelSampler
2
+ from .ohem_pixel_sampler import OHEMPixelSampler
3
+
4
+ __all__ = ['BasePixelSampler', 'OHEMPixelSampler']
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/base_pixel_sampler.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+
4
+ class BasePixelSampler(metaclass=ABCMeta):
5
+ """Base class of pixel sampler."""
6
+
7
+ def __init__(self, **kwargs):
8
+ pass
9
+
10
+ @abstractmethod
11
+ def sample(self, seg_logit, seg_label):
12
+ """Placeholder for sample function."""
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/seg/sampler/ohem_pixel_sampler.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from ..builder import PIXEL_SAMPLERS
5
+ from .base_pixel_sampler import BasePixelSampler
6
+
7
+
8
+ @PIXEL_SAMPLERS.register_module()
9
+ class OHEMPixelSampler(BasePixelSampler):
10
+ """Online Hard Example Mining Sampler for segmentation.
11
+
12
+ Args:
13
+ context (nn.Module): The context of sampler, subclass of
14
+ :obj:`BaseDecodeHead`.
15
+ thresh (float, optional): The threshold for hard example selection.
16
+ Below which, are prediction with low confidence. If not
17
+ specified, the hard examples will be pixels of top ``min_kept``
18
+ loss. Default: None.
19
+ min_kept (int, optional): The minimum number of predictions to keep.
20
+ Default: 100000.
21
+ """
22
+
23
+ def __init__(self, context, thresh=None, min_kept=100000):
24
+ super(OHEMPixelSampler, self).__init__()
25
+ self.context = context
26
+ assert min_kept > 1
27
+ self.thresh = thresh
28
+ self.min_kept = min_kept
29
+
30
+ def sample(self, seg_logit, seg_label):
31
+ """Sample pixels that have high loss or with low prediction confidence.
32
+
33
+ Args:
34
+ seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
35
+ seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
36
+
37
+ Returns:
38
+ torch.Tensor: segmentation weight, shape (N, H, W)
39
+ """
40
+ with torch.no_grad():
41
+ assert seg_logit.shape[2:] == seg_label.shape[2:]
42
+ assert seg_label.shape[1] == 1
43
+ seg_label = seg_label.squeeze(1).long()
44
+ batch_kept = self.min_kept * seg_label.size(0)
45
+ valid_mask = seg_label != self.context.ignore_index
46
+ seg_weight = seg_logit.new_zeros(size=seg_label.size())
47
+ valid_seg_weight = seg_weight[valid_mask]
48
+ if self.thresh is not None:
49
+ seg_prob = F.softmax(seg_logit, dim=1)
50
+
51
+ tmp_seg_label = seg_label.clone().unsqueeze(1)
52
+ tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
53
+ seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
54
+ sort_prob, sort_indices = seg_prob[valid_mask].sort()
55
+
56
+ if sort_prob.numel() > 0:
57
+ min_threshold = sort_prob[min(batch_kept,
58
+ sort_prob.numel() - 1)]
59
+ else:
60
+ min_threshold = 0.0
61
+ threshold = max(min_threshold, self.thresh)
62
+ valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
63
+ else:
64
+ losses = self.context.loss_decode(
65
+ seg_logit,
66
+ seg_label,
67
+ weight=None,
68
+ ignore_index=self.context.ignore_index,
69
+ reduction_override='none')
70
+ # faster than topk according to https://github.com/pytorch/pytorch/issues/22812 # noqa
71
+ _, sort_indices = losses[valid_mask].sort(descending=True)
72
+ valid_seg_weight[sort_indices[:batch_kept]] = 1.
73
+
74
+ seg_weight[valid_mask] = valid_seg_weight
75
+
76
+ return seg_weight
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .misc import add_prefix
2
+
3
+ __all__ = ['add_prefix']
microsoftexcel-controlnet/annotator/mmpkg/mmseg/core/utils/misc.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def add_prefix(inputs, prefix):
2
+ """Add prefix for dict.
3
+
4
+ Args:
5
+ inputs (dict): The input dict with str keys.
6
+ prefix (str): The prefix to add.
7
+
8
+ Returns:
9
+
10
+ dict: The dict with keys updated with ``prefix``.
11
+ """
12
+
13
+ outputs = dict()
14
+ for name, value in inputs.items():
15
+ outputs[f'{prefix}.{name}'] = value
16
+
17
+ return outputs
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ade import ADE20KDataset
2
+ from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
3
+ from .chase_db1 import ChaseDB1Dataset
4
+ from .cityscapes import CityscapesDataset
5
+ from .custom import CustomDataset
6
+ from .dataset_wrappers import ConcatDataset, RepeatDataset
7
+ from .drive import DRIVEDataset
8
+ from .hrf import HRFDataset
9
+ from .pascal_context import PascalContextDataset, PascalContextDataset59
10
+ from .stare import STAREDataset
11
+ from .voc import PascalVOCDataset
12
+
13
+ __all__ = [
14
+ 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
15
+ 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
16
+ 'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
17
+ 'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
18
+ 'STAREDataset'
19
+ ]
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/ade.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .builder import DATASETS
2
+ from .custom import CustomDataset
3
+
4
+
5
+ @DATASETS.register_module()
6
+ class ADE20KDataset(CustomDataset):
7
+ """ADE20K dataset.
8
+
9
+ In segmentation map annotation for ADE20K, 0 stands for background, which
10
+ is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
11
+ The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
12
+ '.png'.
13
+ """
14
+ CLASSES = (
15
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
16
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
17
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
18
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
19
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
20
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
21
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
22
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
23
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
24
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
25
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
26
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
27
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
28
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
29
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
30
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
31
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
32
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
33
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
34
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
35
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
36
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
37
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
38
+ 'clock', 'flag')
39
+
40
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
41
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
42
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
43
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
44
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
45
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
46
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
47
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
48
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
49
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
50
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
51
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
52
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
53
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
54
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
55
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
56
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
57
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
58
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
59
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
60
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
61
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
62
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
63
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
64
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
65
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
66
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
67
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
68
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
69
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
70
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
71
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
72
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
73
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
74
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
75
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
76
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
77
+ [102, 255, 0], [92, 0, 255]]
78
+
79
+ def __init__(self, **kwargs):
80
+ super(ADE20KDataset, self).__init__(
81
+ img_suffix='.jpg',
82
+ seg_map_suffix='.png',
83
+ reduce_zero_label=True,
84
+ **kwargs)
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/builder.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import platform
3
+ import random
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ from annotator.mmpkg.mmcv.parallel import collate
8
+ from annotator.mmpkg.mmcv.runner import get_dist_info
9
+ from annotator.mmpkg.mmcv.utils import Registry, build_from_cfg
10
+ from annotator.mmpkg.mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
11
+ from torch.utils.data import DistributedSampler
12
+
13
+ if platform.system() != 'Windows':
14
+ # https://github.com/pytorch/pytorch/issues/973
15
+ import resource
16
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
17
+ hard_limit = rlimit[1]
18
+ soft_limit = min(4096, hard_limit)
19
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
20
+
21
+ DATASETS = Registry('dataset')
22
+ PIPELINES = Registry('pipeline')
23
+
24
+
25
+ def _concat_dataset(cfg, default_args=None):
26
+ """Build :obj:`ConcatDataset by."""
27
+ from .dataset_wrappers import ConcatDataset
28
+ img_dir = cfg['img_dir']
29
+ ann_dir = cfg.get('ann_dir', None)
30
+ split = cfg.get('split', None)
31
+ num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
32
+ if ann_dir is not None:
33
+ num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
34
+ else:
35
+ num_ann_dir = 0
36
+ if split is not None:
37
+ num_split = len(split) if isinstance(split, (list, tuple)) else 1
38
+ else:
39
+ num_split = 0
40
+ if num_img_dir > 1:
41
+ assert num_img_dir == num_ann_dir or num_ann_dir == 0
42
+ assert num_img_dir == num_split or num_split == 0
43
+ else:
44
+ assert num_split == num_ann_dir or num_ann_dir <= 1
45
+ num_dset = max(num_split, num_img_dir)
46
+
47
+ datasets = []
48
+ for i in range(num_dset):
49
+ data_cfg = copy.deepcopy(cfg)
50
+ if isinstance(img_dir, (list, tuple)):
51
+ data_cfg['img_dir'] = img_dir[i]
52
+ if isinstance(ann_dir, (list, tuple)):
53
+ data_cfg['ann_dir'] = ann_dir[i]
54
+ if isinstance(split, (list, tuple)):
55
+ data_cfg['split'] = split[i]
56
+ datasets.append(build_dataset(data_cfg, default_args))
57
+
58
+ return ConcatDataset(datasets)
59
+
60
+
61
+ def build_dataset(cfg, default_args=None):
62
+ """Build datasets."""
63
+ from .dataset_wrappers import ConcatDataset, RepeatDataset
64
+ if isinstance(cfg, (list, tuple)):
65
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
66
+ elif cfg['type'] == 'RepeatDataset':
67
+ dataset = RepeatDataset(
68
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
69
+ elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
70
+ cfg.get('split', None), (list, tuple)):
71
+ dataset = _concat_dataset(cfg, default_args)
72
+ else:
73
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
74
+
75
+ return dataset
76
+
77
+
78
+ def build_dataloader(dataset,
79
+ samples_per_gpu,
80
+ workers_per_gpu,
81
+ num_gpus=1,
82
+ dist=True,
83
+ shuffle=True,
84
+ seed=None,
85
+ drop_last=False,
86
+ pin_memory=True,
87
+ dataloader_type='PoolDataLoader',
88
+ **kwargs):
89
+ """Build PyTorch DataLoader.
90
+
91
+ In distributed training, each GPU/process has a dataloader.
92
+ In non-distributed training, there is only one dataloader for all GPUs.
93
+
94
+ Args:
95
+ dataset (Dataset): A PyTorch dataset.
96
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
97
+ batch size of each GPU.
98
+ workers_per_gpu (int): How many subprocesses to use for data loading
99
+ for each GPU.
100
+ num_gpus (int): Number of GPUs. Only used in non-distributed training.
101
+ dist (bool): Distributed training/test or not. Default: True.
102
+ shuffle (bool): Whether to shuffle the data at every epoch.
103
+ Default: True.
104
+ seed (int | None): Seed to be used. Default: None.
105
+ drop_last (bool): Whether to drop the last incomplete batch in epoch.
106
+ Default: False
107
+ pin_memory (bool): Whether to use pin_memory in DataLoader.
108
+ Default: True
109
+ dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
110
+ kwargs: any keyword argument to be used to initialize DataLoader
111
+
112
+ Returns:
113
+ DataLoader: A PyTorch dataloader.
114
+ """
115
+ rank, world_size = get_dist_info()
116
+ if dist:
117
+ sampler = DistributedSampler(
118
+ dataset, world_size, rank, shuffle=shuffle)
119
+ shuffle = False
120
+ batch_size = samples_per_gpu
121
+ num_workers = workers_per_gpu
122
+ else:
123
+ sampler = None
124
+ batch_size = num_gpus * samples_per_gpu
125
+ num_workers = num_gpus * workers_per_gpu
126
+
127
+ init_fn = partial(
128
+ worker_init_fn, num_workers=num_workers, rank=rank,
129
+ seed=seed) if seed is not None else None
130
+
131
+ assert dataloader_type in (
132
+ 'DataLoader',
133
+ 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
134
+
135
+ if dataloader_type == 'PoolDataLoader':
136
+ dataloader = PoolDataLoader
137
+ elif dataloader_type == 'DataLoader':
138
+ dataloader = DataLoader
139
+
140
+ data_loader = dataloader(
141
+ dataset,
142
+ batch_size=batch_size,
143
+ sampler=sampler,
144
+ num_workers=num_workers,
145
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
146
+ pin_memory=pin_memory,
147
+ shuffle=shuffle,
148
+ worker_init_fn=init_fn,
149
+ drop_last=drop_last,
150
+ **kwargs)
151
+
152
+ return data_loader
153
+
154
+
155
+ def worker_init_fn(worker_id, num_workers, rank, seed):
156
+ """Worker init func for dataloader.
157
+
158
+ The seed of each worker equals to num_worker * rank + worker_id + user_seed
159
+
160
+ Args:
161
+ worker_id (int): Worker id.
162
+ num_workers (int): Number of workers.
163
+ rank (int): The rank of current process.
164
+ seed (int): The random seed to use.
165
+ """
166
+
167
+ worker_seed = num_workers * rank + worker_id + seed
168
+ np.random.seed(worker_seed)
169
+ random.seed(worker_seed)
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/chase_db1.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from .builder import DATASETS
4
+ from .custom import CustomDataset
5
+
6
+
7
+ @DATASETS.register_module()
8
+ class ChaseDB1Dataset(CustomDataset):
9
+ """Chase_db1 dataset.
10
+
11
+ In segmentation map annotation for Chase_db1, 0 stands for background,
12
+ which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
13
+ The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
14
+ '_1stHO.png'.
15
+ """
16
+
17
+ CLASSES = ('background', 'vessel')
18
+
19
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
20
+
21
+ def __init__(self, **kwargs):
22
+ super(ChaseDB1Dataset, self).__init__(
23
+ img_suffix='.png',
24
+ seg_map_suffix='_1stHO.png',
25
+ reduce_zero_label=False,
26
+ **kwargs)
27
+ assert osp.exists(self.img_dir)
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/cityscapes.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import tempfile
3
+
4
+ import annotator.mmpkg.mmcv as mmcv
5
+ import numpy as np
6
+ from annotator.mmpkg.mmcv.utils import print_log
7
+ from PIL import Image
8
+
9
+ from .builder import DATASETS
10
+ from .custom import CustomDataset
11
+
12
+
13
+ @DATASETS.register_module()
14
+ class CityscapesDataset(CustomDataset):
15
+ """Cityscapes dataset.
16
+
17
+ The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
18
+ fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
19
+ """
20
+
21
+ CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
22
+ 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
23
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
24
+ 'bicycle')
25
+
26
+ PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
27
+ [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
28
+ [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
29
+ [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
30
+ [0, 80, 100], [0, 0, 230], [119, 11, 32]]
31
+
32
+ def __init__(self, **kwargs):
33
+ super(CityscapesDataset, self).__init__(
34
+ img_suffix='_leftImg8bit.png',
35
+ seg_map_suffix='_gtFine_labelTrainIds.png',
36
+ **kwargs)
37
+
38
+ @staticmethod
39
+ def _convert_to_label_id(result):
40
+ """Convert trainId to id for cityscapes."""
41
+ if isinstance(result, str):
42
+ result = np.load(result)
43
+ import cityscapesscripts.helpers.labels as CSLabels
44
+ result_copy = result.copy()
45
+ for trainId, label in CSLabels.trainId2label.items():
46
+ result_copy[result == trainId] = label.id
47
+
48
+ return result_copy
49
+
50
+ def results2img(self, results, imgfile_prefix, to_label_id):
51
+ """Write the segmentation results to images.
52
+
53
+ Args:
54
+ results (list[list | tuple | ndarray]): Testing results of the
55
+ dataset.
56
+ imgfile_prefix (str): The filename prefix of the png files.
57
+ If the prefix is "somepath/xxx",
58
+ the png files will be named "somepath/xxx.png".
59
+ to_label_id (bool): whether convert output to label_id for
60
+ submission
61
+
62
+ Returns:
63
+ list[str: str]: result txt files which contains corresponding
64
+ semantic segmentation images.
65
+ """
66
+ mmcv.mkdir_or_exist(imgfile_prefix)
67
+ result_files = []
68
+ prog_bar = mmcv.ProgressBar(len(self))
69
+ for idx in range(len(self)):
70
+ result = results[idx]
71
+ if to_label_id:
72
+ result = self._convert_to_label_id(result)
73
+ filename = self.img_infos[idx]['filename']
74
+ basename = osp.splitext(osp.basename(filename))[0]
75
+
76
+ png_filename = osp.join(imgfile_prefix, f'{basename}.png')
77
+
78
+ output = Image.fromarray(result.astype(np.uint8)).convert('P')
79
+ import cityscapesscripts.helpers.labels as CSLabels
80
+ palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
81
+ for label_id, label in CSLabels.id2label.items():
82
+ palette[label_id] = label.color
83
+
84
+ output.putpalette(palette)
85
+ output.save(png_filename)
86
+ result_files.append(png_filename)
87
+ prog_bar.update()
88
+
89
+ return result_files
90
+
91
+ def format_results(self, results, imgfile_prefix=None, to_label_id=True):
92
+ """Format the results into dir (standard format for Cityscapes
93
+ evaluation).
94
+
95
+ Args:
96
+ results (list): Testing results of the dataset.
97
+ imgfile_prefix (str | None): The prefix of images files. It
98
+ includes the file path and the prefix of filename, e.g.,
99
+ "a/b/prefix". If not specified, a temp file will be created.
100
+ Default: None.
101
+ to_label_id (bool): whether convert output to label_id for
102
+ submission. Default: False
103
+
104
+ Returns:
105
+ tuple: (result_files, tmp_dir), result_files is a list containing
106
+ the image paths, tmp_dir is the temporal directory created
107
+ for saving json/png files when img_prefix is not specified.
108
+ """
109
+
110
+ assert isinstance(results, list), 'results must be a list'
111
+ assert len(results) == len(self), (
112
+ 'The length of results is not equal to the dataset len: '
113
+ f'{len(results)} != {len(self)}')
114
+
115
+ if imgfile_prefix is None:
116
+ tmp_dir = tempfile.TemporaryDirectory()
117
+ imgfile_prefix = tmp_dir.name
118
+ else:
119
+ tmp_dir = None
120
+ result_files = self.results2img(results, imgfile_prefix, to_label_id)
121
+
122
+ return result_files, tmp_dir
123
+
124
+ def evaluate(self,
125
+ results,
126
+ metric='mIoU',
127
+ logger=None,
128
+ imgfile_prefix=None,
129
+ efficient_test=False):
130
+ """Evaluation in Cityscapes/default protocol.
131
+
132
+ Args:
133
+ results (list): Testing results of the dataset.
134
+ metric (str | list[str]): Metrics to be evaluated.
135
+ logger (logging.Logger | None | str): Logger used for printing
136
+ related information during evaluation. Default: None.
137
+ imgfile_prefix (str | None): The prefix of output image file,
138
+ for cityscapes evaluation only. It includes the file path and
139
+ the prefix of filename, e.g., "a/b/prefix".
140
+ If results are evaluated with cityscapes protocol, it would be
141
+ the prefix of output png files. The output files would be
142
+ png images under folder "a/b/prefix/xxx.png", where "xxx" is
143
+ the image name of cityscapes. If not specified, a temp file
144
+ will be created for evaluation.
145
+ Default: None.
146
+
147
+ Returns:
148
+ dict[str, float]: Cityscapes/default metrics.
149
+ """
150
+
151
+ eval_results = dict()
152
+ metrics = metric.copy() if isinstance(metric, list) else [metric]
153
+ if 'cityscapes' in metrics:
154
+ eval_results.update(
155
+ self._evaluate_cityscapes(results, logger, imgfile_prefix))
156
+ metrics.remove('cityscapes')
157
+ if len(metrics) > 0:
158
+ eval_results.update(
159
+ super(CityscapesDataset,
160
+ self).evaluate(results, metrics, logger, efficient_test))
161
+
162
+ return eval_results
163
+
164
+ def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
165
+ """Evaluation in Cityscapes protocol.
166
+
167
+ Args:
168
+ results (list): Testing results of the dataset.
169
+ logger (logging.Logger | str | None): Logger used for printing
170
+ related information during evaluation. Default: None.
171
+ imgfile_prefix (str | None): The prefix of output image file
172
+
173
+ Returns:
174
+ dict[str: float]: Cityscapes evaluation results.
175
+ """
176
+ try:
177
+ import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa
178
+ except ImportError:
179
+ raise ImportError('Please run "pip install cityscapesscripts" to '
180
+ 'install cityscapesscripts first.')
181
+ msg = 'Evaluating in Cityscapes style'
182
+ if logger is None:
183
+ msg = '\n' + msg
184
+ print_log(msg, logger=logger)
185
+
186
+ result_files, tmp_dir = self.format_results(results, imgfile_prefix)
187
+
188
+ if tmp_dir is None:
189
+ result_dir = imgfile_prefix
190
+ else:
191
+ result_dir = tmp_dir.name
192
+
193
+ eval_results = dict()
194
+ print_log(f'Evaluating results under {result_dir} ...', logger=logger)
195
+
196
+ CSEval.args.evalInstLevelScore = True
197
+ CSEval.args.predictionPath = osp.abspath(result_dir)
198
+ CSEval.args.evalPixelAccuracy = True
199
+ CSEval.args.JSONOutput = False
200
+
201
+ seg_map_list = []
202
+ pred_list = []
203
+
204
+ # when evaluating with official cityscapesscripts,
205
+ # **_gtFine_labelIds.png is used
206
+ for seg_map in mmcv.scandir(
207
+ self.ann_dir, 'gtFine_labelIds.png', recursive=True):
208
+ seg_map_list.append(osp.join(self.ann_dir, seg_map))
209
+ pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
210
+
211
+ eval_results.update(
212
+ CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
213
+
214
+ if tmp_dir is not None:
215
+ tmp_dir.cleanup()
216
+
217
+ return eval_results
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/custom.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ from collections import OrderedDict
4
+ from functools import reduce
5
+
6
+ import annotator.mmpkg.mmcv as mmcv
7
+ import numpy as np
8
+ from annotator.mmpkg.mmcv.utils import print_log
9
+ from torch.utils.data import Dataset
10
+
11
+ from annotator.mmpkg.mmseg.core import eval_metrics
12
+ from annotator.mmpkg.mmseg.utils import get_root_logger
13
+ from .builder import DATASETS
14
+ from .pipelines import Compose
15
+
16
+
17
+ @DATASETS.register_module()
18
+ class CustomDataset(Dataset):
19
+ """Custom dataset for semantic segmentation. An example of file structure
20
+ is as followed.
21
+
22
+ .. code-block:: none
23
+
24
+ ├── data
25
+ │ ├── my_dataset
26
+ │ │ ├── img_dir
27
+ │ │ │ ├── train
28
+ │ │ │ │ ├── xxx{img_suffix}
29
+ │ │ │ │ ├── yyy{img_suffix}
30
+ │ │ │ │ ├── zzz{img_suffix}
31
+ │ │ │ ├── val
32
+ │ │ ├── ann_dir
33
+ │ │ │ ├── train
34
+ │ │ │ │ ├── xxx{seg_map_suffix}
35
+ │ │ │ │ ├── yyy{seg_map_suffix}
36
+ │ │ │ │ ├── zzz{seg_map_suffix}
37
+ │ │ │ ├── val
38
+
39
+ The img/gt_semantic_seg pair of CustomDataset should be of the same
40
+ except suffix. A valid img/gt_semantic_seg filename pair should be like
41
+ ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
42
+ in the suffix). If split is given, then ``xxx`` is specified in txt file.
43
+ Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
44
+ Please refer to ``docs/tutorials/new_dataset.md`` for more details.
45
+
46
+
47
+ Args:
48
+ pipeline (list[dict]): Processing pipeline
49
+ img_dir (str): Path to image directory
50
+ img_suffix (str): Suffix of images. Default: '.jpg'
51
+ ann_dir (str, optional): Path to annotation directory. Default: None
52
+ seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
53
+ split (str, optional): Split txt file. If split is specified, only
54
+ file with suffix in the splits will be loaded. Otherwise, all
55
+ images in img_dir/ann_dir will be loaded. Default: None
56
+ data_root (str, optional): Data root for img_dir/ann_dir. Default:
57
+ None.
58
+ test_mode (bool): If test_mode=True, gt wouldn't be loaded.
59
+ ignore_index (int): The label index to be ignored. Default: 255
60
+ reduce_zero_label (bool): Whether to mark label zero as ignored.
61
+ Default: False
62
+ classes (str | Sequence[str], optional): Specify classes to load.
63
+ If is None, ``cls.CLASSES`` will be used. Default: None.
64
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
65
+ The palette of segmentation map. If None is given, and
66
+ self.PALETTE is None, random palette will be generated.
67
+ Default: None
68
+ """
69
+
70
+ CLASSES = None
71
+
72
+ PALETTE = None
73
+
74
+ def __init__(self,
75
+ pipeline,
76
+ img_dir,
77
+ img_suffix='.jpg',
78
+ ann_dir=None,
79
+ seg_map_suffix='.png',
80
+ split=None,
81
+ data_root=None,
82
+ test_mode=False,
83
+ ignore_index=255,
84
+ reduce_zero_label=False,
85
+ classes=None,
86
+ palette=None):
87
+ self.pipeline = Compose(pipeline)
88
+ self.img_dir = img_dir
89
+ self.img_suffix = img_suffix
90
+ self.ann_dir = ann_dir
91
+ self.seg_map_suffix = seg_map_suffix
92
+ self.split = split
93
+ self.data_root = data_root
94
+ self.test_mode = test_mode
95
+ self.ignore_index = ignore_index
96
+ self.reduce_zero_label = reduce_zero_label
97
+ self.label_map = None
98
+ self.CLASSES, self.PALETTE = self.get_classes_and_palette(
99
+ classes, palette)
100
+
101
+ # join paths if data_root is specified
102
+ if self.data_root is not None:
103
+ if not osp.isabs(self.img_dir):
104
+ self.img_dir = osp.join(self.data_root, self.img_dir)
105
+ if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
106
+ self.ann_dir = osp.join(self.data_root, self.ann_dir)
107
+ if not (self.split is None or osp.isabs(self.split)):
108
+ self.split = osp.join(self.data_root, self.split)
109
+
110
+ # load annotations
111
+ self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
112
+ self.ann_dir,
113
+ self.seg_map_suffix, self.split)
114
+
115
+ def __len__(self):
116
+ """Total number of samples of data."""
117
+ return len(self.img_infos)
118
+
119
+ def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
120
+ split):
121
+ """Load annotation from directory.
122
+
123
+ Args:
124
+ img_dir (str): Path to image directory
125
+ img_suffix (str): Suffix of images.
126
+ ann_dir (str|None): Path to annotation directory.
127
+ seg_map_suffix (str|None): Suffix of segmentation maps.
128
+ split (str|None): Split txt file. If split is specified, only file
129
+ with suffix in the splits will be loaded. Otherwise, all images
130
+ in img_dir/ann_dir will be loaded. Default: None
131
+
132
+ Returns:
133
+ list[dict]: All image info of dataset.
134
+ """
135
+
136
+ img_infos = []
137
+ if split is not None:
138
+ with open(split) as f:
139
+ for line in f:
140
+ img_name = line.strip()
141
+ img_info = dict(filename=img_name + img_suffix)
142
+ if ann_dir is not None:
143
+ seg_map = img_name + seg_map_suffix
144
+ img_info['ann'] = dict(seg_map=seg_map)
145
+ img_infos.append(img_info)
146
+ else:
147
+ for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
148
+ img_info = dict(filename=img)
149
+ if ann_dir is not None:
150
+ seg_map = img.replace(img_suffix, seg_map_suffix)
151
+ img_info['ann'] = dict(seg_map=seg_map)
152
+ img_infos.append(img_info)
153
+
154
+ print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
155
+ return img_infos
156
+
157
+ def get_ann_info(self, idx):
158
+ """Get annotation by index.
159
+
160
+ Args:
161
+ idx (int): Index of data.
162
+
163
+ Returns:
164
+ dict: Annotation info of specified index.
165
+ """
166
+
167
+ return self.img_infos[idx]['ann']
168
+
169
+ def pre_pipeline(self, results):
170
+ """Prepare results dict for pipeline."""
171
+ results['seg_fields'] = []
172
+ results['img_prefix'] = self.img_dir
173
+ results['seg_prefix'] = self.ann_dir
174
+ if self.custom_classes:
175
+ results['label_map'] = self.label_map
176
+
177
+ def __getitem__(self, idx):
178
+ """Get training/test data after pipeline.
179
+
180
+ Args:
181
+ idx (int): Index of data.
182
+
183
+ Returns:
184
+ dict: Training/test data (with annotation if `test_mode` is set
185
+ False).
186
+ """
187
+
188
+ if self.test_mode:
189
+ return self.prepare_test_img(idx)
190
+ else:
191
+ return self.prepare_train_img(idx)
192
+
193
+ def prepare_train_img(self, idx):
194
+ """Get training data and annotations after pipeline.
195
+
196
+ Args:
197
+ idx (int): Index of data.
198
+
199
+ Returns:
200
+ dict: Training data and annotation after pipeline with new keys
201
+ introduced by pipeline.
202
+ """
203
+
204
+ img_info = self.img_infos[idx]
205
+ ann_info = self.get_ann_info(idx)
206
+ results = dict(img_info=img_info, ann_info=ann_info)
207
+ self.pre_pipeline(results)
208
+ return self.pipeline(results)
209
+
210
+ def prepare_test_img(self, idx):
211
+ """Get testing data after pipeline.
212
+
213
+ Args:
214
+ idx (int): Index of data.
215
+
216
+ Returns:
217
+ dict: Testing data after pipeline with new keys introduced by
218
+ pipeline.
219
+ """
220
+
221
+ img_info = self.img_infos[idx]
222
+ results = dict(img_info=img_info)
223
+ self.pre_pipeline(results)
224
+ return self.pipeline(results)
225
+
226
+ def format_results(self, results, **kwargs):
227
+ """Place holder to format result to dataset specific output."""
228
+
229
+ def get_gt_seg_maps(self, efficient_test=False):
230
+ """Get ground truth segmentation maps for evaluation."""
231
+ gt_seg_maps = []
232
+ for img_info in self.img_infos:
233
+ seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
234
+ if efficient_test:
235
+ gt_seg_map = seg_map
236
+ else:
237
+ gt_seg_map = mmcv.imread(
238
+ seg_map, flag='unchanged', backend='pillow')
239
+ gt_seg_maps.append(gt_seg_map)
240
+ return gt_seg_maps
241
+
242
+ def get_classes_and_palette(self, classes=None, palette=None):
243
+ """Get class names of current dataset.
244
+
245
+ Args:
246
+ classes (Sequence[str] | str | None): If classes is None, use
247
+ default CLASSES defined by builtin dataset. If classes is a
248
+ string, take it as a file name. The file contains the name of
249
+ classes where each line contains one class name. If classes is
250
+ a tuple or list, override the CLASSES defined by the dataset.
251
+ palette (Sequence[Sequence[int]]] | np.ndarray | None):
252
+ The palette of segmentation map. If None is given, random
253
+ palette will be generated. Default: None
254
+ """
255
+ if classes is None:
256
+ self.custom_classes = False
257
+ return self.CLASSES, self.PALETTE
258
+
259
+ self.custom_classes = True
260
+ if isinstance(classes, str):
261
+ # take it as a file path
262
+ class_names = mmcv.list_from_file(classes)
263
+ elif isinstance(classes, (tuple, list)):
264
+ class_names = classes
265
+ else:
266
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
267
+
268
+ if self.CLASSES:
269
+ if not set(classes).issubset(self.CLASSES):
270
+ raise ValueError('classes is not a subset of CLASSES.')
271
+
272
+ # dictionary, its keys are the old label ids and its values
273
+ # are the new label ids.
274
+ # used for changing pixel labels in load_annotations.
275
+ self.label_map = {}
276
+ for i, c in enumerate(self.CLASSES):
277
+ if c not in class_names:
278
+ self.label_map[i] = -1
279
+ else:
280
+ self.label_map[i] = classes.index(c)
281
+
282
+ palette = self.get_palette_for_custom_classes(class_names, palette)
283
+
284
+ return class_names, palette
285
+
286
+ def get_palette_for_custom_classes(self, class_names, palette=None):
287
+
288
+ if self.label_map is not None:
289
+ # return subset of palette
290
+ palette = []
291
+ for old_id, new_id in sorted(
292
+ self.label_map.items(), key=lambda x: x[1]):
293
+ if new_id != -1:
294
+ palette.append(self.PALETTE[old_id])
295
+ palette = type(self.PALETTE)(palette)
296
+
297
+ elif palette is None:
298
+ if self.PALETTE is None:
299
+ palette = np.random.randint(0, 255, size=(len(class_names), 3))
300
+ else:
301
+ palette = self.PALETTE
302
+
303
+ return palette
304
+
305
+ def evaluate(self,
306
+ results,
307
+ metric='mIoU',
308
+ logger=None,
309
+ efficient_test=False,
310
+ **kwargs):
311
+ """Evaluate the dataset.
312
+
313
+ Args:
314
+ results (list): Testing results of the dataset.
315
+ metric (str | list[str]): Metrics to be evaluated. 'mIoU',
316
+ 'mDice' and 'mFscore' are supported.
317
+ logger (logging.Logger | None | str): Logger used for printing
318
+ related information during evaluation. Default: None.
319
+
320
+ Returns:
321
+ dict[str, float]: Default metrics.
322
+ """
323
+
324
+ if isinstance(metric, str):
325
+ metric = [metric]
326
+ allowed_metrics = ['mIoU', 'mDice', 'mFscore']
327
+ if not set(metric).issubset(set(allowed_metrics)):
328
+ raise KeyError('metric {} is not supported'.format(metric))
329
+ eval_results = {}
330
+ gt_seg_maps = self.get_gt_seg_maps(efficient_test)
331
+ if self.CLASSES is None:
332
+ num_classes = len(
333
+ reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
334
+ else:
335
+ num_classes = len(self.CLASSES)
336
+ ret_metrics = eval_metrics(
337
+ results,
338
+ gt_seg_maps,
339
+ num_classes,
340
+ self.ignore_index,
341
+ metric,
342
+ label_map=self.label_map,
343
+ reduce_zero_label=self.reduce_zero_label)
344
+
345
+ if self.CLASSES is None:
346
+ class_names = tuple(range(num_classes))
347
+ else:
348
+ class_names = self.CLASSES
349
+
350
+ # summary table
351
+ ret_metrics_summary = OrderedDict({
352
+ ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
353
+ for ret_metric, ret_metric_value in ret_metrics.items()
354
+ })
355
+
356
+ # each class table
357
+ ret_metrics.pop('aAcc', None)
358
+ ret_metrics_class = OrderedDict({
359
+ ret_metric: np.round(ret_metric_value * 100, 2)
360
+ for ret_metric, ret_metric_value in ret_metrics.items()
361
+ })
362
+ ret_metrics_class.update({'Class': class_names})
363
+ ret_metrics_class.move_to_end('Class', last=False)
364
+
365
+ try:
366
+ from prettytable import PrettyTable
367
+ # for logger
368
+ class_table_data = PrettyTable()
369
+ for key, val in ret_metrics_class.items():
370
+ class_table_data.add_column(key, val)
371
+
372
+ summary_table_data = PrettyTable()
373
+ for key, val in ret_metrics_summary.items():
374
+ if key == 'aAcc':
375
+ summary_table_data.add_column(key, [val])
376
+ else:
377
+ summary_table_data.add_column('m' + key, [val])
378
+
379
+ print_log('per class results:', logger)
380
+ print_log('\n' + class_table_data.get_string(), logger=logger)
381
+ print_log('Summary:', logger)
382
+ print_log('\n' + summary_table_data.get_string(), logger=logger)
383
+ except ImportError: # prettytable is not installed
384
+ pass
385
+
386
+ # each metric dict
387
+ for key, value in ret_metrics_summary.items():
388
+ if key == 'aAcc':
389
+ eval_results[key] = value / 100.0
390
+ else:
391
+ eval_results['m' + key] = value / 100.0
392
+
393
+ ret_metrics_class.pop('Class', None)
394
+ for key, value in ret_metrics_class.items():
395
+ eval_results.update({
396
+ key + '.' + str(name): value[idx] / 100.0
397
+ for idx, name in enumerate(class_names)
398
+ })
399
+
400
+ if mmcv.is_list_of(results, str):
401
+ for file_name in results:
402
+ os.remove(file_name)
403
+ return eval_results
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/dataset_wrappers.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
2
+
3
+ from .builder import DATASETS
4
+
5
+
6
+ @DATASETS.register_module()
7
+ class ConcatDataset(_ConcatDataset):
8
+ """A wrapper of concatenated dataset.
9
+
10
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
11
+ concat the group flag for image aspect ratio.
12
+
13
+ Args:
14
+ datasets (list[:obj:`Dataset`]): A list of datasets.
15
+ """
16
+
17
+ def __init__(self, datasets):
18
+ super(ConcatDataset, self).__init__(datasets)
19
+ self.CLASSES = datasets[0].CLASSES
20
+ self.PALETTE = datasets[0].PALETTE
21
+
22
+
23
+ @DATASETS.register_module()
24
+ class RepeatDataset(object):
25
+ """A wrapper of repeated dataset.
26
+
27
+ The length of repeated dataset will be `times` larger than the original
28
+ dataset. This is useful when the data loading time is long but the dataset
29
+ is small. Using RepeatDataset can reduce the data loading time between
30
+ epochs.
31
+
32
+ Args:
33
+ dataset (:obj:`Dataset`): The dataset to be repeated.
34
+ times (int): Repeat times.
35
+ """
36
+
37
+ def __init__(self, dataset, times):
38
+ self.dataset = dataset
39
+ self.times = times
40
+ self.CLASSES = dataset.CLASSES
41
+ self.PALETTE = dataset.PALETTE
42
+ self._ori_len = len(self.dataset)
43
+
44
+ def __getitem__(self, idx):
45
+ """Get item from original dataset."""
46
+ return self.dataset[idx % self._ori_len]
47
+
48
+ def __len__(self):
49
+ """The length is multiplied by ``times``"""
50
+ return self.times * self._ori_len
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/drive.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from .builder import DATASETS
4
+ from .custom import CustomDataset
5
+
6
+
7
+ @DATASETS.register_module()
8
+ class DRIVEDataset(CustomDataset):
9
+ """DRIVE dataset.
10
+
11
+ In segmentation map annotation for DRIVE, 0 stands for background, which is
12
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
13
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
14
+ '_manual1.png'.
15
+ """
16
+
17
+ CLASSES = ('background', 'vessel')
18
+
19
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
20
+
21
+ def __init__(self, **kwargs):
22
+ super(DRIVEDataset, self).__init__(
23
+ img_suffix='.png',
24
+ seg_map_suffix='_manual1.png',
25
+ reduce_zero_label=False,
26
+ **kwargs)
27
+ assert osp.exists(self.img_dir)
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/hrf.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from .builder import DATASETS
4
+ from .custom import CustomDataset
5
+
6
+
7
+ @DATASETS.register_module()
8
+ class HRFDataset(CustomDataset):
9
+ """HRF dataset.
10
+
11
+ In segmentation map annotation for HRF, 0 stands for background, which is
12
+ included in 2 categories. ``reduce_zero_label`` is fixed to False. The
13
+ ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
14
+ '.png'.
15
+ """
16
+
17
+ CLASSES = ('background', 'vessel')
18
+
19
+ PALETTE = [[120, 120, 120], [6, 230, 230]]
20
+
21
+ def __init__(self, **kwargs):
22
+ super(HRFDataset, self).__init__(
23
+ img_suffix='.png',
24
+ seg_map_suffix='.png',
25
+ reduce_zero_label=False,
26
+ **kwargs)
27
+ assert osp.exists(self.img_dir)
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pascal_context.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ from .builder import DATASETS
4
+ from .custom import CustomDataset
5
+
6
+
7
+ @DATASETS.register_module()
8
+ class PascalContextDataset(CustomDataset):
9
+ """PascalContext dataset.
10
+
11
+ In segmentation map annotation for PascalContext, 0 stands for background,
12
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
13
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
14
+ fixed to '.png'.
15
+
16
+ Args:
17
+ split (str): Split txt file for PascalContext.
18
+ """
19
+
20
+ CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
21
+ 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
22
+ 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
23
+ 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
24
+ 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
25
+ 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
26
+ 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
27
+ 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
28
+ 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
29
+ 'window', 'wood')
30
+
31
+ PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
32
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
33
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
34
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
35
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
36
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
37
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
38
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
39
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
40
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
41
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
42
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
43
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
44
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
45
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
46
+
47
+ def __init__(self, split, **kwargs):
48
+ super(PascalContextDataset, self).__init__(
49
+ img_suffix='.jpg',
50
+ seg_map_suffix='.png',
51
+ split=split,
52
+ reduce_zero_label=False,
53
+ **kwargs)
54
+ assert osp.exists(self.img_dir) and self.split is not None
55
+
56
+
57
+ @DATASETS.register_module()
58
+ class PascalContextDataset59(CustomDataset):
59
+ """PascalContext dataset.
60
+
61
+ In segmentation map annotation for PascalContext, 0 stands for background,
62
+ which is included in 60 categories. ``reduce_zero_label`` is fixed to
63
+ False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
64
+ fixed to '.png'.
65
+
66
+ Args:
67
+ split (str): Split txt file for PascalContext.
68
+ """
69
+
70
+ CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
71
+ 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
72
+ 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
73
+ 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
74
+ 'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
75
+ 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
76
+ 'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
77
+ 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
78
+ 'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
79
+
80
+ PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
81
+ [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
82
+ [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
83
+ [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
84
+ [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
85
+ [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
86
+ [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
87
+ [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
88
+ [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
89
+ [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
90
+ [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
91
+ [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
92
+ [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
93
+ [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
94
+ [0, 235, 255], [0, 173, 255], [31, 0, 255]]
95
+
96
+ def __init__(self, split, **kwargs):
97
+ super(PascalContextDataset59, self).__init__(
98
+ img_suffix='.jpg',
99
+ seg_map_suffix='.png',
100
+ split=split,
101
+ reduce_zero_label=True,
102
+ **kwargs)
103
+ assert osp.exists(self.img_dir) and self.split is not None
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .compose import Compose
2
+ from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
3
+ Transpose, to_tensor)
4
+ from .loading import LoadAnnotations, LoadImageFromFile
5
+ from .test_time_aug import MultiScaleFlipAug
6
+ from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
7
+ PhotoMetricDistortion, RandomCrop, RandomFlip,
8
+ RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
9
+
10
+ __all__ = [
11
+ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
12
+ 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
13
+ 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
14
+ 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
15
+ 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
16
+ ]
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/compose.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+
3
+ from annotator.mmpkg.mmcv.utils import build_from_cfg
4
+
5
+ from ..builder import PIPELINES
6
+
7
+
8
+ @PIPELINES.register_module()
9
+ class Compose(object):
10
+ """Compose multiple transforms sequentially.
11
+
12
+ Args:
13
+ transforms (Sequence[dict | callable]): Sequence of transform object or
14
+ config dict to be composed.
15
+ """
16
+
17
+ def __init__(self, transforms):
18
+ assert isinstance(transforms, collections.abc.Sequence)
19
+ self.transforms = []
20
+ for transform in transforms:
21
+ if isinstance(transform, dict):
22
+ transform = build_from_cfg(transform, PIPELINES)
23
+ self.transforms.append(transform)
24
+ elif callable(transform):
25
+ self.transforms.append(transform)
26
+ else:
27
+ raise TypeError('transform must be callable or a dict')
28
+
29
+ def __call__(self, data):
30
+ """Call function to apply transforms sequentially.
31
+
32
+ Args:
33
+ data (dict): A result dict contains the data to transform.
34
+
35
+ Returns:
36
+ dict: Transformed data.
37
+ """
38
+
39
+ for t in self.transforms:
40
+ data = t(data)
41
+ if data is None:
42
+ return None
43
+ return data
44
+
45
+ def __repr__(self):
46
+ format_string = self.__class__.__name__ + '('
47
+ for t in self.transforms:
48
+ format_string += '\n'
49
+ format_string += f' {t}'
50
+ format_string += '\n)'
51
+ return format_string
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/formating.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+
3
+ import annotator.mmpkg.mmcv as mmcv
4
+ import numpy as np
5
+ import torch
6
+ from annotator.mmpkg.mmcv.parallel import DataContainer as DC
7
+
8
+ from ..builder import PIPELINES
9
+
10
+
11
+ def to_tensor(data):
12
+ """Convert objects of various python types to :obj:`torch.Tensor`.
13
+
14
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
15
+ :class:`Sequence`, :class:`int` and :class:`float`.
16
+
17
+ Args:
18
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
19
+ be converted.
20
+ """
21
+
22
+ if isinstance(data, torch.Tensor):
23
+ return data
24
+ elif isinstance(data, np.ndarray):
25
+ return torch.from_numpy(data)
26
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
27
+ return torch.tensor(data)
28
+ elif isinstance(data, int):
29
+ return torch.LongTensor([data])
30
+ elif isinstance(data, float):
31
+ return torch.FloatTensor([data])
32
+ else:
33
+ raise TypeError(f'type {type(data)} cannot be converted to tensor.')
34
+
35
+
36
+ @PIPELINES.register_module()
37
+ class ToTensor(object):
38
+ """Convert some results to :obj:`torch.Tensor` by given keys.
39
+
40
+ Args:
41
+ keys (Sequence[str]): Keys that need to be converted to Tensor.
42
+ """
43
+
44
+ def __init__(self, keys):
45
+ self.keys = keys
46
+
47
+ def __call__(self, results):
48
+ """Call function to convert data in results to :obj:`torch.Tensor`.
49
+
50
+ Args:
51
+ results (dict): Result dict contains the data to convert.
52
+
53
+ Returns:
54
+ dict: The result dict contains the data converted
55
+ to :obj:`torch.Tensor`.
56
+ """
57
+
58
+ for key in self.keys:
59
+ results[key] = to_tensor(results[key])
60
+ return results
61
+
62
+ def __repr__(self):
63
+ return self.__class__.__name__ + f'(keys={self.keys})'
64
+
65
+
66
+ @PIPELINES.register_module()
67
+ class ImageToTensor(object):
68
+ """Convert image to :obj:`torch.Tensor` by given keys.
69
+
70
+ The dimension order of input image is (H, W, C). The pipeline will convert
71
+ it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
72
+ (1, H, W).
73
+
74
+ Args:
75
+ keys (Sequence[str]): Key of images to be converted to Tensor.
76
+ """
77
+
78
+ def __init__(self, keys):
79
+ self.keys = keys
80
+
81
+ def __call__(self, results):
82
+ """Call function to convert image in results to :obj:`torch.Tensor` and
83
+ transpose the channel order.
84
+
85
+ Args:
86
+ results (dict): Result dict contains the image data to convert.
87
+
88
+ Returns:
89
+ dict: The result dict contains the image converted
90
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
91
+ """
92
+
93
+ for key in self.keys:
94
+ img = results[key]
95
+ if len(img.shape) < 3:
96
+ img = np.expand_dims(img, -1)
97
+ results[key] = to_tensor(img.transpose(2, 0, 1))
98
+ return results
99
+
100
+ def __repr__(self):
101
+ return self.__class__.__name__ + f'(keys={self.keys})'
102
+
103
+
104
+ @PIPELINES.register_module()
105
+ class Transpose(object):
106
+ """Transpose some results by given keys.
107
+
108
+ Args:
109
+ keys (Sequence[str]): Keys of results to be transposed.
110
+ order (Sequence[int]): Order of transpose.
111
+ """
112
+
113
+ def __init__(self, keys, order):
114
+ self.keys = keys
115
+ self.order = order
116
+
117
+ def __call__(self, results):
118
+ """Call function to convert image in results to :obj:`torch.Tensor` and
119
+ transpose the channel order.
120
+
121
+ Args:
122
+ results (dict): Result dict contains the image data to convert.
123
+
124
+ Returns:
125
+ dict: The result dict contains the image converted
126
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
127
+ """
128
+
129
+ for key in self.keys:
130
+ results[key] = results[key].transpose(self.order)
131
+ return results
132
+
133
+ def __repr__(self):
134
+ return self.__class__.__name__ + \
135
+ f'(keys={self.keys}, order={self.order})'
136
+
137
+
138
+ @PIPELINES.register_module()
139
+ class ToDataContainer(object):
140
+ """Convert results to :obj:`mmcv.DataContainer` by given fields.
141
+
142
+ Args:
143
+ fields (Sequence[dict]): Each field is a dict like
144
+ ``dict(key='xxx', **kwargs)``. The ``key`` in result will
145
+ be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
146
+ Default: ``(dict(key='img', stack=True),
147
+ dict(key='gt_semantic_seg'))``.
148
+ """
149
+
150
+ def __init__(self,
151
+ fields=(dict(key='img',
152
+ stack=True), dict(key='gt_semantic_seg'))):
153
+ self.fields = fields
154
+
155
+ def __call__(self, results):
156
+ """Call function to convert data in results to
157
+ :obj:`mmcv.DataContainer`.
158
+
159
+ Args:
160
+ results (dict): Result dict contains the data to convert.
161
+
162
+ Returns:
163
+ dict: The result dict contains the data converted to
164
+ :obj:`mmcv.DataContainer`.
165
+ """
166
+
167
+ for field in self.fields:
168
+ field = field.copy()
169
+ key = field.pop('key')
170
+ results[key] = DC(results[key], **field)
171
+ return results
172
+
173
+ def __repr__(self):
174
+ return self.__class__.__name__ + f'(fields={self.fields})'
175
+
176
+
177
+ @PIPELINES.register_module()
178
+ class DefaultFormatBundle(object):
179
+ """Default formatting bundle.
180
+
181
+ It simplifies the pipeline of formatting common fields, including "img"
182
+ and "gt_semantic_seg". These fields are formatted as follows.
183
+
184
+ - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
185
+ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
186
+ (3)to DataContainer (stack=True)
187
+ """
188
+
189
+ def __call__(self, results):
190
+ """Call function to transform and format common fields in results.
191
+
192
+ Args:
193
+ results (dict): Result dict contains the data to convert.
194
+
195
+ Returns:
196
+ dict: The result dict contains the data that is formatted with
197
+ default bundle.
198
+ """
199
+
200
+ if 'img' in results:
201
+ img = results['img']
202
+ if len(img.shape) < 3:
203
+ img = np.expand_dims(img, -1)
204
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
205
+ results['img'] = DC(to_tensor(img), stack=True)
206
+ if 'gt_semantic_seg' in results:
207
+ # convert to long
208
+ results['gt_semantic_seg'] = DC(
209
+ to_tensor(results['gt_semantic_seg'][None,
210
+ ...].astype(np.int64)),
211
+ stack=True)
212
+ return results
213
+
214
+ def __repr__(self):
215
+ return self.__class__.__name__
216
+
217
+
218
+ @PIPELINES.register_module()
219
+ class Collect(object):
220
+ """Collect data from the loader relevant to the specific task.
221
+
222
+ This is usually the last stage of the data loader pipeline. Typically keys
223
+ is set to some subset of "img", "gt_semantic_seg".
224
+
225
+ The "img_meta" item is always populated. The contents of the "img_meta"
226
+ dictionary depends on "meta_keys". By default this includes:
227
+
228
+ - "img_shape": shape of the image input to the network as a tuple
229
+ (h, w, c). Note that images may be zero padded on the bottom/right
230
+ if the batch tensor is larger than this shape.
231
+
232
+ - "scale_factor": a float indicating the preprocessing scale
233
+
234
+ - "flip": a boolean indicating if image flip transform was used
235
+
236
+ - "filename": path to the image file
237
+
238
+ - "ori_shape": original shape of the image as a tuple (h, w, c)
239
+
240
+ - "pad_shape": image shape after padding
241
+
242
+ - "img_norm_cfg": a dict of normalization information:
243
+ - mean - per channel mean subtraction
244
+ - std - per channel std divisor
245
+ - to_rgb - bool indicating if bgr was converted to rgb
246
+
247
+ Args:
248
+ keys (Sequence[str]): Keys of results to be collected in ``data``.
249
+ meta_keys (Sequence[str], optional): Meta keys to be converted to
250
+ ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
251
+ Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
252
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
253
+ 'img_norm_cfg')``
254
+ """
255
+
256
+ def __init__(self,
257
+ keys,
258
+ meta_keys=('filename', 'ori_filename', 'ori_shape',
259
+ 'img_shape', 'pad_shape', 'scale_factor', 'flip',
260
+ 'flip_direction', 'img_norm_cfg')):
261
+ self.keys = keys
262
+ self.meta_keys = meta_keys
263
+
264
+ def __call__(self, results):
265
+ """Call function to collect keys in results. The keys in ``meta_keys``
266
+ will be converted to :obj:mmcv.DataContainer.
267
+
268
+ Args:
269
+ results (dict): Result dict contains the data to collect.
270
+
271
+ Returns:
272
+ dict: The result dict contains the following keys
273
+ - keys in``self.keys``
274
+ - ``img_metas``
275
+ """
276
+
277
+ data = {}
278
+ img_meta = {}
279
+ for key in self.meta_keys:
280
+ img_meta[key] = results[key]
281
+ data['img_metas'] = DC(img_meta, cpu_only=True)
282
+ for key in self.keys:
283
+ data[key] = results[key]
284
+ return data
285
+
286
+ def __repr__(self):
287
+ return self.__class__.__name__ + \
288
+ f'(keys={self.keys}, meta_keys={self.meta_keys})'
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/loading.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ import annotator.mmpkg.mmcv as mmcv
4
+ import numpy as np
5
+
6
+ from ..builder import PIPELINES
7
+
8
+
9
+ @PIPELINES.register_module()
10
+ class LoadImageFromFile(object):
11
+ """Load an image from file.
12
+
13
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
14
+ key "filename"). Added or updated keys are "filename", "img", "img_shape",
15
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
16
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
17
+
18
+ Args:
19
+ to_float32 (bool): Whether to convert the loaded image to a float32
20
+ numpy array. If set to False, the loaded image is an uint8 array.
21
+ Defaults to False.
22
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
23
+ Defaults to 'color'.
24
+ file_client_args (dict): Arguments to instantiate a FileClient.
25
+ See :class:`mmcv.fileio.FileClient` for details.
26
+ Defaults to ``dict(backend='disk')``.
27
+ imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
28
+ 'cv2'
29
+ """
30
+
31
+ def __init__(self,
32
+ to_float32=False,
33
+ color_type='color',
34
+ file_client_args=dict(backend='disk'),
35
+ imdecode_backend='cv2'):
36
+ self.to_float32 = to_float32
37
+ self.color_type = color_type
38
+ self.file_client_args = file_client_args.copy()
39
+ self.file_client = None
40
+ self.imdecode_backend = imdecode_backend
41
+
42
+ def __call__(self, results):
43
+ """Call functions to load image and get image meta information.
44
+
45
+ Args:
46
+ results (dict): Result dict from :obj:`mmseg.CustomDataset`.
47
+
48
+ Returns:
49
+ dict: The dict contains loaded image and meta information.
50
+ """
51
+
52
+ if self.file_client is None:
53
+ self.file_client = mmcv.FileClient(**self.file_client_args)
54
+
55
+ if results.get('img_prefix') is not None:
56
+ filename = osp.join(results['img_prefix'],
57
+ results['img_info']['filename'])
58
+ else:
59
+ filename = results['img_info']['filename']
60
+ img_bytes = self.file_client.get(filename)
61
+ img = mmcv.imfrombytes(
62
+ img_bytes, flag=self.color_type, backend=self.imdecode_backend)
63
+ if self.to_float32:
64
+ img = img.astype(np.float32)
65
+
66
+ results['filename'] = filename
67
+ results['ori_filename'] = results['img_info']['filename']
68
+ results['img'] = img
69
+ results['img_shape'] = img.shape
70
+ results['ori_shape'] = img.shape
71
+ # Set initial values for default meta_keys
72
+ results['pad_shape'] = img.shape
73
+ results['scale_factor'] = 1.0
74
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
75
+ results['img_norm_cfg'] = dict(
76
+ mean=np.zeros(num_channels, dtype=np.float32),
77
+ std=np.ones(num_channels, dtype=np.float32),
78
+ to_rgb=False)
79
+ return results
80
+
81
+ def __repr__(self):
82
+ repr_str = self.__class__.__name__
83
+ repr_str += f'(to_float32={self.to_float32},'
84
+ repr_str += f"color_type='{self.color_type}',"
85
+ repr_str += f"imdecode_backend='{self.imdecode_backend}')"
86
+ return repr_str
87
+
88
+
89
+ @PIPELINES.register_module()
90
+ class LoadAnnotations(object):
91
+ """Load annotations for semantic segmentation.
92
+
93
+ Args:
94
+ reduce_zero_label (bool): Whether reduce all label value by 1.
95
+ Usually used for datasets where 0 is background label.
96
+ Default: False.
97
+ file_client_args (dict): Arguments to instantiate a FileClient.
98
+ See :class:`mmcv.fileio.FileClient` for details.
99
+ Defaults to ``dict(backend='disk')``.
100
+ imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
101
+ 'pillow'
102
+ """
103
+
104
+ def __init__(self,
105
+ reduce_zero_label=False,
106
+ file_client_args=dict(backend='disk'),
107
+ imdecode_backend='pillow'):
108
+ self.reduce_zero_label = reduce_zero_label
109
+ self.file_client_args = file_client_args.copy()
110
+ self.file_client = None
111
+ self.imdecode_backend = imdecode_backend
112
+
113
+ def __call__(self, results):
114
+ """Call function to load multiple types annotations.
115
+
116
+ Args:
117
+ results (dict): Result dict from :obj:`mmseg.CustomDataset`.
118
+
119
+ Returns:
120
+ dict: The dict contains loaded semantic segmentation annotations.
121
+ """
122
+
123
+ if self.file_client is None:
124
+ self.file_client = mmcv.FileClient(**self.file_client_args)
125
+
126
+ if results.get('seg_prefix', None) is not None:
127
+ filename = osp.join(results['seg_prefix'],
128
+ results['ann_info']['seg_map'])
129
+ else:
130
+ filename = results['ann_info']['seg_map']
131
+ img_bytes = self.file_client.get(filename)
132
+ gt_semantic_seg = mmcv.imfrombytes(
133
+ img_bytes, flag='unchanged',
134
+ backend=self.imdecode_backend).squeeze().astype(np.uint8)
135
+ # modify if custom classes
136
+ if results.get('label_map', None) is not None:
137
+ for old_id, new_id in results['label_map'].items():
138
+ gt_semantic_seg[gt_semantic_seg == old_id] = new_id
139
+ # reduce zero_label
140
+ if self.reduce_zero_label:
141
+ # avoid using underflow conversion
142
+ gt_semantic_seg[gt_semantic_seg == 0] = 255
143
+ gt_semantic_seg = gt_semantic_seg - 1
144
+ gt_semantic_seg[gt_semantic_seg == 254] = 255
145
+ results['gt_semantic_seg'] = gt_semantic_seg
146
+ results['seg_fields'].append('gt_semantic_seg')
147
+ return results
148
+
149
+ def __repr__(self):
150
+ repr_str = self.__class__.__name__
151
+ repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
152
+ repr_str += f"imdecode_backend='{self.imdecode_backend}')"
153
+ return repr_str
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/test_time_aug.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import annotator.mmpkg.mmcv as mmcv
4
+
5
+ from ..builder import PIPELINES
6
+ from .compose import Compose
7
+
8
+
9
+ @PIPELINES.register_module()
10
+ class MultiScaleFlipAug(object):
11
+ """Test-time augmentation with multiple scales and flipping.
12
+
13
+ An example configuration is as followed:
14
+
15
+ .. code-block::
16
+
17
+ img_scale=(2048, 1024),
18
+ img_ratios=[0.5, 1.0],
19
+ flip=True,
20
+ transforms=[
21
+ dict(type='Resize', keep_ratio=True),
22
+ dict(type='RandomFlip'),
23
+ dict(type='Normalize', **img_norm_cfg),
24
+ dict(type='Pad', size_divisor=32),
25
+ dict(type='ImageToTensor', keys=['img']),
26
+ dict(type='Collect', keys=['img']),
27
+ ]
28
+
29
+ After MultiScaleFLipAug with above configuration, the results are wrapped
30
+ into lists of the same length as followed:
31
+
32
+ .. code-block::
33
+
34
+ dict(
35
+ img=[...],
36
+ img_shape=[...],
37
+ scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]
38
+ flip=[False, True, False, True]
39
+ ...
40
+ )
41
+
42
+ Args:
43
+ transforms (list[dict]): Transforms to apply in each augmentation.
44
+ img_scale (None | tuple | list[tuple]): Images scales for resizing.
45
+ img_ratios (float | list[float]): Image ratios for resizing
46
+ flip (bool): Whether apply flip augmentation. Default: False.
47
+ flip_direction (str | list[str]): Flip augmentation directions,
48
+ options are "horizontal" and "vertical". If flip_direction is list,
49
+ multiple flip augmentations will be applied.
50
+ It has no effect when flip == False. Default: "horizontal".
51
+ """
52
+
53
+ def __init__(self,
54
+ transforms,
55
+ img_scale,
56
+ img_ratios=None,
57
+ flip=False,
58
+ flip_direction='horizontal'):
59
+ self.transforms = Compose(transforms)
60
+ if img_ratios is not None:
61
+ img_ratios = img_ratios if isinstance(img_ratios,
62
+ list) else [img_ratios]
63
+ assert mmcv.is_list_of(img_ratios, float)
64
+ if img_scale is None:
65
+ # mode 1: given img_scale=None and a range of image ratio
66
+ self.img_scale = None
67
+ assert mmcv.is_list_of(img_ratios, float)
68
+ elif isinstance(img_scale, tuple) and mmcv.is_list_of(
69
+ img_ratios, float):
70
+ assert len(img_scale) == 2
71
+ # mode 2: given a scale and a range of image ratio
72
+ self.img_scale = [(int(img_scale[0] * ratio),
73
+ int(img_scale[1] * ratio))
74
+ for ratio in img_ratios]
75
+ else:
76
+ # mode 3: given multiple scales
77
+ self.img_scale = img_scale if isinstance(img_scale,
78
+ list) else [img_scale]
79
+ assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
80
+ self.flip = flip
81
+ self.img_ratios = img_ratios
82
+ self.flip_direction = flip_direction if isinstance(
83
+ flip_direction, list) else [flip_direction]
84
+ assert mmcv.is_list_of(self.flip_direction, str)
85
+ if not self.flip and self.flip_direction != ['horizontal']:
86
+ warnings.warn(
87
+ 'flip_direction has no effect when flip is set to False')
88
+ if (self.flip
89
+ and not any([t['type'] == 'RandomFlip' for t in transforms])):
90
+ warnings.warn(
91
+ 'flip has no effect when RandomFlip is not in transforms')
92
+
93
+ def __call__(self, results):
94
+ """Call function to apply test time augment transforms on results.
95
+
96
+ Args:
97
+ results (dict): Result dict contains the data to transform.
98
+
99
+ Returns:
100
+ dict[str: list]: The augmented data, where each value is wrapped
101
+ into a list.
102
+ """
103
+
104
+ aug_data = []
105
+ if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
106
+ h, w = results['img'].shape[:2]
107
+ img_scale = [(int(w * ratio), int(h * ratio))
108
+ for ratio in self.img_ratios]
109
+ else:
110
+ img_scale = self.img_scale
111
+ flip_aug = [False, True] if self.flip else [False]
112
+ for scale in img_scale:
113
+ for flip in flip_aug:
114
+ for direction in self.flip_direction:
115
+ _results = results.copy()
116
+ _results['scale'] = scale
117
+ _results['flip'] = flip
118
+ _results['flip_direction'] = direction
119
+ data = self.transforms(_results)
120
+ aug_data.append(data)
121
+ # list of dict to dict of list
122
+ aug_data_dict = {key: [] for key in aug_data[0]}
123
+ for data in aug_data:
124
+ for key, val in data.items():
125
+ aug_data_dict[key].append(val)
126
+ return aug_data_dict
127
+
128
+ def __repr__(self):
129
+ repr_str = self.__class__.__name__
130
+ repr_str += f'(transforms={self.transforms}, '
131
+ repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
132
+ repr_str += f'flip_direction={self.flip_direction}'
133
+ return repr_str
microsoftexcel-controlnet/annotator/mmpkg/mmseg/datasets/pipelines/transforms.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import annotator.mmpkg.mmcv as mmcv
2
+ import numpy as np
3
+ from annotator.mmpkg.mmcv.utils import deprecated_api_warning, is_tuple_of
4
+ from numpy import random
5
+
6
+ from ..builder import PIPELINES
7
+
8
+
9
+ @PIPELINES.register_module()
10
+ class Resize(object):
11
+ """Resize images & seg.
12
+
13
+ This transform resizes the input image to some scale. If the input dict
14
+ contains the key "scale", then the scale in the input dict is used,
15
+ otherwise the specified scale in the init method is used.
16
+
17
+ ``img_scale`` can be None, a tuple (single-scale) or a list of tuple
18
+ (multi-scale). There are 4 multiscale modes:
19
+
20
+ - ``ratio_range is not None``:
21
+ 1. When img_scale is None, img_scale is the shape of image in results
22
+ (img_scale = results['img'].shape[:2]) and the image is resized based
23
+ on the original size. (mode 1)
24
+ 2. When img_scale is a tuple (single-scale), randomly sample a ratio from
25
+ the ratio range and multiply it with the image scale. (mode 2)
26
+
27
+ - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
28
+ scale from the a range. (mode 3)
29
+
30
+ - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
31
+ scale from multiple scales. (mode 4)
32
+
33
+ Args:
34
+ img_scale (tuple or list[tuple]): Images scales for resizing.
35
+ multiscale_mode (str): Either "range" or "value".
36
+ ratio_range (tuple[float]): (min_ratio, max_ratio)
37
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
38
+ image.
39
+ """
40
+
41
+ def __init__(self,
42
+ img_scale=None,
43
+ multiscale_mode='range',
44
+ ratio_range=None,
45
+ keep_ratio=True):
46
+ if img_scale is None:
47
+ self.img_scale = None
48
+ else:
49
+ if isinstance(img_scale, list):
50
+ self.img_scale = img_scale
51
+ else:
52
+ self.img_scale = [img_scale]
53
+ assert mmcv.is_list_of(self.img_scale, tuple)
54
+
55
+ if ratio_range is not None:
56
+ # mode 1: given img_scale=None and a range of image ratio
57
+ # mode 2: given a scale and a range of image ratio
58
+ assert self.img_scale is None or len(self.img_scale) == 1
59
+ else:
60
+ # mode 3 and 4: given multiple scales or a range of scales
61
+ assert multiscale_mode in ['value', 'range']
62
+
63
+ self.multiscale_mode = multiscale_mode
64
+ self.ratio_range = ratio_range
65
+ self.keep_ratio = keep_ratio
66
+
67
+ @staticmethod
68
+ def random_select(img_scales):
69
+ """Randomly select an img_scale from given candidates.
70
+
71
+ Args:
72
+ img_scales (list[tuple]): Images scales for selection.
73
+
74
+ Returns:
75
+ (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
76
+ where ``img_scale`` is the selected image scale and
77
+ ``scale_idx`` is the selected index in the given candidates.
78
+ """
79
+
80
+ assert mmcv.is_list_of(img_scales, tuple)
81
+ scale_idx = np.random.randint(len(img_scales))
82
+ img_scale = img_scales[scale_idx]
83
+ return img_scale, scale_idx
84
+
85
+ @staticmethod
86
+ def random_sample(img_scales):
87
+ """Randomly sample an img_scale when ``multiscale_mode=='range'``.
88
+
89
+ Args:
90
+ img_scales (list[tuple]): Images scale range for sampling.
91
+ There must be two tuples in img_scales, which specify the lower
92
+ and upper bound of image scales.
93
+
94
+ Returns:
95
+ (tuple, None): Returns a tuple ``(img_scale, None)``, where
96
+ ``img_scale`` is sampled scale and None is just a placeholder
97
+ to be consistent with :func:`random_select`.
98
+ """
99
+
100
+ assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
101
+ img_scale_long = [max(s) for s in img_scales]
102
+ img_scale_short = [min(s) for s in img_scales]
103
+ long_edge = np.random.randint(
104
+ min(img_scale_long),
105
+ max(img_scale_long) + 1)
106
+ short_edge = np.random.randint(
107
+ min(img_scale_short),
108
+ max(img_scale_short) + 1)
109
+ img_scale = (long_edge, short_edge)
110
+ return img_scale, None
111
+
112
+ @staticmethod
113
+ def random_sample_ratio(img_scale, ratio_range):
114
+ """Randomly sample an img_scale when ``ratio_range`` is specified.
115
+
116
+ A ratio will be randomly sampled from the range specified by
117
+ ``ratio_range``. Then it would be multiplied with ``img_scale`` to
118
+ generate sampled scale.
119
+
120
+ Args:
121
+ img_scale (tuple): Images scale base to multiply with ratio.
122
+ ratio_range (tuple[float]): The minimum and maximum ratio to scale
123
+ the ``img_scale``.
124
+
125
+ Returns:
126
+ (tuple, None): Returns a tuple ``(scale, None)``, where
127
+ ``scale`` is sampled ratio multiplied with ``img_scale`` and
128
+ None is just a placeholder to be consistent with
129
+ :func:`random_select`.
130
+ """
131
+
132
+ assert isinstance(img_scale, tuple) and len(img_scale) == 2
133
+ min_ratio, max_ratio = ratio_range
134
+ assert min_ratio <= max_ratio
135
+ ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
136
+ scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
137
+ return scale, None
138
+
139
+ def _random_scale(self, results):
140
+ """Randomly sample an img_scale according to ``ratio_range`` and
141
+ ``multiscale_mode``.
142
+
143
+ If ``ratio_range`` is specified, a ratio will be sampled and be
144
+ multiplied with ``img_scale``.
145
+ If multiple scales are specified by ``img_scale``, a scale will be
146
+ sampled according to ``multiscale_mode``.
147
+ Otherwise, single scale will be used.
148
+
149
+ Args:
150
+ results (dict): Result dict from :obj:`dataset`.
151
+
152
+ Returns:
153
+ dict: Two new keys 'scale` and 'scale_idx` are added into
154
+ ``results``, which would be used by subsequent pipelines.
155
+ """
156
+
157
+ if self.ratio_range is not None:
158
+ if self.img_scale is None:
159
+ h, w = results['img'].shape[:2]
160
+ scale, scale_idx = self.random_sample_ratio((w, h),
161
+ self.ratio_range)
162
+ else:
163
+ scale, scale_idx = self.random_sample_ratio(
164
+ self.img_scale[0], self.ratio_range)
165
+ elif len(self.img_scale) == 1:
166
+ scale, scale_idx = self.img_scale[0], 0
167
+ elif self.multiscale_mode == 'range':
168
+ scale, scale_idx = self.random_sample(self.img_scale)
169
+ elif self.multiscale_mode == 'value':
170
+ scale, scale_idx = self.random_select(self.img_scale)
171
+ else:
172
+ raise NotImplementedError
173
+
174
+ results['scale'] = scale
175
+ results['scale_idx'] = scale_idx
176
+
177
+ def _resize_img(self, results):
178
+ """Resize images with ``results['scale']``."""
179
+ if self.keep_ratio:
180
+ img, scale_factor = mmcv.imrescale(
181
+ results['img'], results['scale'], return_scale=True)
182
+ # the w_scale and h_scale has minor difference
183
+ # a real fix should be done in the mmcv.imrescale in the future
184
+ new_h, new_w = img.shape[:2]
185
+ h, w = results['img'].shape[:2]
186
+ w_scale = new_w / w
187
+ h_scale = new_h / h
188
+ else:
189
+ img, w_scale, h_scale = mmcv.imresize(
190
+ results['img'], results['scale'], return_scale=True)
191
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
192
+ dtype=np.float32)
193
+ results['img'] = img
194
+ results['img_shape'] = img.shape
195
+ results['pad_shape'] = img.shape # in case that there is no padding
196
+ results['scale_factor'] = scale_factor
197
+ results['keep_ratio'] = self.keep_ratio
198
+
199
+ def _resize_seg(self, results):
200
+ """Resize semantic segmentation map with ``results['scale']``."""
201
+ for key in results.get('seg_fields', []):
202
+ if self.keep_ratio:
203
+ gt_seg = mmcv.imrescale(
204
+ results[key], results['scale'], interpolation='nearest')
205
+ else:
206
+ gt_seg = mmcv.imresize(
207
+ results[key], results['scale'], interpolation='nearest')
208
+ results[key] = gt_seg
209
+
210
+ def __call__(self, results):
211
+ """Call function to resize images, bounding boxes, masks, semantic
212
+ segmentation map.
213
+
214
+ Args:
215
+ results (dict): Result dict from loading pipeline.
216
+
217
+ Returns:
218
+ dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
219
+ 'keep_ratio' keys are added into result dict.
220
+ """
221
+
222
+ if 'scale' not in results:
223
+ self._random_scale(results)
224
+ self._resize_img(results)
225
+ self._resize_seg(results)
226
+ return results
227
+
228
+ def __repr__(self):
229
+ repr_str = self.__class__.__name__
230
+ repr_str += (f'(img_scale={self.img_scale}, '
231
+ f'multiscale_mode={self.multiscale_mode}, '
232
+ f'ratio_range={self.ratio_range}, '
233
+ f'keep_ratio={self.keep_ratio})')
234
+ return repr_str
235
+
236
+
237
+ @PIPELINES.register_module()
238
+ class RandomFlip(object):
239
+ """Flip the image & seg.
240
+
241
+ If the input dict contains the key "flip", then the flag will be used,
242
+ otherwise it will be randomly decided by a ratio specified in the init
243
+ method.
244
+
245
+ Args:
246
+ prob (float, optional): The flipping probability. Default: None.
247
+ direction(str, optional): The flipping direction. Options are
248
+ 'horizontal' and 'vertical'. Default: 'horizontal'.
249
+ """
250
+
251
+ @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
252
+ def __init__(self, prob=None, direction='horizontal'):
253
+ self.prob = prob
254
+ self.direction = direction
255
+ if prob is not None:
256
+ assert prob >= 0 and prob <= 1
257
+ assert direction in ['horizontal', 'vertical']
258
+
259
+ def __call__(self, results):
260
+ """Call function to flip bounding boxes, masks, semantic segmentation
261
+ maps.
262
+
263
+ Args:
264
+ results (dict): Result dict from loading pipeline.
265
+
266
+ Returns:
267
+ dict: Flipped results, 'flip', 'flip_direction' keys are added into
268
+ result dict.
269
+ """
270
+
271
+ if 'flip' not in results:
272
+ flip = True if np.random.rand() < self.prob else False
273
+ results['flip'] = flip
274
+ if 'flip_direction' not in results:
275
+ results['flip_direction'] = self.direction
276
+ if results['flip']:
277
+ # flip image
278
+ results['img'] = mmcv.imflip(
279
+ results['img'], direction=results['flip_direction'])
280
+
281
+ # flip segs
282
+ for key in results.get('seg_fields', []):
283
+ # use copy() to make numpy stride positive
284
+ results[key] = mmcv.imflip(
285
+ results[key], direction=results['flip_direction']).copy()
286
+ return results
287
+
288
+ def __repr__(self):
289
+ return self.__class__.__name__ + f'(prob={self.prob})'
290
+
291
+
292
+ @PIPELINES.register_module()
293
+ class Pad(object):
294
+ """Pad the image & mask.
295
+
296
+ There are two padding modes: (1) pad to a fixed size and (2) pad to the
297
+ minimum size that is divisible by some number.
298
+ Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
299
+
300
+ Args:
301
+ size (tuple, optional): Fixed padding size.
302
+ size_divisor (int, optional): The divisor of padded size.
303
+ pad_val (float, optional): Padding value. Default: 0.
304
+ seg_pad_val (float, optional): Padding value of segmentation map.
305
+ Default: 255.
306
+ """
307
+
308
+ def __init__(self,
309
+ size=None,
310
+ size_divisor=None,
311
+ pad_val=0,
312
+ seg_pad_val=255):
313
+ self.size = size
314
+ self.size_divisor = size_divisor
315
+ self.pad_val = pad_val
316
+ self.seg_pad_val = seg_pad_val
317
+ # only one of size and size_divisor should be valid
318
+ assert size is not None or size_divisor is not None
319
+ assert size is None or size_divisor is None
320
+
321
+ def _pad_img(self, results):
322
+ """Pad images according to ``self.size``."""
323
+ if self.size is not None:
324
+ padded_img = mmcv.impad(
325
+ results['img'], shape=self.size, pad_val=self.pad_val)
326
+ elif self.size_divisor is not None:
327
+ padded_img = mmcv.impad_to_multiple(
328
+ results['img'], self.size_divisor, pad_val=self.pad_val)
329
+ results['img'] = padded_img
330
+ results['pad_shape'] = padded_img.shape
331
+ results['pad_fixed_size'] = self.size
332
+ results['pad_size_divisor'] = self.size_divisor
333
+
334
+ def _pad_seg(self, results):
335
+ """Pad masks according to ``results['pad_shape']``."""
336
+ for key in results.get('seg_fields', []):
337
+ results[key] = mmcv.impad(
338
+ results[key],
339
+ shape=results['pad_shape'][:2],
340
+ pad_val=self.seg_pad_val)
341
+
342
+ def __call__(self, results):
343
+ """Call function to pad images, masks, semantic segmentation maps.
344
+
345
+ Args:
346
+ results (dict): Result dict from loading pipeline.
347
+
348
+ Returns:
349
+ dict: Updated result dict.
350
+ """
351
+
352
+ self._pad_img(results)
353
+ self._pad_seg(results)
354
+ return results
355
+
356
+ def __repr__(self):
357
+ repr_str = self.__class__.__name__
358
+ repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \
359
+ f'pad_val={self.pad_val})'
360
+ return repr_str
361
+
362
+
363
+ @PIPELINES.register_module()
364
+ class Normalize(object):
365
+ """Normalize the image.
366
+
367
+ Added key is "img_norm_cfg".
368
+
369
+ Args:
370
+ mean (sequence): Mean values of 3 channels.
371
+ std (sequence): Std values of 3 channels.
372
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
373
+ default is true.
374
+ """
375
+
376
+ def __init__(self, mean, std, to_rgb=True):
377
+ self.mean = np.array(mean, dtype=np.float32)
378
+ self.std = np.array(std, dtype=np.float32)
379
+ self.to_rgb = to_rgb
380
+
381
+ def __call__(self, results):
382
+ """Call function to normalize images.
383
+
384
+ Args:
385
+ results (dict): Result dict from loading pipeline.
386
+
387
+ Returns:
388
+ dict: Normalized results, 'img_norm_cfg' key is added into
389
+ result dict.
390
+ """
391
+
392
+ results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
393
+ self.to_rgb)
394
+ results['img_norm_cfg'] = dict(
395
+ mean=self.mean, std=self.std, to_rgb=self.to_rgb)
396
+ return results
397
+
398
+ def __repr__(self):
399
+ repr_str = self.__class__.__name__
400
+ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
401
+ f'{self.to_rgb})'
402
+ return repr_str
403
+
404
+
405
+ @PIPELINES.register_module()
406
+ class Rerange(object):
407
+ """Rerange the image pixel value.
408
+
409
+ Args:
410
+ min_value (float or int): Minimum value of the reranged image.
411
+ Default: 0.
412
+ max_value (float or int): Maximum value of the reranged image.
413
+ Default: 255.
414
+ """
415
+
416
+ def __init__(self, min_value=0, max_value=255):
417
+ assert isinstance(min_value, float) or isinstance(min_value, int)
418
+ assert isinstance(max_value, float) or isinstance(max_value, int)
419
+ assert min_value < max_value
420
+ self.min_value = min_value
421
+ self.max_value = max_value
422
+
423
+ def __call__(self, results):
424
+ """Call function to rerange images.
425
+
426
+ Args:
427
+ results (dict): Result dict from loading pipeline.
428
+ Returns:
429
+ dict: Reranged results.
430
+ """
431
+
432
+ img = results['img']
433
+ img_min_value = np.min(img)
434
+ img_max_value = np.max(img)
435
+
436
+ assert img_min_value < img_max_value
437
+ # rerange to [0, 1]
438
+ img = (img - img_min_value) / (img_max_value - img_min_value)
439
+ # rerange to [min_value, max_value]
440
+ img = img * (self.max_value - self.min_value) + self.min_value
441
+ results['img'] = img
442
+
443
+ return results
444
+
445
+ def __repr__(self):
446
+ repr_str = self.__class__.__name__
447
+ repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
448
+ return repr_str
449
+
450
+
451
+ @PIPELINES.register_module()
452
+ class CLAHE(object):
453
+ """Use CLAHE method to process the image.
454
+
455
+ See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
456
+ Graphics Gems, 1994:474-485.` for more information.
457
+
458
+ Args:
459
+ clip_limit (float): Threshold for contrast limiting. Default: 40.0.
460
+ tile_grid_size (tuple[int]): Size of grid for histogram equalization.
461
+ Input image will be divided into equally sized rectangular tiles.
462
+ It defines the number of tiles in row and column. Default: (8, 8).
463
+ """
464
+
465
+ def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
466
+ assert isinstance(clip_limit, (float, int))
467
+ self.clip_limit = clip_limit
468
+ assert is_tuple_of(tile_grid_size, int)
469
+ assert len(tile_grid_size) == 2
470
+ self.tile_grid_size = tile_grid_size
471
+
472
+ def __call__(self, results):
473
+ """Call function to Use CLAHE method process images.
474
+
475
+ Args:
476
+ results (dict): Result dict from loading pipeline.
477
+
478
+ Returns:
479
+ dict: Processed results.
480
+ """
481
+
482
+ for i in range(results['img'].shape[2]):
483
+ results['img'][:, :, i] = mmcv.clahe(
484
+ np.array(results['img'][:, :, i], dtype=np.uint8),
485
+ self.clip_limit, self.tile_grid_size)
486
+
487
+ return results
488
+
489
+ def __repr__(self):
490
+ repr_str = self.__class__.__name__
491
+ repr_str += f'(clip_limit={self.clip_limit}, '\
492
+ f'tile_grid_size={self.tile_grid_size})'
493
+ return repr_str
494
+
495
+
496
+ @PIPELINES.register_module()
497
+ class RandomCrop(object):
498
+ """Random crop the image & seg.
499
+
500
+ Args:
501
+ crop_size (tuple): Expected size after cropping, (h, w).
502
+ cat_max_ratio (float): The maximum ratio that single category could
503
+ occupy.
504
+ """
505
+
506
+ def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
507
+ assert crop_size[0] > 0 and crop_size[1] > 0
508
+ self.crop_size = crop_size
509
+ self.cat_max_ratio = cat_max_ratio
510
+ self.ignore_index = ignore_index
511
+
512
+ def get_crop_bbox(self, img):
513
+ """Randomly get a crop bounding box."""
514
+ margin_h = max(img.shape[0] - self.crop_size[0], 0)
515
+ margin_w = max(img.shape[1] - self.crop_size[1], 0)
516
+ offset_h = np.random.randint(0, margin_h + 1)
517
+ offset_w = np.random.randint(0, margin_w + 1)
518
+ crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
519
+ crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
520
+
521
+ return crop_y1, crop_y2, crop_x1, crop_x2
522
+
523
+ def crop(self, img, crop_bbox):
524
+ """Crop from ``img``"""
525
+ crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
526
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
527
+ return img
528
+
529
+ def __call__(self, results):
530
+ """Call function to randomly crop images, semantic segmentation maps.
531
+
532
+ Args:
533
+ results (dict): Result dict from loading pipeline.
534
+
535
+ Returns:
536
+ dict: Randomly cropped results, 'img_shape' key in result dict is
537
+ updated according to crop size.
538
+ """
539
+
540
+ img = results['img']
541
+ crop_bbox = self.get_crop_bbox(img)
542
+ if self.cat_max_ratio < 1.:
543
+ # Repeat 10 times
544
+ for _ in range(10):
545
+ seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
546
+ labels, cnt = np.unique(seg_temp, return_counts=True)
547
+ cnt = cnt[labels != self.ignore_index]
548
+ if len(cnt) > 1 and np.max(cnt) / np.sum(
549
+ cnt) < self.cat_max_ratio:
550
+ break
551
+ crop_bbox = self.get_crop_bbox(img)
552
+
553
+ # crop the image
554
+ img = self.crop(img, crop_bbox)
555
+ img_shape = img.shape
556
+ results['img'] = img
557
+ results['img_shape'] = img_shape
558
+
559
+ # crop semantic seg
560
+ for key in results.get('seg_fields', []):
561
+ results[key] = self.crop(results[key], crop_bbox)
562
+
563
+ return results
564
+
565
+ def __repr__(self):
566
+ return self.__class__.__name__ + f'(crop_size={self.crop_size})'
567
+
568
+
569
+ @PIPELINES.register_module()
570
+ class RandomRotate(object):
571
+ """Rotate the image & seg.
572
+
573
+ Args:
574
+ prob (float): The rotation probability.
575
+ degree (float, tuple[float]): Range of degrees to select from. If
576
+ degree is a number instead of tuple like (min, max),
577
+ the range of degree will be (``-degree``, ``+degree``)
578
+ pad_val (float, optional): Padding value of image. Default: 0.
579
+ seg_pad_val (float, optional): Padding value of segmentation map.
580
+ Default: 255.
581
+ center (tuple[float], optional): Center point (w, h) of the rotation in
582
+ the source image. If not specified, the center of the image will be
583
+ used. Default: None.
584
+ auto_bound (bool): Whether to adjust the image size to cover the whole
585
+ rotated image. Default: False
586
+ """
587
+
588
+ def __init__(self,
589
+ prob,
590
+ degree,
591
+ pad_val=0,
592
+ seg_pad_val=255,
593
+ center=None,
594
+ auto_bound=False):
595
+ self.prob = prob
596
+ assert prob >= 0 and prob <= 1
597
+ if isinstance(degree, (float, int)):
598
+ assert degree > 0, f'degree {degree} should be positive'
599
+ self.degree = (-degree, degree)
600
+ else:
601
+ self.degree = degree
602
+ assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
603
+ f'tuple of (min, max)'
604
+ self.pal_val = pad_val
605
+ self.seg_pad_val = seg_pad_val
606
+ self.center = center
607
+ self.auto_bound = auto_bound
608
+
609
+ def __call__(self, results):
610
+ """Call function to rotate image, semantic segmentation maps.
611
+
612
+ Args:
613
+ results (dict): Result dict from loading pipeline.
614
+
615
+ Returns:
616
+ dict: Rotated results.
617
+ """
618
+
619
+ rotate = True if np.random.rand() < self.prob else False
620
+ degree = np.random.uniform(min(*self.degree), max(*self.degree))
621
+ if rotate:
622
+ # rotate image
623
+ results['img'] = mmcv.imrotate(
624
+ results['img'],
625
+ angle=degree,
626
+ border_value=self.pal_val,
627
+ center=self.center,
628
+ auto_bound=self.auto_bound)
629
+
630
+ # rotate segs
631
+ for key in results.get('seg_fields', []):
632
+ results[key] = mmcv.imrotate(
633
+ results[key],
634
+ angle=degree,
635
+ border_value=self.seg_pad_val,
636
+ center=self.center,
637
+ auto_bound=self.auto_bound,
638
+ interpolation='nearest')
639
+ return results
640
+
641
+ def __repr__(self):
642
+ repr_str = self.__class__.__name__
643
+ repr_str += f'(prob={self.prob}, ' \
644
+ f'degree={self.degree}, ' \
645
+ f'pad_val={self.pal_val}, ' \
646
+ f'seg_pad_val={self.seg_pad_val}, ' \
647
+ f'center={self.center}, ' \
648
+ f'auto_bound={self.auto_bound})'
649
+ return repr_str
650
+
651
+
652
+ @PIPELINES.register_module()
653
+ class RGB2Gray(object):
654
+ """Convert RGB image to grayscale image.
655
+
656
+ This transform calculate the weighted mean of input image channels with
657
+ ``weights`` and then expand the channels to ``out_channels``. When
658
+ ``out_channels`` is None, the number of output channels is the same as
659
+ input channels.
660
+
661
+ Args:
662
+ out_channels (int): Expected number of output channels after
663
+ transforming. Default: None.
664
+ weights (tuple[float]): The weights to calculate the weighted mean.
665
+ Default: (0.299, 0.587, 0.114).
666
+ """
667
+
668
+ def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
669
+ assert out_channels is None or out_channels > 0
670
+ self.out_channels = out_channels
671
+ assert isinstance(weights, tuple)
672
+ for item in weights:
673
+ assert isinstance(item, (float, int))
674
+ self.weights = weights
675
+
676
+ def __call__(self, results):
677
+ """Call function to convert RGB image to grayscale image.
678
+
679
+ Args:
680
+ results (dict): Result dict from loading pipeline.
681
+
682
+ Returns:
683
+ dict: Result dict with grayscale image.
684
+ """
685
+ img = results['img']
686
+ assert len(img.shape) == 3
687
+ assert img.shape[2] == len(self.weights)
688
+ weights = np.array(self.weights).reshape((1, 1, -1))
689
+ img = (img * weights).sum(2, keepdims=True)
690
+ if self.out_channels is None:
691
+ img = img.repeat(weights.shape[2], axis=2)
692
+ else:
693
+ img = img.repeat(self.out_channels, axis=2)
694
+
695
+ results['img'] = img
696
+ results['img_shape'] = img.shape
697
+
698
+ return results
699
+
700
+ def __repr__(self):
701
+ repr_str = self.__class__.__name__
702
+ repr_str += f'(out_channels={self.out_channels}, ' \
703
+ f'weights={self.weights})'
704
+ return repr_str
705
+
706
+
707
+ @PIPELINES.register_module()
708
+ class AdjustGamma(object):
709
+ """Using gamma correction to process the image.
710
+
711
+ Args:
712
+ gamma (float or int): Gamma value used in gamma correction.
713
+ Default: 1.0.
714
+ """
715
+
716
+ def __init__(self, gamma=1.0):
717
+ assert isinstance(gamma, float) or isinstance(gamma, int)
718
+ assert gamma > 0
719
+ self.gamma = gamma
720
+ inv_gamma = 1.0 / gamma
721
+ self.table = np.array([(i / 255.0)**inv_gamma * 255
722
+ for i in np.arange(256)]).astype('uint8')
723
+
724
+ def __call__(self, results):
725
+ """Call function to process the image with gamma correction.
726
+
727
+ Args:
728
+ results (dict): Result dict from loading pipeline.
729
+
730
+ Returns:
731
+ dict: Processed results.
732
+ """
733
+
734
+ results['img'] = mmcv.lut_transform(
735
+ np.array(results['img'], dtype=np.uint8), self.table)
736
+
737
+ return results
738
+
739
+ def __repr__(self):
740
+ return self.__class__.__name__ + f'(gamma={self.gamma})'
741
+
742
+
743
+ @PIPELINES.register_module()
744
+ class SegRescale(object):
745
+ """Rescale semantic segmentation maps.
746
+
747
+ Args:
748
+ scale_factor (float): The scale factor of the final output.
749
+ """
750
+
751
+ def __init__(self, scale_factor=1):
752
+ self.scale_factor = scale_factor
753
+
754
+ def __call__(self, results):
755
+ """Call function to scale the semantic segmentation map.
756
+
757
+ Args:
758
+ results (dict): Result dict from loading pipeline.
759
+
760
+ Returns:
761
+ dict: Result dict with semantic segmentation map scaled.
762
+ """
763
+ for key in results.get('seg_fields', []):
764
+ if self.scale_factor != 1:
765
+ results[key] = mmcv.imrescale(
766
+ results[key], self.scale_factor, interpolation='nearest')
767
+ return results
768
+
769
+ def __repr__(self):
770
+ return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
771
+
772
+
773
+ @PIPELINES.register_module()
774
+ class PhotoMetricDistortion(object):
775
+ """Apply photometric distortion to image sequentially, every transformation
776
+ is applied with a probability of 0.5. The position of random contrast is in
777
+ second or second to last.
778
+
779
+ 1. random brightness
780
+ 2. random contrast (mode 0)
781
+ 3. convert color from BGR to HSV
782
+ 4. random saturation
783
+ 5. random hue
784
+ 6. convert color from HSV to BGR
785
+ 7. random contrast (mode 1)
786
+
787
+ Args:
788
+ brightness_delta (int): delta of brightness.
789
+ contrast_range (tuple): range of contrast.
790
+ saturation_range (tuple): range of saturation.
791
+ hue_delta (int): delta of hue.
792
+ """
793
+
794
+ def __init__(self,
795
+ brightness_delta=32,
796
+ contrast_range=(0.5, 1.5),
797
+ saturation_range=(0.5, 1.5),
798
+ hue_delta=18):
799
+ self.brightness_delta = brightness_delta
800
+ self.contrast_lower, self.contrast_upper = contrast_range
801
+ self.saturation_lower, self.saturation_upper = saturation_range
802
+ self.hue_delta = hue_delta
803
+
804
+ def convert(self, img, alpha=1, beta=0):
805
+ """Multiple with alpha and add beat with clip."""
806
+ img = img.astype(np.float32) * alpha + beta
807
+ img = np.clip(img, 0, 255)
808
+ return img.astype(np.uint8)
809
+
810
+ def brightness(self, img):
811
+ """Brightness distortion."""
812
+ if random.randint(2):
813
+ return self.convert(
814
+ img,
815
+ beta=random.uniform(-self.brightness_delta,
816
+ self.brightness_delta))
817
+ return img
818
+
819
+ def contrast(self, img):
820
+ """Contrast distortion."""
821
+ if random.randint(2):
822
+ return self.convert(
823
+ img,
824
+ alpha=random.uniform(self.contrast_lower, self.contrast_upper))
825
+ return img
826
+
827
+ def saturation(self, img):
828
+ """Saturation distortion."""
829
+ if random.randint(2):
830
+ img = mmcv.bgr2hsv(img)
831
+ img[:, :, 1] = self.convert(
832
+ img[:, :, 1],
833
+ alpha=random.uniform(self.saturation_lower,
834
+ self.saturation_upper))
835
+ img = mmcv.hsv2bgr(img)
836
+ return img
837
+
838
+ def hue(self, img):
839
+ """Hue distortion."""
840
+ if random.randint(2):
841
+ img = mmcv.bgr2hsv(img)
842
+ img[:, :,
843
+ 0] = (img[:, :, 0].astype(int) +
844
+ random.randint(-self.hue_delta, self.hue_delta)) % 180
845
+ img = mmcv.hsv2bgr(img)
846
+ return img
847
+
848
+ def __call__(self, results):
849
+ """Call function to perform photometric distortion on images.
850
+
851
+ Args:
852
+ results (dict): Result dict from loading pipeline.
853
+
854
+ Returns:
855
+ dict: Result dict with images distorted.
856
+ """
857
+
858
+ img = results['img']
859
+ # random brightness
860
+ img = self.brightness(img)
861
+
862
+ # mode == 0 --> do random contrast first
863
+ # mode == 1 --> do random contrast last
864
+ mode = random.randint(2)
865
+ if mode == 1:
866
+ img = self.contrast(img)
867
+
868
+ # random saturation
869
+ img = self.saturation(img)
870
+
871
+ # random hue
872
+ img = self.hue(img)
873
+
874
+ # random contrast
875
+ if mode == 0:
876
+ img = self.contrast(img)
877
+
878
+ results['img'] = img
879
+ return results
880
+
881
+ def __repr__(self):
882
+ repr_str = self.__class__.__name__
883
+ repr_str += (f'(brightness_delta={self.brightness_delta}, '
884
+ f'contrast_range=({self.contrast_lower}, '
885
+ f'{self.contrast_upper}), '
886
+ f'saturation_range=({self.saturation_lower}, '
887
+ f'{self.saturation_upper}), '
888
+ f'hue_delta={self.hue_delta})')
889
+ return repr_str