toto10 commited on
Commit
373d463
·
1 Parent(s): b311d13

14c20fdaa90f290a63f8278d3172dcf42927bd5cd67fb86f306067688f2c86f2

Browse files
Files changed (50) hide show
  1. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/collate.py +84 -0
  2. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/data_container.py +89 -0
  3. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/data_parallel.py +89 -0
  4. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/distributed.py +112 -0
  5. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/distributed_deprecated.py +70 -0
  6. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/registry.py +8 -0
  7. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/scatter_gather.py +59 -0
  8. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/utils.py +20 -0
  9. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/__init__.py +47 -0
  10. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/base_module.py +195 -0
  11. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/base_runner.py +542 -0
  12. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/builder.py +24 -0
  13. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/checkpoint.py +707 -0
  14. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/default_constructor.py +44 -0
  15. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/dist_utils.py +164 -0
  16. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/epoch_based_runner.py +187 -0
  17. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/fp16_utils.py +410 -0
  18. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/__init__.py +29 -0
  19. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/checkpoint.py +167 -0
  20. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/closure.py +11 -0
  21. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/ema.py +89 -0
  22. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/evaluation.py +509 -0
  23. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/hook.py +92 -0
  24. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/iter_timer.py +18 -0
  25. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/__init__.py +15 -0
  26. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/base.py +166 -0
  27. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/dvclive.py +58 -0
  28. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/mlflow.py +78 -0
  29. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/neptune.py +82 -0
  30. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/pavi.py +117 -0
  31. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/tensorboard.py +57 -0
  32. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/text.py +256 -0
  33. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/wandb.py +56 -0
  34. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/lr_updater.py +670 -0
  35. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/memory.py +25 -0
  36. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/momentum_updater.py +493 -0
  37. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/optimizer.py +508 -0
  38. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/profiler.py +180 -0
  39. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/sampler_seed.py +20 -0
  40. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/sync_buffer.py +22 -0
  41. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/iter_based_runner.py +273 -0
  42. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/log_buffer.py +41 -0
  43. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/__init__.py +9 -0
  44. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/builder.py +44 -0
  45. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/default_constructor.py +249 -0
  46. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/priority.py +60 -0
  47. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/utils.py +93 -0
  48. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/__init__.py +69 -0
  49. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/config.py +688 -0
  50. extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/env.py +95 -0
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/collate.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from collections.abc import Mapping, Sequence
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.utils.data.dataloader import default_collate
7
+
8
+ from .data_container import DataContainer
9
+
10
+
11
+ def collate(batch, samples_per_gpu=1):
12
+ """Puts each data field into a tensor/DataContainer with outer dimension
13
+ batch size.
14
+
15
+ Extend default_collate to add support for
16
+ :type:`~mmcv.parallel.DataContainer`. There are 3 cases.
17
+
18
+ 1. cpu_only = True, e.g., meta data
19
+ 2. cpu_only = False, stack = True, e.g., images tensors
20
+ 3. cpu_only = False, stack = False, e.g., gt bboxes
21
+ """
22
+
23
+ if not isinstance(batch, Sequence):
24
+ raise TypeError(f'{batch.dtype} is not supported.')
25
+
26
+ if isinstance(batch[0], DataContainer):
27
+ stacked = []
28
+ if batch[0].cpu_only:
29
+ for i in range(0, len(batch), samples_per_gpu):
30
+ stacked.append(
31
+ [sample.data for sample in batch[i:i + samples_per_gpu]])
32
+ return DataContainer(
33
+ stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
34
+ elif batch[0].stack:
35
+ for i in range(0, len(batch), samples_per_gpu):
36
+ assert isinstance(batch[i].data, torch.Tensor)
37
+
38
+ if batch[i].pad_dims is not None:
39
+ ndim = batch[i].dim()
40
+ assert ndim > batch[i].pad_dims
41
+ max_shape = [0 for _ in range(batch[i].pad_dims)]
42
+ for dim in range(1, batch[i].pad_dims + 1):
43
+ max_shape[dim - 1] = batch[i].size(-dim)
44
+ for sample in batch[i:i + samples_per_gpu]:
45
+ for dim in range(0, ndim - batch[i].pad_dims):
46
+ assert batch[i].size(dim) == sample.size(dim)
47
+ for dim in range(1, batch[i].pad_dims + 1):
48
+ max_shape[dim - 1] = max(max_shape[dim - 1],
49
+ sample.size(-dim))
50
+ padded_samples = []
51
+ for sample in batch[i:i + samples_per_gpu]:
52
+ pad = [0 for _ in range(batch[i].pad_dims * 2)]
53
+ for dim in range(1, batch[i].pad_dims + 1):
54
+ pad[2 * dim -
55
+ 1] = max_shape[dim - 1] - sample.size(-dim)
56
+ padded_samples.append(
57
+ F.pad(
58
+ sample.data, pad, value=sample.padding_value))
59
+ stacked.append(default_collate(padded_samples))
60
+ elif batch[i].pad_dims is None:
61
+ stacked.append(
62
+ default_collate([
63
+ sample.data
64
+ for sample in batch[i:i + samples_per_gpu]
65
+ ]))
66
+ else:
67
+ raise ValueError(
68
+ 'pad_dims should be either None or integers (1-3)')
69
+
70
+ else:
71
+ for i in range(0, len(batch), samples_per_gpu):
72
+ stacked.append(
73
+ [sample.data for sample in batch[i:i + samples_per_gpu]])
74
+ return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
75
+ elif isinstance(batch[0], Sequence):
76
+ transposed = zip(*batch)
77
+ return [collate(samples, samples_per_gpu) for samples in transposed]
78
+ elif isinstance(batch[0], Mapping):
79
+ return {
80
+ key: collate([d[key] for d in batch], samples_per_gpu)
81
+ for key in batch[0]
82
+ }
83
+ else:
84
+ return default_collate(batch)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/data_container.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import functools
3
+
4
+ import torch
5
+
6
+
7
+ def assert_tensor_type(func):
8
+
9
+ @functools.wraps(func)
10
+ def wrapper(*args, **kwargs):
11
+ if not isinstance(args[0].data, torch.Tensor):
12
+ raise AttributeError(
13
+ f'{args[0].__class__.__name__} has no attribute '
14
+ f'{func.__name__} for type {args[0].datatype}')
15
+ return func(*args, **kwargs)
16
+
17
+ return wrapper
18
+
19
+
20
+ class DataContainer:
21
+ """A container for any type of objects.
22
+
23
+ Typically tensors will be stacked in the collate function and sliced along
24
+ some dimension in the scatter function. This behavior has some limitations.
25
+ 1. All tensors have to be the same size.
26
+ 2. Types are limited (numpy array or Tensor).
27
+
28
+ We design `DataContainer` and `MMDataParallel` to overcome these
29
+ limitations. The behavior can be either of the following.
30
+
31
+ - copy to GPU, pad all tensors to the same size and stack them
32
+ - copy to GPU without stacking
33
+ - leave the objects as is and pass it to the model
34
+ - pad_dims specifies the number of last few dimensions to do padding
35
+ """
36
+
37
+ def __init__(self,
38
+ data,
39
+ stack=False,
40
+ padding_value=0,
41
+ cpu_only=False,
42
+ pad_dims=2):
43
+ self._data = data
44
+ self._cpu_only = cpu_only
45
+ self._stack = stack
46
+ self._padding_value = padding_value
47
+ assert pad_dims in [None, 1, 2, 3]
48
+ self._pad_dims = pad_dims
49
+
50
+ def __repr__(self):
51
+ return f'{self.__class__.__name__}({repr(self.data)})'
52
+
53
+ def __len__(self):
54
+ return len(self._data)
55
+
56
+ @property
57
+ def data(self):
58
+ return self._data
59
+
60
+ @property
61
+ def datatype(self):
62
+ if isinstance(self.data, torch.Tensor):
63
+ return self.data.type()
64
+ else:
65
+ return type(self.data)
66
+
67
+ @property
68
+ def cpu_only(self):
69
+ return self._cpu_only
70
+
71
+ @property
72
+ def stack(self):
73
+ return self._stack
74
+
75
+ @property
76
+ def padding_value(self):
77
+ return self._padding_value
78
+
79
+ @property
80
+ def pad_dims(self):
81
+ return self._pad_dims
82
+
83
+ @assert_tensor_type
84
+ def size(self, *args, **kwargs):
85
+ return self.data.size(*args, **kwargs)
86
+
87
+ @assert_tensor_type
88
+ def dim(self):
89
+ return self.data.dim()
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/data_parallel.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from itertools import chain
3
+
4
+ from torch.nn.parallel import DataParallel
5
+
6
+ from .scatter_gather import scatter_kwargs
7
+
8
+
9
+ class MMDataParallel(DataParallel):
10
+ """The DataParallel module that supports DataContainer.
11
+
12
+ MMDataParallel has two main differences with PyTorch DataParallel:
13
+
14
+ - It supports a custom type :class:`DataContainer` which allows more
15
+ flexible control of input data during both GPU and CPU inference.
16
+ - It implement two more APIs ``train_step()`` and ``val_step()``.
17
+
18
+ Args:
19
+ module (:class:`nn.Module`): Module to be encapsulated.
20
+ device_ids (list[int]): Device IDS of modules to be scattered to.
21
+ Defaults to None when GPU is not available.
22
+ output_device (str | int): Device ID for output. Defaults to None.
23
+ dim (int): Dimension used to scatter the data. Defaults to 0.
24
+ """
25
+
26
+ def __init__(self, *args, dim=0, **kwargs):
27
+ super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
28
+ self.dim = dim
29
+
30
+ def forward(self, *inputs, **kwargs):
31
+ """Override the original forward function.
32
+
33
+ The main difference lies in the CPU inference where the data in
34
+ :class:`DataContainers` will still be gathered.
35
+ """
36
+ if not self.device_ids:
37
+ # We add the following line thus the module could gather and
38
+ # convert data containers as those in GPU inference
39
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
40
+ return self.module(*inputs[0], **kwargs[0])
41
+ else:
42
+ return super().forward(*inputs, **kwargs)
43
+
44
+ def scatter(self, inputs, kwargs, device_ids):
45
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
46
+
47
+ def train_step(self, *inputs, **kwargs):
48
+ if not self.device_ids:
49
+ # We add the following line thus the module could gather and
50
+ # convert data containers as those in GPU inference
51
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
52
+ return self.module.train_step(*inputs[0], **kwargs[0])
53
+
54
+ assert len(self.device_ids) == 1, \
55
+ ('MMDataParallel only supports single GPU training, if you need to'
56
+ ' train with multiple GPUs, please use MMDistributedDataParallel'
57
+ 'instead.')
58
+
59
+ for t in chain(self.module.parameters(), self.module.buffers()):
60
+ if t.device != self.src_device_obj:
61
+ raise RuntimeError(
62
+ 'module must have its parameters and buffers '
63
+ f'on device {self.src_device_obj} (device_ids[0]) but '
64
+ f'found one of them on device: {t.device}')
65
+
66
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
67
+ return self.module.train_step(*inputs[0], **kwargs[0])
68
+
69
+ def val_step(self, *inputs, **kwargs):
70
+ if not self.device_ids:
71
+ # We add the following line thus the module could gather and
72
+ # convert data containers as those in GPU inference
73
+ inputs, kwargs = self.scatter(inputs, kwargs, [-1])
74
+ return self.module.val_step(*inputs[0], **kwargs[0])
75
+
76
+ assert len(self.device_ids) == 1, \
77
+ ('MMDataParallel only supports single GPU training, if you need to'
78
+ ' train with multiple GPUs, please use MMDistributedDataParallel'
79
+ ' instead.')
80
+
81
+ for t in chain(self.module.parameters(), self.module.buffers()):
82
+ if t.device != self.src_device_obj:
83
+ raise RuntimeError(
84
+ 'module must have its parameters and buffers '
85
+ f'on device {self.src_device_obj} (device_ids[0]) but '
86
+ f'found one of them on device: {t.device}')
87
+
88
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
89
+ return self.module.val_step(*inputs[0], **kwargs[0])
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/distributed.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from torch.nn.parallel.distributed import (DistributedDataParallel,
4
+ _find_tensors)
5
+
6
+ from annotator.mmpkg.mmcv import print_log
7
+ from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
8
+ from .scatter_gather import scatter_kwargs
9
+
10
+
11
+ class MMDistributedDataParallel(DistributedDataParallel):
12
+ """The DDP module that supports DataContainer.
13
+
14
+ MMDDP has two main differences with PyTorch DDP:
15
+
16
+ - It supports a custom type :class:`DataContainer` which allows more
17
+ flexible control of input data.
18
+ - It implement two APIs ``train_step()`` and ``val_step()``.
19
+ """
20
+
21
+ def to_kwargs(self, inputs, kwargs, device_id):
22
+ # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
23
+ # to move all tensors to device_id
24
+ return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
25
+
26
+ def scatter(self, inputs, kwargs, device_ids):
27
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
28
+
29
+ def train_step(self, *inputs, **kwargs):
30
+ """train_step() API for module wrapped by DistributedDataParallel.
31
+
32
+ This method is basically the same as
33
+ ``DistributedDataParallel.forward()``, while replacing
34
+ ``self.module.forward()`` with ``self.module.train_step()``.
35
+ It is compatible with PyTorch 1.1 - 1.5.
36
+ """
37
+
38
+ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
39
+ # end of backward to the beginning of forward.
40
+ if ('parrots' not in TORCH_VERSION
41
+ and digit_version(TORCH_VERSION) >= digit_version('1.7')
42
+ and self.reducer._rebuild_buckets()):
43
+ print_log(
44
+ 'Reducer buckets have been rebuilt in this iteration.',
45
+ logger='mmcv')
46
+
47
+ if getattr(self, 'require_forward_param_sync', True):
48
+ self._sync_params()
49
+ if self.device_ids:
50
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
51
+ if len(self.device_ids) == 1:
52
+ output = self.module.train_step(*inputs[0], **kwargs[0])
53
+ else:
54
+ outputs = self.parallel_apply(
55
+ self._module_copies[:len(inputs)], inputs, kwargs)
56
+ output = self.gather(outputs, self.output_device)
57
+ else:
58
+ output = self.module.train_step(*inputs, **kwargs)
59
+
60
+ if torch.is_grad_enabled() and getattr(
61
+ self, 'require_backward_grad_sync', True):
62
+ if self.find_unused_parameters:
63
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
64
+ else:
65
+ self.reducer.prepare_for_backward([])
66
+ else:
67
+ if ('parrots' not in TORCH_VERSION
68
+ and digit_version(TORCH_VERSION) > digit_version('1.2')):
69
+ self.require_forward_param_sync = False
70
+ return output
71
+
72
+ def val_step(self, *inputs, **kwargs):
73
+ """val_step() API for module wrapped by DistributedDataParallel.
74
+
75
+ This method is basically the same as
76
+ ``DistributedDataParallel.forward()``, while replacing
77
+ ``self.module.forward()`` with ``self.module.val_step()``.
78
+ It is compatible with PyTorch 1.1 - 1.5.
79
+ """
80
+ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
81
+ # end of backward to the beginning of forward.
82
+ if ('parrots' not in TORCH_VERSION
83
+ and digit_version(TORCH_VERSION) >= digit_version('1.7')
84
+ and self.reducer._rebuild_buckets()):
85
+ print_log(
86
+ 'Reducer buckets have been rebuilt in this iteration.',
87
+ logger='mmcv')
88
+
89
+ if getattr(self, 'require_forward_param_sync', True):
90
+ self._sync_params()
91
+ if self.device_ids:
92
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
93
+ if len(self.device_ids) == 1:
94
+ output = self.module.val_step(*inputs[0], **kwargs[0])
95
+ else:
96
+ outputs = self.parallel_apply(
97
+ self._module_copies[:len(inputs)], inputs, kwargs)
98
+ output = self.gather(outputs, self.output_device)
99
+ else:
100
+ output = self.module.val_step(*inputs, **kwargs)
101
+
102
+ if torch.is_grad_enabled() and getattr(
103
+ self, 'require_backward_grad_sync', True):
104
+ if self.find_unused_parameters:
105
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
106
+ else:
107
+ self.reducer.prepare_for_backward([])
108
+ else:
109
+ if ('parrots' not in TORCH_VERSION
110
+ and digit_version(TORCH_VERSION) > digit_version('1.2')):
111
+ self.require_forward_param_sync = False
112
+ return output
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/distributed_deprecated.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+ import torch.nn as nn
5
+ from torch._utils import (_flatten_dense_tensors, _take_tensors,
6
+ _unflatten_dense_tensors)
7
+
8
+ from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
9
+ from .registry import MODULE_WRAPPERS
10
+ from .scatter_gather import scatter_kwargs
11
+
12
+
13
+ @MODULE_WRAPPERS.register_module()
14
+ class MMDistributedDataParallel(nn.Module):
15
+
16
+ def __init__(self,
17
+ module,
18
+ dim=0,
19
+ broadcast_buffers=True,
20
+ bucket_cap_mb=25):
21
+ super(MMDistributedDataParallel, self).__init__()
22
+ self.module = module
23
+ self.dim = dim
24
+ self.broadcast_buffers = broadcast_buffers
25
+
26
+ self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
27
+ self._sync_params()
28
+
29
+ def _dist_broadcast_coalesced(self, tensors, buffer_size):
30
+ for tensors in _take_tensors(tensors, buffer_size):
31
+ flat_tensors = _flatten_dense_tensors(tensors)
32
+ dist.broadcast(flat_tensors, 0)
33
+ for tensor, synced in zip(
34
+ tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
35
+ tensor.copy_(synced)
36
+
37
+ def _sync_params(self):
38
+ module_states = list(self.module.state_dict().values())
39
+ if len(module_states) > 0:
40
+ self._dist_broadcast_coalesced(module_states,
41
+ self.broadcast_bucket_size)
42
+ if self.broadcast_buffers:
43
+ if (TORCH_VERSION != 'parrots'
44
+ and digit_version(TORCH_VERSION) < digit_version('1.0')):
45
+ buffers = [b.data for b in self.module._all_buffers()]
46
+ else:
47
+ buffers = [b.data for b in self.module.buffers()]
48
+ if len(buffers) > 0:
49
+ self._dist_broadcast_coalesced(buffers,
50
+ self.broadcast_bucket_size)
51
+
52
+ def scatter(self, inputs, kwargs, device_ids):
53
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
54
+
55
+ def forward(self, *inputs, **kwargs):
56
+ inputs, kwargs = self.scatter(inputs, kwargs,
57
+ [torch.cuda.current_device()])
58
+ return self.module(*inputs[0], **kwargs[0])
59
+
60
+ def train_step(self, *inputs, **kwargs):
61
+ inputs, kwargs = self.scatter(inputs, kwargs,
62
+ [torch.cuda.current_device()])
63
+ output = self.module.train_step(*inputs[0], **kwargs[0])
64
+ return output
65
+
66
+ def val_step(self, *inputs, **kwargs):
67
+ inputs, kwargs = self.scatter(inputs, kwargs,
68
+ [torch.cuda.current_device()])
69
+ output = self.module.val_step(*inputs[0], **kwargs[0])
70
+ return output
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/registry.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
3
+
4
+ from annotator.mmpkg.mmcv.utils import Registry
5
+
6
+ MODULE_WRAPPERS = Registry('module wrapper')
7
+ MODULE_WRAPPERS.register_module(module=DataParallel)
8
+ MODULE_WRAPPERS.register_module(module=DistributedDataParallel)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/scatter_gather.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from torch.nn.parallel._functions import Scatter as OrigScatter
4
+
5
+ from ._functions import Scatter
6
+ from .data_container import DataContainer
7
+
8
+
9
+ def scatter(inputs, target_gpus, dim=0):
10
+ """Scatter inputs to target gpus.
11
+
12
+ The only difference from original :func:`scatter` is to add support for
13
+ :type:`~mmcv.parallel.DataContainer`.
14
+ """
15
+
16
+ def scatter_map(obj):
17
+ if isinstance(obj, torch.Tensor):
18
+ if target_gpus != [-1]:
19
+ return OrigScatter.apply(target_gpus, None, dim, obj)
20
+ else:
21
+ # for CPU inference we use self-implemented scatter
22
+ return Scatter.forward(target_gpus, obj)
23
+ if isinstance(obj, DataContainer):
24
+ if obj.cpu_only:
25
+ return obj.data
26
+ else:
27
+ return Scatter.forward(target_gpus, obj.data)
28
+ if isinstance(obj, tuple) and len(obj) > 0:
29
+ return list(zip(*map(scatter_map, obj)))
30
+ if isinstance(obj, list) and len(obj) > 0:
31
+ out = list(map(list, zip(*map(scatter_map, obj))))
32
+ return out
33
+ if isinstance(obj, dict) and len(obj) > 0:
34
+ out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
35
+ return out
36
+ return [obj for targets in target_gpus]
37
+
38
+ # After scatter_map is called, a scatter_map cell will exist. This cell
39
+ # has a reference to the actual function scatter_map, which has references
40
+ # to a closure that has a reference to the scatter_map cell (because the
41
+ # fn is recursive). To avoid this reference cycle, we set the function to
42
+ # None, clearing the cell
43
+ try:
44
+ return scatter_map(inputs)
45
+ finally:
46
+ scatter_map = None
47
+
48
+
49
+ def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
50
+ """Scatter with support for kwargs dictionary."""
51
+ inputs = scatter(inputs, target_gpus, dim) if inputs else []
52
+ kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
53
+ if len(inputs) < len(kwargs):
54
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
55
+ elif len(kwargs) < len(inputs):
56
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
57
+ inputs = tuple(inputs)
58
+ kwargs = tuple(kwargs)
59
+ return inputs, kwargs
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .registry import MODULE_WRAPPERS
3
+
4
+
5
+ def is_module_wrapper(module):
6
+ """Check if a module is a module wrapper.
7
+
8
+ The following 3 modules in MMCV (and their subclasses) are regarded as
9
+ module wrappers: DataParallel, DistributedDataParallel,
10
+ MMDistributedDataParallel (the deprecated version). You may add you own
11
+ module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
12
+
13
+ Args:
14
+ module (nn.Module): The module to be checked.
15
+
16
+ Returns:
17
+ bool: True if the input module is a module wrapper.
18
+ """
19
+ module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
20
+ return isinstance(module, module_wrappers)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .base_module import BaseModule, ModuleList, Sequential
3
+ from .base_runner import BaseRunner
4
+ from .builder import RUNNERS, build_runner
5
+ from .checkpoint import (CheckpointLoader, _load_checkpoint,
6
+ _load_checkpoint_with_prefix, load_checkpoint,
7
+ load_state_dict, save_checkpoint, weights_to_cpu)
8
+ from .default_constructor import DefaultRunnerConstructor
9
+ from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
10
+ init_dist, master_only)
11
+ from .epoch_based_runner import EpochBasedRunner, Runner
12
+ from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
13
+ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
14
+ DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
15
+ Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
16
+ GradientCumulativeOptimizerHook, Hook, IterTimerHook,
17
+ LoggerHook, LrUpdaterHook, MlflowLoggerHook,
18
+ NeptuneLoggerHook, OptimizerHook, PaviLoggerHook,
19
+ SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
20
+ WandbLoggerHook)
21
+ from .iter_based_runner import IterBasedRunner, IterLoader
22
+ from .log_buffer import LogBuffer
23
+ from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
24
+ DefaultOptimizerConstructor, build_optimizer,
25
+ build_optimizer_constructor)
26
+ from .priority import Priority, get_priority
27
+ from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
28
+
29
+ __all__ = [
30
+ 'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
31
+ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
32
+ 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
33
+ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
34
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
35
+ 'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
36
+ 'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
37
+ 'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
38
+ 'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
39
+ 'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
40
+ 'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
41
+ 'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
42
+ 'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
43
+ 'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
44
+ '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
45
+ 'ModuleList', 'GradientCumulativeOptimizerHook',
46
+ 'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
47
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/base_module.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import warnings
4
+ from abc import ABCMeta
5
+ from collections import defaultdict
6
+ from logging import FileHandler
7
+
8
+ import torch.nn as nn
9
+
10
+ from annotator.mmpkg.mmcv.runner.dist_utils import master_only
11
+ from annotator.mmpkg.mmcv.utils.logging import get_logger, logger_initialized, print_log
12
+
13
+
14
+ class BaseModule(nn.Module, metaclass=ABCMeta):
15
+ """Base module for all modules in openmmlab.
16
+
17
+ ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
18
+ functionality of parameter initialization. Compared with
19
+ ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
20
+
21
+ - ``init_cfg``: the config to control the initialization.
22
+ - ``init_weights``: The function of parameter
23
+ initialization and recording initialization
24
+ information.
25
+ - ``_params_init_info``: Used to track the parameter
26
+ initialization information. This attribute only
27
+ exists during executing the ``init_weights``.
28
+
29
+ Args:
30
+ init_cfg (dict, optional): Initialization config dict.
31
+ """
32
+
33
+ def __init__(self, init_cfg=None):
34
+ """Initialize BaseModule, inherited from `torch.nn.Module`"""
35
+
36
+ # NOTE init_cfg can be defined in different levels, but init_cfg
37
+ # in low levels has a higher priority.
38
+
39
+ super(BaseModule, self).__init__()
40
+ # define default value of init_cfg instead of hard code
41
+ # in init_weights() function
42
+ self._is_init = False
43
+
44
+ self.init_cfg = copy.deepcopy(init_cfg)
45
+
46
+ # Backward compatibility in derived classes
47
+ # if pretrained is not None:
48
+ # warnings.warn('DeprecationWarning: pretrained is a deprecated \
49
+ # key, please consider using init_cfg')
50
+ # self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
51
+
52
+ @property
53
+ def is_init(self):
54
+ return self._is_init
55
+
56
+ def init_weights(self):
57
+ """Initialize the weights."""
58
+
59
+ is_top_level_module = False
60
+ # check if it is top-level module
61
+ if not hasattr(self, '_params_init_info'):
62
+ # The `_params_init_info` is used to record the initialization
63
+ # information of the parameters
64
+ # the key should be the obj:`nn.Parameter` of model and the value
65
+ # should be a dict containing
66
+ # - init_info (str): The string that describes the initialization.
67
+ # - tmp_mean_value (FloatTensor): The mean of the parameter,
68
+ # which indicates whether the parameter has been modified.
69
+ # this attribute would be deleted after all parameters
70
+ # is initialized.
71
+ self._params_init_info = defaultdict(dict)
72
+ is_top_level_module = True
73
+
74
+ # Initialize the `_params_init_info`,
75
+ # When detecting the `tmp_mean_value` of
76
+ # the corresponding parameter is changed, update related
77
+ # initialization information
78
+ for name, param in self.named_parameters():
79
+ self._params_init_info[param][
80
+ 'init_info'] = f'The value is the same before and ' \
81
+ f'after calling `init_weights` ' \
82
+ f'of {self.__class__.__name__} '
83
+ self._params_init_info[param][
84
+ 'tmp_mean_value'] = param.data.mean()
85
+
86
+ # pass `params_init_info` to all submodules
87
+ # All submodules share the same `params_init_info`,
88
+ # so it will be updated when parameters are
89
+ # modified at any level of the model.
90
+ for sub_module in self.modules():
91
+ sub_module._params_init_info = self._params_init_info
92
+
93
+ # Get the initialized logger, if not exist,
94
+ # create a logger named `mmcv`
95
+ logger_names = list(logger_initialized.keys())
96
+ logger_name = logger_names[0] if logger_names else 'mmcv'
97
+
98
+ from ..cnn import initialize
99
+ from ..cnn.utils.weight_init import update_init_info
100
+ module_name = self.__class__.__name__
101
+ if not self._is_init:
102
+ if self.init_cfg:
103
+ print_log(
104
+ f'initialize {module_name} with init_cfg {self.init_cfg}',
105
+ logger=logger_name)
106
+ initialize(self, self.init_cfg)
107
+ if isinstance(self.init_cfg, dict):
108
+ # prevent the parameters of
109
+ # the pre-trained model
110
+ # from being overwritten by
111
+ # the `init_weights`
112
+ if self.init_cfg['type'] == 'Pretrained':
113
+ return
114
+
115
+ for m in self.children():
116
+ if hasattr(m, 'init_weights'):
117
+ m.init_weights()
118
+ # users may overload the `init_weights`
119
+ update_init_info(
120
+ m,
121
+ init_info=f'Initialized by '
122
+ f'user-defined `init_weights`'
123
+ f' in {m.__class__.__name__} ')
124
+
125
+ self._is_init = True
126
+ else:
127
+ warnings.warn(f'init_weights of {self.__class__.__name__} has '
128
+ f'been called more than once.')
129
+
130
+ if is_top_level_module:
131
+ self._dump_init_info(logger_name)
132
+
133
+ for sub_module in self.modules():
134
+ del sub_module._params_init_info
135
+
136
+ @master_only
137
+ def _dump_init_info(self, logger_name):
138
+ """Dump the initialization information to a file named
139
+ `initialization.log.json` in workdir.
140
+
141
+ Args:
142
+ logger_name (str): The name of logger.
143
+ """
144
+
145
+ logger = get_logger(logger_name)
146
+
147
+ with_file_handler = False
148
+ # dump the information to the logger file if there is a `FileHandler`
149
+ for handler in logger.handlers:
150
+ if isinstance(handler, FileHandler):
151
+ handler.stream.write(
152
+ 'Name of parameter - Initialization information\n')
153
+ for name, param in self.named_parameters():
154
+ handler.stream.write(
155
+ f'\n{name} - {param.shape}: '
156
+ f"\n{self._params_init_info[param]['init_info']} \n")
157
+ handler.stream.flush()
158
+ with_file_handler = True
159
+ if not with_file_handler:
160
+ for name, param in self.named_parameters():
161
+ print_log(
162
+ f'\n{name} - {param.shape}: '
163
+ f"\n{self._params_init_info[param]['init_info']} \n ",
164
+ logger=logger_name)
165
+
166
+ def __repr__(self):
167
+ s = super().__repr__()
168
+ if self.init_cfg:
169
+ s += f'\ninit_cfg={self.init_cfg}'
170
+ return s
171
+
172
+
173
+ class Sequential(BaseModule, nn.Sequential):
174
+ """Sequential module in openmmlab.
175
+
176
+ Args:
177
+ init_cfg (dict, optional): Initialization config dict.
178
+ """
179
+
180
+ def __init__(self, *args, init_cfg=None):
181
+ BaseModule.__init__(self, init_cfg)
182
+ nn.Sequential.__init__(self, *args)
183
+
184
+
185
+ class ModuleList(BaseModule, nn.ModuleList):
186
+ """ModuleList in openmmlab.
187
+
188
+ Args:
189
+ modules (iterable, optional): an iterable of modules to add.
190
+ init_cfg (dict, optional): Initialization config dict.
191
+ """
192
+
193
+ def __init__(self, modules=None, init_cfg=None):
194
+ BaseModule.__init__(self, init_cfg)
195
+ nn.ModuleList.__init__(self, modules)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/base_runner.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import logging
4
+ import os.path as osp
5
+ import warnings
6
+ from abc import ABCMeta, abstractmethod
7
+
8
+ import torch
9
+ from torch.optim import Optimizer
10
+
11
+ import annotator.mmpkg.mmcv as mmcv
12
+ from ..parallel import is_module_wrapper
13
+ from .checkpoint import load_checkpoint
14
+ from .dist_utils import get_dist_info
15
+ from .hooks import HOOKS, Hook
16
+ from .log_buffer import LogBuffer
17
+ from .priority import Priority, get_priority
18
+ from .utils import get_time_str
19
+
20
+
21
+ class BaseRunner(metaclass=ABCMeta):
22
+ """The base class of Runner, a training helper for PyTorch.
23
+
24
+ All subclasses should implement the following APIs:
25
+
26
+ - ``run()``
27
+ - ``train()``
28
+ - ``val()``
29
+ - ``save_checkpoint()``
30
+
31
+ Args:
32
+ model (:obj:`torch.nn.Module`): The model to be run.
33
+ batch_processor (callable): A callable method that process a data
34
+ batch. The interface of this method should be
35
+ `batch_processor(model, data, train_mode) -> dict`
36
+ optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
37
+ optimizer (in most cases) or a dict of optimizers (in models that
38
+ requires more than one optimizer, e.g., GAN).
39
+ work_dir (str, optional): The working directory to save checkpoints
40
+ and logs. Defaults to None.
41
+ logger (:obj:`logging.Logger`): Logger used during training.
42
+ Defaults to None. (The default value is just for backward
43
+ compatibility)
44
+ meta (dict | None): A dict records some import information such as
45
+ environment info and seed, which will be logged in logger hook.
46
+ Defaults to None.
47
+ max_epochs (int, optional): Total training epochs.
48
+ max_iters (int, optional): Total training iterations.
49
+ """
50
+
51
+ def __init__(self,
52
+ model,
53
+ batch_processor=None,
54
+ optimizer=None,
55
+ work_dir=None,
56
+ logger=None,
57
+ meta=None,
58
+ max_iters=None,
59
+ max_epochs=None):
60
+ if batch_processor is not None:
61
+ if not callable(batch_processor):
62
+ raise TypeError('batch_processor must be callable, '
63
+ f'but got {type(batch_processor)}')
64
+ warnings.warn('batch_processor is deprecated, please implement '
65
+ 'train_step() and val_step() in the model instead.')
66
+ # raise an error is `batch_processor` is not None and
67
+ # `model.train_step()` exists.
68
+ if is_module_wrapper(model):
69
+ _model = model.module
70
+ else:
71
+ _model = model
72
+ if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
73
+ raise RuntimeError(
74
+ 'batch_processor and model.train_step()/model.val_step() '
75
+ 'cannot be both available.')
76
+ else:
77
+ assert hasattr(model, 'train_step')
78
+
79
+ # check the type of `optimizer`
80
+ if isinstance(optimizer, dict):
81
+ for name, optim in optimizer.items():
82
+ if not isinstance(optim, Optimizer):
83
+ raise TypeError(
84
+ f'optimizer must be a dict of torch.optim.Optimizers, '
85
+ f'but optimizer["{name}"] is a {type(optim)}')
86
+ elif not isinstance(optimizer, Optimizer) and optimizer is not None:
87
+ raise TypeError(
88
+ f'optimizer must be a torch.optim.Optimizer object '
89
+ f'or dict or None, but got {type(optimizer)}')
90
+
91
+ # check the type of `logger`
92
+ if not isinstance(logger, logging.Logger):
93
+ raise TypeError(f'logger must be a logging.Logger object, '
94
+ f'but got {type(logger)}')
95
+
96
+ # check the type of `meta`
97
+ if meta is not None and not isinstance(meta, dict):
98
+ raise TypeError(
99
+ f'meta must be a dict or None, but got {type(meta)}')
100
+
101
+ self.model = model
102
+ self.batch_processor = batch_processor
103
+ self.optimizer = optimizer
104
+ self.logger = logger
105
+ self.meta = meta
106
+ # create work_dir
107
+ if mmcv.is_str(work_dir):
108
+ self.work_dir = osp.abspath(work_dir)
109
+ mmcv.mkdir_or_exist(self.work_dir)
110
+ elif work_dir is None:
111
+ self.work_dir = None
112
+ else:
113
+ raise TypeError('"work_dir" must be a str or None')
114
+
115
+ # get model name from the model class
116
+ if hasattr(self.model, 'module'):
117
+ self._model_name = self.model.module.__class__.__name__
118
+ else:
119
+ self._model_name = self.model.__class__.__name__
120
+
121
+ self._rank, self._world_size = get_dist_info()
122
+ self.timestamp = get_time_str()
123
+ self.mode = None
124
+ self._hooks = []
125
+ self._epoch = 0
126
+ self._iter = 0
127
+ self._inner_iter = 0
128
+
129
+ if max_epochs is not None and max_iters is not None:
130
+ raise ValueError(
131
+ 'Only one of `max_epochs` or `max_iters` can be set.')
132
+
133
+ self._max_epochs = max_epochs
134
+ self._max_iters = max_iters
135
+ # TODO: Redesign LogBuffer, it is not flexible and elegant enough
136
+ self.log_buffer = LogBuffer()
137
+
138
+ @property
139
+ def model_name(self):
140
+ """str: Name of the model, usually the module class name."""
141
+ return self._model_name
142
+
143
+ @property
144
+ def rank(self):
145
+ """int: Rank of current process. (distributed training)"""
146
+ return self._rank
147
+
148
+ @property
149
+ def world_size(self):
150
+ """int: Number of processes participating in the job.
151
+ (distributed training)"""
152
+ return self._world_size
153
+
154
+ @property
155
+ def hooks(self):
156
+ """list[:obj:`Hook`]: A list of registered hooks."""
157
+ return self._hooks
158
+
159
+ @property
160
+ def epoch(self):
161
+ """int: Current epoch."""
162
+ return self._epoch
163
+
164
+ @property
165
+ def iter(self):
166
+ """int: Current iteration."""
167
+ return self._iter
168
+
169
+ @property
170
+ def inner_iter(self):
171
+ """int: Iteration in an epoch."""
172
+ return self._inner_iter
173
+
174
+ @property
175
+ def max_epochs(self):
176
+ """int: Maximum training epochs."""
177
+ return self._max_epochs
178
+
179
+ @property
180
+ def max_iters(self):
181
+ """int: Maximum training iterations."""
182
+ return self._max_iters
183
+
184
+ @abstractmethod
185
+ def train(self):
186
+ pass
187
+
188
+ @abstractmethod
189
+ def val(self):
190
+ pass
191
+
192
+ @abstractmethod
193
+ def run(self, data_loaders, workflow, **kwargs):
194
+ pass
195
+
196
+ @abstractmethod
197
+ def save_checkpoint(self,
198
+ out_dir,
199
+ filename_tmpl,
200
+ save_optimizer=True,
201
+ meta=None,
202
+ create_symlink=True):
203
+ pass
204
+
205
+ def current_lr(self):
206
+ """Get current learning rates.
207
+
208
+ Returns:
209
+ list[float] | dict[str, list[float]]: Current learning rates of all
210
+ param groups. If the runner has a dict of optimizers, this
211
+ method will return a dict.
212
+ """
213
+ if isinstance(self.optimizer, torch.optim.Optimizer):
214
+ lr = [group['lr'] for group in self.optimizer.param_groups]
215
+ elif isinstance(self.optimizer, dict):
216
+ lr = dict()
217
+ for name, optim in self.optimizer.items():
218
+ lr[name] = [group['lr'] for group in optim.param_groups]
219
+ else:
220
+ raise RuntimeError(
221
+ 'lr is not applicable because optimizer does not exist.')
222
+ return lr
223
+
224
+ def current_momentum(self):
225
+ """Get current momentums.
226
+
227
+ Returns:
228
+ list[float] | dict[str, list[float]]: Current momentums of all
229
+ param groups. If the runner has a dict of optimizers, this
230
+ method will return a dict.
231
+ """
232
+
233
+ def _get_momentum(optimizer):
234
+ momentums = []
235
+ for group in optimizer.param_groups:
236
+ if 'momentum' in group.keys():
237
+ momentums.append(group['momentum'])
238
+ elif 'betas' in group.keys():
239
+ momentums.append(group['betas'][0])
240
+ else:
241
+ momentums.append(0)
242
+ return momentums
243
+
244
+ if self.optimizer is None:
245
+ raise RuntimeError(
246
+ 'momentum is not applicable because optimizer does not exist.')
247
+ elif isinstance(self.optimizer, torch.optim.Optimizer):
248
+ momentums = _get_momentum(self.optimizer)
249
+ elif isinstance(self.optimizer, dict):
250
+ momentums = dict()
251
+ for name, optim in self.optimizer.items():
252
+ momentums[name] = _get_momentum(optim)
253
+ return momentums
254
+
255
+ def register_hook(self, hook, priority='NORMAL'):
256
+ """Register a hook into the hook list.
257
+
258
+ The hook will be inserted into a priority queue, with the specified
259
+ priority (See :class:`Priority` for details of priorities).
260
+ For hooks with the same priority, they will be triggered in the same
261
+ order as they are registered.
262
+
263
+ Args:
264
+ hook (:obj:`Hook`): The hook to be registered.
265
+ priority (int or str or :obj:`Priority`): Hook priority.
266
+ Lower value means higher priority.
267
+ """
268
+ assert isinstance(hook, Hook)
269
+ if hasattr(hook, 'priority'):
270
+ raise ValueError('"priority" is a reserved attribute for hooks')
271
+ priority = get_priority(priority)
272
+ hook.priority = priority
273
+ # insert the hook to a sorted list
274
+ inserted = False
275
+ for i in range(len(self._hooks) - 1, -1, -1):
276
+ if priority >= self._hooks[i].priority:
277
+ self._hooks.insert(i + 1, hook)
278
+ inserted = True
279
+ break
280
+ if not inserted:
281
+ self._hooks.insert(0, hook)
282
+
283
+ def register_hook_from_cfg(self, hook_cfg):
284
+ """Register a hook from its cfg.
285
+
286
+ Args:
287
+ hook_cfg (dict): Hook config. It should have at least keys 'type'
288
+ and 'priority' indicating its type and priority.
289
+
290
+ Notes:
291
+ The specific hook class to register should not use 'type' and
292
+ 'priority' arguments during initialization.
293
+ """
294
+ hook_cfg = hook_cfg.copy()
295
+ priority = hook_cfg.pop('priority', 'NORMAL')
296
+ hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
297
+ self.register_hook(hook, priority=priority)
298
+
299
+ def call_hook(self, fn_name):
300
+ """Call all hooks.
301
+
302
+ Args:
303
+ fn_name (str): The function name in each hook to be called, such as
304
+ "before_train_epoch".
305
+ """
306
+ for hook in self._hooks:
307
+ getattr(hook, fn_name)(self)
308
+
309
+ def get_hook_info(self):
310
+ # Get hooks info in each stage
311
+ stage_hook_map = {stage: [] for stage in Hook.stages}
312
+ for hook in self.hooks:
313
+ try:
314
+ priority = Priority(hook.priority).name
315
+ except ValueError:
316
+ priority = hook.priority
317
+ classname = hook.__class__.__name__
318
+ hook_info = f'({priority:<12}) {classname:<35}'
319
+ for trigger_stage in hook.get_triggered_stages():
320
+ stage_hook_map[trigger_stage].append(hook_info)
321
+
322
+ stage_hook_infos = []
323
+ for stage in Hook.stages:
324
+ hook_infos = stage_hook_map[stage]
325
+ if len(hook_infos) > 0:
326
+ info = f'{stage}:\n'
327
+ info += '\n'.join(hook_infos)
328
+ info += '\n -------------------- '
329
+ stage_hook_infos.append(info)
330
+ return '\n'.join(stage_hook_infos)
331
+
332
+ def load_checkpoint(self,
333
+ filename,
334
+ map_location='cpu',
335
+ strict=False,
336
+ revise_keys=[(r'^module.', '')]):
337
+ return load_checkpoint(
338
+ self.model,
339
+ filename,
340
+ map_location,
341
+ strict,
342
+ self.logger,
343
+ revise_keys=revise_keys)
344
+
345
+ def resume(self,
346
+ checkpoint,
347
+ resume_optimizer=True,
348
+ map_location='default'):
349
+ if map_location == 'default':
350
+ if torch.cuda.is_available():
351
+ device_id = torch.cuda.current_device()
352
+ checkpoint = self.load_checkpoint(
353
+ checkpoint,
354
+ map_location=lambda storage, loc: storage.cuda(device_id))
355
+ else:
356
+ checkpoint = self.load_checkpoint(checkpoint)
357
+ else:
358
+ checkpoint = self.load_checkpoint(
359
+ checkpoint, map_location=map_location)
360
+
361
+ self._epoch = checkpoint['meta']['epoch']
362
+ self._iter = checkpoint['meta']['iter']
363
+ if self.meta is None:
364
+ self.meta = {}
365
+ self.meta.setdefault('hook_msgs', {})
366
+ # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
367
+ self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
368
+
369
+ # Re-calculate the number of iterations when resuming
370
+ # models with different number of GPUs
371
+ if 'config' in checkpoint['meta']:
372
+ config = mmcv.Config.fromstring(
373
+ checkpoint['meta']['config'], file_format='.py')
374
+ previous_gpu_ids = config.get('gpu_ids', None)
375
+ if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
376
+ previous_gpu_ids) != self.world_size:
377
+ self._iter = int(self._iter * len(previous_gpu_ids) /
378
+ self.world_size)
379
+ self.logger.info('the iteration number is changed due to '
380
+ 'change of GPU number')
381
+
382
+ # resume meta information meta
383
+ self.meta = checkpoint['meta']
384
+
385
+ if 'optimizer' in checkpoint and resume_optimizer:
386
+ if isinstance(self.optimizer, Optimizer):
387
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
388
+ elif isinstance(self.optimizer, dict):
389
+ for k in self.optimizer.keys():
390
+ self.optimizer[k].load_state_dict(
391
+ checkpoint['optimizer'][k])
392
+ else:
393
+ raise TypeError(
394
+ 'Optimizer should be dict or torch.optim.Optimizer '
395
+ f'but got {type(self.optimizer)}')
396
+
397
+ self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
398
+
399
+ def register_lr_hook(self, lr_config):
400
+ if lr_config is None:
401
+ return
402
+ elif isinstance(lr_config, dict):
403
+ assert 'policy' in lr_config
404
+ policy_type = lr_config.pop('policy')
405
+ # If the type of policy is all in lower case, e.g., 'cyclic',
406
+ # then its first letter will be capitalized, e.g., to be 'Cyclic'.
407
+ # This is for the convenient usage of Lr updater.
408
+ # Since this is not applicable for `
409
+ # CosineAnnealingLrUpdater`,
410
+ # the string will not be changed if it contains capital letters.
411
+ if policy_type == policy_type.lower():
412
+ policy_type = policy_type.title()
413
+ hook_type = policy_type + 'LrUpdaterHook'
414
+ lr_config['type'] = hook_type
415
+ hook = mmcv.build_from_cfg(lr_config, HOOKS)
416
+ else:
417
+ hook = lr_config
418
+ self.register_hook(hook, priority='VERY_HIGH')
419
+
420
+ def register_momentum_hook(self, momentum_config):
421
+ if momentum_config is None:
422
+ return
423
+ if isinstance(momentum_config, dict):
424
+ assert 'policy' in momentum_config
425
+ policy_type = momentum_config.pop('policy')
426
+ # If the type of policy is all in lower case, e.g., 'cyclic',
427
+ # then its first letter will be capitalized, e.g., to be 'Cyclic'.
428
+ # This is for the convenient usage of momentum updater.
429
+ # Since this is not applicable for
430
+ # `CosineAnnealingMomentumUpdater`,
431
+ # the string will not be changed if it contains capital letters.
432
+ if policy_type == policy_type.lower():
433
+ policy_type = policy_type.title()
434
+ hook_type = policy_type + 'MomentumUpdaterHook'
435
+ momentum_config['type'] = hook_type
436
+ hook = mmcv.build_from_cfg(momentum_config, HOOKS)
437
+ else:
438
+ hook = momentum_config
439
+ self.register_hook(hook, priority='HIGH')
440
+
441
+ def register_optimizer_hook(self, optimizer_config):
442
+ if optimizer_config is None:
443
+ return
444
+ if isinstance(optimizer_config, dict):
445
+ optimizer_config.setdefault('type', 'OptimizerHook')
446
+ hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
447
+ else:
448
+ hook = optimizer_config
449
+ self.register_hook(hook, priority='ABOVE_NORMAL')
450
+
451
+ def register_checkpoint_hook(self, checkpoint_config):
452
+ if checkpoint_config is None:
453
+ return
454
+ if isinstance(checkpoint_config, dict):
455
+ checkpoint_config.setdefault('type', 'CheckpointHook')
456
+ hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
457
+ else:
458
+ hook = checkpoint_config
459
+ self.register_hook(hook, priority='NORMAL')
460
+
461
+ def register_logger_hooks(self, log_config):
462
+ if log_config is None:
463
+ return
464
+ log_interval = log_config['interval']
465
+ for info in log_config['hooks']:
466
+ logger_hook = mmcv.build_from_cfg(
467
+ info, HOOKS, default_args=dict(interval=log_interval))
468
+ self.register_hook(logger_hook, priority='VERY_LOW')
469
+
470
+ def register_timer_hook(self, timer_config):
471
+ if timer_config is None:
472
+ return
473
+ if isinstance(timer_config, dict):
474
+ timer_config_ = copy.deepcopy(timer_config)
475
+ hook = mmcv.build_from_cfg(timer_config_, HOOKS)
476
+ else:
477
+ hook = timer_config
478
+ self.register_hook(hook, priority='LOW')
479
+
480
+ def register_custom_hooks(self, custom_config):
481
+ if custom_config is None:
482
+ return
483
+
484
+ if not isinstance(custom_config, list):
485
+ custom_config = [custom_config]
486
+
487
+ for item in custom_config:
488
+ if isinstance(item, dict):
489
+ self.register_hook_from_cfg(item)
490
+ else:
491
+ self.register_hook(item, priority='NORMAL')
492
+
493
+ def register_profiler_hook(self, profiler_config):
494
+ if profiler_config is None:
495
+ return
496
+ if isinstance(profiler_config, dict):
497
+ profiler_config.setdefault('type', 'ProfilerHook')
498
+ hook = mmcv.build_from_cfg(profiler_config, HOOKS)
499
+ else:
500
+ hook = profiler_config
501
+ self.register_hook(hook)
502
+
503
+ def register_training_hooks(self,
504
+ lr_config,
505
+ optimizer_config=None,
506
+ checkpoint_config=None,
507
+ log_config=None,
508
+ momentum_config=None,
509
+ timer_config=dict(type='IterTimerHook'),
510
+ custom_hooks_config=None):
511
+ """Register default and custom hooks for training.
512
+
513
+ Default and custom hooks include:
514
+
515
+ +----------------------+-------------------------+
516
+ | Hooks | Priority |
517
+ +======================+=========================+
518
+ | LrUpdaterHook | VERY_HIGH (10) |
519
+ +----------------------+-------------------------+
520
+ | MomentumUpdaterHook | HIGH (30) |
521
+ +----------------------+-------------------------+
522
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
523
+ +----------------------+-------------------------+
524
+ | CheckpointSaverHook | NORMAL (50) |
525
+ +----------------------+-------------------------+
526
+ | IterTimerHook | LOW (70) |
527
+ +----------------------+-------------------------+
528
+ | LoggerHook(s) | VERY_LOW (90) |
529
+ +----------------------+-------------------------+
530
+ | CustomHook(s) | defaults to NORMAL (50) |
531
+ +----------------------+-------------------------+
532
+
533
+ If custom hooks have same priority with default hooks, custom hooks
534
+ will be triggered after default hooks.
535
+ """
536
+ self.register_lr_hook(lr_config)
537
+ self.register_momentum_hook(momentum_config)
538
+ self.register_optimizer_hook(optimizer_config)
539
+ self.register_checkpoint_hook(checkpoint_config)
540
+ self.register_timer_hook(timer_config)
541
+ self.register_logger_hooks(log_config)
542
+ self.register_custom_hooks(custom_hooks_config)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/builder.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+
4
+ from ..utils import Registry
5
+
6
+ RUNNERS = Registry('runner')
7
+ RUNNER_BUILDERS = Registry('runner builder')
8
+
9
+
10
+ def build_runner_constructor(cfg):
11
+ return RUNNER_BUILDERS.build(cfg)
12
+
13
+
14
+ def build_runner(cfg, default_args=None):
15
+ runner_cfg = copy.deepcopy(cfg)
16
+ constructor_type = runner_cfg.pop('constructor',
17
+ 'DefaultRunnerConstructor')
18
+ runner_constructor = build_runner_constructor(
19
+ dict(
20
+ type=constructor_type,
21
+ runner_cfg=runner_cfg,
22
+ default_args=default_args))
23
+ runner = runner_constructor()
24
+ return runner
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/checkpoint.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import io
3
+ import os
4
+ import os.path as osp
5
+ import pkgutil
6
+ import re
7
+ import time
8
+ import warnings
9
+ from collections import OrderedDict
10
+ from importlib import import_module
11
+ from tempfile import TemporaryDirectory
12
+
13
+ import torch
14
+ import torchvision
15
+ from torch.optim import Optimizer
16
+ from torch.utils import model_zoo
17
+
18
+ import annotator.mmpkg.mmcv as mmcv
19
+ from ..fileio import FileClient
20
+ from ..fileio import load as load_file
21
+ from ..parallel import is_module_wrapper
22
+ from ..utils import mkdir_or_exist
23
+ from .dist_utils import get_dist_info
24
+
25
+ ENV_MMCV_HOME = 'MMCV_HOME'
26
+ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
27
+ DEFAULT_CACHE_DIR = '~/.cache'
28
+
29
+
30
+ def _get_mmcv_home():
31
+ mmcv_home = os.path.expanduser(
32
+ os.getenv(
33
+ ENV_MMCV_HOME,
34
+ os.path.join(
35
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
36
+
37
+ mkdir_or_exist(mmcv_home)
38
+ return mmcv_home
39
+
40
+
41
+ def load_state_dict(module, state_dict, strict=False, logger=None):
42
+ """Load state_dict to a module.
43
+
44
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
45
+ Default value for ``strict`` is set to ``False`` and the message for
46
+ param mismatch will be shown even if strict is False.
47
+
48
+ Args:
49
+ module (Module): Module that receives the state_dict.
50
+ state_dict (OrderedDict): Weights.
51
+ strict (bool): whether to strictly enforce that the keys
52
+ in :attr:`state_dict` match the keys returned by this module's
53
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
54
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
55
+ message. If not specified, print function will be used.
56
+ """
57
+ unexpected_keys = []
58
+ all_missing_keys = []
59
+ err_msg = []
60
+
61
+ metadata = getattr(state_dict, '_metadata', None)
62
+ state_dict = state_dict.copy()
63
+ if metadata is not None:
64
+ state_dict._metadata = metadata
65
+
66
+ # use _load_from_state_dict to enable checkpoint version control
67
+ def load(module, prefix=''):
68
+ # recursively check parallel module in case that the model has a
69
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
70
+ if is_module_wrapper(module):
71
+ module = module.module
72
+ local_metadata = {} if metadata is None else metadata.get(
73
+ prefix[:-1], {})
74
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
75
+ all_missing_keys, unexpected_keys,
76
+ err_msg)
77
+ for name, child in module._modules.items():
78
+ if child is not None:
79
+ load(child, prefix + name + '.')
80
+
81
+ load(module)
82
+ load = None # break load->load reference cycle
83
+
84
+ # ignore "num_batches_tracked" of BN layers
85
+ missing_keys = [
86
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
87
+ ]
88
+
89
+ if unexpected_keys:
90
+ err_msg.append('unexpected key in source '
91
+ f'state_dict: {", ".join(unexpected_keys)}\n')
92
+ if missing_keys:
93
+ err_msg.append(
94
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
95
+
96
+ rank, _ = get_dist_info()
97
+ if len(err_msg) > 0 and rank == 0:
98
+ err_msg.insert(
99
+ 0, 'The model and loaded state dict do not match exactly\n')
100
+ err_msg = '\n'.join(err_msg)
101
+ if strict:
102
+ raise RuntimeError(err_msg)
103
+ elif logger is not None:
104
+ logger.warning(err_msg)
105
+ else:
106
+ print(err_msg)
107
+
108
+
109
+ def get_torchvision_models():
110
+ model_urls = dict()
111
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
112
+ if ispkg:
113
+ continue
114
+ _zoo = import_module(f'torchvision.models.{name}')
115
+ if hasattr(_zoo, 'model_urls'):
116
+ _urls = getattr(_zoo, 'model_urls')
117
+ model_urls.update(_urls)
118
+ return model_urls
119
+
120
+
121
+ def get_external_models():
122
+ mmcv_home = _get_mmcv_home()
123
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
124
+ default_urls = load_file(default_json_path)
125
+ assert isinstance(default_urls, dict)
126
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
127
+ if osp.exists(external_json_path):
128
+ external_urls = load_file(external_json_path)
129
+ assert isinstance(external_urls, dict)
130
+ default_urls.update(external_urls)
131
+
132
+ return default_urls
133
+
134
+
135
+ def get_mmcls_models():
136
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
137
+ mmcls_urls = load_file(mmcls_json_path)
138
+
139
+ return mmcls_urls
140
+
141
+
142
+ def get_deprecated_model_names():
143
+ deprecate_json_path = osp.join(mmcv.__path__[0],
144
+ 'model_zoo/deprecated.json')
145
+ deprecate_urls = load_file(deprecate_json_path)
146
+ assert isinstance(deprecate_urls, dict)
147
+
148
+ return deprecate_urls
149
+
150
+
151
+ def _process_mmcls_checkpoint(checkpoint):
152
+ state_dict = checkpoint['state_dict']
153
+ new_state_dict = OrderedDict()
154
+ for k, v in state_dict.items():
155
+ if k.startswith('backbone.'):
156
+ new_state_dict[k[9:]] = v
157
+ new_checkpoint = dict(state_dict=new_state_dict)
158
+
159
+ return new_checkpoint
160
+
161
+
162
+ class CheckpointLoader:
163
+ """A general checkpoint loader to manage all schemes."""
164
+
165
+ _schemes = {}
166
+
167
+ @classmethod
168
+ def _register_scheme(cls, prefixes, loader, force=False):
169
+ if isinstance(prefixes, str):
170
+ prefixes = [prefixes]
171
+ else:
172
+ assert isinstance(prefixes, (list, tuple))
173
+ for prefix in prefixes:
174
+ if (prefix not in cls._schemes) or force:
175
+ cls._schemes[prefix] = loader
176
+ else:
177
+ raise KeyError(
178
+ f'{prefix} is already registered as a loader backend, '
179
+ 'add "force=True" if you want to override it')
180
+ # sort, longer prefixes take priority
181
+ cls._schemes = OrderedDict(
182
+ sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
183
+
184
+ @classmethod
185
+ def register_scheme(cls, prefixes, loader=None, force=False):
186
+ """Register a loader to CheckpointLoader.
187
+
188
+ This method can be used as a normal class method or a decorator.
189
+
190
+ Args:
191
+ prefixes (str or list[str] or tuple[str]):
192
+ The prefix of the registered loader.
193
+ loader (function, optional): The loader function to be registered.
194
+ When this method is used as a decorator, loader is None.
195
+ Defaults to None.
196
+ force (bool, optional): Whether to override the loader
197
+ if the prefix has already been registered. Defaults to False.
198
+ """
199
+
200
+ if loader is not None:
201
+ cls._register_scheme(prefixes, loader, force=force)
202
+ return
203
+
204
+ def _register(loader_cls):
205
+ cls._register_scheme(prefixes, loader_cls, force=force)
206
+ return loader_cls
207
+
208
+ return _register
209
+
210
+ @classmethod
211
+ def _get_checkpoint_loader(cls, path):
212
+ """Finds a loader that supports the given path. Falls back to the local
213
+ loader if no other loader is found.
214
+
215
+ Args:
216
+ path (str): checkpoint path
217
+
218
+ Returns:
219
+ loader (function): checkpoint loader
220
+ """
221
+
222
+ for p in cls._schemes:
223
+ if path.startswith(p):
224
+ return cls._schemes[p]
225
+
226
+ @classmethod
227
+ def load_checkpoint(cls, filename, map_location=None, logger=None):
228
+ """load checkpoint through URL scheme path.
229
+
230
+ Args:
231
+ filename (str): checkpoint file name with given prefix
232
+ map_location (str, optional): Same as :func:`torch.load`.
233
+ Default: None
234
+ logger (:mod:`logging.Logger`, optional): The logger for message.
235
+ Default: None
236
+
237
+ Returns:
238
+ dict or OrderedDict: The loaded checkpoint.
239
+ """
240
+
241
+ checkpoint_loader = cls._get_checkpoint_loader(filename)
242
+ class_name = checkpoint_loader.__name__
243
+ mmcv.print_log(
244
+ f'load checkpoint from {class_name[10:]} path: {filename}', logger)
245
+ return checkpoint_loader(filename, map_location)
246
+
247
+
248
+ @CheckpointLoader.register_scheme(prefixes='')
249
+ def load_from_local(filename, map_location):
250
+ """load checkpoint by local file path.
251
+
252
+ Args:
253
+ filename (str): local checkpoint file path
254
+ map_location (str, optional): Same as :func:`torch.load`.
255
+
256
+ Returns:
257
+ dict or OrderedDict: The loaded checkpoint.
258
+ """
259
+
260
+ if not osp.isfile(filename):
261
+ raise IOError(f'{filename} is not a checkpoint file')
262
+ checkpoint = torch.load(filename, map_location=map_location)
263
+ return checkpoint
264
+
265
+
266
+ @CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
267
+ def load_from_http(filename, map_location=None, model_dir=None):
268
+ """load checkpoint through HTTP or HTTPS scheme path. In distributed
269
+ setting, this function only download checkpoint at local rank 0.
270
+
271
+ Args:
272
+ filename (str): checkpoint file path with modelzoo or
273
+ torchvision prefix
274
+ map_location (str, optional): Same as :func:`torch.load`.
275
+ model_dir (string, optional): directory in which to save the object,
276
+ Default: None
277
+
278
+ Returns:
279
+ dict or OrderedDict: The loaded checkpoint.
280
+ """
281
+ rank, world_size = get_dist_info()
282
+ rank = int(os.environ.get('LOCAL_RANK', rank))
283
+ if rank == 0:
284
+ checkpoint = model_zoo.load_url(
285
+ filename, model_dir=model_dir, map_location=map_location)
286
+ if world_size > 1:
287
+ torch.distributed.barrier()
288
+ if rank > 0:
289
+ checkpoint = model_zoo.load_url(
290
+ filename, model_dir=model_dir, map_location=map_location)
291
+ return checkpoint
292
+
293
+
294
+ @CheckpointLoader.register_scheme(prefixes='pavi://')
295
+ def load_from_pavi(filename, map_location=None):
296
+ """load checkpoint through the file path prefixed with pavi. In distributed
297
+ setting, this function download ckpt at all ranks to different temporary
298
+ directories.
299
+
300
+ Args:
301
+ filename (str): checkpoint file path with pavi prefix
302
+ map_location (str, optional): Same as :func:`torch.load`.
303
+ Default: None
304
+
305
+ Returns:
306
+ dict or OrderedDict: The loaded checkpoint.
307
+ """
308
+ assert filename.startswith('pavi://'), \
309
+ f'Expected filename startswith `pavi://`, but get {filename}'
310
+ model_path = filename[7:]
311
+
312
+ try:
313
+ from pavi import modelcloud
314
+ except ImportError:
315
+ raise ImportError(
316
+ 'Please install pavi to load checkpoint from modelcloud.')
317
+
318
+ model = modelcloud.get(model_path)
319
+ with TemporaryDirectory() as tmp_dir:
320
+ downloaded_file = osp.join(tmp_dir, model.name)
321
+ model.download(downloaded_file)
322
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
323
+ return checkpoint
324
+
325
+
326
+ @CheckpointLoader.register_scheme(prefixes='s3://')
327
+ def load_from_ceph(filename, map_location=None, backend='petrel'):
328
+ """load checkpoint through the file path prefixed with s3. In distributed
329
+ setting, this function download ckpt at all ranks to different temporary
330
+ directories.
331
+
332
+ Args:
333
+ filename (str): checkpoint file path with s3 prefix
334
+ map_location (str, optional): Same as :func:`torch.load`.
335
+ backend (str, optional): The storage backend type. Options are 'ceph',
336
+ 'petrel'. Default: 'petrel'.
337
+
338
+ .. warning::
339
+ :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
340
+ please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
341
+
342
+ Returns:
343
+ dict or OrderedDict: The loaded checkpoint.
344
+ """
345
+ allowed_backends = ['ceph', 'petrel']
346
+ if backend not in allowed_backends:
347
+ raise ValueError(f'Load from Backend {backend} is not supported.')
348
+
349
+ if backend == 'ceph':
350
+ warnings.warn(
351
+ 'CephBackend will be deprecated, please use PetrelBackend instead')
352
+
353
+ # CephClient and PetrelBackend have the same prefix 's3://' and the latter
354
+ # will be chosen as default. If PetrelBackend can not be instantiated
355
+ # successfully, the CephClient will be chosen.
356
+ try:
357
+ file_client = FileClient(backend=backend)
358
+ except ImportError:
359
+ allowed_backends.remove(backend)
360
+ file_client = FileClient(backend=allowed_backends[0])
361
+
362
+ with io.BytesIO(file_client.get(filename)) as buffer:
363
+ checkpoint = torch.load(buffer, map_location=map_location)
364
+ return checkpoint
365
+
366
+
367
+ @CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
368
+ def load_from_torchvision(filename, map_location=None):
369
+ """load checkpoint through the file path prefixed with modelzoo or
370
+ torchvision.
371
+
372
+ Args:
373
+ filename (str): checkpoint file path with modelzoo or
374
+ torchvision prefix
375
+ map_location (str, optional): Same as :func:`torch.load`.
376
+
377
+ Returns:
378
+ dict or OrderedDict: The loaded checkpoint.
379
+ """
380
+ model_urls = get_torchvision_models()
381
+ if filename.startswith('modelzoo://'):
382
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
383
+ 'use "torchvision://" instead')
384
+ model_name = filename[11:]
385
+ else:
386
+ model_name = filename[14:]
387
+ return load_from_http(model_urls[model_name], map_location=map_location)
388
+
389
+
390
+ @CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
391
+ def load_from_openmmlab(filename, map_location=None):
392
+ """load checkpoint through the file path prefixed with open-mmlab or
393
+ openmmlab.
394
+
395
+ Args:
396
+ filename (str): checkpoint file path with open-mmlab or
397
+ openmmlab prefix
398
+ map_location (str, optional): Same as :func:`torch.load`.
399
+ Default: None
400
+
401
+ Returns:
402
+ dict or OrderedDict: The loaded checkpoint.
403
+ """
404
+
405
+ model_urls = get_external_models()
406
+ prefix_str = 'open-mmlab://'
407
+ if filename.startswith(prefix_str):
408
+ model_name = filename[13:]
409
+ else:
410
+ model_name = filename[12:]
411
+ prefix_str = 'openmmlab://'
412
+
413
+ deprecated_urls = get_deprecated_model_names()
414
+ if model_name in deprecated_urls:
415
+ warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
416
+ f'of {prefix_str}{deprecated_urls[model_name]}')
417
+ model_name = deprecated_urls[model_name]
418
+ model_url = model_urls[model_name]
419
+ # check if is url
420
+ if model_url.startswith(('http://', 'https://')):
421
+ checkpoint = load_from_http(model_url, map_location=map_location)
422
+ else:
423
+ filename = osp.join(_get_mmcv_home(), model_url)
424
+ if not osp.isfile(filename):
425
+ raise IOError(f'{filename} is not a checkpoint file')
426
+ checkpoint = torch.load(filename, map_location=map_location)
427
+ return checkpoint
428
+
429
+
430
+ @CheckpointLoader.register_scheme(prefixes='mmcls://')
431
+ def load_from_mmcls(filename, map_location=None):
432
+ """load checkpoint through the file path prefixed with mmcls.
433
+
434
+ Args:
435
+ filename (str): checkpoint file path with mmcls prefix
436
+ map_location (str, optional): Same as :func:`torch.load`.
437
+
438
+ Returns:
439
+ dict or OrderedDict: The loaded checkpoint.
440
+ """
441
+
442
+ model_urls = get_mmcls_models()
443
+ model_name = filename[8:]
444
+ checkpoint = load_from_http(
445
+ model_urls[model_name], map_location=map_location)
446
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
447
+ return checkpoint
448
+
449
+
450
+ def _load_checkpoint(filename, map_location=None, logger=None):
451
+ """Load checkpoint from somewhere (modelzoo, file, url).
452
+
453
+ Args:
454
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
455
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
456
+ details.
457
+ map_location (str, optional): Same as :func:`torch.load`.
458
+ Default: None.
459
+ logger (:mod:`logging.Logger`, optional): The logger for error message.
460
+ Default: None
461
+
462
+ Returns:
463
+ dict or OrderedDict: The loaded checkpoint. It can be either an
464
+ OrderedDict storing model weights or a dict containing other
465
+ information, which depends on the checkpoint.
466
+ """
467
+ return CheckpointLoader.load_checkpoint(filename, map_location, logger)
468
+
469
+
470
+ def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
471
+ """Load partial pretrained model with specific prefix.
472
+
473
+ Args:
474
+ prefix (str): The prefix of sub-module.
475
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
476
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
477
+ details.
478
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
479
+
480
+ Returns:
481
+ dict or OrderedDict: The loaded checkpoint.
482
+ """
483
+
484
+ checkpoint = _load_checkpoint(filename, map_location=map_location)
485
+
486
+ if 'state_dict' in checkpoint:
487
+ state_dict = checkpoint['state_dict']
488
+ else:
489
+ state_dict = checkpoint
490
+ if not prefix.endswith('.'):
491
+ prefix += '.'
492
+ prefix_len = len(prefix)
493
+
494
+ state_dict = {
495
+ k[prefix_len:]: v
496
+ for k, v in state_dict.items() if k.startswith(prefix)
497
+ }
498
+
499
+ assert state_dict, f'{prefix} is not in the pretrained model'
500
+ return state_dict
501
+
502
+
503
+ def load_checkpoint(model,
504
+ filename,
505
+ map_location=None,
506
+ strict=False,
507
+ logger=None,
508
+ revise_keys=[(r'^module\.', '')]):
509
+ """Load checkpoint from a file or URI.
510
+
511
+ Args:
512
+ model (Module): Module to load checkpoint.
513
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
514
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
515
+ details.
516
+ map_location (str): Same as :func:`torch.load`.
517
+ strict (bool): Whether to allow different params for the model and
518
+ checkpoint.
519
+ logger (:mod:`logging.Logger` or None): The logger for error message.
520
+ revise_keys (list): A list of customized keywords to modify the
521
+ state_dict in checkpoint. Each item is a (pattern, replacement)
522
+ pair of the regular expression operations. Default: strip
523
+ the prefix 'module.' by [(r'^module\\.', '')].
524
+
525
+ Returns:
526
+ dict or OrderedDict: The loaded checkpoint.
527
+ """
528
+ checkpoint = _load_checkpoint(filename, map_location, logger)
529
+ # OrderedDict is a subclass of dict
530
+ if not isinstance(checkpoint, dict):
531
+ raise RuntimeError(
532
+ f'No state_dict found in checkpoint file {filename}')
533
+ # get state_dict from checkpoint
534
+ if 'state_dict' in checkpoint:
535
+ state_dict = checkpoint['state_dict']
536
+ else:
537
+ state_dict = checkpoint
538
+
539
+ # strip prefix of state_dict
540
+ metadata = getattr(state_dict, '_metadata', OrderedDict())
541
+ for p, r in revise_keys:
542
+ state_dict = OrderedDict(
543
+ {re.sub(p, r, k): v
544
+ for k, v in state_dict.items()})
545
+ # Keep metadata in state_dict
546
+ state_dict._metadata = metadata
547
+
548
+ # load state_dict
549
+ load_state_dict(model, state_dict, strict, logger)
550
+ return checkpoint
551
+
552
+
553
+ def weights_to_cpu(state_dict):
554
+ """Copy a model state_dict to cpu.
555
+
556
+ Args:
557
+ state_dict (OrderedDict): Model weights on GPU.
558
+
559
+ Returns:
560
+ OrderedDict: Model weights on GPU.
561
+ """
562
+ state_dict_cpu = OrderedDict()
563
+ for key, val in state_dict.items():
564
+ state_dict_cpu[key] = val.cpu()
565
+ # Keep metadata in state_dict
566
+ state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
567
+ return state_dict_cpu
568
+
569
+
570
+ def _save_to_state_dict(module, destination, prefix, keep_vars):
571
+ """Saves module state to `destination` dictionary.
572
+
573
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
574
+
575
+ Args:
576
+ module (nn.Module): The module to generate state_dict.
577
+ destination (dict): A dict where state will be stored.
578
+ prefix (str): The prefix for parameters and buffers used in this
579
+ module.
580
+ """
581
+ for name, param in module._parameters.items():
582
+ if param is not None:
583
+ destination[prefix + name] = param if keep_vars else param.detach()
584
+ for name, buf in module._buffers.items():
585
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
586
+ if buf is not None:
587
+ destination[prefix + name] = buf if keep_vars else buf.detach()
588
+
589
+
590
+ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
591
+ """Returns a dictionary containing a whole state of the module.
592
+
593
+ Both parameters and persistent buffers (e.g. running averages) are
594
+ included. Keys are corresponding parameter and buffer names.
595
+
596
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
597
+ recursively check parallel module in case that the model has a complicated
598
+ structure, e.g., nn.Module(nn.Module(DDP)).
599
+
600
+ Args:
601
+ module (nn.Module): The module to generate state_dict.
602
+ destination (OrderedDict): Returned dict for the state of the
603
+ module.
604
+ prefix (str): Prefix of the key.
605
+ keep_vars (bool): Whether to keep the variable property of the
606
+ parameters. Default: False.
607
+
608
+ Returns:
609
+ dict: A dictionary containing a whole state of the module.
610
+ """
611
+ # recursively check parallel module in case that the model has a
612
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
613
+ if is_module_wrapper(module):
614
+ module = module.module
615
+
616
+ # below is the same as torch.nn.Module.state_dict()
617
+ if destination is None:
618
+ destination = OrderedDict()
619
+ destination._metadata = OrderedDict()
620
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
621
+ version=module._version)
622
+ _save_to_state_dict(module, destination, prefix, keep_vars)
623
+ for name, child in module._modules.items():
624
+ if child is not None:
625
+ get_state_dict(
626
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
627
+ for hook in module._state_dict_hooks.values():
628
+ hook_result = hook(module, destination, prefix, local_metadata)
629
+ if hook_result is not None:
630
+ destination = hook_result
631
+ return destination
632
+
633
+
634
+ def save_checkpoint(model,
635
+ filename,
636
+ optimizer=None,
637
+ meta=None,
638
+ file_client_args=None):
639
+ """Save checkpoint to file.
640
+
641
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
642
+ ``optimizer``. By default ``meta`` will contain version and time info.
643
+
644
+ Args:
645
+ model (Module): Module whose params are to be saved.
646
+ filename (str): Checkpoint filename.
647
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
648
+ meta (dict, optional): Metadata to be saved in checkpoint.
649
+ file_client_args (dict, optional): Arguments to instantiate a
650
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
651
+ Default: None.
652
+ `New in version 1.3.16.`
653
+ """
654
+ if meta is None:
655
+ meta = {}
656
+ elif not isinstance(meta, dict):
657
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
658
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
659
+
660
+ if is_module_wrapper(model):
661
+ model = model.module
662
+
663
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
664
+ # save class name to the meta
665
+ meta.update(CLASSES=model.CLASSES)
666
+
667
+ checkpoint = {
668
+ 'meta': meta,
669
+ 'state_dict': weights_to_cpu(get_state_dict(model))
670
+ }
671
+ # save optimizer state dict in the checkpoint
672
+ if isinstance(optimizer, Optimizer):
673
+ checkpoint['optimizer'] = optimizer.state_dict()
674
+ elif isinstance(optimizer, dict):
675
+ checkpoint['optimizer'] = {}
676
+ for name, optim in optimizer.items():
677
+ checkpoint['optimizer'][name] = optim.state_dict()
678
+
679
+ if filename.startswith('pavi://'):
680
+ if file_client_args is not None:
681
+ raise ValueError(
682
+ 'file_client_args should be "None" if filename starts with'
683
+ f'"pavi://", but got {file_client_args}')
684
+ try:
685
+ from pavi import modelcloud
686
+ from pavi import exception
687
+ except ImportError:
688
+ raise ImportError(
689
+ 'Please install pavi to load checkpoint from modelcloud.')
690
+ model_path = filename[7:]
691
+ root = modelcloud.Folder()
692
+ model_dir, model_name = osp.split(model_path)
693
+ try:
694
+ model = modelcloud.get(model_dir)
695
+ except exception.NodeNotFoundError:
696
+ model = root.create_training_model(model_dir)
697
+ with TemporaryDirectory() as tmp_dir:
698
+ checkpoint_file = osp.join(tmp_dir, model_name)
699
+ with open(checkpoint_file, 'wb') as f:
700
+ torch.save(checkpoint, f)
701
+ f.flush()
702
+ model.create_file(checkpoint_file, name=model_name)
703
+ else:
704
+ file_client = FileClient.infer_client(file_client_args, filename)
705
+ with io.BytesIO() as f:
706
+ torch.save(checkpoint, f)
707
+ file_client.put(f.getvalue(), filename)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/default_constructor.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .builder import RUNNER_BUILDERS, RUNNERS
2
+
3
+
4
+ @RUNNER_BUILDERS.register_module()
5
+ class DefaultRunnerConstructor:
6
+ """Default constructor for runners.
7
+
8
+ Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
9
+ For example, We can inject some new properties and functions for `Runner`.
10
+
11
+ Example:
12
+ >>> from annotator.mmpkg.mmcv.runner import RUNNER_BUILDERS, build_runner
13
+ >>> # Define a new RunnerReconstructor
14
+ >>> @RUNNER_BUILDERS.register_module()
15
+ >>> class MyRunnerConstructor:
16
+ ... def __init__(self, runner_cfg, default_args=None):
17
+ ... if not isinstance(runner_cfg, dict):
18
+ ... raise TypeError('runner_cfg should be a dict',
19
+ ... f'but got {type(runner_cfg)}')
20
+ ... self.runner_cfg = runner_cfg
21
+ ... self.default_args = default_args
22
+ ...
23
+ ... def __call__(self):
24
+ ... runner = RUNNERS.build(self.runner_cfg,
25
+ ... default_args=self.default_args)
26
+ ... # Add new properties for existing runner
27
+ ... runner.my_name = 'my_runner'
28
+ ... runner.my_function = lambda self: print(self.my_name)
29
+ ... ...
30
+ >>> # build your runner
31
+ >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
32
+ ... constructor='MyRunnerConstructor')
33
+ >>> runner = build_runner(runner_cfg)
34
+ """
35
+
36
+ def __init__(self, runner_cfg, default_args=None):
37
+ if not isinstance(runner_cfg, dict):
38
+ raise TypeError('runner_cfg should be a dict',
39
+ f'but got {type(runner_cfg)}')
40
+ self.runner_cfg = runner_cfg
41
+ self.default_args = default_args
42
+
43
+ def __call__(self):
44
+ return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/dist_utils.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.multiprocessing as mp
9
+ from torch import distributed as dist
10
+ from torch._utils import (_flatten_dense_tensors, _take_tensors,
11
+ _unflatten_dense_tensors)
12
+
13
+
14
+ def init_dist(launcher, backend='nccl', **kwargs):
15
+ if mp.get_start_method(allow_none=True) is None:
16
+ mp.set_start_method('spawn')
17
+ if launcher == 'pytorch':
18
+ _init_dist_pytorch(backend, **kwargs)
19
+ elif launcher == 'mpi':
20
+ _init_dist_mpi(backend, **kwargs)
21
+ elif launcher == 'slurm':
22
+ _init_dist_slurm(backend, **kwargs)
23
+ else:
24
+ raise ValueError(f'Invalid launcher type: {launcher}')
25
+
26
+
27
+ def _init_dist_pytorch(backend, **kwargs):
28
+ # TODO: use local_rank instead of rank % num_gpus
29
+ rank = int(os.environ['RANK'])
30
+ num_gpus = torch.cuda.device_count()
31
+ torch.cuda.set_device(rank % num_gpus)
32
+ dist.init_process_group(backend=backend, **kwargs)
33
+
34
+
35
+ def _init_dist_mpi(backend, **kwargs):
36
+ # TODO: use local_rank instead of rank % num_gpus
37
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
38
+ num_gpus = torch.cuda.device_count()
39
+ torch.cuda.set_device(rank % num_gpus)
40
+ dist.init_process_group(backend=backend, **kwargs)
41
+
42
+
43
+ def _init_dist_slurm(backend, port=None):
44
+ """Initialize slurm distributed training environment.
45
+
46
+ If argument ``port`` is not specified, then the master port will be system
47
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
48
+ environment variable, then a default port ``29500`` will be used.
49
+
50
+ Args:
51
+ backend (str): Backend of torch.distributed.
52
+ port (int, optional): Master port. Defaults to None.
53
+ """
54
+ proc_id = int(os.environ['SLURM_PROCID'])
55
+ ntasks = int(os.environ['SLURM_NTASKS'])
56
+ node_list = os.environ['SLURM_NODELIST']
57
+ num_gpus = torch.cuda.device_count()
58
+ torch.cuda.set_device(proc_id % num_gpus)
59
+ addr = subprocess.getoutput(
60
+ f'scontrol show hostname {node_list} | head -n1')
61
+ # specify master port
62
+ if port is not None:
63
+ os.environ['MASTER_PORT'] = str(port)
64
+ elif 'MASTER_PORT' in os.environ:
65
+ pass # use MASTER_PORT in the environment variable
66
+ else:
67
+ # 29500 is torch.distributed default port
68
+ os.environ['MASTER_PORT'] = '29500'
69
+ # use MASTER_ADDR in the environment variable if it already exists
70
+ if 'MASTER_ADDR' not in os.environ:
71
+ os.environ['MASTER_ADDR'] = addr
72
+ os.environ['WORLD_SIZE'] = str(ntasks)
73
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
74
+ os.environ['RANK'] = str(proc_id)
75
+ dist.init_process_group(backend=backend)
76
+
77
+
78
+ def get_dist_info():
79
+ if dist.is_available() and dist.is_initialized():
80
+ rank = dist.get_rank()
81
+ world_size = dist.get_world_size()
82
+ else:
83
+ rank = 0
84
+ world_size = 1
85
+ return rank, world_size
86
+
87
+
88
+ def master_only(func):
89
+
90
+ @functools.wraps(func)
91
+ def wrapper(*args, **kwargs):
92
+ rank, _ = get_dist_info()
93
+ if rank == 0:
94
+ return func(*args, **kwargs)
95
+
96
+ return wrapper
97
+
98
+
99
+ def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
100
+ """Allreduce parameters.
101
+
102
+ Args:
103
+ params (list[torch.Parameters]): List of parameters or buffers of a
104
+ model.
105
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
106
+ Defaults to True.
107
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
108
+ Defaults to -1.
109
+ """
110
+ _, world_size = get_dist_info()
111
+ if world_size == 1:
112
+ return
113
+ params = [param.data for param in params]
114
+ if coalesce:
115
+ _allreduce_coalesced(params, world_size, bucket_size_mb)
116
+ else:
117
+ for tensor in params:
118
+ dist.all_reduce(tensor.div_(world_size))
119
+
120
+
121
+ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
122
+ """Allreduce gradients.
123
+
124
+ Args:
125
+ params (list[torch.Parameters]): List of parameters of a model
126
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
127
+ Defaults to True.
128
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
129
+ Defaults to -1.
130
+ """
131
+ grads = [
132
+ param.grad.data for param in params
133
+ if param.requires_grad and param.grad is not None
134
+ ]
135
+ _, world_size = get_dist_info()
136
+ if world_size == 1:
137
+ return
138
+ if coalesce:
139
+ _allreduce_coalesced(grads, world_size, bucket_size_mb)
140
+ else:
141
+ for tensor in grads:
142
+ dist.all_reduce(tensor.div_(world_size))
143
+
144
+
145
+ def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
146
+ if bucket_size_mb > 0:
147
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
148
+ buckets = _take_tensors(tensors, bucket_size_bytes)
149
+ else:
150
+ buckets = OrderedDict()
151
+ for tensor in tensors:
152
+ tp = tensor.type()
153
+ if tp not in buckets:
154
+ buckets[tp] = []
155
+ buckets[tp].append(tensor)
156
+ buckets = buckets.values()
157
+
158
+ for bucket in buckets:
159
+ flat_tensors = _flatten_dense_tensors(bucket)
160
+ dist.all_reduce(flat_tensors)
161
+ flat_tensors.div_(world_size)
162
+ for tensor, synced in zip(
163
+ bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
164
+ tensor.copy_(synced)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/epoch_based_runner.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ import platform
4
+ import shutil
5
+ import time
6
+ import warnings
7
+
8
+ import torch
9
+
10
+ import annotator.mmpkg.mmcv as mmcv
11
+ from .base_runner import BaseRunner
12
+ from .builder import RUNNERS
13
+ from .checkpoint import save_checkpoint
14
+ from .utils import get_host_info
15
+
16
+
17
+ @RUNNERS.register_module()
18
+ class EpochBasedRunner(BaseRunner):
19
+ """Epoch-based Runner.
20
+
21
+ This runner train models epoch by epoch.
22
+ """
23
+
24
+ def run_iter(self, data_batch, train_mode, **kwargs):
25
+ if self.batch_processor is not None:
26
+ outputs = self.batch_processor(
27
+ self.model, data_batch, train_mode=train_mode, **kwargs)
28
+ elif train_mode:
29
+ outputs = self.model.train_step(data_batch, self.optimizer,
30
+ **kwargs)
31
+ else:
32
+ outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
33
+ if not isinstance(outputs, dict):
34
+ raise TypeError('"batch_processor()" or "model.train_step()"'
35
+ 'and "model.val_step()" must return a dict')
36
+ if 'log_vars' in outputs:
37
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
38
+ self.outputs = outputs
39
+
40
+ def train(self, data_loader, **kwargs):
41
+ self.model.train()
42
+ self.mode = 'train'
43
+ self.data_loader = data_loader
44
+ self._max_iters = self._max_epochs * len(self.data_loader)
45
+ self.call_hook('before_train_epoch')
46
+ time.sleep(2) # Prevent possible deadlock during epoch transition
47
+ for i, data_batch in enumerate(self.data_loader):
48
+ self._inner_iter = i
49
+ self.call_hook('before_train_iter')
50
+ self.run_iter(data_batch, train_mode=True, **kwargs)
51
+ self.call_hook('after_train_iter')
52
+ self._iter += 1
53
+
54
+ self.call_hook('after_train_epoch')
55
+ self._epoch += 1
56
+
57
+ @torch.no_grad()
58
+ def val(self, data_loader, **kwargs):
59
+ self.model.eval()
60
+ self.mode = 'val'
61
+ self.data_loader = data_loader
62
+ self.call_hook('before_val_epoch')
63
+ time.sleep(2) # Prevent possible deadlock during epoch transition
64
+ for i, data_batch in enumerate(self.data_loader):
65
+ self._inner_iter = i
66
+ self.call_hook('before_val_iter')
67
+ self.run_iter(data_batch, train_mode=False)
68
+ self.call_hook('after_val_iter')
69
+
70
+ self.call_hook('after_val_epoch')
71
+
72
+ def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
73
+ """Start running.
74
+
75
+ Args:
76
+ data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
77
+ and validation.
78
+ workflow (list[tuple]): A list of (phase, epochs) to specify the
79
+ running order and epochs. E.g, [('train', 2), ('val', 1)] means
80
+ running 2 epochs for training and 1 epoch for validation,
81
+ iteratively.
82
+ """
83
+ assert isinstance(data_loaders, list)
84
+ assert mmcv.is_list_of(workflow, tuple)
85
+ assert len(data_loaders) == len(workflow)
86
+ if max_epochs is not None:
87
+ warnings.warn(
88
+ 'setting max_epochs in run is deprecated, '
89
+ 'please set max_epochs in runner_config', DeprecationWarning)
90
+ self._max_epochs = max_epochs
91
+
92
+ assert self._max_epochs is not None, (
93
+ 'max_epochs must be specified during instantiation')
94
+
95
+ for i, flow in enumerate(workflow):
96
+ mode, epochs = flow
97
+ if mode == 'train':
98
+ self._max_iters = self._max_epochs * len(data_loaders[i])
99
+ break
100
+
101
+ work_dir = self.work_dir if self.work_dir is not None else 'NONE'
102
+ self.logger.info('Start running, host: %s, work_dir: %s',
103
+ get_host_info(), work_dir)
104
+ self.logger.info('Hooks will be executed in the following order:\n%s',
105
+ self.get_hook_info())
106
+ self.logger.info('workflow: %s, max: %d epochs', workflow,
107
+ self._max_epochs)
108
+ self.call_hook('before_run')
109
+
110
+ while self.epoch < self._max_epochs:
111
+ for i, flow in enumerate(workflow):
112
+ mode, epochs = flow
113
+ if isinstance(mode, str): # self.train()
114
+ if not hasattr(self, mode):
115
+ raise ValueError(
116
+ f'runner has no method named "{mode}" to run an '
117
+ 'epoch')
118
+ epoch_runner = getattr(self, mode)
119
+ else:
120
+ raise TypeError(
121
+ 'mode in workflow must be a str, but got {}'.format(
122
+ type(mode)))
123
+
124
+ for _ in range(epochs):
125
+ if mode == 'train' and self.epoch >= self._max_epochs:
126
+ break
127
+ epoch_runner(data_loaders[i], **kwargs)
128
+
129
+ time.sleep(1) # wait for some hooks like loggers to finish
130
+ self.call_hook('after_run')
131
+
132
+ def save_checkpoint(self,
133
+ out_dir,
134
+ filename_tmpl='epoch_{}.pth',
135
+ save_optimizer=True,
136
+ meta=None,
137
+ create_symlink=True):
138
+ """Save the checkpoint.
139
+
140
+ Args:
141
+ out_dir (str): The directory that checkpoints are saved.
142
+ filename_tmpl (str, optional): The checkpoint filename template,
143
+ which contains a placeholder for the epoch number.
144
+ Defaults to 'epoch_{}.pth'.
145
+ save_optimizer (bool, optional): Whether to save the optimizer to
146
+ the checkpoint. Defaults to True.
147
+ meta (dict, optional): The meta information to be saved in the
148
+ checkpoint. Defaults to None.
149
+ create_symlink (bool, optional): Whether to create a symlink
150
+ "latest.pth" to point to the latest checkpoint.
151
+ Defaults to True.
152
+ """
153
+ if meta is None:
154
+ meta = {}
155
+ elif not isinstance(meta, dict):
156
+ raise TypeError(
157
+ f'meta should be a dict or None, but got {type(meta)}')
158
+ if self.meta is not None:
159
+ meta.update(self.meta)
160
+ # Note: meta.update(self.meta) should be done before
161
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
162
+ # there will be problems with resumed checkpoints.
163
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
164
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
165
+
166
+ filename = filename_tmpl.format(self.epoch + 1)
167
+ filepath = osp.join(out_dir, filename)
168
+ optimizer = self.optimizer if save_optimizer else None
169
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
170
+ # in some environments, `os.symlink` is not supported, you may need to
171
+ # set `create_symlink` to False
172
+ if create_symlink:
173
+ dst_file = osp.join(out_dir, 'latest.pth')
174
+ if platform.system() != 'Windows':
175
+ mmcv.symlink(filename, dst_file)
176
+ else:
177
+ shutil.copy(filepath, dst_file)
178
+
179
+
180
+ @RUNNERS.register_module()
181
+ class Runner(EpochBasedRunner):
182
+ """Deprecated name of EpochBasedRunner."""
183
+
184
+ def __init__(self, *args, **kwargs):
185
+ warnings.warn(
186
+ 'Runner was deprecated, please use EpochBasedRunner instead')
187
+ super().__init__(*args, **kwargs)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/fp16_utils.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import functools
3
+ import warnings
4
+ from collections import abc
5
+ from inspect import getfullargspec
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
12
+ from .dist_utils import allreduce_grads as _allreduce_grads
13
+
14
+ try:
15
+ # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
16
+ # and used; otherwise, auto fp16 will adopt mmcv's implementation.
17
+ # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
18
+ # manually, so the behavior may not be consistent with real amp.
19
+ from torch.cuda.amp import autocast
20
+ except ImportError:
21
+ pass
22
+
23
+
24
+ def cast_tensor_type(inputs, src_type, dst_type):
25
+ """Recursively convert Tensor in inputs from src_type to dst_type.
26
+
27
+ Args:
28
+ inputs: Inputs that to be casted.
29
+ src_type (torch.dtype): Source type..
30
+ dst_type (torch.dtype): Destination type.
31
+
32
+ Returns:
33
+ The same type with inputs, but all contained Tensors have been cast.
34
+ """
35
+ if isinstance(inputs, nn.Module):
36
+ return inputs
37
+ elif isinstance(inputs, torch.Tensor):
38
+ return inputs.to(dst_type)
39
+ elif isinstance(inputs, str):
40
+ return inputs
41
+ elif isinstance(inputs, np.ndarray):
42
+ return inputs
43
+ elif isinstance(inputs, abc.Mapping):
44
+ return type(inputs)({
45
+ k: cast_tensor_type(v, src_type, dst_type)
46
+ for k, v in inputs.items()
47
+ })
48
+ elif isinstance(inputs, abc.Iterable):
49
+ return type(inputs)(
50
+ cast_tensor_type(item, src_type, dst_type) for item in inputs)
51
+ else:
52
+ return inputs
53
+
54
+
55
+ def auto_fp16(apply_to=None, out_fp32=False):
56
+ """Decorator to enable fp16 training automatically.
57
+
58
+ This decorator is useful when you write custom modules and want to support
59
+ mixed precision training. If inputs arguments are fp32 tensors, they will
60
+ be converted to fp16 automatically. Arguments other than fp32 tensors are
61
+ ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
62
+ backend, otherwise, original mmcv implementation will be adopted.
63
+
64
+ Args:
65
+ apply_to (Iterable, optional): The argument names to be converted.
66
+ `None` indicates all arguments.
67
+ out_fp32 (bool): Whether to convert the output back to fp32.
68
+
69
+ Example:
70
+
71
+ >>> import torch.nn as nn
72
+ >>> class MyModule1(nn.Module):
73
+ >>>
74
+ >>> # Convert x and y to fp16
75
+ >>> @auto_fp16()
76
+ >>> def forward(self, x, y):
77
+ >>> pass
78
+
79
+ >>> import torch.nn as nn
80
+ >>> class MyModule2(nn.Module):
81
+ >>>
82
+ >>> # convert pred to fp16
83
+ >>> @auto_fp16(apply_to=('pred', ))
84
+ >>> def do_something(self, pred, others):
85
+ >>> pass
86
+ """
87
+
88
+ def auto_fp16_wrapper(old_func):
89
+
90
+ @functools.wraps(old_func)
91
+ def new_func(*args, **kwargs):
92
+ # check if the module has set the attribute `fp16_enabled`, if not,
93
+ # just fallback to the original method.
94
+ if not isinstance(args[0], torch.nn.Module):
95
+ raise TypeError('@auto_fp16 can only be used to decorate the '
96
+ 'method of nn.Module')
97
+ if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
98
+ return old_func(*args, **kwargs)
99
+
100
+ # get the arg spec of the decorated method
101
+ args_info = getfullargspec(old_func)
102
+ # get the argument names to be casted
103
+ args_to_cast = args_info.args if apply_to is None else apply_to
104
+ # convert the args that need to be processed
105
+ new_args = []
106
+ # NOTE: default args are not taken into consideration
107
+ if args:
108
+ arg_names = args_info.args[:len(args)]
109
+ for i, arg_name in enumerate(arg_names):
110
+ if arg_name in args_to_cast:
111
+ new_args.append(
112
+ cast_tensor_type(args[i], torch.float, torch.half))
113
+ else:
114
+ new_args.append(args[i])
115
+ # convert the kwargs that need to be processed
116
+ new_kwargs = {}
117
+ if kwargs:
118
+ for arg_name, arg_value in kwargs.items():
119
+ if arg_name in args_to_cast:
120
+ new_kwargs[arg_name] = cast_tensor_type(
121
+ arg_value, torch.float, torch.half)
122
+ else:
123
+ new_kwargs[arg_name] = arg_value
124
+ # apply converted arguments to the decorated method
125
+ if (TORCH_VERSION != 'parrots' and
126
+ digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
127
+ with autocast(enabled=True):
128
+ output = old_func(*new_args, **new_kwargs)
129
+ else:
130
+ output = old_func(*new_args, **new_kwargs)
131
+ # cast the results back to fp32 if necessary
132
+ if out_fp32:
133
+ output = cast_tensor_type(output, torch.half, torch.float)
134
+ return output
135
+
136
+ return new_func
137
+
138
+ return auto_fp16_wrapper
139
+
140
+
141
+ def force_fp32(apply_to=None, out_fp16=False):
142
+ """Decorator to convert input arguments to fp32 in force.
143
+
144
+ This decorator is useful when you write custom modules and want to support
145
+ mixed precision training. If there are some inputs that must be processed
146
+ in fp32 mode, then this decorator can handle it. If inputs arguments are
147
+ fp16 tensors, they will be converted to fp32 automatically. Arguments other
148
+ than fp16 tensors are ignored. If you are using PyTorch >= 1.6,
149
+ torch.cuda.amp is used as the backend, otherwise, original mmcv
150
+ implementation will be adopted.
151
+
152
+ Args:
153
+ apply_to (Iterable, optional): The argument names to be converted.
154
+ `None` indicates all arguments.
155
+ out_fp16 (bool): Whether to convert the output back to fp16.
156
+
157
+ Example:
158
+
159
+ >>> import torch.nn as nn
160
+ >>> class MyModule1(nn.Module):
161
+ >>>
162
+ >>> # Convert x and y to fp32
163
+ >>> @force_fp32()
164
+ >>> def loss(self, x, y):
165
+ >>> pass
166
+
167
+ >>> import torch.nn as nn
168
+ >>> class MyModule2(nn.Module):
169
+ >>>
170
+ >>> # convert pred to fp32
171
+ >>> @force_fp32(apply_to=('pred', ))
172
+ >>> def post_process(self, pred, others):
173
+ >>> pass
174
+ """
175
+
176
+ def force_fp32_wrapper(old_func):
177
+
178
+ @functools.wraps(old_func)
179
+ def new_func(*args, **kwargs):
180
+ # check if the module has set the attribute `fp16_enabled`, if not,
181
+ # just fallback to the original method.
182
+ if not isinstance(args[0], torch.nn.Module):
183
+ raise TypeError('@force_fp32 can only be used to decorate the '
184
+ 'method of nn.Module')
185
+ if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
186
+ return old_func(*args, **kwargs)
187
+ # get the arg spec of the decorated method
188
+ args_info = getfullargspec(old_func)
189
+ # get the argument names to be casted
190
+ args_to_cast = args_info.args if apply_to is None else apply_to
191
+ # convert the args that need to be processed
192
+ new_args = []
193
+ if args:
194
+ arg_names = args_info.args[:len(args)]
195
+ for i, arg_name in enumerate(arg_names):
196
+ if arg_name in args_to_cast:
197
+ new_args.append(
198
+ cast_tensor_type(args[i], torch.half, torch.float))
199
+ else:
200
+ new_args.append(args[i])
201
+ # convert the kwargs that need to be processed
202
+ new_kwargs = dict()
203
+ if kwargs:
204
+ for arg_name, arg_value in kwargs.items():
205
+ if arg_name in args_to_cast:
206
+ new_kwargs[arg_name] = cast_tensor_type(
207
+ arg_value, torch.half, torch.float)
208
+ else:
209
+ new_kwargs[arg_name] = arg_value
210
+ # apply converted arguments to the decorated method
211
+ if (TORCH_VERSION != 'parrots' and
212
+ digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
213
+ with autocast(enabled=False):
214
+ output = old_func(*new_args, **new_kwargs)
215
+ else:
216
+ output = old_func(*new_args, **new_kwargs)
217
+ # cast the results back to fp32 if necessary
218
+ if out_fp16:
219
+ output = cast_tensor_type(output, torch.float, torch.half)
220
+ return output
221
+
222
+ return new_func
223
+
224
+ return force_fp32_wrapper
225
+
226
+
227
+ def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
228
+ warnings.warning(
229
+ '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
230
+ 'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
231
+ _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
232
+
233
+
234
+ def wrap_fp16_model(model):
235
+ """Wrap the FP32 model to FP16.
236
+
237
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
238
+ backend, otherwise, original mmcv implementation will be adopted.
239
+
240
+ For PyTorch >= 1.6, this function will
241
+ 1. Set fp16 flag inside the model to True.
242
+
243
+ Otherwise:
244
+ 1. Convert FP32 model to FP16.
245
+ 2. Remain some necessary layers to be FP32, e.g., normalization layers.
246
+ 3. Set `fp16_enabled` flag inside the model to True.
247
+
248
+ Args:
249
+ model (nn.Module): Model in FP32.
250
+ """
251
+ if (TORCH_VERSION == 'parrots'
252
+ or digit_version(TORCH_VERSION) < digit_version('1.6.0')):
253
+ # convert model to fp16
254
+ model.half()
255
+ # patch the normalization layers to make it work in fp32 mode
256
+ patch_norm_fp32(model)
257
+ # set `fp16_enabled` flag
258
+ for m in model.modules():
259
+ if hasattr(m, 'fp16_enabled'):
260
+ m.fp16_enabled = True
261
+
262
+
263
+ def patch_norm_fp32(module):
264
+ """Recursively convert normalization layers from FP16 to FP32.
265
+
266
+ Args:
267
+ module (nn.Module): The modules to be converted in FP16.
268
+
269
+ Returns:
270
+ nn.Module: The converted module, the normalization layers have been
271
+ converted to FP32.
272
+ """
273
+ if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
274
+ module.float()
275
+ if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
276
+ module.forward = patch_forward_method(module.forward, torch.half,
277
+ torch.float)
278
+ for child in module.children():
279
+ patch_norm_fp32(child)
280
+ return module
281
+
282
+
283
+ def patch_forward_method(func, src_type, dst_type, convert_output=True):
284
+ """Patch the forward method of a module.
285
+
286
+ Args:
287
+ func (callable): The original forward method.
288
+ src_type (torch.dtype): Type of input arguments to be converted from.
289
+ dst_type (torch.dtype): Type of input arguments to be converted to.
290
+ convert_output (bool): Whether to convert the output back to src_type.
291
+
292
+ Returns:
293
+ callable: The patched forward method.
294
+ """
295
+
296
+ def new_forward(*args, **kwargs):
297
+ output = func(*cast_tensor_type(args, src_type, dst_type),
298
+ **cast_tensor_type(kwargs, src_type, dst_type))
299
+ if convert_output:
300
+ output = cast_tensor_type(output, dst_type, src_type)
301
+ return output
302
+
303
+ return new_forward
304
+
305
+
306
+ class LossScaler:
307
+ """Class that manages loss scaling in mixed precision training which
308
+ supports both dynamic or static mode.
309
+
310
+ The implementation refers to
311
+ https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py.
312
+ Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling.
313
+ It's important to understand how :class:`LossScaler` operates.
314
+ Loss scaling is designed to combat the problem of underflowing
315
+ gradients encountered at long times when training fp16 networks.
316
+ Dynamic loss scaling begins by attempting a very high loss
317
+ scale. Ironically, this may result in OVERflowing gradients.
318
+ If overflowing gradients are encountered, :class:`FP16_Optimizer` then
319
+ skips the update step for this particular iteration/minibatch,
320
+ and :class:`LossScaler` adjusts the loss scale to a lower value.
321
+ If a certain number of iterations occur without overflowing gradients
322
+ detected,:class:`LossScaler` increases the loss scale once more.
323
+ In this way :class:`LossScaler` attempts to "ride the edge" of always
324
+ using the highest loss scale possible without incurring overflow.
325
+
326
+ Args:
327
+ init_scale (float): Initial loss scale value, default: 2**32.
328
+ scale_factor (float): Factor used when adjusting the loss scale.
329
+ Default: 2.
330
+ mode (str): Loss scaling mode. 'dynamic' or 'static'
331
+ scale_window (int): Number of consecutive iterations without an
332
+ overflow to wait before increasing the loss scale. Default: 1000.
333
+ """
334
+
335
+ def __init__(self,
336
+ init_scale=2**32,
337
+ mode='dynamic',
338
+ scale_factor=2.,
339
+ scale_window=1000):
340
+ self.cur_scale = init_scale
341
+ self.cur_iter = 0
342
+ assert mode in ('dynamic',
343
+ 'static'), 'mode can only be dynamic or static'
344
+ self.mode = mode
345
+ self.last_overflow_iter = -1
346
+ self.scale_factor = scale_factor
347
+ self.scale_window = scale_window
348
+
349
+ def has_overflow(self, params):
350
+ """Check if params contain overflow."""
351
+ if self.mode != 'dynamic':
352
+ return False
353
+ for p in params:
354
+ if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data):
355
+ return True
356
+ return False
357
+
358
+ def _has_inf_or_nan(x):
359
+ """Check if params contain NaN."""
360
+ try:
361
+ cpu_sum = float(x.float().sum())
362
+ except RuntimeError as instance:
363
+ if 'value cannot be converted' not in instance.args[0]:
364
+ raise
365
+ return True
366
+ else:
367
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') \
368
+ or cpu_sum != cpu_sum:
369
+ return True
370
+ return False
371
+
372
+ def update_scale(self, overflow):
373
+ """update the current loss scale value when overflow happens."""
374
+ if self.mode != 'dynamic':
375
+ return
376
+ if overflow:
377
+ self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
378
+ self.last_overflow_iter = self.cur_iter
379
+ else:
380
+ if (self.cur_iter - self.last_overflow_iter) % \
381
+ self.scale_window == 0:
382
+ self.cur_scale *= self.scale_factor
383
+ self.cur_iter += 1
384
+
385
+ def state_dict(self):
386
+ """Returns the state of the scaler as a :class:`dict`."""
387
+ return dict(
388
+ cur_scale=self.cur_scale,
389
+ cur_iter=self.cur_iter,
390
+ mode=self.mode,
391
+ last_overflow_iter=self.last_overflow_iter,
392
+ scale_factor=self.scale_factor,
393
+ scale_window=self.scale_window)
394
+
395
+ def load_state_dict(self, state_dict):
396
+ """Loads the loss_scaler state dict.
397
+
398
+ Args:
399
+ state_dict (dict): scaler state.
400
+ """
401
+ self.cur_scale = state_dict['cur_scale']
402
+ self.cur_iter = state_dict['cur_iter']
403
+ self.mode = state_dict['mode']
404
+ self.last_overflow_iter = state_dict['last_overflow_iter']
405
+ self.scale_factor = state_dict['scale_factor']
406
+ self.scale_window = state_dict['scale_window']
407
+
408
+ @property
409
+ def loss_scale(self):
410
+ return self.cur_scale
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .checkpoint import CheckpointHook
3
+ from .closure import ClosureHook
4
+ from .ema import EMAHook
5
+ from .evaluation import DistEvalHook, EvalHook
6
+ from .hook import HOOKS, Hook
7
+ from .iter_timer import IterTimerHook
8
+ from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
9
+ NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
10
+ TextLoggerHook, WandbLoggerHook)
11
+ from .lr_updater import LrUpdaterHook
12
+ from .memory import EmptyCacheHook
13
+ from .momentum_updater import MomentumUpdaterHook
14
+ from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
15
+ GradientCumulativeOptimizerHook, OptimizerHook)
16
+ from .profiler import ProfilerHook
17
+ from .sampler_seed import DistSamplerSeedHook
18
+ from .sync_buffer import SyncBuffersHook
19
+
20
+ __all__ = [
21
+ 'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
22
+ 'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
23
+ 'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
24
+ 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
25
+ 'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
26
+ 'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
27
+ 'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook',
28
+ 'GradientCumulativeFp16OptimizerHook'
29
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/checkpoint.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ import warnings
4
+
5
+ from annotator.mmpkg.mmcv.fileio import FileClient
6
+ from ..dist_utils import allreduce_params, master_only
7
+ from .hook import HOOKS, Hook
8
+
9
+
10
+ @HOOKS.register_module()
11
+ class CheckpointHook(Hook):
12
+ """Save checkpoints periodically.
13
+
14
+ Args:
15
+ interval (int): The saving period. If ``by_epoch=True``, interval
16
+ indicates epochs, otherwise it indicates iterations.
17
+ Default: -1, which means "never".
18
+ by_epoch (bool): Saving checkpoints by epoch or by iteration.
19
+ Default: True.
20
+ save_optimizer (bool): Whether to save optimizer state_dict in the
21
+ checkpoint. It is usually used for resuming experiments.
22
+ Default: True.
23
+ out_dir (str, optional): The root directory to save checkpoints. If not
24
+ specified, ``runner.work_dir`` will be used by default. If
25
+ specified, the ``out_dir`` will be the concatenation of ``out_dir``
26
+ and the last level directory of ``runner.work_dir``.
27
+ `Changed in version 1.3.16.`
28
+ max_keep_ckpts (int, optional): The maximum checkpoints to keep.
29
+ In some cases we want only the latest few checkpoints and would
30
+ like to delete old ones to save the disk space.
31
+ Default: -1, which means unlimited.
32
+ save_last (bool, optional): Whether to force the last checkpoint to be
33
+ saved regardless of interval. Default: True.
34
+ sync_buffer (bool, optional): Whether to synchronize buffers in
35
+ different gpus. Default: False.
36
+ file_client_args (dict, optional): Arguments to instantiate a
37
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
38
+ Default: None.
39
+ `New in version 1.3.16.`
40
+
41
+ .. warning::
42
+ Before v1.3.16, the ``out_dir`` argument indicates the path where the
43
+ checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
44
+ root directory and the final path to save checkpoint is the
45
+ concatenation of ``out_dir`` and the last level directory of
46
+ ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
47
+ and the value of ``runner.work_dir`` is "/path/of/B", then the final
48
+ path will be "/path/of/A/B".
49
+ """
50
+
51
+ def __init__(self,
52
+ interval=-1,
53
+ by_epoch=True,
54
+ save_optimizer=True,
55
+ out_dir=None,
56
+ max_keep_ckpts=-1,
57
+ save_last=True,
58
+ sync_buffer=False,
59
+ file_client_args=None,
60
+ **kwargs):
61
+ self.interval = interval
62
+ self.by_epoch = by_epoch
63
+ self.save_optimizer = save_optimizer
64
+ self.out_dir = out_dir
65
+ self.max_keep_ckpts = max_keep_ckpts
66
+ self.save_last = save_last
67
+ self.args = kwargs
68
+ self.sync_buffer = sync_buffer
69
+ self.file_client_args = file_client_args
70
+
71
+ def before_run(self, runner):
72
+ if not self.out_dir:
73
+ self.out_dir = runner.work_dir
74
+
75
+ self.file_client = FileClient.infer_client(self.file_client_args,
76
+ self.out_dir)
77
+
78
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
79
+ # `self.out_dir` is set so the final `self.out_dir` is the
80
+ # concatenation of `self.out_dir` and the last level directory of
81
+ # `runner.work_dir`
82
+ if self.out_dir != runner.work_dir:
83
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
84
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
85
+
86
+ runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
87
+ f'{self.file_client.name}.'))
88
+
89
+ # disable the create_symlink option because some file backends do not
90
+ # allow to create a symlink
91
+ if 'create_symlink' in self.args:
92
+ if self.args[
93
+ 'create_symlink'] and not self.file_client.allow_symlink:
94
+ self.args['create_symlink'] = False
95
+ warnings.warn(
96
+ ('create_symlink is set as True by the user but is changed'
97
+ 'to be False because creating symbolic link is not '
98
+ f'allowed in {self.file_client.name}'))
99
+ else:
100
+ self.args['create_symlink'] = self.file_client.allow_symlink
101
+
102
+ def after_train_epoch(self, runner):
103
+ if not self.by_epoch:
104
+ return
105
+
106
+ # save checkpoint for following cases:
107
+ # 1. every ``self.interval`` epochs
108
+ # 2. reach the last epoch of training
109
+ if self.every_n_epochs(
110
+ runner, self.interval) or (self.save_last
111
+ and self.is_last_epoch(runner)):
112
+ runner.logger.info(
113
+ f'Saving checkpoint at {runner.epoch + 1} epochs')
114
+ if self.sync_buffer:
115
+ allreduce_params(runner.model.buffers())
116
+ self._save_checkpoint(runner)
117
+
118
+ @master_only
119
+ def _save_checkpoint(self, runner):
120
+ """Save the current checkpoint and delete unwanted checkpoint."""
121
+ runner.save_checkpoint(
122
+ self.out_dir, save_optimizer=self.save_optimizer, **self.args)
123
+ if runner.meta is not None:
124
+ if self.by_epoch:
125
+ cur_ckpt_filename = self.args.get(
126
+ 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
127
+ else:
128
+ cur_ckpt_filename = self.args.get(
129
+ 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
130
+ runner.meta.setdefault('hook_msgs', dict())
131
+ runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
132
+ self.out_dir, cur_ckpt_filename)
133
+ # remove other checkpoints
134
+ if self.max_keep_ckpts > 0:
135
+ if self.by_epoch:
136
+ name = 'epoch_{}.pth'
137
+ current_ckpt = runner.epoch + 1
138
+ else:
139
+ name = 'iter_{}.pth'
140
+ current_ckpt = runner.iter + 1
141
+ redundant_ckpts = range(
142
+ current_ckpt - self.max_keep_ckpts * self.interval, 0,
143
+ -self.interval)
144
+ filename_tmpl = self.args.get('filename_tmpl', name)
145
+ for _step in redundant_ckpts:
146
+ ckpt_path = self.file_client.join_path(
147
+ self.out_dir, filename_tmpl.format(_step))
148
+ if self.file_client.isfile(ckpt_path):
149
+ self.file_client.remove(ckpt_path)
150
+ else:
151
+ break
152
+
153
+ def after_train_iter(self, runner):
154
+ if self.by_epoch:
155
+ return
156
+
157
+ # save checkpoint for following cases:
158
+ # 1. every ``self.interval`` iterations
159
+ # 2. reach the last iteration of training
160
+ if self.every_n_iters(
161
+ runner, self.interval) or (self.save_last
162
+ and self.is_last_iter(runner)):
163
+ runner.logger.info(
164
+ f'Saving checkpoint at {runner.iter + 1} iterations')
165
+ if self.sync_buffer:
166
+ allreduce_params(runner.model.buffers())
167
+ self._save_checkpoint(runner)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/closure.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .hook import HOOKS, Hook
3
+
4
+
5
+ @HOOKS.register_module()
6
+ class ClosureHook(Hook):
7
+
8
+ def __init__(self, fn_name, fn):
9
+ assert hasattr(self, fn_name)
10
+ assert callable(fn)
11
+ setattr(self, fn_name, fn)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/ema.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from ...parallel import is_module_wrapper
3
+ from ..hooks.hook import HOOKS, Hook
4
+
5
+
6
+ @HOOKS.register_module()
7
+ class EMAHook(Hook):
8
+ r"""Exponential Moving Average Hook.
9
+
10
+ Use Exponential Moving Average on all parameters of model in training
11
+ process. All parameters have a ema backup, which update by the formula
12
+ as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
13
+
14
+ .. math::
15
+
16
+ \text{Xema\_{t+1}} = (1 - \text{momentum}) \times
17
+ \text{Xema\_{t}} + \text{momentum} \times X_t
18
+
19
+ Args:
20
+ momentum (float): The momentum used for updating ema parameter.
21
+ Defaults to 0.0002.
22
+ interval (int): Update ema parameter every interval iteration.
23
+ Defaults to 1.
24
+ warm_up (int): During first warm_up steps, we may use smaller momentum
25
+ to update ema parameters more slowly. Defaults to 100.
26
+ resume_from (str): The checkpoint path. Defaults to None.
27
+ """
28
+
29
+ def __init__(self,
30
+ momentum=0.0002,
31
+ interval=1,
32
+ warm_up=100,
33
+ resume_from=None):
34
+ assert isinstance(interval, int) and interval > 0
35
+ self.warm_up = warm_up
36
+ self.interval = interval
37
+ assert momentum > 0 and momentum < 1
38
+ self.momentum = momentum**interval
39
+ self.checkpoint = resume_from
40
+
41
+ def before_run(self, runner):
42
+ """To resume model with it's ema parameters more friendly.
43
+
44
+ Register ema parameter as ``named_buffer`` to model
45
+ """
46
+ model = runner.model
47
+ if is_module_wrapper(model):
48
+ model = model.module
49
+ self.param_ema_buffer = {}
50
+ self.model_parameters = dict(model.named_parameters(recurse=True))
51
+ for name, value in self.model_parameters.items():
52
+ # "." is not allowed in module's buffer name
53
+ buffer_name = f"ema_{name.replace('.', '_')}"
54
+ self.param_ema_buffer[name] = buffer_name
55
+ model.register_buffer(buffer_name, value.data.clone())
56
+ self.model_buffers = dict(model.named_buffers(recurse=True))
57
+ if self.checkpoint is not None:
58
+ runner.resume(self.checkpoint)
59
+
60
+ def after_train_iter(self, runner):
61
+ """Update ema parameter every self.interval iterations."""
62
+ curr_step = runner.iter
63
+ # We warm up the momentum considering the instability at beginning
64
+ momentum = min(self.momentum,
65
+ (1 + curr_step) / (self.warm_up + curr_step))
66
+ if curr_step % self.interval != 0:
67
+ return
68
+ for name, parameter in self.model_parameters.items():
69
+ buffer_name = self.param_ema_buffer[name]
70
+ buffer_parameter = self.model_buffers[buffer_name]
71
+ buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
72
+
73
+ def after_train_epoch(self, runner):
74
+ """We load parameter values from ema backup to model before the
75
+ EvalHook."""
76
+ self._swap_ema_parameters()
77
+
78
+ def before_train_epoch(self, runner):
79
+ """We recover model's parameter from ema backup after last epoch's
80
+ EvalHook."""
81
+ self._swap_ema_parameters()
82
+
83
+ def _swap_ema_parameters(self):
84
+ """Swap the parameter of model with parameter in ema_buffer."""
85
+ for name, value in self.model_parameters.items():
86
+ temp = value.data.clone()
87
+ ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
88
+ value.data.copy_(ema_buffer.data)
89
+ ema_buffer.data.copy_(temp)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/evaluation.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ import warnings
4
+ from math import inf
5
+
6
+ import torch.distributed as dist
7
+ from torch.nn.modules.batchnorm import _BatchNorm
8
+ from torch.utils.data import DataLoader
9
+
10
+ from annotator.mmpkg.mmcv.fileio import FileClient
11
+ from annotator.mmpkg.mmcv.utils import is_seq_of
12
+ from .hook import Hook
13
+ from .logger import LoggerHook
14
+
15
+
16
+ class EvalHook(Hook):
17
+ """Non-Distributed evaluation hook.
18
+
19
+ This hook will regularly perform evaluation in a given interval when
20
+ performing in non-distributed environment.
21
+
22
+ Args:
23
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
24
+ implemented ``evaluate`` function.
25
+ start (int | None, optional): Evaluation starting epoch. It enables
26
+ evaluation before the training starts if ``start`` <= the resuming
27
+ epoch. If None, whether to evaluate is merely decided by
28
+ ``interval``. Default: None.
29
+ interval (int): Evaluation interval. Default: 1.
30
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
31
+ If set to True, it will perform by epoch. Otherwise, by iteration.
32
+ Default: True.
33
+ save_best (str, optional): If a metric is specified, it would measure
34
+ the best checkpoint during evaluation. The information about best
35
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
36
+ best score value and best checkpoint path, which will be also
37
+ loaded when resume checkpoint. Options are the evaluation metrics
38
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
39
+ detection and instance segmentation. ``AR@100`` for proposal
40
+ recall. If ``save_best`` is ``auto``, the first key of the returned
41
+ ``OrderedDict`` result will be used. Default: None.
42
+ rule (str | None, optional): Comparison rule for best score. If set to
43
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
44
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
45
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
46
+ Default: None.
47
+ test_fn (callable, optional): test a model with samples from a
48
+ dataloader, and return the test results. If ``None``, the default
49
+ test function ``mmcv.engine.single_gpu_test`` will be used.
50
+ (default: ``None``)
51
+ greater_keys (List[str] | None, optional): Metric keys that will be
52
+ inferred by 'greater' comparison rule. If ``None``,
53
+ _default_greater_keys will be used. (default: ``None``)
54
+ less_keys (List[str] | None, optional): Metric keys that will be
55
+ inferred by 'less' comparison rule. If ``None``, _default_less_keys
56
+ will be used. (default: ``None``)
57
+ out_dir (str, optional): The root directory to save checkpoints. If not
58
+ specified, `runner.work_dir` will be used by default. If specified,
59
+ the `out_dir` will be the concatenation of `out_dir` and the last
60
+ level directory of `runner.work_dir`.
61
+ `New in version 1.3.16.`
62
+ file_client_args (dict): Arguments to instantiate a FileClient.
63
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
64
+ `New in version 1.3.16.`
65
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
66
+ the dataset.
67
+
68
+ Notes:
69
+ If new arguments are added for EvalHook, tools/test.py,
70
+ tools/eval_metric.py may be affected.
71
+ """
72
+
73
+ # Since the key for determine greater or less is related to the downstream
74
+ # tasks, downstream repos may need to overwrite the following inner
75
+ # variable accordingly.
76
+
77
+ rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
78
+ init_value_map = {'greater': -inf, 'less': inf}
79
+ _default_greater_keys = [
80
+ 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
81
+ 'mAcc', 'aAcc'
82
+ ]
83
+ _default_less_keys = ['loss']
84
+
85
+ def __init__(self,
86
+ dataloader,
87
+ start=None,
88
+ interval=1,
89
+ by_epoch=True,
90
+ save_best=None,
91
+ rule=None,
92
+ test_fn=None,
93
+ greater_keys=None,
94
+ less_keys=None,
95
+ out_dir=None,
96
+ file_client_args=None,
97
+ **eval_kwargs):
98
+ if not isinstance(dataloader, DataLoader):
99
+ raise TypeError(f'dataloader must be a pytorch DataLoader, '
100
+ f'but got {type(dataloader)}')
101
+
102
+ if interval <= 0:
103
+ raise ValueError(f'interval must be a positive number, '
104
+ f'but got {interval}')
105
+
106
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'
107
+
108
+ if start is not None and start < 0:
109
+ raise ValueError(f'The evaluation start epoch {start} is smaller '
110
+ f'than 0')
111
+
112
+ self.dataloader = dataloader
113
+ self.interval = interval
114
+ self.start = start
115
+ self.by_epoch = by_epoch
116
+
117
+ assert isinstance(save_best, str) or save_best is None, \
118
+ '""save_best"" should be a str or None ' \
119
+ f'rather than {type(save_best)}'
120
+ self.save_best = save_best
121
+ self.eval_kwargs = eval_kwargs
122
+ self.initial_flag = True
123
+
124
+ if test_fn is None:
125
+ from annotator.mmpkg.mmcv.engine import single_gpu_test
126
+ self.test_fn = single_gpu_test
127
+ else:
128
+ self.test_fn = test_fn
129
+
130
+ if greater_keys is None:
131
+ self.greater_keys = self._default_greater_keys
132
+ else:
133
+ if not isinstance(greater_keys, (list, tuple)):
134
+ greater_keys = (greater_keys, )
135
+ assert is_seq_of(greater_keys, str)
136
+ self.greater_keys = greater_keys
137
+
138
+ if less_keys is None:
139
+ self.less_keys = self._default_less_keys
140
+ else:
141
+ if not isinstance(less_keys, (list, tuple)):
142
+ less_keys = (less_keys, )
143
+ assert is_seq_of(less_keys, str)
144
+ self.less_keys = less_keys
145
+
146
+ if self.save_best is not None:
147
+ self.best_ckpt_path = None
148
+ self._init_rule(rule, self.save_best)
149
+
150
+ self.out_dir = out_dir
151
+ self.file_client_args = file_client_args
152
+
153
+ def _init_rule(self, rule, key_indicator):
154
+ """Initialize rule, key_indicator, comparison_func, and best score.
155
+
156
+ Here is the rule to determine which rule is used for key indicator
157
+ when the rule is not specific (note that the key indicator matching
158
+ is case-insensitive):
159
+ 1. If the key indicator is in ``self.greater_keys``, the rule will be
160
+ specified as 'greater'.
161
+ 2. Or if the key indicator is in ``self.less_keys``, the rule will be
162
+ specified as 'less'.
163
+ 3. Or if the key indicator is equal to the substring in any one item
164
+ in ``self.greater_keys``, the rule will be specified as 'greater'.
165
+ 4. Or if the key indicator is equal to the substring in any one item
166
+ in ``self.less_keys``, the rule will be specified as 'less'.
167
+
168
+ Args:
169
+ rule (str | None): Comparison rule for best score.
170
+ key_indicator (str | None): Key indicator to determine the
171
+ comparison rule.
172
+ """
173
+ if rule not in self.rule_map and rule is not None:
174
+ raise KeyError(f'rule must be greater, less or None, '
175
+ f'but got {rule}.')
176
+
177
+ if rule is None:
178
+ if key_indicator != 'auto':
179
+ # `_lc` here means we use the lower case of keys for
180
+ # case-insensitive matching
181
+ key_indicator_lc = key_indicator.lower()
182
+ greater_keys = [key.lower() for key in self.greater_keys]
183
+ less_keys = [key.lower() for key in self.less_keys]
184
+
185
+ if key_indicator_lc in greater_keys:
186
+ rule = 'greater'
187
+ elif key_indicator_lc in less_keys:
188
+ rule = 'less'
189
+ elif any(key in key_indicator_lc for key in greater_keys):
190
+ rule = 'greater'
191
+ elif any(key in key_indicator_lc for key in less_keys):
192
+ rule = 'less'
193
+ else:
194
+ raise ValueError(f'Cannot infer the rule for key '
195
+ f'{key_indicator}, thus a specific rule '
196
+ f'must be specified.')
197
+ self.rule = rule
198
+ self.key_indicator = key_indicator
199
+ if self.rule is not None:
200
+ self.compare_func = self.rule_map[self.rule]
201
+
202
+ def before_run(self, runner):
203
+ if not self.out_dir:
204
+ self.out_dir = runner.work_dir
205
+
206
+ self.file_client = FileClient.infer_client(self.file_client_args,
207
+ self.out_dir)
208
+
209
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
210
+ # `self.out_dir` is set so the final `self.out_dir` is the
211
+ # concatenation of `self.out_dir` and the last level directory of
212
+ # `runner.work_dir`
213
+ if self.out_dir != runner.work_dir:
214
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
215
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
216
+ runner.logger.info(
217
+ (f'The best checkpoint will be saved to {self.out_dir} by '
218
+ f'{self.file_client.name}'))
219
+
220
+ if self.save_best is not None:
221
+ if runner.meta is None:
222
+ warnings.warn('runner.meta is None. Creating an empty one.')
223
+ runner.meta = dict()
224
+ runner.meta.setdefault('hook_msgs', dict())
225
+ self.best_ckpt_path = runner.meta['hook_msgs'].get(
226
+ 'best_ckpt', None)
227
+
228
+ def before_train_iter(self, runner):
229
+ """Evaluate the model only at the start of training by iteration."""
230
+ if self.by_epoch or not self.initial_flag:
231
+ return
232
+ if self.start is not None and runner.iter >= self.start:
233
+ self.after_train_iter(runner)
234
+ self.initial_flag = False
235
+
236
+ def before_train_epoch(self, runner):
237
+ """Evaluate the model only at the start of training by epoch."""
238
+ if not (self.by_epoch and self.initial_flag):
239
+ return
240
+ if self.start is not None and runner.epoch >= self.start:
241
+ self.after_train_epoch(runner)
242
+ self.initial_flag = False
243
+
244
+ def after_train_iter(self, runner):
245
+ """Called after every training iter to evaluate the results."""
246
+ if not self.by_epoch and self._should_evaluate(runner):
247
+ # Because the priority of EvalHook is higher than LoggerHook, the
248
+ # training log and the evaluating log are mixed. Therefore,
249
+ # we need to dump the training log and clear it before evaluating
250
+ # log is generated. In addition, this problem will only appear in
251
+ # `IterBasedRunner` whose `self.by_epoch` is False, because
252
+ # `EpochBasedRunner` whose `self.by_epoch` is True calls
253
+ # `_do_evaluate` in `after_train_epoch` stage, and at this stage
254
+ # the training log has been printed, so it will not cause any
255
+ # problem. more details at
256
+ # https://github.com/open-mmlab/mmsegmentation/issues/694
257
+ for hook in runner._hooks:
258
+ if isinstance(hook, LoggerHook):
259
+ hook.after_train_iter(runner)
260
+ runner.log_buffer.clear()
261
+
262
+ self._do_evaluate(runner)
263
+
264
+ def after_train_epoch(self, runner):
265
+ """Called after every training epoch to evaluate the results."""
266
+ if self.by_epoch and self._should_evaluate(runner):
267
+ self._do_evaluate(runner)
268
+
269
+ def _do_evaluate(self, runner):
270
+ """perform evaluation and save ckpt."""
271
+ results = self.test_fn(runner.model, self.dataloader)
272
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
273
+ key_score = self.evaluate(runner, results)
274
+ # the key_score may be `None` so it needs to skip the action to save
275
+ # the best checkpoint
276
+ if self.save_best and key_score:
277
+ self._save_ckpt(runner, key_score)
278
+
279
+ def _should_evaluate(self, runner):
280
+ """Judge whether to perform evaluation.
281
+
282
+ Here is the rule to judge whether to perform evaluation:
283
+ 1. It will not perform evaluation during the epoch/iteration interval,
284
+ which is determined by ``self.interval``.
285
+ 2. It will not perform evaluation if the start time is larger than
286
+ current time.
287
+ 3. It will not perform evaluation when current time is larger than
288
+ the start time but during epoch/iteration interval.
289
+
290
+ Returns:
291
+ bool: The flag indicating whether to perform evaluation.
292
+ """
293
+ if self.by_epoch:
294
+ current = runner.epoch
295
+ check_time = self.every_n_epochs
296
+ else:
297
+ current = runner.iter
298
+ check_time = self.every_n_iters
299
+
300
+ if self.start is None:
301
+ if not check_time(runner, self.interval):
302
+ # No evaluation during the interval.
303
+ return False
304
+ elif (current + 1) < self.start:
305
+ # No evaluation if start is larger than the current time.
306
+ return False
307
+ else:
308
+ # Evaluation only at epochs/iters 3, 5, 7...
309
+ # if start==3 and interval==2
310
+ if (current + 1 - self.start) % self.interval:
311
+ return False
312
+ return True
313
+
314
+ def _save_ckpt(self, runner, key_score):
315
+ """Save the best checkpoint.
316
+
317
+ It will compare the score according to the compare function, write
318
+ related information (best score, best checkpoint path) and save the
319
+ best checkpoint into ``work_dir``.
320
+ """
321
+ if self.by_epoch:
322
+ current = f'epoch_{runner.epoch + 1}'
323
+ cur_type, cur_time = 'epoch', runner.epoch + 1
324
+ else:
325
+ current = f'iter_{runner.iter + 1}'
326
+ cur_type, cur_time = 'iter', runner.iter + 1
327
+
328
+ best_score = runner.meta['hook_msgs'].get(
329
+ 'best_score', self.init_value_map[self.rule])
330
+ if self.compare_func(key_score, best_score):
331
+ best_score = key_score
332
+ runner.meta['hook_msgs']['best_score'] = best_score
333
+
334
+ if self.best_ckpt_path and self.file_client.isfile(
335
+ self.best_ckpt_path):
336
+ self.file_client.remove(self.best_ckpt_path)
337
+ runner.logger.info(
338
+ (f'The previous best checkpoint {self.best_ckpt_path} was '
339
+ 'removed'))
340
+
341
+ best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
342
+ self.best_ckpt_path = self.file_client.join_path(
343
+ self.out_dir, best_ckpt_name)
344
+ runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
345
+
346
+ runner.save_checkpoint(
347
+ self.out_dir, best_ckpt_name, create_symlink=False)
348
+ runner.logger.info(
349
+ f'Now best checkpoint is saved as {best_ckpt_name}.')
350
+ runner.logger.info(
351
+ f'Best {self.key_indicator} is {best_score:0.4f} '
352
+ f'at {cur_time} {cur_type}.')
353
+
354
+ def evaluate(self, runner, results):
355
+ """Evaluate the results.
356
+
357
+ Args:
358
+ runner (:obj:`mmcv.Runner`): The underlined training runner.
359
+ results (list): Output results.
360
+ """
361
+ eval_res = self.dataloader.dataset.evaluate(
362
+ results, logger=runner.logger, **self.eval_kwargs)
363
+
364
+ for name, val in eval_res.items():
365
+ runner.log_buffer.output[name] = val
366
+ runner.log_buffer.ready = True
367
+
368
+ if self.save_best is not None:
369
+ # If the performance of model is pool, the `eval_res` may be an
370
+ # empty dict and it will raise exception when `self.save_best` is
371
+ # not None. More details at
372
+ # https://github.com/open-mmlab/mmdetection/issues/6265.
373
+ if not eval_res:
374
+ warnings.warn(
375
+ 'Since `eval_res` is an empty dict, the behavior to save '
376
+ 'the best checkpoint will be skipped in this evaluation.')
377
+ return None
378
+
379
+ if self.key_indicator == 'auto':
380
+ # infer from eval_results
381
+ self._init_rule(self.rule, list(eval_res.keys())[0])
382
+ return eval_res[self.key_indicator]
383
+
384
+ return None
385
+
386
+
387
+ class DistEvalHook(EvalHook):
388
+ """Distributed evaluation hook.
389
+
390
+ This hook will regularly perform evaluation in a given interval when
391
+ performing in distributed environment.
392
+
393
+ Args:
394
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
395
+ implemented ``evaluate`` function.
396
+ start (int | None, optional): Evaluation starting epoch. It enables
397
+ evaluation before the training starts if ``start`` <= the resuming
398
+ epoch. If None, whether to evaluate is merely decided by
399
+ ``interval``. Default: None.
400
+ interval (int): Evaluation interval. Default: 1.
401
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
402
+ If set to True, it will perform by epoch. Otherwise, by iteration.
403
+ default: True.
404
+ save_best (str, optional): If a metric is specified, it would measure
405
+ the best checkpoint during evaluation. The information about best
406
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
407
+ best score value and best checkpoint path, which will be also
408
+ loaded when resume checkpoint. Options are the evaluation metrics
409
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
410
+ detection and instance segmentation. ``AR@100`` for proposal
411
+ recall. If ``save_best`` is ``auto``, the first key of the returned
412
+ ``OrderedDict`` result will be used. Default: None.
413
+ rule (str | None, optional): Comparison rule for best score. If set to
414
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
415
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
416
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
417
+ Default: None.
418
+ test_fn (callable, optional): test a model with samples from a
419
+ dataloader in a multi-gpu manner, and return the test results. If
420
+ ``None``, the default test function ``mmcv.engine.multi_gpu_test``
421
+ will be used. (default: ``None``)
422
+ tmpdir (str | None): Temporary directory to save the results of all
423
+ processes. Default: None.
424
+ gpu_collect (bool): Whether to use gpu or cpu to collect results.
425
+ Default: False.
426
+ broadcast_bn_buffer (bool): Whether to broadcast the
427
+ buffer(running_mean and running_var) of rank 0 to other rank
428
+ before evaluation. Default: True.
429
+ out_dir (str, optional): The root directory to save checkpoints. If not
430
+ specified, `runner.work_dir` will be used by default. If specified,
431
+ the `out_dir` will be the concatenation of `out_dir` and the last
432
+ level directory of `runner.work_dir`.
433
+ file_client_args (dict): Arguments to instantiate a FileClient.
434
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
435
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
436
+ the dataset.
437
+ """
438
+
439
+ def __init__(self,
440
+ dataloader,
441
+ start=None,
442
+ interval=1,
443
+ by_epoch=True,
444
+ save_best=None,
445
+ rule=None,
446
+ test_fn=None,
447
+ greater_keys=None,
448
+ less_keys=None,
449
+ broadcast_bn_buffer=True,
450
+ tmpdir=None,
451
+ gpu_collect=False,
452
+ out_dir=None,
453
+ file_client_args=None,
454
+ **eval_kwargs):
455
+
456
+ if test_fn is None:
457
+ from annotator.mmpkg.mmcv.engine import multi_gpu_test
458
+ test_fn = multi_gpu_test
459
+
460
+ super().__init__(
461
+ dataloader,
462
+ start=start,
463
+ interval=interval,
464
+ by_epoch=by_epoch,
465
+ save_best=save_best,
466
+ rule=rule,
467
+ test_fn=test_fn,
468
+ greater_keys=greater_keys,
469
+ less_keys=less_keys,
470
+ out_dir=out_dir,
471
+ file_client_args=file_client_args,
472
+ **eval_kwargs)
473
+
474
+ self.broadcast_bn_buffer = broadcast_bn_buffer
475
+ self.tmpdir = tmpdir
476
+ self.gpu_collect = gpu_collect
477
+
478
+ def _do_evaluate(self, runner):
479
+ """perform evaluation and save ckpt."""
480
+ # Synchronization of BatchNorm's buffer (running_mean
481
+ # and running_var) is not supported in the DDP of pytorch,
482
+ # which may cause the inconsistent performance of models in
483
+ # different ranks, so we broadcast BatchNorm's buffers
484
+ # of rank 0 to other ranks to avoid this.
485
+ if self.broadcast_bn_buffer:
486
+ model = runner.model
487
+ for name, module in model.named_modules():
488
+ if isinstance(module,
489
+ _BatchNorm) and module.track_running_stats:
490
+ dist.broadcast(module.running_var, 0)
491
+ dist.broadcast(module.running_mean, 0)
492
+
493
+ tmpdir = self.tmpdir
494
+ if tmpdir is None:
495
+ tmpdir = osp.join(runner.work_dir, '.eval_hook')
496
+
497
+ results = self.test_fn(
498
+ runner.model,
499
+ self.dataloader,
500
+ tmpdir=tmpdir,
501
+ gpu_collect=self.gpu_collect)
502
+ if runner.rank == 0:
503
+ print('\n')
504
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
505
+ key_score = self.evaluate(runner, results)
506
+ # the key_score may be `None` so it needs to skip the action to
507
+ # save the best checkpoint
508
+ if self.save_best and key_score:
509
+ self._save_ckpt(runner, key_score)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/hook.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from annotator.mmpkg.mmcv.utils import Registry, is_method_overridden
3
+
4
+ HOOKS = Registry('hook')
5
+
6
+
7
+ class Hook:
8
+ stages = ('before_run', 'before_train_epoch', 'before_train_iter',
9
+ 'after_train_iter', 'after_train_epoch', 'before_val_epoch',
10
+ 'before_val_iter', 'after_val_iter', 'after_val_epoch',
11
+ 'after_run')
12
+
13
+ def before_run(self, runner):
14
+ pass
15
+
16
+ def after_run(self, runner):
17
+ pass
18
+
19
+ def before_epoch(self, runner):
20
+ pass
21
+
22
+ def after_epoch(self, runner):
23
+ pass
24
+
25
+ def before_iter(self, runner):
26
+ pass
27
+
28
+ def after_iter(self, runner):
29
+ pass
30
+
31
+ def before_train_epoch(self, runner):
32
+ self.before_epoch(runner)
33
+
34
+ def before_val_epoch(self, runner):
35
+ self.before_epoch(runner)
36
+
37
+ def after_train_epoch(self, runner):
38
+ self.after_epoch(runner)
39
+
40
+ def after_val_epoch(self, runner):
41
+ self.after_epoch(runner)
42
+
43
+ def before_train_iter(self, runner):
44
+ self.before_iter(runner)
45
+
46
+ def before_val_iter(self, runner):
47
+ self.before_iter(runner)
48
+
49
+ def after_train_iter(self, runner):
50
+ self.after_iter(runner)
51
+
52
+ def after_val_iter(self, runner):
53
+ self.after_iter(runner)
54
+
55
+ def every_n_epochs(self, runner, n):
56
+ return (runner.epoch + 1) % n == 0 if n > 0 else False
57
+
58
+ def every_n_inner_iters(self, runner, n):
59
+ return (runner.inner_iter + 1) % n == 0 if n > 0 else False
60
+
61
+ def every_n_iters(self, runner, n):
62
+ return (runner.iter + 1) % n == 0 if n > 0 else False
63
+
64
+ def end_of_epoch(self, runner):
65
+ return runner.inner_iter + 1 == len(runner.data_loader)
66
+
67
+ def is_last_epoch(self, runner):
68
+ return runner.epoch + 1 == runner._max_epochs
69
+
70
+ def is_last_iter(self, runner):
71
+ return runner.iter + 1 == runner._max_iters
72
+
73
+ def get_triggered_stages(self):
74
+ trigger_stages = set()
75
+ for stage in Hook.stages:
76
+ if is_method_overridden(stage, Hook, self):
77
+ trigger_stages.add(stage)
78
+
79
+ # some methods will be triggered in multi stages
80
+ # use this dict to map method to stages.
81
+ method_stages_map = {
82
+ 'before_epoch': ['before_train_epoch', 'before_val_epoch'],
83
+ 'after_epoch': ['after_train_epoch', 'after_val_epoch'],
84
+ 'before_iter': ['before_train_iter', 'before_val_iter'],
85
+ 'after_iter': ['after_train_iter', 'after_val_iter'],
86
+ }
87
+
88
+ for method, map_stages in method_stages_map.items():
89
+ if is_method_overridden(method, Hook, self):
90
+ trigger_stages.update(map_stages)
91
+
92
+ return [stage for stage in Hook.stages if stage in trigger_stages]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/iter_timer.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import time
3
+
4
+ from .hook import HOOKS, Hook
5
+
6
+
7
+ @HOOKS.register_module()
8
+ class IterTimerHook(Hook):
9
+
10
+ def before_epoch(self, runner):
11
+ self.t = time.time()
12
+
13
+ def before_iter(self, runner):
14
+ runner.log_buffer.update({'data_time': time.time() - self.t})
15
+
16
+ def after_iter(self, runner):
17
+ runner.log_buffer.update({'time': time.time() - self.t})
18
+ self.t = time.time()
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .base import LoggerHook
3
+ from .dvclive import DvcliveLoggerHook
4
+ from .mlflow import MlflowLoggerHook
5
+ from .neptune import NeptuneLoggerHook
6
+ from .pavi import PaviLoggerHook
7
+ from .tensorboard import TensorboardLoggerHook
8
+ from .text import TextLoggerHook
9
+ from .wandb import WandbLoggerHook
10
+
11
+ __all__ = [
12
+ 'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
13
+ 'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
14
+ 'NeptuneLoggerHook', 'DvcliveLoggerHook'
15
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/base.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import numbers
3
+ from abc import ABCMeta, abstractmethod
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from ..hook import Hook
9
+
10
+
11
+ class LoggerHook(Hook):
12
+ """Base class for logger hooks.
13
+
14
+ Args:
15
+ interval (int): Logging interval (every k iterations).
16
+ ignore_last (bool): Ignore the log of last iterations in each epoch
17
+ if less than `interval`.
18
+ reset_flag (bool): Whether to clear the output buffer after logging.
19
+ by_epoch (bool): Whether EpochBasedRunner is used.
20
+ """
21
+
22
+ __metaclass__ = ABCMeta
23
+
24
+ def __init__(self,
25
+ interval=10,
26
+ ignore_last=True,
27
+ reset_flag=False,
28
+ by_epoch=True):
29
+ self.interval = interval
30
+ self.ignore_last = ignore_last
31
+ self.reset_flag = reset_flag
32
+ self.by_epoch = by_epoch
33
+
34
+ @abstractmethod
35
+ def log(self, runner):
36
+ pass
37
+
38
+ @staticmethod
39
+ def is_scalar(val, include_np=True, include_torch=True):
40
+ """Tell the input variable is a scalar or not.
41
+
42
+ Args:
43
+ val: Input variable.
44
+ include_np (bool): Whether include 0-d np.ndarray as a scalar.
45
+ include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
46
+
47
+ Returns:
48
+ bool: True or False.
49
+ """
50
+ if isinstance(val, numbers.Number):
51
+ return True
52
+ elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
53
+ return True
54
+ elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
55
+ return True
56
+ else:
57
+ return False
58
+
59
+ def get_mode(self, runner):
60
+ if runner.mode == 'train':
61
+ if 'time' in runner.log_buffer.output:
62
+ mode = 'train'
63
+ else:
64
+ mode = 'val'
65
+ elif runner.mode == 'val':
66
+ mode = 'val'
67
+ else:
68
+ raise ValueError(f"runner mode should be 'train' or 'val', "
69
+ f'but got {runner.mode}')
70
+ return mode
71
+
72
+ def get_epoch(self, runner):
73
+ if runner.mode == 'train':
74
+ epoch = runner.epoch + 1
75
+ elif runner.mode == 'val':
76
+ # normal val mode
77
+ # runner.epoch += 1 has been done before val workflow
78
+ epoch = runner.epoch
79
+ else:
80
+ raise ValueError(f"runner mode should be 'train' or 'val', "
81
+ f'but got {runner.mode}')
82
+ return epoch
83
+
84
+ def get_iter(self, runner, inner_iter=False):
85
+ """Get the current training iteration step."""
86
+ if self.by_epoch and inner_iter:
87
+ current_iter = runner.inner_iter + 1
88
+ else:
89
+ current_iter = runner.iter + 1
90
+ return current_iter
91
+
92
+ def get_lr_tags(self, runner):
93
+ tags = {}
94
+ lrs = runner.current_lr()
95
+ if isinstance(lrs, dict):
96
+ for name, value in lrs.items():
97
+ tags[f'learning_rate/{name}'] = value[0]
98
+ else:
99
+ tags['learning_rate'] = lrs[0]
100
+ return tags
101
+
102
+ def get_momentum_tags(self, runner):
103
+ tags = {}
104
+ momentums = runner.current_momentum()
105
+ if isinstance(momentums, dict):
106
+ for name, value in momentums.items():
107
+ tags[f'momentum/{name}'] = value[0]
108
+ else:
109
+ tags['momentum'] = momentums[0]
110
+ return tags
111
+
112
+ def get_loggable_tags(self,
113
+ runner,
114
+ allow_scalar=True,
115
+ allow_text=False,
116
+ add_mode=True,
117
+ tags_to_skip=('time', 'data_time')):
118
+ tags = {}
119
+ for var, val in runner.log_buffer.output.items():
120
+ if var in tags_to_skip:
121
+ continue
122
+ if self.is_scalar(val) and not allow_scalar:
123
+ continue
124
+ if isinstance(val, str) and not allow_text:
125
+ continue
126
+ if add_mode:
127
+ var = f'{self.get_mode(runner)}/{var}'
128
+ tags[var] = val
129
+ tags.update(self.get_lr_tags(runner))
130
+ tags.update(self.get_momentum_tags(runner))
131
+ return tags
132
+
133
+ def before_run(self, runner):
134
+ for hook in runner.hooks[::-1]:
135
+ if isinstance(hook, LoggerHook):
136
+ hook.reset_flag = True
137
+ break
138
+
139
+ def before_epoch(self, runner):
140
+ runner.log_buffer.clear() # clear logs of last epoch
141
+
142
+ def after_train_iter(self, runner):
143
+ if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
144
+ runner.log_buffer.average(self.interval)
145
+ elif not self.by_epoch and self.every_n_iters(runner, self.interval):
146
+ runner.log_buffer.average(self.interval)
147
+ elif self.end_of_epoch(runner) and not self.ignore_last:
148
+ # not precise but more stable
149
+ runner.log_buffer.average(self.interval)
150
+
151
+ if runner.log_buffer.ready:
152
+ self.log(runner)
153
+ if self.reset_flag:
154
+ runner.log_buffer.clear_output()
155
+
156
+ def after_train_epoch(self, runner):
157
+ if runner.log_buffer.ready:
158
+ self.log(runner)
159
+ if self.reset_flag:
160
+ runner.log_buffer.clear_output()
161
+
162
+ def after_val_epoch(self, runner):
163
+ runner.log_buffer.average()
164
+ self.log(runner)
165
+ if self.reset_flag:
166
+ runner.log_buffer.clear_output()
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/dvclive.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from ...dist_utils import master_only
3
+ from ..hook import HOOKS
4
+ from .base import LoggerHook
5
+
6
+
7
+ @HOOKS.register_module()
8
+ class DvcliveLoggerHook(LoggerHook):
9
+ """Class to log metrics with dvclive.
10
+
11
+ It requires `dvclive`_ to be installed.
12
+
13
+ Args:
14
+ path (str): Directory where dvclive will write TSV log files.
15
+ interval (int): Logging interval (every k iterations).
16
+ Default 10.
17
+ ignore_last (bool): Ignore the log of last iterations in each epoch
18
+ if less than `interval`.
19
+ Default: True.
20
+ reset_flag (bool): Whether to clear the output buffer after logging.
21
+ Default: True.
22
+ by_epoch (bool): Whether EpochBasedRunner is used.
23
+ Default: True.
24
+
25
+ .. _dvclive:
26
+ https://dvc.org/doc/dvclive
27
+ """
28
+
29
+ def __init__(self,
30
+ path,
31
+ interval=10,
32
+ ignore_last=True,
33
+ reset_flag=True,
34
+ by_epoch=True):
35
+
36
+ super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
37
+ reset_flag, by_epoch)
38
+ self.path = path
39
+ self.import_dvclive()
40
+
41
+ def import_dvclive(self):
42
+ try:
43
+ import dvclive
44
+ except ImportError:
45
+ raise ImportError(
46
+ 'Please run "pip install dvclive" to install dvclive')
47
+ self.dvclive = dvclive
48
+
49
+ @master_only
50
+ def before_run(self, runner):
51
+ self.dvclive.init(self.path)
52
+
53
+ @master_only
54
+ def log(self, runner):
55
+ tags = self.get_loggable_tags(runner)
56
+ if tags:
57
+ for k, v in tags.items():
58
+ self.dvclive.log(k, v, step=self.get_iter(runner))
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/mlflow.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from ...dist_utils import master_only
3
+ from ..hook import HOOKS
4
+ from .base import LoggerHook
5
+
6
+
7
+ @HOOKS.register_module()
8
+ class MlflowLoggerHook(LoggerHook):
9
+
10
+ def __init__(self,
11
+ exp_name=None,
12
+ tags=None,
13
+ log_model=True,
14
+ interval=10,
15
+ ignore_last=True,
16
+ reset_flag=False,
17
+ by_epoch=True):
18
+ """Class to log metrics and (optionally) a trained model to MLflow.
19
+
20
+ It requires `MLflow`_ to be installed.
21
+
22
+ Args:
23
+ exp_name (str, optional): Name of the experiment to be used.
24
+ Default None.
25
+ If not None, set the active experiment.
26
+ If experiment does not exist, an experiment with provided name
27
+ will be created.
28
+ tags (dict of str: str, optional): Tags for the current run.
29
+ Default None.
30
+ If not None, set tags for the current run.
31
+ log_model (bool, optional): Whether to log an MLflow artifact.
32
+ Default True.
33
+ If True, log runner.model as an MLflow artifact
34
+ for the current run.
35
+ interval (int): Logging interval (every k iterations).
36
+ ignore_last (bool): Ignore the log of last iterations in each epoch
37
+ if less than `interval`.
38
+ reset_flag (bool): Whether to clear the output buffer after logging
39
+ by_epoch (bool): Whether EpochBasedRunner is used.
40
+
41
+ .. _MLflow:
42
+ https://www.mlflow.org/docs/latest/index.html
43
+ """
44
+ super(MlflowLoggerHook, self).__init__(interval, ignore_last,
45
+ reset_flag, by_epoch)
46
+ self.import_mlflow()
47
+ self.exp_name = exp_name
48
+ self.tags = tags
49
+ self.log_model = log_model
50
+
51
+ def import_mlflow(self):
52
+ try:
53
+ import mlflow
54
+ import mlflow.pytorch as mlflow_pytorch
55
+ except ImportError:
56
+ raise ImportError(
57
+ 'Please run "pip install mlflow" to install mlflow')
58
+ self.mlflow = mlflow
59
+ self.mlflow_pytorch = mlflow_pytorch
60
+
61
+ @master_only
62
+ def before_run(self, runner):
63
+ super(MlflowLoggerHook, self).before_run(runner)
64
+ if self.exp_name is not None:
65
+ self.mlflow.set_experiment(self.exp_name)
66
+ if self.tags is not None:
67
+ self.mlflow.set_tags(self.tags)
68
+
69
+ @master_only
70
+ def log(self, runner):
71
+ tags = self.get_loggable_tags(runner)
72
+ if tags:
73
+ self.mlflow.log_metrics(tags, step=self.get_iter(runner))
74
+
75
+ @master_only
76
+ def after_run(self, runner):
77
+ if self.log_model:
78
+ self.mlflow_pytorch.log_model(runner.model, 'models')
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/neptune.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from ...dist_utils import master_only
3
+ from ..hook import HOOKS
4
+ from .base import LoggerHook
5
+
6
+
7
+ @HOOKS.register_module()
8
+ class NeptuneLoggerHook(LoggerHook):
9
+ """Class to log metrics to NeptuneAI.
10
+
11
+ It requires `neptune-client` to be installed.
12
+
13
+ Args:
14
+ init_kwargs (dict): a dict contains the initialization keys as below:
15
+ - project (str): Name of a project in a form of
16
+ namespace/project_name. If None, the value of
17
+ NEPTUNE_PROJECT environment variable will be taken.
18
+ - api_token (str): User’s API token.
19
+ If None, the value of NEPTUNE_API_TOKEN environment
20
+ variable will be taken. Note: It is strongly recommended
21
+ to use NEPTUNE_API_TOKEN environment variable rather than
22
+ placing your API token in plain text in your source code.
23
+ - name (str, optional, default is 'Untitled'): Editable name of
24
+ the run. Name is displayed in the run's Details and in
25
+ Runs table as a column.
26
+ Check https://docs.neptune.ai/api-reference/neptune#init for
27
+ more init arguments.
28
+ interval (int): Logging interval (every k iterations).
29
+ ignore_last (bool): Ignore the log of last iterations in each epoch
30
+ if less than `interval`.
31
+ reset_flag (bool): Whether to clear the output buffer after logging
32
+ by_epoch (bool): Whether EpochBasedRunner is used.
33
+
34
+ .. _NeptuneAI:
35
+ https://docs.neptune.ai/you-should-know/logging-metadata
36
+ """
37
+
38
+ def __init__(self,
39
+ init_kwargs=None,
40
+ interval=10,
41
+ ignore_last=True,
42
+ reset_flag=True,
43
+ with_step=True,
44
+ by_epoch=True):
45
+
46
+ super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
47
+ reset_flag, by_epoch)
48
+ self.import_neptune()
49
+ self.init_kwargs = init_kwargs
50
+ self.with_step = with_step
51
+
52
+ def import_neptune(self):
53
+ try:
54
+ import neptune.new as neptune
55
+ except ImportError:
56
+ raise ImportError(
57
+ 'Please run "pip install neptune-client" to install neptune')
58
+ self.neptune = neptune
59
+ self.run = None
60
+
61
+ @master_only
62
+ def before_run(self, runner):
63
+ if self.init_kwargs:
64
+ self.run = self.neptune.init(**self.init_kwargs)
65
+ else:
66
+ self.run = self.neptune.init()
67
+
68
+ @master_only
69
+ def log(self, runner):
70
+ tags = self.get_loggable_tags(runner)
71
+ if tags:
72
+ for tag_name, tag_value in tags.items():
73
+ if self.with_step:
74
+ self.run[tag_name].log(
75
+ tag_value, step=self.get_iter(runner))
76
+ else:
77
+ tags['global_step'] = self.get_iter(runner)
78
+ self.run[tag_name].log(tags)
79
+
80
+ @master_only
81
+ def after_run(self, runner):
82
+ self.run.stop()
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/pavi.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import json
3
+ import os
4
+ import os.path as osp
5
+
6
+ import torch
7
+ import yaml
8
+
9
+ import annotator.mmpkg.mmcv as mmcv
10
+ from ....parallel.utils import is_module_wrapper
11
+ from ...dist_utils import master_only
12
+ from ..hook import HOOKS
13
+ from .base import LoggerHook
14
+
15
+
16
+ @HOOKS.register_module()
17
+ class PaviLoggerHook(LoggerHook):
18
+
19
+ def __init__(self,
20
+ init_kwargs=None,
21
+ add_graph=False,
22
+ add_last_ckpt=False,
23
+ interval=10,
24
+ ignore_last=True,
25
+ reset_flag=False,
26
+ by_epoch=True,
27
+ img_key='img_info'):
28
+ super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
29
+ by_epoch)
30
+ self.init_kwargs = init_kwargs
31
+ self.add_graph = add_graph
32
+ self.add_last_ckpt = add_last_ckpt
33
+ self.img_key = img_key
34
+
35
+ @master_only
36
+ def before_run(self, runner):
37
+ super(PaviLoggerHook, self).before_run(runner)
38
+ try:
39
+ from pavi import SummaryWriter
40
+ except ImportError:
41
+ raise ImportError('Please run "pip install pavi" to install pavi.')
42
+
43
+ self.run_name = runner.work_dir.split('/')[-1]
44
+
45
+ if not self.init_kwargs:
46
+ self.init_kwargs = dict()
47
+ self.init_kwargs['name'] = self.run_name
48
+ self.init_kwargs['model'] = runner._model_name
49
+ if runner.meta is not None:
50
+ if 'config_dict' in runner.meta:
51
+ config_dict = runner.meta['config_dict']
52
+ assert isinstance(
53
+ config_dict,
54
+ dict), ('meta["config_dict"] has to be of a dict, '
55
+ f'but got {type(config_dict)}')
56
+ elif 'config_file' in runner.meta:
57
+ config_file = runner.meta['config_file']
58
+ config_dict = dict(mmcv.Config.fromfile(config_file))
59
+ else:
60
+ config_dict = None
61
+ if config_dict is not None:
62
+ # 'max_.*iter' is parsed in pavi sdk as the maximum iterations
63
+ # to properly set up the progress bar.
64
+ config_dict = config_dict.copy()
65
+ config_dict.setdefault('max_iter', runner.max_iters)
66
+ # non-serializable values are first converted in
67
+ # mmcv.dump to json
68
+ config_dict = json.loads(
69
+ mmcv.dump(config_dict, file_format='json'))
70
+ session_text = yaml.dump(config_dict)
71
+ self.init_kwargs['session_text'] = session_text
72
+ self.writer = SummaryWriter(**self.init_kwargs)
73
+
74
+ def get_step(self, runner):
75
+ """Get the total training step/epoch."""
76
+ if self.get_mode(runner) == 'val' and self.by_epoch:
77
+ return self.get_epoch(runner)
78
+ else:
79
+ return self.get_iter(runner)
80
+
81
+ @master_only
82
+ def log(self, runner):
83
+ tags = self.get_loggable_tags(runner, add_mode=False)
84
+ if tags:
85
+ self.writer.add_scalars(
86
+ self.get_mode(runner), tags, self.get_step(runner))
87
+
88
+ @master_only
89
+ def after_run(self, runner):
90
+ if self.add_last_ckpt:
91
+ ckpt_path = osp.join(runner.work_dir, 'latest.pth')
92
+ if osp.islink(ckpt_path):
93
+ ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
94
+
95
+ if osp.isfile(ckpt_path):
96
+ # runner.epoch += 1 has been done before `after_run`.
97
+ iteration = runner.epoch if self.by_epoch else runner.iter
98
+ return self.writer.add_snapshot_file(
99
+ tag=self.run_name,
100
+ snapshot_file_path=ckpt_path,
101
+ iteration=iteration)
102
+
103
+ # flush the buffer and send a task ending signal to Pavi
104
+ self.writer.close()
105
+
106
+ @master_only
107
+ def before_epoch(self, runner):
108
+ if runner.epoch == 0 and self.add_graph:
109
+ if is_module_wrapper(runner.model):
110
+ _model = runner.model.module
111
+ else:
112
+ _model = runner.model
113
+ device = next(_model.parameters()).device
114
+ data = next(iter(runner.data_loader))
115
+ image = data[self.img_key][0:1].to(device)
116
+ with torch.no_grad():
117
+ self.writer.add_graph(_model, image)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/tensorboard.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+
4
+ from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
5
+ from ...dist_utils import master_only
6
+ from ..hook import HOOKS
7
+ from .base import LoggerHook
8
+
9
+
10
+ @HOOKS.register_module()
11
+ class TensorboardLoggerHook(LoggerHook):
12
+
13
+ def __init__(self,
14
+ log_dir=None,
15
+ interval=10,
16
+ ignore_last=True,
17
+ reset_flag=False,
18
+ by_epoch=True):
19
+ super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
20
+ reset_flag, by_epoch)
21
+ self.log_dir = log_dir
22
+
23
+ @master_only
24
+ def before_run(self, runner):
25
+ super(TensorboardLoggerHook, self).before_run(runner)
26
+ if (TORCH_VERSION == 'parrots'
27
+ or digit_version(TORCH_VERSION) < digit_version('1.1')):
28
+ try:
29
+ from tensorboardX import SummaryWriter
30
+ except ImportError:
31
+ raise ImportError('Please install tensorboardX to use '
32
+ 'TensorboardLoggerHook.')
33
+ else:
34
+ try:
35
+ from torch.utils.tensorboard import SummaryWriter
36
+ except ImportError:
37
+ raise ImportError(
38
+ 'Please run "pip install future tensorboard" to install '
39
+ 'the dependencies to use torch.utils.tensorboard '
40
+ '(applicable to PyTorch 1.1 or higher)')
41
+
42
+ if self.log_dir is None:
43
+ self.log_dir = osp.join(runner.work_dir, 'tf_logs')
44
+ self.writer = SummaryWriter(self.log_dir)
45
+
46
+ @master_only
47
+ def log(self, runner):
48
+ tags = self.get_loggable_tags(runner, allow_text=True)
49
+ for tag, val in tags.items():
50
+ if isinstance(val, str):
51
+ self.writer.add_text(tag, val, self.get_iter(runner))
52
+ else:
53
+ self.writer.add_scalar(tag, val, self.get_iter(runner))
54
+
55
+ @master_only
56
+ def after_run(self, runner):
57
+ self.writer.close()
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/text.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import datetime
3
+ import os
4
+ import os.path as osp
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+
10
+ import annotator.mmpkg.mmcv as mmcv
11
+ from annotator.mmpkg.mmcv.fileio.file_client import FileClient
12
+ from annotator.mmpkg.mmcv.utils import is_tuple_of, scandir
13
+ from ..hook import HOOKS
14
+ from .base import LoggerHook
15
+
16
+
17
+ @HOOKS.register_module()
18
+ class TextLoggerHook(LoggerHook):
19
+ """Logger hook in text.
20
+
21
+ In this logger hook, the information will be printed on terminal and
22
+ saved in json file.
23
+
24
+ Args:
25
+ by_epoch (bool, optional): Whether EpochBasedRunner is used.
26
+ Default: True.
27
+ interval (int, optional): Logging interval (every k iterations).
28
+ Default: 10.
29
+ ignore_last (bool, optional): Ignore the log of last iterations in each
30
+ epoch if less than :attr:`interval`. Default: True.
31
+ reset_flag (bool, optional): Whether to clear the output buffer after
32
+ logging. Default: False.
33
+ interval_exp_name (int, optional): Logging interval for experiment
34
+ name. This feature is to help users conveniently get the experiment
35
+ information from screen or log file. Default: 1000.
36
+ out_dir (str, optional): Logs are saved in ``runner.work_dir`` default.
37
+ If ``out_dir`` is specified, logs will be copied to a new directory
38
+ which is the concatenation of ``out_dir`` and the last level
39
+ directory of ``runner.work_dir``. Default: None.
40
+ `New in version 1.3.16.`
41
+ out_suffix (str or tuple[str], optional): Those filenames ending with
42
+ ``out_suffix`` will be copied to ``out_dir``.
43
+ Default: ('.log.json', '.log', '.py').
44
+ `New in version 1.3.16.`
45
+ keep_local (bool, optional): Whether to keep local log when
46
+ :attr:`out_dir` is specified. If False, the local log will be
47
+ removed. Default: True.
48
+ `New in version 1.3.16.`
49
+ file_client_args (dict, optional): Arguments to instantiate a
50
+ FileClient. See :class:`mmcv.fileio.FileClient` for details.
51
+ Default: None.
52
+ `New in version 1.3.16.`
53
+ """
54
+
55
+ def __init__(self,
56
+ by_epoch=True,
57
+ interval=10,
58
+ ignore_last=True,
59
+ reset_flag=False,
60
+ interval_exp_name=1000,
61
+ out_dir=None,
62
+ out_suffix=('.log.json', '.log', '.py'),
63
+ keep_local=True,
64
+ file_client_args=None):
65
+ super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
66
+ by_epoch)
67
+ self.by_epoch = by_epoch
68
+ self.time_sec_tot = 0
69
+ self.interval_exp_name = interval_exp_name
70
+
71
+ if out_dir is None and file_client_args is not None:
72
+ raise ValueError(
73
+ 'file_client_args should be "None" when `out_dir` is not'
74
+ 'specified.')
75
+ self.out_dir = out_dir
76
+
77
+ if not (out_dir is None or isinstance(out_dir, str)
78
+ or is_tuple_of(out_dir, str)):
79
+ raise TypeError('out_dir should be "None" or string or tuple of '
80
+ 'string, but got {out_dir}')
81
+ self.out_suffix = out_suffix
82
+
83
+ self.keep_local = keep_local
84
+ self.file_client_args = file_client_args
85
+ if self.out_dir is not None:
86
+ self.file_client = FileClient.infer_client(file_client_args,
87
+ self.out_dir)
88
+
89
+ def before_run(self, runner):
90
+ super(TextLoggerHook, self).before_run(runner)
91
+
92
+ if self.out_dir is not None:
93
+ self.file_client = FileClient.infer_client(self.file_client_args,
94
+ self.out_dir)
95
+ # The final `self.out_dir` is the concatenation of `self.out_dir`
96
+ # and the last level directory of `runner.work_dir`
97
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
98
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
99
+ runner.logger.info(
100
+ (f'Text logs will be saved to {self.out_dir} by '
101
+ f'{self.file_client.name} after the training process.'))
102
+
103
+ self.start_iter = runner.iter
104
+ self.json_log_path = osp.join(runner.work_dir,
105
+ f'{runner.timestamp}.log.json')
106
+ if runner.meta is not None:
107
+ self._dump_log(runner.meta, runner)
108
+
109
+ def _get_max_memory(self, runner):
110
+ device = getattr(runner.model, 'output_device', None)
111
+ mem = torch.cuda.max_memory_allocated(device=device)
112
+ mem_mb = torch.tensor([mem / (1024 * 1024)],
113
+ dtype=torch.int,
114
+ device=device)
115
+ if runner.world_size > 1:
116
+ dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
117
+ return mem_mb.item()
118
+
119
+ def _log_info(self, log_dict, runner):
120
+ # print exp name for users to distinguish experiments
121
+ # at every ``interval_exp_name`` iterations and the end of each epoch
122
+ if runner.meta is not None and 'exp_name' in runner.meta:
123
+ if (self.every_n_iters(runner, self.interval_exp_name)) or (
124
+ self.by_epoch and self.end_of_epoch(runner)):
125
+ exp_info = f'Exp name: {runner.meta["exp_name"]}'
126
+ runner.logger.info(exp_info)
127
+
128
+ if log_dict['mode'] == 'train':
129
+ if isinstance(log_dict['lr'], dict):
130
+ lr_str = []
131
+ for k, val in log_dict['lr'].items():
132
+ lr_str.append(f'lr_{k}: {val:.3e}')
133
+ lr_str = ' '.join(lr_str)
134
+ else:
135
+ lr_str = f'lr: {log_dict["lr"]:.3e}'
136
+
137
+ # by epoch: Epoch [4][100/1000]
138
+ # by iter: Iter [100/100000]
139
+ if self.by_epoch:
140
+ log_str = f'Epoch [{log_dict["epoch"]}]' \
141
+ f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
142
+ else:
143
+ log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
144
+ log_str += f'{lr_str}, '
145
+
146
+ if 'time' in log_dict.keys():
147
+ self.time_sec_tot += (log_dict['time'] * self.interval)
148
+ time_sec_avg = self.time_sec_tot / (
149
+ runner.iter - self.start_iter + 1)
150
+ eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
151
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
152
+ log_str += f'eta: {eta_str}, '
153
+ log_str += f'time: {log_dict["time"]:.3f}, ' \
154
+ f'data_time: {log_dict["data_time"]:.3f}, '
155
+ # statistic memory
156
+ if torch.cuda.is_available():
157
+ log_str += f'memory: {log_dict["memory"]}, '
158
+ else:
159
+ # val/test time
160
+ # here 1000 is the length of the val dataloader
161
+ # by epoch: Epoch[val] [4][1000]
162
+ # by iter: Iter[val] [1000]
163
+ if self.by_epoch:
164
+ log_str = f'Epoch({log_dict["mode"]}) ' \
165
+ f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
166
+ else:
167
+ log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
168
+
169
+ log_items = []
170
+ for name, val in log_dict.items():
171
+ # TODO: resolve this hack
172
+ # these items have been in log_str
173
+ if name in [
174
+ 'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time',
175
+ 'memory', 'epoch'
176
+ ]:
177
+ continue
178
+ if isinstance(val, float):
179
+ val = f'{val:.4f}'
180
+ log_items.append(f'{name}: {val}')
181
+ log_str += ', '.join(log_items)
182
+
183
+ runner.logger.info(log_str)
184
+
185
+ def _dump_log(self, log_dict, runner):
186
+ # dump log in json format
187
+ json_log = OrderedDict()
188
+ for k, v in log_dict.items():
189
+ json_log[k] = self._round_float(v)
190
+ # only append log at last line
191
+ if runner.rank == 0:
192
+ with open(self.json_log_path, 'a+') as f:
193
+ mmcv.dump(json_log, f, file_format='json')
194
+ f.write('\n')
195
+
196
+ def _round_float(self, items):
197
+ if isinstance(items, list):
198
+ return [self._round_float(item) for item in items]
199
+ elif isinstance(items, float):
200
+ return round(items, 5)
201
+ else:
202
+ return items
203
+
204
+ def log(self, runner):
205
+ if 'eval_iter_num' in runner.log_buffer.output:
206
+ # this doesn't modify runner.iter and is regardless of by_epoch
207
+ cur_iter = runner.log_buffer.output.pop('eval_iter_num')
208
+ else:
209
+ cur_iter = self.get_iter(runner, inner_iter=True)
210
+
211
+ log_dict = OrderedDict(
212
+ mode=self.get_mode(runner),
213
+ epoch=self.get_epoch(runner),
214
+ iter=cur_iter)
215
+
216
+ # only record lr of the first param group
217
+ cur_lr = runner.current_lr()
218
+ if isinstance(cur_lr, list):
219
+ log_dict['lr'] = cur_lr[0]
220
+ else:
221
+ assert isinstance(cur_lr, dict)
222
+ log_dict['lr'] = {}
223
+ for k, lr_ in cur_lr.items():
224
+ assert isinstance(lr_, list)
225
+ log_dict['lr'].update({k: lr_[0]})
226
+
227
+ if 'time' in runner.log_buffer.output:
228
+ # statistic memory
229
+ if torch.cuda.is_available():
230
+ log_dict['memory'] = self._get_max_memory(runner)
231
+
232
+ log_dict = dict(log_dict, **runner.log_buffer.output)
233
+
234
+ self._log_info(log_dict, runner)
235
+ self._dump_log(log_dict, runner)
236
+ return log_dict
237
+
238
+ def after_run(self, runner):
239
+ # copy or upload logs to self.out_dir
240
+ if self.out_dir is not None:
241
+ for filename in scandir(runner.work_dir, self.out_suffix, True):
242
+ local_filepath = osp.join(runner.work_dir, filename)
243
+ out_filepath = self.file_client.join_path(
244
+ self.out_dir, filename)
245
+ with open(local_filepath, 'r') as f:
246
+ self.file_client.put_text(f.read(), out_filepath)
247
+
248
+ runner.logger.info(
249
+ (f'The file {local_filepath} has been uploaded to '
250
+ f'{out_filepath}.'))
251
+
252
+ if not self.keep_local:
253
+ os.remove(local_filepath)
254
+ runner.logger.info(
255
+ (f'{local_filepath} was removed due to the '
256
+ '`self.keep_local=False`'))
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/wandb.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from ...dist_utils import master_only
3
+ from ..hook import HOOKS
4
+ from .base import LoggerHook
5
+
6
+
7
+ @HOOKS.register_module()
8
+ class WandbLoggerHook(LoggerHook):
9
+
10
+ def __init__(self,
11
+ init_kwargs=None,
12
+ interval=10,
13
+ ignore_last=True,
14
+ reset_flag=False,
15
+ commit=True,
16
+ by_epoch=True,
17
+ with_step=True):
18
+ super(WandbLoggerHook, self).__init__(interval, ignore_last,
19
+ reset_flag, by_epoch)
20
+ self.import_wandb()
21
+ self.init_kwargs = init_kwargs
22
+ self.commit = commit
23
+ self.with_step = with_step
24
+
25
+ def import_wandb(self):
26
+ try:
27
+ import wandb
28
+ except ImportError:
29
+ raise ImportError(
30
+ 'Please run "pip install wandb" to install wandb')
31
+ self.wandb = wandb
32
+
33
+ @master_only
34
+ def before_run(self, runner):
35
+ super(WandbLoggerHook, self).before_run(runner)
36
+ if self.wandb is None:
37
+ self.import_wandb()
38
+ if self.init_kwargs:
39
+ self.wandb.init(**self.init_kwargs)
40
+ else:
41
+ self.wandb.init()
42
+
43
+ @master_only
44
+ def log(self, runner):
45
+ tags = self.get_loggable_tags(runner)
46
+ if tags:
47
+ if self.with_step:
48
+ self.wandb.log(
49
+ tags, step=self.get_iter(runner), commit=self.commit)
50
+ else:
51
+ tags['global_step'] = self.get_iter(runner)
52
+ self.wandb.log(tags, commit=self.commit)
53
+
54
+ @master_only
55
+ def after_run(self, runner):
56
+ self.wandb.join()
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/lr_updater.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import numbers
3
+ from math import cos, pi
4
+
5
+ import annotator.mmpkg.mmcv as mmcv
6
+ from .hook import HOOKS, Hook
7
+
8
+
9
+ class LrUpdaterHook(Hook):
10
+ """LR Scheduler in MMCV.
11
+
12
+ Args:
13
+ by_epoch (bool): LR changes epoch by epoch
14
+ warmup (string): Type of warmup used. It can be None(use no warmup),
15
+ 'constant', 'linear' or 'exp'
16
+ warmup_iters (int): The number of iterations or epochs that warmup
17
+ lasts
18
+ warmup_ratio (float): LR used at the beginning of warmup equals to
19
+ warmup_ratio * initial_lr
20
+ warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters
21
+ means the number of epochs that warmup lasts, otherwise means the
22
+ number of iteration that warmup lasts
23
+ """
24
+
25
+ def __init__(self,
26
+ by_epoch=True,
27
+ warmup=None,
28
+ warmup_iters=0,
29
+ warmup_ratio=0.1,
30
+ warmup_by_epoch=False):
31
+ # validate the "warmup" argument
32
+ if warmup is not None:
33
+ if warmup not in ['constant', 'linear', 'exp']:
34
+ raise ValueError(
35
+ f'"{warmup}" is not a supported type for warming up, valid'
36
+ ' types are "constant" and "linear"')
37
+ if warmup is not None:
38
+ assert warmup_iters > 0, \
39
+ '"warmup_iters" must be a positive integer'
40
+ assert 0 < warmup_ratio <= 1.0, \
41
+ '"warmup_ratio" must be in range (0,1]'
42
+
43
+ self.by_epoch = by_epoch
44
+ self.warmup = warmup
45
+ self.warmup_iters = warmup_iters
46
+ self.warmup_ratio = warmup_ratio
47
+ self.warmup_by_epoch = warmup_by_epoch
48
+
49
+ if self.warmup_by_epoch:
50
+ self.warmup_epochs = self.warmup_iters
51
+ self.warmup_iters = None
52
+ else:
53
+ self.warmup_epochs = None
54
+
55
+ self.base_lr = [] # initial lr for all param groups
56
+ self.regular_lr = [] # expected lr if no warming up is performed
57
+
58
+ def _set_lr(self, runner, lr_groups):
59
+ if isinstance(runner.optimizer, dict):
60
+ for k, optim in runner.optimizer.items():
61
+ for param_group, lr in zip(optim.param_groups, lr_groups[k]):
62
+ param_group['lr'] = lr
63
+ else:
64
+ for param_group, lr in zip(runner.optimizer.param_groups,
65
+ lr_groups):
66
+ param_group['lr'] = lr
67
+
68
+ def get_lr(self, runner, base_lr):
69
+ raise NotImplementedError
70
+
71
+ def get_regular_lr(self, runner):
72
+ if isinstance(runner.optimizer, dict):
73
+ lr_groups = {}
74
+ for k in runner.optimizer.keys():
75
+ _lr_group = [
76
+ self.get_lr(runner, _base_lr)
77
+ for _base_lr in self.base_lr[k]
78
+ ]
79
+ lr_groups.update({k: _lr_group})
80
+
81
+ return lr_groups
82
+ else:
83
+ return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
84
+
85
+ def get_warmup_lr(self, cur_iters):
86
+
87
+ def _get_warmup_lr(cur_iters, regular_lr):
88
+ if self.warmup == 'constant':
89
+ warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
90
+ elif self.warmup == 'linear':
91
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
92
+ self.warmup_ratio)
93
+ warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
94
+ elif self.warmup == 'exp':
95
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
96
+ warmup_lr = [_lr * k for _lr in regular_lr]
97
+ return warmup_lr
98
+
99
+ if isinstance(self.regular_lr, dict):
100
+ lr_groups = {}
101
+ for key, regular_lr in self.regular_lr.items():
102
+ lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
103
+ return lr_groups
104
+ else:
105
+ return _get_warmup_lr(cur_iters, self.regular_lr)
106
+
107
+ def before_run(self, runner):
108
+ # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
109
+ # it will be set according to the optimizer params
110
+ if isinstance(runner.optimizer, dict):
111
+ self.base_lr = {}
112
+ for k, optim in runner.optimizer.items():
113
+ for group in optim.param_groups:
114
+ group.setdefault('initial_lr', group['lr'])
115
+ _base_lr = [
116
+ group['initial_lr'] for group in optim.param_groups
117
+ ]
118
+ self.base_lr.update({k: _base_lr})
119
+ else:
120
+ for group in runner.optimizer.param_groups:
121
+ group.setdefault('initial_lr', group['lr'])
122
+ self.base_lr = [
123
+ group['initial_lr'] for group in runner.optimizer.param_groups
124
+ ]
125
+
126
+ def before_train_epoch(self, runner):
127
+ if self.warmup_iters is None:
128
+ epoch_len = len(runner.data_loader)
129
+ self.warmup_iters = self.warmup_epochs * epoch_len
130
+
131
+ if not self.by_epoch:
132
+ return
133
+
134
+ self.regular_lr = self.get_regular_lr(runner)
135
+ self._set_lr(runner, self.regular_lr)
136
+
137
+ def before_train_iter(self, runner):
138
+ cur_iter = runner.iter
139
+ if not self.by_epoch:
140
+ self.regular_lr = self.get_regular_lr(runner)
141
+ if self.warmup is None or cur_iter >= self.warmup_iters:
142
+ self._set_lr(runner, self.regular_lr)
143
+ else:
144
+ warmup_lr = self.get_warmup_lr(cur_iter)
145
+ self._set_lr(runner, warmup_lr)
146
+ elif self.by_epoch:
147
+ if self.warmup is None or cur_iter > self.warmup_iters:
148
+ return
149
+ elif cur_iter == self.warmup_iters:
150
+ self._set_lr(runner, self.regular_lr)
151
+ else:
152
+ warmup_lr = self.get_warmup_lr(cur_iter)
153
+ self._set_lr(runner, warmup_lr)
154
+
155
+
156
+ @HOOKS.register_module()
157
+ class FixedLrUpdaterHook(LrUpdaterHook):
158
+
159
+ def __init__(self, **kwargs):
160
+ super(FixedLrUpdaterHook, self).__init__(**kwargs)
161
+
162
+ def get_lr(self, runner, base_lr):
163
+ return base_lr
164
+
165
+
166
+ @HOOKS.register_module()
167
+ class StepLrUpdaterHook(LrUpdaterHook):
168
+ """Step LR scheduler with min_lr clipping.
169
+
170
+ Args:
171
+ step (int | list[int]): Step to decay the LR. If an int value is given,
172
+ regard it as the decay interval. If a list is given, decay LR at
173
+ these steps.
174
+ gamma (float, optional): Decay LR ratio. Default: 0.1.
175
+ min_lr (float, optional): Minimum LR value to keep. If LR after decay
176
+ is lower than `min_lr`, it will be clipped to this value. If None
177
+ is given, we don't perform lr clipping. Default: None.
178
+ """
179
+
180
+ def __init__(self, step, gamma=0.1, min_lr=None, **kwargs):
181
+ if isinstance(step, list):
182
+ assert mmcv.is_list_of(step, int)
183
+ assert all([s > 0 for s in step])
184
+ elif isinstance(step, int):
185
+ assert step > 0
186
+ else:
187
+ raise TypeError('"step" must be a list or integer')
188
+ self.step = step
189
+ self.gamma = gamma
190
+ self.min_lr = min_lr
191
+ super(StepLrUpdaterHook, self).__init__(**kwargs)
192
+
193
+ def get_lr(self, runner, base_lr):
194
+ progress = runner.epoch if self.by_epoch else runner.iter
195
+
196
+ # calculate exponential term
197
+ if isinstance(self.step, int):
198
+ exp = progress // self.step
199
+ else:
200
+ exp = len(self.step)
201
+ for i, s in enumerate(self.step):
202
+ if progress < s:
203
+ exp = i
204
+ break
205
+
206
+ lr = base_lr * (self.gamma**exp)
207
+ if self.min_lr is not None:
208
+ # clip to a minimum value
209
+ lr = max(lr, self.min_lr)
210
+ return lr
211
+
212
+
213
+ @HOOKS.register_module()
214
+ class ExpLrUpdaterHook(LrUpdaterHook):
215
+
216
+ def __init__(self, gamma, **kwargs):
217
+ self.gamma = gamma
218
+ super(ExpLrUpdaterHook, self).__init__(**kwargs)
219
+
220
+ def get_lr(self, runner, base_lr):
221
+ progress = runner.epoch if self.by_epoch else runner.iter
222
+ return base_lr * self.gamma**progress
223
+
224
+
225
+ @HOOKS.register_module()
226
+ class PolyLrUpdaterHook(LrUpdaterHook):
227
+
228
+ def __init__(self, power=1., min_lr=0., **kwargs):
229
+ self.power = power
230
+ self.min_lr = min_lr
231
+ super(PolyLrUpdaterHook, self).__init__(**kwargs)
232
+
233
+ def get_lr(self, runner, base_lr):
234
+ if self.by_epoch:
235
+ progress = runner.epoch
236
+ max_progress = runner.max_epochs
237
+ else:
238
+ progress = runner.iter
239
+ max_progress = runner.max_iters
240
+ coeff = (1 - progress / max_progress)**self.power
241
+ return (base_lr - self.min_lr) * coeff + self.min_lr
242
+
243
+
244
+ @HOOKS.register_module()
245
+ class InvLrUpdaterHook(LrUpdaterHook):
246
+
247
+ def __init__(self, gamma, power=1., **kwargs):
248
+ self.gamma = gamma
249
+ self.power = power
250
+ super(InvLrUpdaterHook, self).__init__(**kwargs)
251
+
252
+ def get_lr(self, runner, base_lr):
253
+ progress = runner.epoch if self.by_epoch else runner.iter
254
+ return base_lr * (1 + self.gamma * progress)**(-self.power)
255
+
256
+
257
+ @HOOKS.register_module()
258
+ class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
259
+
260
+ def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
261
+ assert (min_lr is None) ^ (min_lr_ratio is None)
262
+ self.min_lr = min_lr
263
+ self.min_lr_ratio = min_lr_ratio
264
+ super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
265
+
266
+ def get_lr(self, runner, base_lr):
267
+ if self.by_epoch:
268
+ progress = runner.epoch
269
+ max_progress = runner.max_epochs
270
+ else:
271
+ progress = runner.iter
272
+ max_progress = runner.max_iters
273
+
274
+ if self.min_lr_ratio is not None:
275
+ target_lr = base_lr * self.min_lr_ratio
276
+ else:
277
+ target_lr = self.min_lr
278
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
279
+
280
+
281
+ @HOOKS.register_module()
282
+ class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
283
+ """Flat + Cosine lr schedule.
284
+
285
+ Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501
286
+
287
+ Args:
288
+ start_percent (float): When to start annealing the learning rate
289
+ after the percentage of the total training steps.
290
+ The value should be in range [0, 1).
291
+ Default: 0.75
292
+ min_lr (float, optional): The minimum lr. Default: None.
293
+ min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
294
+ Either `min_lr` or `min_lr_ratio` should be specified.
295
+ Default: None.
296
+ """
297
+
298
+ def __init__(self,
299
+ start_percent=0.75,
300
+ min_lr=None,
301
+ min_lr_ratio=None,
302
+ **kwargs):
303
+ assert (min_lr is None) ^ (min_lr_ratio is None)
304
+ if start_percent < 0 or start_percent > 1 or not isinstance(
305
+ start_percent, float):
306
+ raise ValueError(
307
+ 'expected float between 0 and 1 start_percent, but '
308
+ f'got {start_percent}')
309
+ self.start_percent = start_percent
310
+ self.min_lr = min_lr
311
+ self.min_lr_ratio = min_lr_ratio
312
+ super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
313
+
314
+ def get_lr(self, runner, base_lr):
315
+ if self.by_epoch:
316
+ start = round(runner.max_epochs * self.start_percent)
317
+ progress = runner.epoch - start
318
+ max_progress = runner.max_epochs - start
319
+ else:
320
+ start = round(runner.max_iters * self.start_percent)
321
+ progress = runner.iter - start
322
+ max_progress = runner.max_iters - start
323
+
324
+ if self.min_lr_ratio is not None:
325
+ target_lr = base_lr * self.min_lr_ratio
326
+ else:
327
+ target_lr = self.min_lr
328
+
329
+ if progress < 0:
330
+ return base_lr
331
+ else:
332
+ return annealing_cos(base_lr, target_lr, progress / max_progress)
333
+
334
+
335
+ @HOOKS.register_module()
336
+ class CosineRestartLrUpdaterHook(LrUpdaterHook):
337
+ """Cosine annealing with restarts learning rate scheme.
338
+
339
+ Args:
340
+ periods (list[int]): Periods for each cosine anneling cycle.
341
+ restart_weights (list[float], optional): Restart weights at each
342
+ restart iteration. Default: [1].
343
+ min_lr (float, optional): The minimum lr. Default: None.
344
+ min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
345
+ Either `min_lr` or `min_lr_ratio` should be specified.
346
+ Default: None.
347
+ """
348
+
349
+ def __init__(self,
350
+ periods,
351
+ restart_weights=[1],
352
+ min_lr=None,
353
+ min_lr_ratio=None,
354
+ **kwargs):
355
+ assert (min_lr is None) ^ (min_lr_ratio is None)
356
+ self.periods = periods
357
+ self.min_lr = min_lr
358
+ self.min_lr_ratio = min_lr_ratio
359
+ self.restart_weights = restart_weights
360
+ assert (len(self.periods) == len(self.restart_weights)
361
+ ), 'periods and restart_weights should have the same length.'
362
+ super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
363
+
364
+ self.cumulative_periods = [
365
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
366
+ ]
367
+
368
+ def get_lr(self, runner, base_lr):
369
+ if self.by_epoch:
370
+ progress = runner.epoch
371
+ else:
372
+ progress = runner.iter
373
+
374
+ if self.min_lr_ratio is not None:
375
+ target_lr = base_lr * self.min_lr_ratio
376
+ else:
377
+ target_lr = self.min_lr
378
+
379
+ idx = get_position_from_periods(progress, self.cumulative_periods)
380
+ current_weight = self.restart_weights[idx]
381
+ nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
382
+ current_periods = self.periods[idx]
383
+
384
+ alpha = min((progress - nearest_restart) / current_periods, 1)
385
+ return annealing_cos(base_lr, target_lr, alpha, current_weight)
386
+
387
+
388
+ def get_position_from_periods(iteration, cumulative_periods):
389
+ """Get the position from a period list.
390
+
391
+ It will return the index of the right-closest number in the period list.
392
+ For example, the cumulative_periods = [100, 200, 300, 400],
393
+ if iteration == 50, return 0;
394
+ if iteration == 210, return 2;
395
+ if iteration == 300, return 3.
396
+
397
+ Args:
398
+ iteration (int): Current iteration.
399
+ cumulative_periods (list[int]): Cumulative period list.
400
+
401
+ Returns:
402
+ int: The position of the right-closest number in the period list.
403
+ """
404
+ for i, period in enumerate(cumulative_periods):
405
+ if iteration < period:
406
+ return i
407
+ raise ValueError(f'Current iteration {iteration} exceeds '
408
+ f'cumulative_periods {cumulative_periods}')
409
+
410
+
411
+ @HOOKS.register_module()
412
+ class CyclicLrUpdaterHook(LrUpdaterHook):
413
+ """Cyclic LR Scheduler.
414
+
415
+ Implement the cyclical learning rate policy (CLR) described in
416
+ https://arxiv.org/pdf/1506.01186.pdf
417
+
418
+ Different from the original paper, we use cosine annealing rather than
419
+ triangular policy inside a cycle. This improves the performance in the
420
+ 3D detection area.
421
+
422
+ Args:
423
+ by_epoch (bool): Whether to update LR by epoch.
424
+ target_ratio (tuple[float]): Relative ratio of the highest LR and the
425
+ lowest LR to the initial LR.
426
+ cyclic_times (int): Number of cycles during training
427
+ step_ratio_up (float): The ratio of the increasing process of LR in
428
+ the total cycle.
429
+ anneal_strategy (str): {'cos', 'linear'}
430
+ Specifies the annealing strategy: 'cos' for cosine annealing,
431
+ 'linear' for linear annealing. Default: 'cos'.
432
+ """
433
+
434
+ def __init__(self,
435
+ by_epoch=False,
436
+ target_ratio=(10, 1e-4),
437
+ cyclic_times=1,
438
+ step_ratio_up=0.4,
439
+ anneal_strategy='cos',
440
+ **kwargs):
441
+ if isinstance(target_ratio, float):
442
+ target_ratio = (target_ratio, target_ratio / 1e5)
443
+ elif isinstance(target_ratio, tuple):
444
+ target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
445
+ if len(target_ratio) == 1 else target_ratio
446
+ else:
447
+ raise ValueError('target_ratio should be either float '
448
+ f'or tuple, got {type(target_ratio)}')
449
+
450
+ assert len(target_ratio) == 2, \
451
+ '"target_ratio" must be list or tuple of two floats'
452
+ assert 0 <= step_ratio_up < 1.0, \
453
+ '"step_ratio_up" must be in range [0,1)'
454
+
455
+ self.target_ratio = target_ratio
456
+ self.cyclic_times = cyclic_times
457
+ self.step_ratio_up = step_ratio_up
458
+ self.lr_phases = [] # init lr_phases
459
+ # validate anneal_strategy
460
+ if anneal_strategy not in ['cos', 'linear']:
461
+ raise ValueError('anneal_strategy must be one of "cos" or '
462
+ f'"linear", instead got {anneal_strategy}')
463
+ elif anneal_strategy == 'cos':
464
+ self.anneal_func = annealing_cos
465
+ elif anneal_strategy == 'linear':
466
+ self.anneal_func = annealing_linear
467
+
468
+ assert not by_epoch, \
469
+ 'currently only support "by_epoch" = False'
470
+ super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
471
+
472
+ def before_run(self, runner):
473
+ super(CyclicLrUpdaterHook, self).before_run(runner)
474
+ # initiate lr_phases
475
+ # total lr_phases are separated as up and down
476
+ max_iter_per_phase = runner.max_iters // self.cyclic_times
477
+ iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
478
+ self.lr_phases.append(
479
+ [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
480
+ self.lr_phases.append([
481
+ iter_up_phase, max_iter_per_phase, max_iter_per_phase,
482
+ self.target_ratio[0], self.target_ratio[1]
483
+ ])
484
+
485
+ def get_lr(self, runner, base_lr):
486
+ curr_iter = runner.iter
487
+ for (start_iter, end_iter, max_iter_per_phase, start_ratio,
488
+ end_ratio) in self.lr_phases:
489
+ curr_iter %= max_iter_per_phase
490
+ if start_iter <= curr_iter < end_iter:
491
+ progress = curr_iter - start_iter
492
+ return self.anneal_func(base_lr * start_ratio,
493
+ base_lr * end_ratio,
494
+ progress / (end_iter - start_iter))
495
+
496
+
497
+ @HOOKS.register_module()
498
+ class OneCycleLrUpdaterHook(LrUpdaterHook):
499
+ """One Cycle LR Scheduler.
500
+
501
+ The 1cycle learning rate policy changes the learning rate after every
502
+ batch. The one cycle learning rate policy is described in
503
+ https://arxiv.org/pdf/1708.07120.pdf
504
+
505
+ Args:
506
+ max_lr (float or list): Upper learning rate boundaries in the cycle
507
+ for each parameter group.
508
+ total_steps (int, optional): The total number of steps in the cycle.
509
+ Note that if a value is not provided here, it will be the max_iter
510
+ of runner. Default: None.
511
+ pct_start (float): The percentage of the cycle (in number of steps)
512
+ spent increasing the learning rate.
513
+ Default: 0.3
514
+ anneal_strategy (str): {'cos', 'linear'}
515
+ Specifies the annealing strategy: 'cos' for cosine annealing,
516
+ 'linear' for linear annealing.
517
+ Default: 'cos'
518
+ div_factor (float): Determines the initial learning rate via
519
+ initial_lr = max_lr/div_factor
520
+ Default: 25
521
+ final_div_factor (float): Determines the minimum learning rate via
522
+ min_lr = initial_lr/final_div_factor
523
+ Default: 1e4
524
+ three_phase (bool): If three_phase is True, use a third phase of the
525
+ schedule to annihilate the learning rate according to
526
+ final_div_factor instead of modifying the second phase (the first
527
+ two phases will be symmetrical about the step indicated by
528
+ pct_start).
529
+ Default: False
530
+ """
531
+
532
+ def __init__(self,
533
+ max_lr,
534
+ total_steps=None,
535
+ pct_start=0.3,
536
+ anneal_strategy='cos',
537
+ div_factor=25,
538
+ final_div_factor=1e4,
539
+ three_phase=False,
540
+ **kwargs):
541
+ # validate by_epoch, currently only support by_epoch = False
542
+ if 'by_epoch' not in kwargs:
543
+ kwargs['by_epoch'] = False
544
+ else:
545
+ assert not kwargs['by_epoch'], \
546
+ 'currently only support "by_epoch" = False'
547
+ if not isinstance(max_lr, (numbers.Number, list, dict)):
548
+ raise ValueError('the type of max_lr must be the one of list or '
549
+ f'dict, but got {type(max_lr)}')
550
+ self._max_lr = max_lr
551
+ if total_steps is not None:
552
+ if not isinstance(total_steps, int):
553
+ raise ValueError('the type of total_steps must be int, but'
554
+ f'got {type(total_steps)}')
555
+ self.total_steps = total_steps
556
+ # validate pct_start
557
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
558
+ raise ValueError('expected float between 0 and 1 pct_start, but '
559
+ f'got {pct_start}')
560
+ self.pct_start = pct_start
561
+ # validate anneal_strategy
562
+ if anneal_strategy not in ['cos', 'linear']:
563
+ raise ValueError('anneal_strategy must be one of "cos" or '
564
+ f'"linear", instead got {anneal_strategy}')
565
+ elif anneal_strategy == 'cos':
566
+ self.anneal_func = annealing_cos
567
+ elif anneal_strategy == 'linear':
568
+ self.anneal_func = annealing_linear
569
+ self.div_factor = div_factor
570
+ self.final_div_factor = final_div_factor
571
+ self.three_phase = three_phase
572
+ self.lr_phases = [] # init lr_phases
573
+ super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
574
+
575
+ def before_run(self, runner):
576
+ if hasattr(self, 'total_steps'):
577
+ total_steps = self.total_steps
578
+ else:
579
+ total_steps = runner.max_iters
580
+ if total_steps < runner.max_iters:
581
+ raise ValueError(
582
+ 'The total steps must be greater than or equal to max '
583
+ f'iterations {runner.max_iters} of runner, but total steps '
584
+ f'is {total_steps}.')
585
+
586
+ if isinstance(runner.optimizer, dict):
587
+ self.base_lr = {}
588
+ for k, optim in runner.optimizer.items():
589
+ _max_lr = format_param(k, optim, self._max_lr)
590
+ self.base_lr[k] = [lr / self.div_factor for lr in _max_lr]
591
+ for group, lr in zip(optim.param_groups, self.base_lr[k]):
592
+ group.setdefault('initial_lr', lr)
593
+ else:
594
+ k = type(runner.optimizer).__name__
595
+ _max_lr = format_param(k, runner.optimizer, self._max_lr)
596
+ self.base_lr = [lr / self.div_factor for lr in _max_lr]
597
+ for group, lr in zip(runner.optimizer.param_groups, self.base_lr):
598
+ group.setdefault('initial_lr', lr)
599
+
600
+ if self.three_phase:
601
+ self.lr_phases.append(
602
+ [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
603
+ self.lr_phases.append([
604
+ float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1
605
+ ])
606
+ self.lr_phases.append(
607
+ [total_steps - 1, 1, 1 / self.final_div_factor])
608
+ else:
609
+ self.lr_phases.append(
610
+ [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
611
+ self.lr_phases.append(
612
+ [total_steps - 1, self.div_factor, 1 / self.final_div_factor])
613
+
614
+ def get_lr(self, runner, base_lr):
615
+ curr_iter = runner.iter
616
+ start_iter = 0
617
+ for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
618
+ if curr_iter <= end_iter:
619
+ pct = (curr_iter - start_iter) / (end_iter - start_iter)
620
+ lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr,
621
+ pct)
622
+ break
623
+ start_iter = end_iter
624
+ return lr
625
+
626
+
627
+ def annealing_cos(start, end, factor, weight=1):
628
+ """Calculate annealing cos learning rate.
629
+
630
+ Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
631
+ percentage goes from 0.0 to 1.0.
632
+
633
+ Args:
634
+ start (float): The starting learning rate of the cosine annealing.
635
+ end (float): The ending learing rate of the cosine annealing.
636
+ factor (float): The coefficient of `pi` when calculating the current
637
+ percentage. Range from 0.0 to 1.0.
638
+ weight (float, optional): The combination factor of `start` and `end`
639
+ when calculating the actual starting learning rate. Default to 1.
640
+ """
641
+ cos_out = cos(pi * factor) + 1
642
+ return end + 0.5 * weight * (start - end) * cos_out
643
+
644
+
645
+ def annealing_linear(start, end, factor):
646
+ """Calculate annealing linear learning rate.
647
+
648
+ Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
649
+
650
+ Args:
651
+ start (float): The starting learning rate of the linear annealing.
652
+ end (float): The ending learing rate of the linear annealing.
653
+ factor (float): The coefficient of `pi` when calculating the current
654
+ percentage. Range from 0.0 to 1.0.
655
+ """
656
+ return start + (end - start) * factor
657
+
658
+
659
+ def format_param(name, optim, param):
660
+ if isinstance(param, numbers.Number):
661
+ return [param] * len(optim.param_groups)
662
+ elif isinstance(param, (list, tuple)): # multi param groups
663
+ if len(param) != len(optim.param_groups):
664
+ raise ValueError(f'expected {len(optim.param_groups)} '
665
+ f'values for {name}, got {len(param)}')
666
+ return param
667
+ else: # multi optimizers
668
+ if name not in param:
669
+ raise KeyError(f'{name} is not found in {param.keys()}')
670
+ return param[name]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/memory.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+
4
+ from .hook import HOOKS, Hook
5
+
6
+
7
+ @HOOKS.register_module()
8
+ class EmptyCacheHook(Hook):
9
+
10
+ def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
11
+ self._before_epoch = before_epoch
12
+ self._after_epoch = after_epoch
13
+ self._after_iter = after_iter
14
+
15
+ def after_iter(self, runner):
16
+ if self._after_iter:
17
+ torch.cuda.empty_cache()
18
+
19
+ def before_epoch(self, runner):
20
+ if self._before_epoch:
21
+ torch.cuda.empty_cache()
22
+
23
+ def after_epoch(self, runner):
24
+ if self._after_epoch:
25
+ torch.cuda.empty_cache()
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/momentum_updater.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import annotator.mmpkg.mmcv as mmcv
3
+ from .hook import HOOKS, Hook
4
+ from .lr_updater import annealing_cos, annealing_linear, format_param
5
+
6
+
7
+ class MomentumUpdaterHook(Hook):
8
+
9
+ def __init__(self,
10
+ by_epoch=True,
11
+ warmup=None,
12
+ warmup_iters=0,
13
+ warmup_ratio=0.9):
14
+ # validate the "warmup" argument
15
+ if warmup is not None:
16
+ if warmup not in ['constant', 'linear', 'exp']:
17
+ raise ValueError(
18
+ f'"{warmup}" is not a supported type for warming up, valid'
19
+ ' types are "constant" and "linear"')
20
+ if warmup is not None:
21
+ assert warmup_iters > 0, \
22
+ '"warmup_iters" must be a positive integer'
23
+ assert 0 < warmup_ratio <= 1.0, \
24
+ '"warmup_momentum" must be in range (0,1]'
25
+
26
+ self.by_epoch = by_epoch
27
+ self.warmup = warmup
28
+ self.warmup_iters = warmup_iters
29
+ self.warmup_ratio = warmup_ratio
30
+
31
+ self.base_momentum = [] # initial momentum for all param groups
32
+ self.regular_momentum = [
33
+ ] # expected momentum if no warming up is performed
34
+
35
+ def _set_momentum(self, runner, momentum_groups):
36
+ if isinstance(runner.optimizer, dict):
37
+ for k, optim in runner.optimizer.items():
38
+ for param_group, mom in zip(optim.param_groups,
39
+ momentum_groups[k]):
40
+ if 'momentum' in param_group.keys():
41
+ param_group['momentum'] = mom
42
+ elif 'betas' in param_group.keys():
43
+ param_group['betas'] = (mom, param_group['betas'][1])
44
+ else:
45
+ for param_group, mom in zip(runner.optimizer.param_groups,
46
+ momentum_groups):
47
+ if 'momentum' in param_group.keys():
48
+ param_group['momentum'] = mom
49
+ elif 'betas' in param_group.keys():
50
+ param_group['betas'] = (mom, param_group['betas'][1])
51
+
52
+ def get_momentum(self, runner, base_momentum):
53
+ raise NotImplementedError
54
+
55
+ def get_regular_momentum(self, runner):
56
+ if isinstance(runner.optimizer, dict):
57
+ momentum_groups = {}
58
+ for k in runner.optimizer.keys():
59
+ _momentum_group = [
60
+ self.get_momentum(runner, _base_momentum)
61
+ for _base_momentum in self.base_momentum[k]
62
+ ]
63
+ momentum_groups.update({k: _momentum_group})
64
+ return momentum_groups
65
+ else:
66
+ return [
67
+ self.get_momentum(runner, _base_momentum)
68
+ for _base_momentum in self.base_momentum
69
+ ]
70
+
71
+ def get_warmup_momentum(self, cur_iters):
72
+
73
+ def _get_warmup_momentum(cur_iters, regular_momentum):
74
+ if self.warmup == 'constant':
75
+ warmup_momentum = [
76
+ _momentum / self.warmup_ratio
77
+ for _momentum in self.regular_momentum
78
+ ]
79
+ elif self.warmup == 'linear':
80
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
81
+ self.warmup_ratio)
82
+ warmup_momentum = [
83
+ _momentum / (1 - k) for _momentum in self.regular_mom
84
+ ]
85
+ elif self.warmup == 'exp':
86
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
87
+ warmup_momentum = [
88
+ _momentum / k for _momentum in self.regular_mom
89
+ ]
90
+ return warmup_momentum
91
+
92
+ if isinstance(self.regular_momentum, dict):
93
+ momentum_groups = {}
94
+ for key, regular_momentum in self.regular_momentum.items():
95
+ momentum_groups[key] = _get_warmup_momentum(
96
+ cur_iters, regular_momentum)
97
+ return momentum_groups
98
+ else:
99
+ return _get_warmup_momentum(cur_iters, self.regular_momentum)
100
+
101
+ def before_run(self, runner):
102
+ # NOTE: when resuming from a checkpoint,
103
+ # if 'initial_momentum' is not saved,
104
+ # it will be set according to the optimizer params
105
+ if isinstance(runner.optimizer, dict):
106
+ self.base_momentum = {}
107
+ for k, optim in runner.optimizer.items():
108
+ for group in optim.param_groups:
109
+ if 'momentum' in group.keys():
110
+ group.setdefault('initial_momentum', group['momentum'])
111
+ else:
112
+ group.setdefault('initial_momentum', group['betas'][0])
113
+ _base_momentum = [
114
+ group['initial_momentum'] for group in optim.param_groups
115
+ ]
116
+ self.base_momentum.update({k: _base_momentum})
117
+ else:
118
+ for group in runner.optimizer.param_groups:
119
+ if 'momentum' in group.keys():
120
+ group.setdefault('initial_momentum', group['momentum'])
121
+ else:
122
+ group.setdefault('initial_momentum', group['betas'][0])
123
+ self.base_momentum = [
124
+ group['initial_momentum']
125
+ for group in runner.optimizer.param_groups
126
+ ]
127
+
128
+ def before_train_epoch(self, runner):
129
+ if not self.by_epoch:
130
+ return
131
+ self.regular_mom = self.get_regular_momentum(runner)
132
+ self._set_momentum(runner, self.regular_mom)
133
+
134
+ def before_train_iter(self, runner):
135
+ cur_iter = runner.iter
136
+ if not self.by_epoch:
137
+ self.regular_mom = self.get_regular_momentum(runner)
138
+ if self.warmup is None or cur_iter >= self.warmup_iters:
139
+ self._set_momentum(runner, self.regular_mom)
140
+ else:
141
+ warmup_momentum = self.get_warmup_momentum(cur_iter)
142
+ self._set_momentum(runner, warmup_momentum)
143
+ elif self.by_epoch:
144
+ if self.warmup is None or cur_iter > self.warmup_iters:
145
+ return
146
+ elif cur_iter == self.warmup_iters:
147
+ self._set_momentum(runner, self.regular_mom)
148
+ else:
149
+ warmup_momentum = self.get_warmup_momentum(cur_iter)
150
+ self._set_momentum(runner, warmup_momentum)
151
+
152
+
153
+ @HOOKS.register_module()
154
+ class StepMomentumUpdaterHook(MomentumUpdaterHook):
155
+ """Step momentum scheduler with min value clipping.
156
+
157
+ Args:
158
+ step (int | list[int]): Step to decay the momentum. If an int value is
159
+ given, regard it as the decay interval. If a list is given, decay
160
+ momentum at these steps.
161
+ gamma (float, optional): Decay momentum ratio. Default: 0.5.
162
+ min_momentum (float, optional): Minimum momentum value to keep. If
163
+ momentum after decay is lower than this value, it will be clipped
164
+ accordingly. If None is given, we don't perform lr clipping.
165
+ Default: None.
166
+ """
167
+
168
+ def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs):
169
+ if isinstance(step, list):
170
+ assert mmcv.is_list_of(step, int)
171
+ assert all([s > 0 for s in step])
172
+ elif isinstance(step, int):
173
+ assert step > 0
174
+ else:
175
+ raise TypeError('"step" must be a list or integer')
176
+ self.step = step
177
+ self.gamma = gamma
178
+ self.min_momentum = min_momentum
179
+ super(StepMomentumUpdaterHook, self).__init__(**kwargs)
180
+
181
+ def get_momentum(self, runner, base_momentum):
182
+ progress = runner.epoch if self.by_epoch else runner.iter
183
+
184
+ # calculate exponential term
185
+ if isinstance(self.step, int):
186
+ exp = progress // self.step
187
+ else:
188
+ exp = len(self.step)
189
+ for i, s in enumerate(self.step):
190
+ if progress < s:
191
+ exp = i
192
+ break
193
+
194
+ momentum = base_momentum * (self.gamma**exp)
195
+ if self.min_momentum is not None:
196
+ # clip to a minimum value
197
+ momentum = max(momentum, self.min_momentum)
198
+ return momentum
199
+
200
+
201
+ @HOOKS.register_module()
202
+ class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
203
+
204
+ def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
205
+ assert (min_momentum is None) ^ (min_momentum_ratio is None)
206
+ self.min_momentum = min_momentum
207
+ self.min_momentum_ratio = min_momentum_ratio
208
+ super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
209
+
210
+ def get_momentum(self, runner, base_momentum):
211
+ if self.by_epoch:
212
+ progress = runner.epoch
213
+ max_progress = runner.max_epochs
214
+ else:
215
+ progress = runner.iter
216
+ max_progress = runner.max_iters
217
+ if self.min_momentum_ratio is not None:
218
+ target_momentum = base_momentum * self.min_momentum_ratio
219
+ else:
220
+ target_momentum = self.min_momentum
221
+ return annealing_cos(base_momentum, target_momentum,
222
+ progress / max_progress)
223
+
224
+
225
+ @HOOKS.register_module()
226
+ class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
227
+ """Cyclic momentum Scheduler.
228
+
229
+ Implement the cyclical momentum scheduler policy described in
230
+ https://arxiv.org/pdf/1708.07120.pdf
231
+
232
+ This momentum scheduler usually used together with the CyclicLRUpdater
233
+ to improve the performance in the 3D detection area.
234
+
235
+ Attributes:
236
+ target_ratio (tuple[float]): Relative ratio of the lowest momentum and
237
+ the highest momentum to the initial momentum.
238
+ cyclic_times (int): Number of cycles during training
239
+ step_ratio_up (float): The ratio of the increasing process of momentum
240
+ in the total cycle.
241
+ by_epoch (bool): Whether to update momentum by epoch.
242
+ """
243
+
244
+ def __init__(self,
245
+ by_epoch=False,
246
+ target_ratio=(0.85 / 0.95, 1),
247
+ cyclic_times=1,
248
+ step_ratio_up=0.4,
249
+ **kwargs):
250
+ if isinstance(target_ratio, float):
251
+ target_ratio = (target_ratio, target_ratio / 1e5)
252
+ elif isinstance(target_ratio, tuple):
253
+ target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
254
+ if len(target_ratio) == 1 else target_ratio
255
+ else:
256
+ raise ValueError('target_ratio should be either float '
257
+ f'or tuple, got {type(target_ratio)}')
258
+
259
+ assert len(target_ratio) == 2, \
260
+ '"target_ratio" must be list or tuple of two floats'
261
+ assert 0 <= step_ratio_up < 1.0, \
262
+ '"step_ratio_up" must be in range [0,1)'
263
+
264
+ self.target_ratio = target_ratio
265
+ self.cyclic_times = cyclic_times
266
+ self.step_ratio_up = step_ratio_up
267
+ self.momentum_phases = [] # init momentum_phases
268
+ # currently only support by_epoch=False
269
+ assert not by_epoch, \
270
+ 'currently only support "by_epoch" = False'
271
+ super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
272
+
273
+ def before_run(self, runner):
274
+ super(CyclicMomentumUpdaterHook, self).before_run(runner)
275
+ # initiate momentum_phases
276
+ # total momentum_phases are separated as up and down
277
+ max_iter_per_phase = runner.max_iters // self.cyclic_times
278
+ iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
279
+ self.momentum_phases.append(
280
+ [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
281
+ self.momentum_phases.append([
282
+ iter_up_phase, max_iter_per_phase, max_iter_per_phase,
283
+ self.target_ratio[0], self.target_ratio[1]
284
+ ])
285
+
286
+ def get_momentum(self, runner, base_momentum):
287
+ curr_iter = runner.iter
288
+ for (start_iter, end_iter, max_iter_per_phase, start_ratio,
289
+ end_ratio) in self.momentum_phases:
290
+ curr_iter %= max_iter_per_phase
291
+ if start_iter <= curr_iter < end_iter:
292
+ progress = curr_iter - start_iter
293
+ return annealing_cos(base_momentum * start_ratio,
294
+ base_momentum * end_ratio,
295
+ progress / (end_iter - start_iter))
296
+
297
+
298
+ @HOOKS.register_module()
299
+ class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
300
+ """OneCycle momentum Scheduler.
301
+
302
+ This momentum scheduler usually used together with the OneCycleLrUpdater
303
+ to improve the performance.
304
+
305
+ Args:
306
+ base_momentum (float or list): Lower momentum boundaries in the cycle
307
+ for each parameter group. Note that momentum is cycled inversely
308
+ to learning rate; at the peak of a cycle, momentum is
309
+ 'base_momentum' and learning rate is 'max_lr'.
310
+ Default: 0.85
311
+ max_momentum (float or list): Upper momentum boundaries in the cycle
312
+ for each parameter group. Functionally,
313
+ it defines the cycle amplitude (max_momentum - base_momentum).
314
+ Note that momentum is cycled inversely
315
+ to learning rate; at the start of a cycle, momentum is
316
+ 'max_momentum' and learning rate is 'base_lr'
317
+ Default: 0.95
318
+ pct_start (float): The percentage of the cycle (in number of steps)
319
+ spent increasing the learning rate.
320
+ Default: 0.3
321
+ anneal_strategy (str): {'cos', 'linear'}
322
+ Specifies the annealing strategy: 'cos' for cosine annealing,
323
+ 'linear' for linear annealing.
324
+ Default: 'cos'
325
+ three_phase (bool): If three_phase is True, use a third phase of the
326
+ schedule to annihilate the learning rate according to
327
+ final_div_factor instead of modifying the second phase (the first
328
+ two phases will be symmetrical about the step indicated by
329
+ pct_start).
330
+ Default: False
331
+ """
332
+
333
+ def __init__(self,
334
+ base_momentum=0.85,
335
+ max_momentum=0.95,
336
+ pct_start=0.3,
337
+ anneal_strategy='cos',
338
+ three_phase=False,
339
+ **kwargs):
340
+ # validate by_epoch, currently only support by_epoch=False
341
+ if 'by_epoch' not in kwargs:
342
+ kwargs['by_epoch'] = False
343
+ else:
344
+ assert not kwargs['by_epoch'], \
345
+ 'currently only support "by_epoch" = False'
346
+ if not isinstance(base_momentum, (float, list, dict)):
347
+ raise ValueError('base_momentum must be the type among of float,'
348
+ 'list or dict.')
349
+ self._base_momentum = base_momentum
350
+ if not isinstance(max_momentum, (float, list, dict)):
351
+ raise ValueError('max_momentum must be the type among of float,'
352
+ 'list or dict.')
353
+ self._max_momentum = max_momentum
354
+ # validate pct_start
355
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
356
+ raise ValueError('Expected float between 0 and 1 pct_start, but '
357
+ f'got {pct_start}')
358
+ self.pct_start = pct_start
359
+ # validate anneal_strategy
360
+ if anneal_strategy not in ['cos', 'linear']:
361
+ raise ValueError('anneal_strategy must by one of "cos" or '
362
+ f'"linear", instead got {anneal_strategy}')
363
+ elif anneal_strategy == 'cos':
364
+ self.anneal_func = annealing_cos
365
+ elif anneal_strategy == 'linear':
366
+ self.anneal_func = annealing_linear
367
+ self.three_phase = three_phase
368
+ self.momentum_phases = [] # init momentum_phases
369
+ super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
370
+
371
+ def before_run(self, runner):
372
+ if isinstance(runner.optimizer, dict):
373
+ for k, optim in runner.optimizer.items():
374
+ if ('momentum' not in optim.defaults
375
+ and 'betas' not in optim.defaults):
376
+ raise ValueError('optimizer must support momentum with'
377
+ 'option enabled')
378
+ self.use_beta1 = 'betas' in optim.defaults
379
+ _base_momentum = format_param(k, optim, self._base_momentum)
380
+ _max_momentum = format_param(k, optim, self._max_momentum)
381
+ for group, b_momentum, m_momentum in zip(
382
+ optim.param_groups, _base_momentum, _max_momentum):
383
+ if self.use_beta1:
384
+ _, beta2 = group['betas']
385
+ group['betas'] = (m_momentum, beta2)
386
+ else:
387
+ group['momentum'] = m_momentum
388
+ group['base_momentum'] = b_momentum
389
+ group['max_momentum'] = m_momentum
390
+ else:
391
+ optim = runner.optimizer
392
+ if ('momentum' not in optim.defaults
393
+ and 'betas' not in optim.defaults):
394
+ raise ValueError('optimizer must support momentum with'
395
+ 'option enabled')
396
+ self.use_beta1 = 'betas' in optim.defaults
397
+ k = type(optim).__name__
398
+ _base_momentum = format_param(k, optim, self._base_momentum)
399
+ _max_momentum = format_param(k, optim, self._max_momentum)
400
+ for group, b_momentum, m_momentum in zip(optim.param_groups,
401
+ _base_momentum,
402
+ _max_momentum):
403
+ if self.use_beta1:
404
+ _, beta2 = group['betas']
405
+ group['betas'] = (m_momentum, beta2)
406
+ else:
407
+ group['momentum'] = m_momentum
408
+ group['base_momentum'] = b_momentum
409
+ group['max_momentum'] = m_momentum
410
+
411
+ if self.three_phase:
412
+ self.momentum_phases.append({
413
+ 'end_iter':
414
+ float(self.pct_start * runner.max_iters) - 1,
415
+ 'start_momentum':
416
+ 'max_momentum',
417
+ 'end_momentum':
418
+ 'base_momentum'
419
+ })
420
+ self.momentum_phases.append({
421
+ 'end_iter':
422
+ float(2 * self.pct_start * runner.max_iters) - 2,
423
+ 'start_momentum':
424
+ 'base_momentum',
425
+ 'end_momentum':
426
+ 'max_momentum'
427
+ })
428
+ self.momentum_phases.append({
429
+ 'end_iter': runner.max_iters - 1,
430
+ 'start_momentum': 'max_momentum',
431
+ 'end_momentum': 'max_momentum'
432
+ })
433
+ else:
434
+ self.momentum_phases.append({
435
+ 'end_iter':
436
+ float(self.pct_start * runner.max_iters) - 1,
437
+ 'start_momentum':
438
+ 'max_momentum',
439
+ 'end_momentum':
440
+ 'base_momentum'
441
+ })
442
+ self.momentum_phases.append({
443
+ 'end_iter': runner.max_iters - 1,
444
+ 'start_momentum': 'base_momentum',
445
+ 'end_momentum': 'max_momentum'
446
+ })
447
+
448
+ def _set_momentum(self, runner, momentum_groups):
449
+ if isinstance(runner.optimizer, dict):
450
+ for k, optim in runner.optimizer.items():
451
+ for param_group, mom in zip(optim.param_groups,
452
+ momentum_groups[k]):
453
+ if 'momentum' in param_group.keys():
454
+ param_group['momentum'] = mom
455
+ elif 'betas' in param_group.keys():
456
+ param_group['betas'] = (mom, param_group['betas'][1])
457
+ else:
458
+ for param_group, mom in zip(runner.optimizer.param_groups,
459
+ momentum_groups):
460
+ if 'momentum' in param_group.keys():
461
+ param_group['momentum'] = mom
462
+ elif 'betas' in param_group.keys():
463
+ param_group['betas'] = (mom, param_group['betas'][1])
464
+
465
+ def get_momentum(self, runner, param_group):
466
+ curr_iter = runner.iter
467
+ start_iter = 0
468
+ for i, phase in enumerate(self.momentum_phases):
469
+ end_iter = phase['end_iter']
470
+ if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
471
+ pct = (curr_iter - start_iter) / (end_iter - start_iter)
472
+ momentum = self.anneal_func(
473
+ param_group[phase['start_momentum']],
474
+ param_group[phase['end_momentum']], pct)
475
+ break
476
+ start_iter = end_iter
477
+ return momentum
478
+
479
+ def get_regular_momentum(self, runner):
480
+ if isinstance(runner.optimizer, dict):
481
+ momentum_groups = {}
482
+ for k, optim in runner.optimizer.items():
483
+ _momentum_group = [
484
+ self.get_momentum(runner, param_group)
485
+ for param_group in optim.param_groups
486
+ ]
487
+ momentum_groups.update({k: _momentum_group})
488
+ return momentum_groups
489
+ else:
490
+ momentum_groups = []
491
+ for param_group in runner.optimizer.param_groups:
492
+ momentum_groups.append(self.get_momentum(runner, param_group))
493
+ return momentum_groups
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/optimizer.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ from collections import defaultdict
4
+ from itertools import chain
5
+
6
+ from torch.nn.utils import clip_grad
7
+
8
+ from annotator.mmpkg.mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
9
+ from ..dist_utils import allreduce_grads
10
+ from ..fp16_utils import LossScaler, wrap_fp16_model
11
+ from .hook import HOOKS, Hook
12
+
13
+ try:
14
+ # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
15
+ # and used; otherwise, auto fp16 will adopt mmcv's implementation.
16
+ from torch.cuda.amp import GradScaler
17
+ except ImportError:
18
+ pass
19
+
20
+
21
+ @HOOKS.register_module()
22
+ class OptimizerHook(Hook):
23
+
24
+ def __init__(self, grad_clip=None):
25
+ self.grad_clip = grad_clip
26
+
27
+ def clip_grads(self, params):
28
+ params = list(
29
+ filter(lambda p: p.requires_grad and p.grad is not None, params))
30
+ if len(params) > 0:
31
+ return clip_grad.clip_grad_norm_(params, **self.grad_clip)
32
+
33
+ def after_train_iter(self, runner):
34
+ runner.optimizer.zero_grad()
35
+ runner.outputs['loss'].backward()
36
+ if self.grad_clip is not None:
37
+ grad_norm = self.clip_grads(runner.model.parameters())
38
+ if grad_norm is not None:
39
+ # Add grad norm to the logger
40
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
41
+ runner.outputs['num_samples'])
42
+ runner.optimizer.step()
43
+
44
+
45
+ @HOOKS.register_module()
46
+ class GradientCumulativeOptimizerHook(OptimizerHook):
47
+ """Optimizer Hook implements multi-iters gradient cumulating.
48
+
49
+ Args:
50
+ cumulative_iters (int, optional): Num of gradient cumulative iters.
51
+ The optimizer will step every `cumulative_iters` iters.
52
+ Defaults to 1.
53
+
54
+ Examples:
55
+ >>> # Use cumulative_iters to simulate a large batch size
56
+ >>> # It is helpful when the hardware cannot handle a large batch size.
57
+ >>> loader = DataLoader(data, batch_size=64)
58
+ >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4)
59
+ >>> # almost equals to
60
+ >>> loader = DataLoader(data, batch_size=256)
61
+ >>> optim_hook = OptimizerHook()
62
+ """
63
+
64
+ def __init__(self, cumulative_iters=1, **kwargs):
65
+ super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
66
+
67
+ assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
68
+ f'cumulative_iters only accepts positive int, but got ' \
69
+ f'{type(cumulative_iters)} instead.'
70
+
71
+ self.cumulative_iters = cumulative_iters
72
+ self.divisible_iters = 0
73
+ self.remainder_iters = 0
74
+ self.initialized = False
75
+
76
+ def has_batch_norm(self, module):
77
+ if isinstance(module, _BatchNorm):
78
+ return True
79
+ for m in module.children():
80
+ if self.has_batch_norm(m):
81
+ return True
82
+ return False
83
+
84
+ def _init(self, runner):
85
+ if runner.iter % self.cumulative_iters != 0:
86
+ runner.logger.warning(
87
+ 'Resume iter number is not divisible by cumulative_iters in '
88
+ 'GradientCumulativeOptimizerHook, which means the gradient of '
89
+ 'some iters is lost and the result may be influenced slightly.'
90
+ )
91
+
92
+ if self.has_batch_norm(runner.model) and self.cumulative_iters > 1:
93
+ runner.logger.warning(
94
+ 'GradientCumulativeOptimizerHook may slightly decrease '
95
+ 'performance if the model has BatchNorm layers.')
96
+
97
+ residual_iters = runner.max_iters - runner.iter
98
+
99
+ self.divisible_iters = (
100
+ residual_iters // self.cumulative_iters * self.cumulative_iters)
101
+ self.remainder_iters = residual_iters - self.divisible_iters
102
+
103
+ self.initialized = True
104
+
105
+ def after_train_iter(self, runner):
106
+ if not self.initialized:
107
+ self._init(runner)
108
+
109
+ if runner.iter < self.divisible_iters:
110
+ loss_factor = self.cumulative_iters
111
+ else:
112
+ loss_factor = self.remainder_iters
113
+ loss = runner.outputs['loss']
114
+ loss = loss / loss_factor
115
+ loss.backward()
116
+
117
+ if (self.every_n_iters(runner, self.cumulative_iters)
118
+ or self.is_last_iter(runner)):
119
+
120
+ if self.grad_clip is not None:
121
+ grad_norm = self.clip_grads(runner.model.parameters())
122
+ if grad_norm is not None:
123
+ # Add grad norm to the logger
124
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
125
+ runner.outputs['num_samples'])
126
+ runner.optimizer.step()
127
+ runner.optimizer.zero_grad()
128
+
129
+
130
+ if (TORCH_VERSION != 'parrots'
131
+ and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
132
+
133
+ @HOOKS.register_module()
134
+ class Fp16OptimizerHook(OptimizerHook):
135
+ """FP16 optimizer hook (using PyTorch's implementation).
136
+
137
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
138
+ to take care of the optimization procedure.
139
+
140
+ Args:
141
+ loss_scale (float | str | dict): Scale factor configuration.
142
+ If loss_scale is a float, static loss scaling will be used with
143
+ the specified scale. If loss_scale is a string, it must be
144
+ 'dynamic', then dynamic loss scaling will be used.
145
+ It can also be a dict containing arguments of GradScalar.
146
+ Defaults to 512. For Pytorch >= 1.6, mmcv uses official
147
+ implementation of GradScaler. If you use a dict version of
148
+ loss_scale to create GradScaler, please refer to:
149
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
150
+ for the parameters.
151
+
152
+ Examples:
153
+ >>> loss_scale = dict(
154
+ ... init_scale=65536.0,
155
+ ... growth_factor=2.0,
156
+ ... backoff_factor=0.5,
157
+ ... growth_interval=2000
158
+ ... )
159
+ >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
160
+ """
161
+
162
+ def __init__(self,
163
+ grad_clip=None,
164
+ coalesce=True,
165
+ bucket_size_mb=-1,
166
+ loss_scale=512.,
167
+ distributed=True):
168
+ self.grad_clip = grad_clip
169
+ self.coalesce = coalesce
170
+ self.bucket_size_mb = bucket_size_mb
171
+ self.distributed = distributed
172
+ self._scale_update_param = None
173
+ if loss_scale == 'dynamic':
174
+ self.loss_scaler = GradScaler()
175
+ elif isinstance(loss_scale, float):
176
+ self._scale_update_param = loss_scale
177
+ self.loss_scaler = GradScaler(init_scale=loss_scale)
178
+ elif isinstance(loss_scale, dict):
179
+ self.loss_scaler = GradScaler(**loss_scale)
180
+ else:
181
+ raise ValueError('loss_scale must be of type float, dict, or '
182
+ f'"dynamic", got {loss_scale}')
183
+
184
+ def before_run(self, runner):
185
+ """Preparing steps before Mixed Precision Training."""
186
+ # wrap model mode to fp16
187
+ wrap_fp16_model(runner.model)
188
+ # resume from state dict
189
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
190
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
191
+ self.loss_scaler.load_state_dict(scaler_state_dict)
192
+
193
+ def copy_grads_to_fp32(self, fp16_net, fp32_weights):
194
+ """Copy gradients from fp16 model to fp32 weight copy."""
195
+ for fp32_param, fp16_param in zip(fp32_weights,
196
+ fp16_net.parameters()):
197
+ if fp16_param.grad is not None:
198
+ if fp32_param.grad is None:
199
+ fp32_param.grad = fp32_param.data.new(
200
+ fp32_param.size())
201
+ fp32_param.grad.copy_(fp16_param.grad)
202
+
203
+ def copy_params_to_fp16(self, fp16_net, fp32_weights):
204
+ """Copy updated params from fp32 weight copy to fp16 model."""
205
+ for fp16_param, fp32_param in zip(fp16_net.parameters(),
206
+ fp32_weights):
207
+ fp16_param.data.copy_(fp32_param.data)
208
+
209
+ def after_train_iter(self, runner):
210
+ """Backward optimization steps for Mixed Precision Training. For
211
+ dynamic loss scaling, please refer to
212
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
213
+
214
+ 1. Scale the loss by a scale factor.
215
+ 2. Backward the loss to obtain the gradients.
216
+ 3. Unscale the optimizer’s gradient tensors.
217
+ 4. Call optimizer.step() and update scale factor.
218
+ 5. Save loss_scaler state_dict for resume purpose.
219
+ """
220
+ # clear grads of last iteration
221
+ runner.model.zero_grad()
222
+ runner.optimizer.zero_grad()
223
+
224
+ self.loss_scaler.scale(runner.outputs['loss']).backward()
225
+ self.loss_scaler.unscale_(runner.optimizer)
226
+ # grad clip
227
+ if self.grad_clip is not None:
228
+ grad_norm = self.clip_grads(runner.model.parameters())
229
+ if grad_norm is not None:
230
+ # Add grad norm to the logger
231
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
232
+ runner.outputs['num_samples'])
233
+ # backward and update scaler
234
+ self.loss_scaler.step(runner.optimizer)
235
+ self.loss_scaler.update(self._scale_update_param)
236
+
237
+ # save state_dict of loss_scaler
238
+ runner.meta.setdefault(
239
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
240
+
241
+ @HOOKS.register_module()
242
+ class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
243
+ Fp16OptimizerHook):
244
+ """Fp16 optimizer Hook (using PyTorch's implementation) implements
245
+ multi-iters gradient cumulating.
246
+
247
+ If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
248
+ to take care of the optimization procedure.
249
+ """
250
+
251
+ def __init__(self, *args, **kwargs):
252
+ super(GradientCumulativeFp16OptimizerHook,
253
+ self).__init__(*args, **kwargs)
254
+
255
+ def after_train_iter(self, runner):
256
+ if not self.initialized:
257
+ self._init(runner)
258
+
259
+ if runner.iter < self.divisible_iters:
260
+ loss_factor = self.cumulative_iters
261
+ else:
262
+ loss_factor = self.remainder_iters
263
+ loss = runner.outputs['loss']
264
+ loss = loss / loss_factor
265
+
266
+ self.loss_scaler.scale(loss).backward()
267
+
268
+ if (self.every_n_iters(runner, self.cumulative_iters)
269
+ or self.is_last_iter(runner)):
270
+
271
+ # copy fp16 grads in the model to fp32 params in the optimizer
272
+ self.loss_scaler.unscale_(runner.optimizer)
273
+
274
+ if self.grad_clip is not None:
275
+ grad_norm = self.clip_grads(runner.model.parameters())
276
+ if grad_norm is not None:
277
+ # Add grad norm to the logger
278
+ runner.log_buffer.update(
279
+ {'grad_norm': float(grad_norm)},
280
+ runner.outputs['num_samples'])
281
+
282
+ # backward and update scaler
283
+ self.loss_scaler.step(runner.optimizer)
284
+ self.loss_scaler.update(self._scale_update_param)
285
+
286
+ # save state_dict of loss_scaler
287
+ runner.meta.setdefault(
288
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
289
+
290
+ # clear grads
291
+ runner.model.zero_grad()
292
+ runner.optimizer.zero_grad()
293
+
294
+ else:
295
+
296
+ @HOOKS.register_module()
297
+ class Fp16OptimizerHook(OptimizerHook):
298
+ """FP16 optimizer hook (mmcv's implementation).
299
+
300
+ The steps of fp16 optimizer is as follows.
301
+ 1. Scale the loss value.
302
+ 2. BP in the fp16 model.
303
+ 2. Copy gradients from fp16 model to fp32 weights.
304
+ 3. Update fp32 weights.
305
+ 4. Copy updated parameters from fp32 weights to fp16 model.
306
+
307
+ Refer to https://arxiv.org/abs/1710.03740 for more details.
308
+
309
+ Args:
310
+ loss_scale (float | str | dict): Scale factor configuration.
311
+ If loss_scale is a float, static loss scaling will be used with
312
+ the specified scale. If loss_scale is a string, it must be
313
+ 'dynamic', then dynamic loss scaling will be used.
314
+ It can also be a dict containing arguments of LossScaler.
315
+ Defaults to 512.
316
+ """
317
+
318
+ def __init__(self,
319
+ grad_clip=None,
320
+ coalesce=True,
321
+ bucket_size_mb=-1,
322
+ loss_scale=512.,
323
+ distributed=True):
324
+ self.grad_clip = grad_clip
325
+ self.coalesce = coalesce
326
+ self.bucket_size_mb = bucket_size_mb
327
+ self.distributed = distributed
328
+ if loss_scale == 'dynamic':
329
+ self.loss_scaler = LossScaler(mode='dynamic')
330
+ elif isinstance(loss_scale, float):
331
+ self.loss_scaler = LossScaler(
332
+ init_scale=loss_scale, mode='static')
333
+ elif isinstance(loss_scale, dict):
334
+ self.loss_scaler = LossScaler(**loss_scale)
335
+ else:
336
+ raise ValueError('loss_scale must be of type float, dict, or '
337
+ f'"dynamic", got {loss_scale}')
338
+
339
+ def before_run(self, runner):
340
+ """Preparing steps before Mixed Precision Training.
341
+
342
+ 1. Make a master copy of fp32 weights for optimization.
343
+ 2. Convert the main model from fp32 to fp16.
344
+ """
345
+ # keep a copy of fp32 weights
346
+ old_groups = runner.optimizer.param_groups
347
+ runner.optimizer.param_groups = copy.deepcopy(
348
+ runner.optimizer.param_groups)
349
+ state = defaultdict(dict)
350
+ p_map = {
351
+ old_p: p
352
+ for old_p, p in zip(
353
+ chain(*(g['params'] for g in old_groups)),
354
+ chain(*(g['params']
355
+ for g in runner.optimizer.param_groups)))
356
+ }
357
+ for k, v in runner.optimizer.state.items():
358
+ state[p_map[k]] = v
359
+ runner.optimizer.state = state
360
+ # convert model to fp16
361
+ wrap_fp16_model(runner.model)
362
+ # resume from state dict
363
+ if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
364
+ scaler_state_dict = runner.meta['fp16']['loss_scaler']
365
+ self.loss_scaler.load_state_dict(scaler_state_dict)
366
+
367
+ def copy_grads_to_fp32(self, fp16_net, fp32_weights):
368
+ """Copy gradients from fp16 model to fp32 weight copy."""
369
+ for fp32_param, fp16_param in zip(fp32_weights,
370
+ fp16_net.parameters()):
371
+ if fp16_param.grad is not None:
372
+ if fp32_param.grad is None:
373
+ fp32_param.grad = fp32_param.data.new(
374
+ fp32_param.size())
375
+ fp32_param.grad.copy_(fp16_param.grad)
376
+
377
+ def copy_params_to_fp16(self, fp16_net, fp32_weights):
378
+ """Copy updated params from fp32 weight copy to fp16 model."""
379
+ for fp16_param, fp32_param in zip(fp16_net.parameters(),
380
+ fp32_weights):
381
+ fp16_param.data.copy_(fp32_param.data)
382
+
383
+ def after_train_iter(self, runner):
384
+ """Backward optimization steps for Mixed Precision Training. For
385
+ dynamic loss scaling, please refer `loss_scalar.py`
386
+
387
+ 1. Scale the loss by a scale factor.
388
+ 2. Backward the loss to obtain the gradients (fp16).
389
+ 3. Copy gradients from the model to the fp32 weight copy.
390
+ 4. Scale the gradients back and update the fp32 weight copy.
391
+ 5. Copy back the params from fp32 weight copy to the fp16 model.
392
+ 6. Save loss_scaler state_dict for resume purpose.
393
+ """
394
+ # clear grads of last iteration
395
+ runner.model.zero_grad()
396
+ runner.optimizer.zero_grad()
397
+ # scale the loss value
398
+ scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale
399
+ scaled_loss.backward()
400
+ # copy fp16 grads in the model to fp32 params in the optimizer
401
+
402
+ fp32_weights = []
403
+ for param_group in runner.optimizer.param_groups:
404
+ fp32_weights += param_group['params']
405
+ self.copy_grads_to_fp32(runner.model, fp32_weights)
406
+ # allreduce grads
407
+ if self.distributed:
408
+ allreduce_grads(fp32_weights, self.coalesce,
409
+ self.bucket_size_mb)
410
+
411
+ has_overflow = self.loss_scaler.has_overflow(fp32_weights)
412
+ # if has overflow, skip this iteration
413
+ if not has_overflow:
414
+ # scale the gradients back
415
+ for param in fp32_weights:
416
+ if param.grad is not None:
417
+ param.grad.div_(self.loss_scaler.loss_scale)
418
+ if self.grad_clip is not None:
419
+ grad_norm = self.clip_grads(fp32_weights)
420
+ if grad_norm is not None:
421
+ # Add grad norm to the logger
422
+ runner.log_buffer.update(
423
+ {'grad_norm': float(grad_norm)},
424
+ runner.outputs['num_samples'])
425
+ # update fp32 params
426
+ runner.optimizer.step()
427
+ # copy fp32 params to the fp16 model
428
+ self.copy_params_to_fp16(runner.model, fp32_weights)
429
+ self.loss_scaler.update_scale(has_overflow)
430
+ if has_overflow:
431
+ runner.logger.warning('Check overflow, downscale loss scale '
432
+ f'to {self.loss_scaler.cur_scale}')
433
+
434
+ # save state_dict of loss_scaler
435
+ runner.meta.setdefault(
436
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
437
+
438
+ @HOOKS.register_module()
439
+ class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
440
+ Fp16OptimizerHook):
441
+ """Fp16 optimizer Hook (using mmcv implementation) implements multi-
442
+ iters gradient cumulating."""
443
+
444
+ def __init__(self, *args, **kwargs):
445
+ super(GradientCumulativeFp16OptimizerHook,
446
+ self).__init__(*args, **kwargs)
447
+
448
+ def after_train_iter(self, runner):
449
+ if not self.initialized:
450
+ self._init(runner)
451
+
452
+ if runner.iter < self.divisible_iters:
453
+ loss_factor = self.cumulative_iters
454
+ else:
455
+ loss_factor = self.remainder_iters
456
+
457
+ loss = runner.outputs['loss']
458
+ loss = loss / loss_factor
459
+
460
+ # scale the loss value
461
+ scaled_loss = loss * self.loss_scaler.loss_scale
462
+ scaled_loss.backward()
463
+
464
+ if (self.every_n_iters(runner, self.cumulative_iters)
465
+ or self.is_last_iter(runner)):
466
+
467
+ # copy fp16 grads in the model to fp32 params in the optimizer
468
+ fp32_weights = []
469
+ for param_group in runner.optimizer.param_groups:
470
+ fp32_weights += param_group['params']
471
+ self.copy_grads_to_fp32(runner.model, fp32_weights)
472
+ # allreduce grads
473
+ if self.distributed:
474
+ allreduce_grads(fp32_weights, self.coalesce,
475
+ self.bucket_size_mb)
476
+
477
+ has_overflow = self.loss_scaler.has_overflow(fp32_weights)
478
+ # if has overflow, skip this iteration
479
+ if not has_overflow:
480
+ # scale the gradients back
481
+ for param in fp32_weights:
482
+ if param.grad is not None:
483
+ param.grad.div_(self.loss_scaler.loss_scale)
484
+ if self.grad_clip is not None:
485
+ grad_norm = self.clip_grads(fp32_weights)
486
+ if grad_norm is not None:
487
+ # Add grad norm to the logger
488
+ runner.log_buffer.update(
489
+ {'grad_norm': float(grad_norm)},
490
+ runner.outputs['num_samples'])
491
+ # update fp32 params
492
+ runner.optimizer.step()
493
+ # copy fp32 params to the fp16 model
494
+ self.copy_params_to_fp16(runner.model, fp32_weights)
495
+ else:
496
+ runner.logger.warning(
497
+ 'Check overflow, downscale loss scale '
498
+ f'to {self.loss_scaler.cur_scale}')
499
+
500
+ self.loss_scaler.update_scale(has_overflow)
501
+
502
+ # save state_dict of loss_scaler
503
+ runner.meta.setdefault(
504
+ 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
505
+
506
+ # clear grads
507
+ runner.model.zero_grad()
508
+ runner.optimizer.zero_grad()
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/profiler.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import torch
6
+
7
+ from ..dist_utils import master_only
8
+ from .hook import HOOKS, Hook
9
+
10
+
11
+ @HOOKS.register_module()
12
+ class ProfilerHook(Hook):
13
+ """Profiler to analyze performance during training.
14
+
15
+ PyTorch Profiler is a tool that allows the collection of the performance
16
+ metrics during the training. More details on Profiler can be found at
17
+ https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile
18
+
19
+ Args:
20
+ by_epoch (bool): Profile performance by epoch or by iteration.
21
+ Default: True.
22
+ profile_iters (int): Number of iterations for profiling.
23
+ If ``by_epoch=True``, profile_iters indicates that they are the
24
+ first profile_iters epochs at the beginning of the
25
+ training, otherwise it indicates the first profile_iters
26
+ iterations. Default: 1.
27
+ activities (list[str]): List of activity groups (CPU, CUDA) to use in
28
+ profiling. Default: ['cpu', 'cuda'].
29
+ schedule (dict, optional): Config of generating the callable schedule.
30
+ if schedule is None, profiler will not add step markers into the
31
+ trace and table view. Default: None.
32
+ on_trace_ready (callable, dict): Either a handler or a dict of generate
33
+ handler. Default: None.
34
+ record_shapes (bool): Save information about operator's input shapes.
35
+ Default: False.
36
+ profile_memory (bool): Track tensor memory allocation/deallocation.
37
+ Default: False.
38
+ with_stack (bool): Record source information (file and line number)
39
+ for the ops. Default: False.
40
+ with_flops (bool): Use formula to estimate the FLOPS of specific
41
+ operators (matrix multiplication and 2D convolution).
42
+ Default: False.
43
+ json_trace_path (str, optional): Exports the collected trace in Chrome
44
+ JSON format. Default: None.
45
+
46
+ Example:
47
+ >>> runner = ... # instantiate a Runner
48
+ >>> # tensorboard trace
49
+ >>> trace_config = dict(type='tb_trace', dir_name='work_dir')
50
+ >>> profiler_config = dict(on_trace_ready=trace_config)
51
+ >>> runner.register_profiler_hook(profiler_config)
52
+ >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)])
53
+ """
54
+
55
+ def __init__(self,
56
+ by_epoch: bool = True,
57
+ profile_iters: int = 1,
58
+ activities: List[str] = ['cpu', 'cuda'],
59
+ schedule: Optional[dict] = None,
60
+ on_trace_ready: Optional[Union[Callable, dict]] = None,
61
+ record_shapes: bool = False,
62
+ profile_memory: bool = False,
63
+ with_stack: bool = False,
64
+ with_flops: bool = False,
65
+ json_trace_path: Optional[str] = None) -> None:
66
+ try:
67
+ from torch import profiler # torch version >= 1.8.1
68
+ except ImportError:
69
+ raise ImportError('profiler is the new feature of torch1.8.1, '
70
+ f'but your version is {torch.__version__}')
71
+
72
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
73
+ self.by_epoch = by_epoch
74
+
75
+ if profile_iters < 1:
76
+ raise ValueError('profile_iters should be greater than 0, but got '
77
+ f'{profile_iters}')
78
+ self.profile_iters = profile_iters
79
+
80
+ if not isinstance(activities, list):
81
+ raise ValueError(
82
+ f'activities should be list, but got {type(activities)}')
83
+ self.activities = []
84
+ for activity in activities:
85
+ activity = activity.lower()
86
+ if activity == 'cpu':
87
+ self.activities.append(profiler.ProfilerActivity.CPU)
88
+ elif activity == 'cuda':
89
+ self.activities.append(profiler.ProfilerActivity.CUDA)
90
+ else:
91
+ raise ValueError(
92
+ f'activity should be "cpu" or "cuda", but got {activity}')
93
+
94
+ if schedule is not None:
95
+ self.schedule = profiler.schedule(**schedule)
96
+ else:
97
+ self.schedule = None
98
+
99
+ self.on_trace_ready = on_trace_ready
100
+ self.record_shapes = record_shapes
101
+ self.profile_memory = profile_memory
102
+ self.with_stack = with_stack
103
+ self.with_flops = with_flops
104
+ self.json_trace_path = json_trace_path
105
+
106
+ @master_only
107
+ def before_run(self, runner):
108
+ if self.by_epoch and runner.max_epochs < self.profile_iters:
109
+ raise ValueError('self.profile_iters should not be greater than '
110
+ f'{runner.max_epochs}')
111
+
112
+ if not self.by_epoch and runner.max_iters < self.profile_iters:
113
+ raise ValueError('self.profile_iters should not be greater than '
114
+ f'{runner.max_iters}')
115
+
116
+ if callable(self.on_trace_ready): # handler
117
+ _on_trace_ready = self.on_trace_ready
118
+ elif isinstance(self.on_trace_ready, dict): # config of handler
119
+ trace_cfg = self.on_trace_ready.copy()
120
+ trace_type = trace_cfg.pop('type') # log_trace handler
121
+ if trace_type == 'log_trace':
122
+
123
+ def _log_handler(prof):
124
+ print(prof.key_averages().table(**trace_cfg))
125
+
126
+ _on_trace_ready = _log_handler
127
+ elif trace_type == 'tb_trace': # tensorboard_trace handler
128
+ try:
129
+ import torch_tb_profiler # noqa: F401
130
+ except ImportError:
131
+ raise ImportError('please run "pip install '
132
+ 'torch-tb-profiler" to install '
133
+ 'torch_tb_profiler')
134
+ _on_trace_ready = torch.profiler.tensorboard_trace_handler(
135
+ **trace_cfg)
136
+ else:
137
+ raise ValueError('trace_type should be "log_trace" or '
138
+ f'"tb_trace", but got {trace_type}')
139
+ elif self.on_trace_ready is None:
140
+ _on_trace_ready = None # type: ignore
141
+ else:
142
+ raise ValueError('on_trace_ready should be handler, dict or None, '
143
+ f'but got {type(self.on_trace_ready)}')
144
+
145
+ if runner.max_epochs > 1:
146
+ warnings.warn(f'profiler will profile {runner.max_epochs} epochs '
147
+ 'instead of 1 epoch. Since profiler will slow down '
148
+ 'the training, it is recommended to train 1 epoch '
149
+ 'with ProfilerHook and adjust your setting according'
150
+ ' to the profiler summary. During normal training '
151
+ '(epoch > 1), you may disable the ProfilerHook.')
152
+
153
+ self.profiler = torch.profiler.profile(
154
+ activities=self.activities,
155
+ schedule=self.schedule,
156
+ on_trace_ready=_on_trace_ready,
157
+ record_shapes=self.record_shapes,
158
+ profile_memory=self.profile_memory,
159
+ with_stack=self.with_stack,
160
+ with_flops=self.with_flops)
161
+
162
+ self.profiler.__enter__()
163
+ runner.logger.info('profiler is profiling...')
164
+
165
+ @master_only
166
+ def after_train_epoch(self, runner):
167
+ if self.by_epoch and runner.epoch == self.profile_iters - 1:
168
+ runner.logger.info('profiler may take a few minutes...')
169
+ self.profiler.__exit__(None, None, None)
170
+ if self.json_trace_path is not None:
171
+ self.profiler.export_chrome_trace(self.json_trace_path)
172
+
173
+ @master_only
174
+ def after_train_iter(self, runner):
175
+ self.profiler.step()
176
+ if not self.by_epoch and runner.iter == self.profile_iters - 1:
177
+ runner.logger.info('profiler may take a few minutes...')
178
+ self.profiler.__exit__(None, None, None)
179
+ if self.json_trace_path is not None:
180
+ self.profiler.export_chrome_trace(self.json_trace_path)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/sampler_seed.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .hook import HOOKS, Hook
3
+
4
+
5
+ @HOOKS.register_module()
6
+ class DistSamplerSeedHook(Hook):
7
+ """Data-loading sampler for distributed training.
8
+
9
+ When distributed training, it is only useful in conjunction with
10
+ :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
11
+ purpose with :obj:`IterLoader`.
12
+ """
13
+
14
+ def before_epoch(self, runner):
15
+ if hasattr(runner.data_loader.sampler, 'set_epoch'):
16
+ # in case the data loader uses `SequentialSampler` in Pytorch
17
+ runner.data_loader.sampler.set_epoch(runner.epoch)
18
+ elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
19
+ # batch sampler in pytorch warps the sampler as its attributes.
20
+ runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/sync_buffer.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from ..dist_utils import allreduce_params
3
+ from .hook import HOOKS, Hook
4
+
5
+
6
+ @HOOKS.register_module()
7
+ class SyncBuffersHook(Hook):
8
+ """Synchronize model buffers such as running_mean and running_var in BN at
9
+ the end of each epoch.
10
+
11
+ Args:
12
+ distributed (bool): Whether distributed training is used. It is
13
+ effective only for distributed training. Defaults to True.
14
+ """
15
+
16
+ def __init__(self, distributed=True):
17
+ self.distributed = distributed
18
+
19
+ def after_epoch(self, runner):
20
+ """All-reduce model buffers at the end of each epoch."""
21
+ if self.distributed:
22
+ allreduce_params(runner.model.buffers())
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/iter_based_runner.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ import platform
4
+ import shutil
5
+ import time
6
+ import warnings
7
+
8
+ import torch
9
+ from torch.optim import Optimizer
10
+
11
+ import annotator.mmpkg.mmcv as mmcv
12
+ from .base_runner import BaseRunner
13
+ from .builder import RUNNERS
14
+ from .checkpoint import save_checkpoint
15
+ from .hooks import IterTimerHook
16
+ from .utils import get_host_info
17
+
18
+
19
+ class IterLoader:
20
+
21
+ def __init__(self, dataloader):
22
+ self._dataloader = dataloader
23
+ self.iter_loader = iter(self._dataloader)
24
+ self._epoch = 0
25
+
26
+ @property
27
+ def epoch(self):
28
+ return self._epoch
29
+
30
+ def __next__(self):
31
+ try:
32
+ data = next(self.iter_loader)
33
+ except StopIteration:
34
+ self._epoch += 1
35
+ if hasattr(self._dataloader.sampler, 'set_epoch'):
36
+ self._dataloader.sampler.set_epoch(self._epoch)
37
+ time.sleep(2) # Prevent possible deadlock during epoch transition
38
+ self.iter_loader = iter(self._dataloader)
39
+ data = next(self.iter_loader)
40
+
41
+ return data
42
+
43
+ def __len__(self):
44
+ return len(self._dataloader)
45
+
46
+
47
+ @RUNNERS.register_module()
48
+ class IterBasedRunner(BaseRunner):
49
+ """Iteration-based Runner.
50
+
51
+ This runner train models iteration by iteration.
52
+ """
53
+
54
+ def train(self, data_loader, **kwargs):
55
+ self.model.train()
56
+ self.mode = 'train'
57
+ self.data_loader = data_loader
58
+ self._epoch = data_loader.epoch
59
+ data_batch = next(data_loader)
60
+ self.call_hook('before_train_iter')
61
+ outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
62
+ if not isinstance(outputs, dict):
63
+ raise TypeError('model.train_step() must return a dict')
64
+ if 'log_vars' in outputs:
65
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
66
+ self.outputs = outputs
67
+ self.call_hook('after_train_iter')
68
+ self._inner_iter += 1
69
+ self._iter += 1
70
+
71
+ @torch.no_grad()
72
+ def val(self, data_loader, **kwargs):
73
+ self.model.eval()
74
+ self.mode = 'val'
75
+ self.data_loader = data_loader
76
+ data_batch = next(data_loader)
77
+ self.call_hook('before_val_iter')
78
+ outputs = self.model.val_step(data_batch, **kwargs)
79
+ if not isinstance(outputs, dict):
80
+ raise TypeError('model.val_step() must return a dict')
81
+ if 'log_vars' in outputs:
82
+ self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
83
+ self.outputs = outputs
84
+ self.call_hook('after_val_iter')
85
+ self._inner_iter += 1
86
+
87
+ def run(self, data_loaders, workflow, max_iters=None, **kwargs):
88
+ """Start running.
89
+
90
+ Args:
91
+ data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
92
+ and validation.
93
+ workflow (list[tuple]): A list of (phase, iters) to specify the
94
+ running order and iterations. E.g, [('train', 10000),
95
+ ('val', 1000)] means running 10000 iterations for training and
96
+ 1000 iterations for validation, iteratively.
97
+ """
98
+ assert isinstance(data_loaders, list)
99
+ assert mmcv.is_list_of(workflow, tuple)
100
+ assert len(data_loaders) == len(workflow)
101
+ if max_iters is not None:
102
+ warnings.warn(
103
+ 'setting max_iters in run is deprecated, '
104
+ 'please set max_iters in runner_config', DeprecationWarning)
105
+ self._max_iters = max_iters
106
+ assert self._max_iters is not None, (
107
+ 'max_iters must be specified during instantiation')
108
+
109
+ work_dir = self.work_dir if self.work_dir is not None else 'NONE'
110
+ self.logger.info('Start running, host: %s, work_dir: %s',
111
+ get_host_info(), work_dir)
112
+ self.logger.info('Hooks will be executed in the following order:\n%s',
113
+ self.get_hook_info())
114
+ self.logger.info('workflow: %s, max: %d iters', workflow,
115
+ self._max_iters)
116
+ self.call_hook('before_run')
117
+
118
+ iter_loaders = [IterLoader(x) for x in data_loaders]
119
+
120
+ self.call_hook('before_epoch')
121
+
122
+ while self.iter < self._max_iters:
123
+ for i, flow in enumerate(workflow):
124
+ self._inner_iter = 0
125
+ mode, iters = flow
126
+ if not isinstance(mode, str) or not hasattr(self, mode):
127
+ raise ValueError(
128
+ 'runner has no method named "{}" to run a workflow'.
129
+ format(mode))
130
+ iter_runner = getattr(self, mode)
131
+ for _ in range(iters):
132
+ if mode == 'train' and self.iter >= self._max_iters:
133
+ break
134
+ iter_runner(iter_loaders[i], **kwargs)
135
+
136
+ time.sleep(1) # wait for some hooks like loggers to finish
137
+ self.call_hook('after_epoch')
138
+ self.call_hook('after_run')
139
+
140
+ def resume(self,
141
+ checkpoint,
142
+ resume_optimizer=True,
143
+ map_location='default'):
144
+ """Resume model from checkpoint.
145
+
146
+ Args:
147
+ checkpoint (str): Checkpoint to resume from.
148
+ resume_optimizer (bool, optional): Whether resume the optimizer(s)
149
+ if the checkpoint file includes optimizer(s). Default to True.
150
+ map_location (str, optional): Same as :func:`torch.load`.
151
+ Default to 'default'.
152
+ """
153
+ if map_location == 'default':
154
+ device_id = torch.cuda.current_device()
155
+ checkpoint = self.load_checkpoint(
156
+ checkpoint,
157
+ map_location=lambda storage, loc: storage.cuda(device_id))
158
+ else:
159
+ checkpoint = self.load_checkpoint(
160
+ checkpoint, map_location=map_location)
161
+
162
+ self._epoch = checkpoint['meta']['epoch']
163
+ self._iter = checkpoint['meta']['iter']
164
+ self._inner_iter = checkpoint['meta']['iter']
165
+ if 'optimizer' in checkpoint and resume_optimizer:
166
+ if isinstance(self.optimizer, Optimizer):
167
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
168
+ elif isinstance(self.optimizer, dict):
169
+ for k in self.optimizer.keys():
170
+ self.optimizer[k].load_state_dict(
171
+ checkpoint['optimizer'][k])
172
+ else:
173
+ raise TypeError(
174
+ 'Optimizer should be dict or torch.optim.Optimizer '
175
+ f'but got {type(self.optimizer)}')
176
+
177
+ self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
178
+
179
+ def save_checkpoint(self,
180
+ out_dir,
181
+ filename_tmpl='iter_{}.pth',
182
+ meta=None,
183
+ save_optimizer=True,
184
+ create_symlink=True):
185
+ """Save checkpoint to file.
186
+
187
+ Args:
188
+ out_dir (str): Directory to save checkpoint files.
189
+ filename_tmpl (str, optional): Checkpoint file template.
190
+ Defaults to 'iter_{}.pth'.
191
+ meta (dict, optional): Metadata to be saved in checkpoint.
192
+ Defaults to None.
193
+ save_optimizer (bool, optional): Whether save optimizer.
194
+ Defaults to True.
195
+ create_symlink (bool, optional): Whether create symlink to the
196
+ latest checkpoint file. Defaults to True.
197
+ """
198
+ if meta is None:
199
+ meta = {}
200
+ elif not isinstance(meta, dict):
201
+ raise TypeError(
202
+ f'meta should be a dict or None, but got {type(meta)}')
203
+ if self.meta is not None:
204
+ meta.update(self.meta)
205
+ # Note: meta.update(self.meta) should be done before
206
+ # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
207
+ # there will be problems with resumed checkpoints.
208
+ # More details in https://github.com/open-mmlab/mmcv/pull/1108
209
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
210
+
211
+ filename = filename_tmpl.format(self.iter + 1)
212
+ filepath = osp.join(out_dir, filename)
213
+ optimizer = self.optimizer if save_optimizer else None
214
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
215
+ # in some environments, `os.symlink` is not supported, you may need to
216
+ # set `create_symlink` to False
217
+ if create_symlink:
218
+ dst_file = osp.join(out_dir, 'latest.pth')
219
+ if platform.system() != 'Windows':
220
+ mmcv.symlink(filename, dst_file)
221
+ else:
222
+ shutil.copy(filepath, dst_file)
223
+
224
+ def register_training_hooks(self,
225
+ lr_config,
226
+ optimizer_config=None,
227
+ checkpoint_config=None,
228
+ log_config=None,
229
+ momentum_config=None,
230
+ custom_hooks_config=None):
231
+ """Register default hooks for iter-based training.
232
+
233
+ Checkpoint hook, optimizer stepper hook and logger hooks will be set to
234
+ `by_epoch=False` by default.
235
+
236
+ Default hooks include:
237
+
238
+ +----------------------+-------------------------+
239
+ | Hooks | Priority |
240
+ +======================+=========================+
241
+ | LrUpdaterHook | VERY_HIGH (10) |
242
+ +----------------------+-------------------------+
243
+ | MomentumUpdaterHook | HIGH (30) |
244
+ +----------------------+-------------------------+
245
+ | OptimizerStepperHook | ABOVE_NORMAL (40) |
246
+ +----------------------+-------------------------+
247
+ | CheckpointSaverHook | NORMAL (50) |
248
+ +----------------------+-------------------------+
249
+ | IterTimerHook | LOW (70) |
250
+ +----------------------+-------------------------+
251
+ | LoggerHook(s) | VERY_LOW (90) |
252
+ +----------------------+-------------------------+
253
+ | CustomHook(s) | defaults to NORMAL (50) |
254
+ +----------------------+-------------------------+
255
+
256
+ If custom hooks have same priority with default hooks, custom hooks
257
+ will be triggered after default hooks.
258
+ """
259
+ if checkpoint_config is not None:
260
+ checkpoint_config.setdefault('by_epoch', False)
261
+ if lr_config is not None:
262
+ lr_config.setdefault('by_epoch', False)
263
+ if log_config is not None:
264
+ for info in log_config['hooks']:
265
+ info.setdefault('by_epoch', False)
266
+ super(IterBasedRunner, self).register_training_hooks(
267
+ lr_config=lr_config,
268
+ momentum_config=momentum_config,
269
+ optimizer_config=optimizer_config,
270
+ checkpoint_config=checkpoint_config,
271
+ log_config=log_config,
272
+ timer_config=IterTimerHook(),
273
+ custom_hooks_config=custom_hooks_config)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/log_buffer.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from collections import OrderedDict
3
+
4
+ import numpy as np
5
+
6
+
7
+ class LogBuffer:
8
+
9
+ def __init__(self):
10
+ self.val_history = OrderedDict()
11
+ self.n_history = OrderedDict()
12
+ self.output = OrderedDict()
13
+ self.ready = False
14
+
15
+ def clear(self):
16
+ self.val_history.clear()
17
+ self.n_history.clear()
18
+ self.clear_output()
19
+
20
+ def clear_output(self):
21
+ self.output.clear()
22
+ self.ready = False
23
+
24
+ def update(self, vars, count=1):
25
+ assert isinstance(vars, dict)
26
+ for key, var in vars.items():
27
+ if key not in self.val_history:
28
+ self.val_history[key] = []
29
+ self.n_history[key] = []
30
+ self.val_history[key].append(var)
31
+ self.n_history[key].append(count)
32
+
33
+ def average(self, n=0):
34
+ """Average latest n values or all values."""
35
+ assert n >= 0
36
+ for key in self.val_history:
37
+ values = np.array(self.val_history[key][-n:])
38
+ nums = np.array(self.n_history[key][-n:])
39
+ avg = np.sum(values * nums) / np.sum(nums)
40
+ self.output[key] = avg
41
+ self.ready = True
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer,
3
+ build_optimizer_constructor)
4
+ from .default_constructor import DefaultOptimizerConstructor
5
+
6
+ __all__ = [
7
+ 'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
8
+ 'build_optimizer', 'build_optimizer_constructor'
9
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/builder.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import inspect
4
+
5
+ import torch
6
+
7
+ from ...utils import Registry, build_from_cfg
8
+
9
+ OPTIMIZERS = Registry('optimizer')
10
+ OPTIMIZER_BUILDERS = Registry('optimizer builder')
11
+
12
+
13
+ def register_torch_optimizers():
14
+ torch_optimizers = []
15
+ for module_name in dir(torch.optim):
16
+ if module_name.startswith('__'):
17
+ continue
18
+ _optim = getattr(torch.optim, module_name)
19
+ if inspect.isclass(_optim) and issubclass(_optim,
20
+ torch.optim.Optimizer):
21
+ OPTIMIZERS.register_module()(_optim)
22
+ torch_optimizers.append(module_name)
23
+ return torch_optimizers
24
+
25
+
26
+ TORCH_OPTIMIZERS = register_torch_optimizers()
27
+
28
+
29
+ def build_optimizer_constructor(cfg):
30
+ return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
31
+
32
+
33
+ def build_optimizer(model, cfg):
34
+ optimizer_cfg = copy.deepcopy(cfg)
35
+ constructor_type = optimizer_cfg.pop('constructor',
36
+ 'DefaultOptimizerConstructor')
37
+ paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
38
+ optim_constructor = build_optimizer_constructor(
39
+ dict(
40
+ type=constructor_type,
41
+ optimizer_cfg=optimizer_cfg,
42
+ paramwise_cfg=paramwise_cfg))
43
+ optimizer = optim_constructor(model)
44
+ return optimizer
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/default_constructor.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import torch
5
+ from torch.nn import GroupNorm, LayerNorm
6
+
7
+ from annotator.mmpkg.mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
8
+ from annotator.mmpkg.mmcv.utils.ext_loader import check_ops_exist
9
+ from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS
10
+
11
+
12
+ @OPTIMIZER_BUILDERS.register_module()
13
+ class DefaultOptimizerConstructor:
14
+ """Default constructor for optimizers.
15
+
16
+ By default each parameter share the same optimizer settings, and we
17
+ provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
18
+ It is a dict and may contain the following fields:
19
+
20
+ - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
21
+ one of the keys in ``custom_keys`` is a substring of the name of one
22
+ parameter, then the setting of the parameter will be specified by
23
+ ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
24
+ be ignored. It should be noted that the aforementioned ``key`` is the
25
+ longest key that is a substring of the name of the parameter. If there
26
+ are multiple matched keys with the same length, then the key with lower
27
+ alphabet order will be chosen.
28
+ ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
29
+ and ``decay_mult``. See Example 2 below.
30
+ - ``bias_lr_mult`` (float): It will be multiplied to the learning
31
+ rate for all bias parameters (except for those in normalization
32
+ layers and offset layers of DCN).
33
+ - ``bias_decay_mult`` (float): It will be multiplied to the weight
34
+ decay for all bias parameters (except for those in
35
+ normalization layers, depthwise conv layers, offset layers of DCN).
36
+ - ``norm_decay_mult`` (float): It will be multiplied to the weight
37
+ decay for all weight and bias parameters of normalization
38
+ layers.
39
+ - ``dwconv_decay_mult`` (float): It will be multiplied to the weight
40
+ decay for all weight and bias parameters of depthwise conv
41
+ layers.
42
+ - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
43
+ rate for parameters of offset layer in the deformable convs
44
+ of a model.
45
+ - ``bypass_duplicate`` (bool): If true, the duplicate parameters
46
+ would not be added into optimizer. Default: False.
47
+
48
+ Note:
49
+ 1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
50
+ override the effect of ``bias_lr_mult`` in the bias of offset
51
+ layer. So be careful when using both ``bias_lr_mult`` and
52
+ ``dcn_offset_lr_mult``. If you wish to apply both of them to the
53
+ offset layer in deformable convs, set ``dcn_offset_lr_mult``
54
+ to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
55
+ 2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
56
+ apply it to all the DCN layers in the model. So be careful when
57
+ the model contains multiple DCN layers in places other than
58
+ backbone.
59
+
60
+ Args:
61
+ model (:obj:`nn.Module`): The model with parameters to be optimized.
62
+ optimizer_cfg (dict): The config dict of the optimizer.
63
+ Positional fields are
64
+
65
+ - `type`: class name of the optimizer.
66
+
67
+ Optional fields are
68
+
69
+ - any arguments of the corresponding optimizer type, e.g.,
70
+ lr, weight_decay, momentum, etc.
71
+ paramwise_cfg (dict, optional): Parameter-wise options.
72
+
73
+ Example 1:
74
+ >>> model = torch.nn.modules.Conv1d(1, 1, 1)
75
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
76
+ >>> weight_decay=0.0001)
77
+ >>> paramwise_cfg = dict(norm_decay_mult=0.)
78
+ >>> optim_builder = DefaultOptimizerConstructor(
79
+ >>> optimizer_cfg, paramwise_cfg)
80
+ >>> optimizer = optim_builder(model)
81
+
82
+ Example 2:
83
+ >>> # assume model have attribute model.backbone and model.cls_head
84
+ >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
85
+ >>> paramwise_cfg = dict(custom_keys={
86
+ '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
87
+ >>> optim_builder = DefaultOptimizerConstructor(
88
+ >>> optimizer_cfg, paramwise_cfg)
89
+ >>> optimizer = optim_builder(model)
90
+ >>> # Then the `lr` and `weight_decay` for model.backbone is
91
+ >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
92
+ >>> # model.cls_head is (0.01, 0.95).
93
+ """
94
+
95
+ def __init__(self, optimizer_cfg, paramwise_cfg=None):
96
+ if not isinstance(optimizer_cfg, dict):
97
+ raise TypeError('optimizer_cfg should be a dict',
98
+ f'but got {type(optimizer_cfg)}')
99
+ self.optimizer_cfg = optimizer_cfg
100
+ self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
101
+ self.base_lr = optimizer_cfg.get('lr', None)
102
+ self.base_wd = optimizer_cfg.get('weight_decay', None)
103
+ self._validate_cfg()
104
+
105
+ def _validate_cfg(self):
106
+ if not isinstance(self.paramwise_cfg, dict):
107
+ raise TypeError('paramwise_cfg should be None or a dict, '
108
+ f'but got {type(self.paramwise_cfg)}')
109
+
110
+ if 'custom_keys' in self.paramwise_cfg:
111
+ if not isinstance(self.paramwise_cfg['custom_keys'], dict):
112
+ raise TypeError(
113
+ 'If specified, custom_keys must be a dict, '
114
+ f'but got {type(self.paramwise_cfg["custom_keys"])}')
115
+ if self.base_wd is None:
116
+ for key in self.paramwise_cfg['custom_keys']:
117
+ if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]:
118
+ raise ValueError('base_wd should not be None')
119
+
120
+ # get base lr and weight decay
121
+ # weight_decay must be explicitly specified if mult is specified
122
+ if ('bias_decay_mult' in self.paramwise_cfg
123
+ or 'norm_decay_mult' in self.paramwise_cfg
124
+ or 'dwconv_decay_mult' in self.paramwise_cfg):
125
+ if self.base_wd is None:
126
+ raise ValueError('base_wd should not be None')
127
+
128
+ def _is_in(self, param_group, param_group_list):
129
+ assert is_list_of(param_group_list, dict)
130
+ param = set(param_group['params'])
131
+ param_set = set()
132
+ for group in param_group_list:
133
+ param_set.update(set(group['params']))
134
+
135
+ return not param.isdisjoint(param_set)
136
+
137
+ def add_params(self, params, module, prefix='', is_dcn_module=None):
138
+ """Add all parameters of module to the params list.
139
+
140
+ The parameters of the given module will be added to the list of param
141
+ groups, with specific rules defined by paramwise_cfg.
142
+
143
+ Args:
144
+ params (list[dict]): A list of param groups, it will be modified
145
+ in place.
146
+ module (nn.Module): The module to be added.
147
+ prefix (str): The prefix of the module
148
+ is_dcn_module (int|float|None): If the current module is a
149
+ submodule of DCN, `is_dcn_module` will be passed to
150
+ control conv_offset layer's learning rate. Defaults to None.
151
+ """
152
+ # get param-wise options
153
+ custom_keys = self.paramwise_cfg.get('custom_keys', {})
154
+ # first sort with alphabet order and then sort with reversed len of str
155
+ sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
156
+
157
+ bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
158
+ bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
159
+ norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
160
+ dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
161
+ bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
162
+ dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)
163
+
164
+ # special rules for norm layers and depth-wise conv layers
165
+ is_norm = isinstance(module,
166
+ (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
167
+ is_dwconv = (
168
+ isinstance(module, torch.nn.Conv2d)
169
+ and module.in_channels == module.groups)
170
+
171
+ for name, param in module.named_parameters(recurse=False):
172
+ param_group = {'params': [param]}
173
+ if not param.requires_grad:
174
+ params.append(param_group)
175
+ continue
176
+ if bypass_duplicate and self._is_in(param_group, params):
177
+ warnings.warn(f'{prefix} is duplicate. It is skipped since '
178
+ f'bypass_duplicate={bypass_duplicate}')
179
+ continue
180
+ # if the parameter match one of the custom keys, ignore other rules
181
+ is_custom = False
182
+ for key in sorted_keys:
183
+ if key in f'{prefix}.{name}':
184
+ is_custom = True
185
+ lr_mult = custom_keys[key].get('lr_mult', 1.)
186
+ param_group['lr'] = self.base_lr * lr_mult
187
+ if self.base_wd is not None:
188
+ decay_mult = custom_keys[key].get('decay_mult', 1.)
189
+ param_group['weight_decay'] = self.base_wd * decay_mult
190
+ break
191
+
192
+ if not is_custom:
193
+ # bias_lr_mult affects all bias parameters
194
+ # except for norm.bias dcn.conv_offset.bias
195
+ if name == 'bias' and not (is_norm or is_dcn_module):
196
+ param_group['lr'] = self.base_lr * bias_lr_mult
197
+
198
+ if (prefix.find('conv_offset') != -1 and is_dcn_module
199
+ and isinstance(module, torch.nn.Conv2d)):
200
+ # deal with both dcn_offset's bias & weight
201
+ param_group['lr'] = self.base_lr * dcn_offset_lr_mult
202
+
203
+ # apply weight decay policies
204
+ if self.base_wd is not None:
205
+ # norm decay
206
+ if is_norm:
207
+ param_group[
208
+ 'weight_decay'] = self.base_wd * norm_decay_mult
209
+ # depth-wise conv
210
+ elif is_dwconv:
211
+ param_group[
212
+ 'weight_decay'] = self.base_wd * dwconv_decay_mult
213
+ # bias lr and decay
214
+ elif name == 'bias' and not is_dcn_module:
215
+ # TODO: current bias_decay_mult will have affect on DCN
216
+ param_group[
217
+ 'weight_decay'] = self.base_wd * bias_decay_mult
218
+ params.append(param_group)
219
+
220
+ if check_ops_exist():
221
+ from annotator.mmpkg.mmcv.ops import DeformConv2d, ModulatedDeformConv2d
222
+ is_dcn_module = isinstance(module,
223
+ (DeformConv2d, ModulatedDeformConv2d))
224
+ else:
225
+ is_dcn_module = False
226
+ for child_name, child_mod in module.named_children():
227
+ child_prefix = f'{prefix}.{child_name}' if prefix else child_name
228
+ self.add_params(
229
+ params,
230
+ child_mod,
231
+ prefix=child_prefix,
232
+ is_dcn_module=is_dcn_module)
233
+
234
+ def __call__(self, model):
235
+ if hasattr(model, 'module'):
236
+ model = model.module
237
+
238
+ optimizer_cfg = self.optimizer_cfg.copy()
239
+ # if no paramwise option is specified, just use the global setting
240
+ if not self.paramwise_cfg:
241
+ optimizer_cfg['params'] = model.parameters()
242
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
243
+
244
+ # set param-wise lr and weight decay recursively
245
+ params = []
246
+ self.add_params(params, model)
247
+ optimizer_cfg['params'] = params
248
+
249
+ return build_from_cfg(optimizer_cfg, OPTIMIZERS)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/priority.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from enum import Enum
3
+
4
+
5
+ class Priority(Enum):
6
+ """Hook priority levels.
7
+
8
+ +--------------+------------+
9
+ | Level | Value |
10
+ +==============+============+
11
+ | HIGHEST | 0 |
12
+ +--------------+------------+
13
+ | VERY_HIGH | 10 |
14
+ +--------------+------------+
15
+ | HIGH | 30 |
16
+ +--------------+------------+
17
+ | ABOVE_NORMAL | 40 |
18
+ +--------------+------------+
19
+ | NORMAL | 50 |
20
+ +--------------+------------+
21
+ | BELOW_NORMAL | 60 |
22
+ +--------------+------------+
23
+ | LOW | 70 |
24
+ +--------------+------------+
25
+ | VERY_LOW | 90 |
26
+ +--------------+------------+
27
+ | LOWEST | 100 |
28
+ +--------------+------------+
29
+ """
30
+
31
+ HIGHEST = 0
32
+ VERY_HIGH = 10
33
+ HIGH = 30
34
+ ABOVE_NORMAL = 40
35
+ NORMAL = 50
36
+ BELOW_NORMAL = 60
37
+ LOW = 70
38
+ VERY_LOW = 90
39
+ LOWEST = 100
40
+
41
+
42
+ def get_priority(priority):
43
+ """Get priority value.
44
+
45
+ Args:
46
+ priority (int or str or :obj:`Priority`): Priority.
47
+
48
+ Returns:
49
+ int: The priority value.
50
+ """
51
+ if isinstance(priority, int):
52
+ if priority < 0 or priority > 100:
53
+ raise ValueError('priority must be between 0 and 100')
54
+ return priority
55
+ elif isinstance(priority, Priority):
56
+ return priority.value
57
+ elif isinstance(priority, str):
58
+ return Priority[priority.upper()].value
59
+ else:
60
+ raise TypeError('priority must be an integer or Priority enum value')
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ import random
4
+ import sys
5
+ import time
6
+ import warnings
7
+ from getpass import getuser
8
+ from socket import gethostname
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ import annotator.mmpkg.mmcv as mmcv
14
+
15
+
16
+ def get_host_info():
17
+ """Get hostname and username.
18
+
19
+ Return empty string if exception raised, e.g. ``getpass.getuser()`` will
20
+ lead to error in docker container
21
+ """
22
+ host = ''
23
+ try:
24
+ host = f'{getuser()}@{gethostname()}'
25
+ except Exception as e:
26
+ warnings.warn(f'Host or user not found: {str(e)}')
27
+ finally:
28
+ return host
29
+
30
+
31
+ def get_time_str():
32
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
33
+
34
+
35
+ def obj_from_dict(info, parent=None, default_args=None):
36
+ """Initialize an object from dict.
37
+
38
+ The dict must contain the key "type", which indicates the object type, it
39
+ can be either a string or type, such as "list" or ``list``. Remaining
40
+ fields are treated as the arguments for constructing the object.
41
+
42
+ Args:
43
+ info (dict): Object types and arguments.
44
+ parent (:class:`module`): Module which may containing expected object
45
+ classes.
46
+ default_args (dict, optional): Default arguments for initializing the
47
+ object.
48
+
49
+ Returns:
50
+ any type: Object built from the dict.
51
+ """
52
+ assert isinstance(info, dict) and 'type' in info
53
+ assert isinstance(default_args, dict) or default_args is None
54
+ args = info.copy()
55
+ obj_type = args.pop('type')
56
+ if mmcv.is_str(obj_type):
57
+ if parent is not None:
58
+ obj_type = getattr(parent, obj_type)
59
+ else:
60
+ obj_type = sys.modules[obj_type]
61
+ elif not isinstance(obj_type, type):
62
+ raise TypeError('type must be a str or valid type, but '
63
+ f'got {type(obj_type)}')
64
+ if default_args is not None:
65
+ for name, value in default_args.items():
66
+ args.setdefault(name, value)
67
+ return obj_type(**args)
68
+
69
+
70
+ def set_random_seed(seed, deterministic=False, use_rank_shift=False):
71
+ """Set random seed.
72
+
73
+ Args:
74
+ seed (int): Seed to be used.
75
+ deterministic (bool): Whether to set the deterministic option for
76
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
77
+ to True and `torch.backends.cudnn.benchmark` to False.
78
+ Default: False.
79
+ rank_shift (bool): Whether to add rank number to the random seed to
80
+ have different random seed in different threads. Default: False.
81
+ """
82
+ if use_rank_shift:
83
+ rank, _ = mmcv.runner.get_dist_info()
84
+ seed += rank
85
+ random.seed(seed)
86
+ np.random.seed(seed)
87
+ torch.manual_seed(seed)
88
+ torch.cuda.manual_seed(seed)
89
+ torch.cuda.manual_seed_all(seed)
90
+ os.environ['PYTHONHASHSEED'] = str(seed)
91
+ if deterministic:
92
+ torch.backends.cudnn.deterministic = True
93
+ torch.backends.cudnn.benchmark = False
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/__init__.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+ # Copyright (c) OpenMMLab. All rights reserved.
3
+ from .config import Config, ConfigDict, DictAction
4
+ from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
5
+ has_method, import_modules_from_strings, is_list_of,
6
+ is_method_overridden, is_seq_of, is_str, is_tuple_of,
7
+ iter_cast, list_cast, requires_executable, requires_package,
8
+ slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
9
+ to_ntuple, tuple_cast)
10
+ from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
11
+ scandir, symlink)
12
+ from .progressbar import (ProgressBar, track_iter_progress,
13
+ track_parallel_progress, track_progress)
14
+ from .testing import (assert_attrs_equal, assert_dict_contains_subset,
15
+ assert_dict_has_keys, assert_is_norm_layer,
16
+ assert_keys_equal, assert_params_all_zeros,
17
+ check_python_script)
18
+ from .timer import Timer, TimerError, check_time
19
+ from .version_utils import digit_version, get_git_hash
20
+
21
+ try:
22
+ import torch
23
+ except ImportError:
24
+ __all__ = [
25
+ 'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast',
26
+ 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of',
27
+ 'slice_list', 'concat_list', 'check_prerequisites', 'requires_package',
28
+ 'requires_executable', 'is_filepath', 'fopen', 'check_file_exist',
29
+ 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
30
+ 'track_progress', 'track_iter_progress', 'track_parallel_progress',
31
+ 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
32
+ 'digit_version', 'get_git_hash', 'import_modules_from_strings',
33
+ 'assert_dict_contains_subset', 'assert_attrs_equal',
34
+ 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
35
+ 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
36
+ 'is_method_overridden', 'has_method'
37
+ ]
38
+ else:
39
+ from .env import collect_env
40
+ from .logging import get_logger, print_log
41
+ from .parrots_jit import jit, skip_no_elena
42
+ from .parrots_wrapper import (
43
+ TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
44
+ PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
45
+ _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
46
+ _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
47
+ from .registry import Registry, build_from_cfg
48
+ from .trace import is_jit_tracing
49
+ __all__ = [
50
+ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
51
+ 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
52
+ 'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
53
+ 'check_prerequisites', 'requires_package', 'requires_executable',
54
+ 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
55
+ 'symlink', 'scandir', 'ProgressBar', 'track_progress',
56
+ 'track_iter_progress', 'track_parallel_progress', 'Registry',
57
+ 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
58
+ '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
59
+ '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
60
+ 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
61
+ 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
62
+ 'deprecated_api_warning', 'digit_version', 'get_git_hash',
63
+ 'import_modules_from_strings', 'jit', 'skip_no_elena',
64
+ 'assert_dict_contains_subset', 'assert_attrs_equal',
65
+ 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
66
+ 'assert_params_all_zeros', 'check_python_script',
67
+ 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
68
+ '_get_cuda_home', 'has_method'
69
+ ]
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/config.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import ast
3
+ import copy
4
+ import os
5
+ import os.path as osp
6
+ import platform
7
+ import shutil
8
+ import sys
9
+ import tempfile
10
+ import uuid
11
+ import warnings
12
+ from argparse import Action, ArgumentParser
13
+ from collections import abc
14
+ from importlib import import_module
15
+
16
+ from addict import Dict
17
+ from yapf.yapflib.yapf_api import FormatCode
18
+
19
+ from .misc import import_modules_from_strings
20
+ from .path import check_file_exist
21
+
22
+ if platform.system() == 'Windows':
23
+ import regex as re
24
+ else:
25
+ import re
26
+
27
+ BASE_KEY = '_base_'
28
+ DELETE_KEY = '_delete_'
29
+ DEPRECATION_KEY = '_deprecation_'
30
+ RESERVED_KEYS = ['filename', 'text', 'pretty_text']
31
+
32
+
33
+ class ConfigDict(Dict):
34
+
35
+ def __missing__(self, name):
36
+ raise KeyError(name)
37
+
38
+ def __getattr__(self, name):
39
+ try:
40
+ value = super(ConfigDict, self).__getattr__(name)
41
+ except KeyError:
42
+ ex = AttributeError(f"'{self.__class__.__name__}' object has no "
43
+ f"attribute '{name}'")
44
+ except Exception as e:
45
+ ex = e
46
+ else:
47
+ return value
48
+ raise ex
49
+
50
+
51
+ def add_args(parser, cfg, prefix=''):
52
+ for k, v in cfg.items():
53
+ if isinstance(v, str):
54
+ parser.add_argument('--' + prefix + k)
55
+ elif isinstance(v, int):
56
+ parser.add_argument('--' + prefix + k, type=int)
57
+ elif isinstance(v, float):
58
+ parser.add_argument('--' + prefix + k, type=float)
59
+ elif isinstance(v, bool):
60
+ parser.add_argument('--' + prefix + k, action='store_true')
61
+ elif isinstance(v, dict):
62
+ add_args(parser, v, prefix + k + '.')
63
+ elif isinstance(v, abc.Iterable):
64
+ parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
65
+ else:
66
+ print(f'cannot parse key {prefix + k} of type {type(v)}')
67
+ return parser
68
+
69
+
70
+ class Config:
71
+ """A facility for config and config files.
72
+
73
+ It supports common file formats as configs: python/json/yaml. The interface
74
+ is the same as a dict object and also allows access config values as
75
+ attributes.
76
+
77
+ Example:
78
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
79
+ >>> cfg.a
80
+ 1
81
+ >>> cfg.b
82
+ {'b1': [0, 1]}
83
+ >>> cfg.b.b1
84
+ [0, 1]
85
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
86
+ >>> cfg.filename
87
+ "/home/kchen/projects/mmcv/tests/data/config/a.py"
88
+ >>> cfg.item4
89
+ 'test'
90
+ >>> cfg
91
+ "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
92
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
93
+ """
94
+
95
+ @staticmethod
96
+ def _validate_py_syntax(filename):
97
+ with open(filename, 'r', encoding='utf-8') as f:
98
+ # Setting encoding explicitly to resolve coding issue on windows
99
+ content = f.read()
100
+ try:
101
+ ast.parse(content)
102
+ except SyntaxError as e:
103
+ raise SyntaxError('There are syntax errors in config '
104
+ f'file {filename}: {e}')
105
+
106
+ @staticmethod
107
+ def _substitute_predefined_vars(filename, temp_config_name):
108
+ file_dirname = osp.dirname(filename)
109
+ file_basename = osp.basename(filename)
110
+ file_basename_no_extension = osp.splitext(file_basename)[0]
111
+ file_extname = osp.splitext(filename)[1]
112
+ support_templates = dict(
113
+ fileDirname=file_dirname,
114
+ fileBasename=file_basename,
115
+ fileBasenameNoExtension=file_basename_no_extension,
116
+ fileExtname=file_extname)
117
+ with open(filename, 'r', encoding='utf-8') as f:
118
+ # Setting encoding explicitly to resolve coding issue on windows
119
+ config_file = f.read()
120
+ for key, value in support_templates.items():
121
+ regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
122
+ value = value.replace('\\', '/')
123
+ config_file = re.sub(regexp, value, config_file)
124
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
125
+ tmp_config_file.write(config_file)
126
+
127
+ @staticmethod
128
+ def _pre_substitute_base_vars(filename, temp_config_name):
129
+ """Substitute base variable placehoders to string, so that parsing
130
+ would work."""
131
+ with open(filename, 'r', encoding='utf-8') as f:
132
+ # Setting encoding explicitly to resolve coding issue on windows
133
+ config_file = f.read()
134
+ base_var_dict = {}
135
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
136
+ base_vars = set(re.findall(regexp, config_file))
137
+ for base_var in base_vars:
138
+ randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
139
+ base_var_dict[randstr] = base_var
140
+ regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
141
+ config_file = re.sub(regexp, f'"{randstr}"', config_file)
142
+ with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
143
+ tmp_config_file.write(config_file)
144
+ return base_var_dict
145
+
146
+ @staticmethod
147
+ def _substitute_base_vars(cfg, base_var_dict, base_cfg):
148
+ """Substitute variable strings to their actual values."""
149
+ cfg = copy.deepcopy(cfg)
150
+
151
+ if isinstance(cfg, dict):
152
+ for k, v in cfg.items():
153
+ if isinstance(v, str) and v in base_var_dict:
154
+ new_v = base_cfg
155
+ for new_k in base_var_dict[v].split('.'):
156
+ new_v = new_v[new_k]
157
+ cfg[k] = new_v
158
+ elif isinstance(v, (list, tuple, dict)):
159
+ cfg[k] = Config._substitute_base_vars(
160
+ v, base_var_dict, base_cfg)
161
+ elif isinstance(cfg, tuple):
162
+ cfg = tuple(
163
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
164
+ for c in cfg)
165
+ elif isinstance(cfg, list):
166
+ cfg = [
167
+ Config._substitute_base_vars(c, base_var_dict, base_cfg)
168
+ for c in cfg
169
+ ]
170
+ elif isinstance(cfg, str) and cfg in base_var_dict:
171
+ new_v = base_cfg
172
+ for new_k in base_var_dict[cfg].split('.'):
173
+ new_v = new_v[new_k]
174
+ cfg = new_v
175
+
176
+ return cfg
177
+
178
+ @staticmethod
179
+ def _file2dict(filename, use_predefined_variables=True):
180
+ filename = osp.abspath(osp.expanduser(filename))
181
+ check_file_exist(filename)
182
+ fileExtname = osp.splitext(filename)[1]
183
+ if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
184
+ raise IOError('Only py/yml/yaml/json type are supported now!')
185
+
186
+ with tempfile.TemporaryDirectory() as temp_config_dir:
187
+ temp_config_file = tempfile.NamedTemporaryFile(
188
+ dir=temp_config_dir, suffix=fileExtname)
189
+ if platform.system() == 'Windows':
190
+ temp_config_file.close()
191
+ temp_config_name = osp.basename(temp_config_file.name)
192
+ # Substitute predefined variables
193
+ if use_predefined_variables:
194
+ Config._substitute_predefined_vars(filename,
195
+ temp_config_file.name)
196
+ else:
197
+ shutil.copyfile(filename, temp_config_file.name)
198
+ # Substitute base variables from placeholders to strings
199
+ base_var_dict = Config._pre_substitute_base_vars(
200
+ temp_config_file.name, temp_config_file.name)
201
+
202
+ if filename.endswith('.py'):
203
+ temp_module_name = osp.splitext(temp_config_name)[0]
204
+ sys.path.insert(0, temp_config_dir)
205
+ Config._validate_py_syntax(filename)
206
+ mod = import_module(temp_module_name)
207
+ sys.path.pop(0)
208
+ cfg_dict = {
209
+ name: value
210
+ for name, value in mod.__dict__.items()
211
+ if not name.startswith('__')
212
+ }
213
+ # delete imported module
214
+ del sys.modules[temp_module_name]
215
+ elif filename.endswith(('.yml', '.yaml', '.json')):
216
+ import annotator.mmpkg.mmcv as mmcv
217
+ cfg_dict = mmcv.load(temp_config_file.name)
218
+ # close temp file
219
+ temp_config_file.close()
220
+
221
+ # check deprecation information
222
+ if DEPRECATION_KEY in cfg_dict:
223
+ deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
224
+ warning_msg = f'The config file {filename} will be deprecated ' \
225
+ 'in the future.'
226
+ if 'expected' in deprecation_info:
227
+ warning_msg += f' Please use {deprecation_info["expected"]} ' \
228
+ 'instead.'
229
+ if 'reference' in deprecation_info:
230
+ warning_msg += ' More information can be found at ' \
231
+ f'{deprecation_info["reference"]}'
232
+ warnings.warn(warning_msg)
233
+
234
+ cfg_text = filename + '\n'
235
+ with open(filename, 'r', encoding='utf-8') as f:
236
+ # Setting encoding explicitly to resolve coding issue on windows
237
+ cfg_text += f.read()
238
+
239
+ if BASE_KEY in cfg_dict:
240
+ cfg_dir = osp.dirname(filename)
241
+ base_filename = cfg_dict.pop(BASE_KEY)
242
+ base_filename = base_filename if isinstance(
243
+ base_filename, list) else [base_filename]
244
+
245
+ cfg_dict_list = list()
246
+ cfg_text_list = list()
247
+ for f in base_filename:
248
+ _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
249
+ cfg_dict_list.append(_cfg_dict)
250
+ cfg_text_list.append(_cfg_text)
251
+
252
+ base_cfg_dict = dict()
253
+ for c in cfg_dict_list:
254
+ duplicate_keys = base_cfg_dict.keys() & c.keys()
255
+ if len(duplicate_keys) > 0:
256
+ raise KeyError('Duplicate key is not allowed among bases. '
257
+ f'Duplicate keys: {duplicate_keys}')
258
+ base_cfg_dict.update(c)
259
+
260
+ # Substitute base variables from strings to their actual values
261
+ cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
262
+ base_cfg_dict)
263
+
264
+ base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
265
+ cfg_dict = base_cfg_dict
266
+
267
+ # merge cfg_text
268
+ cfg_text_list.append(cfg_text)
269
+ cfg_text = '\n'.join(cfg_text_list)
270
+
271
+ return cfg_dict, cfg_text
272
+
273
+ @staticmethod
274
+ def _merge_a_into_b(a, b, allow_list_keys=False):
275
+ """merge dict ``a`` into dict ``b`` (non-inplace).
276
+
277
+ Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
278
+ in-place modifications.
279
+
280
+ Args:
281
+ a (dict): The source dict to be merged into ``b``.
282
+ b (dict): The origin dict to be fetch keys from ``a``.
283
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
284
+ are allowed in source ``a`` and will replace the element of the
285
+ corresponding index in b if b is a list. Default: False.
286
+
287
+ Returns:
288
+ dict: The modified dict of ``b`` using ``a``.
289
+
290
+ Examples:
291
+ # Normally merge a into b.
292
+ >>> Config._merge_a_into_b(
293
+ ... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
294
+ {'obj': {'a': 2}}
295
+
296
+ # Delete b first and merge a into b.
297
+ >>> Config._merge_a_into_b(
298
+ ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
299
+ {'obj': {'a': 2}}
300
+
301
+ # b is a list
302
+ >>> Config._merge_a_into_b(
303
+ ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
304
+ [{'a': 2}, {'b': 2}]
305
+ """
306
+ b = b.copy()
307
+ for k, v in a.items():
308
+ if allow_list_keys and k.isdigit() and isinstance(b, list):
309
+ k = int(k)
310
+ if len(b) <= k:
311
+ raise KeyError(f'Index {k} exceeds the length of list {b}')
312
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
313
+ elif isinstance(v,
314
+ dict) and k in b and not v.pop(DELETE_KEY, False):
315
+ allowed_types = (dict, list) if allow_list_keys else dict
316
+ if not isinstance(b[k], allowed_types):
317
+ raise TypeError(
318
+ f'{k}={v} in child config cannot inherit from base '
319
+ f'because {k} is a dict in the child config but is of '
320
+ f'type {type(b[k])} in base config. You may set '
321
+ f'`{DELETE_KEY}=True` to ignore the base config')
322
+ b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
323
+ else:
324
+ b[k] = v
325
+ return b
326
+
327
+ @staticmethod
328
+ def fromfile(filename,
329
+ use_predefined_variables=True,
330
+ import_custom_modules=True):
331
+ cfg_dict, cfg_text = Config._file2dict(filename,
332
+ use_predefined_variables)
333
+ if import_custom_modules and cfg_dict.get('custom_imports', None):
334
+ import_modules_from_strings(**cfg_dict['custom_imports'])
335
+ return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
336
+
337
+ @staticmethod
338
+ def fromstring(cfg_str, file_format):
339
+ """Generate config from config str.
340
+
341
+ Args:
342
+ cfg_str (str): Config str.
343
+ file_format (str): Config file format corresponding to the
344
+ config str. Only py/yml/yaml/json type are supported now!
345
+
346
+ Returns:
347
+ obj:`Config`: Config obj.
348
+ """
349
+ if file_format not in ['.py', '.json', '.yaml', '.yml']:
350
+ raise IOError('Only py/yml/yaml/json type are supported now!')
351
+ if file_format != '.py' and 'dict(' in cfg_str:
352
+ # check if users specify a wrong suffix for python
353
+ warnings.warn(
354
+ 'Please check "file_format", the file format may be .py')
355
+ with tempfile.NamedTemporaryFile(
356
+ 'w', encoding='utf-8', suffix=file_format,
357
+ delete=False) as temp_file:
358
+ temp_file.write(cfg_str)
359
+ # on windows, previous implementation cause error
360
+ # see PR 1077 for details
361
+ cfg = Config.fromfile(temp_file.name)
362
+ os.remove(temp_file.name)
363
+ return cfg
364
+
365
+ @staticmethod
366
+ def auto_argparser(description=None):
367
+ """Generate argparser from config file automatically (experimental)"""
368
+ partial_parser = ArgumentParser(description=description)
369
+ partial_parser.add_argument('config', help='config file path')
370
+ cfg_file = partial_parser.parse_known_args()[0].config
371
+ cfg = Config.fromfile(cfg_file)
372
+ parser = ArgumentParser(description=description)
373
+ parser.add_argument('config', help='config file path')
374
+ add_args(parser, cfg)
375
+ return parser, cfg
376
+
377
+ def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
378
+ if cfg_dict is None:
379
+ cfg_dict = dict()
380
+ elif not isinstance(cfg_dict, dict):
381
+ raise TypeError('cfg_dict must be a dict, but '
382
+ f'got {type(cfg_dict)}')
383
+ for key in cfg_dict:
384
+ if key in RESERVED_KEYS:
385
+ raise KeyError(f'{key} is reserved for config file')
386
+
387
+ super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
388
+ super(Config, self).__setattr__('_filename', filename)
389
+ if cfg_text:
390
+ text = cfg_text
391
+ elif filename:
392
+ with open(filename, 'r') as f:
393
+ text = f.read()
394
+ else:
395
+ text = ''
396
+ super(Config, self).__setattr__('_text', text)
397
+
398
+ @property
399
+ def filename(self):
400
+ return self._filename
401
+
402
+ @property
403
+ def text(self):
404
+ return self._text
405
+
406
+ @property
407
+ def pretty_text(self):
408
+
409
+ indent = 4
410
+
411
+ def _indent(s_, num_spaces):
412
+ s = s_.split('\n')
413
+ if len(s) == 1:
414
+ return s_
415
+ first = s.pop(0)
416
+ s = [(num_spaces * ' ') + line for line in s]
417
+ s = '\n'.join(s)
418
+ s = first + '\n' + s
419
+ return s
420
+
421
+ def _format_basic_types(k, v, use_mapping=False):
422
+ if isinstance(v, str):
423
+ v_str = f"'{v}'"
424
+ else:
425
+ v_str = str(v)
426
+
427
+ if use_mapping:
428
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
429
+ attr_str = f'{k_str}: {v_str}'
430
+ else:
431
+ attr_str = f'{str(k)}={v_str}'
432
+ attr_str = _indent(attr_str, indent)
433
+
434
+ return attr_str
435
+
436
+ def _format_list(k, v, use_mapping=False):
437
+ # check if all items in the list are dict
438
+ if all(isinstance(_, dict) for _ in v):
439
+ v_str = '[\n'
440
+ v_str += '\n'.join(
441
+ f'dict({_indent(_format_dict(v_), indent)}),'
442
+ for v_ in v).rstrip(',')
443
+ if use_mapping:
444
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
445
+ attr_str = f'{k_str}: {v_str}'
446
+ else:
447
+ attr_str = f'{str(k)}={v_str}'
448
+ attr_str = _indent(attr_str, indent) + ']'
449
+ else:
450
+ attr_str = _format_basic_types(k, v, use_mapping)
451
+ return attr_str
452
+
453
+ def _contain_invalid_identifier(dict_str):
454
+ contain_invalid_identifier = False
455
+ for key_name in dict_str:
456
+ contain_invalid_identifier |= \
457
+ (not str(key_name).isidentifier())
458
+ return contain_invalid_identifier
459
+
460
+ def _format_dict(input_dict, outest_level=False):
461
+ r = ''
462
+ s = []
463
+
464
+ use_mapping = _contain_invalid_identifier(input_dict)
465
+ if use_mapping:
466
+ r += '{'
467
+ for idx, (k, v) in enumerate(input_dict.items()):
468
+ is_last = idx >= len(input_dict) - 1
469
+ end = '' if outest_level or is_last else ','
470
+ if isinstance(v, dict):
471
+ v_str = '\n' + _format_dict(v)
472
+ if use_mapping:
473
+ k_str = f"'{k}'" if isinstance(k, str) else str(k)
474
+ attr_str = f'{k_str}: dict({v_str}'
475
+ else:
476
+ attr_str = f'{str(k)}=dict({v_str}'
477
+ attr_str = _indent(attr_str, indent) + ')' + end
478
+ elif isinstance(v, list):
479
+ attr_str = _format_list(k, v, use_mapping) + end
480
+ else:
481
+ attr_str = _format_basic_types(k, v, use_mapping) + end
482
+
483
+ s.append(attr_str)
484
+ r += '\n'.join(s)
485
+ if use_mapping:
486
+ r += '}'
487
+ return r
488
+
489
+ cfg_dict = self._cfg_dict.to_dict()
490
+ text = _format_dict(cfg_dict, outest_level=True)
491
+ # copied from setup.cfg
492
+ yapf_style = dict(
493
+ based_on_style='pep8',
494
+ blank_line_before_nested_class_or_def=True,
495
+ split_before_expression_after_opening_paren=True)
496
+ text, _ = FormatCode(text, style_config=yapf_style, verify=True)
497
+
498
+ return text
499
+
500
+ def __repr__(self):
501
+ return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
502
+
503
+ def __len__(self):
504
+ return len(self._cfg_dict)
505
+
506
+ def __getattr__(self, name):
507
+ return getattr(self._cfg_dict, name)
508
+
509
+ def __getitem__(self, name):
510
+ return self._cfg_dict.__getitem__(name)
511
+
512
+ def __setattr__(self, name, value):
513
+ if isinstance(value, dict):
514
+ value = ConfigDict(value)
515
+ self._cfg_dict.__setattr__(name, value)
516
+
517
+ def __setitem__(self, name, value):
518
+ if isinstance(value, dict):
519
+ value = ConfigDict(value)
520
+ self._cfg_dict.__setitem__(name, value)
521
+
522
+ def __iter__(self):
523
+ return iter(self._cfg_dict)
524
+
525
+ def __getstate__(self):
526
+ return (self._cfg_dict, self._filename, self._text)
527
+
528
+ def __setstate__(self, state):
529
+ _cfg_dict, _filename, _text = state
530
+ super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
531
+ super(Config, self).__setattr__('_filename', _filename)
532
+ super(Config, self).__setattr__('_text', _text)
533
+
534
+ def dump(self, file=None):
535
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
536
+ if self.filename.endswith('.py'):
537
+ if file is None:
538
+ return self.pretty_text
539
+ else:
540
+ with open(file, 'w', encoding='utf-8') as f:
541
+ f.write(self.pretty_text)
542
+ else:
543
+ import annotator.mmpkg.mmcv as mmcv
544
+ if file is None:
545
+ file_format = self.filename.split('.')[-1]
546
+ return mmcv.dump(cfg_dict, file_format=file_format)
547
+ else:
548
+ mmcv.dump(cfg_dict, file)
549
+
550
+ def merge_from_dict(self, options, allow_list_keys=True):
551
+ """Merge list into cfg_dict.
552
+
553
+ Merge the dict parsed by MultipleKVAction into this cfg.
554
+
555
+ Examples:
556
+ >>> options = {'model.backbone.depth': 50,
557
+ ... 'model.backbone.with_cp':True}
558
+ >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
559
+ >>> cfg.merge_from_dict(options)
560
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
561
+ >>> assert cfg_dict == dict(
562
+ ... model=dict(backbone=dict(depth=50, with_cp=True)))
563
+
564
+ # Merge list element
565
+ >>> cfg = Config(dict(pipeline=[
566
+ ... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
567
+ >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
568
+ >>> cfg.merge_from_dict(options, allow_list_keys=True)
569
+ >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
570
+ >>> assert cfg_dict == dict(pipeline=[
571
+ ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
572
+
573
+ Args:
574
+ options (dict): dict of configs to merge from.
575
+ allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
576
+ are allowed in ``options`` and will replace the element of the
577
+ corresponding index in the config if the config is a list.
578
+ Default: True.
579
+ """
580
+ option_cfg_dict = {}
581
+ for full_key, v in options.items():
582
+ d = option_cfg_dict
583
+ key_list = full_key.split('.')
584
+ for subkey in key_list[:-1]:
585
+ d.setdefault(subkey, ConfigDict())
586
+ d = d[subkey]
587
+ subkey = key_list[-1]
588
+ d[subkey] = v
589
+
590
+ cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
591
+ super(Config, self).__setattr__(
592
+ '_cfg_dict',
593
+ Config._merge_a_into_b(
594
+ option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
595
+
596
+
597
+ class DictAction(Action):
598
+ """
599
+ argparse action to split an argument into KEY=VALUE form
600
+ on the first = and append to a dictionary. List options can
601
+ be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
602
+ brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
603
+ list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
604
+ """
605
+
606
+ @staticmethod
607
+ def _parse_int_float_bool(val):
608
+ try:
609
+ return int(val)
610
+ except ValueError:
611
+ pass
612
+ try:
613
+ return float(val)
614
+ except ValueError:
615
+ pass
616
+ if val.lower() in ['true', 'false']:
617
+ return True if val.lower() == 'true' else False
618
+ return val
619
+
620
+ @staticmethod
621
+ def _parse_iterable(val):
622
+ """Parse iterable values in the string.
623
+
624
+ All elements inside '()' or '[]' are treated as iterable values.
625
+
626
+ Args:
627
+ val (str): Value string.
628
+
629
+ Returns:
630
+ list | tuple: The expanded list or tuple from the string.
631
+
632
+ Examples:
633
+ >>> DictAction._parse_iterable('1,2,3')
634
+ [1, 2, 3]
635
+ >>> DictAction._parse_iterable('[a, b, c]')
636
+ ['a', 'b', 'c']
637
+ >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
638
+ [(1, 2, 3), ['a', 'b'], 'c']
639
+ """
640
+
641
+ def find_next_comma(string):
642
+ """Find the position of next comma in the string.
643
+
644
+ If no ',' is found in the string, return the string length. All
645
+ chars inside '()' and '[]' are treated as one element and thus ','
646
+ inside these brackets are ignored.
647
+ """
648
+ assert (string.count('(') == string.count(')')) and (
649
+ string.count('[') == string.count(']')), \
650
+ f'Imbalanced brackets exist in {string}'
651
+ end = len(string)
652
+ for idx, char in enumerate(string):
653
+ pre = string[:idx]
654
+ # The string before this ',' is balanced
655
+ if ((char == ',') and (pre.count('(') == pre.count(')'))
656
+ and (pre.count('[') == pre.count(']'))):
657
+ end = idx
658
+ break
659
+ return end
660
+
661
+ # Strip ' and " characters and replace whitespace.
662
+ val = val.strip('\'\"').replace(' ', '')
663
+ is_tuple = False
664
+ if val.startswith('(') and val.endswith(')'):
665
+ is_tuple = True
666
+ val = val[1:-1]
667
+ elif val.startswith('[') and val.endswith(']'):
668
+ val = val[1:-1]
669
+ elif ',' not in val:
670
+ # val is a single value
671
+ return DictAction._parse_int_float_bool(val)
672
+
673
+ values = []
674
+ while len(val) > 0:
675
+ comma_idx = find_next_comma(val)
676
+ element = DictAction._parse_iterable(val[:comma_idx])
677
+ values.append(element)
678
+ val = val[comma_idx + 1:]
679
+ if is_tuple:
680
+ values = tuple(values)
681
+ return values
682
+
683
+ def __call__(self, parser, namespace, values, option_string=None):
684
+ options = {}
685
+ for kv in values:
686
+ key, val = kv.split('=', maxsplit=1)
687
+ options[key] = self._parse_iterable(val)
688
+ setattr(namespace, self.dest, options)
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/env.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ """This file holding some environment constant for sharing by other files."""
3
+
4
+ import os.path as osp
5
+ import subprocess
6
+ import sys
7
+ from collections import defaultdict
8
+
9
+ import cv2
10
+ import torch
11
+
12
+ import annotator.mmpkg.mmcv as mmcv
13
+ from .parrots_wrapper import get_build_config
14
+
15
+
16
+ def collect_env():
17
+ """Collect the information of the running environments.
18
+
19
+ Returns:
20
+ dict: The environment information. The following fields are contained.
21
+
22
+ - sys.platform: The variable of ``sys.platform``.
23
+ - Python: Python version.
24
+ - CUDA available: Bool, indicating if CUDA is available.
25
+ - GPU devices: Device type of each GPU.
26
+ - CUDA_HOME (optional): The env var ``CUDA_HOME``.
27
+ - NVCC (optional): NVCC version.
28
+ - GCC: GCC version, "n/a" if GCC is not installed.
29
+ - PyTorch: PyTorch version.
30
+ - PyTorch compiling details: The output of \
31
+ ``torch.__config__.show()``.
32
+ - TorchVision (optional): TorchVision version.
33
+ - OpenCV: OpenCV version.
34
+ - MMCV: MMCV version.
35
+ - MMCV Compiler: The GCC version for compiling MMCV ops.
36
+ - MMCV CUDA Compiler: The CUDA version for compiling MMCV ops.
37
+ """
38
+ env_info = {}
39
+ env_info['sys.platform'] = sys.platform
40
+ env_info['Python'] = sys.version.replace('\n', '')
41
+
42
+ cuda_available = torch.cuda.is_available()
43
+ env_info['CUDA available'] = cuda_available
44
+
45
+ if cuda_available:
46
+ devices = defaultdict(list)
47
+ for k in range(torch.cuda.device_count()):
48
+ devices[torch.cuda.get_device_name(k)].append(str(k))
49
+ for name, device_ids in devices.items():
50
+ env_info['GPU ' + ','.join(device_ids)] = name
51
+
52
+ from annotator.mmpkg.mmcv.utils.parrots_wrapper import _get_cuda_home
53
+ CUDA_HOME = _get_cuda_home()
54
+ env_info['CUDA_HOME'] = CUDA_HOME
55
+
56
+ if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
57
+ try:
58
+ nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
59
+ nvcc = subprocess.check_output(
60
+ f'"{nvcc}" -V | tail -n1', shell=True)
61
+ nvcc = nvcc.decode('utf-8').strip()
62
+ except subprocess.SubprocessError:
63
+ nvcc = 'Not Available'
64
+ env_info['NVCC'] = nvcc
65
+
66
+ try:
67
+ gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
68
+ gcc = gcc.decode('utf-8').strip()
69
+ env_info['GCC'] = gcc
70
+ except subprocess.CalledProcessError: # gcc is unavailable
71
+ env_info['GCC'] = 'n/a'
72
+
73
+ env_info['PyTorch'] = torch.__version__
74
+ env_info['PyTorch compiling details'] = get_build_config()
75
+
76
+ try:
77
+ import torchvision
78
+ env_info['TorchVision'] = torchvision.__version__
79
+ except ModuleNotFoundError:
80
+ pass
81
+
82
+ env_info['OpenCV'] = cv2.__version__
83
+
84
+ env_info['MMCV'] = mmcv.__version__
85
+
86
+ try:
87
+ from annotator.mmpkg.mmcv.ops import get_compiler_version, get_compiling_cuda_version
88
+ except ModuleNotFoundError:
89
+ env_info['MMCV Compiler'] = 'n/a'
90
+ env_info['MMCV CUDA Compiler'] = 'n/a'
91
+ else:
92
+ env_info['MMCV Compiler'] = get_compiler_version()
93
+ env_info['MMCV CUDA Compiler'] = get_compiling_cuda_version()
94
+
95
+ return env_info