14c20fdaa90f290a63f8278d3172dcf42927bd5cd67fb86f306067688f2c86f2
Browse files- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/collate.py +84 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/data_container.py +89 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/data_parallel.py +89 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/distributed.py +112 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/distributed_deprecated.py +70 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/registry.py +8 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/scatter_gather.py +59 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/utils.py +20 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/__init__.py +47 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/base_module.py +195 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/base_runner.py +542 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/builder.py +24 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/checkpoint.py +707 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/default_constructor.py +44 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/dist_utils.py +164 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/epoch_based_runner.py +187 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/fp16_utils.py +410 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/__init__.py +29 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/checkpoint.py +167 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/closure.py +11 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/ema.py +89 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/evaluation.py +509 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/hook.py +92 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/iter_timer.py +18 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/__init__.py +15 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/base.py +166 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/dvclive.py +58 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/mlflow.py +78 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/neptune.py +82 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/pavi.py +117 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/tensorboard.py +57 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/text.py +256 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/wandb.py +56 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/lr_updater.py +670 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/memory.py +25 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/momentum_updater.py +493 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/optimizer.py +508 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/profiler.py +180 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/sampler_seed.py +20 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/sync_buffer.py +22 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/iter_based_runner.py +273 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/log_buffer.py +41 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/__init__.py +9 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/builder.py +44 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/default_constructor.py +249 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/priority.py +60 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/utils.py +93 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/__init__.py +69 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/config.py +688 -0
- extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/env.py +95 -0
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/collate.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from collections.abc import Mapping, Sequence
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.utils.data.dataloader import default_collate
|
7 |
+
|
8 |
+
from .data_container import DataContainer
|
9 |
+
|
10 |
+
|
11 |
+
def collate(batch, samples_per_gpu=1):
|
12 |
+
"""Puts each data field into a tensor/DataContainer with outer dimension
|
13 |
+
batch size.
|
14 |
+
|
15 |
+
Extend default_collate to add support for
|
16 |
+
:type:`~mmcv.parallel.DataContainer`. There are 3 cases.
|
17 |
+
|
18 |
+
1. cpu_only = True, e.g., meta data
|
19 |
+
2. cpu_only = False, stack = True, e.g., images tensors
|
20 |
+
3. cpu_only = False, stack = False, e.g., gt bboxes
|
21 |
+
"""
|
22 |
+
|
23 |
+
if not isinstance(batch, Sequence):
|
24 |
+
raise TypeError(f'{batch.dtype} is not supported.')
|
25 |
+
|
26 |
+
if isinstance(batch[0], DataContainer):
|
27 |
+
stacked = []
|
28 |
+
if batch[0].cpu_only:
|
29 |
+
for i in range(0, len(batch), samples_per_gpu):
|
30 |
+
stacked.append(
|
31 |
+
[sample.data for sample in batch[i:i + samples_per_gpu]])
|
32 |
+
return DataContainer(
|
33 |
+
stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
|
34 |
+
elif batch[0].stack:
|
35 |
+
for i in range(0, len(batch), samples_per_gpu):
|
36 |
+
assert isinstance(batch[i].data, torch.Tensor)
|
37 |
+
|
38 |
+
if batch[i].pad_dims is not None:
|
39 |
+
ndim = batch[i].dim()
|
40 |
+
assert ndim > batch[i].pad_dims
|
41 |
+
max_shape = [0 for _ in range(batch[i].pad_dims)]
|
42 |
+
for dim in range(1, batch[i].pad_dims + 1):
|
43 |
+
max_shape[dim - 1] = batch[i].size(-dim)
|
44 |
+
for sample in batch[i:i + samples_per_gpu]:
|
45 |
+
for dim in range(0, ndim - batch[i].pad_dims):
|
46 |
+
assert batch[i].size(dim) == sample.size(dim)
|
47 |
+
for dim in range(1, batch[i].pad_dims + 1):
|
48 |
+
max_shape[dim - 1] = max(max_shape[dim - 1],
|
49 |
+
sample.size(-dim))
|
50 |
+
padded_samples = []
|
51 |
+
for sample in batch[i:i + samples_per_gpu]:
|
52 |
+
pad = [0 for _ in range(batch[i].pad_dims * 2)]
|
53 |
+
for dim in range(1, batch[i].pad_dims + 1):
|
54 |
+
pad[2 * dim -
|
55 |
+
1] = max_shape[dim - 1] - sample.size(-dim)
|
56 |
+
padded_samples.append(
|
57 |
+
F.pad(
|
58 |
+
sample.data, pad, value=sample.padding_value))
|
59 |
+
stacked.append(default_collate(padded_samples))
|
60 |
+
elif batch[i].pad_dims is None:
|
61 |
+
stacked.append(
|
62 |
+
default_collate([
|
63 |
+
sample.data
|
64 |
+
for sample in batch[i:i + samples_per_gpu]
|
65 |
+
]))
|
66 |
+
else:
|
67 |
+
raise ValueError(
|
68 |
+
'pad_dims should be either None or integers (1-3)')
|
69 |
+
|
70 |
+
else:
|
71 |
+
for i in range(0, len(batch), samples_per_gpu):
|
72 |
+
stacked.append(
|
73 |
+
[sample.data for sample in batch[i:i + samples_per_gpu]])
|
74 |
+
return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
|
75 |
+
elif isinstance(batch[0], Sequence):
|
76 |
+
transposed = zip(*batch)
|
77 |
+
return [collate(samples, samples_per_gpu) for samples in transposed]
|
78 |
+
elif isinstance(batch[0], Mapping):
|
79 |
+
return {
|
80 |
+
key: collate([d[key] for d in batch], samples_per_gpu)
|
81 |
+
for key in batch[0]
|
82 |
+
}
|
83 |
+
else:
|
84 |
+
return default_collate(batch)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/data_container.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import functools
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def assert_tensor_type(func):
|
8 |
+
|
9 |
+
@functools.wraps(func)
|
10 |
+
def wrapper(*args, **kwargs):
|
11 |
+
if not isinstance(args[0].data, torch.Tensor):
|
12 |
+
raise AttributeError(
|
13 |
+
f'{args[0].__class__.__name__} has no attribute '
|
14 |
+
f'{func.__name__} for type {args[0].datatype}')
|
15 |
+
return func(*args, **kwargs)
|
16 |
+
|
17 |
+
return wrapper
|
18 |
+
|
19 |
+
|
20 |
+
class DataContainer:
|
21 |
+
"""A container for any type of objects.
|
22 |
+
|
23 |
+
Typically tensors will be stacked in the collate function and sliced along
|
24 |
+
some dimension in the scatter function. This behavior has some limitations.
|
25 |
+
1. All tensors have to be the same size.
|
26 |
+
2. Types are limited (numpy array or Tensor).
|
27 |
+
|
28 |
+
We design `DataContainer` and `MMDataParallel` to overcome these
|
29 |
+
limitations. The behavior can be either of the following.
|
30 |
+
|
31 |
+
- copy to GPU, pad all tensors to the same size and stack them
|
32 |
+
- copy to GPU without stacking
|
33 |
+
- leave the objects as is and pass it to the model
|
34 |
+
- pad_dims specifies the number of last few dimensions to do padding
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self,
|
38 |
+
data,
|
39 |
+
stack=False,
|
40 |
+
padding_value=0,
|
41 |
+
cpu_only=False,
|
42 |
+
pad_dims=2):
|
43 |
+
self._data = data
|
44 |
+
self._cpu_only = cpu_only
|
45 |
+
self._stack = stack
|
46 |
+
self._padding_value = padding_value
|
47 |
+
assert pad_dims in [None, 1, 2, 3]
|
48 |
+
self._pad_dims = pad_dims
|
49 |
+
|
50 |
+
def __repr__(self):
|
51 |
+
return f'{self.__class__.__name__}({repr(self.data)})'
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return len(self._data)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def data(self):
|
58 |
+
return self._data
|
59 |
+
|
60 |
+
@property
|
61 |
+
def datatype(self):
|
62 |
+
if isinstance(self.data, torch.Tensor):
|
63 |
+
return self.data.type()
|
64 |
+
else:
|
65 |
+
return type(self.data)
|
66 |
+
|
67 |
+
@property
|
68 |
+
def cpu_only(self):
|
69 |
+
return self._cpu_only
|
70 |
+
|
71 |
+
@property
|
72 |
+
def stack(self):
|
73 |
+
return self._stack
|
74 |
+
|
75 |
+
@property
|
76 |
+
def padding_value(self):
|
77 |
+
return self._padding_value
|
78 |
+
|
79 |
+
@property
|
80 |
+
def pad_dims(self):
|
81 |
+
return self._pad_dims
|
82 |
+
|
83 |
+
@assert_tensor_type
|
84 |
+
def size(self, *args, **kwargs):
|
85 |
+
return self.data.size(*args, **kwargs)
|
86 |
+
|
87 |
+
@assert_tensor_type
|
88 |
+
def dim(self):
|
89 |
+
return self.data.dim()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/data_parallel.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from itertools import chain
|
3 |
+
|
4 |
+
from torch.nn.parallel import DataParallel
|
5 |
+
|
6 |
+
from .scatter_gather import scatter_kwargs
|
7 |
+
|
8 |
+
|
9 |
+
class MMDataParallel(DataParallel):
|
10 |
+
"""The DataParallel module that supports DataContainer.
|
11 |
+
|
12 |
+
MMDataParallel has two main differences with PyTorch DataParallel:
|
13 |
+
|
14 |
+
- It supports a custom type :class:`DataContainer` which allows more
|
15 |
+
flexible control of input data during both GPU and CPU inference.
|
16 |
+
- It implement two more APIs ``train_step()`` and ``val_step()``.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
module (:class:`nn.Module`): Module to be encapsulated.
|
20 |
+
device_ids (list[int]): Device IDS of modules to be scattered to.
|
21 |
+
Defaults to None when GPU is not available.
|
22 |
+
output_device (str | int): Device ID for output. Defaults to None.
|
23 |
+
dim (int): Dimension used to scatter the data. Defaults to 0.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, *args, dim=0, **kwargs):
|
27 |
+
super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
|
28 |
+
self.dim = dim
|
29 |
+
|
30 |
+
def forward(self, *inputs, **kwargs):
|
31 |
+
"""Override the original forward function.
|
32 |
+
|
33 |
+
The main difference lies in the CPU inference where the data in
|
34 |
+
:class:`DataContainers` will still be gathered.
|
35 |
+
"""
|
36 |
+
if not self.device_ids:
|
37 |
+
# We add the following line thus the module could gather and
|
38 |
+
# convert data containers as those in GPU inference
|
39 |
+
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
40 |
+
return self.module(*inputs[0], **kwargs[0])
|
41 |
+
else:
|
42 |
+
return super().forward(*inputs, **kwargs)
|
43 |
+
|
44 |
+
def scatter(self, inputs, kwargs, device_ids):
|
45 |
+
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
46 |
+
|
47 |
+
def train_step(self, *inputs, **kwargs):
|
48 |
+
if not self.device_ids:
|
49 |
+
# We add the following line thus the module could gather and
|
50 |
+
# convert data containers as those in GPU inference
|
51 |
+
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
52 |
+
return self.module.train_step(*inputs[0], **kwargs[0])
|
53 |
+
|
54 |
+
assert len(self.device_ids) == 1, \
|
55 |
+
('MMDataParallel only supports single GPU training, if you need to'
|
56 |
+
' train with multiple GPUs, please use MMDistributedDataParallel'
|
57 |
+
'instead.')
|
58 |
+
|
59 |
+
for t in chain(self.module.parameters(), self.module.buffers()):
|
60 |
+
if t.device != self.src_device_obj:
|
61 |
+
raise RuntimeError(
|
62 |
+
'module must have its parameters and buffers '
|
63 |
+
f'on device {self.src_device_obj} (device_ids[0]) but '
|
64 |
+
f'found one of them on device: {t.device}')
|
65 |
+
|
66 |
+
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
67 |
+
return self.module.train_step(*inputs[0], **kwargs[0])
|
68 |
+
|
69 |
+
def val_step(self, *inputs, **kwargs):
|
70 |
+
if not self.device_ids:
|
71 |
+
# We add the following line thus the module could gather and
|
72 |
+
# convert data containers as those in GPU inference
|
73 |
+
inputs, kwargs = self.scatter(inputs, kwargs, [-1])
|
74 |
+
return self.module.val_step(*inputs[0], **kwargs[0])
|
75 |
+
|
76 |
+
assert len(self.device_ids) == 1, \
|
77 |
+
('MMDataParallel only supports single GPU training, if you need to'
|
78 |
+
' train with multiple GPUs, please use MMDistributedDataParallel'
|
79 |
+
' instead.')
|
80 |
+
|
81 |
+
for t in chain(self.module.parameters(), self.module.buffers()):
|
82 |
+
if t.device != self.src_device_obj:
|
83 |
+
raise RuntimeError(
|
84 |
+
'module must have its parameters and buffers '
|
85 |
+
f'on device {self.src_device_obj} (device_ids[0]) but '
|
86 |
+
f'found one of them on device: {t.device}')
|
87 |
+
|
88 |
+
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
89 |
+
return self.module.val_step(*inputs[0], **kwargs[0])
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/distributed.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
from torch.nn.parallel.distributed import (DistributedDataParallel,
|
4 |
+
_find_tensors)
|
5 |
+
|
6 |
+
from annotator.mmpkg.mmcv import print_log
|
7 |
+
from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
|
8 |
+
from .scatter_gather import scatter_kwargs
|
9 |
+
|
10 |
+
|
11 |
+
class MMDistributedDataParallel(DistributedDataParallel):
|
12 |
+
"""The DDP module that supports DataContainer.
|
13 |
+
|
14 |
+
MMDDP has two main differences with PyTorch DDP:
|
15 |
+
|
16 |
+
- It supports a custom type :class:`DataContainer` which allows more
|
17 |
+
flexible control of input data.
|
18 |
+
- It implement two APIs ``train_step()`` and ``val_step()``.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def to_kwargs(self, inputs, kwargs, device_id):
|
22 |
+
# Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
|
23 |
+
# to move all tensors to device_id
|
24 |
+
return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
|
25 |
+
|
26 |
+
def scatter(self, inputs, kwargs, device_ids):
|
27 |
+
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
28 |
+
|
29 |
+
def train_step(self, *inputs, **kwargs):
|
30 |
+
"""train_step() API for module wrapped by DistributedDataParallel.
|
31 |
+
|
32 |
+
This method is basically the same as
|
33 |
+
``DistributedDataParallel.forward()``, while replacing
|
34 |
+
``self.module.forward()`` with ``self.module.train_step()``.
|
35 |
+
It is compatible with PyTorch 1.1 - 1.5.
|
36 |
+
"""
|
37 |
+
|
38 |
+
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
|
39 |
+
# end of backward to the beginning of forward.
|
40 |
+
if ('parrots' not in TORCH_VERSION
|
41 |
+
and digit_version(TORCH_VERSION) >= digit_version('1.7')
|
42 |
+
and self.reducer._rebuild_buckets()):
|
43 |
+
print_log(
|
44 |
+
'Reducer buckets have been rebuilt in this iteration.',
|
45 |
+
logger='mmcv')
|
46 |
+
|
47 |
+
if getattr(self, 'require_forward_param_sync', True):
|
48 |
+
self._sync_params()
|
49 |
+
if self.device_ids:
|
50 |
+
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
51 |
+
if len(self.device_ids) == 1:
|
52 |
+
output = self.module.train_step(*inputs[0], **kwargs[0])
|
53 |
+
else:
|
54 |
+
outputs = self.parallel_apply(
|
55 |
+
self._module_copies[:len(inputs)], inputs, kwargs)
|
56 |
+
output = self.gather(outputs, self.output_device)
|
57 |
+
else:
|
58 |
+
output = self.module.train_step(*inputs, **kwargs)
|
59 |
+
|
60 |
+
if torch.is_grad_enabled() and getattr(
|
61 |
+
self, 'require_backward_grad_sync', True):
|
62 |
+
if self.find_unused_parameters:
|
63 |
+
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
64 |
+
else:
|
65 |
+
self.reducer.prepare_for_backward([])
|
66 |
+
else:
|
67 |
+
if ('parrots' not in TORCH_VERSION
|
68 |
+
and digit_version(TORCH_VERSION) > digit_version('1.2')):
|
69 |
+
self.require_forward_param_sync = False
|
70 |
+
return output
|
71 |
+
|
72 |
+
def val_step(self, *inputs, **kwargs):
|
73 |
+
"""val_step() API for module wrapped by DistributedDataParallel.
|
74 |
+
|
75 |
+
This method is basically the same as
|
76 |
+
``DistributedDataParallel.forward()``, while replacing
|
77 |
+
``self.module.forward()`` with ``self.module.val_step()``.
|
78 |
+
It is compatible with PyTorch 1.1 - 1.5.
|
79 |
+
"""
|
80 |
+
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
|
81 |
+
# end of backward to the beginning of forward.
|
82 |
+
if ('parrots' not in TORCH_VERSION
|
83 |
+
and digit_version(TORCH_VERSION) >= digit_version('1.7')
|
84 |
+
and self.reducer._rebuild_buckets()):
|
85 |
+
print_log(
|
86 |
+
'Reducer buckets have been rebuilt in this iteration.',
|
87 |
+
logger='mmcv')
|
88 |
+
|
89 |
+
if getattr(self, 'require_forward_param_sync', True):
|
90 |
+
self._sync_params()
|
91 |
+
if self.device_ids:
|
92 |
+
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
93 |
+
if len(self.device_ids) == 1:
|
94 |
+
output = self.module.val_step(*inputs[0], **kwargs[0])
|
95 |
+
else:
|
96 |
+
outputs = self.parallel_apply(
|
97 |
+
self._module_copies[:len(inputs)], inputs, kwargs)
|
98 |
+
output = self.gather(outputs, self.output_device)
|
99 |
+
else:
|
100 |
+
output = self.module.val_step(*inputs, **kwargs)
|
101 |
+
|
102 |
+
if torch.is_grad_enabled() and getattr(
|
103 |
+
self, 'require_backward_grad_sync', True):
|
104 |
+
if self.find_unused_parameters:
|
105 |
+
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
106 |
+
else:
|
107 |
+
self.reducer.prepare_for_backward([])
|
108 |
+
else:
|
109 |
+
if ('parrots' not in TORCH_VERSION
|
110 |
+
and digit_version(TORCH_VERSION) > digit_version('1.2')):
|
111 |
+
self.require_forward_param_sync = False
|
112 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/distributed_deprecated.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
6 |
+
_unflatten_dense_tensors)
|
7 |
+
|
8 |
+
from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
|
9 |
+
from .registry import MODULE_WRAPPERS
|
10 |
+
from .scatter_gather import scatter_kwargs
|
11 |
+
|
12 |
+
|
13 |
+
@MODULE_WRAPPERS.register_module()
|
14 |
+
class MMDistributedDataParallel(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
module,
|
18 |
+
dim=0,
|
19 |
+
broadcast_buffers=True,
|
20 |
+
bucket_cap_mb=25):
|
21 |
+
super(MMDistributedDataParallel, self).__init__()
|
22 |
+
self.module = module
|
23 |
+
self.dim = dim
|
24 |
+
self.broadcast_buffers = broadcast_buffers
|
25 |
+
|
26 |
+
self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
|
27 |
+
self._sync_params()
|
28 |
+
|
29 |
+
def _dist_broadcast_coalesced(self, tensors, buffer_size):
|
30 |
+
for tensors in _take_tensors(tensors, buffer_size):
|
31 |
+
flat_tensors = _flatten_dense_tensors(tensors)
|
32 |
+
dist.broadcast(flat_tensors, 0)
|
33 |
+
for tensor, synced in zip(
|
34 |
+
tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
|
35 |
+
tensor.copy_(synced)
|
36 |
+
|
37 |
+
def _sync_params(self):
|
38 |
+
module_states = list(self.module.state_dict().values())
|
39 |
+
if len(module_states) > 0:
|
40 |
+
self._dist_broadcast_coalesced(module_states,
|
41 |
+
self.broadcast_bucket_size)
|
42 |
+
if self.broadcast_buffers:
|
43 |
+
if (TORCH_VERSION != 'parrots'
|
44 |
+
and digit_version(TORCH_VERSION) < digit_version('1.0')):
|
45 |
+
buffers = [b.data for b in self.module._all_buffers()]
|
46 |
+
else:
|
47 |
+
buffers = [b.data for b in self.module.buffers()]
|
48 |
+
if len(buffers) > 0:
|
49 |
+
self._dist_broadcast_coalesced(buffers,
|
50 |
+
self.broadcast_bucket_size)
|
51 |
+
|
52 |
+
def scatter(self, inputs, kwargs, device_ids):
|
53 |
+
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
|
54 |
+
|
55 |
+
def forward(self, *inputs, **kwargs):
|
56 |
+
inputs, kwargs = self.scatter(inputs, kwargs,
|
57 |
+
[torch.cuda.current_device()])
|
58 |
+
return self.module(*inputs[0], **kwargs[0])
|
59 |
+
|
60 |
+
def train_step(self, *inputs, **kwargs):
|
61 |
+
inputs, kwargs = self.scatter(inputs, kwargs,
|
62 |
+
[torch.cuda.current_device()])
|
63 |
+
output = self.module.train_step(*inputs[0], **kwargs[0])
|
64 |
+
return output
|
65 |
+
|
66 |
+
def val_step(self, *inputs, **kwargs):
|
67 |
+
inputs, kwargs = self.scatter(inputs, kwargs,
|
68 |
+
[torch.cuda.current_device()])
|
69 |
+
output = self.module.val_step(*inputs[0], **kwargs[0])
|
70 |
+
return output
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/registry.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
3 |
+
|
4 |
+
from annotator.mmpkg.mmcv.utils import Registry
|
5 |
+
|
6 |
+
MODULE_WRAPPERS = Registry('module wrapper')
|
7 |
+
MODULE_WRAPPERS.register_module(module=DataParallel)
|
8 |
+
MODULE_WRAPPERS.register_module(module=DistributedDataParallel)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/scatter_gather.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
from torch.nn.parallel._functions import Scatter as OrigScatter
|
4 |
+
|
5 |
+
from ._functions import Scatter
|
6 |
+
from .data_container import DataContainer
|
7 |
+
|
8 |
+
|
9 |
+
def scatter(inputs, target_gpus, dim=0):
|
10 |
+
"""Scatter inputs to target gpus.
|
11 |
+
|
12 |
+
The only difference from original :func:`scatter` is to add support for
|
13 |
+
:type:`~mmcv.parallel.DataContainer`.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def scatter_map(obj):
|
17 |
+
if isinstance(obj, torch.Tensor):
|
18 |
+
if target_gpus != [-1]:
|
19 |
+
return OrigScatter.apply(target_gpus, None, dim, obj)
|
20 |
+
else:
|
21 |
+
# for CPU inference we use self-implemented scatter
|
22 |
+
return Scatter.forward(target_gpus, obj)
|
23 |
+
if isinstance(obj, DataContainer):
|
24 |
+
if obj.cpu_only:
|
25 |
+
return obj.data
|
26 |
+
else:
|
27 |
+
return Scatter.forward(target_gpus, obj.data)
|
28 |
+
if isinstance(obj, tuple) and len(obj) > 0:
|
29 |
+
return list(zip(*map(scatter_map, obj)))
|
30 |
+
if isinstance(obj, list) and len(obj) > 0:
|
31 |
+
out = list(map(list, zip(*map(scatter_map, obj))))
|
32 |
+
return out
|
33 |
+
if isinstance(obj, dict) and len(obj) > 0:
|
34 |
+
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
35 |
+
return out
|
36 |
+
return [obj for targets in target_gpus]
|
37 |
+
|
38 |
+
# After scatter_map is called, a scatter_map cell will exist. This cell
|
39 |
+
# has a reference to the actual function scatter_map, which has references
|
40 |
+
# to a closure that has a reference to the scatter_map cell (because the
|
41 |
+
# fn is recursive). To avoid this reference cycle, we set the function to
|
42 |
+
# None, clearing the cell
|
43 |
+
try:
|
44 |
+
return scatter_map(inputs)
|
45 |
+
finally:
|
46 |
+
scatter_map = None
|
47 |
+
|
48 |
+
|
49 |
+
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
|
50 |
+
"""Scatter with support for kwargs dictionary."""
|
51 |
+
inputs = scatter(inputs, target_gpus, dim) if inputs else []
|
52 |
+
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
|
53 |
+
if len(inputs) < len(kwargs):
|
54 |
+
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
55 |
+
elif len(kwargs) < len(inputs):
|
56 |
+
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
57 |
+
inputs = tuple(inputs)
|
58 |
+
kwargs = tuple(kwargs)
|
59 |
+
return inputs, kwargs
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/parallel/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .registry import MODULE_WRAPPERS
|
3 |
+
|
4 |
+
|
5 |
+
def is_module_wrapper(module):
|
6 |
+
"""Check if a module is a module wrapper.
|
7 |
+
|
8 |
+
The following 3 modules in MMCV (and their subclasses) are regarded as
|
9 |
+
module wrappers: DataParallel, DistributedDataParallel,
|
10 |
+
MMDistributedDataParallel (the deprecated version). You may add you own
|
11 |
+
module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
module (nn.Module): The module to be checked.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
bool: True if the input module is a module wrapper.
|
18 |
+
"""
|
19 |
+
module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
|
20 |
+
return isinstance(module, module_wrappers)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .base_module import BaseModule, ModuleList, Sequential
|
3 |
+
from .base_runner import BaseRunner
|
4 |
+
from .builder import RUNNERS, build_runner
|
5 |
+
from .checkpoint import (CheckpointLoader, _load_checkpoint,
|
6 |
+
_load_checkpoint_with_prefix, load_checkpoint,
|
7 |
+
load_state_dict, save_checkpoint, weights_to_cpu)
|
8 |
+
from .default_constructor import DefaultRunnerConstructor
|
9 |
+
from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
|
10 |
+
init_dist, master_only)
|
11 |
+
from .epoch_based_runner import EpochBasedRunner, Runner
|
12 |
+
from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
|
13 |
+
from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
|
14 |
+
DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
|
15 |
+
Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
|
16 |
+
GradientCumulativeOptimizerHook, Hook, IterTimerHook,
|
17 |
+
LoggerHook, LrUpdaterHook, MlflowLoggerHook,
|
18 |
+
NeptuneLoggerHook, OptimizerHook, PaviLoggerHook,
|
19 |
+
SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
|
20 |
+
WandbLoggerHook)
|
21 |
+
from .iter_based_runner import IterBasedRunner, IterLoader
|
22 |
+
from .log_buffer import LogBuffer
|
23 |
+
from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
|
24 |
+
DefaultOptimizerConstructor, build_optimizer,
|
25 |
+
build_optimizer_constructor)
|
26 |
+
from .priority import Priority, get_priority
|
27 |
+
from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
|
28 |
+
|
29 |
+
__all__ = [
|
30 |
+
'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
|
31 |
+
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
|
32 |
+
'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
|
33 |
+
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
|
34 |
+
'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
|
35 |
+
'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
|
36 |
+
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
|
37 |
+
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
|
38 |
+
'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
|
39 |
+
'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
|
40 |
+
'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
|
41 |
+
'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
|
42 |
+
'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
|
43 |
+
'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
|
44 |
+
'_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
|
45 |
+
'ModuleList', 'GradientCumulativeOptimizerHook',
|
46 |
+
'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
|
47 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/base_module.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import warnings
|
4 |
+
from abc import ABCMeta
|
5 |
+
from collections import defaultdict
|
6 |
+
from logging import FileHandler
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from annotator.mmpkg.mmcv.runner.dist_utils import master_only
|
11 |
+
from annotator.mmpkg.mmcv.utils.logging import get_logger, logger_initialized, print_log
|
12 |
+
|
13 |
+
|
14 |
+
class BaseModule(nn.Module, metaclass=ABCMeta):
|
15 |
+
"""Base module for all modules in openmmlab.
|
16 |
+
|
17 |
+
``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
|
18 |
+
functionality of parameter initialization. Compared with
|
19 |
+
``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
|
20 |
+
|
21 |
+
- ``init_cfg``: the config to control the initialization.
|
22 |
+
- ``init_weights``: The function of parameter
|
23 |
+
initialization and recording initialization
|
24 |
+
information.
|
25 |
+
- ``_params_init_info``: Used to track the parameter
|
26 |
+
initialization information. This attribute only
|
27 |
+
exists during executing the ``init_weights``.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
init_cfg (dict, optional): Initialization config dict.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, init_cfg=None):
|
34 |
+
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
|
35 |
+
|
36 |
+
# NOTE init_cfg can be defined in different levels, but init_cfg
|
37 |
+
# in low levels has a higher priority.
|
38 |
+
|
39 |
+
super(BaseModule, self).__init__()
|
40 |
+
# define default value of init_cfg instead of hard code
|
41 |
+
# in init_weights() function
|
42 |
+
self._is_init = False
|
43 |
+
|
44 |
+
self.init_cfg = copy.deepcopy(init_cfg)
|
45 |
+
|
46 |
+
# Backward compatibility in derived classes
|
47 |
+
# if pretrained is not None:
|
48 |
+
# warnings.warn('DeprecationWarning: pretrained is a deprecated \
|
49 |
+
# key, please consider using init_cfg')
|
50 |
+
# self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
51 |
+
|
52 |
+
@property
|
53 |
+
def is_init(self):
|
54 |
+
return self._is_init
|
55 |
+
|
56 |
+
def init_weights(self):
|
57 |
+
"""Initialize the weights."""
|
58 |
+
|
59 |
+
is_top_level_module = False
|
60 |
+
# check if it is top-level module
|
61 |
+
if not hasattr(self, '_params_init_info'):
|
62 |
+
# The `_params_init_info` is used to record the initialization
|
63 |
+
# information of the parameters
|
64 |
+
# the key should be the obj:`nn.Parameter` of model and the value
|
65 |
+
# should be a dict containing
|
66 |
+
# - init_info (str): The string that describes the initialization.
|
67 |
+
# - tmp_mean_value (FloatTensor): The mean of the parameter,
|
68 |
+
# which indicates whether the parameter has been modified.
|
69 |
+
# this attribute would be deleted after all parameters
|
70 |
+
# is initialized.
|
71 |
+
self._params_init_info = defaultdict(dict)
|
72 |
+
is_top_level_module = True
|
73 |
+
|
74 |
+
# Initialize the `_params_init_info`,
|
75 |
+
# When detecting the `tmp_mean_value` of
|
76 |
+
# the corresponding parameter is changed, update related
|
77 |
+
# initialization information
|
78 |
+
for name, param in self.named_parameters():
|
79 |
+
self._params_init_info[param][
|
80 |
+
'init_info'] = f'The value is the same before and ' \
|
81 |
+
f'after calling `init_weights` ' \
|
82 |
+
f'of {self.__class__.__name__} '
|
83 |
+
self._params_init_info[param][
|
84 |
+
'tmp_mean_value'] = param.data.mean()
|
85 |
+
|
86 |
+
# pass `params_init_info` to all submodules
|
87 |
+
# All submodules share the same `params_init_info`,
|
88 |
+
# so it will be updated when parameters are
|
89 |
+
# modified at any level of the model.
|
90 |
+
for sub_module in self.modules():
|
91 |
+
sub_module._params_init_info = self._params_init_info
|
92 |
+
|
93 |
+
# Get the initialized logger, if not exist,
|
94 |
+
# create a logger named `mmcv`
|
95 |
+
logger_names = list(logger_initialized.keys())
|
96 |
+
logger_name = logger_names[0] if logger_names else 'mmcv'
|
97 |
+
|
98 |
+
from ..cnn import initialize
|
99 |
+
from ..cnn.utils.weight_init import update_init_info
|
100 |
+
module_name = self.__class__.__name__
|
101 |
+
if not self._is_init:
|
102 |
+
if self.init_cfg:
|
103 |
+
print_log(
|
104 |
+
f'initialize {module_name} with init_cfg {self.init_cfg}',
|
105 |
+
logger=logger_name)
|
106 |
+
initialize(self, self.init_cfg)
|
107 |
+
if isinstance(self.init_cfg, dict):
|
108 |
+
# prevent the parameters of
|
109 |
+
# the pre-trained model
|
110 |
+
# from being overwritten by
|
111 |
+
# the `init_weights`
|
112 |
+
if self.init_cfg['type'] == 'Pretrained':
|
113 |
+
return
|
114 |
+
|
115 |
+
for m in self.children():
|
116 |
+
if hasattr(m, 'init_weights'):
|
117 |
+
m.init_weights()
|
118 |
+
# users may overload the `init_weights`
|
119 |
+
update_init_info(
|
120 |
+
m,
|
121 |
+
init_info=f'Initialized by '
|
122 |
+
f'user-defined `init_weights`'
|
123 |
+
f' in {m.__class__.__name__} ')
|
124 |
+
|
125 |
+
self._is_init = True
|
126 |
+
else:
|
127 |
+
warnings.warn(f'init_weights of {self.__class__.__name__} has '
|
128 |
+
f'been called more than once.')
|
129 |
+
|
130 |
+
if is_top_level_module:
|
131 |
+
self._dump_init_info(logger_name)
|
132 |
+
|
133 |
+
for sub_module in self.modules():
|
134 |
+
del sub_module._params_init_info
|
135 |
+
|
136 |
+
@master_only
|
137 |
+
def _dump_init_info(self, logger_name):
|
138 |
+
"""Dump the initialization information to a file named
|
139 |
+
`initialization.log.json` in workdir.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
logger_name (str): The name of logger.
|
143 |
+
"""
|
144 |
+
|
145 |
+
logger = get_logger(logger_name)
|
146 |
+
|
147 |
+
with_file_handler = False
|
148 |
+
# dump the information to the logger file if there is a `FileHandler`
|
149 |
+
for handler in logger.handlers:
|
150 |
+
if isinstance(handler, FileHandler):
|
151 |
+
handler.stream.write(
|
152 |
+
'Name of parameter - Initialization information\n')
|
153 |
+
for name, param in self.named_parameters():
|
154 |
+
handler.stream.write(
|
155 |
+
f'\n{name} - {param.shape}: '
|
156 |
+
f"\n{self._params_init_info[param]['init_info']} \n")
|
157 |
+
handler.stream.flush()
|
158 |
+
with_file_handler = True
|
159 |
+
if not with_file_handler:
|
160 |
+
for name, param in self.named_parameters():
|
161 |
+
print_log(
|
162 |
+
f'\n{name} - {param.shape}: '
|
163 |
+
f"\n{self._params_init_info[param]['init_info']} \n ",
|
164 |
+
logger=logger_name)
|
165 |
+
|
166 |
+
def __repr__(self):
|
167 |
+
s = super().__repr__()
|
168 |
+
if self.init_cfg:
|
169 |
+
s += f'\ninit_cfg={self.init_cfg}'
|
170 |
+
return s
|
171 |
+
|
172 |
+
|
173 |
+
class Sequential(BaseModule, nn.Sequential):
|
174 |
+
"""Sequential module in openmmlab.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
init_cfg (dict, optional): Initialization config dict.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, *args, init_cfg=None):
|
181 |
+
BaseModule.__init__(self, init_cfg)
|
182 |
+
nn.Sequential.__init__(self, *args)
|
183 |
+
|
184 |
+
|
185 |
+
class ModuleList(BaseModule, nn.ModuleList):
|
186 |
+
"""ModuleList in openmmlab.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
modules (iterable, optional): an iterable of modules to add.
|
190 |
+
init_cfg (dict, optional): Initialization config dict.
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(self, modules=None, init_cfg=None):
|
194 |
+
BaseModule.__init__(self, init_cfg)
|
195 |
+
nn.ModuleList.__init__(self, modules)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/base_runner.py
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
import os.path as osp
|
5 |
+
import warnings
|
6 |
+
from abc import ABCMeta, abstractmethod
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.optim import Optimizer
|
10 |
+
|
11 |
+
import annotator.mmpkg.mmcv as mmcv
|
12 |
+
from ..parallel import is_module_wrapper
|
13 |
+
from .checkpoint import load_checkpoint
|
14 |
+
from .dist_utils import get_dist_info
|
15 |
+
from .hooks import HOOKS, Hook
|
16 |
+
from .log_buffer import LogBuffer
|
17 |
+
from .priority import Priority, get_priority
|
18 |
+
from .utils import get_time_str
|
19 |
+
|
20 |
+
|
21 |
+
class BaseRunner(metaclass=ABCMeta):
|
22 |
+
"""The base class of Runner, a training helper for PyTorch.
|
23 |
+
|
24 |
+
All subclasses should implement the following APIs:
|
25 |
+
|
26 |
+
- ``run()``
|
27 |
+
- ``train()``
|
28 |
+
- ``val()``
|
29 |
+
- ``save_checkpoint()``
|
30 |
+
|
31 |
+
Args:
|
32 |
+
model (:obj:`torch.nn.Module`): The model to be run.
|
33 |
+
batch_processor (callable): A callable method that process a data
|
34 |
+
batch. The interface of this method should be
|
35 |
+
`batch_processor(model, data, train_mode) -> dict`
|
36 |
+
optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
|
37 |
+
optimizer (in most cases) or a dict of optimizers (in models that
|
38 |
+
requires more than one optimizer, e.g., GAN).
|
39 |
+
work_dir (str, optional): The working directory to save checkpoints
|
40 |
+
and logs. Defaults to None.
|
41 |
+
logger (:obj:`logging.Logger`): Logger used during training.
|
42 |
+
Defaults to None. (The default value is just for backward
|
43 |
+
compatibility)
|
44 |
+
meta (dict | None): A dict records some import information such as
|
45 |
+
environment info and seed, which will be logged in logger hook.
|
46 |
+
Defaults to None.
|
47 |
+
max_epochs (int, optional): Total training epochs.
|
48 |
+
max_iters (int, optional): Total training iterations.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self,
|
52 |
+
model,
|
53 |
+
batch_processor=None,
|
54 |
+
optimizer=None,
|
55 |
+
work_dir=None,
|
56 |
+
logger=None,
|
57 |
+
meta=None,
|
58 |
+
max_iters=None,
|
59 |
+
max_epochs=None):
|
60 |
+
if batch_processor is not None:
|
61 |
+
if not callable(batch_processor):
|
62 |
+
raise TypeError('batch_processor must be callable, '
|
63 |
+
f'but got {type(batch_processor)}')
|
64 |
+
warnings.warn('batch_processor is deprecated, please implement '
|
65 |
+
'train_step() and val_step() in the model instead.')
|
66 |
+
# raise an error is `batch_processor` is not None and
|
67 |
+
# `model.train_step()` exists.
|
68 |
+
if is_module_wrapper(model):
|
69 |
+
_model = model.module
|
70 |
+
else:
|
71 |
+
_model = model
|
72 |
+
if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
|
73 |
+
raise RuntimeError(
|
74 |
+
'batch_processor and model.train_step()/model.val_step() '
|
75 |
+
'cannot be both available.')
|
76 |
+
else:
|
77 |
+
assert hasattr(model, 'train_step')
|
78 |
+
|
79 |
+
# check the type of `optimizer`
|
80 |
+
if isinstance(optimizer, dict):
|
81 |
+
for name, optim in optimizer.items():
|
82 |
+
if not isinstance(optim, Optimizer):
|
83 |
+
raise TypeError(
|
84 |
+
f'optimizer must be a dict of torch.optim.Optimizers, '
|
85 |
+
f'but optimizer["{name}"] is a {type(optim)}')
|
86 |
+
elif not isinstance(optimizer, Optimizer) and optimizer is not None:
|
87 |
+
raise TypeError(
|
88 |
+
f'optimizer must be a torch.optim.Optimizer object '
|
89 |
+
f'or dict or None, but got {type(optimizer)}')
|
90 |
+
|
91 |
+
# check the type of `logger`
|
92 |
+
if not isinstance(logger, logging.Logger):
|
93 |
+
raise TypeError(f'logger must be a logging.Logger object, '
|
94 |
+
f'but got {type(logger)}')
|
95 |
+
|
96 |
+
# check the type of `meta`
|
97 |
+
if meta is not None and not isinstance(meta, dict):
|
98 |
+
raise TypeError(
|
99 |
+
f'meta must be a dict or None, but got {type(meta)}')
|
100 |
+
|
101 |
+
self.model = model
|
102 |
+
self.batch_processor = batch_processor
|
103 |
+
self.optimizer = optimizer
|
104 |
+
self.logger = logger
|
105 |
+
self.meta = meta
|
106 |
+
# create work_dir
|
107 |
+
if mmcv.is_str(work_dir):
|
108 |
+
self.work_dir = osp.abspath(work_dir)
|
109 |
+
mmcv.mkdir_or_exist(self.work_dir)
|
110 |
+
elif work_dir is None:
|
111 |
+
self.work_dir = None
|
112 |
+
else:
|
113 |
+
raise TypeError('"work_dir" must be a str or None')
|
114 |
+
|
115 |
+
# get model name from the model class
|
116 |
+
if hasattr(self.model, 'module'):
|
117 |
+
self._model_name = self.model.module.__class__.__name__
|
118 |
+
else:
|
119 |
+
self._model_name = self.model.__class__.__name__
|
120 |
+
|
121 |
+
self._rank, self._world_size = get_dist_info()
|
122 |
+
self.timestamp = get_time_str()
|
123 |
+
self.mode = None
|
124 |
+
self._hooks = []
|
125 |
+
self._epoch = 0
|
126 |
+
self._iter = 0
|
127 |
+
self._inner_iter = 0
|
128 |
+
|
129 |
+
if max_epochs is not None and max_iters is not None:
|
130 |
+
raise ValueError(
|
131 |
+
'Only one of `max_epochs` or `max_iters` can be set.')
|
132 |
+
|
133 |
+
self._max_epochs = max_epochs
|
134 |
+
self._max_iters = max_iters
|
135 |
+
# TODO: Redesign LogBuffer, it is not flexible and elegant enough
|
136 |
+
self.log_buffer = LogBuffer()
|
137 |
+
|
138 |
+
@property
|
139 |
+
def model_name(self):
|
140 |
+
"""str: Name of the model, usually the module class name."""
|
141 |
+
return self._model_name
|
142 |
+
|
143 |
+
@property
|
144 |
+
def rank(self):
|
145 |
+
"""int: Rank of current process. (distributed training)"""
|
146 |
+
return self._rank
|
147 |
+
|
148 |
+
@property
|
149 |
+
def world_size(self):
|
150 |
+
"""int: Number of processes participating in the job.
|
151 |
+
(distributed training)"""
|
152 |
+
return self._world_size
|
153 |
+
|
154 |
+
@property
|
155 |
+
def hooks(self):
|
156 |
+
"""list[:obj:`Hook`]: A list of registered hooks."""
|
157 |
+
return self._hooks
|
158 |
+
|
159 |
+
@property
|
160 |
+
def epoch(self):
|
161 |
+
"""int: Current epoch."""
|
162 |
+
return self._epoch
|
163 |
+
|
164 |
+
@property
|
165 |
+
def iter(self):
|
166 |
+
"""int: Current iteration."""
|
167 |
+
return self._iter
|
168 |
+
|
169 |
+
@property
|
170 |
+
def inner_iter(self):
|
171 |
+
"""int: Iteration in an epoch."""
|
172 |
+
return self._inner_iter
|
173 |
+
|
174 |
+
@property
|
175 |
+
def max_epochs(self):
|
176 |
+
"""int: Maximum training epochs."""
|
177 |
+
return self._max_epochs
|
178 |
+
|
179 |
+
@property
|
180 |
+
def max_iters(self):
|
181 |
+
"""int: Maximum training iterations."""
|
182 |
+
return self._max_iters
|
183 |
+
|
184 |
+
@abstractmethod
|
185 |
+
def train(self):
|
186 |
+
pass
|
187 |
+
|
188 |
+
@abstractmethod
|
189 |
+
def val(self):
|
190 |
+
pass
|
191 |
+
|
192 |
+
@abstractmethod
|
193 |
+
def run(self, data_loaders, workflow, **kwargs):
|
194 |
+
pass
|
195 |
+
|
196 |
+
@abstractmethod
|
197 |
+
def save_checkpoint(self,
|
198 |
+
out_dir,
|
199 |
+
filename_tmpl,
|
200 |
+
save_optimizer=True,
|
201 |
+
meta=None,
|
202 |
+
create_symlink=True):
|
203 |
+
pass
|
204 |
+
|
205 |
+
def current_lr(self):
|
206 |
+
"""Get current learning rates.
|
207 |
+
|
208 |
+
Returns:
|
209 |
+
list[float] | dict[str, list[float]]: Current learning rates of all
|
210 |
+
param groups. If the runner has a dict of optimizers, this
|
211 |
+
method will return a dict.
|
212 |
+
"""
|
213 |
+
if isinstance(self.optimizer, torch.optim.Optimizer):
|
214 |
+
lr = [group['lr'] for group in self.optimizer.param_groups]
|
215 |
+
elif isinstance(self.optimizer, dict):
|
216 |
+
lr = dict()
|
217 |
+
for name, optim in self.optimizer.items():
|
218 |
+
lr[name] = [group['lr'] for group in optim.param_groups]
|
219 |
+
else:
|
220 |
+
raise RuntimeError(
|
221 |
+
'lr is not applicable because optimizer does not exist.')
|
222 |
+
return lr
|
223 |
+
|
224 |
+
def current_momentum(self):
|
225 |
+
"""Get current momentums.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
list[float] | dict[str, list[float]]: Current momentums of all
|
229 |
+
param groups. If the runner has a dict of optimizers, this
|
230 |
+
method will return a dict.
|
231 |
+
"""
|
232 |
+
|
233 |
+
def _get_momentum(optimizer):
|
234 |
+
momentums = []
|
235 |
+
for group in optimizer.param_groups:
|
236 |
+
if 'momentum' in group.keys():
|
237 |
+
momentums.append(group['momentum'])
|
238 |
+
elif 'betas' in group.keys():
|
239 |
+
momentums.append(group['betas'][0])
|
240 |
+
else:
|
241 |
+
momentums.append(0)
|
242 |
+
return momentums
|
243 |
+
|
244 |
+
if self.optimizer is None:
|
245 |
+
raise RuntimeError(
|
246 |
+
'momentum is not applicable because optimizer does not exist.')
|
247 |
+
elif isinstance(self.optimizer, torch.optim.Optimizer):
|
248 |
+
momentums = _get_momentum(self.optimizer)
|
249 |
+
elif isinstance(self.optimizer, dict):
|
250 |
+
momentums = dict()
|
251 |
+
for name, optim in self.optimizer.items():
|
252 |
+
momentums[name] = _get_momentum(optim)
|
253 |
+
return momentums
|
254 |
+
|
255 |
+
def register_hook(self, hook, priority='NORMAL'):
|
256 |
+
"""Register a hook into the hook list.
|
257 |
+
|
258 |
+
The hook will be inserted into a priority queue, with the specified
|
259 |
+
priority (See :class:`Priority` for details of priorities).
|
260 |
+
For hooks with the same priority, they will be triggered in the same
|
261 |
+
order as they are registered.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
hook (:obj:`Hook`): The hook to be registered.
|
265 |
+
priority (int or str or :obj:`Priority`): Hook priority.
|
266 |
+
Lower value means higher priority.
|
267 |
+
"""
|
268 |
+
assert isinstance(hook, Hook)
|
269 |
+
if hasattr(hook, 'priority'):
|
270 |
+
raise ValueError('"priority" is a reserved attribute for hooks')
|
271 |
+
priority = get_priority(priority)
|
272 |
+
hook.priority = priority
|
273 |
+
# insert the hook to a sorted list
|
274 |
+
inserted = False
|
275 |
+
for i in range(len(self._hooks) - 1, -1, -1):
|
276 |
+
if priority >= self._hooks[i].priority:
|
277 |
+
self._hooks.insert(i + 1, hook)
|
278 |
+
inserted = True
|
279 |
+
break
|
280 |
+
if not inserted:
|
281 |
+
self._hooks.insert(0, hook)
|
282 |
+
|
283 |
+
def register_hook_from_cfg(self, hook_cfg):
|
284 |
+
"""Register a hook from its cfg.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
hook_cfg (dict): Hook config. It should have at least keys 'type'
|
288 |
+
and 'priority' indicating its type and priority.
|
289 |
+
|
290 |
+
Notes:
|
291 |
+
The specific hook class to register should not use 'type' and
|
292 |
+
'priority' arguments during initialization.
|
293 |
+
"""
|
294 |
+
hook_cfg = hook_cfg.copy()
|
295 |
+
priority = hook_cfg.pop('priority', 'NORMAL')
|
296 |
+
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
|
297 |
+
self.register_hook(hook, priority=priority)
|
298 |
+
|
299 |
+
def call_hook(self, fn_name):
|
300 |
+
"""Call all hooks.
|
301 |
+
|
302 |
+
Args:
|
303 |
+
fn_name (str): The function name in each hook to be called, such as
|
304 |
+
"before_train_epoch".
|
305 |
+
"""
|
306 |
+
for hook in self._hooks:
|
307 |
+
getattr(hook, fn_name)(self)
|
308 |
+
|
309 |
+
def get_hook_info(self):
|
310 |
+
# Get hooks info in each stage
|
311 |
+
stage_hook_map = {stage: [] for stage in Hook.stages}
|
312 |
+
for hook in self.hooks:
|
313 |
+
try:
|
314 |
+
priority = Priority(hook.priority).name
|
315 |
+
except ValueError:
|
316 |
+
priority = hook.priority
|
317 |
+
classname = hook.__class__.__name__
|
318 |
+
hook_info = f'({priority:<12}) {classname:<35}'
|
319 |
+
for trigger_stage in hook.get_triggered_stages():
|
320 |
+
stage_hook_map[trigger_stage].append(hook_info)
|
321 |
+
|
322 |
+
stage_hook_infos = []
|
323 |
+
for stage in Hook.stages:
|
324 |
+
hook_infos = stage_hook_map[stage]
|
325 |
+
if len(hook_infos) > 0:
|
326 |
+
info = f'{stage}:\n'
|
327 |
+
info += '\n'.join(hook_infos)
|
328 |
+
info += '\n -------------------- '
|
329 |
+
stage_hook_infos.append(info)
|
330 |
+
return '\n'.join(stage_hook_infos)
|
331 |
+
|
332 |
+
def load_checkpoint(self,
|
333 |
+
filename,
|
334 |
+
map_location='cpu',
|
335 |
+
strict=False,
|
336 |
+
revise_keys=[(r'^module.', '')]):
|
337 |
+
return load_checkpoint(
|
338 |
+
self.model,
|
339 |
+
filename,
|
340 |
+
map_location,
|
341 |
+
strict,
|
342 |
+
self.logger,
|
343 |
+
revise_keys=revise_keys)
|
344 |
+
|
345 |
+
def resume(self,
|
346 |
+
checkpoint,
|
347 |
+
resume_optimizer=True,
|
348 |
+
map_location='default'):
|
349 |
+
if map_location == 'default':
|
350 |
+
if torch.cuda.is_available():
|
351 |
+
device_id = torch.cuda.current_device()
|
352 |
+
checkpoint = self.load_checkpoint(
|
353 |
+
checkpoint,
|
354 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
355 |
+
else:
|
356 |
+
checkpoint = self.load_checkpoint(checkpoint)
|
357 |
+
else:
|
358 |
+
checkpoint = self.load_checkpoint(
|
359 |
+
checkpoint, map_location=map_location)
|
360 |
+
|
361 |
+
self._epoch = checkpoint['meta']['epoch']
|
362 |
+
self._iter = checkpoint['meta']['iter']
|
363 |
+
if self.meta is None:
|
364 |
+
self.meta = {}
|
365 |
+
self.meta.setdefault('hook_msgs', {})
|
366 |
+
# load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
|
367 |
+
self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
|
368 |
+
|
369 |
+
# Re-calculate the number of iterations when resuming
|
370 |
+
# models with different number of GPUs
|
371 |
+
if 'config' in checkpoint['meta']:
|
372 |
+
config = mmcv.Config.fromstring(
|
373 |
+
checkpoint['meta']['config'], file_format='.py')
|
374 |
+
previous_gpu_ids = config.get('gpu_ids', None)
|
375 |
+
if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
|
376 |
+
previous_gpu_ids) != self.world_size:
|
377 |
+
self._iter = int(self._iter * len(previous_gpu_ids) /
|
378 |
+
self.world_size)
|
379 |
+
self.logger.info('the iteration number is changed due to '
|
380 |
+
'change of GPU number')
|
381 |
+
|
382 |
+
# resume meta information meta
|
383 |
+
self.meta = checkpoint['meta']
|
384 |
+
|
385 |
+
if 'optimizer' in checkpoint and resume_optimizer:
|
386 |
+
if isinstance(self.optimizer, Optimizer):
|
387 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
388 |
+
elif isinstance(self.optimizer, dict):
|
389 |
+
for k in self.optimizer.keys():
|
390 |
+
self.optimizer[k].load_state_dict(
|
391 |
+
checkpoint['optimizer'][k])
|
392 |
+
else:
|
393 |
+
raise TypeError(
|
394 |
+
'Optimizer should be dict or torch.optim.Optimizer '
|
395 |
+
f'but got {type(self.optimizer)}')
|
396 |
+
|
397 |
+
self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
|
398 |
+
|
399 |
+
def register_lr_hook(self, lr_config):
|
400 |
+
if lr_config is None:
|
401 |
+
return
|
402 |
+
elif isinstance(lr_config, dict):
|
403 |
+
assert 'policy' in lr_config
|
404 |
+
policy_type = lr_config.pop('policy')
|
405 |
+
# If the type of policy is all in lower case, e.g., 'cyclic',
|
406 |
+
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
|
407 |
+
# This is for the convenient usage of Lr updater.
|
408 |
+
# Since this is not applicable for `
|
409 |
+
# CosineAnnealingLrUpdater`,
|
410 |
+
# the string will not be changed if it contains capital letters.
|
411 |
+
if policy_type == policy_type.lower():
|
412 |
+
policy_type = policy_type.title()
|
413 |
+
hook_type = policy_type + 'LrUpdaterHook'
|
414 |
+
lr_config['type'] = hook_type
|
415 |
+
hook = mmcv.build_from_cfg(lr_config, HOOKS)
|
416 |
+
else:
|
417 |
+
hook = lr_config
|
418 |
+
self.register_hook(hook, priority='VERY_HIGH')
|
419 |
+
|
420 |
+
def register_momentum_hook(self, momentum_config):
|
421 |
+
if momentum_config is None:
|
422 |
+
return
|
423 |
+
if isinstance(momentum_config, dict):
|
424 |
+
assert 'policy' in momentum_config
|
425 |
+
policy_type = momentum_config.pop('policy')
|
426 |
+
# If the type of policy is all in lower case, e.g., 'cyclic',
|
427 |
+
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
|
428 |
+
# This is for the convenient usage of momentum updater.
|
429 |
+
# Since this is not applicable for
|
430 |
+
# `CosineAnnealingMomentumUpdater`,
|
431 |
+
# the string will not be changed if it contains capital letters.
|
432 |
+
if policy_type == policy_type.lower():
|
433 |
+
policy_type = policy_type.title()
|
434 |
+
hook_type = policy_type + 'MomentumUpdaterHook'
|
435 |
+
momentum_config['type'] = hook_type
|
436 |
+
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
|
437 |
+
else:
|
438 |
+
hook = momentum_config
|
439 |
+
self.register_hook(hook, priority='HIGH')
|
440 |
+
|
441 |
+
def register_optimizer_hook(self, optimizer_config):
|
442 |
+
if optimizer_config is None:
|
443 |
+
return
|
444 |
+
if isinstance(optimizer_config, dict):
|
445 |
+
optimizer_config.setdefault('type', 'OptimizerHook')
|
446 |
+
hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
|
447 |
+
else:
|
448 |
+
hook = optimizer_config
|
449 |
+
self.register_hook(hook, priority='ABOVE_NORMAL')
|
450 |
+
|
451 |
+
def register_checkpoint_hook(self, checkpoint_config):
|
452 |
+
if checkpoint_config is None:
|
453 |
+
return
|
454 |
+
if isinstance(checkpoint_config, dict):
|
455 |
+
checkpoint_config.setdefault('type', 'CheckpointHook')
|
456 |
+
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
|
457 |
+
else:
|
458 |
+
hook = checkpoint_config
|
459 |
+
self.register_hook(hook, priority='NORMAL')
|
460 |
+
|
461 |
+
def register_logger_hooks(self, log_config):
|
462 |
+
if log_config is None:
|
463 |
+
return
|
464 |
+
log_interval = log_config['interval']
|
465 |
+
for info in log_config['hooks']:
|
466 |
+
logger_hook = mmcv.build_from_cfg(
|
467 |
+
info, HOOKS, default_args=dict(interval=log_interval))
|
468 |
+
self.register_hook(logger_hook, priority='VERY_LOW')
|
469 |
+
|
470 |
+
def register_timer_hook(self, timer_config):
|
471 |
+
if timer_config is None:
|
472 |
+
return
|
473 |
+
if isinstance(timer_config, dict):
|
474 |
+
timer_config_ = copy.deepcopy(timer_config)
|
475 |
+
hook = mmcv.build_from_cfg(timer_config_, HOOKS)
|
476 |
+
else:
|
477 |
+
hook = timer_config
|
478 |
+
self.register_hook(hook, priority='LOW')
|
479 |
+
|
480 |
+
def register_custom_hooks(self, custom_config):
|
481 |
+
if custom_config is None:
|
482 |
+
return
|
483 |
+
|
484 |
+
if not isinstance(custom_config, list):
|
485 |
+
custom_config = [custom_config]
|
486 |
+
|
487 |
+
for item in custom_config:
|
488 |
+
if isinstance(item, dict):
|
489 |
+
self.register_hook_from_cfg(item)
|
490 |
+
else:
|
491 |
+
self.register_hook(item, priority='NORMAL')
|
492 |
+
|
493 |
+
def register_profiler_hook(self, profiler_config):
|
494 |
+
if profiler_config is None:
|
495 |
+
return
|
496 |
+
if isinstance(profiler_config, dict):
|
497 |
+
profiler_config.setdefault('type', 'ProfilerHook')
|
498 |
+
hook = mmcv.build_from_cfg(profiler_config, HOOKS)
|
499 |
+
else:
|
500 |
+
hook = profiler_config
|
501 |
+
self.register_hook(hook)
|
502 |
+
|
503 |
+
def register_training_hooks(self,
|
504 |
+
lr_config,
|
505 |
+
optimizer_config=None,
|
506 |
+
checkpoint_config=None,
|
507 |
+
log_config=None,
|
508 |
+
momentum_config=None,
|
509 |
+
timer_config=dict(type='IterTimerHook'),
|
510 |
+
custom_hooks_config=None):
|
511 |
+
"""Register default and custom hooks for training.
|
512 |
+
|
513 |
+
Default and custom hooks include:
|
514 |
+
|
515 |
+
+----------------------+-------------------------+
|
516 |
+
| Hooks | Priority |
|
517 |
+
+======================+=========================+
|
518 |
+
| LrUpdaterHook | VERY_HIGH (10) |
|
519 |
+
+----------------------+-------------------------+
|
520 |
+
| MomentumUpdaterHook | HIGH (30) |
|
521 |
+
+----------------------+-------------------------+
|
522 |
+
| OptimizerStepperHook | ABOVE_NORMAL (40) |
|
523 |
+
+----------------------+-------------------------+
|
524 |
+
| CheckpointSaverHook | NORMAL (50) |
|
525 |
+
+----------------------+-------------------------+
|
526 |
+
| IterTimerHook | LOW (70) |
|
527 |
+
+----------------------+-------------------------+
|
528 |
+
| LoggerHook(s) | VERY_LOW (90) |
|
529 |
+
+----------------------+-------------------------+
|
530 |
+
| CustomHook(s) | defaults to NORMAL (50) |
|
531 |
+
+----------------------+-------------------------+
|
532 |
+
|
533 |
+
If custom hooks have same priority with default hooks, custom hooks
|
534 |
+
will be triggered after default hooks.
|
535 |
+
"""
|
536 |
+
self.register_lr_hook(lr_config)
|
537 |
+
self.register_momentum_hook(momentum_config)
|
538 |
+
self.register_optimizer_hook(optimizer_config)
|
539 |
+
self.register_checkpoint_hook(checkpoint_config)
|
540 |
+
self.register_timer_hook(timer_config)
|
541 |
+
self.register_logger_hooks(log_config)
|
542 |
+
self.register_custom_hooks(custom_hooks_config)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/builder.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
|
4 |
+
from ..utils import Registry
|
5 |
+
|
6 |
+
RUNNERS = Registry('runner')
|
7 |
+
RUNNER_BUILDERS = Registry('runner builder')
|
8 |
+
|
9 |
+
|
10 |
+
def build_runner_constructor(cfg):
|
11 |
+
return RUNNER_BUILDERS.build(cfg)
|
12 |
+
|
13 |
+
|
14 |
+
def build_runner(cfg, default_args=None):
|
15 |
+
runner_cfg = copy.deepcopy(cfg)
|
16 |
+
constructor_type = runner_cfg.pop('constructor',
|
17 |
+
'DefaultRunnerConstructor')
|
18 |
+
runner_constructor = build_runner_constructor(
|
19 |
+
dict(
|
20 |
+
type=constructor_type,
|
21 |
+
runner_cfg=runner_cfg,
|
22 |
+
default_args=default_args))
|
23 |
+
runner = runner_constructor()
|
24 |
+
return runner
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/checkpoint.py
ADDED
@@ -0,0 +1,707 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import pkgutil
|
6 |
+
import re
|
7 |
+
import time
|
8 |
+
import warnings
|
9 |
+
from collections import OrderedDict
|
10 |
+
from importlib import import_module
|
11 |
+
from tempfile import TemporaryDirectory
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torchvision
|
15 |
+
from torch.optim import Optimizer
|
16 |
+
from torch.utils import model_zoo
|
17 |
+
|
18 |
+
import annotator.mmpkg.mmcv as mmcv
|
19 |
+
from ..fileio import FileClient
|
20 |
+
from ..fileio import load as load_file
|
21 |
+
from ..parallel import is_module_wrapper
|
22 |
+
from ..utils import mkdir_or_exist
|
23 |
+
from .dist_utils import get_dist_info
|
24 |
+
|
25 |
+
ENV_MMCV_HOME = 'MMCV_HOME'
|
26 |
+
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
27 |
+
DEFAULT_CACHE_DIR = '~/.cache'
|
28 |
+
|
29 |
+
|
30 |
+
def _get_mmcv_home():
|
31 |
+
mmcv_home = os.path.expanduser(
|
32 |
+
os.getenv(
|
33 |
+
ENV_MMCV_HOME,
|
34 |
+
os.path.join(
|
35 |
+
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
|
36 |
+
|
37 |
+
mkdir_or_exist(mmcv_home)
|
38 |
+
return mmcv_home
|
39 |
+
|
40 |
+
|
41 |
+
def load_state_dict(module, state_dict, strict=False, logger=None):
|
42 |
+
"""Load state_dict to a module.
|
43 |
+
|
44 |
+
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
45 |
+
Default value for ``strict`` is set to ``False`` and the message for
|
46 |
+
param mismatch will be shown even if strict is False.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
module (Module): Module that receives the state_dict.
|
50 |
+
state_dict (OrderedDict): Weights.
|
51 |
+
strict (bool): whether to strictly enforce that the keys
|
52 |
+
in :attr:`state_dict` match the keys returned by this module's
|
53 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
54 |
+
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
55 |
+
message. If not specified, print function will be used.
|
56 |
+
"""
|
57 |
+
unexpected_keys = []
|
58 |
+
all_missing_keys = []
|
59 |
+
err_msg = []
|
60 |
+
|
61 |
+
metadata = getattr(state_dict, '_metadata', None)
|
62 |
+
state_dict = state_dict.copy()
|
63 |
+
if metadata is not None:
|
64 |
+
state_dict._metadata = metadata
|
65 |
+
|
66 |
+
# use _load_from_state_dict to enable checkpoint version control
|
67 |
+
def load(module, prefix=''):
|
68 |
+
# recursively check parallel module in case that the model has a
|
69 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
70 |
+
if is_module_wrapper(module):
|
71 |
+
module = module.module
|
72 |
+
local_metadata = {} if metadata is None else metadata.get(
|
73 |
+
prefix[:-1], {})
|
74 |
+
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
75 |
+
all_missing_keys, unexpected_keys,
|
76 |
+
err_msg)
|
77 |
+
for name, child in module._modules.items():
|
78 |
+
if child is not None:
|
79 |
+
load(child, prefix + name + '.')
|
80 |
+
|
81 |
+
load(module)
|
82 |
+
load = None # break load->load reference cycle
|
83 |
+
|
84 |
+
# ignore "num_batches_tracked" of BN layers
|
85 |
+
missing_keys = [
|
86 |
+
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
87 |
+
]
|
88 |
+
|
89 |
+
if unexpected_keys:
|
90 |
+
err_msg.append('unexpected key in source '
|
91 |
+
f'state_dict: {", ".join(unexpected_keys)}\n')
|
92 |
+
if missing_keys:
|
93 |
+
err_msg.append(
|
94 |
+
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
95 |
+
|
96 |
+
rank, _ = get_dist_info()
|
97 |
+
if len(err_msg) > 0 and rank == 0:
|
98 |
+
err_msg.insert(
|
99 |
+
0, 'The model and loaded state dict do not match exactly\n')
|
100 |
+
err_msg = '\n'.join(err_msg)
|
101 |
+
if strict:
|
102 |
+
raise RuntimeError(err_msg)
|
103 |
+
elif logger is not None:
|
104 |
+
logger.warning(err_msg)
|
105 |
+
else:
|
106 |
+
print(err_msg)
|
107 |
+
|
108 |
+
|
109 |
+
def get_torchvision_models():
|
110 |
+
model_urls = dict()
|
111 |
+
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
112 |
+
if ispkg:
|
113 |
+
continue
|
114 |
+
_zoo = import_module(f'torchvision.models.{name}')
|
115 |
+
if hasattr(_zoo, 'model_urls'):
|
116 |
+
_urls = getattr(_zoo, 'model_urls')
|
117 |
+
model_urls.update(_urls)
|
118 |
+
return model_urls
|
119 |
+
|
120 |
+
|
121 |
+
def get_external_models():
|
122 |
+
mmcv_home = _get_mmcv_home()
|
123 |
+
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
|
124 |
+
default_urls = load_file(default_json_path)
|
125 |
+
assert isinstance(default_urls, dict)
|
126 |
+
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
|
127 |
+
if osp.exists(external_json_path):
|
128 |
+
external_urls = load_file(external_json_path)
|
129 |
+
assert isinstance(external_urls, dict)
|
130 |
+
default_urls.update(external_urls)
|
131 |
+
|
132 |
+
return default_urls
|
133 |
+
|
134 |
+
|
135 |
+
def get_mmcls_models():
|
136 |
+
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
|
137 |
+
mmcls_urls = load_file(mmcls_json_path)
|
138 |
+
|
139 |
+
return mmcls_urls
|
140 |
+
|
141 |
+
|
142 |
+
def get_deprecated_model_names():
|
143 |
+
deprecate_json_path = osp.join(mmcv.__path__[0],
|
144 |
+
'model_zoo/deprecated.json')
|
145 |
+
deprecate_urls = load_file(deprecate_json_path)
|
146 |
+
assert isinstance(deprecate_urls, dict)
|
147 |
+
|
148 |
+
return deprecate_urls
|
149 |
+
|
150 |
+
|
151 |
+
def _process_mmcls_checkpoint(checkpoint):
|
152 |
+
state_dict = checkpoint['state_dict']
|
153 |
+
new_state_dict = OrderedDict()
|
154 |
+
for k, v in state_dict.items():
|
155 |
+
if k.startswith('backbone.'):
|
156 |
+
new_state_dict[k[9:]] = v
|
157 |
+
new_checkpoint = dict(state_dict=new_state_dict)
|
158 |
+
|
159 |
+
return new_checkpoint
|
160 |
+
|
161 |
+
|
162 |
+
class CheckpointLoader:
|
163 |
+
"""A general checkpoint loader to manage all schemes."""
|
164 |
+
|
165 |
+
_schemes = {}
|
166 |
+
|
167 |
+
@classmethod
|
168 |
+
def _register_scheme(cls, prefixes, loader, force=False):
|
169 |
+
if isinstance(prefixes, str):
|
170 |
+
prefixes = [prefixes]
|
171 |
+
else:
|
172 |
+
assert isinstance(prefixes, (list, tuple))
|
173 |
+
for prefix in prefixes:
|
174 |
+
if (prefix not in cls._schemes) or force:
|
175 |
+
cls._schemes[prefix] = loader
|
176 |
+
else:
|
177 |
+
raise KeyError(
|
178 |
+
f'{prefix} is already registered as a loader backend, '
|
179 |
+
'add "force=True" if you want to override it')
|
180 |
+
# sort, longer prefixes take priority
|
181 |
+
cls._schemes = OrderedDict(
|
182 |
+
sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
|
183 |
+
|
184 |
+
@classmethod
|
185 |
+
def register_scheme(cls, prefixes, loader=None, force=False):
|
186 |
+
"""Register a loader to CheckpointLoader.
|
187 |
+
|
188 |
+
This method can be used as a normal class method or a decorator.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
prefixes (str or list[str] or tuple[str]):
|
192 |
+
The prefix of the registered loader.
|
193 |
+
loader (function, optional): The loader function to be registered.
|
194 |
+
When this method is used as a decorator, loader is None.
|
195 |
+
Defaults to None.
|
196 |
+
force (bool, optional): Whether to override the loader
|
197 |
+
if the prefix has already been registered. Defaults to False.
|
198 |
+
"""
|
199 |
+
|
200 |
+
if loader is not None:
|
201 |
+
cls._register_scheme(prefixes, loader, force=force)
|
202 |
+
return
|
203 |
+
|
204 |
+
def _register(loader_cls):
|
205 |
+
cls._register_scheme(prefixes, loader_cls, force=force)
|
206 |
+
return loader_cls
|
207 |
+
|
208 |
+
return _register
|
209 |
+
|
210 |
+
@classmethod
|
211 |
+
def _get_checkpoint_loader(cls, path):
|
212 |
+
"""Finds a loader that supports the given path. Falls back to the local
|
213 |
+
loader if no other loader is found.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
path (str): checkpoint path
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
loader (function): checkpoint loader
|
220 |
+
"""
|
221 |
+
|
222 |
+
for p in cls._schemes:
|
223 |
+
if path.startswith(p):
|
224 |
+
return cls._schemes[p]
|
225 |
+
|
226 |
+
@classmethod
|
227 |
+
def load_checkpoint(cls, filename, map_location=None, logger=None):
|
228 |
+
"""load checkpoint through URL scheme path.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
filename (str): checkpoint file name with given prefix
|
232 |
+
map_location (str, optional): Same as :func:`torch.load`.
|
233 |
+
Default: None
|
234 |
+
logger (:mod:`logging.Logger`, optional): The logger for message.
|
235 |
+
Default: None
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
dict or OrderedDict: The loaded checkpoint.
|
239 |
+
"""
|
240 |
+
|
241 |
+
checkpoint_loader = cls._get_checkpoint_loader(filename)
|
242 |
+
class_name = checkpoint_loader.__name__
|
243 |
+
mmcv.print_log(
|
244 |
+
f'load checkpoint from {class_name[10:]} path: {filename}', logger)
|
245 |
+
return checkpoint_loader(filename, map_location)
|
246 |
+
|
247 |
+
|
248 |
+
@CheckpointLoader.register_scheme(prefixes='')
|
249 |
+
def load_from_local(filename, map_location):
|
250 |
+
"""load checkpoint by local file path.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
filename (str): local checkpoint file path
|
254 |
+
map_location (str, optional): Same as :func:`torch.load`.
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
dict or OrderedDict: The loaded checkpoint.
|
258 |
+
"""
|
259 |
+
|
260 |
+
if not osp.isfile(filename):
|
261 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
262 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
263 |
+
return checkpoint
|
264 |
+
|
265 |
+
|
266 |
+
@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
|
267 |
+
def load_from_http(filename, map_location=None, model_dir=None):
|
268 |
+
"""load checkpoint through HTTP or HTTPS scheme path. In distributed
|
269 |
+
setting, this function only download checkpoint at local rank 0.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
filename (str): checkpoint file path with modelzoo or
|
273 |
+
torchvision prefix
|
274 |
+
map_location (str, optional): Same as :func:`torch.load`.
|
275 |
+
model_dir (string, optional): directory in which to save the object,
|
276 |
+
Default: None
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
dict or OrderedDict: The loaded checkpoint.
|
280 |
+
"""
|
281 |
+
rank, world_size = get_dist_info()
|
282 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
283 |
+
if rank == 0:
|
284 |
+
checkpoint = model_zoo.load_url(
|
285 |
+
filename, model_dir=model_dir, map_location=map_location)
|
286 |
+
if world_size > 1:
|
287 |
+
torch.distributed.barrier()
|
288 |
+
if rank > 0:
|
289 |
+
checkpoint = model_zoo.load_url(
|
290 |
+
filename, model_dir=model_dir, map_location=map_location)
|
291 |
+
return checkpoint
|
292 |
+
|
293 |
+
|
294 |
+
@CheckpointLoader.register_scheme(prefixes='pavi://')
|
295 |
+
def load_from_pavi(filename, map_location=None):
|
296 |
+
"""load checkpoint through the file path prefixed with pavi. In distributed
|
297 |
+
setting, this function download ckpt at all ranks to different temporary
|
298 |
+
directories.
|
299 |
+
|
300 |
+
Args:
|
301 |
+
filename (str): checkpoint file path with pavi prefix
|
302 |
+
map_location (str, optional): Same as :func:`torch.load`.
|
303 |
+
Default: None
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
dict or OrderedDict: The loaded checkpoint.
|
307 |
+
"""
|
308 |
+
assert filename.startswith('pavi://'), \
|
309 |
+
f'Expected filename startswith `pavi://`, but get {filename}'
|
310 |
+
model_path = filename[7:]
|
311 |
+
|
312 |
+
try:
|
313 |
+
from pavi import modelcloud
|
314 |
+
except ImportError:
|
315 |
+
raise ImportError(
|
316 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
317 |
+
|
318 |
+
model = modelcloud.get(model_path)
|
319 |
+
with TemporaryDirectory() as tmp_dir:
|
320 |
+
downloaded_file = osp.join(tmp_dir, model.name)
|
321 |
+
model.download(downloaded_file)
|
322 |
+
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
323 |
+
return checkpoint
|
324 |
+
|
325 |
+
|
326 |
+
@CheckpointLoader.register_scheme(prefixes='s3://')
|
327 |
+
def load_from_ceph(filename, map_location=None, backend='petrel'):
|
328 |
+
"""load checkpoint through the file path prefixed with s3. In distributed
|
329 |
+
setting, this function download ckpt at all ranks to different temporary
|
330 |
+
directories.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
filename (str): checkpoint file path with s3 prefix
|
334 |
+
map_location (str, optional): Same as :func:`torch.load`.
|
335 |
+
backend (str, optional): The storage backend type. Options are 'ceph',
|
336 |
+
'petrel'. Default: 'petrel'.
|
337 |
+
|
338 |
+
.. warning::
|
339 |
+
:class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
|
340 |
+
please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
|
341 |
+
|
342 |
+
Returns:
|
343 |
+
dict or OrderedDict: The loaded checkpoint.
|
344 |
+
"""
|
345 |
+
allowed_backends = ['ceph', 'petrel']
|
346 |
+
if backend not in allowed_backends:
|
347 |
+
raise ValueError(f'Load from Backend {backend} is not supported.')
|
348 |
+
|
349 |
+
if backend == 'ceph':
|
350 |
+
warnings.warn(
|
351 |
+
'CephBackend will be deprecated, please use PetrelBackend instead')
|
352 |
+
|
353 |
+
# CephClient and PetrelBackend have the same prefix 's3://' and the latter
|
354 |
+
# will be chosen as default. If PetrelBackend can not be instantiated
|
355 |
+
# successfully, the CephClient will be chosen.
|
356 |
+
try:
|
357 |
+
file_client = FileClient(backend=backend)
|
358 |
+
except ImportError:
|
359 |
+
allowed_backends.remove(backend)
|
360 |
+
file_client = FileClient(backend=allowed_backends[0])
|
361 |
+
|
362 |
+
with io.BytesIO(file_client.get(filename)) as buffer:
|
363 |
+
checkpoint = torch.load(buffer, map_location=map_location)
|
364 |
+
return checkpoint
|
365 |
+
|
366 |
+
|
367 |
+
@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
|
368 |
+
def load_from_torchvision(filename, map_location=None):
|
369 |
+
"""load checkpoint through the file path prefixed with modelzoo or
|
370 |
+
torchvision.
|
371 |
+
|
372 |
+
Args:
|
373 |
+
filename (str): checkpoint file path with modelzoo or
|
374 |
+
torchvision prefix
|
375 |
+
map_location (str, optional): Same as :func:`torch.load`.
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
dict or OrderedDict: The loaded checkpoint.
|
379 |
+
"""
|
380 |
+
model_urls = get_torchvision_models()
|
381 |
+
if filename.startswith('modelzoo://'):
|
382 |
+
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
383 |
+
'use "torchvision://" instead')
|
384 |
+
model_name = filename[11:]
|
385 |
+
else:
|
386 |
+
model_name = filename[14:]
|
387 |
+
return load_from_http(model_urls[model_name], map_location=map_location)
|
388 |
+
|
389 |
+
|
390 |
+
@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
|
391 |
+
def load_from_openmmlab(filename, map_location=None):
|
392 |
+
"""load checkpoint through the file path prefixed with open-mmlab or
|
393 |
+
openmmlab.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
filename (str): checkpoint file path with open-mmlab or
|
397 |
+
openmmlab prefix
|
398 |
+
map_location (str, optional): Same as :func:`torch.load`.
|
399 |
+
Default: None
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
dict or OrderedDict: The loaded checkpoint.
|
403 |
+
"""
|
404 |
+
|
405 |
+
model_urls = get_external_models()
|
406 |
+
prefix_str = 'open-mmlab://'
|
407 |
+
if filename.startswith(prefix_str):
|
408 |
+
model_name = filename[13:]
|
409 |
+
else:
|
410 |
+
model_name = filename[12:]
|
411 |
+
prefix_str = 'openmmlab://'
|
412 |
+
|
413 |
+
deprecated_urls = get_deprecated_model_names()
|
414 |
+
if model_name in deprecated_urls:
|
415 |
+
warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
|
416 |
+
f'of {prefix_str}{deprecated_urls[model_name]}')
|
417 |
+
model_name = deprecated_urls[model_name]
|
418 |
+
model_url = model_urls[model_name]
|
419 |
+
# check if is url
|
420 |
+
if model_url.startswith(('http://', 'https://')):
|
421 |
+
checkpoint = load_from_http(model_url, map_location=map_location)
|
422 |
+
else:
|
423 |
+
filename = osp.join(_get_mmcv_home(), model_url)
|
424 |
+
if not osp.isfile(filename):
|
425 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
426 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
427 |
+
return checkpoint
|
428 |
+
|
429 |
+
|
430 |
+
@CheckpointLoader.register_scheme(prefixes='mmcls://')
|
431 |
+
def load_from_mmcls(filename, map_location=None):
|
432 |
+
"""load checkpoint through the file path prefixed with mmcls.
|
433 |
+
|
434 |
+
Args:
|
435 |
+
filename (str): checkpoint file path with mmcls prefix
|
436 |
+
map_location (str, optional): Same as :func:`torch.load`.
|
437 |
+
|
438 |
+
Returns:
|
439 |
+
dict or OrderedDict: The loaded checkpoint.
|
440 |
+
"""
|
441 |
+
|
442 |
+
model_urls = get_mmcls_models()
|
443 |
+
model_name = filename[8:]
|
444 |
+
checkpoint = load_from_http(
|
445 |
+
model_urls[model_name], map_location=map_location)
|
446 |
+
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
447 |
+
return checkpoint
|
448 |
+
|
449 |
+
|
450 |
+
def _load_checkpoint(filename, map_location=None, logger=None):
|
451 |
+
"""Load checkpoint from somewhere (modelzoo, file, url).
|
452 |
+
|
453 |
+
Args:
|
454 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
455 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
456 |
+
details.
|
457 |
+
map_location (str, optional): Same as :func:`torch.load`.
|
458 |
+
Default: None.
|
459 |
+
logger (:mod:`logging.Logger`, optional): The logger for error message.
|
460 |
+
Default: None
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
dict or OrderedDict: The loaded checkpoint. It can be either an
|
464 |
+
OrderedDict storing model weights or a dict containing other
|
465 |
+
information, which depends on the checkpoint.
|
466 |
+
"""
|
467 |
+
return CheckpointLoader.load_checkpoint(filename, map_location, logger)
|
468 |
+
|
469 |
+
|
470 |
+
def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
|
471 |
+
"""Load partial pretrained model with specific prefix.
|
472 |
+
|
473 |
+
Args:
|
474 |
+
prefix (str): The prefix of sub-module.
|
475 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
476 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
477 |
+
details.
|
478 |
+
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
479 |
+
|
480 |
+
Returns:
|
481 |
+
dict or OrderedDict: The loaded checkpoint.
|
482 |
+
"""
|
483 |
+
|
484 |
+
checkpoint = _load_checkpoint(filename, map_location=map_location)
|
485 |
+
|
486 |
+
if 'state_dict' in checkpoint:
|
487 |
+
state_dict = checkpoint['state_dict']
|
488 |
+
else:
|
489 |
+
state_dict = checkpoint
|
490 |
+
if not prefix.endswith('.'):
|
491 |
+
prefix += '.'
|
492 |
+
prefix_len = len(prefix)
|
493 |
+
|
494 |
+
state_dict = {
|
495 |
+
k[prefix_len:]: v
|
496 |
+
for k, v in state_dict.items() if k.startswith(prefix)
|
497 |
+
}
|
498 |
+
|
499 |
+
assert state_dict, f'{prefix} is not in the pretrained model'
|
500 |
+
return state_dict
|
501 |
+
|
502 |
+
|
503 |
+
def load_checkpoint(model,
|
504 |
+
filename,
|
505 |
+
map_location=None,
|
506 |
+
strict=False,
|
507 |
+
logger=None,
|
508 |
+
revise_keys=[(r'^module\.', '')]):
|
509 |
+
"""Load checkpoint from a file or URI.
|
510 |
+
|
511 |
+
Args:
|
512 |
+
model (Module): Module to load checkpoint.
|
513 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
514 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
515 |
+
details.
|
516 |
+
map_location (str): Same as :func:`torch.load`.
|
517 |
+
strict (bool): Whether to allow different params for the model and
|
518 |
+
checkpoint.
|
519 |
+
logger (:mod:`logging.Logger` or None): The logger for error message.
|
520 |
+
revise_keys (list): A list of customized keywords to modify the
|
521 |
+
state_dict in checkpoint. Each item is a (pattern, replacement)
|
522 |
+
pair of the regular expression operations. Default: strip
|
523 |
+
the prefix 'module.' by [(r'^module\\.', '')].
|
524 |
+
|
525 |
+
Returns:
|
526 |
+
dict or OrderedDict: The loaded checkpoint.
|
527 |
+
"""
|
528 |
+
checkpoint = _load_checkpoint(filename, map_location, logger)
|
529 |
+
# OrderedDict is a subclass of dict
|
530 |
+
if not isinstance(checkpoint, dict):
|
531 |
+
raise RuntimeError(
|
532 |
+
f'No state_dict found in checkpoint file {filename}')
|
533 |
+
# get state_dict from checkpoint
|
534 |
+
if 'state_dict' in checkpoint:
|
535 |
+
state_dict = checkpoint['state_dict']
|
536 |
+
else:
|
537 |
+
state_dict = checkpoint
|
538 |
+
|
539 |
+
# strip prefix of state_dict
|
540 |
+
metadata = getattr(state_dict, '_metadata', OrderedDict())
|
541 |
+
for p, r in revise_keys:
|
542 |
+
state_dict = OrderedDict(
|
543 |
+
{re.sub(p, r, k): v
|
544 |
+
for k, v in state_dict.items()})
|
545 |
+
# Keep metadata in state_dict
|
546 |
+
state_dict._metadata = metadata
|
547 |
+
|
548 |
+
# load state_dict
|
549 |
+
load_state_dict(model, state_dict, strict, logger)
|
550 |
+
return checkpoint
|
551 |
+
|
552 |
+
|
553 |
+
def weights_to_cpu(state_dict):
|
554 |
+
"""Copy a model state_dict to cpu.
|
555 |
+
|
556 |
+
Args:
|
557 |
+
state_dict (OrderedDict): Model weights on GPU.
|
558 |
+
|
559 |
+
Returns:
|
560 |
+
OrderedDict: Model weights on GPU.
|
561 |
+
"""
|
562 |
+
state_dict_cpu = OrderedDict()
|
563 |
+
for key, val in state_dict.items():
|
564 |
+
state_dict_cpu[key] = val.cpu()
|
565 |
+
# Keep metadata in state_dict
|
566 |
+
state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
|
567 |
+
return state_dict_cpu
|
568 |
+
|
569 |
+
|
570 |
+
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
571 |
+
"""Saves module state to `destination` dictionary.
|
572 |
+
|
573 |
+
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
574 |
+
|
575 |
+
Args:
|
576 |
+
module (nn.Module): The module to generate state_dict.
|
577 |
+
destination (dict): A dict where state will be stored.
|
578 |
+
prefix (str): The prefix for parameters and buffers used in this
|
579 |
+
module.
|
580 |
+
"""
|
581 |
+
for name, param in module._parameters.items():
|
582 |
+
if param is not None:
|
583 |
+
destination[prefix + name] = param if keep_vars else param.detach()
|
584 |
+
for name, buf in module._buffers.items():
|
585 |
+
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
586 |
+
if buf is not None:
|
587 |
+
destination[prefix + name] = buf if keep_vars else buf.detach()
|
588 |
+
|
589 |
+
|
590 |
+
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
591 |
+
"""Returns a dictionary containing a whole state of the module.
|
592 |
+
|
593 |
+
Both parameters and persistent buffers (e.g. running averages) are
|
594 |
+
included. Keys are corresponding parameter and buffer names.
|
595 |
+
|
596 |
+
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
597 |
+
recursively check parallel module in case that the model has a complicated
|
598 |
+
structure, e.g., nn.Module(nn.Module(DDP)).
|
599 |
+
|
600 |
+
Args:
|
601 |
+
module (nn.Module): The module to generate state_dict.
|
602 |
+
destination (OrderedDict): Returned dict for the state of the
|
603 |
+
module.
|
604 |
+
prefix (str): Prefix of the key.
|
605 |
+
keep_vars (bool): Whether to keep the variable property of the
|
606 |
+
parameters. Default: False.
|
607 |
+
|
608 |
+
Returns:
|
609 |
+
dict: A dictionary containing a whole state of the module.
|
610 |
+
"""
|
611 |
+
# recursively check parallel module in case that the model has a
|
612 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
613 |
+
if is_module_wrapper(module):
|
614 |
+
module = module.module
|
615 |
+
|
616 |
+
# below is the same as torch.nn.Module.state_dict()
|
617 |
+
if destination is None:
|
618 |
+
destination = OrderedDict()
|
619 |
+
destination._metadata = OrderedDict()
|
620 |
+
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
621 |
+
version=module._version)
|
622 |
+
_save_to_state_dict(module, destination, prefix, keep_vars)
|
623 |
+
for name, child in module._modules.items():
|
624 |
+
if child is not None:
|
625 |
+
get_state_dict(
|
626 |
+
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
627 |
+
for hook in module._state_dict_hooks.values():
|
628 |
+
hook_result = hook(module, destination, prefix, local_metadata)
|
629 |
+
if hook_result is not None:
|
630 |
+
destination = hook_result
|
631 |
+
return destination
|
632 |
+
|
633 |
+
|
634 |
+
def save_checkpoint(model,
|
635 |
+
filename,
|
636 |
+
optimizer=None,
|
637 |
+
meta=None,
|
638 |
+
file_client_args=None):
|
639 |
+
"""Save checkpoint to file.
|
640 |
+
|
641 |
+
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
642 |
+
``optimizer``. By default ``meta`` will contain version and time info.
|
643 |
+
|
644 |
+
Args:
|
645 |
+
model (Module): Module whose params are to be saved.
|
646 |
+
filename (str): Checkpoint filename.
|
647 |
+
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
648 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
649 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
650 |
+
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
651 |
+
Default: None.
|
652 |
+
`New in version 1.3.16.`
|
653 |
+
"""
|
654 |
+
if meta is None:
|
655 |
+
meta = {}
|
656 |
+
elif not isinstance(meta, dict):
|
657 |
+
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
658 |
+
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
659 |
+
|
660 |
+
if is_module_wrapper(model):
|
661 |
+
model = model.module
|
662 |
+
|
663 |
+
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
664 |
+
# save class name to the meta
|
665 |
+
meta.update(CLASSES=model.CLASSES)
|
666 |
+
|
667 |
+
checkpoint = {
|
668 |
+
'meta': meta,
|
669 |
+
'state_dict': weights_to_cpu(get_state_dict(model))
|
670 |
+
}
|
671 |
+
# save optimizer state dict in the checkpoint
|
672 |
+
if isinstance(optimizer, Optimizer):
|
673 |
+
checkpoint['optimizer'] = optimizer.state_dict()
|
674 |
+
elif isinstance(optimizer, dict):
|
675 |
+
checkpoint['optimizer'] = {}
|
676 |
+
for name, optim in optimizer.items():
|
677 |
+
checkpoint['optimizer'][name] = optim.state_dict()
|
678 |
+
|
679 |
+
if filename.startswith('pavi://'):
|
680 |
+
if file_client_args is not None:
|
681 |
+
raise ValueError(
|
682 |
+
'file_client_args should be "None" if filename starts with'
|
683 |
+
f'"pavi://", but got {file_client_args}')
|
684 |
+
try:
|
685 |
+
from pavi import modelcloud
|
686 |
+
from pavi import exception
|
687 |
+
except ImportError:
|
688 |
+
raise ImportError(
|
689 |
+
'Please install pavi to load checkpoint from modelcloud.')
|
690 |
+
model_path = filename[7:]
|
691 |
+
root = modelcloud.Folder()
|
692 |
+
model_dir, model_name = osp.split(model_path)
|
693 |
+
try:
|
694 |
+
model = modelcloud.get(model_dir)
|
695 |
+
except exception.NodeNotFoundError:
|
696 |
+
model = root.create_training_model(model_dir)
|
697 |
+
with TemporaryDirectory() as tmp_dir:
|
698 |
+
checkpoint_file = osp.join(tmp_dir, model_name)
|
699 |
+
with open(checkpoint_file, 'wb') as f:
|
700 |
+
torch.save(checkpoint, f)
|
701 |
+
f.flush()
|
702 |
+
model.create_file(checkpoint_file, name=model_name)
|
703 |
+
else:
|
704 |
+
file_client = FileClient.infer_client(file_client_args, filename)
|
705 |
+
with io.BytesIO() as f:
|
706 |
+
torch.save(checkpoint, f)
|
707 |
+
file_client.put(f.getvalue(), filename)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/default_constructor.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .builder import RUNNER_BUILDERS, RUNNERS
|
2 |
+
|
3 |
+
|
4 |
+
@RUNNER_BUILDERS.register_module()
|
5 |
+
class DefaultRunnerConstructor:
|
6 |
+
"""Default constructor for runners.
|
7 |
+
|
8 |
+
Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
|
9 |
+
For example, We can inject some new properties and functions for `Runner`.
|
10 |
+
|
11 |
+
Example:
|
12 |
+
>>> from annotator.mmpkg.mmcv.runner import RUNNER_BUILDERS, build_runner
|
13 |
+
>>> # Define a new RunnerReconstructor
|
14 |
+
>>> @RUNNER_BUILDERS.register_module()
|
15 |
+
>>> class MyRunnerConstructor:
|
16 |
+
... def __init__(self, runner_cfg, default_args=None):
|
17 |
+
... if not isinstance(runner_cfg, dict):
|
18 |
+
... raise TypeError('runner_cfg should be a dict',
|
19 |
+
... f'but got {type(runner_cfg)}')
|
20 |
+
... self.runner_cfg = runner_cfg
|
21 |
+
... self.default_args = default_args
|
22 |
+
...
|
23 |
+
... def __call__(self):
|
24 |
+
... runner = RUNNERS.build(self.runner_cfg,
|
25 |
+
... default_args=self.default_args)
|
26 |
+
... # Add new properties for existing runner
|
27 |
+
... runner.my_name = 'my_runner'
|
28 |
+
... runner.my_function = lambda self: print(self.my_name)
|
29 |
+
... ...
|
30 |
+
>>> # build your runner
|
31 |
+
>>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
|
32 |
+
... constructor='MyRunnerConstructor')
|
33 |
+
>>> runner = build_runner(runner_cfg)
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, runner_cfg, default_args=None):
|
37 |
+
if not isinstance(runner_cfg, dict):
|
38 |
+
raise TypeError('runner_cfg should be a dict',
|
39 |
+
f'but got {type(runner_cfg)}')
|
40 |
+
self.runner_cfg = runner_cfg
|
41 |
+
self.default_args = default_args
|
42 |
+
|
43 |
+
def __call__(self):
|
44 |
+
return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/dist_utils.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.multiprocessing as mp
|
9 |
+
from torch import distributed as dist
|
10 |
+
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
11 |
+
_unflatten_dense_tensors)
|
12 |
+
|
13 |
+
|
14 |
+
def init_dist(launcher, backend='nccl', **kwargs):
|
15 |
+
if mp.get_start_method(allow_none=True) is None:
|
16 |
+
mp.set_start_method('spawn')
|
17 |
+
if launcher == 'pytorch':
|
18 |
+
_init_dist_pytorch(backend, **kwargs)
|
19 |
+
elif launcher == 'mpi':
|
20 |
+
_init_dist_mpi(backend, **kwargs)
|
21 |
+
elif launcher == 'slurm':
|
22 |
+
_init_dist_slurm(backend, **kwargs)
|
23 |
+
else:
|
24 |
+
raise ValueError(f'Invalid launcher type: {launcher}')
|
25 |
+
|
26 |
+
|
27 |
+
def _init_dist_pytorch(backend, **kwargs):
|
28 |
+
# TODO: use local_rank instead of rank % num_gpus
|
29 |
+
rank = int(os.environ['RANK'])
|
30 |
+
num_gpus = torch.cuda.device_count()
|
31 |
+
torch.cuda.set_device(rank % num_gpus)
|
32 |
+
dist.init_process_group(backend=backend, **kwargs)
|
33 |
+
|
34 |
+
|
35 |
+
def _init_dist_mpi(backend, **kwargs):
|
36 |
+
# TODO: use local_rank instead of rank % num_gpus
|
37 |
+
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
38 |
+
num_gpus = torch.cuda.device_count()
|
39 |
+
torch.cuda.set_device(rank % num_gpus)
|
40 |
+
dist.init_process_group(backend=backend, **kwargs)
|
41 |
+
|
42 |
+
|
43 |
+
def _init_dist_slurm(backend, port=None):
|
44 |
+
"""Initialize slurm distributed training environment.
|
45 |
+
|
46 |
+
If argument ``port`` is not specified, then the master port will be system
|
47 |
+
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
48 |
+
environment variable, then a default port ``29500`` will be used.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
backend (str): Backend of torch.distributed.
|
52 |
+
port (int, optional): Master port. Defaults to None.
|
53 |
+
"""
|
54 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
55 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
56 |
+
node_list = os.environ['SLURM_NODELIST']
|
57 |
+
num_gpus = torch.cuda.device_count()
|
58 |
+
torch.cuda.set_device(proc_id % num_gpus)
|
59 |
+
addr = subprocess.getoutput(
|
60 |
+
f'scontrol show hostname {node_list} | head -n1')
|
61 |
+
# specify master port
|
62 |
+
if port is not None:
|
63 |
+
os.environ['MASTER_PORT'] = str(port)
|
64 |
+
elif 'MASTER_PORT' in os.environ:
|
65 |
+
pass # use MASTER_PORT in the environment variable
|
66 |
+
else:
|
67 |
+
# 29500 is torch.distributed default port
|
68 |
+
os.environ['MASTER_PORT'] = '29500'
|
69 |
+
# use MASTER_ADDR in the environment variable if it already exists
|
70 |
+
if 'MASTER_ADDR' not in os.environ:
|
71 |
+
os.environ['MASTER_ADDR'] = addr
|
72 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
73 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
74 |
+
os.environ['RANK'] = str(proc_id)
|
75 |
+
dist.init_process_group(backend=backend)
|
76 |
+
|
77 |
+
|
78 |
+
def get_dist_info():
|
79 |
+
if dist.is_available() and dist.is_initialized():
|
80 |
+
rank = dist.get_rank()
|
81 |
+
world_size = dist.get_world_size()
|
82 |
+
else:
|
83 |
+
rank = 0
|
84 |
+
world_size = 1
|
85 |
+
return rank, world_size
|
86 |
+
|
87 |
+
|
88 |
+
def master_only(func):
|
89 |
+
|
90 |
+
@functools.wraps(func)
|
91 |
+
def wrapper(*args, **kwargs):
|
92 |
+
rank, _ = get_dist_info()
|
93 |
+
if rank == 0:
|
94 |
+
return func(*args, **kwargs)
|
95 |
+
|
96 |
+
return wrapper
|
97 |
+
|
98 |
+
|
99 |
+
def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
|
100 |
+
"""Allreduce parameters.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
params (list[torch.Parameters]): List of parameters or buffers of a
|
104 |
+
model.
|
105 |
+
coalesce (bool, optional): Whether allreduce parameters as a whole.
|
106 |
+
Defaults to True.
|
107 |
+
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
|
108 |
+
Defaults to -1.
|
109 |
+
"""
|
110 |
+
_, world_size = get_dist_info()
|
111 |
+
if world_size == 1:
|
112 |
+
return
|
113 |
+
params = [param.data for param in params]
|
114 |
+
if coalesce:
|
115 |
+
_allreduce_coalesced(params, world_size, bucket_size_mb)
|
116 |
+
else:
|
117 |
+
for tensor in params:
|
118 |
+
dist.all_reduce(tensor.div_(world_size))
|
119 |
+
|
120 |
+
|
121 |
+
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
|
122 |
+
"""Allreduce gradients.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
params (list[torch.Parameters]): List of parameters of a model
|
126 |
+
coalesce (bool, optional): Whether allreduce parameters as a whole.
|
127 |
+
Defaults to True.
|
128 |
+
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
|
129 |
+
Defaults to -1.
|
130 |
+
"""
|
131 |
+
grads = [
|
132 |
+
param.grad.data for param in params
|
133 |
+
if param.requires_grad and param.grad is not None
|
134 |
+
]
|
135 |
+
_, world_size = get_dist_info()
|
136 |
+
if world_size == 1:
|
137 |
+
return
|
138 |
+
if coalesce:
|
139 |
+
_allreduce_coalesced(grads, world_size, bucket_size_mb)
|
140 |
+
else:
|
141 |
+
for tensor in grads:
|
142 |
+
dist.all_reduce(tensor.div_(world_size))
|
143 |
+
|
144 |
+
|
145 |
+
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
|
146 |
+
if bucket_size_mb > 0:
|
147 |
+
bucket_size_bytes = bucket_size_mb * 1024 * 1024
|
148 |
+
buckets = _take_tensors(tensors, bucket_size_bytes)
|
149 |
+
else:
|
150 |
+
buckets = OrderedDict()
|
151 |
+
for tensor in tensors:
|
152 |
+
tp = tensor.type()
|
153 |
+
if tp not in buckets:
|
154 |
+
buckets[tp] = []
|
155 |
+
buckets[tp].append(tensor)
|
156 |
+
buckets = buckets.values()
|
157 |
+
|
158 |
+
for bucket in buckets:
|
159 |
+
flat_tensors = _flatten_dense_tensors(bucket)
|
160 |
+
dist.all_reduce(flat_tensors)
|
161 |
+
flat_tensors.div_(world_size)
|
162 |
+
for tensor, synced in zip(
|
163 |
+
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
|
164 |
+
tensor.copy_(synced)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/epoch_based_runner.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import platform
|
4 |
+
import shutil
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import annotator.mmpkg.mmcv as mmcv
|
11 |
+
from .base_runner import BaseRunner
|
12 |
+
from .builder import RUNNERS
|
13 |
+
from .checkpoint import save_checkpoint
|
14 |
+
from .utils import get_host_info
|
15 |
+
|
16 |
+
|
17 |
+
@RUNNERS.register_module()
|
18 |
+
class EpochBasedRunner(BaseRunner):
|
19 |
+
"""Epoch-based Runner.
|
20 |
+
|
21 |
+
This runner train models epoch by epoch.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def run_iter(self, data_batch, train_mode, **kwargs):
|
25 |
+
if self.batch_processor is not None:
|
26 |
+
outputs = self.batch_processor(
|
27 |
+
self.model, data_batch, train_mode=train_mode, **kwargs)
|
28 |
+
elif train_mode:
|
29 |
+
outputs = self.model.train_step(data_batch, self.optimizer,
|
30 |
+
**kwargs)
|
31 |
+
else:
|
32 |
+
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
|
33 |
+
if not isinstance(outputs, dict):
|
34 |
+
raise TypeError('"batch_processor()" or "model.train_step()"'
|
35 |
+
'and "model.val_step()" must return a dict')
|
36 |
+
if 'log_vars' in outputs:
|
37 |
+
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
|
38 |
+
self.outputs = outputs
|
39 |
+
|
40 |
+
def train(self, data_loader, **kwargs):
|
41 |
+
self.model.train()
|
42 |
+
self.mode = 'train'
|
43 |
+
self.data_loader = data_loader
|
44 |
+
self._max_iters = self._max_epochs * len(self.data_loader)
|
45 |
+
self.call_hook('before_train_epoch')
|
46 |
+
time.sleep(2) # Prevent possible deadlock during epoch transition
|
47 |
+
for i, data_batch in enumerate(self.data_loader):
|
48 |
+
self._inner_iter = i
|
49 |
+
self.call_hook('before_train_iter')
|
50 |
+
self.run_iter(data_batch, train_mode=True, **kwargs)
|
51 |
+
self.call_hook('after_train_iter')
|
52 |
+
self._iter += 1
|
53 |
+
|
54 |
+
self.call_hook('after_train_epoch')
|
55 |
+
self._epoch += 1
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
def val(self, data_loader, **kwargs):
|
59 |
+
self.model.eval()
|
60 |
+
self.mode = 'val'
|
61 |
+
self.data_loader = data_loader
|
62 |
+
self.call_hook('before_val_epoch')
|
63 |
+
time.sleep(2) # Prevent possible deadlock during epoch transition
|
64 |
+
for i, data_batch in enumerate(self.data_loader):
|
65 |
+
self._inner_iter = i
|
66 |
+
self.call_hook('before_val_iter')
|
67 |
+
self.run_iter(data_batch, train_mode=False)
|
68 |
+
self.call_hook('after_val_iter')
|
69 |
+
|
70 |
+
self.call_hook('after_val_epoch')
|
71 |
+
|
72 |
+
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
|
73 |
+
"""Start running.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
|
77 |
+
and validation.
|
78 |
+
workflow (list[tuple]): A list of (phase, epochs) to specify the
|
79 |
+
running order and epochs. E.g, [('train', 2), ('val', 1)] means
|
80 |
+
running 2 epochs for training and 1 epoch for validation,
|
81 |
+
iteratively.
|
82 |
+
"""
|
83 |
+
assert isinstance(data_loaders, list)
|
84 |
+
assert mmcv.is_list_of(workflow, tuple)
|
85 |
+
assert len(data_loaders) == len(workflow)
|
86 |
+
if max_epochs is not None:
|
87 |
+
warnings.warn(
|
88 |
+
'setting max_epochs in run is deprecated, '
|
89 |
+
'please set max_epochs in runner_config', DeprecationWarning)
|
90 |
+
self._max_epochs = max_epochs
|
91 |
+
|
92 |
+
assert self._max_epochs is not None, (
|
93 |
+
'max_epochs must be specified during instantiation')
|
94 |
+
|
95 |
+
for i, flow in enumerate(workflow):
|
96 |
+
mode, epochs = flow
|
97 |
+
if mode == 'train':
|
98 |
+
self._max_iters = self._max_epochs * len(data_loaders[i])
|
99 |
+
break
|
100 |
+
|
101 |
+
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
|
102 |
+
self.logger.info('Start running, host: %s, work_dir: %s',
|
103 |
+
get_host_info(), work_dir)
|
104 |
+
self.logger.info('Hooks will be executed in the following order:\n%s',
|
105 |
+
self.get_hook_info())
|
106 |
+
self.logger.info('workflow: %s, max: %d epochs', workflow,
|
107 |
+
self._max_epochs)
|
108 |
+
self.call_hook('before_run')
|
109 |
+
|
110 |
+
while self.epoch < self._max_epochs:
|
111 |
+
for i, flow in enumerate(workflow):
|
112 |
+
mode, epochs = flow
|
113 |
+
if isinstance(mode, str): # self.train()
|
114 |
+
if not hasattr(self, mode):
|
115 |
+
raise ValueError(
|
116 |
+
f'runner has no method named "{mode}" to run an '
|
117 |
+
'epoch')
|
118 |
+
epoch_runner = getattr(self, mode)
|
119 |
+
else:
|
120 |
+
raise TypeError(
|
121 |
+
'mode in workflow must be a str, but got {}'.format(
|
122 |
+
type(mode)))
|
123 |
+
|
124 |
+
for _ in range(epochs):
|
125 |
+
if mode == 'train' and self.epoch >= self._max_epochs:
|
126 |
+
break
|
127 |
+
epoch_runner(data_loaders[i], **kwargs)
|
128 |
+
|
129 |
+
time.sleep(1) # wait for some hooks like loggers to finish
|
130 |
+
self.call_hook('after_run')
|
131 |
+
|
132 |
+
def save_checkpoint(self,
|
133 |
+
out_dir,
|
134 |
+
filename_tmpl='epoch_{}.pth',
|
135 |
+
save_optimizer=True,
|
136 |
+
meta=None,
|
137 |
+
create_symlink=True):
|
138 |
+
"""Save the checkpoint.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
out_dir (str): The directory that checkpoints are saved.
|
142 |
+
filename_tmpl (str, optional): The checkpoint filename template,
|
143 |
+
which contains a placeholder for the epoch number.
|
144 |
+
Defaults to 'epoch_{}.pth'.
|
145 |
+
save_optimizer (bool, optional): Whether to save the optimizer to
|
146 |
+
the checkpoint. Defaults to True.
|
147 |
+
meta (dict, optional): The meta information to be saved in the
|
148 |
+
checkpoint. Defaults to None.
|
149 |
+
create_symlink (bool, optional): Whether to create a symlink
|
150 |
+
"latest.pth" to point to the latest checkpoint.
|
151 |
+
Defaults to True.
|
152 |
+
"""
|
153 |
+
if meta is None:
|
154 |
+
meta = {}
|
155 |
+
elif not isinstance(meta, dict):
|
156 |
+
raise TypeError(
|
157 |
+
f'meta should be a dict or None, but got {type(meta)}')
|
158 |
+
if self.meta is not None:
|
159 |
+
meta.update(self.meta)
|
160 |
+
# Note: meta.update(self.meta) should be done before
|
161 |
+
# meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
|
162 |
+
# there will be problems with resumed checkpoints.
|
163 |
+
# More details in https://github.com/open-mmlab/mmcv/pull/1108
|
164 |
+
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
165 |
+
|
166 |
+
filename = filename_tmpl.format(self.epoch + 1)
|
167 |
+
filepath = osp.join(out_dir, filename)
|
168 |
+
optimizer = self.optimizer if save_optimizer else None
|
169 |
+
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
|
170 |
+
# in some environments, `os.symlink` is not supported, you may need to
|
171 |
+
# set `create_symlink` to False
|
172 |
+
if create_symlink:
|
173 |
+
dst_file = osp.join(out_dir, 'latest.pth')
|
174 |
+
if platform.system() != 'Windows':
|
175 |
+
mmcv.symlink(filename, dst_file)
|
176 |
+
else:
|
177 |
+
shutil.copy(filepath, dst_file)
|
178 |
+
|
179 |
+
|
180 |
+
@RUNNERS.register_module()
|
181 |
+
class Runner(EpochBasedRunner):
|
182 |
+
"""Deprecated name of EpochBasedRunner."""
|
183 |
+
|
184 |
+
def __init__(self, *args, **kwargs):
|
185 |
+
warnings.warn(
|
186 |
+
'Runner was deprecated, please use EpochBasedRunner instead')
|
187 |
+
super().__init__(*args, **kwargs)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/fp16_utils.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import functools
|
3 |
+
import warnings
|
4 |
+
from collections import abc
|
5 |
+
from inspect import getfullargspec
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
|
12 |
+
from .dist_utils import allreduce_grads as _allreduce_grads
|
13 |
+
|
14 |
+
try:
|
15 |
+
# If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
|
16 |
+
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
|
17 |
+
# Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
|
18 |
+
# manually, so the behavior may not be consistent with real amp.
|
19 |
+
from torch.cuda.amp import autocast
|
20 |
+
except ImportError:
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
def cast_tensor_type(inputs, src_type, dst_type):
|
25 |
+
"""Recursively convert Tensor in inputs from src_type to dst_type.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
inputs: Inputs that to be casted.
|
29 |
+
src_type (torch.dtype): Source type..
|
30 |
+
dst_type (torch.dtype): Destination type.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
The same type with inputs, but all contained Tensors have been cast.
|
34 |
+
"""
|
35 |
+
if isinstance(inputs, nn.Module):
|
36 |
+
return inputs
|
37 |
+
elif isinstance(inputs, torch.Tensor):
|
38 |
+
return inputs.to(dst_type)
|
39 |
+
elif isinstance(inputs, str):
|
40 |
+
return inputs
|
41 |
+
elif isinstance(inputs, np.ndarray):
|
42 |
+
return inputs
|
43 |
+
elif isinstance(inputs, abc.Mapping):
|
44 |
+
return type(inputs)({
|
45 |
+
k: cast_tensor_type(v, src_type, dst_type)
|
46 |
+
for k, v in inputs.items()
|
47 |
+
})
|
48 |
+
elif isinstance(inputs, abc.Iterable):
|
49 |
+
return type(inputs)(
|
50 |
+
cast_tensor_type(item, src_type, dst_type) for item in inputs)
|
51 |
+
else:
|
52 |
+
return inputs
|
53 |
+
|
54 |
+
|
55 |
+
def auto_fp16(apply_to=None, out_fp32=False):
|
56 |
+
"""Decorator to enable fp16 training automatically.
|
57 |
+
|
58 |
+
This decorator is useful when you write custom modules and want to support
|
59 |
+
mixed precision training. If inputs arguments are fp32 tensors, they will
|
60 |
+
be converted to fp16 automatically. Arguments other than fp32 tensors are
|
61 |
+
ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
|
62 |
+
backend, otherwise, original mmcv implementation will be adopted.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
apply_to (Iterable, optional): The argument names to be converted.
|
66 |
+
`None` indicates all arguments.
|
67 |
+
out_fp32 (bool): Whether to convert the output back to fp32.
|
68 |
+
|
69 |
+
Example:
|
70 |
+
|
71 |
+
>>> import torch.nn as nn
|
72 |
+
>>> class MyModule1(nn.Module):
|
73 |
+
>>>
|
74 |
+
>>> # Convert x and y to fp16
|
75 |
+
>>> @auto_fp16()
|
76 |
+
>>> def forward(self, x, y):
|
77 |
+
>>> pass
|
78 |
+
|
79 |
+
>>> import torch.nn as nn
|
80 |
+
>>> class MyModule2(nn.Module):
|
81 |
+
>>>
|
82 |
+
>>> # convert pred to fp16
|
83 |
+
>>> @auto_fp16(apply_to=('pred', ))
|
84 |
+
>>> def do_something(self, pred, others):
|
85 |
+
>>> pass
|
86 |
+
"""
|
87 |
+
|
88 |
+
def auto_fp16_wrapper(old_func):
|
89 |
+
|
90 |
+
@functools.wraps(old_func)
|
91 |
+
def new_func(*args, **kwargs):
|
92 |
+
# check if the module has set the attribute `fp16_enabled`, if not,
|
93 |
+
# just fallback to the original method.
|
94 |
+
if not isinstance(args[0], torch.nn.Module):
|
95 |
+
raise TypeError('@auto_fp16 can only be used to decorate the '
|
96 |
+
'method of nn.Module')
|
97 |
+
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
|
98 |
+
return old_func(*args, **kwargs)
|
99 |
+
|
100 |
+
# get the arg spec of the decorated method
|
101 |
+
args_info = getfullargspec(old_func)
|
102 |
+
# get the argument names to be casted
|
103 |
+
args_to_cast = args_info.args if apply_to is None else apply_to
|
104 |
+
# convert the args that need to be processed
|
105 |
+
new_args = []
|
106 |
+
# NOTE: default args are not taken into consideration
|
107 |
+
if args:
|
108 |
+
arg_names = args_info.args[:len(args)]
|
109 |
+
for i, arg_name in enumerate(arg_names):
|
110 |
+
if arg_name in args_to_cast:
|
111 |
+
new_args.append(
|
112 |
+
cast_tensor_type(args[i], torch.float, torch.half))
|
113 |
+
else:
|
114 |
+
new_args.append(args[i])
|
115 |
+
# convert the kwargs that need to be processed
|
116 |
+
new_kwargs = {}
|
117 |
+
if kwargs:
|
118 |
+
for arg_name, arg_value in kwargs.items():
|
119 |
+
if arg_name in args_to_cast:
|
120 |
+
new_kwargs[arg_name] = cast_tensor_type(
|
121 |
+
arg_value, torch.float, torch.half)
|
122 |
+
else:
|
123 |
+
new_kwargs[arg_name] = arg_value
|
124 |
+
# apply converted arguments to the decorated method
|
125 |
+
if (TORCH_VERSION != 'parrots' and
|
126 |
+
digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
|
127 |
+
with autocast(enabled=True):
|
128 |
+
output = old_func(*new_args, **new_kwargs)
|
129 |
+
else:
|
130 |
+
output = old_func(*new_args, **new_kwargs)
|
131 |
+
# cast the results back to fp32 if necessary
|
132 |
+
if out_fp32:
|
133 |
+
output = cast_tensor_type(output, torch.half, torch.float)
|
134 |
+
return output
|
135 |
+
|
136 |
+
return new_func
|
137 |
+
|
138 |
+
return auto_fp16_wrapper
|
139 |
+
|
140 |
+
|
141 |
+
def force_fp32(apply_to=None, out_fp16=False):
|
142 |
+
"""Decorator to convert input arguments to fp32 in force.
|
143 |
+
|
144 |
+
This decorator is useful when you write custom modules and want to support
|
145 |
+
mixed precision training. If there are some inputs that must be processed
|
146 |
+
in fp32 mode, then this decorator can handle it. If inputs arguments are
|
147 |
+
fp16 tensors, they will be converted to fp32 automatically. Arguments other
|
148 |
+
than fp16 tensors are ignored. If you are using PyTorch >= 1.6,
|
149 |
+
torch.cuda.amp is used as the backend, otherwise, original mmcv
|
150 |
+
implementation will be adopted.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
apply_to (Iterable, optional): The argument names to be converted.
|
154 |
+
`None` indicates all arguments.
|
155 |
+
out_fp16 (bool): Whether to convert the output back to fp16.
|
156 |
+
|
157 |
+
Example:
|
158 |
+
|
159 |
+
>>> import torch.nn as nn
|
160 |
+
>>> class MyModule1(nn.Module):
|
161 |
+
>>>
|
162 |
+
>>> # Convert x and y to fp32
|
163 |
+
>>> @force_fp32()
|
164 |
+
>>> def loss(self, x, y):
|
165 |
+
>>> pass
|
166 |
+
|
167 |
+
>>> import torch.nn as nn
|
168 |
+
>>> class MyModule2(nn.Module):
|
169 |
+
>>>
|
170 |
+
>>> # convert pred to fp32
|
171 |
+
>>> @force_fp32(apply_to=('pred', ))
|
172 |
+
>>> def post_process(self, pred, others):
|
173 |
+
>>> pass
|
174 |
+
"""
|
175 |
+
|
176 |
+
def force_fp32_wrapper(old_func):
|
177 |
+
|
178 |
+
@functools.wraps(old_func)
|
179 |
+
def new_func(*args, **kwargs):
|
180 |
+
# check if the module has set the attribute `fp16_enabled`, if not,
|
181 |
+
# just fallback to the original method.
|
182 |
+
if not isinstance(args[0], torch.nn.Module):
|
183 |
+
raise TypeError('@force_fp32 can only be used to decorate the '
|
184 |
+
'method of nn.Module')
|
185 |
+
if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
|
186 |
+
return old_func(*args, **kwargs)
|
187 |
+
# get the arg spec of the decorated method
|
188 |
+
args_info = getfullargspec(old_func)
|
189 |
+
# get the argument names to be casted
|
190 |
+
args_to_cast = args_info.args if apply_to is None else apply_to
|
191 |
+
# convert the args that need to be processed
|
192 |
+
new_args = []
|
193 |
+
if args:
|
194 |
+
arg_names = args_info.args[:len(args)]
|
195 |
+
for i, arg_name in enumerate(arg_names):
|
196 |
+
if arg_name in args_to_cast:
|
197 |
+
new_args.append(
|
198 |
+
cast_tensor_type(args[i], torch.half, torch.float))
|
199 |
+
else:
|
200 |
+
new_args.append(args[i])
|
201 |
+
# convert the kwargs that need to be processed
|
202 |
+
new_kwargs = dict()
|
203 |
+
if kwargs:
|
204 |
+
for arg_name, arg_value in kwargs.items():
|
205 |
+
if arg_name in args_to_cast:
|
206 |
+
new_kwargs[arg_name] = cast_tensor_type(
|
207 |
+
arg_value, torch.half, torch.float)
|
208 |
+
else:
|
209 |
+
new_kwargs[arg_name] = arg_value
|
210 |
+
# apply converted arguments to the decorated method
|
211 |
+
if (TORCH_VERSION != 'parrots' and
|
212 |
+
digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
|
213 |
+
with autocast(enabled=False):
|
214 |
+
output = old_func(*new_args, **new_kwargs)
|
215 |
+
else:
|
216 |
+
output = old_func(*new_args, **new_kwargs)
|
217 |
+
# cast the results back to fp32 if necessary
|
218 |
+
if out_fp16:
|
219 |
+
output = cast_tensor_type(output, torch.float, torch.half)
|
220 |
+
return output
|
221 |
+
|
222 |
+
return new_func
|
223 |
+
|
224 |
+
return force_fp32_wrapper
|
225 |
+
|
226 |
+
|
227 |
+
def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
|
228 |
+
warnings.warning(
|
229 |
+
'"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
|
230 |
+
'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
|
231 |
+
_allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
|
232 |
+
|
233 |
+
|
234 |
+
def wrap_fp16_model(model):
|
235 |
+
"""Wrap the FP32 model to FP16.
|
236 |
+
|
237 |
+
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
|
238 |
+
backend, otherwise, original mmcv implementation will be adopted.
|
239 |
+
|
240 |
+
For PyTorch >= 1.6, this function will
|
241 |
+
1. Set fp16 flag inside the model to True.
|
242 |
+
|
243 |
+
Otherwise:
|
244 |
+
1. Convert FP32 model to FP16.
|
245 |
+
2. Remain some necessary layers to be FP32, e.g., normalization layers.
|
246 |
+
3. Set `fp16_enabled` flag inside the model to True.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
model (nn.Module): Model in FP32.
|
250 |
+
"""
|
251 |
+
if (TORCH_VERSION == 'parrots'
|
252 |
+
or digit_version(TORCH_VERSION) < digit_version('1.6.0')):
|
253 |
+
# convert model to fp16
|
254 |
+
model.half()
|
255 |
+
# patch the normalization layers to make it work in fp32 mode
|
256 |
+
patch_norm_fp32(model)
|
257 |
+
# set `fp16_enabled` flag
|
258 |
+
for m in model.modules():
|
259 |
+
if hasattr(m, 'fp16_enabled'):
|
260 |
+
m.fp16_enabled = True
|
261 |
+
|
262 |
+
|
263 |
+
def patch_norm_fp32(module):
|
264 |
+
"""Recursively convert normalization layers from FP16 to FP32.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
module (nn.Module): The modules to be converted in FP16.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
nn.Module: The converted module, the normalization layers have been
|
271 |
+
converted to FP32.
|
272 |
+
"""
|
273 |
+
if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
|
274 |
+
module.float()
|
275 |
+
if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
|
276 |
+
module.forward = patch_forward_method(module.forward, torch.half,
|
277 |
+
torch.float)
|
278 |
+
for child in module.children():
|
279 |
+
patch_norm_fp32(child)
|
280 |
+
return module
|
281 |
+
|
282 |
+
|
283 |
+
def patch_forward_method(func, src_type, dst_type, convert_output=True):
|
284 |
+
"""Patch the forward method of a module.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
func (callable): The original forward method.
|
288 |
+
src_type (torch.dtype): Type of input arguments to be converted from.
|
289 |
+
dst_type (torch.dtype): Type of input arguments to be converted to.
|
290 |
+
convert_output (bool): Whether to convert the output back to src_type.
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
callable: The patched forward method.
|
294 |
+
"""
|
295 |
+
|
296 |
+
def new_forward(*args, **kwargs):
|
297 |
+
output = func(*cast_tensor_type(args, src_type, dst_type),
|
298 |
+
**cast_tensor_type(kwargs, src_type, dst_type))
|
299 |
+
if convert_output:
|
300 |
+
output = cast_tensor_type(output, dst_type, src_type)
|
301 |
+
return output
|
302 |
+
|
303 |
+
return new_forward
|
304 |
+
|
305 |
+
|
306 |
+
class LossScaler:
|
307 |
+
"""Class that manages loss scaling in mixed precision training which
|
308 |
+
supports both dynamic or static mode.
|
309 |
+
|
310 |
+
The implementation refers to
|
311 |
+
https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py.
|
312 |
+
Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling.
|
313 |
+
It's important to understand how :class:`LossScaler` operates.
|
314 |
+
Loss scaling is designed to combat the problem of underflowing
|
315 |
+
gradients encountered at long times when training fp16 networks.
|
316 |
+
Dynamic loss scaling begins by attempting a very high loss
|
317 |
+
scale. Ironically, this may result in OVERflowing gradients.
|
318 |
+
If overflowing gradients are encountered, :class:`FP16_Optimizer` then
|
319 |
+
skips the update step for this particular iteration/minibatch,
|
320 |
+
and :class:`LossScaler` adjusts the loss scale to a lower value.
|
321 |
+
If a certain number of iterations occur without overflowing gradients
|
322 |
+
detected,:class:`LossScaler` increases the loss scale once more.
|
323 |
+
In this way :class:`LossScaler` attempts to "ride the edge" of always
|
324 |
+
using the highest loss scale possible without incurring overflow.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
init_scale (float): Initial loss scale value, default: 2**32.
|
328 |
+
scale_factor (float): Factor used when adjusting the loss scale.
|
329 |
+
Default: 2.
|
330 |
+
mode (str): Loss scaling mode. 'dynamic' or 'static'
|
331 |
+
scale_window (int): Number of consecutive iterations without an
|
332 |
+
overflow to wait before increasing the loss scale. Default: 1000.
|
333 |
+
"""
|
334 |
+
|
335 |
+
def __init__(self,
|
336 |
+
init_scale=2**32,
|
337 |
+
mode='dynamic',
|
338 |
+
scale_factor=2.,
|
339 |
+
scale_window=1000):
|
340 |
+
self.cur_scale = init_scale
|
341 |
+
self.cur_iter = 0
|
342 |
+
assert mode in ('dynamic',
|
343 |
+
'static'), 'mode can only be dynamic or static'
|
344 |
+
self.mode = mode
|
345 |
+
self.last_overflow_iter = -1
|
346 |
+
self.scale_factor = scale_factor
|
347 |
+
self.scale_window = scale_window
|
348 |
+
|
349 |
+
def has_overflow(self, params):
|
350 |
+
"""Check if params contain overflow."""
|
351 |
+
if self.mode != 'dynamic':
|
352 |
+
return False
|
353 |
+
for p in params:
|
354 |
+
if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data):
|
355 |
+
return True
|
356 |
+
return False
|
357 |
+
|
358 |
+
def _has_inf_or_nan(x):
|
359 |
+
"""Check if params contain NaN."""
|
360 |
+
try:
|
361 |
+
cpu_sum = float(x.float().sum())
|
362 |
+
except RuntimeError as instance:
|
363 |
+
if 'value cannot be converted' not in instance.args[0]:
|
364 |
+
raise
|
365 |
+
return True
|
366 |
+
else:
|
367 |
+
if cpu_sum == float('inf') or cpu_sum == -float('inf') \
|
368 |
+
or cpu_sum != cpu_sum:
|
369 |
+
return True
|
370 |
+
return False
|
371 |
+
|
372 |
+
def update_scale(self, overflow):
|
373 |
+
"""update the current loss scale value when overflow happens."""
|
374 |
+
if self.mode != 'dynamic':
|
375 |
+
return
|
376 |
+
if overflow:
|
377 |
+
self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
|
378 |
+
self.last_overflow_iter = self.cur_iter
|
379 |
+
else:
|
380 |
+
if (self.cur_iter - self.last_overflow_iter) % \
|
381 |
+
self.scale_window == 0:
|
382 |
+
self.cur_scale *= self.scale_factor
|
383 |
+
self.cur_iter += 1
|
384 |
+
|
385 |
+
def state_dict(self):
|
386 |
+
"""Returns the state of the scaler as a :class:`dict`."""
|
387 |
+
return dict(
|
388 |
+
cur_scale=self.cur_scale,
|
389 |
+
cur_iter=self.cur_iter,
|
390 |
+
mode=self.mode,
|
391 |
+
last_overflow_iter=self.last_overflow_iter,
|
392 |
+
scale_factor=self.scale_factor,
|
393 |
+
scale_window=self.scale_window)
|
394 |
+
|
395 |
+
def load_state_dict(self, state_dict):
|
396 |
+
"""Loads the loss_scaler state dict.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
state_dict (dict): scaler state.
|
400 |
+
"""
|
401 |
+
self.cur_scale = state_dict['cur_scale']
|
402 |
+
self.cur_iter = state_dict['cur_iter']
|
403 |
+
self.mode = state_dict['mode']
|
404 |
+
self.last_overflow_iter = state_dict['last_overflow_iter']
|
405 |
+
self.scale_factor = state_dict['scale_factor']
|
406 |
+
self.scale_window = state_dict['scale_window']
|
407 |
+
|
408 |
+
@property
|
409 |
+
def loss_scale(self):
|
410 |
+
return self.cur_scale
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .checkpoint import CheckpointHook
|
3 |
+
from .closure import ClosureHook
|
4 |
+
from .ema import EMAHook
|
5 |
+
from .evaluation import DistEvalHook, EvalHook
|
6 |
+
from .hook import HOOKS, Hook
|
7 |
+
from .iter_timer import IterTimerHook
|
8 |
+
from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
|
9 |
+
NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
|
10 |
+
TextLoggerHook, WandbLoggerHook)
|
11 |
+
from .lr_updater import LrUpdaterHook
|
12 |
+
from .memory import EmptyCacheHook
|
13 |
+
from .momentum_updater import MomentumUpdaterHook
|
14 |
+
from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
|
15 |
+
GradientCumulativeOptimizerHook, OptimizerHook)
|
16 |
+
from .profiler import ProfilerHook
|
17 |
+
from .sampler_seed import DistSamplerSeedHook
|
18 |
+
from .sync_buffer import SyncBuffersHook
|
19 |
+
|
20 |
+
__all__ = [
|
21 |
+
'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
|
22 |
+
'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
|
23 |
+
'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
|
24 |
+
'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
|
25 |
+
'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
|
26 |
+
'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
|
27 |
+
'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook',
|
28 |
+
'GradientCumulativeFp16OptimizerHook'
|
29 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/checkpoint.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
from annotator.mmpkg.mmcv.fileio import FileClient
|
6 |
+
from ..dist_utils import allreduce_params, master_only
|
7 |
+
from .hook import HOOKS, Hook
|
8 |
+
|
9 |
+
|
10 |
+
@HOOKS.register_module()
|
11 |
+
class CheckpointHook(Hook):
|
12 |
+
"""Save checkpoints periodically.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
interval (int): The saving period. If ``by_epoch=True``, interval
|
16 |
+
indicates epochs, otherwise it indicates iterations.
|
17 |
+
Default: -1, which means "never".
|
18 |
+
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
19 |
+
Default: True.
|
20 |
+
save_optimizer (bool): Whether to save optimizer state_dict in the
|
21 |
+
checkpoint. It is usually used for resuming experiments.
|
22 |
+
Default: True.
|
23 |
+
out_dir (str, optional): The root directory to save checkpoints. If not
|
24 |
+
specified, ``runner.work_dir`` will be used by default. If
|
25 |
+
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
26 |
+
and the last level directory of ``runner.work_dir``.
|
27 |
+
`Changed in version 1.3.16.`
|
28 |
+
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
|
29 |
+
In some cases we want only the latest few checkpoints and would
|
30 |
+
like to delete old ones to save the disk space.
|
31 |
+
Default: -1, which means unlimited.
|
32 |
+
save_last (bool, optional): Whether to force the last checkpoint to be
|
33 |
+
saved regardless of interval. Default: True.
|
34 |
+
sync_buffer (bool, optional): Whether to synchronize buffers in
|
35 |
+
different gpus. Default: False.
|
36 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
37 |
+
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
38 |
+
Default: None.
|
39 |
+
`New in version 1.3.16.`
|
40 |
+
|
41 |
+
.. warning::
|
42 |
+
Before v1.3.16, the ``out_dir`` argument indicates the path where the
|
43 |
+
checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
|
44 |
+
root directory and the final path to save checkpoint is the
|
45 |
+
concatenation of ``out_dir`` and the last level directory of
|
46 |
+
``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
|
47 |
+
and the value of ``runner.work_dir`` is "/path/of/B", then the final
|
48 |
+
path will be "/path/of/A/B".
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self,
|
52 |
+
interval=-1,
|
53 |
+
by_epoch=True,
|
54 |
+
save_optimizer=True,
|
55 |
+
out_dir=None,
|
56 |
+
max_keep_ckpts=-1,
|
57 |
+
save_last=True,
|
58 |
+
sync_buffer=False,
|
59 |
+
file_client_args=None,
|
60 |
+
**kwargs):
|
61 |
+
self.interval = interval
|
62 |
+
self.by_epoch = by_epoch
|
63 |
+
self.save_optimizer = save_optimizer
|
64 |
+
self.out_dir = out_dir
|
65 |
+
self.max_keep_ckpts = max_keep_ckpts
|
66 |
+
self.save_last = save_last
|
67 |
+
self.args = kwargs
|
68 |
+
self.sync_buffer = sync_buffer
|
69 |
+
self.file_client_args = file_client_args
|
70 |
+
|
71 |
+
def before_run(self, runner):
|
72 |
+
if not self.out_dir:
|
73 |
+
self.out_dir = runner.work_dir
|
74 |
+
|
75 |
+
self.file_client = FileClient.infer_client(self.file_client_args,
|
76 |
+
self.out_dir)
|
77 |
+
|
78 |
+
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
|
79 |
+
# `self.out_dir` is set so the final `self.out_dir` is the
|
80 |
+
# concatenation of `self.out_dir` and the last level directory of
|
81 |
+
# `runner.work_dir`
|
82 |
+
if self.out_dir != runner.work_dir:
|
83 |
+
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
84 |
+
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
85 |
+
|
86 |
+
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
|
87 |
+
f'{self.file_client.name}.'))
|
88 |
+
|
89 |
+
# disable the create_symlink option because some file backends do not
|
90 |
+
# allow to create a symlink
|
91 |
+
if 'create_symlink' in self.args:
|
92 |
+
if self.args[
|
93 |
+
'create_symlink'] and not self.file_client.allow_symlink:
|
94 |
+
self.args['create_symlink'] = False
|
95 |
+
warnings.warn(
|
96 |
+
('create_symlink is set as True by the user but is changed'
|
97 |
+
'to be False because creating symbolic link is not '
|
98 |
+
f'allowed in {self.file_client.name}'))
|
99 |
+
else:
|
100 |
+
self.args['create_symlink'] = self.file_client.allow_symlink
|
101 |
+
|
102 |
+
def after_train_epoch(self, runner):
|
103 |
+
if not self.by_epoch:
|
104 |
+
return
|
105 |
+
|
106 |
+
# save checkpoint for following cases:
|
107 |
+
# 1. every ``self.interval`` epochs
|
108 |
+
# 2. reach the last epoch of training
|
109 |
+
if self.every_n_epochs(
|
110 |
+
runner, self.interval) or (self.save_last
|
111 |
+
and self.is_last_epoch(runner)):
|
112 |
+
runner.logger.info(
|
113 |
+
f'Saving checkpoint at {runner.epoch + 1} epochs')
|
114 |
+
if self.sync_buffer:
|
115 |
+
allreduce_params(runner.model.buffers())
|
116 |
+
self._save_checkpoint(runner)
|
117 |
+
|
118 |
+
@master_only
|
119 |
+
def _save_checkpoint(self, runner):
|
120 |
+
"""Save the current checkpoint and delete unwanted checkpoint."""
|
121 |
+
runner.save_checkpoint(
|
122 |
+
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
|
123 |
+
if runner.meta is not None:
|
124 |
+
if self.by_epoch:
|
125 |
+
cur_ckpt_filename = self.args.get(
|
126 |
+
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
|
127 |
+
else:
|
128 |
+
cur_ckpt_filename = self.args.get(
|
129 |
+
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
130 |
+
runner.meta.setdefault('hook_msgs', dict())
|
131 |
+
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
|
132 |
+
self.out_dir, cur_ckpt_filename)
|
133 |
+
# remove other checkpoints
|
134 |
+
if self.max_keep_ckpts > 0:
|
135 |
+
if self.by_epoch:
|
136 |
+
name = 'epoch_{}.pth'
|
137 |
+
current_ckpt = runner.epoch + 1
|
138 |
+
else:
|
139 |
+
name = 'iter_{}.pth'
|
140 |
+
current_ckpt = runner.iter + 1
|
141 |
+
redundant_ckpts = range(
|
142 |
+
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
143 |
+
-self.interval)
|
144 |
+
filename_tmpl = self.args.get('filename_tmpl', name)
|
145 |
+
for _step in redundant_ckpts:
|
146 |
+
ckpt_path = self.file_client.join_path(
|
147 |
+
self.out_dir, filename_tmpl.format(_step))
|
148 |
+
if self.file_client.isfile(ckpt_path):
|
149 |
+
self.file_client.remove(ckpt_path)
|
150 |
+
else:
|
151 |
+
break
|
152 |
+
|
153 |
+
def after_train_iter(self, runner):
|
154 |
+
if self.by_epoch:
|
155 |
+
return
|
156 |
+
|
157 |
+
# save checkpoint for following cases:
|
158 |
+
# 1. every ``self.interval`` iterations
|
159 |
+
# 2. reach the last iteration of training
|
160 |
+
if self.every_n_iters(
|
161 |
+
runner, self.interval) or (self.save_last
|
162 |
+
and self.is_last_iter(runner)):
|
163 |
+
runner.logger.info(
|
164 |
+
f'Saving checkpoint at {runner.iter + 1} iterations')
|
165 |
+
if self.sync_buffer:
|
166 |
+
allreduce_params(runner.model.buffers())
|
167 |
+
self._save_checkpoint(runner)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/closure.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .hook import HOOKS, Hook
|
3 |
+
|
4 |
+
|
5 |
+
@HOOKS.register_module()
|
6 |
+
class ClosureHook(Hook):
|
7 |
+
|
8 |
+
def __init__(self, fn_name, fn):
|
9 |
+
assert hasattr(self, fn_name)
|
10 |
+
assert callable(fn)
|
11 |
+
setattr(self, fn_name, fn)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/ema.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from ...parallel import is_module_wrapper
|
3 |
+
from ..hooks.hook import HOOKS, Hook
|
4 |
+
|
5 |
+
|
6 |
+
@HOOKS.register_module()
|
7 |
+
class EMAHook(Hook):
|
8 |
+
r"""Exponential Moving Average Hook.
|
9 |
+
|
10 |
+
Use Exponential Moving Average on all parameters of model in training
|
11 |
+
process. All parameters have a ema backup, which update by the formula
|
12 |
+
as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
|
13 |
+
|
14 |
+
.. math::
|
15 |
+
|
16 |
+
\text{Xema\_{t+1}} = (1 - \text{momentum}) \times
|
17 |
+
\text{Xema\_{t}} + \text{momentum} \times X_t
|
18 |
+
|
19 |
+
Args:
|
20 |
+
momentum (float): The momentum used for updating ema parameter.
|
21 |
+
Defaults to 0.0002.
|
22 |
+
interval (int): Update ema parameter every interval iteration.
|
23 |
+
Defaults to 1.
|
24 |
+
warm_up (int): During first warm_up steps, we may use smaller momentum
|
25 |
+
to update ema parameters more slowly. Defaults to 100.
|
26 |
+
resume_from (str): The checkpoint path. Defaults to None.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
momentum=0.0002,
|
31 |
+
interval=1,
|
32 |
+
warm_up=100,
|
33 |
+
resume_from=None):
|
34 |
+
assert isinstance(interval, int) and interval > 0
|
35 |
+
self.warm_up = warm_up
|
36 |
+
self.interval = interval
|
37 |
+
assert momentum > 0 and momentum < 1
|
38 |
+
self.momentum = momentum**interval
|
39 |
+
self.checkpoint = resume_from
|
40 |
+
|
41 |
+
def before_run(self, runner):
|
42 |
+
"""To resume model with it's ema parameters more friendly.
|
43 |
+
|
44 |
+
Register ema parameter as ``named_buffer`` to model
|
45 |
+
"""
|
46 |
+
model = runner.model
|
47 |
+
if is_module_wrapper(model):
|
48 |
+
model = model.module
|
49 |
+
self.param_ema_buffer = {}
|
50 |
+
self.model_parameters = dict(model.named_parameters(recurse=True))
|
51 |
+
for name, value in self.model_parameters.items():
|
52 |
+
# "." is not allowed in module's buffer name
|
53 |
+
buffer_name = f"ema_{name.replace('.', '_')}"
|
54 |
+
self.param_ema_buffer[name] = buffer_name
|
55 |
+
model.register_buffer(buffer_name, value.data.clone())
|
56 |
+
self.model_buffers = dict(model.named_buffers(recurse=True))
|
57 |
+
if self.checkpoint is not None:
|
58 |
+
runner.resume(self.checkpoint)
|
59 |
+
|
60 |
+
def after_train_iter(self, runner):
|
61 |
+
"""Update ema parameter every self.interval iterations."""
|
62 |
+
curr_step = runner.iter
|
63 |
+
# We warm up the momentum considering the instability at beginning
|
64 |
+
momentum = min(self.momentum,
|
65 |
+
(1 + curr_step) / (self.warm_up + curr_step))
|
66 |
+
if curr_step % self.interval != 0:
|
67 |
+
return
|
68 |
+
for name, parameter in self.model_parameters.items():
|
69 |
+
buffer_name = self.param_ema_buffer[name]
|
70 |
+
buffer_parameter = self.model_buffers[buffer_name]
|
71 |
+
buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
|
72 |
+
|
73 |
+
def after_train_epoch(self, runner):
|
74 |
+
"""We load parameter values from ema backup to model before the
|
75 |
+
EvalHook."""
|
76 |
+
self._swap_ema_parameters()
|
77 |
+
|
78 |
+
def before_train_epoch(self, runner):
|
79 |
+
"""We recover model's parameter from ema backup after last epoch's
|
80 |
+
EvalHook."""
|
81 |
+
self._swap_ema_parameters()
|
82 |
+
|
83 |
+
def _swap_ema_parameters(self):
|
84 |
+
"""Swap the parameter of model with parameter in ema_buffer."""
|
85 |
+
for name, value in self.model_parameters.items():
|
86 |
+
temp = value.data.clone()
|
87 |
+
ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
|
88 |
+
value.data.copy_(ema_buffer.data)
|
89 |
+
ema_buffer.data.copy_(temp)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/evaluation.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import warnings
|
4 |
+
from math import inf
|
5 |
+
|
6 |
+
import torch.distributed as dist
|
7 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
from annotator.mmpkg.mmcv.fileio import FileClient
|
11 |
+
from annotator.mmpkg.mmcv.utils import is_seq_of
|
12 |
+
from .hook import Hook
|
13 |
+
from .logger import LoggerHook
|
14 |
+
|
15 |
+
|
16 |
+
class EvalHook(Hook):
|
17 |
+
"""Non-Distributed evaluation hook.
|
18 |
+
|
19 |
+
This hook will regularly perform evaluation in a given interval when
|
20 |
+
performing in non-distributed environment.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
dataloader (DataLoader): A PyTorch dataloader, whose dataset has
|
24 |
+
implemented ``evaluate`` function.
|
25 |
+
start (int | None, optional): Evaluation starting epoch. It enables
|
26 |
+
evaluation before the training starts if ``start`` <= the resuming
|
27 |
+
epoch. If None, whether to evaluate is merely decided by
|
28 |
+
``interval``. Default: None.
|
29 |
+
interval (int): Evaluation interval. Default: 1.
|
30 |
+
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
|
31 |
+
If set to True, it will perform by epoch. Otherwise, by iteration.
|
32 |
+
Default: True.
|
33 |
+
save_best (str, optional): If a metric is specified, it would measure
|
34 |
+
the best checkpoint during evaluation. The information about best
|
35 |
+
checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
|
36 |
+
best score value and best checkpoint path, which will be also
|
37 |
+
loaded when resume checkpoint. Options are the evaluation metrics
|
38 |
+
on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
|
39 |
+
detection and instance segmentation. ``AR@100`` for proposal
|
40 |
+
recall. If ``save_best`` is ``auto``, the first key of the returned
|
41 |
+
``OrderedDict`` result will be used. Default: None.
|
42 |
+
rule (str | None, optional): Comparison rule for best score. If set to
|
43 |
+
None, it will infer a reasonable rule. Keys such as 'acc', 'top'
|
44 |
+
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
|
45 |
+
be inferred by 'less' rule. Options are 'greater', 'less', None.
|
46 |
+
Default: None.
|
47 |
+
test_fn (callable, optional): test a model with samples from a
|
48 |
+
dataloader, and return the test results. If ``None``, the default
|
49 |
+
test function ``mmcv.engine.single_gpu_test`` will be used.
|
50 |
+
(default: ``None``)
|
51 |
+
greater_keys (List[str] | None, optional): Metric keys that will be
|
52 |
+
inferred by 'greater' comparison rule. If ``None``,
|
53 |
+
_default_greater_keys will be used. (default: ``None``)
|
54 |
+
less_keys (List[str] | None, optional): Metric keys that will be
|
55 |
+
inferred by 'less' comparison rule. If ``None``, _default_less_keys
|
56 |
+
will be used. (default: ``None``)
|
57 |
+
out_dir (str, optional): The root directory to save checkpoints. If not
|
58 |
+
specified, `runner.work_dir` will be used by default. If specified,
|
59 |
+
the `out_dir` will be the concatenation of `out_dir` and the last
|
60 |
+
level directory of `runner.work_dir`.
|
61 |
+
`New in version 1.3.16.`
|
62 |
+
file_client_args (dict): Arguments to instantiate a FileClient.
|
63 |
+
See :class:`mmcv.fileio.FileClient` for details. Default: None.
|
64 |
+
`New in version 1.3.16.`
|
65 |
+
**eval_kwargs: Evaluation arguments fed into the evaluate function of
|
66 |
+
the dataset.
|
67 |
+
|
68 |
+
Notes:
|
69 |
+
If new arguments are added for EvalHook, tools/test.py,
|
70 |
+
tools/eval_metric.py may be affected.
|
71 |
+
"""
|
72 |
+
|
73 |
+
# Since the key for determine greater or less is related to the downstream
|
74 |
+
# tasks, downstream repos may need to overwrite the following inner
|
75 |
+
# variable accordingly.
|
76 |
+
|
77 |
+
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
|
78 |
+
init_value_map = {'greater': -inf, 'less': inf}
|
79 |
+
_default_greater_keys = [
|
80 |
+
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
|
81 |
+
'mAcc', 'aAcc'
|
82 |
+
]
|
83 |
+
_default_less_keys = ['loss']
|
84 |
+
|
85 |
+
def __init__(self,
|
86 |
+
dataloader,
|
87 |
+
start=None,
|
88 |
+
interval=1,
|
89 |
+
by_epoch=True,
|
90 |
+
save_best=None,
|
91 |
+
rule=None,
|
92 |
+
test_fn=None,
|
93 |
+
greater_keys=None,
|
94 |
+
less_keys=None,
|
95 |
+
out_dir=None,
|
96 |
+
file_client_args=None,
|
97 |
+
**eval_kwargs):
|
98 |
+
if not isinstance(dataloader, DataLoader):
|
99 |
+
raise TypeError(f'dataloader must be a pytorch DataLoader, '
|
100 |
+
f'but got {type(dataloader)}')
|
101 |
+
|
102 |
+
if interval <= 0:
|
103 |
+
raise ValueError(f'interval must be a positive number, '
|
104 |
+
f'but got {interval}')
|
105 |
+
|
106 |
+
assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'
|
107 |
+
|
108 |
+
if start is not None and start < 0:
|
109 |
+
raise ValueError(f'The evaluation start epoch {start} is smaller '
|
110 |
+
f'than 0')
|
111 |
+
|
112 |
+
self.dataloader = dataloader
|
113 |
+
self.interval = interval
|
114 |
+
self.start = start
|
115 |
+
self.by_epoch = by_epoch
|
116 |
+
|
117 |
+
assert isinstance(save_best, str) or save_best is None, \
|
118 |
+
'""save_best"" should be a str or None ' \
|
119 |
+
f'rather than {type(save_best)}'
|
120 |
+
self.save_best = save_best
|
121 |
+
self.eval_kwargs = eval_kwargs
|
122 |
+
self.initial_flag = True
|
123 |
+
|
124 |
+
if test_fn is None:
|
125 |
+
from annotator.mmpkg.mmcv.engine import single_gpu_test
|
126 |
+
self.test_fn = single_gpu_test
|
127 |
+
else:
|
128 |
+
self.test_fn = test_fn
|
129 |
+
|
130 |
+
if greater_keys is None:
|
131 |
+
self.greater_keys = self._default_greater_keys
|
132 |
+
else:
|
133 |
+
if not isinstance(greater_keys, (list, tuple)):
|
134 |
+
greater_keys = (greater_keys, )
|
135 |
+
assert is_seq_of(greater_keys, str)
|
136 |
+
self.greater_keys = greater_keys
|
137 |
+
|
138 |
+
if less_keys is None:
|
139 |
+
self.less_keys = self._default_less_keys
|
140 |
+
else:
|
141 |
+
if not isinstance(less_keys, (list, tuple)):
|
142 |
+
less_keys = (less_keys, )
|
143 |
+
assert is_seq_of(less_keys, str)
|
144 |
+
self.less_keys = less_keys
|
145 |
+
|
146 |
+
if self.save_best is not None:
|
147 |
+
self.best_ckpt_path = None
|
148 |
+
self._init_rule(rule, self.save_best)
|
149 |
+
|
150 |
+
self.out_dir = out_dir
|
151 |
+
self.file_client_args = file_client_args
|
152 |
+
|
153 |
+
def _init_rule(self, rule, key_indicator):
|
154 |
+
"""Initialize rule, key_indicator, comparison_func, and best score.
|
155 |
+
|
156 |
+
Here is the rule to determine which rule is used for key indicator
|
157 |
+
when the rule is not specific (note that the key indicator matching
|
158 |
+
is case-insensitive):
|
159 |
+
1. If the key indicator is in ``self.greater_keys``, the rule will be
|
160 |
+
specified as 'greater'.
|
161 |
+
2. Or if the key indicator is in ``self.less_keys``, the rule will be
|
162 |
+
specified as 'less'.
|
163 |
+
3. Or if the key indicator is equal to the substring in any one item
|
164 |
+
in ``self.greater_keys``, the rule will be specified as 'greater'.
|
165 |
+
4. Or if the key indicator is equal to the substring in any one item
|
166 |
+
in ``self.less_keys``, the rule will be specified as 'less'.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
rule (str | None): Comparison rule for best score.
|
170 |
+
key_indicator (str | None): Key indicator to determine the
|
171 |
+
comparison rule.
|
172 |
+
"""
|
173 |
+
if rule not in self.rule_map and rule is not None:
|
174 |
+
raise KeyError(f'rule must be greater, less or None, '
|
175 |
+
f'but got {rule}.')
|
176 |
+
|
177 |
+
if rule is None:
|
178 |
+
if key_indicator != 'auto':
|
179 |
+
# `_lc` here means we use the lower case of keys for
|
180 |
+
# case-insensitive matching
|
181 |
+
key_indicator_lc = key_indicator.lower()
|
182 |
+
greater_keys = [key.lower() for key in self.greater_keys]
|
183 |
+
less_keys = [key.lower() for key in self.less_keys]
|
184 |
+
|
185 |
+
if key_indicator_lc in greater_keys:
|
186 |
+
rule = 'greater'
|
187 |
+
elif key_indicator_lc in less_keys:
|
188 |
+
rule = 'less'
|
189 |
+
elif any(key in key_indicator_lc for key in greater_keys):
|
190 |
+
rule = 'greater'
|
191 |
+
elif any(key in key_indicator_lc for key in less_keys):
|
192 |
+
rule = 'less'
|
193 |
+
else:
|
194 |
+
raise ValueError(f'Cannot infer the rule for key '
|
195 |
+
f'{key_indicator}, thus a specific rule '
|
196 |
+
f'must be specified.')
|
197 |
+
self.rule = rule
|
198 |
+
self.key_indicator = key_indicator
|
199 |
+
if self.rule is not None:
|
200 |
+
self.compare_func = self.rule_map[self.rule]
|
201 |
+
|
202 |
+
def before_run(self, runner):
|
203 |
+
if not self.out_dir:
|
204 |
+
self.out_dir = runner.work_dir
|
205 |
+
|
206 |
+
self.file_client = FileClient.infer_client(self.file_client_args,
|
207 |
+
self.out_dir)
|
208 |
+
|
209 |
+
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
|
210 |
+
# `self.out_dir` is set so the final `self.out_dir` is the
|
211 |
+
# concatenation of `self.out_dir` and the last level directory of
|
212 |
+
# `runner.work_dir`
|
213 |
+
if self.out_dir != runner.work_dir:
|
214 |
+
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
215 |
+
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
216 |
+
runner.logger.info(
|
217 |
+
(f'The best checkpoint will be saved to {self.out_dir} by '
|
218 |
+
f'{self.file_client.name}'))
|
219 |
+
|
220 |
+
if self.save_best is not None:
|
221 |
+
if runner.meta is None:
|
222 |
+
warnings.warn('runner.meta is None. Creating an empty one.')
|
223 |
+
runner.meta = dict()
|
224 |
+
runner.meta.setdefault('hook_msgs', dict())
|
225 |
+
self.best_ckpt_path = runner.meta['hook_msgs'].get(
|
226 |
+
'best_ckpt', None)
|
227 |
+
|
228 |
+
def before_train_iter(self, runner):
|
229 |
+
"""Evaluate the model only at the start of training by iteration."""
|
230 |
+
if self.by_epoch or not self.initial_flag:
|
231 |
+
return
|
232 |
+
if self.start is not None and runner.iter >= self.start:
|
233 |
+
self.after_train_iter(runner)
|
234 |
+
self.initial_flag = False
|
235 |
+
|
236 |
+
def before_train_epoch(self, runner):
|
237 |
+
"""Evaluate the model only at the start of training by epoch."""
|
238 |
+
if not (self.by_epoch and self.initial_flag):
|
239 |
+
return
|
240 |
+
if self.start is not None and runner.epoch >= self.start:
|
241 |
+
self.after_train_epoch(runner)
|
242 |
+
self.initial_flag = False
|
243 |
+
|
244 |
+
def after_train_iter(self, runner):
|
245 |
+
"""Called after every training iter to evaluate the results."""
|
246 |
+
if not self.by_epoch and self._should_evaluate(runner):
|
247 |
+
# Because the priority of EvalHook is higher than LoggerHook, the
|
248 |
+
# training log and the evaluating log are mixed. Therefore,
|
249 |
+
# we need to dump the training log and clear it before evaluating
|
250 |
+
# log is generated. In addition, this problem will only appear in
|
251 |
+
# `IterBasedRunner` whose `self.by_epoch` is False, because
|
252 |
+
# `EpochBasedRunner` whose `self.by_epoch` is True calls
|
253 |
+
# `_do_evaluate` in `after_train_epoch` stage, and at this stage
|
254 |
+
# the training log has been printed, so it will not cause any
|
255 |
+
# problem. more details at
|
256 |
+
# https://github.com/open-mmlab/mmsegmentation/issues/694
|
257 |
+
for hook in runner._hooks:
|
258 |
+
if isinstance(hook, LoggerHook):
|
259 |
+
hook.after_train_iter(runner)
|
260 |
+
runner.log_buffer.clear()
|
261 |
+
|
262 |
+
self._do_evaluate(runner)
|
263 |
+
|
264 |
+
def after_train_epoch(self, runner):
|
265 |
+
"""Called after every training epoch to evaluate the results."""
|
266 |
+
if self.by_epoch and self._should_evaluate(runner):
|
267 |
+
self._do_evaluate(runner)
|
268 |
+
|
269 |
+
def _do_evaluate(self, runner):
|
270 |
+
"""perform evaluation and save ckpt."""
|
271 |
+
results = self.test_fn(runner.model, self.dataloader)
|
272 |
+
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
273 |
+
key_score = self.evaluate(runner, results)
|
274 |
+
# the key_score may be `None` so it needs to skip the action to save
|
275 |
+
# the best checkpoint
|
276 |
+
if self.save_best and key_score:
|
277 |
+
self._save_ckpt(runner, key_score)
|
278 |
+
|
279 |
+
def _should_evaluate(self, runner):
|
280 |
+
"""Judge whether to perform evaluation.
|
281 |
+
|
282 |
+
Here is the rule to judge whether to perform evaluation:
|
283 |
+
1. It will not perform evaluation during the epoch/iteration interval,
|
284 |
+
which is determined by ``self.interval``.
|
285 |
+
2. It will not perform evaluation if the start time is larger than
|
286 |
+
current time.
|
287 |
+
3. It will not perform evaluation when current time is larger than
|
288 |
+
the start time but during epoch/iteration interval.
|
289 |
+
|
290 |
+
Returns:
|
291 |
+
bool: The flag indicating whether to perform evaluation.
|
292 |
+
"""
|
293 |
+
if self.by_epoch:
|
294 |
+
current = runner.epoch
|
295 |
+
check_time = self.every_n_epochs
|
296 |
+
else:
|
297 |
+
current = runner.iter
|
298 |
+
check_time = self.every_n_iters
|
299 |
+
|
300 |
+
if self.start is None:
|
301 |
+
if not check_time(runner, self.interval):
|
302 |
+
# No evaluation during the interval.
|
303 |
+
return False
|
304 |
+
elif (current + 1) < self.start:
|
305 |
+
# No evaluation if start is larger than the current time.
|
306 |
+
return False
|
307 |
+
else:
|
308 |
+
# Evaluation only at epochs/iters 3, 5, 7...
|
309 |
+
# if start==3 and interval==2
|
310 |
+
if (current + 1 - self.start) % self.interval:
|
311 |
+
return False
|
312 |
+
return True
|
313 |
+
|
314 |
+
def _save_ckpt(self, runner, key_score):
|
315 |
+
"""Save the best checkpoint.
|
316 |
+
|
317 |
+
It will compare the score according to the compare function, write
|
318 |
+
related information (best score, best checkpoint path) and save the
|
319 |
+
best checkpoint into ``work_dir``.
|
320 |
+
"""
|
321 |
+
if self.by_epoch:
|
322 |
+
current = f'epoch_{runner.epoch + 1}'
|
323 |
+
cur_type, cur_time = 'epoch', runner.epoch + 1
|
324 |
+
else:
|
325 |
+
current = f'iter_{runner.iter + 1}'
|
326 |
+
cur_type, cur_time = 'iter', runner.iter + 1
|
327 |
+
|
328 |
+
best_score = runner.meta['hook_msgs'].get(
|
329 |
+
'best_score', self.init_value_map[self.rule])
|
330 |
+
if self.compare_func(key_score, best_score):
|
331 |
+
best_score = key_score
|
332 |
+
runner.meta['hook_msgs']['best_score'] = best_score
|
333 |
+
|
334 |
+
if self.best_ckpt_path and self.file_client.isfile(
|
335 |
+
self.best_ckpt_path):
|
336 |
+
self.file_client.remove(self.best_ckpt_path)
|
337 |
+
runner.logger.info(
|
338 |
+
(f'The previous best checkpoint {self.best_ckpt_path} was '
|
339 |
+
'removed'))
|
340 |
+
|
341 |
+
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
|
342 |
+
self.best_ckpt_path = self.file_client.join_path(
|
343 |
+
self.out_dir, best_ckpt_name)
|
344 |
+
runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
|
345 |
+
|
346 |
+
runner.save_checkpoint(
|
347 |
+
self.out_dir, best_ckpt_name, create_symlink=False)
|
348 |
+
runner.logger.info(
|
349 |
+
f'Now best checkpoint is saved as {best_ckpt_name}.')
|
350 |
+
runner.logger.info(
|
351 |
+
f'Best {self.key_indicator} is {best_score:0.4f} '
|
352 |
+
f'at {cur_time} {cur_type}.')
|
353 |
+
|
354 |
+
def evaluate(self, runner, results):
|
355 |
+
"""Evaluate the results.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
runner (:obj:`mmcv.Runner`): The underlined training runner.
|
359 |
+
results (list): Output results.
|
360 |
+
"""
|
361 |
+
eval_res = self.dataloader.dataset.evaluate(
|
362 |
+
results, logger=runner.logger, **self.eval_kwargs)
|
363 |
+
|
364 |
+
for name, val in eval_res.items():
|
365 |
+
runner.log_buffer.output[name] = val
|
366 |
+
runner.log_buffer.ready = True
|
367 |
+
|
368 |
+
if self.save_best is not None:
|
369 |
+
# If the performance of model is pool, the `eval_res` may be an
|
370 |
+
# empty dict and it will raise exception when `self.save_best` is
|
371 |
+
# not None. More details at
|
372 |
+
# https://github.com/open-mmlab/mmdetection/issues/6265.
|
373 |
+
if not eval_res:
|
374 |
+
warnings.warn(
|
375 |
+
'Since `eval_res` is an empty dict, the behavior to save '
|
376 |
+
'the best checkpoint will be skipped in this evaluation.')
|
377 |
+
return None
|
378 |
+
|
379 |
+
if self.key_indicator == 'auto':
|
380 |
+
# infer from eval_results
|
381 |
+
self._init_rule(self.rule, list(eval_res.keys())[0])
|
382 |
+
return eval_res[self.key_indicator]
|
383 |
+
|
384 |
+
return None
|
385 |
+
|
386 |
+
|
387 |
+
class DistEvalHook(EvalHook):
|
388 |
+
"""Distributed evaluation hook.
|
389 |
+
|
390 |
+
This hook will regularly perform evaluation in a given interval when
|
391 |
+
performing in distributed environment.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
dataloader (DataLoader): A PyTorch dataloader, whose dataset has
|
395 |
+
implemented ``evaluate`` function.
|
396 |
+
start (int | None, optional): Evaluation starting epoch. It enables
|
397 |
+
evaluation before the training starts if ``start`` <= the resuming
|
398 |
+
epoch. If None, whether to evaluate is merely decided by
|
399 |
+
``interval``. Default: None.
|
400 |
+
interval (int): Evaluation interval. Default: 1.
|
401 |
+
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
|
402 |
+
If set to True, it will perform by epoch. Otherwise, by iteration.
|
403 |
+
default: True.
|
404 |
+
save_best (str, optional): If a metric is specified, it would measure
|
405 |
+
the best checkpoint during evaluation. The information about best
|
406 |
+
checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
|
407 |
+
best score value and best checkpoint path, which will be also
|
408 |
+
loaded when resume checkpoint. Options are the evaluation metrics
|
409 |
+
on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
|
410 |
+
detection and instance segmentation. ``AR@100`` for proposal
|
411 |
+
recall. If ``save_best`` is ``auto``, the first key of the returned
|
412 |
+
``OrderedDict`` result will be used. Default: None.
|
413 |
+
rule (str | None, optional): Comparison rule for best score. If set to
|
414 |
+
None, it will infer a reasonable rule. Keys such as 'acc', 'top'
|
415 |
+
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
|
416 |
+
be inferred by 'less' rule. Options are 'greater', 'less', None.
|
417 |
+
Default: None.
|
418 |
+
test_fn (callable, optional): test a model with samples from a
|
419 |
+
dataloader in a multi-gpu manner, and return the test results. If
|
420 |
+
``None``, the default test function ``mmcv.engine.multi_gpu_test``
|
421 |
+
will be used. (default: ``None``)
|
422 |
+
tmpdir (str | None): Temporary directory to save the results of all
|
423 |
+
processes. Default: None.
|
424 |
+
gpu_collect (bool): Whether to use gpu or cpu to collect results.
|
425 |
+
Default: False.
|
426 |
+
broadcast_bn_buffer (bool): Whether to broadcast the
|
427 |
+
buffer(running_mean and running_var) of rank 0 to other rank
|
428 |
+
before evaluation. Default: True.
|
429 |
+
out_dir (str, optional): The root directory to save checkpoints. If not
|
430 |
+
specified, `runner.work_dir` will be used by default. If specified,
|
431 |
+
the `out_dir` will be the concatenation of `out_dir` and the last
|
432 |
+
level directory of `runner.work_dir`.
|
433 |
+
file_client_args (dict): Arguments to instantiate a FileClient.
|
434 |
+
See :class:`mmcv.fileio.FileClient` for details. Default: None.
|
435 |
+
**eval_kwargs: Evaluation arguments fed into the evaluate function of
|
436 |
+
the dataset.
|
437 |
+
"""
|
438 |
+
|
439 |
+
def __init__(self,
|
440 |
+
dataloader,
|
441 |
+
start=None,
|
442 |
+
interval=1,
|
443 |
+
by_epoch=True,
|
444 |
+
save_best=None,
|
445 |
+
rule=None,
|
446 |
+
test_fn=None,
|
447 |
+
greater_keys=None,
|
448 |
+
less_keys=None,
|
449 |
+
broadcast_bn_buffer=True,
|
450 |
+
tmpdir=None,
|
451 |
+
gpu_collect=False,
|
452 |
+
out_dir=None,
|
453 |
+
file_client_args=None,
|
454 |
+
**eval_kwargs):
|
455 |
+
|
456 |
+
if test_fn is None:
|
457 |
+
from annotator.mmpkg.mmcv.engine import multi_gpu_test
|
458 |
+
test_fn = multi_gpu_test
|
459 |
+
|
460 |
+
super().__init__(
|
461 |
+
dataloader,
|
462 |
+
start=start,
|
463 |
+
interval=interval,
|
464 |
+
by_epoch=by_epoch,
|
465 |
+
save_best=save_best,
|
466 |
+
rule=rule,
|
467 |
+
test_fn=test_fn,
|
468 |
+
greater_keys=greater_keys,
|
469 |
+
less_keys=less_keys,
|
470 |
+
out_dir=out_dir,
|
471 |
+
file_client_args=file_client_args,
|
472 |
+
**eval_kwargs)
|
473 |
+
|
474 |
+
self.broadcast_bn_buffer = broadcast_bn_buffer
|
475 |
+
self.tmpdir = tmpdir
|
476 |
+
self.gpu_collect = gpu_collect
|
477 |
+
|
478 |
+
def _do_evaluate(self, runner):
|
479 |
+
"""perform evaluation and save ckpt."""
|
480 |
+
# Synchronization of BatchNorm's buffer (running_mean
|
481 |
+
# and running_var) is not supported in the DDP of pytorch,
|
482 |
+
# which may cause the inconsistent performance of models in
|
483 |
+
# different ranks, so we broadcast BatchNorm's buffers
|
484 |
+
# of rank 0 to other ranks to avoid this.
|
485 |
+
if self.broadcast_bn_buffer:
|
486 |
+
model = runner.model
|
487 |
+
for name, module in model.named_modules():
|
488 |
+
if isinstance(module,
|
489 |
+
_BatchNorm) and module.track_running_stats:
|
490 |
+
dist.broadcast(module.running_var, 0)
|
491 |
+
dist.broadcast(module.running_mean, 0)
|
492 |
+
|
493 |
+
tmpdir = self.tmpdir
|
494 |
+
if tmpdir is None:
|
495 |
+
tmpdir = osp.join(runner.work_dir, '.eval_hook')
|
496 |
+
|
497 |
+
results = self.test_fn(
|
498 |
+
runner.model,
|
499 |
+
self.dataloader,
|
500 |
+
tmpdir=tmpdir,
|
501 |
+
gpu_collect=self.gpu_collect)
|
502 |
+
if runner.rank == 0:
|
503 |
+
print('\n')
|
504 |
+
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
505 |
+
key_score = self.evaluate(runner, results)
|
506 |
+
# the key_score may be `None` so it needs to skip the action to
|
507 |
+
# save the best checkpoint
|
508 |
+
if self.save_best and key_score:
|
509 |
+
self._save_ckpt(runner, key_score)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/hook.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from annotator.mmpkg.mmcv.utils import Registry, is_method_overridden
|
3 |
+
|
4 |
+
HOOKS = Registry('hook')
|
5 |
+
|
6 |
+
|
7 |
+
class Hook:
|
8 |
+
stages = ('before_run', 'before_train_epoch', 'before_train_iter',
|
9 |
+
'after_train_iter', 'after_train_epoch', 'before_val_epoch',
|
10 |
+
'before_val_iter', 'after_val_iter', 'after_val_epoch',
|
11 |
+
'after_run')
|
12 |
+
|
13 |
+
def before_run(self, runner):
|
14 |
+
pass
|
15 |
+
|
16 |
+
def after_run(self, runner):
|
17 |
+
pass
|
18 |
+
|
19 |
+
def before_epoch(self, runner):
|
20 |
+
pass
|
21 |
+
|
22 |
+
def after_epoch(self, runner):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def before_iter(self, runner):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def after_iter(self, runner):
|
29 |
+
pass
|
30 |
+
|
31 |
+
def before_train_epoch(self, runner):
|
32 |
+
self.before_epoch(runner)
|
33 |
+
|
34 |
+
def before_val_epoch(self, runner):
|
35 |
+
self.before_epoch(runner)
|
36 |
+
|
37 |
+
def after_train_epoch(self, runner):
|
38 |
+
self.after_epoch(runner)
|
39 |
+
|
40 |
+
def after_val_epoch(self, runner):
|
41 |
+
self.after_epoch(runner)
|
42 |
+
|
43 |
+
def before_train_iter(self, runner):
|
44 |
+
self.before_iter(runner)
|
45 |
+
|
46 |
+
def before_val_iter(self, runner):
|
47 |
+
self.before_iter(runner)
|
48 |
+
|
49 |
+
def after_train_iter(self, runner):
|
50 |
+
self.after_iter(runner)
|
51 |
+
|
52 |
+
def after_val_iter(self, runner):
|
53 |
+
self.after_iter(runner)
|
54 |
+
|
55 |
+
def every_n_epochs(self, runner, n):
|
56 |
+
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
57 |
+
|
58 |
+
def every_n_inner_iters(self, runner, n):
|
59 |
+
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
|
60 |
+
|
61 |
+
def every_n_iters(self, runner, n):
|
62 |
+
return (runner.iter + 1) % n == 0 if n > 0 else False
|
63 |
+
|
64 |
+
def end_of_epoch(self, runner):
|
65 |
+
return runner.inner_iter + 1 == len(runner.data_loader)
|
66 |
+
|
67 |
+
def is_last_epoch(self, runner):
|
68 |
+
return runner.epoch + 1 == runner._max_epochs
|
69 |
+
|
70 |
+
def is_last_iter(self, runner):
|
71 |
+
return runner.iter + 1 == runner._max_iters
|
72 |
+
|
73 |
+
def get_triggered_stages(self):
|
74 |
+
trigger_stages = set()
|
75 |
+
for stage in Hook.stages:
|
76 |
+
if is_method_overridden(stage, Hook, self):
|
77 |
+
trigger_stages.add(stage)
|
78 |
+
|
79 |
+
# some methods will be triggered in multi stages
|
80 |
+
# use this dict to map method to stages.
|
81 |
+
method_stages_map = {
|
82 |
+
'before_epoch': ['before_train_epoch', 'before_val_epoch'],
|
83 |
+
'after_epoch': ['after_train_epoch', 'after_val_epoch'],
|
84 |
+
'before_iter': ['before_train_iter', 'before_val_iter'],
|
85 |
+
'after_iter': ['after_train_iter', 'after_val_iter'],
|
86 |
+
}
|
87 |
+
|
88 |
+
for method, map_stages in method_stages_map.items():
|
89 |
+
if is_method_overridden(method, Hook, self):
|
90 |
+
trigger_stages.update(map_stages)
|
91 |
+
|
92 |
+
return [stage for stage in Hook.stages if stage in trigger_stages]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/iter_timer.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import time
|
3 |
+
|
4 |
+
from .hook import HOOKS, Hook
|
5 |
+
|
6 |
+
|
7 |
+
@HOOKS.register_module()
|
8 |
+
class IterTimerHook(Hook):
|
9 |
+
|
10 |
+
def before_epoch(self, runner):
|
11 |
+
self.t = time.time()
|
12 |
+
|
13 |
+
def before_iter(self, runner):
|
14 |
+
runner.log_buffer.update({'data_time': time.time() - self.t})
|
15 |
+
|
16 |
+
def after_iter(self, runner):
|
17 |
+
runner.log_buffer.update({'time': time.time() - self.t})
|
18 |
+
self.t = time.time()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .base import LoggerHook
|
3 |
+
from .dvclive import DvcliveLoggerHook
|
4 |
+
from .mlflow import MlflowLoggerHook
|
5 |
+
from .neptune import NeptuneLoggerHook
|
6 |
+
from .pavi import PaviLoggerHook
|
7 |
+
from .tensorboard import TensorboardLoggerHook
|
8 |
+
from .text import TextLoggerHook
|
9 |
+
from .wandb import WandbLoggerHook
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
|
13 |
+
'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
|
14 |
+
'NeptuneLoggerHook', 'DvcliveLoggerHook'
|
15 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/base.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import numbers
|
3 |
+
from abc import ABCMeta, abstractmethod
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from ..hook import Hook
|
9 |
+
|
10 |
+
|
11 |
+
class LoggerHook(Hook):
|
12 |
+
"""Base class for logger hooks.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
interval (int): Logging interval (every k iterations).
|
16 |
+
ignore_last (bool): Ignore the log of last iterations in each epoch
|
17 |
+
if less than `interval`.
|
18 |
+
reset_flag (bool): Whether to clear the output buffer after logging.
|
19 |
+
by_epoch (bool): Whether EpochBasedRunner is used.
|
20 |
+
"""
|
21 |
+
|
22 |
+
__metaclass__ = ABCMeta
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
interval=10,
|
26 |
+
ignore_last=True,
|
27 |
+
reset_flag=False,
|
28 |
+
by_epoch=True):
|
29 |
+
self.interval = interval
|
30 |
+
self.ignore_last = ignore_last
|
31 |
+
self.reset_flag = reset_flag
|
32 |
+
self.by_epoch = by_epoch
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def log(self, runner):
|
36 |
+
pass
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def is_scalar(val, include_np=True, include_torch=True):
|
40 |
+
"""Tell the input variable is a scalar or not.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
val: Input variable.
|
44 |
+
include_np (bool): Whether include 0-d np.ndarray as a scalar.
|
45 |
+
include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
bool: True or False.
|
49 |
+
"""
|
50 |
+
if isinstance(val, numbers.Number):
|
51 |
+
return True
|
52 |
+
elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
|
53 |
+
return True
|
54 |
+
elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
|
55 |
+
return True
|
56 |
+
else:
|
57 |
+
return False
|
58 |
+
|
59 |
+
def get_mode(self, runner):
|
60 |
+
if runner.mode == 'train':
|
61 |
+
if 'time' in runner.log_buffer.output:
|
62 |
+
mode = 'train'
|
63 |
+
else:
|
64 |
+
mode = 'val'
|
65 |
+
elif runner.mode == 'val':
|
66 |
+
mode = 'val'
|
67 |
+
else:
|
68 |
+
raise ValueError(f"runner mode should be 'train' or 'val', "
|
69 |
+
f'but got {runner.mode}')
|
70 |
+
return mode
|
71 |
+
|
72 |
+
def get_epoch(self, runner):
|
73 |
+
if runner.mode == 'train':
|
74 |
+
epoch = runner.epoch + 1
|
75 |
+
elif runner.mode == 'val':
|
76 |
+
# normal val mode
|
77 |
+
# runner.epoch += 1 has been done before val workflow
|
78 |
+
epoch = runner.epoch
|
79 |
+
else:
|
80 |
+
raise ValueError(f"runner mode should be 'train' or 'val', "
|
81 |
+
f'but got {runner.mode}')
|
82 |
+
return epoch
|
83 |
+
|
84 |
+
def get_iter(self, runner, inner_iter=False):
|
85 |
+
"""Get the current training iteration step."""
|
86 |
+
if self.by_epoch and inner_iter:
|
87 |
+
current_iter = runner.inner_iter + 1
|
88 |
+
else:
|
89 |
+
current_iter = runner.iter + 1
|
90 |
+
return current_iter
|
91 |
+
|
92 |
+
def get_lr_tags(self, runner):
|
93 |
+
tags = {}
|
94 |
+
lrs = runner.current_lr()
|
95 |
+
if isinstance(lrs, dict):
|
96 |
+
for name, value in lrs.items():
|
97 |
+
tags[f'learning_rate/{name}'] = value[0]
|
98 |
+
else:
|
99 |
+
tags['learning_rate'] = lrs[0]
|
100 |
+
return tags
|
101 |
+
|
102 |
+
def get_momentum_tags(self, runner):
|
103 |
+
tags = {}
|
104 |
+
momentums = runner.current_momentum()
|
105 |
+
if isinstance(momentums, dict):
|
106 |
+
for name, value in momentums.items():
|
107 |
+
tags[f'momentum/{name}'] = value[0]
|
108 |
+
else:
|
109 |
+
tags['momentum'] = momentums[0]
|
110 |
+
return tags
|
111 |
+
|
112 |
+
def get_loggable_tags(self,
|
113 |
+
runner,
|
114 |
+
allow_scalar=True,
|
115 |
+
allow_text=False,
|
116 |
+
add_mode=True,
|
117 |
+
tags_to_skip=('time', 'data_time')):
|
118 |
+
tags = {}
|
119 |
+
for var, val in runner.log_buffer.output.items():
|
120 |
+
if var in tags_to_skip:
|
121 |
+
continue
|
122 |
+
if self.is_scalar(val) and not allow_scalar:
|
123 |
+
continue
|
124 |
+
if isinstance(val, str) and not allow_text:
|
125 |
+
continue
|
126 |
+
if add_mode:
|
127 |
+
var = f'{self.get_mode(runner)}/{var}'
|
128 |
+
tags[var] = val
|
129 |
+
tags.update(self.get_lr_tags(runner))
|
130 |
+
tags.update(self.get_momentum_tags(runner))
|
131 |
+
return tags
|
132 |
+
|
133 |
+
def before_run(self, runner):
|
134 |
+
for hook in runner.hooks[::-1]:
|
135 |
+
if isinstance(hook, LoggerHook):
|
136 |
+
hook.reset_flag = True
|
137 |
+
break
|
138 |
+
|
139 |
+
def before_epoch(self, runner):
|
140 |
+
runner.log_buffer.clear() # clear logs of last epoch
|
141 |
+
|
142 |
+
def after_train_iter(self, runner):
|
143 |
+
if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
|
144 |
+
runner.log_buffer.average(self.interval)
|
145 |
+
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
|
146 |
+
runner.log_buffer.average(self.interval)
|
147 |
+
elif self.end_of_epoch(runner) and not self.ignore_last:
|
148 |
+
# not precise but more stable
|
149 |
+
runner.log_buffer.average(self.interval)
|
150 |
+
|
151 |
+
if runner.log_buffer.ready:
|
152 |
+
self.log(runner)
|
153 |
+
if self.reset_flag:
|
154 |
+
runner.log_buffer.clear_output()
|
155 |
+
|
156 |
+
def after_train_epoch(self, runner):
|
157 |
+
if runner.log_buffer.ready:
|
158 |
+
self.log(runner)
|
159 |
+
if self.reset_flag:
|
160 |
+
runner.log_buffer.clear_output()
|
161 |
+
|
162 |
+
def after_val_epoch(self, runner):
|
163 |
+
runner.log_buffer.average()
|
164 |
+
self.log(runner)
|
165 |
+
if self.reset_flag:
|
166 |
+
runner.log_buffer.clear_output()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/dvclive.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from ...dist_utils import master_only
|
3 |
+
from ..hook import HOOKS
|
4 |
+
from .base import LoggerHook
|
5 |
+
|
6 |
+
|
7 |
+
@HOOKS.register_module()
|
8 |
+
class DvcliveLoggerHook(LoggerHook):
|
9 |
+
"""Class to log metrics with dvclive.
|
10 |
+
|
11 |
+
It requires `dvclive`_ to be installed.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
path (str): Directory where dvclive will write TSV log files.
|
15 |
+
interval (int): Logging interval (every k iterations).
|
16 |
+
Default 10.
|
17 |
+
ignore_last (bool): Ignore the log of last iterations in each epoch
|
18 |
+
if less than `interval`.
|
19 |
+
Default: True.
|
20 |
+
reset_flag (bool): Whether to clear the output buffer after logging.
|
21 |
+
Default: True.
|
22 |
+
by_epoch (bool): Whether EpochBasedRunner is used.
|
23 |
+
Default: True.
|
24 |
+
|
25 |
+
.. _dvclive:
|
26 |
+
https://dvc.org/doc/dvclive
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
path,
|
31 |
+
interval=10,
|
32 |
+
ignore_last=True,
|
33 |
+
reset_flag=True,
|
34 |
+
by_epoch=True):
|
35 |
+
|
36 |
+
super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
|
37 |
+
reset_flag, by_epoch)
|
38 |
+
self.path = path
|
39 |
+
self.import_dvclive()
|
40 |
+
|
41 |
+
def import_dvclive(self):
|
42 |
+
try:
|
43 |
+
import dvclive
|
44 |
+
except ImportError:
|
45 |
+
raise ImportError(
|
46 |
+
'Please run "pip install dvclive" to install dvclive')
|
47 |
+
self.dvclive = dvclive
|
48 |
+
|
49 |
+
@master_only
|
50 |
+
def before_run(self, runner):
|
51 |
+
self.dvclive.init(self.path)
|
52 |
+
|
53 |
+
@master_only
|
54 |
+
def log(self, runner):
|
55 |
+
tags = self.get_loggable_tags(runner)
|
56 |
+
if tags:
|
57 |
+
for k, v in tags.items():
|
58 |
+
self.dvclive.log(k, v, step=self.get_iter(runner))
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/mlflow.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from ...dist_utils import master_only
|
3 |
+
from ..hook import HOOKS
|
4 |
+
from .base import LoggerHook
|
5 |
+
|
6 |
+
|
7 |
+
@HOOKS.register_module()
|
8 |
+
class MlflowLoggerHook(LoggerHook):
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
exp_name=None,
|
12 |
+
tags=None,
|
13 |
+
log_model=True,
|
14 |
+
interval=10,
|
15 |
+
ignore_last=True,
|
16 |
+
reset_flag=False,
|
17 |
+
by_epoch=True):
|
18 |
+
"""Class to log metrics and (optionally) a trained model to MLflow.
|
19 |
+
|
20 |
+
It requires `MLflow`_ to be installed.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
exp_name (str, optional): Name of the experiment to be used.
|
24 |
+
Default None.
|
25 |
+
If not None, set the active experiment.
|
26 |
+
If experiment does not exist, an experiment with provided name
|
27 |
+
will be created.
|
28 |
+
tags (dict of str: str, optional): Tags for the current run.
|
29 |
+
Default None.
|
30 |
+
If not None, set tags for the current run.
|
31 |
+
log_model (bool, optional): Whether to log an MLflow artifact.
|
32 |
+
Default True.
|
33 |
+
If True, log runner.model as an MLflow artifact
|
34 |
+
for the current run.
|
35 |
+
interval (int): Logging interval (every k iterations).
|
36 |
+
ignore_last (bool): Ignore the log of last iterations in each epoch
|
37 |
+
if less than `interval`.
|
38 |
+
reset_flag (bool): Whether to clear the output buffer after logging
|
39 |
+
by_epoch (bool): Whether EpochBasedRunner is used.
|
40 |
+
|
41 |
+
.. _MLflow:
|
42 |
+
https://www.mlflow.org/docs/latest/index.html
|
43 |
+
"""
|
44 |
+
super(MlflowLoggerHook, self).__init__(interval, ignore_last,
|
45 |
+
reset_flag, by_epoch)
|
46 |
+
self.import_mlflow()
|
47 |
+
self.exp_name = exp_name
|
48 |
+
self.tags = tags
|
49 |
+
self.log_model = log_model
|
50 |
+
|
51 |
+
def import_mlflow(self):
|
52 |
+
try:
|
53 |
+
import mlflow
|
54 |
+
import mlflow.pytorch as mlflow_pytorch
|
55 |
+
except ImportError:
|
56 |
+
raise ImportError(
|
57 |
+
'Please run "pip install mlflow" to install mlflow')
|
58 |
+
self.mlflow = mlflow
|
59 |
+
self.mlflow_pytorch = mlflow_pytorch
|
60 |
+
|
61 |
+
@master_only
|
62 |
+
def before_run(self, runner):
|
63 |
+
super(MlflowLoggerHook, self).before_run(runner)
|
64 |
+
if self.exp_name is not None:
|
65 |
+
self.mlflow.set_experiment(self.exp_name)
|
66 |
+
if self.tags is not None:
|
67 |
+
self.mlflow.set_tags(self.tags)
|
68 |
+
|
69 |
+
@master_only
|
70 |
+
def log(self, runner):
|
71 |
+
tags = self.get_loggable_tags(runner)
|
72 |
+
if tags:
|
73 |
+
self.mlflow.log_metrics(tags, step=self.get_iter(runner))
|
74 |
+
|
75 |
+
@master_only
|
76 |
+
def after_run(self, runner):
|
77 |
+
if self.log_model:
|
78 |
+
self.mlflow_pytorch.log_model(runner.model, 'models')
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/neptune.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from ...dist_utils import master_only
|
3 |
+
from ..hook import HOOKS
|
4 |
+
from .base import LoggerHook
|
5 |
+
|
6 |
+
|
7 |
+
@HOOKS.register_module()
|
8 |
+
class NeptuneLoggerHook(LoggerHook):
|
9 |
+
"""Class to log metrics to NeptuneAI.
|
10 |
+
|
11 |
+
It requires `neptune-client` to be installed.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
init_kwargs (dict): a dict contains the initialization keys as below:
|
15 |
+
- project (str): Name of a project in a form of
|
16 |
+
namespace/project_name. If None, the value of
|
17 |
+
NEPTUNE_PROJECT environment variable will be taken.
|
18 |
+
- api_token (str): User’s API token.
|
19 |
+
If None, the value of NEPTUNE_API_TOKEN environment
|
20 |
+
variable will be taken. Note: It is strongly recommended
|
21 |
+
to use NEPTUNE_API_TOKEN environment variable rather than
|
22 |
+
placing your API token in plain text in your source code.
|
23 |
+
- name (str, optional, default is 'Untitled'): Editable name of
|
24 |
+
the run. Name is displayed in the run's Details and in
|
25 |
+
Runs table as a column.
|
26 |
+
Check https://docs.neptune.ai/api-reference/neptune#init for
|
27 |
+
more init arguments.
|
28 |
+
interval (int): Logging interval (every k iterations).
|
29 |
+
ignore_last (bool): Ignore the log of last iterations in each epoch
|
30 |
+
if less than `interval`.
|
31 |
+
reset_flag (bool): Whether to clear the output buffer after logging
|
32 |
+
by_epoch (bool): Whether EpochBasedRunner is used.
|
33 |
+
|
34 |
+
.. _NeptuneAI:
|
35 |
+
https://docs.neptune.ai/you-should-know/logging-metadata
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
init_kwargs=None,
|
40 |
+
interval=10,
|
41 |
+
ignore_last=True,
|
42 |
+
reset_flag=True,
|
43 |
+
with_step=True,
|
44 |
+
by_epoch=True):
|
45 |
+
|
46 |
+
super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
|
47 |
+
reset_flag, by_epoch)
|
48 |
+
self.import_neptune()
|
49 |
+
self.init_kwargs = init_kwargs
|
50 |
+
self.with_step = with_step
|
51 |
+
|
52 |
+
def import_neptune(self):
|
53 |
+
try:
|
54 |
+
import neptune.new as neptune
|
55 |
+
except ImportError:
|
56 |
+
raise ImportError(
|
57 |
+
'Please run "pip install neptune-client" to install neptune')
|
58 |
+
self.neptune = neptune
|
59 |
+
self.run = None
|
60 |
+
|
61 |
+
@master_only
|
62 |
+
def before_run(self, runner):
|
63 |
+
if self.init_kwargs:
|
64 |
+
self.run = self.neptune.init(**self.init_kwargs)
|
65 |
+
else:
|
66 |
+
self.run = self.neptune.init()
|
67 |
+
|
68 |
+
@master_only
|
69 |
+
def log(self, runner):
|
70 |
+
tags = self.get_loggable_tags(runner)
|
71 |
+
if tags:
|
72 |
+
for tag_name, tag_value in tags.items():
|
73 |
+
if self.with_step:
|
74 |
+
self.run[tag_name].log(
|
75 |
+
tag_value, step=self.get_iter(runner))
|
76 |
+
else:
|
77 |
+
tags['global_step'] = self.get_iter(runner)
|
78 |
+
self.run[tag_name].log(tags)
|
79 |
+
|
80 |
+
@master_only
|
81 |
+
def after_run(self, runner):
|
82 |
+
self.run.stop()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/pavi.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import yaml
|
8 |
+
|
9 |
+
import annotator.mmpkg.mmcv as mmcv
|
10 |
+
from ....parallel.utils import is_module_wrapper
|
11 |
+
from ...dist_utils import master_only
|
12 |
+
from ..hook import HOOKS
|
13 |
+
from .base import LoggerHook
|
14 |
+
|
15 |
+
|
16 |
+
@HOOKS.register_module()
|
17 |
+
class PaviLoggerHook(LoggerHook):
|
18 |
+
|
19 |
+
def __init__(self,
|
20 |
+
init_kwargs=None,
|
21 |
+
add_graph=False,
|
22 |
+
add_last_ckpt=False,
|
23 |
+
interval=10,
|
24 |
+
ignore_last=True,
|
25 |
+
reset_flag=False,
|
26 |
+
by_epoch=True,
|
27 |
+
img_key='img_info'):
|
28 |
+
super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
|
29 |
+
by_epoch)
|
30 |
+
self.init_kwargs = init_kwargs
|
31 |
+
self.add_graph = add_graph
|
32 |
+
self.add_last_ckpt = add_last_ckpt
|
33 |
+
self.img_key = img_key
|
34 |
+
|
35 |
+
@master_only
|
36 |
+
def before_run(self, runner):
|
37 |
+
super(PaviLoggerHook, self).before_run(runner)
|
38 |
+
try:
|
39 |
+
from pavi import SummaryWriter
|
40 |
+
except ImportError:
|
41 |
+
raise ImportError('Please run "pip install pavi" to install pavi.')
|
42 |
+
|
43 |
+
self.run_name = runner.work_dir.split('/')[-1]
|
44 |
+
|
45 |
+
if not self.init_kwargs:
|
46 |
+
self.init_kwargs = dict()
|
47 |
+
self.init_kwargs['name'] = self.run_name
|
48 |
+
self.init_kwargs['model'] = runner._model_name
|
49 |
+
if runner.meta is not None:
|
50 |
+
if 'config_dict' in runner.meta:
|
51 |
+
config_dict = runner.meta['config_dict']
|
52 |
+
assert isinstance(
|
53 |
+
config_dict,
|
54 |
+
dict), ('meta["config_dict"] has to be of a dict, '
|
55 |
+
f'but got {type(config_dict)}')
|
56 |
+
elif 'config_file' in runner.meta:
|
57 |
+
config_file = runner.meta['config_file']
|
58 |
+
config_dict = dict(mmcv.Config.fromfile(config_file))
|
59 |
+
else:
|
60 |
+
config_dict = None
|
61 |
+
if config_dict is not None:
|
62 |
+
# 'max_.*iter' is parsed in pavi sdk as the maximum iterations
|
63 |
+
# to properly set up the progress bar.
|
64 |
+
config_dict = config_dict.copy()
|
65 |
+
config_dict.setdefault('max_iter', runner.max_iters)
|
66 |
+
# non-serializable values are first converted in
|
67 |
+
# mmcv.dump to json
|
68 |
+
config_dict = json.loads(
|
69 |
+
mmcv.dump(config_dict, file_format='json'))
|
70 |
+
session_text = yaml.dump(config_dict)
|
71 |
+
self.init_kwargs['session_text'] = session_text
|
72 |
+
self.writer = SummaryWriter(**self.init_kwargs)
|
73 |
+
|
74 |
+
def get_step(self, runner):
|
75 |
+
"""Get the total training step/epoch."""
|
76 |
+
if self.get_mode(runner) == 'val' and self.by_epoch:
|
77 |
+
return self.get_epoch(runner)
|
78 |
+
else:
|
79 |
+
return self.get_iter(runner)
|
80 |
+
|
81 |
+
@master_only
|
82 |
+
def log(self, runner):
|
83 |
+
tags = self.get_loggable_tags(runner, add_mode=False)
|
84 |
+
if tags:
|
85 |
+
self.writer.add_scalars(
|
86 |
+
self.get_mode(runner), tags, self.get_step(runner))
|
87 |
+
|
88 |
+
@master_only
|
89 |
+
def after_run(self, runner):
|
90 |
+
if self.add_last_ckpt:
|
91 |
+
ckpt_path = osp.join(runner.work_dir, 'latest.pth')
|
92 |
+
if osp.islink(ckpt_path):
|
93 |
+
ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
|
94 |
+
|
95 |
+
if osp.isfile(ckpt_path):
|
96 |
+
# runner.epoch += 1 has been done before `after_run`.
|
97 |
+
iteration = runner.epoch if self.by_epoch else runner.iter
|
98 |
+
return self.writer.add_snapshot_file(
|
99 |
+
tag=self.run_name,
|
100 |
+
snapshot_file_path=ckpt_path,
|
101 |
+
iteration=iteration)
|
102 |
+
|
103 |
+
# flush the buffer and send a task ending signal to Pavi
|
104 |
+
self.writer.close()
|
105 |
+
|
106 |
+
@master_only
|
107 |
+
def before_epoch(self, runner):
|
108 |
+
if runner.epoch == 0 and self.add_graph:
|
109 |
+
if is_module_wrapper(runner.model):
|
110 |
+
_model = runner.model.module
|
111 |
+
else:
|
112 |
+
_model = runner.model
|
113 |
+
device = next(_model.parameters()).device
|
114 |
+
data = next(iter(runner.data_loader))
|
115 |
+
image = data[self.img_key][0:1].to(device)
|
116 |
+
with torch.no_grad():
|
117 |
+
self.writer.add_graph(_model, image)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/tensorboard.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
|
4 |
+
from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version
|
5 |
+
from ...dist_utils import master_only
|
6 |
+
from ..hook import HOOKS
|
7 |
+
from .base import LoggerHook
|
8 |
+
|
9 |
+
|
10 |
+
@HOOKS.register_module()
|
11 |
+
class TensorboardLoggerHook(LoggerHook):
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
log_dir=None,
|
15 |
+
interval=10,
|
16 |
+
ignore_last=True,
|
17 |
+
reset_flag=False,
|
18 |
+
by_epoch=True):
|
19 |
+
super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
|
20 |
+
reset_flag, by_epoch)
|
21 |
+
self.log_dir = log_dir
|
22 |
+
|
23 |
+
@master_only
|
24 |
+
def before_run(self, runner):
|
25 |
+
super(TensorboardLoggerHook, self).before_run(runner)
|
26 |
+
if (TORCH_VERSION == 'parrots'
|
27 |
+
or digit_version(TORCH_VERSION) < digit_version('1.1')):
|
28 |
+
try:
|
29 |
+
from tensorboardX import SummaryWriter
|
30 |
+
except ImportError:
|
31 |
+
raise ImportError('Please install tensorboardX to use '
|
32 |
+
'TensorboardLoggerHook.')
|
33 |
+
else:
|
34 |
+
try:
|
35 |
+
from torch.utils.tensorboard import SummaryWriter
|
36 |
+
except ImportError:
|
37 |
+
raise ImportError(
|
38 |
+
'Please run "pip install future tensorboard" to install '
|
39 |
+
'the dependencies to use torch.utils.tensorboard '
|
40 |
+
'(applicable to PyTorch 1.1 or higher)')
|
41 |
+
|
42 |
+
if self.log_dir is None:
|
43 |
+
self.log_dir = osp.join(runner.work_dir, 'tf_logs')
|
44 |
+
self.writer = SummaryWriter(self.log_dir)
|
45 |
+
|
46 |
+
@master_only
|
47 |
+
def log(self, runner):
|
48 |
+
tags = self.get_loggable_tags(runner, allow_text=True)
|
49 |
+
for tag, val in tags.items():
|
50 |
+
if isinstance(val, str):
|
51 |
+
self.writer.add_text(tag, val, self.get_iter(runner))
|
52 |
+
else:
|
53 |
+
self.writer.add_scalar(tag, val, self.get_iter(runner))
|
54 |
+
|
55 |
+
@master_only
|
56 |
+
def after_run(self, runner):
|
57 |
+
self.writer.close()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/text.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import datetime
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.distributed as dist
|
9 |
+
|
10 |
+
import annotator.mmpkg.mmcv as mmcv
|
11 |
+
from annotator.mmpkg.mmcv.fileio.file_client import FileClient
|
12 |
+
from annotator.mmpkg.mmcv.utils import is_tuple_of, scandir
|
13 |
+
from ..hook import HOOKS
|
14 |
+
from .base import LoggerHook
|
15 |
+
|
16 |
+
|
17 |
+
@HOOKS.register_module()
|
18 |
+
class TextLoggerHook(LoggerHook):
|
19 |
+
"""Logger hook in text.
|
20 |
+
|
21 |
+
In this logger hook, the information will be printed on terminal and
|
22 |
+
saved in json file.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
by_epoch (bool, optional): Whether EpochBasedRunner is used.
|
26 |
+
Default: True.
|
27 |
+
interval (int, optional): Logging interval (every k iterations).
|
28 |
+
Default: 10.
|
29 |
+
ignore_last (bool, optional): Ignore the log of last iterations in each
|
30 |
+
epoch if less than :attr:`interval`. Default: True.
|
31 |
+
reset_flag (bool, optional): Whether to clear the output buffer after
|
32 |
+
logging. Default: False.
|
33 |
+
interval_exp_name (int, optional): Logging interval for experiment
|
34 |
+
name. This feature is to help users conveniently get the experiment
|
35 |
+
information from screen or log file. Default: 1000.
|
36 |
+
out_dir (str, optional): Logs are saved in ``runner.work_dir`` default.
|
37 |
+
If ``out_dir`` is specified, logs will be copied to a new directory
|
38 |
+
which is the concatenation of ``out_dir`` and the last level
|
39 |
+
directory of ``runner.work_dir``. Default: None.
|
40 |
+
`New in version 1.3.16.`
|
41 |
+
out_suffix (str or tuple[str], optional): Those filenames ending with
|
42 |
+
``out_suffix`` will be copied to ``out_dir``.
|
43 |
+
Default: ('.log.json', '.log', '.py').
|
44 |
+
`New in version 1.3.16.`
|
45 |
+
keep_local (bool, optional): Whether to keep local log when
|
46 |
+
:attr:`out_dir` is specified. If False, the local log will be
|
47 |
+
removed. Default: True.
|
48 |
+
`New in version 1.3.16.`
|
49 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
50 |
+
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
51 |
+
Default: None.
|
52 |
+
`New in version 1.3.16.`
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self,
|
56 |
+
by_epoch=True,
|
57 |
+
interval=10,
|
58 |
+
ignore_last=True,
|
59 |
+
reset_flag=False,
|
60 |
+
interval_exp_name=1000,
|
61 |
+
out_dir=None,
|
62 |
+
out_suffix=('.log.json', '.log', '.py'),
|
63 |
+
keep_local=True,
|
64 |
+
file_client_args=None):
|
65 |
+
super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
|
66 |
+
by_epoch)
|
67 |
+
self.by_epoch = by_epoch
|
68 |
+
self.time_sec_tot = 0
|
69 |
+
self.interval_exp_name = interval_exp_name
|
70 |
+
|
71 |
+
if out_dir is None and file_client_args is not None:
|
72 |
+
raise ValueError(
|
73 |
+
'file_client_args should be "None" when `out_dir` is not'
|
74 |
+
'specified.')
|
75 |
+
self.out_dir = out_dir
|
76 |
+
|
77 |
+
if not (out_dir is None or isinstance(out_dir, str)
|
78 |
+
or is_tuple_of(out_dir, str)):
|
79 |
+
raise TypeError('out_dir should be "None" or string or tuple of '
|
80 |
+
'string, but got {out_dir}')
|
81 |
+
self.out_suffix = out_suffix
|
82 |
+
|
83 |
+
self.keep_local = keep_local
|
84 |
+
self.file_client_args = file_client_args
|
85 |
+
if self.out_dir is not None:
|
86 |
+
self.file_client = FileClient.infer_client(file_client_args,
|
87 |
+
self.out_dir)
|
88 |
+
|
89 |
+
def before_run(self, runner):
|
90 |
+
super(TextLoggerHook, self).before_run(runner)
|
91 |
+
|
92 |
+
if self.out_dir is not None:
|
93 |
+
self.file_client = FileClient.infer_client(self.file_client_args,
|
94 |
+
self.out_dir)
|
95 |
+
# The final `self.out_dir` is the concatenation of `self.out_dir`
|
96 |
+
# and the last level directory of `runner.work_dir`
|
97 |
+
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
98 |
+
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
99 |
+
runner.logger.info(
|
100 |
+
(f'Text logs will be saved to {self.out_dir} by '
|
101 |
+
f'{self.file_client.name} after the training process.'))
|
102 |
+
|
103 |
+
self.start_iter = runner.iter
|
104 |
+
self.json_log_path = osp.join(runner.work_dir,
|
105 |
+
f'{runner.timestamp}.log.json')
|
106 |
+
if runner.meta is not None:
|
107 |
+
self._dump_log(runner.meta, runner)
|
108 |
+
|
109 |
+
def _get_max_memory(self, runner):
|
110 |
+
device = getattr(runner.model, 'output_device', None)
|
111 |
+
mem = torch.cuda.max_memory_allocated(device=device)
|
112 |
+
mem_mb = torch.tensor([mem / (1024 * 1024)],
|
113 |
+
dtype=torch.int,
|
114 |
+
device=device)
|
115 |
+
if runner.world_size > 1:
|
116 |
+
dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
|
117 |
+
return mem_mb.item()
|
118 |
+
|
119 |
+
def _log_info(self, log_dict, runner):
|
120 |
+
# print exp name for users to distinguish experiments
|
121 |
+
# at every ``interval_exp_name`` iterations and the end of each epoch
|
122 |
+
if runner.meta is not None and 'exp_name' in runner.meta:
|
123 |
+
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
124 |
+
self.by_epoch and self.end_of_epoch(runner)):
|
125 |
+
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
126 |
+
runner.logger.info(exp_info)
|
127 |
+
|
128 |
+
if log_dict['mode'] == 'train':
|
129 |
+
if isinstance(log_dict['lr'], dict):
|
130 |
+
lr_str = []
|
131 |
+
for k, val in log_dict['lr'].items():
|
132 |
+
lr_str.append(f'lr_{k}: {val:.3e}')
|
133 |
+
lr_str = ' '.join(lr_str)
|
134 |
+
else:
|
135 |
+
lr_str = f'lr: {log_dict["lr"]:.3e}'
|
136 |
+
|
137 |
+
# by epoch: Epoch [4][100/1000]
|
138 |
+
# by iter: Iter [100/100000]
|
139 |
+
if self.by_epoch:
|
140 |
+
log_str = f'Epoch [{log_dict["epoch"]}]' \
|
141 |
+
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
|
142 |
+
else:
|
143 |
+
log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
|
144 |
+
log_str += f'{lr_str}, '
|
145 |
+
|
146 |
+
if 'time' in log_dict.keys():
|
147 |
+
self.time_sec_tot += (log_dict['time'] * self.interval)
|
148 |
+
time_sec_avg = self.time_sec_tot / (
|
149 |
+
runner.iter - self.start_iter + 1)
|
150 |
+
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
151 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
152 |
+
log_str += f'eta: {eta_str}, '
|
153 |
+
log_str += f'time: {log_dict["time"]:.3f}, ' \
|
154 |
+
f'data_time: {log_dict["data_time"]:.3f}, '
|
155 |
+
# statistic memory
|
156 |
+
if torch.cuda.is_available():
|
157 |
+
log_str += f'memory: {log_dict["memory"]}, '
|
158 |
+
else:
|
159 |
+
# val/test time
|
160 |
+
# here 1000 is the length of the val dataloader
|
161 |
+
# by epoch: Epoch[val] [4][1000]
|
162 |
+
# by iter: Iter[val] [1000]
|
163 |
+
if self.by_epoch:
|
164 |
+
log_str = f'Epoch({log_dict["mode"]}) ' \
|
165 |
+
f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
|
166 |
+
else:
|
167 |
+
log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
|
168 |
+
|
169 |
+
log_items = []
|
170 |
+
for name, val in log_dict.items():
|
171 |
+
# TODO: resolve this hack
|
172 |
+
# these items have been in log_str
|
173 |
+
if name in [
|
174 |
+
'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time',
|
175 |
+
'memory', 'epoch'
|
176 |
+
]:
|
177 |
+
continue
|
178 |
+
if isinstance(val, float):
|
179 |
+
val = f'{val:.4f}'
|
180 |
+
log_items.append(f'{name}: {val}')
|
181 |
+
log_str += ', '.join(log_items)
|
182 |
+
|
183 |
+
runner.logger.info(log_str)
|
184 |
+
|
185 |
+
def _dump_log(self, log_dict, runner):
|
186 |
+
# dump log in json format
|
187 |
+
json_log = OrderedDict()
|
188 |
+
for k, v in log_dict.items():
|
189 |
+
json_log[k] = self._round_float(v)
|
190 |
+
# only append log at last line
|
191 |
+
if runner.rank == 0:
|
192 |
+
with open(self.json_log_path, 'a+') as f:
|
193 |
+
mmcv.dump(json_log, f, file_format='json')
|
194 |
+
f.write('\n')
|
195 |
+
|
196 |
+
def _round_float(self, items):
|
197 |
+
if isinstance(items, list):
|
198 |
+
return [self._round_float(item) for item in items]
|
199 |
+
elif isinstance(items, float):
|
200 |
+
return round(items, 5)
|
201 |
+
else:
|
202 |
+
return items
|
203 |
+
|
204 |
+
def log(self, runner):
|
205 |
+
if 'eval_iter_num' in runner.log_buffer.output:
|
206 |
+
# this doesn't modify runner.iter and is regardless of by_epoch
|
207 |
+
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
|
208 |
+
else:
|
209 |
+
cur_iter = self.get_iter(runner, inner_iter=True)
|
210 |
+
|
211 |
+
log_dict = OrderedDict(
|
212 |
+
mode=self.get_mode(runner),
|
213 |
+
epoch=self.get_epoch(runner),
|
214 |
+
iter=cur_iter)
|
215 |
+
|
216 |
+
# only record lr of the first param group
|
217 |
+
cur_lr = runner.current_lr()
|
218 |
+
if isinstance(cur_lr, list):
|
219 |
+
log_dict['lr'] = cur_lr[0]
|
220 |
+
else:
|
221 |
+
assert isinstance(cur_lr, dict)
|
222 |
+
log_dict['lr'] = {}
|
223 |
+
for k, lr_ in cur_lr.items():
|
224 |
+
assert isinstance(lr_, list)
|
225 |
+
log_dict['lr'].update({k: lr_[0]})
|
226 |
+
|
227 |
+
if 'time' in runner.log_buffer.output:
|
228 |
+
# statistic memory
|
229 |
+
if torch.cuda.is_available():
|
230 |
+
log_dict['memory'] = self._get_max_memory(runner)
|
231 |
+
|
232 |
+
log_dict = dict(log_dict, **runner.log_buffer.output)
|
233 |
+
|
234 |
+
self._log_info(log_dict, runner)
|
235 |
+
self._dump_log(log_dict, runner)
|
236 |
+
return log_dict
|
237 |
+
|
238 |
+
def after_run(self, runner):
|
239 |
+
# copy or upload logs to self.out_dir
|
240 |
+
if self.out_dir is not None:
|
241 |
+
for filename in scandir(runner.work_dir, self.out_suffix, True):
|
242 |
+
local_filepath = osp.join(runner.work_dir, filename)
|
243 |
+
out_filepath = self.file_client.join_path(
|
244 |
+
self.out_dir, filename)
|
245 |
+
with open(local_filepath, 'r') as f:
|
246 |
+
self.file_client.put_text(f.read(), out_filepath)
|
247 |
+
|
248 |
+
runner.logger.info(
|
249 |
+
(f'The file {local_filepath} has been uploaded to '
|
250 |
+
f'{out_filepath}.'))
|
251 |
+
|
252 |
+
if not self.keep_local:
|
253 |
+
os.remove(local_filepath)
|
254 |
+
runner.logger.info(
|
255 |
+
(f'{local_filepath} was removed due to the '
|
256 |
+
'`self.keep_local=False`'))
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/logger/wandb.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from ...dist_utils import master_only
|
3 |
+
from ..hook import HOOKS
|
4 |
+
from .base import LoggerHook
|
5 |
+
|
6 |
+
|
7 |
+
@HOOKS.register_module()
|
8 |
+
class WandbLoggerHook(LoggerHook):
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
init_kwargs=None,
|
12 |
+
interval=10,
|
13 |
+
ignore_last=True,
|
14 |
+
reset_flag=False,
|
15 |
+
commit=True,
|
16 |
+
by_epoch=True,
|
17 |
+
with_step=True):
|
18 |
+
super(WandbLoggerHook, self).__init__(interval, ignore_last,
|
19 |
+
reset_flag, by_epoch)
|
20 |
+
self.import_wandb()
|
21 |
+
self.init_kwargs = init_kwargs
|
22 |
+
self.commit = commit
|
23 |
+
self.with_step = with_step
|
24 |
+
|
25 |
+
def import_wandb(self):
|
26 |
+
try:
|
27 |
+
import wandb
|
28 |
+
except ImportError:
|
29 |
+
raise ImportError(
|
30 |
+
'Please run "pip install wandb" to install wandb')
|
31 |
+
self.wandb = wandb
|
32 |
+
|
33 |
+
@master_only
|
34 |
+
def before_run(self, runner):
|
35 |
+
super(WandbLoggerHook, self).before_run(runner)
|
36 |
+
if self.wandb is None:
|
37 |
+
self.import_wandb()
|
38 |
+
if self.init_kwargs:
|
39 |
+
self.wandb.init(**self.init_kwargs)
|
40 |
+
else:
|
41 |
+
self.wandb.init()
|
42 |
+
|
43 |
+
@master_only
|
44 |
+
def log(self, runner):
|
45 |
+
tags = self.get_loggable_tags(runner)
|
46 |
+
if tags:
|
47 |
+
if self.with_step:
|
48 |
+
self.wandb.log(
|
49 |
+
tags, step=self.get_iter(runner), commit=self.commit)
|
50 |
+
else:
|
51 |
+
tags['global_step'] = self.get_iter(runner)
|
52 |
+
self.wandb.log(tags, commit=self.commit)
|
53 |
+
|
54 |
+
@master_only
|
55 |
+
def after_run(self, runner):
|
56 |
+
self.wandb.join()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/lr_updater.py
ADDED
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import numbers
|
3 |
+
from math import cos, pi
|
4 |
+
|
5 |
+
import annotator.mmpkg.mmcv as mmcv
|
6 |
+
from .hook import HOOKS, Hook
|
7 |
+
|
8 |
+
|
9 |
+
class LrUpdaterHook(Hook):
|
10 |
+
"""LR Scheduler in MMCV.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
by_epoch (bool): LR changes epoch by epoch
|
14 |
+
warmup (string): Type of warmup used. It can be None(use no warmup),
|
15 |
+
'constant', 'linear' or 'exp'
|
16 |
+
warmup_iters (int): The number of iterations or epochs that warmup
|
17 |
+
lasts
|
18 |
+
warmup_ratio (float): LR used at the beginning of warmup equals to
|
19 |
+
warmup_ratio * initial_lr
|
20 |
+
warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters
|
21 |
+
means the number of epochs that warmup lasts, otherwise means the
|
22 |
+
number of iteration that warmup lasts
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
by_epoch=True,
|
27 |
+
warmup=None,
|
28 |
+
warmup_iters=0,
|
29 |
+
warmup_ratio=0.1,
|
30 |
+
warmup_by_epoch=False):
|
31 |
+
# validate the "warmup" argument
|
32 |
+
if warmup is not None:
|
33 |
+
if warmup not in ['constant', 'linear', 'exp']:
|
34 |
+
raise ValueError(
|
35 |
+
f'"{warmup}" is not a supported type for warming up, valid'
|
36 |
+
' types are "constant" and "linear"')
|
37 |
+
if warmup is not None:
|
38 |
+
assert warmup_iters > 0, \
|
39 |
+
'"warmup_iters" must be a positive integer'
|
40 |
+
assert 0 < warmup_ratio <= 1.0, \
|
41 |
+
'"warmup_ratio" must be in range (0,1]'
|
42 |
+
|
43 |
+
self.by_epoch = by_epoch
|
44 |
+
self.warmup = warmup
|
45 |
+
self.warmup_iters = warmup_iters
|
46 |
+
self.warmup_ratio = warmup_ratio
|
47 |
+
self.warmup_by_epoch = warmup_by_epoch
|
48 |
+
|
49 |
+
if self.warmup_by_epoch:
|
50 |
+
self.warmup_epochs = self.warmup_iters
|
51 |
+
self.warmup_iters = None
|
52 |
+
else:
|
53 |
+
self.warmup_epochs = None
|
54 |
+
|
55 |
+
self.base_lr = [] # initial lr for all param groups
|
56 |
+
self.regular_lr = [] # expected lr if no warming up is performed
|
57 |
+
|
58 |
+
def _set_lr(self, runner, lr_groups):
|
59 |
+
if isinstance(runner.optimizer, dict):
|
60 |
+
for k, optim in runner.optimizer.items():
|
61 |
+
for param_group, lr in zip(optim.param_groups, lr_groups[k]):
|
62 |
+
param_group['lr'] = lr
|
63 |
+
else:
|
64 |
+
for param_group, lr in zip(runner.optimizer.param_groups,
|
65 |
+
lr_groups):
|
66 |
+
param_group['lr'] = lr
|
67 |
+
|
68 |
+
def get_lr(self, runner, base_lr):
|
69 |
+
raise NotImplementedError
|
70 |
+
|
71 |
+
def get_regular_lr(self, runner):
|
72 |
+
if isinstance(runner.optimizer, dict):
|
73 |
+
lr_groups = {}
|
74 |
+
for k in runner.optimizer.keys():
|
75 |
+
_lr_group = [
|
76 |
+
self.get_lr(runner, _base_lr)
|
77 |
+
for _base_lr in self.base_lr[k]
|
78 |
+
]
|
79 |
+
lr_groups.update({k: _lr_group})
|
80 |
+
|
81 |
+
return lr_groups
|
82 |
+
else:
|
83 |
+
return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
|
84 |
+
|
85 |
+
def get_warmup_lr(self, cur_iters):
|
86 |
+
|
87 |
+
def _get_warmup_lr(cur_iters, regular_lr):
|
88 |
+
if self.warmup == 'constant':
|
89 |
+
warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
|
90 |
+
elif self.warmup == 'linear':
|
91 |
+
k = (1 - cur_iters / self.warmup_iters) * (1 -
|
92 |
+
self.warmup_ratio)
|
93 |
+
warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
|
94 |
+
elif self.warmup == 'exp':
|
95 |
+
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
|
96 |
+
warmup_lr = [_lr * k for _lr in regular_lr]
|
97 |
+
return warmup_lr
|
98 |
+
|
99 |
+
if isinstance(self.regular_lr, dict):
|
100 |
+
lr_groups = {}
|
101 |
+
for key, regular_lr in self.regular_lr.items():
|
102 |
+
lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
|
103 |
+
return lr_groups
|
104 |
+
else:
|
105 |
+
return _get_warmup_lr(cur_iters, self.regular_lr)
|
106 |
+
|
107 |
+
def before_run(self, runner):
|
108 |
+
# NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
|
109 |
+
# it will be set according to the optimizer params
|
110 |
+
if isinstance(runner.optimizer, dict):
|
111 |
+
self.base_lr = {}
|
112 |
+
for k, optim in runner.optimizer.items():
|
113 |
+
for group in optim.param_groups:
|
114 |
+
group.setdefault('initial_lr', group['lr'])
|
115 |
+
_base_lr = [
|
116 |
+
group['initial_lr'] for group in optim.param_groups
|
117 |
+
]
|
118 |
+
self.base_lr.update({k: _base_lr})
|
119 |
+
else:
|
120 |
+
for group in runner.optimizer.param_groups:
|
121 |
+
group.setdefault('initial_lr', group['lr'])
|
122 |
+
self.base_lr = [
|
123 |
+
group['initial_lr'] for group in runner.optimizer.param_groups
|
124 |
+
]
|
125 |
+
|
126 |
+
def before_train_epoch(self, runner):
|
127 |
+
if self.warmup_iters is None:
|
128 |
+
epoch_len = len(runner.data_loader)
|
129 |
+
self.warmup_iters = self.warmup_epochs * epoch_len
|
130 |
+
|
131 |
+
if not self.by_epoch:
|
132 |
+
return
|
133 |
+
|
134 |
+
self.regular_lr = self.get_regular_lr(runner)
|
135 |
+
self._set_lr(runner, self.regular_lr)
|
136 |
+
|
137 |
+
def before_train_iter(self, runner):
|
138 |
+
cur_iter = runner.iter
|
139 |
+
if not self.by_epoch:
|
140 |
+
self.regular_lr = self.get_regular_lr(runner)
|
141 |
+
if self.warmup is None or cur_iter >= self.warmup_iters:
|
142 |
+
self._set_lr(runner, self.regular_lr)
|
143 |
+
else:
|
144 |
+
warmup_lr = self.get_warmup_lr(cur_iter)
|
145 |
+
self._set_lr(runner, warmup_lr)
|
146 |
+
elif self.by_epoch:
|
147 |
+
if self.warmup is None or cur_iter > self.warmup_iters:
|
148 |
+
return
|
149 |
+
elif cur_iter == self.warmup_iters:
|
150 |
+
self._set_lr(runner, self.regular_lr)
|
151 |
+
else:
|
152 |
+
warmup_lr = self.get_warmup_lr(cur_iter)
|
153 |
+
self._set_lr(runner, warmup_lr)
|
154 |
+
|
155 |
+
|
156 |
+
@HOOKS.register_module()
|
157 |
+
class FixedLrUpdaterHook(LrUpdaterHook):
|
158 |
+
|
159 |
+
def __init__(self, **kwargs):
|
160 |
+
super(FixedLrUpdaterHook, self).__init__(**kwargs)
|
161 |
+
|
162 |
+
def get_lr(self, runner, base_lr):
|
163 |
+
return base_lr
|
164 |
+
|
165 |
+
|
166 |
+
@HOOKS.register_module()
|
167 |
+
class StepLrUpdaterHook(LrUpdaterHook):
|
168 |
+
"""Step LR scheduler with min_lr clipping.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
step (int | list[int]): Step to decay the LR. If an int value is given,
|
172 |
+
regard it as the decay interval. If a list is given, decay LR at
|
173 |
+
these steps.
|
174 |
+
gamma (float, optional): Decay LR ratio. Default: 0.1.
|
175 |
+
min_lr (float, optional): Minimum LR value to keep. If LR after decay
|
176 |
+
is lower than `min_lr`, it will be clipped to this value. If None
|
177 |
+
is given, we don't perform lr clipping. Default: None.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, step, gamma=0.1, min_lr=None, **kwargs):
|
181 |
+
if isinstance(step, list):
|
182 |
+
assert mmcv.is_list_of(step, int)
|
183 |
+
assert all([s > 0 for s in step])
|
184 |
+
elif isinstance(step, int):
|
185 |
+
assert step > 0
|
186 |
+
else:
|
187 |
+
raise TypeError('"step" must be a list or integer')
|
188 |
+
self.step = step
|
189 |
+
self.gamma = gamma
|
190 |
+
self.min_lr = min_lr
|
191 |
+
super(StepLrUpdaterHook, self).__init__(**kwargs)
|
192 |
+
|
193 |
+
def get_lr(self, runner, base_lr):
|
194 |
+
progress = runner.epoch if self.by_epoch else runner.iter
|
195 |
+
|
196 |
+
# calculate exponential term
|
197 |
+
if isinstance(self.step, int):
|
198 |
+
exp = progress // self.step
|
199 |
+
else:
|
200 |
+
exp = len(self.step)
|
201 |
+
for i, s in enumerate(self.step):
|
202 |
+
if progress < s:
|
203 |
+
exp = i
|
204 |
+
break
|
205 |
+
|
206 |
+
lr = base_lr * (self.gamma**exp)
|
207 |
+
if self.min_lr is not None:
|
208 |
+
# clip to a minimum value
|
209 |
+
lr = max(lr, self.min_lr)
|
210 |
+
return lr
|
211 |
+
|
212 |
+
|
213 |
+
@HOOKS.register_module()
|
214 |
+
class ExpLrUpdaterHook(LrUpdaterHook):
|
215 |
+
|
216 |
+
def __init__(self, gamma, **kwargs):
|
217 |
+
self.gamma = gamma
|
218 |
+
super(ExpLrUpdaterHook, self).__init__(**kwargs)
|
219 |
+
|
220 |
+
def get_lr(self, runner, base_lr):
|
221 |
+
progress = runner.epoch if self.by_epoch else runner.iter
|
222 |
+
return base_lr * self.gamma**progress
|
223 |
+
|
224 |
+
|
225 |
+
@HOOKS.register_module()
|
226 |
+
class PolyLrUpdaterHook(LrUpdaterHook):
|
227 |
+
|
228 |
+
def __init__(self, power=1., min_lr=0., **kwargs):
|
229 |
+
self.power = power
|
230 |
+
self.min_lr = min_lr
|
231 |
+
super(PolyLrUpdaterHook, self).__init__(**kwargs)
|
232 |
+
|
233 |
+
def get_lr(self, runner, base_lr):
|
234 |
+
if self.by_epoch:
|
235 |
+
progress = runner.epoch
|
236 |
+
max_progress = runner.max_epochs
|
237 |
+
else:
|
238 |
+
progress = runner.iter
|
239 |
+
max_progress = runner.max_iters
|
240 |
+
coeff = (1 - progress / max_progress)**self.power
|
241 |
+
return (base_lr - self.min_lr) * coeff + self.min_lr
|
242 |
+
|
243 |
+
|
244 |
+
@HOOKS.register_module()
|
245 |
+
class InvLrUpdaterHook(LrUpdaterHook):
|
246 |
+
|
247 |
+
def __init__(self, gamma, power=1., **kwargs):
|
248 |
+
self.gamma = gamma
|
249 |
+
self.power = power
|
250 |
+
super(InvLrUpdaterHook, self).__init__(**kwargs)
|
251 |
+
|
252 |
+
def get_lr(self, runner, base_lr):
|
253 |
+
progress = runner.epoch if self.by_epoch else runner.iter
|
254 |
+
return base_lr * (1 + self.gamma * progress)**(-self.power)
|
255 |
+
|
256 |
+
|
257 |
+
@HOOKS.register_module()
|
258 |
+
class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
|
259 |
+
|
260 |
+
def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
|
261 |
+
assert (min_lr is None) ^ (min_lr_ratio is None)
|
262 |
+
self.min_lr = min_lr
|
263 |
+
self.min_lr_ratio = min_lr_ratio
|
264 |
+
super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
|
265 |
+
|
266 |
+
def get_lr(self, runner, base_lr):
|
267 |
+
if self.by_epoch:
|
268 |
+
progress = runner.epoch
|
269 |
+
max_progress = runner.max_epochs
|
270 |
+
else:
|
271 |
+
progress = runner.iter
|
272 |
+
max_progress = runner.max_iters
|
273 |
+
|
274 |
+
if self.min_lr_ratio is not None:
|
275 |
+
target_lr = base_lr * self.min_lr_ratio
|
276 |
+
else:
|
277 |
+
target_lr = self.min_lr
|
278 |
+
return annealing_cos(base_lr, target_lr, progress / max_progress)
|
279 |
+
|
280 |
+
|
281 |
+
@HOOKS.register_module()
|
282 |
+
class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
|
283 |
+
"""Flat + Cosine lr schedule.
|
284 |
+
|
285 |
+
Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501
|
286 |
+
|
287 |
+
Args:
|
288 |
+
start_percent (float): When to start annealing the learning rate
|
289 |
+
after the percentage of the total training steps.
|
290 |
+
The value should be in range [0, 1).
|
291 |
+
Default: 0.75
|
292 |
+
min_lr (float, optional): The minimum lr. Default: None.
|
293 |
+
min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
|
294 |
+
Either `min_lr` or `min_lr_ratio` should be specified.
|
295 |
+
Default: None.
|
296 |
+
"""
|
297 |
+
|
298 |
+
def __init__(self,
|
299 |
+
start_percent=0.75,
|
300 |
+
min_lr=None,
|
301 |
+
min_lr_ratio=None,
|
302 |
+
**kwargs):
|
303 |
+
assert (min_lr is None) ^ (min_lr_ratio is None)
|
304 |
+
if start_percent < 0 or start_percent > 1 or not isinstance(
|
305 |
+
start_percent, float):
|
306 |
+
raise ValueError(
|
307 |
+
'expected float between 0 and 1 start_percent, but '
|
308 |
+
f'got {start_percent}')
|
309 |
+
self.start_percent = start_percent
|
310 |
+
self.min_lr = min_lr
|
311 |
+
self.min_lr_ratio = min_lr_ratio
|
312 |
+
super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
|
313 |
+
|
314 |
+
def get_lr(self, runner, base_lr):
|
315 |
+
if self.by_epoch:
|
316 |
+
start = round(runner.max_epochs * self.start_percent)
|
317 |
+
progress = runner.epoch - start
|
318 |
+
max_progress = runner.max_epochs - start
|
319 |
+
else:
|
320 |
+
start = round(runner.max_iters * self.start_percent)
|
321 |
+
progress = runner.iter - start
|
322 |
+
max_progress = runner.max_iters - start
|
323 |
+
|
324 |
+
if self.min_lr_ratio is not None:
|
325 |
+
target_lr = base_lr * self.min_lr_ratio
|
326 |
+
else:
|
327 |
+
target_lr = self.min_lr
|
328 |
+
|
329 |
+
if progress < 0:
|
330 |
+
return base_lr
|
331 |
+
else:
|
332 |
+
return annealing_cos(base_lr, target_lr, progress / max_progress)
|
333 |
+
|
334 |
+
|
335 |
+
@HOOKS.register_module()
|
336 |
+
class CosineRestartLrUpdaterHook(LrUpdaterHook):
|
337 |
+
"""Cosine annealing with restarts learning rate scheme.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
periods (list[int]): Periods for each cosine anneling cycle.
|
341 |
+
restart_weights (list[float], optional): Restart weights at each
|
342 |
+
restart iteration. Default: [1].
|
343 |
+
min_lr (float, optional): The minimum lr. Default: None.
|
344 |
+
min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
|
345 |
+
Either `min_lr` or `min_lr_ratio` should be specified.
|
346 |
+
Default: None.
|
347 |
+
"""
|
348 |
+
|
349 |
+
def __init__(self,
|
350 |
+
periods,
|
351 |
+
restart_weights=[1],
|
352 |
+
min_lr=None,
|
353 |
+
min_lr_ratio=None,
|
354 |
+
**kwargs):
|
355 |
+
assert (min_lr is None) ^ (min_lr_ratio is None)
|
356 |
+
self.periods = periods
|
357 |
+
self.min_lr = min_lr
|
358 |
+
self.min_lr_ratio = min_lr_ratio
|
359 |
+
self.restart_weights = restart_weights
|
360 |
+
assert (len(self.periods) == len(self.restart_weights)
|
361 |
+
), 'periods and restart_weights should have the same length.'
|
362 |
+
super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
|
363 |
+
|
364 |
+
self.cumulative_periods = [
|
365 |
+
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
366 |
+
]
|
367 |
+
|
368 |
+
def get_lr(self, runner, base_lr):
|
369 |
+
if self.by_epoch:
|
370 |
+
progress = runner.epoch
|
371 |
+
else:
|
372 |
+
progress = runner.iter
|
373 |
+
|
374 |
+
if self.min_lr_ratio is not None:
|
375 |
+
target_lr = base_lr * self.min_lr_ratio
|
376 |
+
else:
|
377 |
+
target_lr = self.min_lr
|
378 |
+
|
379 |
+
idx = get_position_from_periods(progress, self.cumulative_periods)
|
380 |
+
current_weight = self.restart_weights[idx]
|
381 |
+
nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
|
382 |
+
current_periods = self.periods[idx]
|
383 |
+
|
384 |
+
alpha = min((progress - nearest_restart) / current_periods, 1)
|
385 |
+
return annealing_cos(base_lr, target_lr, alpha, current_weight)
|
386 |
+
|
387 |
+
|
388 |
+
def get_position_from_periods(iteration, cumulative_periods):
|
389 |
+
"""Get the position from a period list.
|
390 |
+
|
391 |
+
It will return the index of the right-closest number in the period list.
|
392 |
+
For example, the cumulative_periods = [100, 200, 300, 400],
|
393 |
+
if iteration == 50, return 0;
|
394 |
+
if iteration == 210, return 2;
|
395 |
+
if iteration == 300, return 3.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
iteration (int): Current iteration.
|
399 |
+
cumulative_periods (list[int]): Cumulative period list.
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
int: The position of the right-closest number in the period list.
|
403 |
+
"""
|
404 |
+
for i, period in enumerate(cumulative_periods):
|
405 |
+
if iteration < period:
|
406 |
+
return i
|
407 |
+
raise ValueError(f'Current iteration {iteration} exceeds '
|
408 |
+
f'cumulative_periods {cumulative_periods}')
|
409 |
+
|
410 |
+
|
411 |
+
@HOOKS.register_module()
|
412 |
+
class CyclicLrUpdaterHook(LrUpdaterHook):
|
413 |
+
"""Cyclic LR Scheduler.
|
414 |
+
|
415 |
+
Implement the cyclical learning rate policy (CLR) described in
|
416 |
+
https://arxiv.org/pdf/1506.01186.pdf
|
417 |
+
|
418 |
+
Different from the original paper, we use cosine annealing rather than
|
419 |
+
triangular policy inside a cycle. This improves the performance in the
|
420 |
+
3D detection area.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
by_epoch (bool): Whether to update LR by epoch.
|
424 |
+
target_ratio (tuple[float]): Relative ratio of the highest LR and the
|
425 |
+
lowest LR to the initial LR.
|
426 |
+
cyclic_times (int): Number of cycles during training
|
427 |
+
step_ratio_up (float): The ratio of the increasing process of LR in
|
428 |
+
the total cycle.
|
429 |
+
anneal_strategy (str): {'cos', 'linear'}
|
430 |
+
Specifies the annealing strategy: 'cos' for cosine annealing,
|
431 |
+
'linear' for linear annealing. Default: 'cos'.
|
432 |
+
"""
|
433 |
+
|
434 |
+
def __init__(self,
|
435 |
+
by_epoch=False,
|
436 |
+
target_ratio=(10, 1e-4),
|
437 |
+
cyclic_times=1,
|
438 |
+
step_ratio_up=0.4,
|
439 |
+
anneal_strategy='cos',
|
440 |
+
**kwargs):
|
441 |
+
if isinstance(target_ratio, float):
|
442 |
+
target_ratio = (target_ratio, target_ratio / 1e5)
|
443 |
+
elif isinstance(target_ratio, tuple):
|
444 |
+
target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
|
445 |
+
if len(target_ratio) == 1 else target_ratio
|
446 |
+
else:
|
447 |
+
raise ValueError('target_ratio should be either float '
|
448 |
+
f'or tuple, got {type(target_ratio)}')
|
449 |
+
|
450 |
+
assert len(target_ratio) == 2, \
|
451 |
+
'"target_ratio" must be list or tuple of two floats'
|
452 |
+
assert 0 <= step_ratio_up < 1.0, \
|
453 |
+
'"step_ratio_up" must be in range [0,1)'
|
454 |
+
|
455 |
+
self.target_ratio = target_ratio
|
456 |
+
self.cyclic_times = cyclic_times
|
457 |
+
self.step_ratio_up = step_ratio_up
|
458 |
+
self.lr_phases = [] # init lr_phases
|
459 |
+
# validate anneal_strategy
|
460 |
+
if anneal_strategy not in ['cos', 'linear']:
|
461 |
+
raise ValueError('anneal_strategy must be one of "cos" or '
|
462 |
+
f'"linear", instead got {anneal_strategy}')
|
463 |
+
elif anneal_strategy == 'cos':
|
464 |
+
self.anneal_func = annealing_cos
|
465 |
+
elif anneal_strategy == 'linear':
|
466 |
+
self.anneal_func = annealing_linear
|
467 |
+
|
468 |
+
assert not by_epoch, \
|
469 |
+
'currently only support "by_epoch" = False'
|
470 |
+
super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
|
471 |
+
|
472 |
+
def before_run(self, runner):
|
473 |
+
super(CyclicLrUpdaterHook, self).before_run(runner)
|
474 |
+
# initiate lr_phases
|
475 |
+
# total lr_phases are separated as up and down
|
476 |
+
max_iter_per_phase = runner.max_iters // self.cyclic_times
|
477 |
+
iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
|
478 |
+
self.lr_phases.append(
|
479 |
+
[0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
|
480 |
+
self.lr_phases.append([
|
481 |
+
iter_up_phase, max_iter_per_phase, max_iter_per_phase,
|
482 |
+
self.target_ratio[0], self.target_ratio[1]
|
483 |
+
])
|
484 |
+
|
485 |
+
def get_lr(self, runner, base_lr):
|
486 |
+
curr_iter = runner.iter
|
487 |
+
for (start_iter, end_iter, max_iter_per_phase, start_ratio,
|
488 |
+
end_ratio) in self.lr_phases:
|
489 |
+
curr_iter %= max_iter_per_phase
|
490 |
+
if start_iter <= curr_iter < end_iter:
|
491 |
+
progress = curr_iter - start_iter
|
492 |
+
return self.anneal_func(base_lr * start_ratio,
|
493 |
+
base_lr * end_ratio,
|
494 |
+
progress / (end_iter - start_iter))
|
495 |
+
|
496 |
+
|
497 |
+
@HOOKS.register_module()
|
498 |
+
class OneCycleLrUpdaterHook(LrUpdaterHook):
|
499 |
+
"""One Cycle LR Scheduler.
|
500 |
+
|
501 |
+
The 1cycle learning rate policy changes the learning rate after every
|
502 |
+
batch. The one cycle learning rate policy is described in
|
503 |
+
https://arxiv.org/pdf/1708.07120.pdf
|
504 |
+
|
505 |
+
Args:
|
506 |
+
max_lr (float or list): Upper learning rate boundaries in the cycle
|
507 |
+
for each parameter group.
|
508 |
+
total_steps (int, optional): The total number of steps in the cycle.
|
509 |
+
Note that if a value is not provided here, it will be the max_iter
|
510 |
+
of runner. Default: None.
|
511 |
+
pct_start (float): The percentage of the cycle (in number of steps)
|
512 |
+
spent increasing the learning rate.
|
513 |
+
Default: 0.3
|
514 |
+
anneal_strategy (str): {'cos', 'linear'}
|
515 |
+
Specifies the annealing strategy: 'cos' for cosine annealing,
|
516 |
+
'linear' for linear annealing.
|
517 |
+
Default: 'cos'
|
518 |
+
div_factor (float): Determines the initial learning rate via
|
519 |
+
initial_lr = max_lr/div_factor
|
520 |
+
Default: 25
|
521 |
+
final_div_factor (float): Determines the minimum learning rate via
|
522 |
+
min_lr = initial_lr/final_div_factor
|
523 |
+
Default: 1e4
|
524 |
+
three_phase (bool): If three_phase is True, use a third phase of the
|
525 |
+
schedule to annihilate the learning rate according to
|
526 |
+
final_div_factor instead of modifying the second phase (the first
|
527 |
+
two phases will be symmetrical about the step indicated by
|
528 |
+
pct_start).
|
529 |
+
Default: False
|
530 |
+
"""
|
531 |
+
|
532 |
+
def __init__(self,
|
533 |
+
max_lr,
|
534 |
+
total_steps=None,
|
535 |
+
pct_start=0.3,
|
536 |
+
anneal_strategy='cos',
|
537 |
+
div_factor=25,
|
538 |
+
final_div_factor=1e4,
|
539 |
+
three_phase=False,
|
540 |
+
**kwargs):
|
541 |
+
# validate by_epoch, currently only support by_epoch = False
|
542 |
+
if 'by_epoch' not in kwargs:
|
543 |
+
kwargs['by_epoch'] = False
|
544 |
+
else:
|
545 |
+
assert not kwargs['by_epoch'], \
|
546 |
+
'currently only support "by_epoch" = False'
|
547 |
+
if not isinstance(max_lr, (numbers.Number, list, dict)):
|
548 |
+
raise ValueError('the type of max_lr must be the one of list or '
|
549 |
+
f'dict, but got {type(max_lr)}')
|
550 |
+
self._max_lr = max_lr
|
551 |
+
if total_steps is not None:
|
552 |
+
if not isinstance(total_steps, int):
|
553 |
+
raise ValueError('the type of total_steps must be int, but'
|
554 |
+
f'got {type(total_steps)}')
|
555 |
+
self.total_steps = total_steps
|
556 |
+
# validate pct_start
|
557 |
+
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
|
558 |
+
raise ValueError('expected float between 0 and 1 pct_start, but '
|
559 |
+
f'got {pct_start}')
|
560 |
+
self.pct_start = pct_start
|
561 |
+
# validate anneal_strategy
|
562 |
+
if anneal_strategy not in ['cos', 'linear']:
|
563 |
+
raise ValueError('anneal_strategy must be one of "cos" or '
|
564 |
+
f'"linear", instead got {anneal_strategy}')
|
565 |
+
elif anneal_strategy == 'cos':
|
566 |
+
self.anneal_func = annealing_cos
|
567 |
+
elif anneal_strategy == 'linear':
|
568 |
+
self.anneal_func = annealing_linear
|
569 |
+
self.div_factor = div_factor
|
570 |
+
self.final_div_factor = final_div_factor
|
571 |
+
self.three_phase = three_phase
|
572 |
+
self.lr_phases = [] # init lr_phases
|
573 |
+
super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
|
574 |
+
|
575 |
+
def before_run(self, runner):
|
576 |
+
if hasattr(self, 'total_steps'):
|
577 |
+
total_steps = self.total_steps
|
578 |
+
else:
|
579 |
+
total_steps = runner.max_iters
|
580 |
+
if total_steps < runner.max_iters:
|
581 |
+
raise ValueError(
|
582 |
+
'The total steps must be greater than or equal to max '
|
583 |
+
f'iterations {runner.max_iters} of runner, but total steps '
|
584 |
+
f'is {total_steps}.')
|
585 |
+
|
586 |
+
if isinstance(runner.optimizer, dict):
|
587 |
+
self.base_lr = {}
|
588 |
+
for k, optim in runner.optimizer.items():
|
589 |
+
_max_lr = format_param(k, optim, self._max_lr)
|
590 |
+
self.base_lr[k] = [lr / self.div_factor for lr in _max_lr]
|
591 |
+
for group, lr in zip(optim.param_groups, self.base_lr[k]):
|
592 |
+
group.setdefault('initial_lr', lr)
|
593 |
+
else:
|
594 |
+
k = type(runner.optimizer).__name__
|
595 |
+
_max_lr = format_param(k, runner.optimizer, self._max_lr)
|
596 |
+
self.base_lr = [lr / self.div_factor for lr in _max_lr]
|
597 |
+
for group, lr in zip(runner.optimizer.param_groups, self.base_lr):
|
598 |
+
group.setdefault('initial_lr', lr)
|
599 |
+
|
600 |
+
if self.three_phase:
|
601 |
+
self.lr_phases.append(
|
602 |
+
[float(self.pct_start * total_steps) - 1, 1, self.div_factor])
|
603 |
+
self.lr_phases.append([
|
604 |
+
float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1
|
605 |
+
])
|
606 |
+
self.lr_phases.append(
|
607 |
+
[total_steps - 1, 1, 1 / self.final_div_factor])
|
608 |
+
else:
|
609 |
+
self.lr_phases.append(
|
610 |
+
[float(self.pct_start * total_steps) - 1, 1, self.div_factor])
|
611 |
+
self.lr_phases.append(
|
612 |
+
[total_steps - 1, self.div_factor, 1 / self.final_div_factor])
|
613 |
+
|
614 |
+
def get_lr(self, runner, base_lr):
|
615 |
+
curr_iter = runner.iter
|
616 |
+
start_iter = 0
|
617 |
+
for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
|
618 |
+
if curr_iter <= end_iter:
|
619 |
+
pct = (curr_iter - start_iter) / (end_iter - start_iter)
|
620 |
+
lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr,
|
621 |
+
pct)
|
622 |
+
break
|
623 |
+
start_iter = end_iter
|
624 |
+
return lr
|
625 |
+
|
626 |
+
|
627 |
+
def annealing_cos(start, end, factor, weight=1):
|
628 |
+
"""Calculate annealing cos learning rate.
|
629 |
+
|
630 |
+
Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
|
631 |
+
percentage goes from 0.0 to 1.0.
|
632 |
+
|
633 |
+
Args:
|
634 |
+
start (float): The starting learning rate of the cosine annealing.
|
635 |
+
end (float): The ending learing rate of the cosine annealing.
|
636 |
+
factor (float): The coefficient of `pi` when calculating the current
|
637 |
+
percentage. Range from 0.0 to 1.0.
|
638 |
+
weight (float, optional): The combination factor of `start` and `end`
|
639 |
+
when calculating the actual starting learning rate. Default to 1.
|
640 |
+
"""
|
641 |
+
cos_out = cos(pi * factor) + 1
|
642 |
+
return end + 0.5 * weight * (start - end) * cos_out
|
643 |
+
|
644 |
+
|
645 |
+
def annealing_linear(start, end, factor):
|
646 |
+
"""Calculate annealing linear learning rate.
|
647 |
+
|
648 |
+
Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
|
649 |
+
|
650 |
+
Args:
|
651 |
+
start (float): The starting learning rate of the linear annealing.
|
652 |
+
end (float): The ending learing rate of the linear annealing.
|
653 |
+
factor (float): The coefficient of `pi` when calculating the current
|
654 |
+
percentage. Range from 0.0 to 1.0.
|
655 |
+
"""
|
656 |
+
return start + (end - start) * factor
|
657 |
+
|
658 |
+
|
659 |
+
def format_param(name, optim, param):
|
660 |
+
if isinstance(param, numbers.Number):
|
661 |
+
return [param] * len(optim.param_groups)
|
662 |
+
elif isinstance(param, (list, tuple)): # multi param groups
|
663 |
+
if len(param) != len(optim.param_groups):
|
664 |
+
raise ValueError(f'expected {len(optim.param_groups)} '
|
665 |
+
f'values for {name}, got {len(param)}')
|
666 |
+
return param
|
667 |
+
else: # multi optimizers
|
668 |
+
if name not in param:
|
669 |
+
raise KeyError(f'{name} is not found in {param.keys()}')
|
670 |
+
return param[name]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/memory.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .hook import HOOKS, Hook
|
5 |
+
|
6 |
+
|
7 |
+
@HOOKS.register_module()
|
8 |
+
class EmptyCacheHook(Hook):
|
9 |
+
|
10 |
+
def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
|
11 |
+
self._before_epoch = before_epoch
|
12 |
+
self._after_epoch = after_epoch
|
13 |
+
self._after_iter = after_iter
|
14 |
+
|
15 |
+
def after_iter(self, runner):
|
16 |
+
if self._after_iter:
|
17 |
+
torch.cuda.empty_cache()
|
18 |
+
|
19 |
+
def before_epoch(self, runner):
|
20 |
+
if self._before_epoch:
|
21 |
+
torch.cuda.empty_cache()
|
22 |
+
|
23 |
+
def after_epoch(self, runner):
|
24 |
+
if self._after_epoch:
|
25 |
+
torch.cuda.empty_cache()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/momentum_updater.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import annotator.mmpkg.mmcv as mmcv
|
3 |
+
from .hook import HOOKS, Hook
|
4 |
+
from .lr_updater import annealing_cos, annealing_linear, format_param
|
5 |
+
|
6 |
+
|
7 |
+
class MomentumUpdaterHook(Hook):
|
8 |
+
|
9 |
+
def __init__(self,
|
10 |
+
by_epoch=True,
|
11 |
+
warmup=None,
|
12 |
+
warmup_iters=0,
|
13 |
+
warmup_ratio=0.9):
|
14 |
+
# validate the "warmup" argument
|
15 |
+
if warmup is not None:
|
16 |
+
if warmup not in ['constant', 'linear', 'exp']:
|
17 |
+
raise ValueError(
|
18 |
+
f'"{warmup}" is not a supported type for warming up, valid'
|
19 |
+
' types are "constant" and "linear"')
|
20 |
+
if warmup is not None:
|
21 |
+
assert warmup_iters > 0, \
|
22 |
+
'"warmup_iters" must be a positive integer'
|
23 |
+
assert 0 < warmup_ratio <= 1.0, \
|
24 |
+
'"warmup_momentum" must be in range (0,1]'
|
25 |
+
|
26 |
+
self.by_epoch = by_epoch
|
27 |
+
self.warmup = warmup
|
28 |
+
self.warmup_iters = warmup_iters
|
29 |
+
self.warmup_ratio = warmup_ratio
|
30 |
+
|
31 |
+
self.base_momentum = [] # initial momentum for all param groups
|
32 |
+
self.regular_momentum = [
|
33 |
+
] # expected momentum if no warming up is performed
|
34 |
+
|
35 |
+
def _set_momentum(self, runner, momentum_groups):
|
36 |
+
if isinstance(runner.optimizer, dict):
|
37 |
+
for k, optim in runner.optimizer.items():
|
38 |
+
for param_group, mom in zip(optim.param_groups,
|
39 |
+
momentum_groups[k]):
|
40 |
+
if 'momentum' in param_group.keys():
|
41 |
+
param_group['momentum'] = mom
|
42 |
+
elif 'betas' in param_group.keys():
|
43 |
+
param_group['betas'] = (mom, param_group['betas'][1])
|
44 |
+
else:
|
45 |
+
for param_group, mom in zip(runner.optimizer.param_groups,
|
46 |
+
momentum_groups):
|
47 |
+
if 'momentum' in param_group.keys():
|
48 |
+
param_group['momentum'] = mom
|
49 |
+
elif 'betas' in param_group.keys():
|
50 |
+
param_group['betas'] = (mom, param_group['betas'][1])
|
51 |
+
|
52 |
+
def get_momentum(self, runner, base_momentum):
|
53 |
+
raise NotImplementedError
|
54 |
+
|
55 |
+
def get_regular_momentum(self, runner):
|
56 |
+
if isinstance(runner.optimizer, dict):
|
57 |
+
momentum_groups = {}
|
58 |
+
for k in runner.optimizer.keys():
|
59 |
+
_momentum_group = [
|
60 |
+
self.get_momentum(runner, _base_momentum)
|
61 |
+
for _base_momentum in self.base_momentum[k]
|
62 |
+
]
|
63 |
+
momentum_groups.update({k: _momentum_group})
|
64 |
+
return momentum_groups
|
65 |
+
else:
|
66 |
+
return [
|
67 |
+
self.get_momentum(runner, _base_momentum)
|
68 |
+
for _base_momentum in self.base_momentum
|
69 |
+
]
|
70 |
+
|
71 |
+
def get_warmup_momentum(self, cur_iters):
|
72 |
+
|
73 |
+
def _get_warmup_momentum(cur_iters, regular_momentum):
|
74 |
+
if self.warmup == 'constant':
|
75 |
+
warmup_momentum = [
|
76 |
+
_momentum / self.warmup_ratio
|
77 |
+
for _momentum in self.regular_momentum
|
78 |
+
]
|
79 |
+
elif self.warmup == 'linear':
|
80 |
+
k = (1 - cur_iters / self.warmup_iters) * (1 -
|
81 |
+
self.warmup_ratio)
|
82 |
+
warmup_momentum = [
|
83 |
+
_momentum / (1 - k) for _momentum in self.regular_mom
|
84 |
+
]
|
85 |
+
elif self.warmup == 'exp':
|
86 |
+
k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
|
87 |
+
warmup_momentum = [
|
88 |
+
_momentum / k for _momentum in self.regular_mom
|
89 |
+
]
|
90 |
+
return warmup_momentum
|
91 |
+
|
92 |
+
if isinstance(self.regular_momentum, dict):
|
93 |
+
momentum_groups = {}
|
94 |
+
for key, regular_momentum in self.regular_momentum.items():
|
95 |
+
momentum_groups[key] = _get_warmup_momentum(
|
96 |
+
cur_iters, regular_momentum)
|
97 |
+
return momentum_groups
|
98 |
+
else:
|
99 |
+
return _get_warmup_momentum(cur_iters, self.regular_momentum)
|
100 |
+
|
101 |
+
def before_run(self, runner):
|
102 |
+
# NOTE: when resuming from a checkpoint,
|
103 |
+
# if 'initial_momentum' is not saved,
|
104 |
+
# it will be set according to the optimizer params
|
105 |
+
if isinstance(runner.optimizer, dict):
|
106 |
+
self.base_momentum = {}
|
107 |
+
for k, optim in runner.optimizer.items():
|
108 |
+
for group in optim.param_groups:
|
109 |
+
if 'momentum' in group.keys():
|
110 |
+
group.setdefault('initial_momentum', group['momentum'])
|
111 |
+
else:
|
112 |
+
group.setdefault('initial_momentum', group['betas'][0])
|
113 |
+
_base_momentum = [
|
114 |
+
group['initial_momentum'] for group in optim.param_groups
|
115 |
+
]
|
116 |
+
self.base_momentum.update({k: _base_momentum})
|
117 |
+
else:
|
118 |
+
for group in runner.optimizer.param_groups:
|
119 |
+
if 'momentum' in group.keys():
|
120 |
+
group.setdefault('initial_momentum', group['momentum'])
|
121 |
+
else:
|
122 |
+
group.setdefault('initial_momentum', group['betas'][0])
|
123 |
+
self.base_momentum = [
|
124 |
+
group['initial_momentum']
|
125 |
+
for group in runner.optimizer.param_groups
|
126 |
+
]
|
127 |
+
|
128 |
+
def before_train_epoch(self, runner):
|
129 |
+
if not self.by_epoch:
|
130 |
+
return
|
131 |
+
self.regular_mom = self.get_regular_momentum(runner)
|
132 |
+
self._set_momentum(runner, self.regular_mom)
|
133 |
+
|
134 |
+
def before_train_iter(self, runner):
|
135 |
+
cur_iter = runner.iter
|
136 |
+
if not self.by_epoch:
|
137 |
+
self.regular_mom = self.get_regular_momentum(runner)
|
138 |
+
if self.warmup is None or cur_iter >= self.warmup_iters:
|
139 |
+
self._set_momentum(runner, self.regular_mom)
|
140 |
+
else:
|
141 |
+
warmup_momentum = self.get_warmup_momentum(cur_iter)
|
142 |
+
self._set_momentum(runner, warmup_momentum)
|
143 |
+
elif self.by_epoch:
|
144 |
+
if self.warmup is None or cur_iter > self.warmup_iters:
|
145 |
+
return
|
146 |
+
elif cur_iter == self.warmup_iters:
|
147 |
+
self._set_momentum(runner, self.regular_mom)
|
148 |
+
else:
|
149 |
+
warmup_momentum = self.get_warmup_momentum(cur_iter)
|
150 |
+
self._set_momentum(runner, warmup_momentum)
|
151 |
+
|
152 |
+
|
153 |
+
@HOOKS.register_module()
|
154 |
+
class StepMomentumUpdaterHook(MomentumUpdaterHook):
|
155 |
+
"""Step momentum scheduler with min value clipping.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
step (int | list[int]): Step to decay the momentum. If an int value is
|
159 |
+
given, regard it as the decay interval. If a list is given, decay
|
160 |
+
momentum at these steps.
|
161 |
+
gamma (float, optional): Decay momentum ratio. Default: 0.5.
|
162 |
+
min_momentum (float, optional): Minimum momentum value to keep. If
|
163 |
+
momentum after decay is lower than this value, it will be clipped
|
164 |
+
accordingly. If None is given, we don't perform lr clipping.
|
165 |
+
Default: None.
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs):
|
169 |
+
if isinstance(step, list):
|
170 |
+
assert mmcv.is_list_of(step, int)
|
171 |
+
assert all([s > 0 for s in step])
|
172 |
+
elif isinstance(step, int):
|
173 |
+
assert step > 0
|
174 |
+
else:
|
175 |
+
raise TypeError('"step" must be a list or integer')
|
176 |
+
self.step = step
|
177 |
+
self.gamma = gamma
|
178 |
+
self.min_momentum = min_momentum
|
179 |
+
super(StepMomentumUpdaterHook, self).__init__(**kwargs)
|
180 |
+
|
181 |
+
def get_momentum(self, runner, base_momentum):
|
182 |
+
progress = runner.epoch if self.by_epoch else runner.iter
|
183 |
+
|
184 |
+
# calculate exponential term
|
185 |
+
if isinstance(self.step, int):
|
186 |
+
exp = progress // self.step
|
187 |
+
else:
|
188 |
+
exp = len(self.step)
|
189 |
+
for i, s in enumerate(self.step):
|
190 |
+
if progress < s:
|
191 |
+
exp = i
|
192 |
+
break
|
193 |
+
|
194 |
+
momentum = base_momentum * (self.gamma**exp)
|
195 |
+
if self.min_momentum is not None:
|
196 |
+
# clip to a minimum value
|
197 |
+
momentum = max(momentum, self.min_momentum)
|
198 |
+
return momentum
|
199 |
+
|
200 |
+
|
201 |
+
@HOOKS.register_module()
|
202 |
+
class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
|
203 |
+
|
204 |
+
def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
|
205 |
+
assert (min_momentum is None) ^ (min_momentum_ratio is None)
|
206 |
+
self.min_momentum = min_momentum
|
207 |
+
self.min_momentum_ratio = min_momentum_ratio
|
208 |
+
super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
|
209 |
+
|
210 |
+
def get_momentum(self, runner, base_momentum):
|
211 |
+
if self.by_epoch:
|
212 |
+
progress = runner.epoch
|
213 |
+
max_progress = runner.max_epochs
|
214 |
+
else:
|
215 |
+
progress = runner.iter
|
216 |
+
max_progress = runner.max_iters
|
217 |
+
if self.min_momentum_ratio is not None:
|
218 |
+
target_momentum = base_momentum * self.min_momentum_ratio
|
219 |
+
else:
|
220 |
+
target_momentum = self.min_momentum
|
221 |
+
return annealing_cos(base_momentum, target_momentum,
|
222 |
+
progress / max_progress)
|
223 |
+
|
224 |
+
|
225 |
+
@HOOKS.register_module()
|
226 |
+
class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
|
227 |
+
"""Cyclic momentum Scheduler.
|
228 |
+
|
229 |
+
Implement the cyclical momentum scheduler policy described in
|
230 |
+
https://arxiv.org/pdf/1708.07120.pdf
|
231 |
+
|
232 |
+
This momentum scheduler usually used together with the CyclicLRUpdater
|
233 |
+
to improve the performance in the 3D detection area.
|
234 |
+
|
235 |
+
Attributes:
|
236 |
+
target_ratio (tuple[float]): Relative ratio of the lowest momentum and
|
237 |
+
the highest momentum to the initial momentum.
|
238 |
+
cyclic_times (int): Number of cycles during training
|
239 |
+
step_ratio_up (float): The ratio of the increasing process of momentum
|
240 |
+
in the total cycle.
|
241 |
+
by_epoch (bool): Whether to update momentum by epoch.
|
242 |
+
"""
|
243 |
+
|
244 |
+
def __init__(self,
|
245 |
+
by_epoch=False,
|
246 |
+
target_ratio=(0.85 / 0.95, 1),
|
247 |
+
cyclic_times=1,
|
248 |
+
step_ratio_up=0.4,
|
249 |
+
**kwargs):
|
250 |
+
if isinstance(target_ratio, float):
|
251 |
+
target_ratio = (target_ratio, target_ratio / 1e5)
|
252 |
+
elif isinstance(target_ratio, tuple):
|
253 |
+
target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
|
254 |
+
if len(target_ratio) == 1 else target_ratio
|
255 |
+
else:
|
256 |
+
raise ValueError('target_ratio should be either float '
|
257 |
+
f'or tuple, got {type(target_ratio)}')
|
258 |
+
|
259 |
+
assert len(target_ratio) == 2, \
|
260 |
+
'"target_ratio" must be list or tuple of two floats'
|
261 |
+
assert 0 <= step_ratio_up < 1.0, \
|
262 |
+
'"step_ratio_up" must be in range [0,1)'
|
263 |
+
|
264 |
+
self.target_ratio = target_ratio
|
265 |
+
self.cyclic_times = cyclic_times
|
266 |
+
self.step_ratio_up = step_ratio_up
|
267 |
+
self.momentum_phases = [] # init momentum_phases
|
268 |
+
# currently only support by_epoch=False
|
269 |
+
assert not by_epoch, \
|
270 |
+
'currently only support "by_epoch" = False'
|
271 |
+
super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
|
272 |
+
|
273 |
+
def before_run(self, runner):
|
274 |
+
super(CyclicMomentumUpdaterHook, self).before_run(runner)
|
275 |
+
# initiate momentum_phases
|
276 |
+
# total momentum_phases are separated as up and down
|
277 |
+
max_iter_per_phase = runner.max_iters // self.cyclic_times
|
278 |
+
iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
|
279 |
+
self.momentum_phases.append(
|
280 |
+
[0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
|
281 |
+
self.momentum_phases.append([
|
282 |
+
iter_up_phase, max_iter_per_phase, max_iter_per_phase,
|
283 |
+
self.target_ratio[0], self.target_ratio[1]
|
284 |
+
])
|
285 |
+
|
286 |
+
def get_momentum(self, runner, base_momentum):
|
287 |
+
curr_iter = runner.iter
|
288 |
+
for (start_iter, end_iter, max_iter_per_phase, start_ratio,
|
289 |
+
end_ratio) in self.momentum_phases:
|
290 |
+
curr_iter %= max_iter_per_phase
|
291 |
+
if start_iter <= curr_iter < end_iter:
|
292 |
+
progress = curr_iter - start_iter
|
293 |
+
return annealing_cos(base_momentum * start_ratio,
|
294 |
+
base_momentum * end_ratio,
|
295 |
+
progress / (end_iter - start_iter))
|
296 |
+
|
297 |
+
|
298 |
+
@HOOKS.register_module()
|
299 |
+
class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
|
300 |
+
"""OneCycle momentum Scheduler.
|
301 |
+
|
302 |
+
This momentum scheduler usually used together with the OneCycleLrUpdater
|
303 |
+
to improve the performance.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
base_momentum (float or list): Lower momentum boundaries in the cycle
|
307 |
+
for each parameter group. Note that momentum is cycled inversely
|
308 |
+
to learning rate; at the peak of a cycle, momentum is
|
309 |
+
'base_momentum' and learning rate is 'max_lr'.
|
310 |
+
Default: 0.85
|
311 |
+
max_momentum (float or list): Upper momentum boundaries in the cycle
|
312 |
+
for each parameter group. Functionally,
|
313 |
+
it defines the cycle amplitude (max_momentum - base_momentum).
|
314 |
+
Note that momentum is cycled inversely
|
315 |
+
to learning rate; at the start of a cycle, momentum is
|
316 |
+
'max_momentum' and learning rate is 'base_lr'
|
317 |
+
Default: 0.95
|
318 |
+
pct_start (float): The percentage of the cycle (in number of steps)
|
319 |
+
spent increasing the learning rate.
|
320 |
+
Default: 0.3
|
321 |
+
anneal_strategy (str): {'cos', 'linear'}
|
322 |
+
Specifies the annealing strategy: 'cos' for cosine annealing,
|
323 |
+
'linear' for linear annealing.
|
324 |
+
Default: 'cos'
|
325 |
+
three_phase (bool): If three_phase is True, use a third phase of the
|
326 |
+
schedule to annihilate the learning rate according to
|
327 |
+
final_div_factor instead of modifying the second phase (the first
|
328 |
+
two phases will be symmetrical about the step indicated by
|
329 |
+
pct_start).
|
330 |
+
Default: False
|
331 |
+
"""
|
332 |
+
|
333 |
+
def __init__(self,
|
334 |
+
base_momentum=0.85,
|
335 |
+
max_momentum=0.95,
|
336 |
+
pct_start=0.3,
|
337 |
+
anneal_strategy='cos',
|
338 |
+
three_phase=False,
|
339 |
+
**kwargs):
|
340 |
+
# validate by_epoch, currently only support by_epoch=False
|
341 |
+
if 'by_epoch' not in kwargs:
|
342 |
+
kwargs['by_epoch'] = False
|
343 |
+
else:
|
344 |
+
assert not kwargs['by_epoch'], \
|
345 |
+
'currently only support "by_epoch" = False'
|
346 |
+
if not isinstance(base_momentum, (float, list, dict)):
|
347 |
+
raise ValueError('base_momentum must be the type among of float,'
|
348 |
+
'list or dict.')
|
349 |
+
self._base_momentum = base_momentum
|
350 |
+
if not isinstance(max_momentum, (float, list, dict)):
|
351 |
+
raise ValueError('max_momentum must be the type among of float,'
|
352 |
+
'list or dict.')
|
353 |
+
self._max_momentum = max_momentum
|
354 |
+
# validate pct_start
|
355 |
+
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
|
356 |
+
raise ValueError('Expected float between 0 and 1 pct_start, but '
|
357 |
+
f'got {pct_start}')
|
358 |
+
self.pct_start = pct_start
|
359 |
+
# validate anneal_strategy
|
360 |
+
if anneal_strategy not in ['cos', 'linear']:
|
361 |
+
raise ValueError('anneal_strategy must by one of "cos" or '
|
362 |
+
f'"linear", instead got {anneal_strategy}')
|
363 |
+
elif anneal_strategy == 'cos':
|
364 |
+
self.anneal_func = annealing_cos
|
365 |
+
elif anneal_strategy == 'linear':
|
366 |
+
self.anneal_func = annealing_linear
|
367 |
+
self.three_phase = three_phase
|
368 |
+
self.momentum_phases = [] # init momentum_phases
|
369 |
+
super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
|
370 |
+
|
371 |
+
def before_run(self, runner):
|
372 |
+
if isinstance(runner.optimizer, dict):
|
373 |
+
for k, optim in runner.optimizer.items():
|
374 |
+
if ('momentum' not in optim.defaults
|
375 |
+
and 'betas' not in optim.defaults):
|
376 |
+
raise ValueError('optimizer must support momentum with'
|
377 |
+
'option enabled')
|
378 |
+
self.use_beta1 = 'betas' in optim.defaults
|
379 |
+
_base_momentum = format_param(k, optim, self._base_momentum)
|
380 |
+
_max_momentum = format_param(k, optim, self._max_momentum)
|
381 |
+
for group, b_momentum, m_momentum in zip(
|
382 |
+
optim.param_groups, _base_momentum, _max_momentum):
|
383 |
+
if self.use_beta1:
|
384 |
+
_, beta2 = group['betas']
|
385 |
+
group['betas'] = (m_momentum, beta2)
|
386 |
+
else:
|
387 |
+
group['momentum'] = m_momentum
|
388 |
+
group['base_momentum'] = b_momentum
|
389 |
+
group['max_momentum'] = m_momentum
|
390 |
+
else:
|
391 |
+
optim = runner.optimizer
|
392 |
+
if ('momentum' not in optim.defaults
|
393 |
+
and 'betas' not in optim.defaults):
|
394 |
+
raise ValueError('optimizer must support momentum with'
|
395 |
+
'option enabled')
|
396 |
+
self.use_beta1 = 'betas' in optim.defaults
|
397 |
+
k = type(optim).__name__
|
398 |
+
_base_momentum = format_param(k, optim, self._base_momentum)
|
399 |
+
_max_momentum = format_param(k, optim, self._max_momentum)
|
400 |
+
for group, b_momentum, m_momentum in zip(optim.param_groups,
|
401 |
+
_base_momentum,
|
402 |
+
_max_momentum):
|
403 |
+
if self.use_beta1:
|
404 |
+
_, beta2 = group['betas']
|
405 |
+
group['betas'] = (m_momentum, beta2)
|
406 |
+
else:
|
407 |
+
group['momentum'] = m_momentum
|
408 |
+
group['base_momentum'] = b_momentum
|
409 |
+
group['max_momentum'] = m_momentum
|
410 |
+
|
411 |
+
if self.three_phase:
|
412 |
+
self.momentum_phases.append({
|
413 |
+
'end_iter':
|
414 |
+
float(self.pct_start * runner.max_iters) - 1,
|
415 |
+
'start_momentum':
|
416 |
+
'max_momentum',
|
417 |
+
'end_momentum':
|
418 |
+
'base_momentum'
|
419 |
+
})
|
420 |
+
self.momentum_phases.append({
|
421 |
+
'end_iter':
|
422 |
+
float(2 * self.pct_start * runner.max_iters) - 2,
|
423 |
+
'start_momentum':
|
424 |
+
'base_momentum',
|
425 |
+
'end_momentum':
|
426 |
+
'max_momentum'
|
427 |
+
})
|
428 |
+
self.momentum_phases.append({
|
429 |
+
'end_iter': runner.max_iters - 1,
|
430 |
+
'start_momentum': 'max_momentum',
|
431 |
+
'end_momentum': 'max_momentum'
|
432 |
+
})
|
433 |
+
else:
|
434 |
+
self.momentum_phases.append({
|
435 |
+
'end_iter':
|
436 |
+
float(self.pct_start * runner.max_iters) - 1,
|
437 |
+
'start_momentum':
|
438 |
+
'max_momentum',
|
439 |
+
'end_momentum':
|
440 |
+
'base_momentum'
|
441 |
+
})
|
442 |
+
self.momentum_phases.append({
|
443 |
+
'end_iter': runner.max_iters - 1,
|
444 |
+
'start_momentum': 'base_momentum',
|
445 |
+
'end_momentum': 'max_momentum'
|
446 |
+
})
|
447 |
+
|
448 |
+
def _set_momentum(self, runner, momentum_groups):
|
449 |
+
if isinstance(runner.optimizer, dict):
|
450 |
+
for k, optim in runner.optimizer.items():
|
451 |
+
for param_group, mom in zip(optim.param_groups,
|
452 |
+
momentum_groups[k]):
|
453 |
+
if 'momentum' in param_group.keys():
|
454 |
+
param_group['momentum'] = mom
|
455 |
+
elif 'betas' in param_group.keys():
|
456 |
+
param_group['betas'] = (mom, param_group['betas'][1])
|
457 |
+
else:
|
458 |
+
for param_group, mom in zip(runner.optimizer.param_groups,
|
459 |
+
momentum_groups):
|
460 |
+
if 'momentum' in param_group.keys():
|
461 |
+
param_group['momentum'] = mom
|
462 |
+
elif 'betas' in param_group.keys():
|
463 |
+
param_group['betas'] = (mom, param_group['betas'][1])
|
464 |
+
|
465 |
+
def get_momentum(self, runner, param_group):
|
466 |
+
curr_iter = runner.iter
|
467 |
+
start_iter = 0
|
468 |
+
for i, phase in enumerate(self.momentum_phases):
|
469 |
+
end_iter = phase['end_iter']
|
470 |
+
if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
|
471 |
+
pct = (curr_iter - start_iter) / (end_iter - start_iter)
|
472 |
+
momentum = self.anneal_func(
|
473 |
+
param_group[phase['start_momentum']],
|
474 |
+
param_group[phase['end_momentum']], pct)
|
475 |
+
break
|
476 |
+
start_iter = end_iter
|
477 |
+
return momentum
|
478 |
+
|
479 |
+
def get_regular_momentum(self, runner):
|
480 |
+
if isinstance(runner.optimizer, dict):
|
481 |
+
momentum_groups = {}
|
482 |
+
for k, optim in runner.optimizer.items():
|
483 |
+
_momentum_group = [
|
484 |
+
self.get_momentum(runner, param_group)
|
485 |
+
for param_group in optim.param_groups
|
486 |
+
]
|
487 |
+
momentum_groups.update({k: _momentum_group})
|
488 |
+
return momentum_groups
|
489 |
+
else:
|
490 |
+
momentum_groups = []
|
491 |
+
for param_group in runner.optimizer.param_groups:
|
492 |
+
momentum_groups.append(self.get_momentum(runner, param_group))
|
493 |
+
return momentum_groups
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/optimizer.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
from collections import defaultdict
|
4 |
+
from itertools import chain
|
5 |
+
|
6 |
+
from torch.nn.utils import clip_grad
|
7 |
+
|
8 |
+
from annotator.mmpkg.mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
|
9 |
+
from ..dist_utils import allreduce_grads
|
10 |
+
from ..fp16_utils import LossScaler, wrap_fp16_model
|
11 |
+
from .hook import HOOKS, Hook
|
12 |
+
|
13 |
+
try:
|
14 |
+
# If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
|
15 |
+
# and used; otherwise, auto fp16 will adopt mmcv's implementation.
|
16 |
+
from torch.cuda.amp import GradScaler
|
17 |
+
except ImportError:
|
18 |
+
pass
|
19 |
+
|
20 |
+
|
21 |
+
@HOOKS.register_module()
|
22 |
+
class OptimizerHook(Hook):
|
23 |
+
|
24 |
+
def __init__(self, grad_clip=None):
|
25 |
+
self.grad_clip = grad_clip
|
26 |
+
|
27 |
+
def clip_grads(self, params):
|
28 |
+
params = list(
|
29 |
+
filter(lambda p: p.requires_grad and p.grad is not None, params))
|
30 |
+
if len(params) > 0:
|
31 |
+
return clip_grad.clip_grad_norm_(params, **self.grad_clip)
|
32 |
+
|
33 |
+
def after_train_iter(self, runner):
|
34 |
+
runner.optimizer.zero_grad()
|
35 |
+
runner.outputs['loss'].backward()
|
36 |
+
if self.grad_clip is not None:
|
37 |
+
grad_norm = self.clip_grads(runner.model.parameters())
|
38 |
+
if grad_norm is not None:
|
39 |
+
# Add grad norm to the logger
|
40 |
+
runner.log_buffer.update({'grad_norm': float(grad_norm)},
|
41 |
+
runner.outputs['num_samples'])
|
42 |
+
runner.optimizer.step()
|
43 |
+
|
44 |
+
|
45 |
+
@HOOKS.register_module()
|
46 |
+
class GradientCumulativeOptimizerHook(OptimizerHook):
|
47 |
+
"""Optimizer Hook implements multi-iters gradient cumulating.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
cumulative_iters (int, optional): Num of gradient cumulative iters.
|
51 |
+
The optimizer will step every `cumulative_iters` iters.
|
52 |
+
Defaults to 1.
|
53 |
+
|
54 |
+
Examples:
|
55 |
+
>>> # Use cumulative_iters to simulate a large batch size
|
56 |
+
>>> # It is helpful when the hardware cannot handle a large batch size.
|
57 |
+
>>> loader = DataLoader(data, batch_size=64)
|
58 |
+
>>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4)
|
59 |
+
>>> # almost equals to
|
60 |
+
>>> loader = DataLoader(data, batch_size=256)
|
61 |
+
>>> optim_hook = OptimizerHook()
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self, cumulative_iters=1, **kwargs):
|
65 |
+
super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
|
66 |
+
|
67 |
+
assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
|
68 |
+
f'cumulative_iters only accepts positive int, but got ' \
|
69 |
+
f'{type(cumulative_iters)} instead.'
|
70 |
+
|
71 |
+
self.cumulative_iters = cumulative_iters
|
72 |
+
self.divisible_iters = 0
|
73 |
+
self.remainder_iters = 0
|
74 |
+
self.initialized = False
|
75 |
+
|
76 |
+
def has_batch_norm(self, module):
|
77 |
+
if isinstance(module, _BatchNorm):
|
78 |
+
return True
|
79 |
+
for m in module.children():
|
80 |
+
if self.has_batch_norm(m):
|
81 |
+
return True
|
82 |
+
return False
|
83 |
+
|
84 |
+
def _init(self, runner):
|
85 |
+
if runner.iter % self.cumulative_iters != 0:
|
86 |
+
runner.logger.warning(
|
87 |
+
'Resume iter number is not divisible by cumulative_iters in '
|
88 |
+
'GradientCumulativeOptimizerHook, which means the gradient of '
|
89 |
+
'some iters is lost and the result may be influenced slightly.'
|
90 |
+
)
|
91 |
+
|
92 |
+
if self.has_batch_norm(runner.model) and self.cumulative_iters > 1:
|
93 |
+
runner.logger.warning(
|
94 |
+
'GradientCumulativeOptimizerHook may slightly decrease '
|
95 |
+
'performance if the model has BatchNorm layers.')
|
96 |
+
|
97 |
+
residual_iters = runner.max_iters - runner.iter
|
98 |
+
|
99 |
+
self.divisible_iters = (
|
100 |
+
residual_iters // self.cumulative_iters * self.cumulative_iters)
|
101 |
+
self.remainder_iters = residual_iters - self.divisible_iters
|
102 |
+
|
103 |
+
self.initialized = True
|
104 |
+
|
105 |
+
def after_train_iter(self, runner):
|
106 |
+
if not self.initialized:
|
107 |
+
self._init(runner)
|
108 |
+
|
109 |
+
if runner.iter < self.divisible_iters:
|
110 |
+
loss_factor = self.cumulative_iters
|
111 |
+
else:
|
112 |
+
loss_factor = self.remainder_iters
|
113 |
+
loss = runner.outputs['loss']
|
114 |
+
loss = loss / loss_factor
|
115 |
+
loss.backward()
|
116 |
+
|
117 |
+
if (self.every_n_iters(runner, self.cumulative_iters)
|
118 |
+
or self.is_last_iter(runner)):
|
119 |
+
|
120 |
+
if self.grad_clip is not None:
|
121 |
+
grad_norm = self.clip_grads(runner.model.parameters())
|
122 |
+
if grad_norm is not None:
|
123 |
+
# Add grad norm to the logger
|
124 |
+
runner.log_buffer.update({'grad_norm': float(grad_norm)},
|
125 |
+
runner.outputs['num_samples'])
|
126 |
+
runner.optimizer.step()
|
127 |
+
runner.optimizer.zero_grad()
|
128 |
+
|
129 |
+
|
130 |
+
if (TORCH_VERSION != 'parrots'
|
131 |
+
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
|
132 |
+
|
133 |
+
@HOOKS.register_module()
|
134 |
+
class Fp16OptimizerHook(OptimizerHook):
|
135 |
+
"""FP16 optimizer hook (using PyTorch's implementation).
|
136 |
+
|
137 |
+
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
|
138 |
+
to take care of the optimization procedure.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
loss_scale (float | str | dict): Scale factor configuration.
|
142 |
+
If loss_scale is a float, static loss scaling will be used with
|
143 |
+
the specified scale. If loss_scale is a string, it must be
|
144 |
+
'dynamic', then dynamic loss scaling will be used.
|
145 |
+
It can also be a dict containing arguments of GradScalar.
|
146 |
+
Defaults to 512. For Pytorch >= 1.6, mmcv uses official
|
147 |
+
implementation of GradScaler. If you use a dict version of
|
148 |
+
loss_scale to create GradScaler, please refer to:
|
149 |
+
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
|
150 |
+
for the parameters.
|
151 |
+
|
152 |
+
Examples:
|
153 |
+
>>> loss_scale = dict(
|
154 |
+
... init_scale=65536.0,
|
155 |
+
... growth_factor=2.0,
|
156 |
+
... backoff_factor=0.5,
|
157 |
+
... growth_interval=2000
|
158 |
+
... )
|
159 |
+
>>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
|
160 |
+
"""
|
161 |
+
|
162 |
+
def __init__(self,
|
163 |
+
grad_clip=None,
|
164 |
+
coalesce=True,
|
165 |
+
bucket_size_mb=-1,
|
166 |
+
loss_scale=512.,
|
167 |
+
distributed=True):
|
168 |
+
self.grad_clip = grad_clip
|
169 |
+
self.coalesce = coalesce
|
170 |
+
self.bucket_size_mb = bucket_size_mb
|
171 |
+
self.distributed = distributed
|
172 |
+
self._scale_update_param = None
|
173 |
+
if loss_scale == 'dynamic':
|
174 |
+
self.loss_scaler = GradScaler()
|
175 |
+
elif isinstance(loss_scale, float):
|
176 |
+
self._scale_update_param = loss_scale
|
177 |
+
self.loss_scaler = GradScaler(init_scale=loss_scale)
|
178 |
+
elif isinstance(loss_scale, dict):
|
179 |
+
self.loss_scaler = GradScaler(**loss_scale)
|
180 |
+
else:
|
181 |
+
raise ValueError('loss_scale must be of type float, dict, or '
|
182 |
+
f'"dynamic", got {loss_scale}')
|
183 |
+
|
184 |
+
def before_run(self, runner):
|
185 |
+
"""Preparing steps before Mixed Precision Training."""
|
186 |
+
# wrap model mode to fp16
|
187 |
+
wrap_fp16_model(runner.model)
|
188 |
+
# resume from state dict
|
189 |
+
if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
|
190 |
+
scaler_state_dict = runner.meta['fp16']['loss_scaler']
|
191 |
+
self.loss_scaler.load_state_dict(scaler_state_dict)
|
192 |
+
|
193 |
+
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
|
194 |
+
"""Copy gradients from fp16 model to fp32 weight copy."""
|
195 |
+
for fp32_param, fp16_param in zip(fp32_weights,
|
196 |
+
fp16_net.parameters()):
|
197 |
+
if fp16_param.grad is not None:
|
198 |
+
if fp32_param.grad is None:
|
199 |
+
fp32_param.grad = fp32_param.data.new(
|
200 |
+
fp32_param.size())
|
201 |
+
fp32_param.grad.copy_(fp16_param.grad)
|
202 |
+
|
203 |
+
def copy_params_to_fp16(self, fp16_net, fp32_weights):
|
204 |
+
"""Copy updated params from fp32 weight copy to fp16 model."""
|
205 |
+
for fp16_param, fp32_param in zip(fp16_net.parameters(),
|
206 |
+
fp32_weights):
|
207 |
+
fp16_param.data.copy_(fp32_param.data)
|
208 |
+
|
209 |
+
def after_train_iter(self, runner):
|
210 |
+
"""Backward optimization steps for Mixed Precision Training. For
|
211 |
+
dynamic loss scaling, please refer to
|
212 |
+
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
|
213 |
+
|
214 |
+
1. Scale the loss by a scale factor.
|
215 |
+
2. Backward the loss to obtain the gradients.
|
216 |
+
3. Unscale the optimizer’s gradient tensors.
|
217 |
+
4. Call optimizer.step() and update scale factor.
|
218 |
+
5. Save loss_scaler state_dict for resume purpose.
|
219 |
+
"""
|
220 |
+
# clear grads of last iteration
|
221 |
+
runner.model.zero_grad()
|
222 |
+
runner.optimizer.zero_grad()
|
223 |
+
|
224 |
+
self.loss_scaler.scale(runner.outputs['loss']).backward()
|
225 |
+
self.loss_scaler.unscale_(runner.optimizer)
|
226 |
+
# grad clip
|
227 |
+
if self.grad_clip is not None:
|
228 |
+
grad_norm = self.clip_grads(runner.model.parameters())
|
229 |
+
if grad_norm is not None:
|
230 |
+
# Add grad norm to the logger
|
231 |
+
runner.log_buffer.update({'grad_norm': float(grad_norm)},
|
232 |
+
runner.outputs['num_samples'])
|
233 |
+
# backward and update scaler
|
234 |
+
self.loss_scaler.step(runner.optimizer)
|
235 |
+
self.loss_scaler.update(self._scale_update_param)
|
236 |
+
|
237 |
+
# save state_dict of loss_scaler
|
238 |
+
runner.meta.setdefault(
|
239 |
+
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
|
240 |
+
|
241 |
+
@HOOKS.register_module()
|
242 |
+
class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
|
243 |
+
Fp16OptimizerHook):
|
244 |
+
"""Fp16 optimizer Hook (using PyTorch's implementation) implements
|
245 |
+
multi-iters gradient cumulating.
|
246 |
+
|
247 |
+
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
|
248 |
+
to take care of the optimization procedure.
|
249 |
+
"""
|
250 |
+
|
251 |
+
def __init__(self, *args, **kwargs):
|
252 |
+
super(GradientCumulativeFp16OptimizerHook,
|
253 |
+
self).__init__(*args, **kwargs)
|
254 |
+
|
255 |
+
def after_train_iter(self, runner):
|
256 |
+
if not self.initialized:
|
257 |
+
self._init(runner)
|
258 |
+
|
259 |
+
if runner.iter < self.divisible_iters:
|
260 |
+
loss_factor = self.cumulative_iters
|
261 |
+
else:
|
262 |
+
loss_factor = self.remainder_iters
|
263 |
+
loss = runner.outputs['loss']
|
264 |
+
loss = loss / loss_factor
|
265 |
+
|
266 |
+
self.loss_scaler.scale(loss).backward()
|
267 |
+
|
268 |
+
if (self.every_n_iters(runner, self.cumulative_iters)
|
269 |
+
or self.is_last_iter(runner)):
|
270 |
+
|
271 |
+
# copy fp16 grads in the model to fp32 params in the optimizer
|
272 |
+
self.loss_scaler.unscale_(runner.optimizer)
|
273 |
+
|
274 |
+
if self.grad_clip is not None:
|
275 |
+
grad_norm = self.clip_grads(runner.model.parameters())
|
276 |
+
if grad_norm is not None:
|
277 |
+
# Add grad norm to the logger
|
278 |
+
runner.log_buffer.update(
|
279 |
+
{'grad_norm': float(grad_norm)},
|
280 |
+
runner.outputs['num_samples'])
|
281 |
+
|
282 |
+
# backward and update scaler
|
283 |
+
self.loss_scaler.step(runner.optimizer)
|
284 |
+
self.loss_scaler.update(self._scale_update_param)
|
285 |
+
|
286 |
+
# save state_dict of loss_scaler
|
287 |
+
runner.meta.setdefault(
|
288 |
+
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
|
289 |
+
|
290 |
+
# clear grads
|
291 |
+
runner.model.zero_grad()
|
292 |
+
runner.optimizer.zero_grad()
|
293 |
+
|
294 |
+
else:
|
295 |
+
|
296 |
+
@HOOKS.register_module()
|
297 |
+
class Fp16OptimizerHook(OptimizerHook):
|
298 |
+
"""FP16 optimizer hook (mmcv's implementation).
|
299 |
+
|
300 |
+
The steps of fp16 optimizer is as follows.
|
301 |
+
1. Scale the loss value.
|
302 |
+
2. BP in the fp16 model.
|
303 |
+
2. Copy gradients from fp16 model to fp32 weights.
|
304 |
+
3. Update fp32 weights.
|
305 |
+
4. Copy updated parameters from fp32 weights to fp16 model.
|
306 |
+
|
307 |
+
Refer to https://arxiv.org/abs/1710.03740 for more details.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
loss_scale (float | str | dict): Scale factor configuration.
|
311 |
+
If loss_scale is a float, static loss scaling will be used with
|
312 |
+
the specified scale. If loss_scale is a string, it must be
|
313 |
+
'dynamic', then dynamic loss scaling will be used.
|
314 |
+
It can also be a dict containing arguments of LossScaler.
|
315 |
+
Defaults to 512.
|
316 |
+
"""
|
317 |
+
|
318 |
+
def __init__(self,
|
319 |
+
grad_clip=None,
|
320 |
+
coalesce=True,
|
321 |
+
bucket_size_mb=-1,
|
322 |
+
loss_scale=512.,
|
323 |
+
distributed=True):
|
324 |
+
self.grad_clip = grad_clip
|
325 |
+
self.coalesce = coalesce
|
326 |
+
self.bucket_size_mb = bucket_size_mb
|
327 |
+
self.distributed = distributed
|
328 |
+
if loss_scale == 'dynamic':
|
329 |
+
self.loss_scaler = LossScaler(mode='dynamic')
|
330 |
+
elif isinstance(loss_scale, float):
|
331 |
+
self.loss_scaler = LossScaler(
|
332 |
+
init_scale=loss_scale, mode='static')
|
333 |
+
elif isinstance(loss_scale, dict):
|
334 |
+
self.loss_scaler = LossScaler(**loss_scale)
|
335 |
+
else:
|
336 |
+
raise ValueError('loss_scale must be of type float, dict, or '
|
337 |
+
f'"dynamic", got {loss_scale}')
|
338 |
+
|
339 |
+
def before_run(self, runner):
|
340 |
+
"""Preparing steps before Mixed Precision Training.
|
341 |
+
|
342 |
+
1. Make a master copy of fp32 weights for optimization.
|
343 |
+
2. Convert the main model from fp32 to fp16.
|
344 |
+
"""
|
345 |
+
# keep a copy of fp32 weights
|
346 |
+
old_groups = runner.optimizer.param_groups
|
347 |
+
runner.optimizer.param_groups = copy.deepcopy(
|
348 |
+
runner.optimizer.param_groups)
|
349 |
+
state = defaultdict(dict)
|
350 |
+
p_map = {
|
351 |
+
old_p: p
|
352 |
+
for old_p, p in zip(
|
353 |
+
chain(*(g['params'] for g in old_groups)),
|
354 |
+
chain(*(g['params']
|
355 |
+
for g in runner.optimizer.param_groups)))
|
356 |
+
}
|
357 |
+
for k, v in runner.optimizer.state.items():
|
358 |
+
state[p_map[k]] = v
|
359 |
+
runner.optimizer.state = state
|
360 |
+
# convert model to fp16
|
361 |
+
wrap_fp16_model(runner.model)
|
362 |
+
# resume from state dict
|
363 |
+
if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
|
364 |
+
scaler_state_dict = runner.meta['fp16']['loss_scaler']
|
365 |
+
self.loss_scaler.load_state_dict(scaler_state_dict)
|
366 |
+
|
367 |
+
def copy_grads_to_fp32(self, fp16_net, fp32_weights):
|
368 |
+
"""Copy gradients from fp16 model to fp32 weight copy."""
|
369 |
+
for fp32_param, fp16_param in zip(fp32_weights,
|
370 |
+
fp16_net.parameters()):
|
371 |
+
if fp16_param.grad is not None:
|
372 |
+
if fp32_param.grad is None:
|
373 |
+
fp32_param.grad = fp32_param.data.new(
|
374 |
+
fp32_param.size())
|
375 |
+
fp32_param.grad.copy_(fp16_param.grad)
|
376 |
+
|
377 |
+
def copy_params_to_fp16(self, fp16_net, fp32_weights):
|
378 |
+
"""Copy updated params from fp32 weight copy to fp16 model."""
|
379 |
+
for fp16_param, fp32_param in zip(fp16_net.parameters(),
|
380 |
+
fp32_weights):
|
381 |
+
fp16_param.data.copy_(fp32_param.data)
|
382 |
+
|
383 |
+
def after_train_iter(self, runner):
|
384 |
+
"""Backward optimization steps for Mixed Precision Training. For
|
385 |
+
dynamic loss scaling, please refer `loss_scalar.py`
|
386 |
+
|
387 |
+
1. Scale the loss by a scale factor.
|
388 |
+
2. Backward the loss to obtain the gradients (fp16).
|
389 |
+
3. Copy gradients from the model to the fp32 weight copy.
|
390 |
+
4. Scale the gradients back and update the fp32 weight copy.
|
391 |
+
5. Copy back the params from fp32 weight copy to the fp16 model.
|
392 |
+
6. Save loss_scaler state_dict for resume purpose.
|
393 |
+
"""
|
394 |
+
# clear grads of last iteration
|
395 |
+
runner.model.zero_grad()
|
396 |
+
runner.optimizer.zero_grad()
|
397 |
+
# scale the loss value
|
398 |
+
scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale
|
399 |
+
scaled_loss.backward()
|
400 |
+
# copy fp16 grads in the model to fp32 params in the optimizer
|
401 |
+
|
402 |
+
fp32_weights = []
|
403 |
+
for param_group in runner.optimizer.param_groups:
|
404 |
+
fp32_weights += param_group['params']
|
405 |
+
self.copy_grads_to_fp32(runner.model, fp32_weights)
|
406 |
+
# allreduce grads
|
407 |
+
if self.distributed:
|
408 |
+
allreduce_grads(fp32_weights, self.coalesce,
|
409 |
+
self.bucket_size_mb)
|
410 |
+
|
411 |
+
has_overflow = self.loss_scaler.has_overflow(fp32_weights)
|
412 |
+
# if has overflow, skip this iteration
|
413 |
+
if not has_overflow:
|
414 |
+
# scale the gradients back
|
415 |
+
for param in fp32_weights:
|
416 |
+
if param.grad is not None:
|
417 |
+
param.grad.div_(self.loss_scaler.loss_scale)
|
418 |
+
if self.grad_clip is not None:
|
419 |
+
grad_norm = self.clip_grads(fp32_weights)
|
420 |
+
if grad_norm is not None:
|
421 |
+
# Add grad norm to the logger
|
422 |
+
runner.log_buffer.update(
|
423 |
+
{'grad_norm': float(grad_norm)},
|
424 |
+
runner.outputs['num_samples'])
|
425 |
+
# update fp32 params
|
426 |
+
runner.optimizer.step()
|
427 |
+
# copy fp32 params to the fp16 model
|
428 |
+
self.copy_params_to_fp16(runner.model, fp32_weights)
|
429 |
+
self.loss_scaler.update_scale(has_overflow)
|
430 |
+
if has_overflow:
|
431 |
+
runner.logger.warning('Check overflow, downscale loss scale '
|
432 |
+
f'to {self.loss_scaler.cur_scale}')
|
433 |
+
|
434 |
+
# save state_dict of loss_scaler
|
435 |
+
runner.meta.setdefault(
|
436 |
+
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
|
437 |
+
|
438 |
+
@HOOKS.register_module()
|
439 |
+
class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
|
440 |
+
Fp16OptimizerHook):
|
441 |
+
"""Fp16 optimizer Hook (using mmcv implementation) implements multi-
|
442 |
+
iters gradient cumulating."""
|
443 |
+
|
444 |
+
def __init__(self, *args, **kwargs):
|
445 |
+
super(GradientCumulativeFp16OptimizerHook,
|
446 |
+
self).__init__(*args, **kwargs)
|
447 |
+
|
448 |
+
def after_train_iter(self, runner):
|
449 |
+
if not self.initialized:
|
450 |
+
self._init(runner)
|
451 |
+
|
452 |
+
if runner.iter < self.divisible_iters:
|
453 |
+
loss_factor = self.cumulative_iters
|
454 |
+
else:
|
455 |
+
loss_factor = self.remainder_iters
|
456 |
+
|
457 |
+
loss = runner.outputs['loss']
|
458 |
+
loss = loss / loss_factor
|
459 |
+
|
460 |
+
# scale the loss value
|
461 |
+
scaled_loss = loss * self.loss_scaler.loss_scale
|
462 |
+
scaled_loss.backward()
|
463 |
+
|
464 |
+
if (self.every_n_iters(runner, self.cumulative_iters)
|
465 |
+
or self.is_last_iter(runner)):
|
466 |
+
|
467 |
+
# copy fp16 grads in the model to fp32 params in the optimizer
|
468 |
+
fp32_weights = []
|
469 |
+
for param_group in runner.optimizer.param_groups:
|
470 |
+
fp32_weights += param_group['params']
|
471 |
+
self.copy_grads_to_fp32(runner.model, fp32_weights)
|
472 |
+
# allreduce grads
|
473 |
+
if self.distributed:
|
474 |
+
allreduce_grads(fp32_weights, self.coalesce,
|
475 |
+
self.bucket_size_mb)
|
476 |
+
|
477 |
+
has_overflow = self.loss_scaler.has_overflow(fp32_weights)
|
478 |
+
# if has overflow, skip this iteration
|
479 |
+
if not has_overflow:
|
480 |
+
# scale the gradients back
|
481 |
+
for param in fp32_weights:
|
482 |
+
if param.grad is not None:
|
483 |
+
param.grad.div_(self.loss_scaler.loss_scale)
|
484 |
+
if self.grad_clip is not None:
|
485 |
+
grad_norm = self.clip_grads(fp32_weights)
|
486 |
+
if grad_norm is not None:
|
487 |
+
# Add grad norm to the logger
|
488 |
+
runner.log_buffer.update(
|
489 |
+
{'grad_norm': float(grad_norm)},
|
490 |
+
runner.outputs['num_samples'])
|
491 |
+
# update fp32 params
|
492 |
+
runner.optimizer.step()
|
493 |
+
# copy fp32 params to the fp16 model
|
494 |
+
self.copy_params_to_fp16(runner.model, fp32_weights)
|
495 |
+
else:
|
496 |
+
runner.logger.warning(
|
497 |
+
'Check overflow, downscale loss scale '
|
498 |
+
f'to {self.loss_scaler.cur_scale}')
|
499 |
+
|
500 |
+
self.loss_scaler.update_scale(has_overflow)
|
501 |
+
|
502 |
+
# save state_dict of loss_scaler
|
503 |
+
runner.meta.setdefault(
|
504 |
+
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
|
505 |
+
|
506 |
+
# clear grads
|
507 |
+
runner.model.zero_grad()
|
508 |
+
runner.optimizer.zero_grad()
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/profiler.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import warnings
|
3 |
+
from typing import Callable, List, Optional, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from ..dist_utils import master_only
|
8 |
+
from .hook import HOOKS, Hook
|
9 |
+
|
10 |
+
|
11 |
+
@HOOKS.register_module()
|
12 |
+
class ProfilerHook(Hook):
|
13 |
+
"""Profiler to analyze performance during training.
|
14 |
+
|
15 |
+
PyTorch Profiler is a tool that allows the collection of the performance
|
16 |
+
metrics during the training. More details on Profiler can be found at
|
17 |
+
https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile
|
18 |
+
|
19 |
+
Args:
|
20 |
+
by_epoch (bool): Profile performance by epoch or by iteration.
|
21 |
+
Default: True.
|
22 |
+
profile_iters (int): Number of iterations for profiling.
|
23 |
+
If ``by_epoch=True``, profile_iters indicates that they are the
|
24 |
+
first profile_iters epochs at the beginning of the
|
25 |
+
training, otherwise it indicates the first profile_iters
|
26 |
+
iterations. Default: 1.
|
27 |
+
activities (list[str]): List of activity groups (CPU, CUDA) to use in
|
28 |
+
profiling. Default: ['cpu', 'cuda'].
|
29 |
+
schedule (dict, optional): Config of generating the callable schedule.
|
30 |
+
if schedule is None, profiler will not add step markers into the
|
31 |
+
trace and table view. Default: None.
|
32 |
+
on_trace_ready (callable, dict): Either a handler or a dict of generate
|
33 |
+
handler. Default: None.
|
34 |
+
record_shapes (bool): Save information about operator's input shapes.
|
35 |
+
Default: False.
|
36 |
+
profile_memory (bool): Track tensor memory allocation/deallocation.
|
37 |
+
Default: False.
|
38 |
+
with_stack (bool): Record source information (file and line number)
|
39 |
+
for the ops. Default: False.
|
40 |
+
with_flops (bool): Use formula to estimate the FLOPS of specific
|
41 |
+
operators (matrix multiplication and 2D convolution).
|
42 |
+
Default: False.
|
43 |
+
json_trace_path (str, optional): Exports the collected trace in Chrome
|
44 |
+
JSON format. Default: None.
|
45 |
+
|
46 |
+
Example:
|
47 |
+
>>> runner = ... # instantiate a Runner
|
48 |
+
>>> # tensorboard trace
|
49 |
+
>>> trace_config = dict(type='tb_trace', dir_name='work_dir')
|
50 |
+
>>> profiler_config = dict(on_trace_ready=trace_config)
|
51 |
+
>>> runner.register_profiler_hook(profiler_config)
|
52 |
+
>>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)])
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self,
|
56 |
+
by_epoch: bool = True,
|
57 |
+
profile_iters: int = 1,
|
58 |
+
activities: List[str] = ['cpu', 'cuda'],
|
59 |
+
schedule: Optional[dict] = None,
|
60 |
+
on_trace_ready: Optional[Union[Callable, dict]] = None,
|
61 |
+
record_shapes: bool = False,
|
62 |
+
profile_memory: bool = False,
|
63 |
+
with_stack: bool = False,
|
64 |
+
with_flops: bool = False,
|
65 |
+
json_trace_path: Optional[str] = None) -> None:
|
66 |
+
try:
|
67 |
+
from torch import profiler # torch version >= 1.8.1
|
68 |
+
except ImportError:
|
69 |
+
raise ImportError('profiler is the new feature of torch1.8.1, '
|
70 |
+
f'but your version is {torch.__version__}')
|
71 |
+
|
72 |
+
assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
|
73 |
+
self.by_epoch = by_epoch
|
74 |
+
|
75 |
+
if profile_iters < 1:
|
76 |
+
raise ValueError('profile_iters should be greater than 0, but got '
|
77 |
+
f'{profile_iters}')
|
78 |
+
self.profile_iters = profile_iters
|
79 |
+
|
80 |
+
if not isinstance(activities, list):
|
81 |
+
raise ValueError(
|
82 |
+
f'activities should be list, but got {type(activities)}')
|
83 |
+
self.activities = []
|
84 |
+
for activity in activities:
|
85 |
+
activity = activity.lower()
|
86 |
+
if activity == 'cpu':
|
87 |
+
self.activities.append(profiler.ProfilerActivity.CPU)
|
88 |
+
elif activity == 'cuda':
|
89 |
+
self.activities.append(profiler.ProfilerActivity.CUDA)
|
90 |
+
else:
|
91 |
+
raise ValueError(
|
92 |
+
f'activity should be "cpu" or "cuda", but got {activity}')
|
93 |
+
|
94 |
+
if schedule is not None:
|
95 |
+
self.schedule = profiler.schedule(**schedule)
|
96 |
+
else:
|
97 |
+
self.schedule = None
|
98 |
+
|
99 |
+
self.on_trace_ready = on_trace_ready
|
100 |
+
self.record_shapes = record_shapes
|
101 |
+
self.profile_memory = profile_memory
|
102 |
+
self.with_stack = with_stack
|
103 |
+
self.with_flops = with_flops
|
104 |
+
self.json_trace_path = json_trace_path
|
105 |
+
|
106 |
+
@master_only
|
107 |
+
def before_run(self, runner):
|
108 |
+
if self.by_epoch and runner.max_epochs < self.profile_iters:
|
109 |
+
raise ValueError('self.profile_iters should not be greater than '
|
110 |
+
f'{runner.max_epochs}')
|
111 |
+
|
112 |
+
if not self.by_epoch and runner.max_iters < self.profile_iters:
|
113 |
+
raise ValueError('self.profile_iters should not be greater than '
|
114 |
+
f'{runner.max_iters}')
|
115 |
+
|
116 |
+
if callable(self.on_trace_ready): # handler
|
117 |
+
_on_trace_ready = self.on_trace_ready
|
118 |
+
elif isinstance(self.on_trace_ready, dict): # config of handler
|
119 |
+
trace_cfg = self.on_trace_ready.copy()
|
120 |
+
trace_type = trace_cfg.pop('type') # log_trace handler
|
121 |
+
if trace_type == 'log_trace':
|
122 |
+
|
123 |
+
def _log_handler(prof):
|
124 |
+
print(prof.key_averages().table(**trace_cfg))
|
125 |
+
|
126 |
+
_on_trace_ready = _log_handler
|
127 |
+
elif trace_type == 'tb_trace': # tensorboard_trace handler
|
128 |
+
try:
|
129 |
+
import torch_tb_profiler # noqa: F401
|
130 |
+
except ImportError:
|
131 |
+
raise ImportError('please run "pip install '
|
132 |
+
'torch-tb-profiler" to install '
|
133 |
+
'torch_tb_profiler')
|
134 |
+
_on_trace_ready = torch.profiler.tensorboard_trace_handler(
|
135 |
+
**trace_cfg)
|
136 |
+
else:
|
137 |
+
raise ValueError('trace_type should be "log_trace" or '
|
138 |
+
f'"tb_trace", but got {trace_type}')
|
139 |
+
elif self.on_trace_ready is None:
|
140 |
+
_on_trace_ready = None # type: ignore
|
141 |
+
else:
|
142 |
+
raise ValueError('on_trace_ready should be handler, dict or None, '
|
143 |
+
f'but got {type(self.on_trace_ready)}')
|
144 |
+
|
145 |
+
if runner.max_epochs > 1:
|
146 |
+
warnings.warn(f'profiler will profile {runner.max_epochs} epochs '
|
147 |
+
'instead of 1 epoch. Since profiler will slow down '
|
148 |
+
'the training, it is recommended to train 1 epoch '
|
149 |
+
'with ProfilerHook and adjust your setting according'
|
150 |
+
' to the profiler summary. During normal training '
|
151 |
+
'(epoch > 1), you may disable the ProfilerHook.')
|
152 |
+
|
153 |
+
self.profiler = torch.profiler.profile(
|
154 |
+
activities=self.activities,
|
155 |
+
schedule=self.schedule,
|
156 |
+
on_trace_ready=_on_trace_ready,
|
157 |
+
record_shapes=self.record_shapes,
|
158 |
+
profile_memory=self.profile_memory,
|
159 |
+
with_stack=self.with_stack,
|
160 |
+
with_flops=self.with_flops)
|
161 |
+
|
162 |
+
self.profiler.__enter__()
|
163 |
+
runner.logger.info('profiler is profiling...')
|
164 |
+
|
165 |
+
@master_only
|
166 |
+
def after_train_epoch(self, runner):
|
167 |
+
if self.by_epoch and runner.epoch == self.profile_iters - 1:
|
168 |
+
runner.logger.info('profiler may take a few minutes...')
|
169 |
+
self.profiler.__exit__(None, None, None)
|
170 |
+
if self.json_trace_path is not None:
|
171 |
+
self.profiler.export_chrome_trace(self.json_trace_path)
|
172 |
+
|
173 |
+
@master_only
|
174 |
+
def after_train_iter(self, runner):
|
175 |
+
self.profiler.step()
|
176 |
+
if not self.by_epoch and runner.iter == self.profile_iters - 1:
|
177 |
+
runner.logger.info('profiler may take a few minutes...')
|
178 |
+
self.profiler.__exit__(None, None, None)
|
179 |
+
if self.json_trace_path is not None:
|
180 |
+
self.profiler.export_chrome_trace(self.json_trace_path)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/sampler_seed.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .hook import HOOKS, Hook
|
3 |
+
|
4 |
+
|
5 |
+
@HOOKS.register_module()
|
6 |
+
class DistSamplerSeedHook(Hook):
|
7 |
+
"""Data-loading sampler for distributed training.
|
8 |
+
|
9 |
+
When distributed training, it is only useful in conjunction with
|
10 |
+
:obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
|
11 |
+
purpose with :obj:`IterLoader`.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def before_epoch(self, runner):
|
15 |
+
if hasattr(runner.data_loader.sampler, 'set_epoch'):
|
16 |
+
# in case the data loader uses `SequentialSampler` in Pytorch
|
17 |
+
runner.data_loader.sampler.set_epoch(runner.epoch)
|
18 |
+
elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
|
19 |
+
# batch sampler in pytorch warps the sampler as its attributes.
|
20 |
+
runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/hooks/sync_buffer.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from ..dist_utils import allreduce_params
|
3 |
+
from .hook import HOOKS, Hook
|
4 |
+
|
5 |
+
|
6 |
+
@HOOKS.register_module()
|
7 |
+
class SyncBuffersHook(Hook):
|
8 |
+
"""Synchronize model buffers such as running_mean and running_var in BN at
|
9 |
+
the end of each epoch.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
distributed (bool): Whether distributed training is used. It is
|
13 |
+
effective only for distributed training. Defaults to True.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, distributed=True):
|
17 |
+
self.distributed = distributed
|
18 |
+
|
19 |
+
def after_epoch(self, runner):
|
20 |
+
"""All-reduce model buffers at the end of each epoch."""
|
21 |
+
if self.distributed:
|
22 |
+
allreduce_params(runner.model.buffers())
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/iter_based_runner.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import platform
|
4 |
+
import shutil
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.optim import Optimizer
|
10 |
+
|
11 |
+
import annotator.mmpkg.mmcv as mmcv
|
12 |
+
from .base_runner import BaseRunner
|
13 |
+
from .builder import RUNNERS
|
14 |
+
from .checkpoint import save_checkpoint
|
15 |
+
from .hooks import IterTimerHook
|
16 |
+
from .utils import get_host_info
|
17 |
+
|
18 |
+
|
19 |
+
class IterLoader:
|
20 |
+
|
21 |
+
def __init__(self, dataloader):
|
22 |
+
self._dataloader = dataloader
|
23 |
+
self.iter_loader = iter(self._dataloader)
|
24 |
+
self._epoch = 0
|
25 |
+
|
26 |
+
@property
|
27 |
+
def epoch(self):
|
28 |
+
return self._epoch
|
29 |
+
|
30 |
+
def __next__(self):
|
31 |
+
try:
|
32 |
+
data = next(self.iter_loader)
|
33 |
+
except StopIteration:
|
34 |
+
self._epoch += 1
|
35 |
+
if hasattr(self._dataloader.sampler, 'set_epoch'):
|
36 |
+
self._dataloader.sampler.set_epoch(self._epoch)
|
37 |
+
time.sleep(2) # Prevent possible deadlock during epoch transition
|
38 |
+
self.iter_loader = iter(self._dataloader)
|
39 |
+
data = next(self.iter_loader)
|
40 |
+
|
41 |
+
return data
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self._dataloader)
|
45 |
+
|
46 |
+
|
47 |
+
@RUNNERS.register_module()
|
48 |
+
class IterBasedRunner(BaseRunner):
|
49 |
+
"""Iteration-based Runner.
|
50 |
+
|
51 |
+
This runner train models iteration by iteration.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def train(self, data_loader, **kwargs):
|
55 |
+
self.model.train()
|
56 |
+
self.mode = 'train'
|
57 |
+
self.data_loader = data_loader
|
58 |
+
self._epoch = data_loader.epoch
|
59 |
+
data_batch = next(data_loader)
|
60 |
+
self.call_hook('before_train_iter')
|
61 |
+
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
|
62 |
+
if not isinstance(outputs, dict):
|
63 |
+
raise TypeError('model.train_step() must return a dict')
|
64 |
+
if 'log_vars' in outputs:
|
65 |
+
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
|
66 |
+
self.outputs = outputs
|
67 |
+
self.call_hook('after_train_iter')
|
68 |
+
self._inner_iter += 1
|
69 |
+
self._iter += 1
|
70 |
+
|
71 |
+
@torch.no_grad()
|
72 |
+
def val(self, data_loader, **kwargs):
|
73 |
+
self.model.eval()
|
74 |
+
self.mode = 'val'
|
75 |
+
self.data_loader = data_loader
|
76 |
+
data_batch = next(data_loader)
|
77 |
+
self.call_hook('before_val_iter')
|
78 |
+
outputs = self.model.val_step(data_batch, **kwargs)
|
79 |
+
if not isinstance(outputs, dict):
|
80 |
+
raise TypeError('model.val_step() must return a dict')
|
81 |
+
if 'log_vars' in outputs:
|
82 |
+
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
|
83 |
+
self.outputs = outputs
|
84 |
+
self.call_hook('after_val_iter')
|
85 |
+
self._inner_iter += 1
|
86 |
+
|
87 |
+
def run(self, data_loaders, workflow, max_iters=None, **kwargs):
|
88 |
+
"""Start running.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
|
92 |
+
and validation.
|
93 |
+
workflow (list[tuple]): A list of (phase, iters) to specify the
|
94 |
+
running order and iterations. E.g, [('train', 10000),
|
95 |
+
('val', 1000)] means running 10000 iterations for training and
|
96 |
+
1000 iterations for validation, iteratively.
|
97 |
+
"""
|
98 |
+
assert isinstance(data_loaders, list)
|
99 |
+
assert mmcv.is_list_of(workflow, tuple)
|
100 |
+
assert len(data_loaders) == len(workflow)
|
101 |
+
if max_iters is not None:
|
102 |
+
warnings.warn(
|
103 |
+
'setting max_iters in run is deprecated, '
|
104 |
+
'please set max_iters in runner_config', DeprecationWarning)
|
105 |
+
self._max_iters = max_iters
|
106 |
+
assert self._max_iters is not None, (
|
107 |
+
'max_iters must be specified during instantiation')
|
108 |
+
|
109 |
+
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
|
110 |
+
self.logger.info('Start running, host: %s, work_dir: %s',
|
111 |
+
get_host_info(), work_dir)
|
112 |
+
self.logger.info('Hooks will be executed in the following order:\n%s',
|
113 |
+
self.get_hook_info())
|
114 |
+
self.logger.info('workflow: %s, max: %d iters', workflow,
|
115 |
+
self._max_iters)
|
116 |
+
self.call_hook('before_run')
|
117 |
+
|
118 |
+
iter_loaders = [IterLoader(x) for x in data_loaders]
|
119 |
+
|
120 |
+
self.call_hook('before_epoch')
|
121 |
+
|
122 |
+
while self.iter < self._max_iters:
|
123 |
+
for i, flow in enumerate(workflow):
|
124 |
+
self._inner_iter = 0
|
125 |
+
mode, iters = flow
|
126 |
+
if not isinstance(mode, str) or not hasattr(self, mode):
|
127 |
+
raise ValueError(
|
128 |
+
'runner has no method named "{}" to run a workflow'.
|
129 |
+
format(mode))
|
130 |
+
iter_runner = getattr(self, mode)
|
131 |
+
for _ in range(iters):
|
132 |
+
if mode == 'train' and self.iter >= self._max_iters:
|
133 |
+
break
|
134 |
+
iter_runner(iter_loaders[i], **kwargs)
|
135 |
+
|
136 |
+
time.sleep(1) # wait for some hooks like loggers to finish
|
137 |
+
self.call_hook('after_epoch')
|
138 |
+
self.call_hook('after_run')
|
139 |
+
|
140 |
+
def resume(self,
|
141 |
+
checkpoint,
|
142 |
+
resume_optimizer=True,
|
143 |
+
map_location='default'):
|
144 |
+
"""Resume model from checkpoint.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
checkpoint (str): Checkpoint to resume from.
|
148 |
+
resume_optimizer (bool, optional): Whether resume the optimizer(s)
|
149 |
+
if the checkpoint file includes optimizer(s). Default to True.
|
150 |
+
map_location (str, optional): Same as :func:`torch.load`.
|
151 |
+
Default to 'default'.
|
152 |
+
"""
|
153 |
+
if map_location == 'default':
|
154 |
+
device_id = torch.cuda.current_device()
|
155 |
+
checkpoint = self.load_checkpoint(
|
156 |
+
checkpoint,
|
157 |
+
map_location=lambda storage, loc: storage.cuda(device_id))
|
158 |
+
else:
|
159 |
+
checkpoint = self.load_checkpoint(
|
160 |
+
checkpoint, map_location=map_location)
|
161 |
+
|
162 |
+
self._epoch = checkpoint['meta']['epoch']
|
163 |
+
self._iter = checkpoint['meta']['iter']
|
164 |
+
self._inner_iter = checkpoint['meta']['iter']
|
165 |
+
if 'optimizer' in checkpoint and resume_optimizer:
|
166 |
+
if isinstance(self.optimizer, Optimizer):
|
167 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
168 |
+
elif isinstance(self.optimizer, dict):
|
169 |
+
for k in self.optimizer.keys():
|
170 |
+
self.optimizer[k].load_state_dict(
|
171 |
+
checkpoint['optimizer'][k])
|
172 |
+
else:
|
173 |
+
raise TypeError(
|
174 |
+
'Optimizer should be dict or torch.optim.Optimizer '
|
175 |
+
f'but got {type(self.optimizer)}')
|
176 |
+
|
177 |
+
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
|
178 |
+
|
179 |
+
def save_checkpoint(self,
|
180 |
+
out_dir,
|
181 |
+
filename_tmpl='iter_{}.pth',
|
182 |
+
meta=None,
|
183 |
+
save_optimizer=True,
|
184 |
+
create_symlink=True):
|
185 |
+
"""Save checkpoint to file.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
out_dir (str): Directory to save checkpoint files.
|
189 |
+
filename_tmpl (str, optional): Checkpoint file template.
|
190 |
+
Defaults to 'iter_{}.pth'.
|
191 |
+
meta (dict, optional): Metadata to be saved in checkpoint.
|
192 |
+
Defaults to None.
|
193 |
+
save_optimizer (bool, optional): Whether save optimizer.
|
194 |
+
Defaults to True.
|
195 |
+
create_symlink (bool, optional): Whether create symlink to the
|
196 |
+
latest checkpoint file. Defaults to True.
|
197 |
+
"""
|
198 |
+
if meta is None:
|
199 |
+
meta = {}
|
200 |
+
elif not isinstance(meta, dict):
|
201 |
+
raise TypeError(
|
202 |
+
f'meta should be a dict or None, but got {type(meta)}')
|
203 |
+
if self.meta is not None:
|
204 |
+
meta.update(self.meta)
|
205 |
+
# Note: meta.update(self.meta) should be done before
|
206 |
+
# meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
|
207 |
+
# there will be problems with resumed checkpoints.
|
208 |
+
# More details in https://github.com/open-mmlab/mmcv/pull/1108
|
209 |
+
meta.update(epoch=self.epoch + 1, iter=self.iter)
|
210 |
+
|
211 |
+
filename = filename_tmpl.format(self.iter + 1)
|
212 |
+
filepath = osp.join(out_dir, filename)
|
213 |
+
optimizer = self.optimizer if save_optimizer else None
|
214 |
+
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
|
215 |
+
# in some environments, `os.symlink` is not supported, you may need to
|
216 |
+
# set `create_symlink` to False
|
217 |
+
if create_symlink:
|
218 |
+
dst_file = osp.join(out_dir, 'latest.pth')
|
219 |
+
if platform.system() != 'Windows':
|
220 |
+
mmcv.symlink(filename, dst_file)
|
221 |
+
else:
|
222 |
+
shutil.copy(filepath, dst_file)
|
223 |
+
|
224 |
+
def register_training_hooks(self,
|
225 |
+
lr_config,
|
226 |
+
optimizer_config=None,
|
227 |
+
checkpoint_config=None,
|
228 |
+
log_config=None,
|
229 |
+
momentum_config=None,
|
230 |
+
custom_hooks_config=None):
|
231 |
+
"""Register default hooks for iter-based training.
|
232 |
+
|
233 |
+
Checkpoint hook, optimizer stepper hook and logger hooks will be set to
|
234 |
+
`by_epoch=False` by default.
|
235 |
+
|
236 |
+
Default hooks include:
|
237 |
+
|
238 |
+
+----------------------+-------------------------+
|
239 |
+
| Hooks | Priority |
|
240 |
+
+======================+=========================+
|
241 |
+
| LrUpdaterHook | VERY_HIGH (10) |
|
242 |
+
+----------------------+-------------------------+
|
243 |
+
| MomentumUpdaterHook | HIGH (30) |
|
244 |
+
+----------------------+-------------------------+
|
245 |
+
| OptimizerStepperHook | ABOVE_NORMAL (40) |
|
246 |
+
+----------------------+-------------------------+
|
247 |
+
| CheckpointSaverHook | NORMAL (50) |
|
248 |
+
+----------------------+-------------------------+
|
249 |
+
| IterTimerHook | LOW (70) |
|
250 |
+
+----------------------+-------------------------+
|
251 |
+
| LoggerHook(s) | VERY_LOW (90) |
|
252 |
+
+----------------------+-------------------------+
|
253 |
+
| CustomHook(s) | defaults to NORMAL (50) |
|
254 |
+
+----------------------+-------------------------+
|
255 |
+
|
256 |
+
If custom hooks have same priority with default hooks, custom hooks
|
257 |
+
will be triggered after default hooks.
|
258 |
+
"""
|
259 |
+
if checkpoint_config is not None:
|
260 |
+
checkpoint_config.setdefault('by_epoch', False)
|
261 |
+
if lr_config is not None:
|
262 |
+
lr_config.setdefault('by_epoch', False)
|
263 |
+
if log_config is not None:
|
264 |
+
for info in log_config['hooks']:
|
265 |
+
info.setdefault('by_epoch', False)
|
266 |
+
super(IterBasedRunner, self).register_training_hooks(
|
267 |
+
lr_config=lr_config,
|
268 |
+
momentum_config=momentum_config,
|
269 |
+
optimizer_config=optimizer_config,
|
270 |
+
checkpoint_config=checkpoint_config,
|
271 |
+
log_config=log_config,
|
272 |
+
timer_config=IterTimerHook(),
|
273 |
+
custom_hooks_config=custom_hooks_config)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/log_buffer.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class LogBuffer:
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
self.val_history = OrderedDict()
|
11 |
+
self.n_history = OrderedDict()
|
12 |
+
self.output = OrderedDict()
|
13 |
+
self.ready = False
|
14 |
+
|
15 |
+
def clear(self):
|
16 |
+
self.val_history.clear()
|
17 |
+
self.n_history.clear()
|
18 |
+
self.clear_output()
|
19 |
+
|
20 |
+
def clear_output(self):
|
21 |
+
self.output.clear()
|
22 |
+
self.ready = False
|
23 |
+
|
24 |
+
def update(self, vars, count=1):
|
25 |
+
assert isinstance(vars, dict)
|
26 |
+
for key, var in vars.items():
|
27 |
+
if key not in self.val_history:
|
28 |
+
self.val_history[key] = []
|
29 |
+
self.n_history[key] = []
|
30 |
+
self.val_history[key].append(var)
|
31 |
+
self.n_history[key].append(count)
|
32 |
+
|
33 |
+
def average(self, n=0):
|
34 |
+
"""Average latest n values or all values."""
|
35 |
+
assert n >= 0
|
36 |
+
for key in self.val_history:
|
37 |
+
values = np.array(self.val_history[key][-n:])
|
38 |
+
nums = np.array(self.n_history[key][-n:])
|
39 |
+
avg = np.sum(values * nums) / np.sum(nums)
|
40 |
+
self.output[key] = avg
|
41 |
+
self.ready = True
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer,
|
3 |
+
build_optimizer_constructor)
|
4 |
+
from .default_constructor import DefaultOptimizerConstructor
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
|
8 |
+
'build_optimizer', 'build_optimizer_constructor'
|
9 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/builder.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import inspect
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from ...utils import Registry, build_from_cfg
|
8 |
+
|
9 |
+
OPTIMIZERS = Registry('optimizer')
|
10 |
+
OPTIMIZER_BUILDERS = Registry('optimizer builder')
|
11 |
+
|
12 |
+
|
13 |
+
def register_torch_optimizers():
|
14 |
+
torch_optimizers = []
|
15 |
+
for module_name in dir(torch.optim):
|
16 |
+
if module_name.startswith('__'):
|
17 |
+
continue
|
18 |
+
_optim = getattr(torch.optim, module_name)
|
19 |
+
if inspect.isclass(_optim) and issubclass(_optim,
|
20 |
+
torch.optim.Optimizer):
|
21 |
+
OPTIMIZERS.register_module()(_optim)
|
22 |
+
torch_optimizers.append(module_name)
|
23 |
+
return torch_optimizers
|
24 |
+
|
25 |
+
|
26 |
+
TORCH_OPTIMIZERS = register_torch_optimizers()
|
27 |
+
|
28 |
+
|
29 |
+
def build_optimizer_constructor(cfg):
|
30 |
+
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
|
31 |
+
|
32 |
+
|
33 |
+
def build_optimizer(model, cfg):
|
34 |
+
optimizer_cfg = copy.deepcopy(cfg)
|
35 |
+
constructor_type = optimizer_cfg.pop('constructor',
|
36 |
+
'DefaultOptimizerConstructor')
|
37 |
+
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
38 |
+
optim_constructor = build_optimizer_constructor(
|
39 |
+
dict(
|
40 |
+
type=constructor_type,
|
41 |
+
optimizer_cfg=optimizer_cfg,
|
42 |
+
paramwise_cfg=paramwise_cfg))
|
43 |
+
optimizer = optim_constructor(model)
|
44 |
+
return optimizer
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/optimizer/default_constructor.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.nn import GroupNorm, LayerNorm
|
6 |
+
|
7 |
+
from annotator.mmpkg.mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
|
8 |
+
from annotator.mmpkg.mmcv.utils.ext_loader import check_ops_exist
|
9 |
+
from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS
|
10 |
+
|
11 |
+
|
12 |
+
@OPTIMIZER_BUILDERS.register_module()
|
13 |
+
class DefaultOptimizerConstructor:
|
14 |
+
"""Default constructor for optimizers.
|
15 |
+
|
16 |
+
By default each parameter share the same optimizer settings, and we
|
17 |
+
provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
|
18 |
+
It is a dict and may contain the following fields:
|
19 |
+
|
20 |
+
- ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
|
21 |
+
one of the keys in ``custom_keys`` is a substring of the name of one
|
22 |
+
parameter, then the setting of the parameter will be specified by
|
23 |
+
``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
|
24 |
+
be ignored. It should be noted that the aforementioned ``key`` is the
|
25 |
+
longest key that is a substring of the name of the parameter. If there
|
26 |
+
are multiple matched keys with the same length, then the key with lower
|
27 |
+
alphabet order will be chosen.
|
28 |
+
``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
|
29 |
+
and ``decay_mult``. See Example 2 below.
|
30 |
+
- ``bias_lr_mult`` (float): It will be multiplied to the learning
|
31 |
+
rate for all bias parameters (except for those in normalization
|
32 |
+
layers and offset layers of DCN).
|
33 |
+
- ``bias_decay_mult`` (float): It will be multiplied to the weight
|
34 |
+
decay for all bias parameters (except for those in
|
35 |
+
normalization layers, depthwise conv layers, offset layers of DCN).
|
36 |
+
- ``norm_decay_mult`` (float): It will be multiplied to the weight
|
37 |
+
decay for all weight and bias parameters of normalization
|
38 |
+
layers.
|
39 |
+
- ``dwconv_decay_mult`` (float): It will be multiplied to the weight
|
40 |
+
decay for all weight and bias parameters of depthwise conv
|
41 |
+
layers.
|
42 |
+
- ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
|
43 |
+
rate for parameters of offset layer in the deformable convs
|
44 |
+
of a model.
|
45 |
+
- ``bypass_duplicate`` (bool): If true, the duplicate parameters
|
46 |
+
would not be added into optimizer. Default: False.
|
47 |
+
|
48 |
+
Note:
|
49 |
+
1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
|
50 |
+
override the effect of ``bias_lr_mult`` in the bias of offset
|
51 |
+
layer. So be careful when using both ``bias_lr_mult`` and
|
52 |
+
``dcn_offset_lr_mult``. If you wish to apply both of them to the
|
53 |
+
offset layer in deformable convs, set ``dcn_offset_lr_mult``
|
54 |
+
to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
|
55 |
+
2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
|
56 |
+
apply it to all the DCN layers in the model. So be careful when
|
57 |
+
the model contains multiple DCN layers in places other than
|
58 |
+
backbone.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
model (:obj:`nn.Module`): The model with parameters to be optimized.
|
62 |
+
optimizer_cfg (dict): The config dict of the optimizer.
|
63 |
+
Positional fields are
|
64 |
+
|
65 |
+
- `type`: class name of the optimizer.
|
66 |
+
|
67 |
+
Optional fields are
|
68 |
+
|
69 |
+
- any arguments of the corresponding optimizer type, e.g.,
|
70 |
+
lr, weight_decay, momentum, etc.
|
71 |
+
paramwise_cfg (dict, optional): Parameter-wise options.
|
72 |
+
|
73 |
+
Example 1:
|
74 |
+
>>> model = torch.nn.modules.Conv1d(1, 1, 1)
|
75 |
+
>>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
|
76 |
+
>>> weight_decay=0.0001)
|
77 |
+
>>> paramwise_cfg = dict(norm_decay_mult=0.)
|
78 |
+
>>> optim_builder = DefaultOptimizerConstructor(
|
79 |
+
>>> optimizer_cfg, paramwise_cfg)
|
80 |
+
>>> optimizer = optim_builder(model)
|
81 |
+
|
82 |
+
Example 2:
|
83 |
+
>>> # assume model have attribute model.backbone and model.cls_head
|
84 |
+
>>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
|
85 |
+
>>> paramwise_cfg = dict(custom_keys={
|
86 |
+
'.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
|
87 |
+
>>> optim_builder = DefaultOptimizerConstructor(
|
88 |
+
>>> optimizer_cfg, paramwise_cfg)
|
89 |
+
>>> optimizer = optim_builder(model)
|
90 |
+
>>> # Then the `lr` and `weight_decay` for model.backbone is
|
91 |
+
>>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
|
92 |
+
>>> # model.cls_head is (0.01, 0.95).
|
93 |
+
"""
|
94 |
+
|
95 |
+
def __init__(self, optimizer_cfg, paramwise_cfg=None):
|
96 |
+
if not isinstance(optimizer_cfg, dict):
|
97 |
+
raise TypeError('optimizer_cfg should be a dict',
|
98 |
+
f'but got {type(optimizer_cfg)}')
|
99 |
+
self.optimizer_cfg = optimizer_cfg
|
100 |
+
self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
|
101 |
+
self.base_lr = optimizer_cfg.get('lr', None)
|
102 |
+
self.base_wd = optimizer_cfg.get('weight_decay', None)
|
103 |
+
self._validate_cfg()
|
104 |
+
|
105 |
+
def _validate_cfg(self):
|
106 |
+
if not isinstance(self.paramwise_cfg, dict):
|
107 |
+
raise TypeError('paramwise_cfg should be None or a dict, '
|
108 |
+
f'but got {type(self.paramwise_cfg)}')
|
109 |
+
|
110 |
+
if 'custom_keys' in self.paramwise_cfg:
|
111 |
+
if not isinstance(self.paramwise_cfg['custom_keys'], dict):
|
112 |
+
raise TypeError(
|
113 |
+
'If specified, custom_keys must be a dict, '
|
114 |
+
f'but got {type(self.paramwise_cfg["custom_keys"])}')
|
115 |
+
if self.base_wd is None:
|
116 |
+
for key in self.paramwise_cfg['custom_keys']:
|
117 |
+
if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]:
|
118 |
+
raise ValueError('base_wd should not be None')
|
119 |
+
|
120 |
+
# get base lr and weight decay
|
121 |
+
# weight_decay must be explicitly specified if mult is specified
|
122 |
+
if ('bias_decay_mult' in self.paramwise_cfg
|
123 |
+
or 'norm_decay_mult' in self.paramwise_cfg
|
124 |
+
or 'dwconv_decay_mult' in self.paramwise_cfg):
|
125 |
+
if self.base_wd is None:
|
126 |
+
raise ValueError('base_wd should not be None')
|
127 |
+
|
128 |
+
def _is_in(self, param_group, param_group_list):
|
129 |
+
assert is_list_of(param_group_list, dict)
|
130 |
+
param = set(param_group['params'])
|
131 |
+
param_set = set()
|
132 |
+
for group in param_group_list:
|
133 |
+
param_set.update(set(group['params']))
|
134 |
+
|
135 |
+
return not param.isdisjoint(param_set)
|
136 |
+
|
137 |
+
def add_params(self, params, module, prefix='', is_dcn_module=None):
|
138 |
+
"""Add all parameters of module to the params list.
|
139 |
+
|
140 |
+
The parameters of the given module will be added to the list of param
|
141 |
+
groups, with specific rules defined by paramwise_cfg.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
params (list[dict]): A list of param groups, it will be modified
|
145 |
+
in place.
|
146 |
+
module (nn.Module): The module to be added.
|
147 |
+
prefix (str): The prefix of the module
|
148 |
+
is_dcn_module (int|float|None): If the current module is a
|
149 |
+
submodule of DCN, `is_dcn_module` will be passed to
|
150 |
+
control conv_offset layer's learning rate. Defaults to None.
|
151 |
+
"""
|
152 |
+
# get param-wise options
|
153 |
+
custom_keys = self.paramwise_cfg.get('custom_keys', {})
|
154 |
+
# first sort with alphabet order and then sort with reversed len of str
|
155 |
+
sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
|
156 |
+
|
157 |
+
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
|
158 |
+
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
|
159 |
+
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
|
160 |
+
dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
|
161 |
+
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
|
162 |
+
dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)
|
163 |
+
|
164 |
+
# special rules for norm layers and depth-wise conv layers
|
165 |
+
is_norm = isinstance(module,
|
166 |
+
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
|
167 |
+
is_dwconv = (
|
168 |
+
isinstance(module, torch.nn.Conv2d)
|
169 |
+
and module.in_channels == module.groups)
|
170 |
+
|
171 |
+
for name, param in module.named_parameters(recurse=False):
|
172 |
+
param_group = {'params': [param]}
|
173 |
+
if not param.requires_grad:
|
174 |
+
params.append(param_group)
|
175 |
+
continue
|
176 |
+
if bypass_duplicate and self._is_in(param_group, params):
|
177 |
+
warnings.warn(f'{prefix} is duplicate. It is skipped since '
|
178 |
+
f'bypass_duplicate={bypass_duplicate}')
|
179 |
+
continue
|
180 |
+
# if the parameter match one of the custom keys, ignore other rules
|
181 |
+
is_custom = False
|
182 |
+
for key in sorted_keys:
|
183 |
+
if key in f'{prefix}.{name}':
|
184 |
+
is_custom = True
|
185 |
+
lr_mult = custom_keys[key].get('lr_mult', 1.)
|
186 |
+
param_group['lr'] = self.base_lr * lr_mult
|
187 |
+
if self.base_wd is not None:
|
188 |
+
decay_mult = custom_keys[key].get('decay_mult', 1.)
|
189 |
+
param_group['weight_decay'] = self.base_wd * decay_mult
|
190 |
+
break
|
191 |
+
|
192 |
+
if not is_custom:
|
193 |
+
# bias_lr_mult affects all bias parameters
|
194 |
+
# except for norm.bias dcn.conv_offset.bias
|
195 |
+
if name == 'bias' and not (is_norm or is_dcn_module):
|
196 |
+
param_group['lr'] = self.base_lr * bias_lr_mult
|
197 |
+
|
198 |
+
if (prefix.find('conv_offset') != -1 and is_dcn_module
|
199 |
+
and isinstance(module, torch.nn.Conv2d)):
|
200 |
+
# deal with both dcn_offset's bias & weight
|
201 |
+
param_group['lr'] = self.base_lr * dcn_offset_lr_mult
|
202 |
+
|
203 |
+
# apply weight decay policies
|
204 |
+
if self.base_wd is not None:
|
205 |
+
# norm decay
|
206 |
+
if is_norm:
|
207 |
+
param_group[
|
208 |
+
'weight_decay'] = self.base_wd * norm_decay_mult
|
209 |
+
# depth-wise conv
|
210 |
+
elif is_dwconv:
|
211 |
+
param_group[
|
212 |
+
'weight_decay'] = self.base_wd * dwconv_decay_mult
|
213 |
+
# bias lr and decay
|
214 |
+
elif name == 'bias' and not is_dcn_module:
|
215 |
+
# TODO: current bias_decay_mult will have affect on DCN
|
216 |
+
param_group[
|
217 |
+
'weight_decay'] = self.base_wd * bias_decay_mult
|
218 |
+
params.append(param_group)
|
219 |
+
|
220 |
+
if check_ops_exist():
|
221 |
+
from annotator.mmpkg.mmcv.ops import DeformConv2d, ModulatedDeformConv2d
|
222 |
+
is_dcn_module = isinstance(module,
|
223 |
+
(DeformConv2d, ModulatedDeformConv2d))
|
224 |
+
else:
|
225 |
+
is_dcn_module = False
|
226 |
+
for child_name, child_mod in module.named_children():
|
227 |
+
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
|
228 |
+
self.add_params(
|
229 |
+
params,
|
230 |
+
child_mod,
|
231 |
+
prefix=child_prefix,
|
232 |
+
is_dcn_module=is_dcn_module)
|
233 |
+
|
234 |
+
def __call__(self, model):
|
235 |
+
if hasattr(model, 'module'):
|
236 |
+
model = model.module
|
237 |
+
|
238 |
+
optimizer_cfg = self.optimizer_cfg.copy()
|
239 |
+
# if no paramwise option is specified, just use the global setting
|
240 |
+
if not self.paramwise_cfg:
|
241 |
+
optimizer_cfg['params'] = model.parameters()
|
242 |
+
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
243 |
+
|
244 |
+
# set param-wise lr and weight decay recursively
|
245 |
+
params = []
|
246 |
+
self.add_params(params, model)
|
247 |
+
optimizer_cfg['params'] = params
|
248 |
+
|
249 |
+
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/priority.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
|
5 |
+
class Priority(Enum):
|
6 |
+
"""Hook priority levels.
|
7 |
+
|
8 |
+
+--------------+------------+
|
9 |
+
| Level | Value |
|
10 |
+
+==============+============+
|
11 |
+
| HIGHEST | 0 |
|
12 |
+
+--------------+------------+
|
13 |
+
| VERY_HIGH | 10 |
|
14 |
+
+--------------+------------+
|
15 |
+
| HIGH | 30 |
|
16 |
+
+--------------+------------+
|
17 |
+
| ABOVE_NORMAL | 40 |
|
18 |
+
+--------------+------------+
|
19 |
+
| NORMAL | 50 |
|
20 |
+
+--------------+------------+
|
21 |
+
| BELOW_NORMAL | 60 |
|
22 |
+
+--------------+------------+
|
23 |
+
| LOW | 70 |
|
24 |
+
+--------------+------------+
|
25 |
+
| VERY_LOW | 90 |
|
26 |
+
+--------------+------------+
|
27 |
+
| LOWEST | 100 |
|
28 |
+
+--------------+------------+
|
29 |
+
"""
|
30 |
+
|
31 |
+
HIGHEST = 0
|
32 |
+
VERY_HIGH = 10
|
33 |
+
HIGH = 30
|
34 |
+
ABOVE_NORMAL = 40
|
35 |
+
NORMAL = 50
|
36 |
+
BELOW_NORMAL = 60
|
37 |
+
LOW = 70
|
38 |
+
VERY_LOW = 90
|
39 |
+
LOWEST = 100
|
40 |
+
|
41 |
+
|
42 |
+
def get_priority(priority):
|
43 |
+
"""Get priority value.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
priority (int or str or :obj:`Priority`): Priority.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
int: The priority value.
|
50 |
+
"""
|
51 |
+
if isinstance(priority, int):
|
52 |
+
if priority < 0 or priority > 100:
|
53 |
+
raise ValueError('priority must be between 0 and 100')
|
54 |
+
return priority
|
55 |
+
elif isinstance(priority, Priority):
|
56 |
+
return priority.value
|
57 |
+
elif isinstance(priority, str):
|
58 |
+
return Priority[priority.upper()].value
|
59 |
+
else:
|
60 |
+
raise TypeError('priority must be an integer or Priority enum value')
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/runner/utils.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
from getpass import getuser
|
8 |
+
from socket import gethostname
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
import annotator.mmpkg.mmcv as mmcv
|
14 |
+
|
15 |
+
|
16 |
+
def get_host_info():
|
17 |
+
"""Get hostname and username.
|
18 |
+
|
19 |
+
Return empty string if exception raised, e.g. ``getpass.getuser()`` will
|
20 |
+
lead to error in docker container
|
21 |
+
"""
|
22 |
+
host = ''
|
23 |
+
try:
|
24 |
+
host = f'{getuser()}@{gethostname()}'
|
25 |
+
except Exception as e:
|
26 |
+
warnings.warn(f'Host or user not found: {str(e)}')
|
27 |
+
finally:
|
28 |
+
return host
|
29 |
+
|
30 |
+
|
31 |
+
def get_time_str():
|
32 |
+
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
33 |
+
|
34 |
+
|
35 |
+
def obj_from_dict(info, parent=None, default_args=None):
|
36 |
+
"""Initialize an object from dict.
|
37 |
+
|
38 |
+
The dict must contain the key "type", which indicates the object type, it
|
39 |
+
can be either a string or type, such as "list" or ``list``. Remaining
|
40 |
+
fields are treated as the arguments for constructing the object.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
info (dict): Object types and arguments.
|
44 |
+
parent (:class:`module`): Module which may containing expected object
|
45 |
+
classes.
|
46 |
+
default_args (dict, optional): Default arguments for initializing the
|
47 |
+
object.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
any type: Object built from the dict.
|
51 |
+
"""
|
52 |
+
assert isinstance(info, dict) and 'type' in info
|
53 |
+
assert isinstance(default_args, dict) or default_args is None
|
54 |
+
args = info.copy()
|
55 |
+
obj_type = args.pop('type')
|
56 |
+
if mmcv.is_str(obj_type):
|
57 |
+
if parent is not None:
|
58 |
+
obj_type = getattr(parent, obj_type)
|
59 |
+
else:
|
60 |
+
obj_type = sys.modules[obj_type]
|
61 |
+
elif not isinstance(obj_type, type):
|
62 |
+
raise TypeError('type must be a str or valid type, but '
|
63 |
+
f'got {type(obj_type)}')
|
64 |
+
if default_args is not None:
|
65 |
+
for name, value in default_args.items():
|
66 |
+
args.setdefault(name, value)
|
67 |
+
return obj_type(**args)
|
68 |
+
|
69 |
+
|
70 |
+
def set_random_seed(seed, deterministic=False, use_rank_shift=False):
|
71 |
+
"""Set random seed.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
seed (int): Seed to be used.
|
75 |
+
deterministic (bool): Whether to set the deterministic option for
|
76 |
+
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
77 |
+
to True and `torch.backends.cudnn.benchmark` to False.
|
78 |
+
Default: False.
|
79 |
+
rank_shift (bool): Whether to add rank number to the random seed to
|
80 |
+
have different random seed in different threads. Default: False.
|
81 |
+
"""
|
82 |
+
if use_rank_shift:
|
83 |
+
rank, _ = mmcv.runner.get_dist_info()
|
84 |
+
seed += rank
|
85 |
+
random.seed(seed)
|
86 |
+
np.random.seed(seed)
|
87 |
+
torch.manual_seed(seed)
|
88 |
+
torch.cuda.manual_seed(seed)
|
89 |
+
torch.cuda.manual_seed_all(seed)
|
90 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
91 |
+
if deterministic:
|
92 |
+
torch.backends.cudnn.deterministic = True
|
93 |
+
torch.backends.cudnn.benchmark = False
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/__init__.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
3 |
+
from .config import Config, ConfigDict, DictAction
|
4 |
+
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
5 |
+
has_method, import_modules_from_strings, is_list_of,
|
6 |
+
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
7 |
+
iter_cast, list_cast, requires_executable, requires_package,
|
8 |
+
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
|
9 |
+
to_ntuple, tuple_cast)
|
10 |
+
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
|
11 |
+
scandir, symlink)
|
12 |
+
from .progressbar import (ProgressBar, track_iter_progress,
|
13 |
+
track_parallel_progress, track_progress)
|
14 |
+
from .testing import (assert_attrs_equal, assert_dict_contains_subset,
|
15 |
+
assert_dict_has_keys, assert_is_norm_layer,
|
16 |
+
assert_keys_equal, assert_params_all_zeros,
|
17 |
+
check_python_script)
|
18 |
+
from .timer import Timer, TimerError, check_time
|
19 |
+
from .version_utils import digit_version, get_git_hash
|
20 |
+
|
21 |
+
try:
|
22 |
+
import torch
|
23 |
+
except ImportError:
|
24 |
+
__all__ = [
|
25 |
+
'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast',
|
26 |
+
'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of',
|
27 |
+
'slice_list', 'concat_list', 'check_prerequisites', 'requires_package',
|
28 |
+
'requires_executable', 'is_filepath', 'fopen', 'check_file_exist',
|
29 |
+
'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
|
30 |
+
'track_progress', 'track_iter_progress', 'track_parallel_progress',
|
31 |
+
'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
|
32 |
+
'digit_version', 'get_git_hash', 'import_modules_from_strings',
|
33 |
+
'assert_dict_contains_subset', 'assert_attrs_equal',
|
34 |
+
'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
|
35 |
+
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
36 |
+
'is_method_overridden', 'has_method'
|
37 |
+
]
|
38 |
+
else:
|
39 |
+
from .env import collect_env
|
40 |
+
from .logging import get_logger, print_log
|
41 |
+
from .parrots_jit import jit, skip_no_elena
|
42 |
+
from .parrots_wrapper import (
|
43 |
+
TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
|
44 |
+
PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
|
45 |
+
_AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
|
46 |
+
_MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
|
47 |
+
from .registry import Registry, build_from_cfg
|
48 |
+
from .trace import is_jit_tracing
|
49 |
+
__all__ = [
|
50 |
+
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
|
51 |
+
'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
|
52 |
+
'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
|
53 |
+
'check_prerequisites', 'requires_package', 'requires_executable',
|
54 |
+
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
|
55 |
+
'symlink', 'scandir', 'ProgressBar', 'track_progress',
|
56 |
+
'track_iter_progress', 'track_parallel_progress', 'Registry',
|
57 |
+
'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
|
58 |
+
'_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
|
59 |
+
'_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
|
60 |
+
'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
|
61 |
+
'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
|
62 |
+
'deprecated_api_warning', 'digit_version', 'get_git_hash',
|
63 |
+
'import_modules_from_strings', 'jit', 'skip_no_elena',
|
64 |
+
'assert_dict_contains_subset', 'assert_attrs_equal',
|
65 |
+
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
|
66 |
+
'assert_params_all_zeros', 'check_python_script',
|
67 |
+
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
|
68 |
+
'_get_cuda_home', 'has_method'
|
69 |
+
]
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/config.py
ADDED
@@ -0,0 +1,688 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import ast
|
3 |
+
import copy
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import platform
|
7 |
+
import shutil
|
8 |
+
import sys
|
9 |
+
import tempfile
|
10 |
+
import uuid
|
11 |
+
import warnings
|
12 |
+
from argparse import Action, ArgumentParser
|
13 |
+
from collections import abc
|
14 |
+
from importlib import import_module
|
15 |
+
|
16 |
+
from addict import Dict
|
17 |
+
from yapf.yapflib.yapf_api import FormatCode
|
18 |
+
|
19 |
+
from .misc import import_modules_from_strings
|
20 |
+
from .path import check_file_exist
|
21 |
+
|
22 |
+
if platform.system() == 'Windows':
|
23 |
+
import regex as re
|
24 |
+
else:
|
25 |
+
import re
|
26 |
+
|
27 |
+
BASE_KEY = '_base_'
|
28 |
+
DELETE_KEY = '_delete_'
|
29 |
+
DEPRECATION_KEY = '_deprecation_'
|
30 |
+
RESERVED_KEYS = ['filename', 'text', 'pretty_text']
|
31 |
+
|
32 |
+
|
33 |
+
class ConfigDict(Dict):
|
34 |
+
|
35 |
+
def __missing__(self, name):
|
36 |
+
raise KeyError(name)
|
37 |
+
|
38 |
+
def __getattr__(self, name):
|
39 |
+
try:
|
40 |
+
value = super(ConfigDict, self).__getattr__(name)
|
41 |
+
except KeyError:
|
42 |
+
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
|
43 |
+
f"attribute '{name}'")
|
44 |
+
except Exception as e:
|
45 |
+
ex = e
|
46 |
+
else:
|
47 |
+
return value
|
48 |
+
raise ex
|
49 |
+
|
50 |
+
|
51 |
+
def add_args(parser, cfg, prefix=''):
|
52 |
+
for k, v in cfg.items():
|
53 |
+
if isinstance(v, str):
|
54 |
+
parser.add_argument('--' + prefix + k)
|
55 |
+
elif isinstance(v, int):
|
56 |
+
parser.add_argument('--' + prefix + k, type=int)
|
57 |
+
elif isinstance(v, float):
|
58 |
+
parser.add_argument('--' + prefix + k, type=float)
|
59 |
+
elif isinstance(v, bool):
|
60 |
+
parser.add_argument('--' + prefix + k, action='store_true')
|
61 |
+
elif isinstance(v, dict):
|
62 |
+
add_args(parser, v, prefix + k + '.')
|
63 |
+
elif isinstance(v, abc.Iterable):
|
64 |
+
parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
|
65 |
+
else:
|
66 |
+
print(f'cannot parse key {prefix + k} of type {type(v)}')
|
67 |
+
return parser
|
68 |
+
|
69 |
+
|
70 |
+
class Config:
|
71 |
+
"""A facility for config and config files.
|
72 |
+
|
73 |
+
It supports common file formats as configs: python/json/yaml. The interface
|
74 |
+
is the same as a dict object and also allows access config values as
|
75 |
+
attributes.
|
76 |
+
|
77 |
+
Example:
|
78 |
+
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
|
79 |
+
>>> cfg.a
|
80 |
+
1
|
81 |
+
>>> cfg.b
|
82 |
+
{'b1': [0, 1]}
|
83 |
+
>>> cfg.b.b1
|
84 |
+
[0, 1]
|
85 |
+
>>> cfg = Config.fromfile('tests/data/config/a.py')
|
86 |
+
>>> cfg.filename
|
87 |
+
"/home/kchen/projects/mmcv/tests/data/config/a.py"
|
88 |
+
>>> cfg.item4
|
89 |
+
'test'
|
90 |
+
>>> cfg
|
91 |
+
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
|
92 |
+
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
|
93 |
+
"""
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def _validate_py_syntax(filename):
|
97 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
98 |
+
# Setting encoding explicitly to resolve coding issue on windows
|
99 |
+
content = f.read()
|
100 |
+
try:
|
101 |
+
ast.parse(content)
|
102 |
+
except SyntaxError as e:
|
103 |
+
raise SyntaxError('There are syntax errors in config '
|
104 |
+
f'file {filename}: {e}')
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def _substitute_predefined_vars(filename, temp_config_name):
|
108 |
+
file_dirname = osp.dirname(filename)
|
109 |
+
file_basename = osp.basename(filename)
|
110 |
+
file_basename_no_extension = osp.splitext(file_basename)[0]
|
111 |
+
file_extname = osp.splitext(filename)[1]
|
112 |
+
support_templates = dict(
|
113 |
+
fileDirname=file_dirname,
|
114 |
+
fileBasename=file_basename,
|
115 |
+
fileBasenameNoExtension=file_basename_no_extension,
|
116 |
+
fileExtname=file_extname)
|
117 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
118 |
+
# Setting encoding explicitly to resolve coding issue on windows
|
119 |
+
config_file = f.read()
|
120 |
+
for key, value in support_templates.items():
|
121 |
+
regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
|
122 |
+
value = value.replace('\\', '/')
|
123 |
+
config_file = re.sub(regexp, value, config_file)
|
124 |
+
with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
|
125 |
+
tmp_config_file.write(config_file)
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def _pre_substitute_base_vars(filename, temp_config_name):
|
129 |
+
"""Substitute base variable placehoders to string, so that parsing
|
130 |
+
would work."""
|
131 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
132 |
+
# Setting encoding explicitly to resolve coding issue on windows
|
133 |
+
config_file = f.read()
|
134 |
+
base_var_dict = {}
|
135 |
+
regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
|
136 |
+
base_vars = set(re.findall(regexp, config_file))
|
137 |
+
for base_var in base_vars:
|
138 |
+
randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
|
139 |
+
base_var_dict[randstr] = base_var
|
140 |
+
regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
|
141 |
+
config_file = re.sub(regexp, f'"{randstr}"', config_file)
|
142 |
+
with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
|
143 |
+
tmp_config_file.write(config_file)
|
144 |
+
return base_var_dict
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
|
148 |
+
"""Substitute variable strings to their actual values."""
|
149 |
+
cfg = copy.deepcopy(cfg)
|
150 |
+
|
151 |
+
if isinstance(cfg, dict):
|
152 |
+
for k, v in cfg.items():
|
153 |
+
if isinstance(v, str) and v in base_var_dict:
|
154 |
+
new_v = base_cfg
|
155 |
+
for new_k in base_var_dict[v].split('.'):
|
156 |
+
new_v = new_v[new_k]
|
157 |
+
cfg[k] = new_v
|
158 |
+
elif isinstance(v, (list, tuple, dict)):
|
159 |
+
cfg[k] = Config._substitute_base_vars(
|
160 |
+
v, base_var_dict, base_cfg)
|
161 |
+
elif isinstance(cfg, tuple):
|
162 |
+
cfg = tuple(
|
163 |
+
Config._substitute_base_vars(c, base_var_dict, base_cfg)
|
164 |
+
for c in cfg)
|
165 |
+
elif isinstance(cfg, list):
|
166 |
+
cfg = [
|
167 |
+
Config._substitute_base_vars(c, base_var_dict, base_cfg)
|
168 |
+
for c in cfg
|
169 |
+
]
|
170 |
+
elif isinstance(cfg, str) and cfg in base_var_dict:
|
171 |
+
new_v = base_cfg
|
172 |
+
for new_k in base_var_dict[cfg].split('.'):
|
173 |
+
new_v = new_v[new_k]
|
174 |
+
cfg = new_v
|
175 |
+
|
176 |
+
return cfg
|
177 |
+
|
178 |
+
@staticmethod
|
179 |
+
def _file2dict(filename, use_predefined_variables=True):
|
180 |
+
filename = osp.abspath(osp.expanduser(filename))
|
181 |
+
check_file_exist(filename)
|
182 |
+
fileExtname = osp.splitext(filename)[1]
|
183 |
+
if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
|
184 |
+
raise IOError('Only py/yml/yaml/json type are supported now!')
|
185 |
+
|
186 |
+
with tempfile.TemporaryDirectory() as temp_config_dir:
|
187 |
+
temp_config_file = tempfile.NamedTemporaryFile(
|
188 |
+
dir=temp_config_dir, suffix=fileExtname)
|
189 |
+
if platform.system() == 'Windows':
|
190 |
+
temp_config_file.close()
|
191 |
+
temp_config_name = osp.basename(temp_config_file.name)
|
192 |
+
# Substitute predefined variables
|
193 |
+
if use_predefined_variables:
|
194 |
+
Config._substitute_predefined_vars(filename,
|
195 |
+
temp_config_file.name)
|
196 |
+
else:
|
197 |
+
shutil.copyfile(filename, temp_config_file.name)
|
198 |
+
# Substitute base variables from placeholders to strings
|
199 |
+
base_var_dict = Config._pre_substitute_base_vars(
|
200 |
+
temp_config_file.name, temp_config_file.name)
|
201 |
+
|
202 |
+
if filename.endswith('.py'):
|
203 |
+
temp_module_name = osp.splitext(temp_config_name)[0]
|
204 |
+
sys.path.insert(0, temp_config_dir)
|
205 |
+
Config._validate_py_syntax(filename)
|
206 |
+
mod = import_module(temp_module_name)
|
207 |
+
sys.path.pop(0)
|
208 |
+
cfg_dict = {
|
209 |
+
name: value
|
210 |
+
for name, value in mod.__dict__.items()
|
211 |
+
if not name.startswith('__')
|
212 |
+
}
|
213 |
+
# delete imported module
|
214 |
+
del sys.modules[temp_module_name]
|
215 |
+
elif filename.endswith(('.yml', '.yaml', '.json')):
|
216 |
+
import annotator.mmpkg.mmcv as mmcv
|
217 |
+
cfg_dict = mmcv.load(temp_config_file.name)
|
218 |
+
# close temp file
|
219 |
+
temp_config_file.close()
|
220 |
+
|
221 |
+
# check deprecation information
|
222 |
+
if DEPRECATION_KEY in cfg_dict:
|
223 |
+
deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
|
224 |
+
warning_msg = f'The config file {filename} will be deprecated ' \
|
225 |
+
'in the future.'
|
226 |
+
if 'expected' in deprecation_info:
|
227 |
+
warning_msg += f' Please use {deprecation_info["expected"]} ' \
|
228 |
+
'instead.'
|
229 |
+
if 'reference' in deprecation_info:
|
230 |
+
warning_msg += ' More information can be found at ' \
|
231 |
+
f'{deprecation_info["reference"]}'
|
232 |
+
warnings.warn(warning_msg)
|
233 |
+
|
234 |
+
cfg_text = filename + '\n'
|
235 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
236 |
+
# Setting encoding explicitly to resolve coding issue on windows
|
237 |
+
cfg_text += f.read()
|
238 |
+
|
239 |
+
if BASE_KEY in cfg_dict:
|
240 |
+
cfg_dir = osp.dirname(filename)
|
241 |
+
base_filename = cfg_dict.pop(BASE_KEY)
|
242 |
+
base_filename = base_filename if isinstance(
|
243 |
+
base_filename, list) else [base_filename]
|
244 |
+
|
245 |
+
cfg_dict_list = list()
|
246 |
+
cfg_text_list = list()
|
247 |
+
for f in base_filename:
|
248 |
+
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
|
249 |
+
cfg_dict_list.append(_cfg_dict)
|
250 |
+
cfg_text_list.append(_cfg_text)
|
251 |
+
|
252 |
+
base_cfg_dict = dict()
|
253 |
+
for c in cfg_dict_list:
|
254 |
+
duplicate_keys = base_cfg_dict.keys() & c.keys()
|
255 |
+
if len(duplicate_keys) > 0:
|
256 |
+
raise KeyError('Duplicate key is not allowed among bases. '
|
257 |
+
f'Duplicate keys: {duplicate_keys}')
|
258 |
+
base_cfg_dict.update(c)
|
259 |
+
|
260 |
+
# Substitute base variables from strings to their actual values
|
261 |
+
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
|
262 |
+
base_cfg_dict)
|
263 |
+
|
264 |
+
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
|
265 |
+
cfg_dict = base_cfg_dict
|
266 |
+
|
267 |
+
# merge cfg_text
|
268 |
+
cfg_text_list.append(cfg_text)
|
269 |
+
cfg_text = '\n'.join(cfg_text_list)
|
270 |
+
|
271 |
+
return cfg_dict, cfg_text
|
272 |
+
|
273 |
+
@staticmethod
|
274 |
+
def _merge_a_into_b(a, b, allow_list_keys=False):
|
275 |
+
"""merge dict ``a`` into dict ``b`` (non-inplace).
|
276 |
+
|
277 |
+
Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
|
278 |
+
in-place modifications.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
a (dict): The source dict to be merged into ``b``.
|
282 |
+
b (dict): The origin dict to be fetch keys from ``a``.
|
283 |
+
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
284 |
+
are allowed in source ``a`` and will replace the element of the
|
285 |
+
corresponding index in b if b is a list. Default: False.
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
dict: The modified dict of ``b`` using ``a``.
|
289 |
+
|
290 |
+
Examples:
|
291 |
+
# Normally merge a into b.
|
292 |
+
>>> Config._merge_a_into_b(
|
293 |
+
... dict(obj=dict(a=2)), dict(obj=dict(a=1)))
|
294 |
+
{'obj': {'a': 2}}
|
295 |
+
|
296 |
+
# Delete b first and merge a into b.
|
297 |
+
>>> Config._merge_a_into_b(
|
298 |
+
... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
|
299 |
+
{'obj': {'a': 2}}
|
300 |
+
|
301 |
+
# b is a list
|
302 |
+
>>> Config._merge_a_into_b(
|
303 |
+
... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
|
304 |
+
[{'a': 2}, {'b': 2}]
|
305 |
+
"""
|
306 |
+
b = b.copy()
|
307 |
+
for k, v in a.items():
|
308 |
+
if allow_list_keys and k.isdigit() and isinstance(b, list):
|
309 |
+
k = int(k)
|
310 |
+
if len(b) <= k:
|
311 |
+
raise KeyError(f'Index {k} exceeds the length of list {b}')
|
312 |
+
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
313 |
+
elif isinstance(v,
|
314 |
+
dict) and k in b and not v.pop(DELETE_KEY, False):
|
315 |
+
allowed_types = (dict, list) if allow_list_keys else dict
|
316 |
+
if not isinstance(b[k], allowed_types):
|
317 |
+
raise TypeError(
|
318 |
+
f'{k}={v} in child config cannot inherit from base '
|
319 |
+
f'because {k} is a dict in the child config but is of '
|
320 |
+
f'type {type(b[k])} in base config. You may set '
|
321 |
+
f'`{DELETE_KEY}=True` to ignore the base config')
|
322 |
+
b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
|
323 |
+
else:
|
324 |
+
b[k] = v
|
325 |
+
return b
|
326 |
+
|
327 |
+
@staticmethod
|
328 |
+
def fromfile(filename,
|
329 |
+
use_predefined_variables=True,
|
330 |
+
import_custom_modules=True):
|
331 |
+
cfg_dict, cfg_text = Config._file2dict(filename,
|
332 |
+
use_predefined_variables)
|
333 |
+
if import_custom_modules and cfg_dict.get('custom_imports', None):
|
334 |
+
import_modules_from_strings(**cfg_dict['custom_imports'])
|
335 |
+
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
|
336 |
+
|
337 |
+
@staticmethod
|
338 |
+
def fromstring(cfg_str, file_format):
|
339 |
+
"""Generate config from config str.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
cfg_str (str): Config str.
|
343 |
+
file_format (str): Config file format corresponding to the
|
344 |
+
config str. Only py/yml/yaml/json type are supported now!
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
obj:`Config`: Config obj.
|
348 |
+
"""
|
349 |
+
if file_format not in ['.py', '.json', '.yaml', '.yml']:
|
350 |
+
raise IOError('Only py/yml/yaml/json type are supported now!')
|
351 |
+
if file_format != '.py' and 'dict(' in cfg_str:
|
352 |
+
# check if users specify a wrong suffix for python
|
353 |
+
warnings.warn(
|
354 |
+
'Please check "file_format", the file format may be .py')
|
355 |
+
with tempfile.NamedTemporaryFile(
|
356 |
+
'w', encoding='utf-8', suffix=file_format,
|
357 |
+
delete=False) as temp_file:
|
358 |
+
temp_file.write(cfg_str)
|
359 |
+
# on windows, previous implementation cause error
|
360 |
+
# see PR 1077 for details
|
361 |
+
cfg = Config.fromfile(temp_file.name)
|
362 |
+
os.remove(temp_file.name)
|
363 |
+
return cfg
|
364 |
+
|
365 |
+
@staticmethod
|
366 |
+
def auto_argparser(description=None):
|
367 |
+
"""Generate argparser from config file automatically (experimental)"""
|
368 |
+
partial_parser = ArgumentParser(description=description)
|
369 |
+
partial_parser.add_argument('config', help='config file path')
|
370 |
+
cfg_file = partial_parser.parse_known_args()[0].config
|
371 |
+
cfg = Config.fromfile(cfg_file)
|
372 |
+
parser = ArgumentParser(description=description)
|
373 |
+
parser.add_argument('config', help='config file path')
|
374 |
+
add_args(parser, cfg)
|
375 |
+
return parser, cfg
|
376 |
+
|
377 |
+
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
|
378 |
+
if cfg_dict is None:
|
379 |
+
cfg_dict = dict()
|
380 |
+
elif not isinstance(cfg_dict, dict):
|
381 |
+
raise TypeError('cfg_dict must be a dict, but '
|
382 |
+
f'got {type(cfg_dict)}')
|
383 |
+
for key in cfg_dict:
|
384 |
+
if key in RESERVED_KEYS:
|
385 |
+
raise KeyError(f'{key} is reserved for config file')
|
386 |
+
|
387 |
+
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
|
388 |
+
super(Config, self).__setattr__('_filename', filename)
|
389 |
+
if cfg_text:
|
390 |
+
text = cfg_text
|
391 |
+
elif filename:
|
392 |
+
with open(filename, 'r') as f:
|
393 |
+
text = f.read()
|
394 |
+
else:
|
395 |
+
text = ''
|
396 |
+
super(Config, self).__setattr__('_text', text)
|
397 |
+
|
398 |
+
@property
|
399 |
+
def filename(self):
|
400 |
+
return self._filename
|
401 |
+
|
402 |
+
@property
|
403 |
+
def text(self):
|
404 |
+
return self._text
|
405 |
+
|
406 |
+
@property
|
407 |
+
def pretty_text(self):
|
408 |
+
|
409 |
+
indent = 4
|
410 |
+
|
411 |
+
def _indent(s_, num_spaces):
|
412 |
+
s = s_.split('\n')
|
413 |
+
if len(s) == 1:
|
414 |
+
return s_
|
415 |
+
first = s.pop(0)
|
416 |
+
s = [(num_spaces * ' ') + line for line in s]
|
417 |
+
s = '\n'.join(s)
|
418 |
+
s = first + '\n' + s
|
419 |
+
return s
|
420 |
+
|
421 |
+
def _format_basic_types(k, v, use_mapping=False):
|
422 |
+
if isinstance(v, str):
|
423 |
+
v_str = f"'{v}'"
|
424 |
+
else:
|
425 |
+
v_str = str(v)
|
426 |
+
|
427 |
+
if use_mapping:
|
428 |
+
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
429 |
+
attr_str = f'{k_str}: {v_str}'
|
430 |
+
else:
|
431 |
+
attr_str = f'{str(k)}={v_str}'
|
432 |
+
attr_str = _indent(attr_str, indent)
|
433 |
+
|
434 |
+
return attr_str
|
435 |
+
|
436 |
+
def _format_list(k, v, use_mapping=False):
|
437 |
+
# check if all items in the list are dict
|
438 |
+
if all(isinstance(_, dict) for _ in v):
|
439 |
+
v_str = '[\n'
|
440 |
+
v_str += '\n'.join(
|
441 |
+
f'dict({_indent(_format_dict(v_), indent)}),'
|
442 |
+
for v_ in v).rstrip(',')
|
443 |
+
if use_mapping:
|
444 |
+
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
445 |
+
attr_str = f'{k_str}: {v_str}'
|
446 |
+
else:
|
447 |
+
attr_str = f'{str(k)}={v_str}'
|
448 |
+
attr_str = _indent(attr_str, indent) + ']'
|
449 |
+
else:
|
450 |
+
attr_str = _format_basic_types(k, v, use_mapping)
|
451 |
+
return attr_str
|
452 |
+
|
453 |
+
def _contain_invalid_identifier(dict_str):
|
454 |
+
contain_invalid_identifier = False
|
455 |
+
for key_name in dict_str:
|
456 |
+
contain_invalid_identifier |= \
|
457 |
+
(not str(key_name).isidentifier())
|
458 |
+
return contain_invalid_identifier
|
459 |
+
|
460 |
+
def _format_dict(input_dict, outest_level=False):
|
461 |
+
r = ''
|
462 |
+
s = []
|
463 |
+
|
464 |
+
use_mapping = _contain_invalid_identifier(input_dict)
|
465 |
+
if use_mapping:
|
466 |
+
r += '{'
|
467 |
+
for idx, (k, v) in enumerate(input_dict.items()):
|
468 |
+
is_last = idx >= len(input_dict) - 1
|
469 |
+
end = '' if outest_level or is_last else ','
|
470 |
+
if isinstance(v, dict):
|
471 |
+
v_str = '\n' + _format_dict(v)
|
472 |
+
if use_mapping:
|
473 |
+
k_str = f"'{k}'" if isinstance(k, str) else str(k)
|
474 |
+
attr_str = f'{k_str}: dict({v_str}'
|
475 |
+
else:
|
476 |
+
attr_str = f'{str(k)}=dict({v_str}'
|
477 |
+
attr_str = _indent(attr_str, indent) + ')' + end
|
478 |
+
elif isinstance(v, list):
|
479 |
+
attr_str = _format_list(k, v, use_mapping) + end
|
480 |
+
else:
|
481 |
+
attr_str = _format_basic_types(k, v, use_mapping) + end
|
482 |
+
|
483 |
+
s.append(attr_str)
|
484 |
+
r += '\n'.join(s)
|
485 |
+
if use_mapping:
|
486 |
+
r += '}'
|
487 |
+
return r
|
488 |
+
|
489 |
+
cfg_dict = self._cfg_dict.to_dict()
|
490 |
+
text = _format_dict(cfg_dict, outest_level=True)
|
491 |
+
# copied from setup.cfg
|
492 |
+
yapf_style = dict(
|
493 |
+
based_on_style='pep8',
|
494 |
+
blank_line_before_nested_class_or_def=True,
|
495 |
+
split_before_expression_after_opening_paren=True)
|
496 |
+
text, _ = FormatCode(text, style_config=yapf_style, verify=True)
|
497 |
+
|
498 |
+
return text
|
499 |
+
|
500 |
+
def __repr__(self):
|
501 |
+
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
|
502 |
+
|
503 |
+
def __len__(self):
|
504 |
+
return len(self._cfg_dict)
|
505 |
+
|
506 |
+
def __getattr__(self, name):
|
507 |
+
return getattr(self._cfg_dict, name)
|
508 |
+
|
509 |
+
def __getitem__(self, name):
|
510 |
+
return self._cfg_dict.__getitem__(name)
|
511 |
+
|
512 |
+
def __setattr__(self, name, value):
|
513 |
+
if isinstance(value, dict):
|
514 |
+
value = ConfigDict(value)
|
515 |
+
self._cfg_dict.__setattr__(name, value)
|
516 |
+
|
517 |
+
def __setitem__(self, name, value):
|
518 |
+
if isinstance(value, dict):
|
519 |
+
value = ConfigDict(value)
|
520 |
+
self._cfg_dict.__setitem__(name, value)
|
521 |
+
|
522 |
+
def __iter__(self):
|
523 |
+
return iter(self._cfg_dict)
|
524 |
+
|
525 |
+
def __getstate__(self):
|
526 |
+
return (self._cfg_dict, self._filename, self._text)
|
527 |
+
|
528 |
+
def __setstate__(self, state):
|
529 |
+
_cfg_dict, _filename, _text = state
|
530 |
+
super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
|
531 |
+
super(Config, self).__setattr__('_filename', _filename)
|
532 |
+
super(Config, self).__setattr__('_text', _text)
|
533 |
+
|
534 |
+
def dump(self, file=None):
|
535 |
+
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
|
536 |
+
if self.filename.endswith('.py'):
|
537 |
+
if file is None:
|
538 |
+
return self.pretty_text
|
539 |
+
else:
|
540 |
+
with open(file, 'w', encoding='utf-8') as f:
|
541 |
+
f.write(self.pretty_text)
|
542 |
+
else:
|
543 |
+
import annotator.mmpkg.mmcv as mmcv
|
544 |
+
if file is None:
|
545 |
+
file_format = self.filename.split('.')[-1]
|
546 |
+
return mmcv.dump(cfg_dict, file_format=file_format)
|
547 |
+
else:
|
548 |
+
mmcv.dump(cfg_dict, file)
|
549 |
+
|
550 |
+
def merge_from_dict(self, options, allow_list_keys=True):
|
551 |
+
"""Merge list into cfg_dict.
|
552 |
+
|
553 |
+
Merge the dict parsed by MultipleKVAction into this cfg.
|
554 |
+
|
555 |
+
Examples:
|
556 |
+
>>> options = {'model.backbone.depth': 50,
|
557 |
+
... 'model.backbone.with_cp':True}
|
558 |
+
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
|
559 |
+
>>> cfg.merge_from_dict(options)
|
560 |
+
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
561 |
+
>>> assert cfg_dict == dict(
|
562 |
+
... model=dict(backbone=dict(depth=50, with_cp=True)))
|
563 |
+
|
564 |
+
# Merge list element
|
565 |
+
>>> cfg = Config(dict(pipeline=[
|
566 |
+
... dict(type='LoadImage'), dict(type='LoadAnnotations')]))
|
567 |
+
>>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
|
568 |
+
>>> cfg.merge_from_dict(options, allow_list_keys=True)
|
569 |
+
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
570 |
+
>>> assert cfg_dict == dict(pipeline=[
|
571 |
+
... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
|
572 |
+
|
573 |
+
Args:
|
574 |
+
options (dict): dict of configs to merge from.
|
575 |
+
allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
|
576 |
+
are allowed in ``options`` and will replace the element of the
|
577 |
+
corresponding index in the config if the config is a list.
|
578 |
+
Default: True.
|
579 |
+
"""
|
580 |
+
option_cfg_dict = {}
|
581 |
+
for full_key, v in options.items():
|
582 |
+
d = option_cfg_dict
|
583 |
+
key_list = full_key.split('.')
|
584 |
+
for subkey in key_list[:-1]:
|
585 |
+
d.setdefault(subkey, ConfigDict())
|
586 |
+
d = d[subkey]
|
587 |
+
subkey = key_list[-1]
|
588 |
+
d[subkey] = v
|
589 |
+
|
590 |
+
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
|
591 |
+
super(Config, self).__setattr__(
|
592 |
+
'_cfg_dict',
|
593 |
+
Config._merge_a_into_b(
|
594 |
+
option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
|
595 |
+
|
596 |
+
|
597 |
+
class DictAction(Action):
|
598 |
+
"""
|
599 |
+
argparse action to split an argument into KEY=VALUE form
|
600 |
+
on the first = and append to a dictionary. List options can
|
601 |
+
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
|
602 |
+
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
|
603 |
+
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
|
604 |
+
"""
|
605 |
+
|
606 |
+
@staticmethod
|
607 |
+
def _parse_int_float_bool(val):
|
608 |
+
try:
|
609 |
+
return int(val)
|
610 |
+
except ValueError:
|
611 |
+
pass
|
612 |
+
try:
|
613 |
+
return float(val)
|
614 |
+
except ValueError:
|
615 |
+
pass
|
616 |
+
if val.lower() in ['true', 'false']:
|
617 |
+
return True if val.lower() == 'true' else False
|
618 |
+
return val
|
619 |
+
|
620 |
+
@staticmethod
|
621 |
+
def _parse_iterable(val):
|
622 |
+
"""Parse iterable values in the string.
|
623 |
+
|
624 |
+
All elements inside '()' or '[]' are treated as iterable values.
|
625 |
+
|
626 |
+
Args:
|
627 |
+
val (str): Value string.
|
628 |
+
|
629 |
+
Returns:
|
630 |
+
list | tuple: The expanded list or tuple from the string.
|
631 |
+
|
632 |
+
Examples:
|
633 |
+
>>> DictAction._parse_iterable('1,2,3')
|
634 |
+
[1, 2, 3]
|
635 |
+
>>> DictAction._parse_iterable('[a, b, c]')
|
636 |
+
['a', 'b', 'c']
|
637 |
+
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
|
638 |
+
[(1, 2, 3), ['a', 'b'], 'c']
|
639 |
+
"""
|
640 |
+
|
641 |
+
def find_next_comma(string):
|
642 |
+
"""Find the position of next comma in the string.
|
643 |
+
|
644 |
+
If no ',' is found in the string, return the string length. All
|
645 |
+
chars inside '()' and '[]' are treated as one element and thus ','
|
646 |
+
inside these brackets are ignored.
|
647 |
+
"""
|
648 |
+
assert (string.count('(') == string.count(')')) and (
|
649 |
+
string.count('[') == string.count(']')), \
|
650 |
+
f'Imbalanced brackets exist in {string}'
|
651 |
+
end = len(string)
|
652 |
+
for idx, char in enumerate(string):
|
653 |
+
pre = string[:idx]
|
654 |
+
# The string before this ',' is balanced
|
655 |
+
if ((char == ',') and (pre.count('(') == pre.count(')'))
|
656 |
+
and (pre.count('[') == pre.count(']'))):
|
657 |
+
end = idx
|
658 |
+
break
|
659 |
+
return end
|
660 |
+
|
661 |
+
# Strip ' and " characters and replace whitespace.
|
662 |
+
val = val.strip('\'\"').replace(' ', '')
|
663 |
+
is_tuple = False
|
664 |
+
if val.startswith('(') and val.endswith(')'):
|
665 |
+
is_tuple = True
|
666 |
+
val = val[1:-1]
|
667 |
+
elif val.startswith('[') and val.endswith(']'):
|
668 |
+
val = val[1:-1]
|
669 |
+
elif ',' not in val:
|
670 |
+
# val is a single value
|
671 |
+
return DictAction._parse_int_float_bool(val)
|
672 |
+
|
673 |
+
values = []
|
674 |
+
while len(val) > 0:
|
675 |
+
comma_idx = find_next_comma(val)
|
676 |
+
element = DictAction._parse_iterable(val[:comma_idx])
|
677 |
+
values.append(element)
|
678 |
+
val = val[comma_idx + 1:]
|
679 |
+
if is_tuple:
|
680 |
+
values = tuple(values)
|
681 |
+
return values
|
682 |
+
|
683 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
684 |
+
options = {}
|
685 |
+
for kv in values:
|
686 |
+
key, val = kv.split('=', maxsplit=1)
|
687 |
+
options[key] = self._parse_iterable(val)
|
688 |
+
setattr(namespace, self.dest, options)
|
extensions/microsoftexcel-controlnet/annotator/mmpkg/mmcv/utils/env.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
"""This file holding some environment constant for sharing by other files."""
|
3 |
+
|
4 |
+
import os.path as osp
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
from collections import defaultdict
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import annotator.mmpkg.mmcv as mmcv
|
13 |
+
from .parrots_wrapper import get_build_config
|
14 |
+
|
15 |
+
|
16 |
+
def collect_env():
|
17 |
+
"""Collect the information of the running environments.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
dict: The environment information. The following fields are contained.
|
21 |
+
|
22 |
+
- sys.platform: The variable of ``sys.platform``.
|
23 |
+
- Python: Python version.
|
24 |
+
- CUDA available: Bool, indicating if CUDA is available.
|
25 |
+
- GPU devices: Device type of each GPU.
|
26 |
+
- CUDA_HOME (optional): The env var ``CUDA_HOME``.
|
27 |
+
- NVCC (optional): NVCC version.
|
28 |
+
- GCC: GCC version, "n/a" if GCC is not installed.
|
29 |
+
- PyTorch: PyTorch version.
|
30 |
+
- PyTorch compiling details: The output of \
|
31 |
+
``torch.__config__.show()``.
|
32 |
+
- TorchVision (optional): TorchVision version.
|
33 |
+
- OpenCV: OpenCV version.
|
34 |
+
- MMCV: MMCV version.
|
35 |
+
- MMCV Compiler: The GCC version for compiling MMCV ops.
|
36 |
+
- MMCV CUDA Compiler: The CUDA version for compiling MMCV ops.
|
37 |
+
"""
|
38 |
+
env_info = {}
|
39 |
+
env_info['sys.platform'] = sys.platform
|
40 |
+
env_info['Python'] = sys.version.replace('\n', '')
|
41 |
+
|
42 |
+
cuda_available = torch.cuda.is_available()
|
43 |
+
env_info['CUDA available'] = cuda_available
|
44 |
+
|
45 |
+
if cuda_available:
|
46 |
+
devices = defaultdict(list)
|
47 |
+
for k in range(torch.cuda.device_count()):
|
48 |
+
devices[torch.cuda.get_device_name(k)].append(str(k))
|
49 |
+
for name, device_ids in devices.items():
|
50 |
+
env_info['GPU ' + ','.join(device_ids)] = name
|
51 |
+
|
52 |
+
from annotator.mmpkg.mmcv.utils.parrots_wrapper import _get_cuda_home
|
53 |
+
CUDA_HOME = _get_cuda_home()
|
54 |
+
env_info['CUDA_HOME'] = CUDA_HOME
|
55 |
+
|
56 |
+
if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
|
57 |
+
try:
|
58 |
+
nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
|
59 |
+
nvcc = subprocess.check_output(
|
60 |
+
f'"{nvcc}" -V | tail -n1', shell=True)
|
61 |
+
nvcc = nvcc.decode('utf-8').strip()
|
62 |
+
except subprocess.SubprocessError:
|
63 |
+
nvcc = 'Not Available'
|
64 |
+
env_info['NVCC'] = nvcc
|
65 |
+
|
66 |
+
try:
|
67 |
+
gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
|
68 |
+
gcc = gcc.decode('utf-8').strip()
|
69 |
+
env_info['GCC'] = gcc
|
70 |
+
except subprocess.CalledProcessError: # gcc is unavailable
|
71 |
+
env_info['GCC'] = 'n/a'
|
72 |
+
|
73 |
+
env_info['PyTorch'] = torch.__version__
|
74 |
+
env_info['PyTorch compiling details'] = get_build_config()
|
75 |
+
|
76 |
+
try:
|
77 |
+
import torchvision
|
78 |
+
env_info['TorchVision'] = torchvision.__version__
|
79 |
+
except ModuleNotFoundError:
|
80 |
+
pass
|
81 |
+
|
82 |
+
env_info['OpenCV'] = cv2.__version__
|
83 |
+
|
84 |
+
env_info['MMCV'] = mmcv.__version__
|
85 |
+
|
86 |
+
try:
|
87 |
+
from annotator.mmpkg.mmcv.ops import get_compiler_version, get_compiling_cuda_version
|
88 |
+
except ModuleNotFoundError:
|
89 |
+
env_info['MMCV Compiler'] = 'n/a'
|
90 |
+
env_info['MMCV CUDA Compiler'] = 'n/a'
|
91 |
+
else:
|
92 |
+
env_info['MMCV Compiler'] = get_compiler_version()
|
93 |
+
env_info['MMCV CUDA Compiler'] = get_compiling_cuda_version()
|
94 |
+
|
95 |
+
return env_info
|