onescotch commited on
Commit
010a8bc
1 Parent(s): d200058

clean up for zero gpus

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +8 -15
  2. common/base.py +1 -1
  3. common/utils/distribute_utils.py +1 -1
  4. main/SMPLer_X.py +1 -1
  5. main/config.py +2 -1
  6. main/inference.py +7 -1
  7. main/transformer_utils/mmpose/__init__.py +1 -1
  8. main/transformer_utils/mmpose/core/camera/camera_base.py +1 -1
  9. main/transformer_utils/mmpose/core/distributed_wrapper.py +1 -1
  10. main/transformer_utils/mmpose/core/evaluation/eval_hooks.py +511 -2
  11. main/transformer_utils/mmpose/core/fp16/hooks.py +77 -2
  12. main/transformer_utils/mmpose/core/optimizers/builder.py +20 -7
  13. main/transformer_utils/mmpose/core/optimizers/layer_decay_optimizer_constructor.py +2 -2
  14. main/transformer_utils/mmpose/core/post_processing/smoother.py +2 -2
  15. main/transformer_utils/mmpose/core/post_processing/temporal_filters/builder.py +1 -1
  16. main/transformer_utils/mmpose/core/post_processing/temporal_filters/smoothnet_filter.py +1 -1
  17. main/transformer_utils/mmpose/core/utils/dist_utils.py +1 -1
  18. main/transformer_utils/mmpose/core/utils/model_util_hooks.py +2 -2
  19. main/transformer_utils/mmpose/core/visualization/image.py +1 -1
  20. main/transformer_utils/mmpose/models/__init__.py +1 -1
  21. main/transformer_utils/mmpose/models/backbones/__init__.py +39 -38
  22. main/transformer_utils/mmpose/models/backbones/alexnet.py +0 -56
  23. main/transformer_utils/mmpose/models/backbones/cpm.py +0 -186
  24. main/transformer_utils/mmpose/models/backbones/hourglass.py +0 -212
  25. main/transformer_utils/mmpose/models/backbones/hourglass_ae.py +0 -212
  26. main/transformer_utils/mmpose/models/backbones/hrformer.py +0 -746
  27. main/transformer_utils/mmpose/models/backbones/hrnet.py +0 -604
  28. main/transformer_utils/mmpose/models/backbones/hrt.py +0 -676
  29. main/transformer_utils/mmpose/models/backbones/hrt_checkpoint.py +0 -500
  30. main/transformer_utils/mmpose/models/backbones/i3d.py +0 -215
  31. main/transformer_utils/mmpose/models/backbones/litehrnet.py +0 -984
  32. main/transformer_utils/mmpose/models/backbones/mobilenet_v2.py +0 -275
  33. main/transformer_utils/mmpose/models/backbones/mobilenet_v3.py +0 -188
  34. main/transformer_utils/mmpose/models/backbones/modules/basic_block.py +1 -3
  35. main/transformer_utils/mmpose/models/backbones/mspn.py +0 -513
  36. main/transformer_utils/mmpose/models/backbones/pvt.py +0 -592
  37. main/transformer_utils/mmpose/models/backbones/regnet.py +0 -317
  38. main/transformer_utils/mmpose/models/backbones/resnest.py +0 -338
  39. main/transformer_utils/mmpose/models/backbones/resnet.py +3 -3
  40. main/transformer_utils/mmpose/models/backbones/resnext.py +0 -162
  41. main/transformer_utils/mmpose/models/backbones/rsn.py +0 -616
  42. main/transformer_utils/mmpose/models/backbones/scnet.py +0 -248
  43. main/transformer_utils/mmpose/models/backbones/seresnet.py +0 -125
  44. main/transformer_utils/mmpose/models/backbones/seresnext.py +0 -168
  45. main/transformer_utils/mmpose/models/backbones/shufflenet_v1.py +0 -329
  46. main/transformer_utils/mmpose/models/backbones/shufflenet_v2.py +0 -302
  47. main/transformer_utils/mmpose/models/backbones/swin.py +0 -733
  48. main/transformer_utils/mmpose/models/backbones/tcformer.py +0 -283
  49. main/transformer_utils/mmpose/models/backbones/tcn.py +0 -267
  50. main/transformer_utils/mmpose/models/backbones/utils/utils.py +12 -12
app.py CHANGED
@@ -13,29 +13,22 @@ try:
13
  except:
14
  os.system('pip install /home/user/app/main/transformer_utils')
15
  hf_hub_download(repo_id="caizhongang/SMPLer-X", filename="smpler_x_h32.pth.tar", local_dir="/home/user/app/pretrained_models")
16
- os.system('cp -rf /home/user/app/assets/conversions.py /home/user/.pyenv/versions/3.9.18/lib/python3.9/site-packages/torchgeometry/core/conversions.py')
17
  DEFAULT_MODEL='smpler_x_h32'
18
  OUT_FOLDER = '/home/user/app/demo_out'
19
  os.makedirs(OUT_FOLDER, exist_ok=True)
20
- # num_gpus = 1 if torch.cuda.is_available() else -1
21
- # print("!!!", torch.cuda.is_available())
22
- # print(torch.cuda.device_count())
23
- # print(torch.version.cuda)
24
- # index = torch.cuda.current_device()
25
- # print(index)
26
- # print(torch.cuda.get_device_name(index))
27
  # from main.inference import Inferer
28
  # inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
29
 
30
  @spaces.GPU(enable_queue=True)
31
  def infer(video_input, in_threshold=0.5, num_people="Single person", render_mesh=False):
32
- num_gpus = 1 if torch.cuda.is_available() else -1
33
- print("!!!", torch.cuda.is_available())
34
- print(torch.cuda.device_count())
35
- print(torch.version.cuda)
36
- index = torch.cuda.current_device()
37
- print(index)
38
- print(torch.cuda.get_device_name(index))
39
  from main.inference import Inferer
40
  inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
41
  os.system(f'rm -rf {OUT_FOLDER}/*')
 
13
  except:
14
  os.system('pip install /home/user/app/main/transformer_utils')
15
  hf_hub_download(repo_id="caizhongang/SMPLer-X", filename="smpler_x_h32.pth.tar", local_dir="/home/user/app/pretrained_models")
16
+ os.system('cp -rf /home/user/app/assets/conversions.py /usr/local/lib/python3.10/site-packages/torchgeometry/core/conversions.py')
17
  DEFAULT_MODEL='smpler_x_h32'
18
  OUT_FOLDER = '/home/user/app/demo_out'
19
  os.makedirs(OUT_FOLDER, exist_ok=True)
20
+ num_gpus = 1 if torch.cuda.is_available() else -1
21
+ print("!!!", torch.cuda.is_available())
22
+ print(torch.cuda.device_count())
23
+ print(torch.version.cuda)
24
+ index = torch.cuda.current_device()
25
+ print(index)
26
+ print(torch.cuda.get_device_name(index))
27
  # from main.inference import Inferer
28
  # inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
29
 
30
  @spaces.GPU(enable_queue=True)
31
  def infer(video_input, in_threshold=0.5, num_people="Single person", render_mesh=False):
 
 
 
 
 
 
 
32
  from main.inference import Inferer
33
  inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
34
  os.system(f'rm -rf {OUT_FOLDER}/*')
common/base.py CHANGED
@@ -17,7 +17,7 @@ import torch.utils.data.distributed
17
  from utils.distribute_utils import (
18
  get_rank, is_main_process, time_synchronized, get_group_idx, get_process_groups
19
  )
20
- from mmcv.runner import get_dist_info
21
 
22
  class Base(object):
23
  __metaclass__ = abc.ABCMeta
 
17
  from utils.distribute_utils import (
18
  get_rank, is_main_process, time_synchronized, get_group_idx, get_process_groups
19
  )
20
+
21
 
22
  class Base(object):
23
  __metaclass__ = abc.ABCMeta
common/utils/distribute_utils.py CHANGED
@@ -7,7 +7,7 @@ import tempfile
7
  import time
8
  import torch
9
  import torch.distributed as dist
10
- from mmcv.runner import get_dist_info
11
  import random
12
  import numpy as np
13
  import subprocess
 
7
  import time
8
  import torch
9
  import torch.distributed as dist
10
+ from mmengine.dist import get_dist_info
11
  import random
12
  import numpy as np
13
  import subprocess
main/SMPLer_X.py CHANGED
@@ -9,7 +9,7 @@ from config import cfg
9
  import math
10
  import copy
11
  from mmpose.models import build_posenet
12
- from mmcv import Config
13
 
14
  class Model(nn.Module):
15
  def __init__(self, encoder, body_position_net, body_rotation_net, box_net, hand_position_net, hand_roi_net,
 
9
  import math
10
  import copy
11
  from mmpose.models import build_posenet
12
+ from mmengine.config import Config
13
 
14
  class Model(nn.Module):
15
  def __init__(self, encoder, body_position_net, body_rotation_net, box_net, hand_position_net, hand_roi_net,
main/config.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  import os.path as osp
3
  import sys
4
  import datetime
5
- from mmcv import Config as MMConfig
 
6
 
7
  class Config:
8
  def get_config_fromfile(self, config_path):
 
2
  import os.path as osp
3
  import sys
4
  import datetime
5
+ from mmengine.config import Config as MMConfig
6
+
7
 
8
  class Config:
9
  def get_config_fromfile(self, config_path):
main/inference.py CHANGED
@@ -53,8 +53,14 @@ class Inferer:
53
 
54
  ## mmdet inference
55
  mmdet_results = inference_detector(self.model, original_img)
56
- mmdet_box = process_mmdet_results(mmdet_results, cat_id=0, multi_person=True)
57
 
 
 
 
 
 
 
 
58
  # save original image if no bbox
59
  if len(mmdet_box[0])<1:
60
  return original_img, [], []
 
53
 
54
  ## mmdet inference
55
  mmdet_results = inference_detector(self.model, original_img)
 
56
 
57
+ pred_instance = mmdet_results.pred_instances.cpu().numpy()
58
+ bboxes = np.concatenate(
59
+ (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
60
+ bboxes = bboxes[pred_instance.labels == 0]
61
+ bboxes = np.expand_dims(bboxes, axis=0)
62
+ mmdet_box = process_mmdet_results(bboxes, cat_id=0, multi_person=True)
63
+
64
  # save original image if no bbox
65
  if len(mmdet_box[0])<1:
66
  return original_img, [], []
main/transformer_utils/mmpose/__init__.py CHANGED
@@ -17,7 +17,7 @@ def digit_version(version_str):
17
 
18
 
19
  mmcv_minimum_version = '1.3.8'
20
- mmcv_maximum_version = '1.8.0'
21
  mmcv_version = digit_version(mmcv.__version__)
22
 
23
 
 
17
 
18
 
19
  mmcv_minimum_version = '1.3.8'
20
+ mmcv_maximum_version = '2.3.0'
21
  mmcv_version = digit_version(mmcv.__version__)
22
 
23
 
main/transformer_utils/mmpose/core/camera/camera_base.py CHANGED
@@ -1,7 +1,7 @@
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
  from abc import ABCMeta, abstractmethod
3
 
4
- from mmcv.utils import Registry
5
 
6
  CAMERAS = Registry('camera')
7
 
 
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
  from abc import ABCMeta, abstractmethod
3
 
4
+ from mmengine import Registry
5
 
6
  CAMERAS = Registry('camera')
7
 
main/transformer_utils/mmpose/core/distributed_wrapper.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
4
  from mmcv.parallel import MODULE_WRAPPERS as MMCV_MODULE_WRAPPERS
5
  from mmcv.parallel import MMDistributedDataParallel
6
  from mmcv.parallel.scatter_gather import scatter_kwargs
7
- from mmcv.utils import Registry
8
  from torch.cuda._utils import _get_device_index
9
 
10
  MODULE_WRAPPERS = Registry('module wrapper', parent=MMCV_MODULE_WRAPPERS)
 
4
  from mmcv.parallel import MODULE_WRAPPERS as MMCV_MODULE_WRAPPERS
5
  from mmcv.parallel import MMDistributedDataParallel
6
  from mmcv.parallel.scatter_gather import scatter_kwargs
7
+ from mmengine import Registry
8
  from torch.cuda._utils import _get_device_index
9
 
10
  MODULE_WRAPPERS = Registry('module wrapper', parent=MMCV_MODULE_WRAPPERS)
main/transformer_utils/mmpose/core/evaluation/eval_hooks.py CHANGED
@@ -1,8 +1,18 @@
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
  import warnings
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- from mmcv.runner import DistEvalHook as _DistEvalHook
5
- from mmcv.runner import EvalHook as _EvalHook
6
 
7
  MMPOSE_GREATER_KEYS = [
8
  'acc', 'ap', 'ar', 'pck', 'auc', '3dpck', 'p-3dpck', '3dauc', 'p-3dauc',
@@ -10,6 +20,505 @@ MMPOSE_GREATER_KEYS = [
10
  ]
11
  MMPOSE_LESS_KEYS = ['loss', 'epe', 'nme', 'mpjpe', 'p-mpjpe', 'n-mpjpe']
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class EvalHook(_EvalHook):
15
 
 
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
  import warnings
3
+ import os.path as osp
4
+ import warnings
5
+ from math import inf
6
+ from typing import Callable, List, Optional
7
+
8
+ import torch.distributed as dist
9
+ from torch.nn.modules.batchnorm import _BatchNorm
10
+ from torch.utils.data import DataLoader
11
+
12
+ from mmengine.fileio import FileClient
13
+ from mmengine.utils import is_seq_of
14
+ from mmengine.hooks import Hook, LoggerHook
15
 
 
 
16
 
17
  MMPOSE_GREATER_KEYS = [
18
  'acc', 'ap', 'ar', 'pck', 'auc', '3dpck', 'p-3dpck', '3dauc', 'p-3dauc',
 
20
  ]
21
  MMPOSE_LESS_KEYS = ['loss', 'epe', 'nme', 'mpjpe', 'p-mpjpe', 'n-mpjpe']
22
 
23
+ class _EvalHook(Hook):
24
+ """Non-Distributed evaluation hook.
25
+
26
+ This hook will regularly perform evaluation in a given interval when
27
+ performing in non-distributed environment.
28
+
29
+ Args:
30
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
31
+ implemented ``evaluate`` function.
32
+ start (int | None, optional): Evaluation starting epoch or iteration.
33
+ It enables evaluation before the training starts if ``start`` <=
34
+ the resuming epoch or iteration. If None, whether to evaluate is
35
+ merely decided by ``interval``. Default: None.
36
+ interval (int): Evaluation interval. Default: 1.
37
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
38
+ If set to True, it will perform by epoch. Otherwise, by iteration.
39
+ Default: True.
40
+ save_best (str, optional): If a metric is specified, it would measure
41
+ the best checkpoint during evaluation. The information about best
42
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
43
+ best score value and best checkpoint path, which will be also
44
+ loaded when resume checkpoint. Options are the evaluation metrics
45
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
46
+ detection and instance segmentation. ``AR@100`` for proposal
47
+ recall. If ``save_best`` is ``auto``, the first key of the returned
48
+ ``OrderedDict`` result will be used. Default: None.
49
+ rule (str | None, optional): Comparison rule for best score. If set to
50
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
51
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
52
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
53
+ Default: None.
54
+ test_fn (callable, optional): test a model with samples from a
55
+ dataloader, and return the test results. If ``None``, the default
56
+ test function ``mmcv.engine.single_gpu_test`` will be used.
57
+ (default: ``None``)
58
+ greater_keys (List[str] | None, optional): Metric keys that will be
59
+ inferred by 'greater' comparison rule. If ``None``,
60
+ _default_greater_keys will be used. (default: ``None``)
61
+ less_keys (List[str] | None, optional): Metric keys that will be
62
+ inferred by 'less' comparison rule. If ``None``, _default_less_keys
63
+ will be used. (default: ``None``)
64
+ out_dir (str, optional): The root directory to save checkpoints. If not
65
+ specified, `runner.work_dir` will be used by default. If specified,
66
+ the `out_dir` will be the concatenation of `out_dir` and the last
67
+ level directory of `runner.work_dir`.
68
+ `New in version 1.3.16.`
69
+ file_client_args (dict): Arguments to instantiate a FileClient.
70
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
71
+ `New in version 1.3.16.`
72
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
73
+ the dataset.
74
+
75
+ Note:
76
+ If new arguments are added for EvalHook, tools/test.py,
77
+ tools/eval_metric.py may be affected.
78
+ """
79
+
80
+ # Since the key for determine greater or less is related to the downstream
81
+ # tasks, downstream repos may need to overwrite the following inner
82
+ # variable accordingly.
83
+
84
+ rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
85
+ init_value_map = {'greater': -inf, 'less': inf}
86
+ _default_greater_keys = [
87
+ 'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
88
+ 'mAcc', 'aAcc'
89
+ ]
90
+ _default_less_keys = ['loss']
91
+
92
+ def __init__(self,
93
+ dataloader: DataLoader,
94
+ start: Optional[int] = None,
95
+ interval: int = 1,
96
+ by_epoch: bool = True,
97
+ save_best: Optional[str] = None,
98
+ rule: Optional[str] = None,
99
+ test_fn: Optional[Callable] = None,
100
+ greater_keys: Optional[List[str]] = None,
101
+ less_keys: Optional[List[str]] = None,
102
+ out_dir: Optional[str] = None,
103
+ file_client_args: Optional[dict] = None,
104
+ **eval_kwargs):
105
+ if not isinstance(dataloader, DataLoader):
106
+ raise TypeError(f'dataloader must be a pytorch DataLoader, '
107
+ f'but got {type(dataloader)}')
108
+
109
+ if interval <= 0:
110
+ raise ValueError(f'interval must be a positive number, '
111
+ f'but got {interval}')
112
+
113
+ assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'
114
+
115
+ if start is not None and start < 0:
116
+ raise ValueError(f'The evaluation start epoch {start} is smaller '
117
+ f'than 0')
118
+
119
+ self.dataloader = dataloader
120
+ self.interval = interval
121
+ self.start = start
122
+ self.by_epoch = by_epoch
123
+
124
+ assert isinstance(save_best, str) or save_best is None, \
125
+ '""save_best"" should be a str or None ' \
126
+ f'rather than {type(save_best)}'
127
+ self.save_best = save_best
128
+ self.eval_kwargs = eval_kwargs
129
+ self.initial_flag = True
130
+
131
+ if test_fn is None:
132
+ from mmcv.engine import single_gpu_test
133
+ self.test_fn = single_gpu_test
134
+ else:
135
+ self.test_fn = test_fn
136
+
137
+ if greater_keys is None:
138
+ self.greater_keys = self._default_greater_keys
139
+ else:
140
+ if not isinstance(greater_keys, (list, tuple)):
141
+ assert isinstance(greater_keys, str)
142
+ greater_keys = (greater_keys, )
143
+ assert is_seq_of(greater_keys, str)
144
+ self.greater_keys = greater_keys
145
+
146
+ if less_keys is None:
147
+ self.less_keys = self._default_less_keys
148
+ else:
149
+ if not isinstance(less_keys, (list, tuple)):
150
+ assert isinstance(greater_keys, str)
151
+ less_keys = (less_keys, )
152
+ assert is_seq_of(less_keys, str)
153
+ self.less_keys = less_keys
154
+
155
+ if self.save_best is not None:
156
+ self.best_ckpt_path = None
157
+ self._init_rule(rule, self.save_best)
158
+
159
+ self.out_dir = out_dir
160
+ self.file_client_args = file_client_args
161
+
162
+ def _init_rule(self, rule: Optional[str], key_indicator: str):
163
+ """Initialize rule, key_indicator, comparison_func, and best score.
164
+
165
+ Here is the rule to determine which rule is used for key indicator
166
+ when the rule is not specific (note that the key indicator matching
167
+ is case-insensitive):
168
+ 1. If the key indicator is in ``self.greater_keys``, the rule will be
169
+ specified as 'greater'.
170
+ 2. Or if the key indicator is in ``self.less_keys``, the rule will be
171
+ specified as 'less'.
172
+ 3. Or if any one item in ``self.greater_keys`` is a substring of
173
+ key_indicator , the rule will be specified as 'greater'.
174
+ 4. Or if any one item in ``self.less_keys`` is a substring of
175
+ key_indicator , the rule will be specified as 'less'.
176
+
177
+ Args:
178
+ rule (str | None): Comparison rule for best score.
179
+ key_indicator (str | None): Key indicator to determine the
180
+ comparison rule.
181
+ """
182
+ if rule not in self.rule_map and rule is not None:
183
+ raise KeyError(f'rule must be greater, less or None, '
184
+ f'but got {rule}.')
185
+
186
+ if rule is None:
187
+ if key_indicator != 'auto':
188
+ # `_lc` here means we use the lower case of keys for
189
+ # case-insensitive matching
190
+ assert isinstance(key_indicator, str)
191
+ key_indicator_lc = key_indicator.lower()
192
+ greater_keys = [key.lower() for key in self.greater_keys]
193
+ less_keys = [key.lower() for key in self.less_keys]
194
+
195
+ if key_indicator_lc in greater_keys:
196
+ rule = 'greater'
197
+ elif key_indicator_lc in less_keys:
198
+ rule = 'less'
199
+ elif any(key in key_indicator_lc for key in greater_keys):
200
+ rule = 'greater'
201
+ elif any(key in key_indicator_lc for key in less_keys):
202
+ rule = 'less'
203
+ else:
204
+ raise ValueError(f'Cannot infer the rule for key '
205
+ f'{key_indicator}, thus a specific rule '
206
+ f'must be specified.')
207
+ self.rule = rule
208
+ self.key_indicator = key_indicator
209
+ if self.rule is not None:
210
+ self.compare_func = self.rule_map[self.rule]
211
+
212
+ def before_run(self, runner):
213
+ if not self.out_dir:
214
+ self.out_dir = runner.work_dir
215
+
216
+ self.file_client = FileClient.infer_client(self.file_client_args,
217
+ self.out_dir)
218
+
219
+ # if `self.out_dir` is not equal to `runner.work_dir`, it means that
220
+ # `self.out_dir` is set so the final `self.out_dir` is the
221
+ # concatenation of `self.out_dir` and the last level directory of
222
+ # `runner.work_dir`
223
+ if self.out_dir != runner.work_dir:
224
+ basename = osp.basename(runner.work_dir.rstrip(osp.sep))
225
+ self.out_dir = self.file_client.join_path(self.out_dir, basename)
226
+ runner.logger.info(
227
+ f'The best checkpoint will be saved to {self.out_dir} by '
228
+ f'{self.file_client.name}')
229
+
230
+ if self.save_best is not None:
231
+ if runner.meta is None:
232
+ warnings.warn('runner.meta is None. Creating an empty one.')
233
+ runner.meta = dict()
234
+ runner.meta.setdefault('hook_msgs', dict())
235
+ self.best_ckpt_path = runner.meta['hook_msgs'].get(
236
+ 'best_ckpt', None)
237
+
238
+ def before_train_iter(self, runner):
239
+ """Evaluate the model only at the start of training by iteration."""
240
+ if self.by_epoch or not self.initial_flag:
241
+ return
242
+ if self.start is not None and runner.iter >= self.start:
243
+ self.after_train_iter(runner)
244
+ self.initial_flag = False
245
+
246
+ def before_train_epoch(self, runner):
247
+ """Evaluate the model only at the start of training by epoch."""
248
+ if not (self.by_epoch and self.initial_flag):
249
+ return
250
+ if self.start is not None and runner.epoch >= self.start:
251
+ self.after_train_epoch(runner)
252
+ self.initial_flag = False
253
+
254
+ def after_train_iter(self, runner):
255
+ """Called after every training iter to evaluate the results."""
256
+ if not self.by_epoch and self._should_evaluate(runner):
257
+ # Because the priority of EvalHook is higher than LoggerHook, the
258
+ # training log and the evaluating log are mixed. Therefore,
259
+ # we need to dump the training log and clear it before evaluating
260
+ # log is generated. In addition, this problem will only appear in
261
+ # `IterBasedRunner` whose `self.by_epoch` is False, because
262
+ # `EpochBasedRunner` whose `self.by_epoch` is True calls
263
+ # `_do_evaluate` in `after_train_epoch` stage, and at this stage
264
+ # the training log has been printed, so it will not cause any
265
+ # problem. more details at
266
+ # https://github.com/open-mmlab/mmsegmentation/issues/694
267
+ for hook in runner._hooks:
268
+ if isinstance(hook, LoggerHook):
269
+ hook.after_train_iter(runner)
270
+ runner.log_buffer.clear()
271
+
272
+ self._do_evaluate(runner)
273
+
274
+ def after_train_epoch(self, runner):
275
+ """Called after every training epoch to evaluate the results."""
276
+ if self.by_epoch and self._should_evaluate(runner):
277
+ self._do_evaluate(runner)
278
+
279
+ def _do_evaluate(self, runner):
280
+ """perform evaluation and save ckpt."""
281
+ results = self.test_fn(runner.model, self.dataloader)
282
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
283
+ key_score = self.evaluate(runner, results)
284
+ # the key_score may be `None` so it needs to skip the action to save
285
+ # the best checkpoint
286
+ if self.save_best and key_score:
287
+ self._save_ckpt(runner, key_score)
288
+
289
+ def _should_evaluate(self, runner):
290
+ """Judge whether to perform evaluation.
291
+
292
+ Here is the rule to judge whether to perform evaluation:
293
+ 1. It will not perform evaluation during the epoch/iteration interval,
294
+ which is determined by ``self.interval``.
295
+ 2. It will not perform evaluation if the start time is larger than
296
+ current time.
297
+ 3. It will not perform evaluation when current time is larger than
298
+ the start time but during epoch/iteration interval.
299
+
300
+ Returns:
301
+ bool: The flag indicating whether to perform evaluation.
302
+ """
303
+ if self.by_epoch:
304
+ current = runner.epoch
305
+ check_time = self.every_n_epochs
306
+ else:
307
+ current = runner.iter
308
+ check_time = self.every_n_iters
309
+
310
+ if self.start is None:
311
+ if not check_time(runner, self.interval):
312
+ # No evaluation during the interval.
313
+ return False
314
+ elif (current + 1) < self.start:
315
+ # No evaluation if start is larger than the current time.
316
+ return False
317
+ else:
318
+ # Evaluation only at epochs/iters 3, 5, 7...
319
+ # if start==3 and interval==2
320
+ if (current + 1 - self.start) % self.interval:
321
+ return False
322
+ return True
323
+
324
+ def _save_ckpt(self, runner, key_score):
325
+ """Save the best checkpoint.
326
+
327
+ It will compare the score according to the compare function, write
328
+ related information (best score, best checkpoint path) and save the
329
+ best checkpoint into ``work_dir``.
330
+ """
331
+ if self.by_epoch:
332
+ current = f'epoch_{runner.epoch + 1}'
333
+ cur_type, cur_time = 'epoch', runner.epoch + 1
334
+ else:
335
+ current = f'iter_{runner.iter + 1}'
336
+ cur_type, cur_time = 'iter', runner.iter + 1
337
+
338
+ best_score = runner.meta['hook_msgs'].get(
339
+ 'best_score', self.init_value_map[self.rule])
340
+ if self.compare_func(key_score, best_score):
341
+ best_score = key_score
342
+ runner.meta['hook_msgs']['best_score'] = best_score
343
+
344
+ if self.best_ckpt_path and self.file_client.isfile(
345
+ self.best_ckpt_path):
346
+ self.file_client.remove(self.best_ckpt_path)
347
+ runner.logger.info(
348
+ f'The previous best checkpoint {self.best_ckpt_path} was '
349
+ 'removed')
350
+
351
+ best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
352
+ self.best_ckpt_path = self.file_client.join_path(
353
+ self.out_dir, best_ckpt_name)
354
+ runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
355
+
356
+ runner.save_checkpoint(
357
+ self.out_dir,
358
+ filename_tmpl=best_ckpt_name,
359
+ create_symlink=False)
360
+ runner.logger.info(
361
+ f'Now best checkpoint is saved as {best_ckpt_name}.')
362
+ runner.logger.info(
363
+ f'Best {self.key_indicator} is {best_score:0.4f} '
364
+ f'at {cur_time} {cur_type}.')
365
+
366
+ def evaluate(self, runner, results):
367
+ """Evaluate the results.
368
+
369
+ Args:
370
+ runner (:obj:`mmcv.Runner`): The underlined training runner.
371
+ results (list): Output results.
372
+ """
373
+ eval_res = self.dataloader.dataset.evaluate(
374
+ results, logger=runner.logger, **self.eval_kwargs)
375
+
376
+ for name, val in eval_res.items():
377
+ runner.log_buffer.output[name] = val
378
+ runner.log_buffer.ready = True
379
+
380
+ if self.save_best is not None:
381
+ # If the performance of model is poor, the `eval_res` may be an
382
+ # empty dict and it will raise exception when `self.save_best` is
383
+ # not None. More details at
384
+ # https://github.com/open-mmlab/mmdetection/issues/6265.
385
+ if not eval_res:
386
+ warnings.warn(
387
+ 'Since `eval_res` is an empty dict, the behavior to save '
388
+ 'the best checkpoint will be skipped in this evaluation.')
389
+ return None
390
+
391
+ if self.key_indicator == 'auto':
392
+ # infer from eval_results
393
+ self._init_rule(self.rule, list(eval_res.keys())[0])
394
+ return eval_res[self.key_indicator]
395
+
396
+ return None
397
+
398
+
399
+ class _DistEvalHook(_EvalHook):
400
+ """Distributed evaluation hook.
401
+
402
+ This hook will regularly perform evaluation in a given interval when
403
+ performing in distributed environment.
404
+
405
+ Args:
406
+ dataloader (DataLoader): A PyTorch dataloader, whose dataset has
407
+ implemented ``evaluate`` function.
408
+ start (int | None, optional): Evaluation starting epoch. It enables
409
+ evaluation before the training starts if ``start`` <= the resuming
410
+ epoch. If None, whether to evaluate is merely decided by
411
+ ``interval``. Default: None.
412
+ interval (int): Evaluation interval. Default: 1.
413
+ by_epoch (bool): Determine perform evaluation by epoch or by iteration.
414
+ If set to True, it will perform by epoch. Otherwise, by iteration.
415
+ default: True.
416
+ save_best (str, optional): If a metric is specified, it would measure
417
+ the best checkpoint during evaluation. The information about best
418
+ checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
419
+ best score value and best checkpoint path, which will be also
420
+ loaded when resume checkpoint. Options are the evaluation metrics
421
+ on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
422
+ detection and instance segmentation. ``AR@100`` for proposal
423
+ recall. If ``save_best`` is ``auto``, the first key of the returned
424
+ ``OrderedDict`` result will be used. Default: None.
425
+ rule (str | None, optional): Comparison rule for best score. If set to
426
+ None, it will infer a reasonable rule. Keys such as 'acc', 'top'
427
+ .etc will be inferred by 'greater' rule. Keys contain 'loss' will
428
+ be inferred by 'less' rule. Options are 'greater', 'less', None.
429
+ Default: None.
430
+ test_fn (callable, optional): test a model with samples from a
431
+ dataloader in a multi-gpu manner, and return the test results. If
432
+ ``None``, the default test function ``mmcv.engine.multi_gpu_test``
433
+ will be used. (default: ``None``)
434
+ tmpdir (str | None): Temporary directory to save the results of all
435
+ processes. Default: None.
436
+ gpu_collect (bool): Whether to use gpu or cpu to collect results.
437
+ Default: False.
438
+ broadcast_bn_buffer (bool): Whether to broadcast the
439
+ buffer(running_mean and running_var) of rank 0 to other rank
440
+ before evaluation. Default: True.
441
+ out_dir (str, optional): The root directory to save checkpoints. If not
442
+ specified, `runner.work_dir` will be used by default. If specified,
443
+ the `out_dir` will be the concatenation of `out_dir` and the last
444
+ level directory of `runner.work_dir`.
445
+ file_client_args (dict): Arguments to instantiate a FileClient.
446
+ See :class:`mmcv.fileio.FileClient` for details. Default: None.
447
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
448
+ the dataset.
449
+ """
450
+
451
+ def __init__(self,
452
+ dataloader: DataLoader,
453
+ start: Optional[int] = None,
454
+ interval: int = 1,
455
+ by_epoch: bool = True,
456
+ save_best: Optional[str] = None,
457
+ rule: Optional[str] = None,
458
+ test_fn: Optional[Callable] = None,
459
+ greater_keys: Optional[List[str]] = None,
460
+ less_keys: Optional[List[str]] = None,
461
+ broadcast_bn_buffer: bool = True,
462
+ tmpdir: Optional[str] = None,
463
+ gpu_collect: bool = False,
464
+ out_dir: Optional[str] = None,
465
+ file_client_args: Optional[dict] = None,
466
+ **eval_kwargs):
467
+
468
+ if test_fn is None:
469
+ from mmcv.engine import multi_gpu_test
470
+ test_fn = multi_gpu_test
471
+
472
+ super().__init__(
473
+ dataloader,
474
+ start=start,
475
+ interval=interval,
476
+ by_epoch=by_epoch,
477
+ save_best=save_best,
478
+ rule=rule,
479
+ test_fn=test_fn,
480
+ greater_keys=greater_keys,
481
+ less_keys=less_keys,
482
+ out_dir=out_dir,
483
+ file_client_args=file_client_args,
484
+ **eval_kwargs)
485
+
486
+ self.broadcast_bn_buffer = broadcast_bn_buffer
487
+ self.tmpdir = tmpdir
488
+ self.gpu_collect = gpu_collect
489
+
490
+ def _do_evaluate(self, runner):
491
+ """perform evaluation and save ckpt."""
492
+ # Synchronization of BatchNorm's buffer (running_mean
493
+ # and running_var) is not supported in the DDP of pytorch,
494
+ # which may cause the inconsistent performance of models in
495
+ # different ranks, so we broadcast BatchNorm's buffers
496
+ # of rank 0 to other ranks to avoid this.
497
+ if self.broadcast_bn_buffer:
498
+ model = runner.model
499
+ for name, module in model.named_modules():
500
+ if isinstance(module,
501
+ _BatchNorm) and module.track_running_stats:
502
+ dist.broadcast(module.running_var, 0)
503
+ dist.broadcast(module.running_mean, 0)
504
+
505
+ tmpdir = self.tmpdir
506
+ if tmpdir is None:
507
+ tmpdir = osp.join(runner.work_dir, '.eval_hook')
508
+
509
+ results = self.test_fn(
510
+ runner.model,
511
+ self.dataloader,
512
+ tmpdir=tmpdir,
513
+ gpu_collect=self.gpu_collect)
514
+ if runner.rank == 0:
515
+ print('\n')
516
+ runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
517
+ key_score = self.evaluate(runner, results)
518
+ # the key_score may be `None` so it needs to skip the action to
519
+ # save the best checkpoint
520
+ if self.save_best and key_score:
521
+ self._save_ckpt(runner, key_score)
522
 
523
  class EvalHook(_EvalHook):
524
 
main/transformer_utils/mmpose/core/fp16/hooks.py CHANGED
@@ -1,15 +1,90 @@
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
  import copy
 
 
3
 
4
  import torch
5
  import torch.nn as nn
6
- from mmcv.runner import OptimizerHook
7
- from mmcv.utils import _BatchNorm
 
 
8
 
9
  from ..utils.dist_utils import allreduce_grads
10
  from .utils import cast_tensor_type
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class Fp16OptimizerHook(OptimizerHook):
14
  """FP16 optimizer hook.
15
 
 
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
  import copy
3
+ import logging
4
+ from typing import Optional
5
 
6
  import torch
7
  import torch.nn as nn
8
+ from torch import Tensor
9
+ from torch.nn.utils import clip_grad
10
+ from mmengine.hooks import Hook
11
+ from torch.nn.modules.batchnorm import _BatchNorm
12
 
13
  from ..utils.dist_utils import allreduce_grads
14
  from .utils import cast_tensor_type
15
 
16
 
17
+ class OptimizerHook(Hook):
18
+ """A hook contains custom operations for the optimizer.
19
+
20
+ Args:
21
+ grad_clip (dict, optional): A config dict to control the clip_grad.
22
+ Default: None.
23
+ detect_anomalous_params (bool): This option is only used for
24
+ debugging which will slow down the training speed.
25
+ Detect anomalous parameters that are not included in
26
+ the computational graph with `loss` as the root.
27
+ There are two cases
28
+
29
+ - Parameters were not used during
30
+ forward pass.
31
+ - Parameters were not used to produce
32
+ loss.
33
+ Default: False.
34
+ """
35
+
36
+ def __init__(self,
37
+ grad_clip: Optional[dict] = None,
38
+ detect_anomalous_params: bool = False):
39
+ self.grad_clip = grad_clip
40
+ self.detect_anomalous_params = detect_anomalous_params
41
+
42
+ def clip_grads(self, params):
43
+ params = list(
44
+ filter(lambda p: p.requires_grad and p.grad is not None, params))
45
+ if len(params) > 0:
46
+ return clip_grad.clip_grad_norm_(params, **self.grad_clip)
47
+
48
+ def after_train_iter(self, runner):
49
+ runner.optimizer.zero_grad()
50
+ if self.detect_anomalous_params:
51
+ self.detect_anomalous_parameters(runner.outputs['loss'], runner)
52
+ runner.outputs['loss'].backward()
53
+
54
+ if self.grad_clip is not None:
55
+ grad_norm = self.clip_grads(runner.model.parameters())
56
+ if grad_norm is not None:
57
+ # Add grad norm to the logger
58
+ runner.log_buffer.update({'grad_norm': float(grad_norm)},
59
+ runner.outputs['num_samples'])
60
+ runner.optimizer.step()
61
+
62
+ def detect_anomalous_parameters(self, loss: Tensor, runner) -> None:
63
+ logger = runner.logger
64
+ parameters_in_graph = set()
65
+ visited = set()
66
+
67
+ def traverse(grad_fn):
68
+ if grad_fn is None:
69
+ return
70
+ if grad_fn not in visited:
71
+ visited.add(grad_fn)
72
+ if hasattr(grad_fn, 'variable'):
73
+ parameters_in_graph.add(grad_fn.variable)
74
+ parents = grad_fn.next_functions
75
+ if parents is not None:
76
+ for parent in parents:
77
+ grad_fn = parent[0]
78
+ traverse(grad_fn)
79
+
80
+ traverse(loss.grad_fn)
81
+ for n, p in runner.model.named_parameters():
82
+ if p not in parameters_in_graph and p.requires_grad:
83
+ logger.log(
84
+ level=logging.ERROR,
85
+ msg=f'{n} with shape {p.size()} is not '
86
+ f'in the computational graph \n')
87
+
88
  class Fp16OptimizerHook(OptimizerHook):
89
  """FP16 optimizer hook.
90
 
main/transformer_utils/mmpose/core/optimizers/builder.py CHANGED
@@ -1,24 +1,37 @@
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
- from mmcv.runner import build_optimizer
3
- from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
4
- from mmcv.utils import Registry, build_from_cfg
 
 
5
 
6
  OPTIMIZERS = Registry('optimizers')
7
- OPTIMIZER_BUILDERS = Registry(
8
- 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
9
 
10
 
11
  def build_optimizer_constructor(cfg):
12
  constructor_type = cfg.get('type')
13
  if constructor_type in OPTIMIZER_BUILDERS:
14
  return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
15
- elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
16
- return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
17
  else:
18
  raise KeyError(f'{constructor_type} is not registered '
19
  'in the optimizer builder registry.')
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def build_optimizers(model, cfgs):
23
  """Build multiple optimizers from configs.
24
 
 
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ from typing import Dict
4
+ # from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
5
+ from mmengine import Registry
6
+ from mmengine.registry import build_from_cfg
7
 
8
  OPTIMIZERS = Registry('optimizers')
9
+ OPTIMIZER_BUILDERS = Registry('optimizer builder')
 
10
 
11
 
12
  def build_optimizer_constructor(cfg):
13
  constructor_type = cfg.get('type')
14
  if constructor_type in OPTIMIZER_BUILDERS:
15
  return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
 
 
16
  else:
17
  raise KeyError(f'{constructor_type} is not registered '
18
  'in the optimizer builder registry.')
19
 
20
 
21
+ def build_optimizer(model, cfg: Dict):
22
+ optimizer_cfg = copy.deepcopy(cfg)
23
+ constructor_type = optimizer_cfg.pop('constructor',
24
+ 'DefaultOptimizerConstructor')
25
+ paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
26
+ optim_constructor = build_optimizer_constructor(
27
+ dict(
28
+ type=constructor_type,
29
+ optimizer_cfg=optimizer_cfg,
30
+ paramwise_cfg=paramwise_cfg))
31
+ optimizer = optim_constructor(model)
32
+ return optimizer
33
+
34
+
35
  def build_optimizers(model, cfgs):
36
  """Build multiple optimizers from configs.
37
 
main/transformer_utils/mmpose/core/optimizers/layer_decay_optimizer_constructor.py CHANGED
@@ -1,8 +1,8 @@
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
  import json
3
  import warnings
4
-
5
- from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
6
 
7
  from mmpose.utils import get_root_logger
8
  from .builder import OPTIMIZER_BUILDERS
 
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
  import json
3
  import warnings
4
+ from mmengine.dist import get_dist_info
5
+ from mmcv.runner import DefaultOptimizerConstructor
6
 
7
  from mmpose.utils import get_root_logger
8
  from .builder import OPTIMIZER_BUILDERS
main/transformer_utils/mmpose/core/post_processing/smoother.py CHANGED
@@ -4,8 +4,8 @@ import warnings
4
  from typing import Dict, Union
5
 
6
  import numpy as np
7
- from mmcv import Config, is_seq_of
8
-
9
  from mmpose.core.post_processing.temporal_filters import build_filter
10
 
11
 
 
4
  from typing import Dict, Union
5
 
6
  import numpy as np
7
+ from mmengine.config import Config
8
+ from mmengine.utils import is_seq_of
9
  from mmpose.core.post_processing.temporal_filters import build_filter
10
 
11
 
main/transformer_utils/mmpose/core/post_processing/temporal_filters/builder.py CHANGED
@@ -1,5 +1,5 @@
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
- from mmcv.utils import Registry
3
 
4
  FILTERS = Registry('filters')
5
 
 
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmengine import Registry
3
 
4
  FILTERS = Registry('filters')
5
 
main/transformer_utils/mmpose/core/post_processing/temporal_filters/smoothnet_filter.py CHANGED
@@ -3,7 +3,7 @@ from typing import Optional
3
 
4
  import numpy as np
5
  import torch
6
- from mmcv.runner import load_checkpoint
7
  from torch import Tensor, nn
8
 
9
  from .builder import FILTERS
 
3
 
4
  import numpy as np
5
  import torch
6
+ from mmengine.runner import load_checkpoint
7
  from torch import Tensor, nn
8
 
9
  from .builder import FILTERS
main/transformer_utils/mmpose/core/utils/dist_utils.py CHANGED
@@ -4,7 +4,7 @@ from collections import OrderedDict
4
  import numpy as np
5
  import torch
6
  import torch.distributed as dist
7
- from mmcv.runner import get_dist_info
8
  from torch._utils import (_flatten_dense_tensors, _take_tensors,
9
  _unflatten_dense_tensors)
10
 
 
4
  import numpy as np
5
  import torch
6
  import torch.distributed as dist
7
+ from mmengine.dist import get_dist_info
8
  from torch._utils import (_flatten_dense_tensors, _take_tensors,
9
  _unflatten_dense_tensors)
10
 
main/transformer_utils/mmpose/core/utils/model_util_hooks.py CHANGED
@@ -1,6 +1,6 @@
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
- from mmcv.runner import HOOKS, Hook
3
-
4
 
5
  @HOOKS.register_module()
6
  class ModelSetEpochHook(Hook):
 
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
+ from mmengine.registry import HOOKS
3
+ from mmengine.hooks import Hook
4
 
5
  @HOOKS.register_module()
6
  class ModelSetEpochHook(Hook):
main/transformer_utils/mmpose/core/visualization/image.py CHANGED
@@ -7,7 +7,7 @@ import cv2
7
  import mmcv
8
  import numpy as np
9
  from matplotlib import pyplot as plt
10
- from mmcv.utils.misc import deprecated_api_warning
11
  from mmcv.visualization.color import color_val
12
 
13
  try:
 
7
  import mmcv
8
  import numpy as np
9
  from matplotlib import pyplot as plt
10
+ from mmengine.utils import deprecated_api_warning
11
  from mmcv.visualization.color import color_val
12
 
13
  try:
main/transformer_utils/mmpose/models/__init__.py CHANGED
@@ -3,9 +3,9 @@ from .builder import (BACKBONES, HEADS, LOSSES, MESH_MODELS, NECKS, POSENETS,
3
  build_backbone, build_head, build_loss, build_mesh_model,
4
  build_neck, build_posenet)
5
  from .detectors import * # noqa
 
6
  from .heads import * # noqa
7
  from .losses import * # noqa
8
- from .necks import * # noqa
9
  from .utils import * # noqa
10
 
11
 
 
3
  build_backbone, build_head, build_loss, build_mesh_model,
4
  build_neck, build_posenet)
5
  from .detectors import * # noqa
6
+ from .backbones import *
7
  from .heads import * # noqa
8
  from .losses import * # noqa
 
9
  from .utils import * # noqa
10
 
11
 
main/transformer_utils/mmpose/models/backbones/__init__.py CHANGED
@@ -1,41 +1,42 @@
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
- from .alexnet import AlexNet
3
- from .cpm import CPM
4
- from .hourglass import HourglassNet
5
- from .hourglass_ae import HourglassAENet
6
- from .hrformer import HRFormer
7
- from .hrnet import HRNet
8
- from .i3d import I3D
9
- from .litehrnet import LiteHRNet
10
- from .mobilenet_v2 import MobileNetV2
11
- from .mobilenet_v3 import MobileNetV3
12
- from .mspn import MSPN
13
- from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2
14
- from .regnet import RegNet
15
- from .resnest import ResNeSt
16
- from .resnet import ResNet, ResNetV1d
17
- from .resnext import ResNeXt
18
- from .rsn import RSN
19
- from .scnet import SCNet
20
- from .seresnet import SEResNet
21
- from .seresnext import SEResNeXt
22
- from .shufflenet_v1 import ShuffleNetV1
23
- from .shufflenet_v2 import ShuffleNetV2
24
- from .swin import SwinTransformer
25
- from .tcformer import TCFormer
26
- from .tcn import TCN
27
- from .v2v_net import V2VNet
28
- from .vgg import VGG
29
- from .vipnas_mbv3 import ViPNAS_MobileNetV3
30
- from .vipnas_resnet import ViPNAS_ResNet
31
- from .hrt import HRT
32
  from .vit import ViT
33
 
34
- __all__ = [
35
- 'AlexNet', 'HourglassNet', 'HourglassAENet', 'HRNet', 'MobileNetV2',
36
- 'MobileNetV3', 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SCNet',
37
- 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN',
38
- 'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3',
39
- 'LiteHRNet', 'V2VNet', 'HRFormer', 'PyramidVisionTransformer',
40
- 'PyramidVisionTransformerV2', 'SwinTransformer', 'I3D', 'TCFormer', 'ViT'
41
- ]
 
 
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
+ # from .alexnet import AlexNet
3
+ # from .cpm import CPM
4
+ # from .hourglass import HourglassNet
5
+ # from .hourglass_ae import HourglassAENet
6
+ # from .hrformer import HRFormer
7
+ # from .hrnet import HRNet
8
+ # from .i3d import I3D
9
+ # from .litehrnet import LiteHRNet
10
+ # from .mobilenet_v2 import MobileNetV2
11
+ # from .mobilenet_v3 import MobileNetV3
12
+ # from .mspn import MSPN
13
+ # from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2
14
+ # from .regnet import RegNet
15
+ # from .resnest import ResNeSt
16
+ # from .resnet import ResNet, ResNetV1d
17
+ # from .resnext import ResNeXt
18
+ # from .rsn import RSN
19
+ # from .scnet import SCNet
20
+ # from .seresnet import SEResNet
21
+ # from .seresnext import SEResNeXt
22
+ # from .shufflenet_v1 import ShuffleNetV1
23
+ # from .shufflenet_v2 import ShuffleNetV2
24
+ # from .swin import SwinTransformer
25
+ # from .tcformer import TCFormer
26
+ # from .tcn import TCN
27
+ # from .v2v_net import V2VNet
28
+ # from .vgg import VGG
29
+ # from .vipnas_mbv3 import ViPNAS_MobileNetV3
30
+ # from .vipnas_resnet import ViPNAS_ResNet
31
+ # from .hrt import HRT
32
  from .vit import ViT
33
 
34
+ # __all__ = [
35
+ # 'AlexNet', 'HourglassNet', 'HourglassAENet', 'HRNet', 'MobileNetV2',
36
+ # 'MobileNetV3', 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SCNet',
37
+ # 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN',
38
+ # 'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3',
39
+ # 'LiteHRNet', 'V2VNet', 'HRFormer', 'PyramidVisionTransformer',
40
+ # 'PyramidVisionTransformerV2', 'SwinTransformer', 'I3D', 'TCFormer', 'ViT'
41
+ # ]
42
+ __all__ = ['ViT']
main/transformer_utils/mmpose/models/backbones/alexnet.py DELETED
@@ -1,56 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import torch.nn as nn
3
-
4
- from ..builder import BACKBONES
5
- from .base_backbone import BaseBackbone
6
-
7
-
8
- @BACKBONES.register_module()
9
- class AlexNet(BaseBackbone):
10
- """`AlexNet <https://en.wikipedia.org/wiki/AlexNet>`__ backbone.
11
-
12
- The input for AlexNet is a 224x224 RGB image.
13
-
14
- Args:
15
- num_classes (int): number of classes for classification.
16
- The default value is -1, which uses the backbone as
17
- a feature extractor without the top classifier.
18
- """
19
-
20
- def __init__(self, num_classes=-1):
21
- super().__init__()
22
- self.num_classes = num_classes
23
- self.features = nn.Sequential(
24
- nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
25
- nn.ReLU(inplace=True),
26
- nn.MaxPool2d(kernel_size=3, stride=2),
27
- nn.Conv2d(64, 192, kernel_size=5, padding=2),
28
- nn.ReLU(inplace=True),
29
- nn.MaxPool2d(kernel_size=3, stride=2),
30
- nn.Conv2d(192, 384, kernel_size=3, padding=1),
31
- nn.ReLU(inplace=True),
32
- nn.Conv2d(384, 256, kernel_size=3, padding=1),
33
- nn.ReLU(inplace=True),
34
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
35
- nn.ReLU(inplace=True),
36
- nn.MaxPool2d(kernel_size=3, stride=2),
37
- )
38
- if self.num_classes > 0:
39
- self.classifier = nn.Sequential(
40
- nn.Dropout(),
41
- nn.Linear(256 * 6 * 6, 4096),
42
- nn.ReLU(inplace=True),
43
- nn.Dropout(),
44
- nn.Linear(4096, 4096),
45
- nn.ReLU(inplace=True),
46
- nn.Linear(4096, num_classes),
47
- )
48
-
49
- def forward(self, x):
50
-
51
- x = self.features(x)
52
- if self.num_classes > 0:
53
- x = x.view(x.size(0), 256 * 6 * 6)
54
- x = self.classifier(x)
55
-
56
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/cpm.py DELETED
@@ -1,186 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
-
4
- import torch
5
- import torch.nn as nn
6
- from mmcv.cnn import ConvModule, constant_init, normal_init
7
- from torch.nn.modules.batchnorm import _BatchNorm
8
-
9
- from mmpose.utils import get_root_logger
10
- from ..builder import BACKBONES
11
- from .base_backbone import BaseBackbone
12
- from .utils import load_checkpoint
13
-
14
-
15
- class CpmBlock(nn.Module):
16
- """CpmBlock for Convolutional Pose Machine.
17
-
18
- Args:
19
- in_channels (int): Input channels of this block.
20
- channels (list): Output channels of each conv module.
21
- kernels (list): Kernel sizes of each conv module.
22
- """
23
-
24
- def __init__(self,
25
- in_channels,
26
- channels=(128, 128, 128),
27
- kernels=(11, 11, 11),
28
- norm_cfg=None):
29
- super().__init__()
30
-
31
- assert len(channels) == len(kernels)
32
- layers = []
33
- for i in range(len(channels)):
34
- if i == 0:
35
- input_channels = in_channels
36
- else:
37
- input_channels = channels[i - 1]
38
- layers.append(
39
- ConvModule(
40
- input_channels,
41
- channels[i],
42
- kernels[i],
43
- padding=(kernels[i] - 1) // 2,
44
- norm_cfg=norm_cfg))
45
- self.model = nn.Sequential(*layers)
46
-
47
- def forward(self, x):
48
- """Model forward function."""
49
- out = self.model(x)
50
- return out
51
-
52
-
53
- @BACKBONES.register_module()
54
- class CPM(BaseBackbone):
55
- """CPM backbone.
56
-
57
- Convolutional Pose Machines.
58
- More details can be found in the `paper
59
- <https://arxiv.org/abs/1602.00134>`__ .
60
-
61
- Args:
62
- in_channels (int): The input channels of the CPM.
63
- out_channels (int): The output channels of the CPM.
64
- feat_channels (int): Feature channel of each CPM stage.
65
- middle_channels (int): Feature channel of conv after the middle stage.
66
- num_stages (int): Number of stages.
67
- norm_cfg (dict): Dictionary to construct and config norm layer.
68
-
69
- Example:
70
- >>> from mmpose.models import CPM
71
- >>> import torch
72
- >>> self = CPM(3, 17)
73
- >>> self.eval()
74
- >>> inputs = torch.rand(1, 3, 368, 368)
75
- >>> level_outputs = self.forward(inputs)
76
- >>> for level_output in level_outputs:
77
- ... print(tuple(level_output.shape))
78
- (1, 17, 46, 46)
79
- (1, 17, 46, 46)
80
- (1, 17, 46, 46)
81
- (1, 17, 46, 46)
82
- (1, 17, 46, 46)
83
- (1, 17, 46, 46)
84
- """
85
-
86
- def __init__(self,
87
- in_channels,
88
- out_channels,
89
- feat_channels=128,
90
- middle_channels=32,
91
- num_stages=6,
92
- norm_cfg=dict(type='BN', requires_grad=True)):
93
- # Protect mutable default arguments
94
- norm_cfg = copy.deepcopy(norm_cfg)
95
- super().__init__()
96
-
97
- assert in_channels == 3
98
-
99
- self.num_stages = num_stages
100
- assert self.num_stages >= 1
101
-
102
- self.stem = nn.Sequential(
103
- ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg),
104
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
105
- ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
106
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
107
- ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
108
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
109
- ConvModule(128, 32, 5, padding=2, norm_cfg=norm_cfg),
110
- ConvModule(32, 512, 9, padding=4, norm_cfg=norm_cfg),
111
- ConvModule(512, 512, 1, padding=0, norm_cfg=norm_cfg),
112
- ConvModule(512, out_channels, 1, padding=0, act_cfg=None))
113
-
114
- self.middle = nn.Sequential(
115
- ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg),
116
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
117
- ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
118
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
119
- ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
120
- nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
121
-
122
- self.cpm_stages = nn.ModuleList([
123
- CpmBlock(
124
- middle_channels + out_channels,
125
- channels=[feat_channels, feat_channels, feat_channels],
126
- kernels=[11, 11, 11],
127
- norm_cfg=norm_cfg) for _ in range(num_stages - 1)
128
- ])
129
-
130
- self.middle_conv = nn.ModuleList([
131
- nn.Sequential(
132
- ConvModule(
133
- 128, middle_channels, 5, padding=2, norm_cfg=norm_cfg))
134
- for _ in range(num_stages - 1)
135
- ])
136
-
137
- self.out_convs = nn.ModuleList([
138
- nn.Sequential(
139
- ConvModule(
140
- feat_channels,
141
- feat_channels,
142
- 1,
143
- padding=0,
144
- norm_cfg=norm_cfg),
145
- ConvModule(feat_channels, out_channels, 1, act_cfg=None))
146
- for _ in range(num_stages - 1)
147
- ])
148
-
149
- def init_weights(self, pretrained=None):
150
- """Initialize the weights in backbone.
151
-
152
- Args:
153
- pretrained (str, optional): Path to pre-trained weights.
154
- Defaults to None.
155
- """
156
- if isinstance(pretrained, str):
157
- logger = get_root_logger()
158
- load_checkpoint(self, pretrained, strict=False, logger=logger)
159
- elif pretrained is None:
160
- for m in self.modules():
161
- if isinstance(m, nn.Conv2d):
162
- normal_init(m, std=0.001)
163
- elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
164
- constant_init(m, 1)
165
- else:
166
- raise TypeError('pretrained must be a str or None')
167
-
168
- def forward(self, x):
169
- """Model forward function."""
170
- stage1_out = self.stem(x)
171
- middle_out = self.middle(x)
172
- out_feats = []
173
-
174
- out_feats.append(stage1_out)
175
-
176
- for ind in range(self.num_stages - 1):
177
- single_stage = self.cpm_stages[ind]
178
- out_conv = self.out_convs[ind]
179
-
180
- inp_feat = torch.cat(
181
- [out_feats[-1], self.middle_conv[ind](middle_out)], 1)
182
- cpm_feat = single_stage(inp_feat)
183
- out_feat = out_conv(cpm_feat)
184
- out_feats.append(out_feat)
185
-
186
- return out_feats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/hourglass.py DELETED
@@ -1,212 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
-
4
- import torch.nn as nn
5
- from mmcv.cnn import ConvModule, constant_init, normal_init
6
- from torch.nn.modules.batchnorm import _BatchNorm
7
-
8
- from mmpose.utils import get_root_logger
9
- from ..builder import BACKBONES
10
- from .base_backbone import BaseBackbone
11
- from .resnet import BasicBlock, ResLayer
12
- from .utils import load_checkpoint
13
-
14
-
15
- class HourglassModule(nn.Module):
16
- """Hourglass Module for HourglassNet backbone.
17
-
18
- Generate module recursively and use BasicBlock as the base unit.
19
-
20
- Args:
21
- depth (int): Depth of current HourglassModule.
22
- stage_channels (list[int]): Feature channels of sub-modules in current
23
- and follow-up HourglassModule.
24
- stage_blocks (list[int]): Number of sub-modules stacked in current and
25
- follow-up HourglassModule.
26
- norm_cfg (dict): Dictionary to construct and config norm layer.
27
- """
28
-
29
- def __init__(self,
30
- depth,
31
- stage_channels,
32
- stage_blocks,
33
- norm_cfg=dict(type='BN', requires_grad=True)):
34
- # Protect mutable default arguments
35
- norm_cfg = copy.deepcopy(norm_cfg)
36
- super().__init__()
37
-
38
- self.depth = depth
39
-
40
- cur_block = stage_blocks[0]
41
- next_block = stage_blocks[1]
42
-
43
- cur_channel = stage_channels[0]
44
- next_channel = stage_channels[1]
45
-
46
- self.up1 = ResLayer(
47
- BasicBlock, cur_block, cur_channel, cur_channel, norm_cfg=norm_cfg)
48
-
49
- self.low1 = ResLayer(
50
- BasicBlock,
51
- cur_block,
52
- cur_channel,
53
- next_channel,
54
- stride=2,
55
- norm_cfg=norm_cfg)
56
-
57
- if self.depth > 1:
58
- self.low2 = HourglassModule(depth - 1, stage_channels[1:],
59
- stage_blocks[1:])
60
- else:
61
- self.low2 = ResLayer(
62
- BasicBlock,
63
- next_block,
64
- next_channel,
65
- next_channel,
66
- norm_cfg=norm_cfg)
67
-
68
- self.low3 = ResLayer(
69
- BasicBlock,
70
- cur_block,
71
- next_channel,
72
- cur_channel,
73
- norm_cfg=norm_cfg,
74
- downsample_first=False)
75
-
76
- self.up2 = nn.Upsample(scale_factor=2)
77
-
78
- def forward(self, x):
79
- """Model forward function."""
80
- up1 = self.up1(x)
81
- low1 = self.low1(x)
82
- low2 = self.low2(low1)
83
- low3 = self.low3(low2)
84
- up2 = self.up2(low3)
85
- return up1 + up2
86
-
87
-
88
- @BACKBONES.register_module()
89
- class HourglassNet(BaseBackbone):
90
- """HourglassNet backbone.
91
-
92
- Stacked Hourglass Networks for Human Pose Estimation.
93
- More details can be found in the `paper
94
- <https://arxiv.org/abs/1603.06937>`__ .
95
-
96
- Args:
97
- downsample_times (int): Downsample times in a HourglassModule.
98
- num_stacks (int): Number of HourglassModule modules stacked,
99
- 1 for Hourglass-52, 2 for Hourglass-104.
100
- stage_channels (list[int]): Feature channel of each sub-module in a
101
- HourglassModule.
102
- stage_blocks (list[int]): Number of sub-modules stacked in a
103
- HourglassModule.
104
- feat_channel (int): Feature channel of conv after a HourglassModule.
105
- norm_cfg (dict): Dictionary to construct and config norm layer.
106
-
107
- Example:
108
- >>> from mmpose.models import HourglassNet
109
- >>> import torch
110
- >>> self = HourglassNet()
111
- >>> self.eval()
112
- >>> inputs = torch.rand(1, 3, 511, 511)
113
- >>> level_outputs = self.forward(inputs)
114
- >>> for level_output in level_outputs:
115
- ... print(tuple(level_output.shape))
116
- (1, 256, 128, 128)
117
- (1, 256, 128, 128)
118
- """
119
-
120
- def __init__(self,
121
- downsample_times=5,
122
- num_stacks=2,
123
- stage_channels=(256, 256, 384, 384, 384, 512),
124
- stage_blocks=(2, 2, 2, 2, 2, 4),
125
- feat_channel=256,
126
- norm_cfg=dict(type='BN', requires_grad=True)):
127
- # Protect mutable default arguments
128
- norm_cfg = copy.deepcopy(norm_cfg)
129
- super().__init__()
130
-
131
- self.num_stacks = num_stacks
132
- assert self.num_stacks >= 1
133
- assert len(stage_channels) == len(stage_blocks)
134
- assert len(stage_channels) > downsample_times
135
-
136
- cur_channel = stage_channels[0]
137
-
138
- self.stem = nn.Sequential(
139
- ConvModule(3, 128, 7, padding=3, stride=2, norm_cfg=norm_cfg),
140
- ResLayer(BasicBlock, 1, 128, 256, stride=2, norm_cfg=norm_cfg))
141
-
142
- self.hourglass_modules = nn.ModuleList([
143
- HourglassModule(downsample_times, stage_channels, stage_blocks)
144
- for _ in range(num_stacks)
145
- ])
146
-
147
- self.inters = ResLayer(
148
- BasicBlock,
149
- num_stacks - 1,
150
- cur_channel,
151
- cur_channel,
152
- norm_cfg=norm_cfg)
153
-
154
- self.conv1x1s = nn.ModuleList([
155
- ConvModule(
156
- cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
157
- for _ in range(num_stacks - 1)
158
- ])
159
-
160
- self.out_convs = nn.ModuleList([
161
- ConvModule(
162
- cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
163
- for _ in range(num_stacks)
164
- ])
165
-
166
- self.remap_convs = nn.ModuleList([
167
- ConvModule(
168
- feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
169
- for _ in range(num_stacks - 1)
170
- ])
171
-
172
- self.relu = nn.ReLU(inplace=True)
173
-
174
- def init_weights(self, pretrained=None):
175
- """Initialize the weights in backbone.
176
-
177
- Args:
178
- pretrained (str, optional): Path to pre-trained weights.
179
- Defaults to None.
180
- """
181
- if isinstance(pretrained, str):
182
- logger = get_root_logger()
183
- load_checkpoint(self, pretrained, strict=False, logger=logger)
184
- elif pretrained is None:
185
- for m in self.modules():
186
- if isinstance(m, nn.Conv2d):
187
- normal_init(m, std=0.001)
188
- elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
189
- constant_init(m, 1)
190
- else:
191
- raise TypeError('pretrained must be a str or None')
192
-
193
- def forward(self, x):
194
- """Model forward function."""
195
- inter_feat = self.stem(x)
196
- out_feats = []
197
-
198
- for ind in range(self.num_stacks):
199
- single_hourglass = self.hourglass_modules[ind]
200
- out_conv = self.out_convs[ind]
201
-
202
- hourglass_feat = single_hourglass(inter_feat)
203
- out_feat = out_conv(hourglass_feat)
204
- out_feats.append(out_feat)
205
-
206
- if ind < self.num_stacks - 1:
207
- inter_feat = self.conv1x1s[ind](
208
- inter_feat) + self.remap_convs[ind](
209
- out_feat)
210
- inter_feat = self.inters[ind](self.relu(inter_feat))
211
-
212
- return out_feats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/hourglass_ae.py DELETED
@@ -1,212 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
-
4
- import torch.nn as nn
5
- from mmcv.cnn import ConvModule, MaxPool2d, constant_init, normal_init
6
- from torch.nn.modules.batchnorm import _BatchNorm
7
-
8
- from mmpose.utils import get_root_logger
9
- from ..builder import BACKBONES
10
- from .base_backbone import BaseBackbone
11
- from .utils import load_checkpoint
12
-
13
-
14
- class HourglassAEModule(nn.Module):
15
- """Modified Hourglass Module for HourglassNet_AE backbone.
16
-
17
- Generate module recursively and use BasicBlock as the base unit.
18
-
19
- Args:
20
- depth (int): Depth of current HourglassModule.
21
- stage_channels (list[int]): Feature channels of sub-modules in current
22
- and follow-up HourglassModule.
23
- norm_cfg (dict): Dictionary to construct and config norm layer.
24
- """
25
-
26
- def __init__(self,
27
- depth,
28
- stage_channels,
29
- norm_cfg=dict(type='BN', requires_grad=True)):
30
- # Protect mutable default arguments
31
- norm_cfg = copy.deepcopy(norm_cfg)
32
- super().__init__()
33
-
34
- self.depth = depth
35
-
36
- cur_channel = stage_channels[0]
37
- next_channel = stage_channels[1]
38
-
39
- self.up1 = ConvModule(
40
- cur_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg)
41
-
42
- self.pool1 = MaxPool2d(2, 2)
43
-
44
- self.low1 = ConvModule(
45
- cur_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg)
46
-
47
- if self.depth > 1:
48
- self.low2 = HourglassAEModule(depth - 1, stage_channels[1:])
49
- else:
50
- self.low2 = ConvModule(
51
- next_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg)
52
-
53
- self.low3 = ConvModule(
54
- next_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg)
55
-
56
- self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
57
-
58
- def forward(self, x):
59
- """Model forward function."""
60
- up1 = self.up1(x)
61
- pool1 = self.pool1(x)
62
- low1 = self.low1(pool1)
63
- low2 = self.low2(low1)
64
- low3 = self.low3(low2)
65
- up2 = self.up2(low3)
66
- return up1 + up2
67
-
68
-
69
- @BACKBONES.register_module()
70
- class HourglassAENet(BaseBackbone):
71
- """Hourglass-AE Network proposed by Newell et al.
72
-
73
- Associative Embedding: End-to-End Learning for Joint
74
- Detection and Grouping.
75
-
76
- More details can be found in the `paper
77
- <https://arxiv.org/abs/1611.05424>`__ .
78
-
79
- Args:
80
- downsample_times (int): Downsample times in a HourglassModule.
81
- num_stacks (int): Number of HourglassModule modules stacked,
82
- 1 for Hourglass-52, 2 for Hourglass-104.
83
- stage_channels (list[int]): Feature channel of each sub-module in a
84
- HourglassModule.
85
- stage_blocks (list[int]): Number of sub-modules stacked in a
86
- HourglassModule.
87
- feat_channels (int): Feature channel of conv after a HourglassModule.
88
- norm_cfg (dict): Dictionary to construct and config norm layer.
89
-
90
- Example:
91
- >>> from mmpose.models import HourglassAENet
92
- >>> import torch
93
- >>> self = HourglassAENet()
94
- >>> self.eval()
95
- >>> inputs = torch.rand(1, 3, 512, 512)
96
- >>> level_outputs = self.forward(inputs)
97
- >>> for level_output in level_outputs:
98
- ... print(tuple(level_output.shape))
99
- (1, 34, 128, 128)
100
- """
101
-
102
- def __init__(self,
103
- downsample_times=4,
104
- num_stacks=1,
105
- out_channels=34,
106
- stage_channels=(256, 384, 512, 640, 768),
107
- feat_channels=256,
108
- norm_cfg=dict(type='BN', requires_grad=True)):
109
- # Protect mutable default arguments
110
- norm_cfg = copy.deepcopy(norm_cfg)
111
- super().__init__()
112
-
113
- self.num_stacks = num_stacks
114
- assert self.num_stacks >= 1
115
- assert len(stage_channels) > downsample_times
116
-
117
- cur_channels = stage_channels[0]
118
-
119
- self.stem = nn.Sequential(
120
- ConvModule(3, 64, 7, padding=3, stride=2, norm_cfg=norm_cfg),
121
- ConvModule(64, 128, 3, padding=1, norm_cfg=norm_cfg),
122
- MaxPool2d(2, 2),
123
- ConvModule(128, 128, 3, padding=1, norm_cfg=norm_cfg),
124
- ConvModule(128, feat_channels, 3, padding=1, norm_cfg=norm_cfg),
125
- )
126
-
127
- self.hourglass_modules = nn.ModuleList([
128
- nn.Sequential(
129
- HourglassAEModule(
130
- downsample_times, stage_channels, norm_cfg=norm_cfg),
131
- ConvModule(
132
- feat_channels,
133
- feat_channels,
134
- 3,
135
- padding=1,
136
- norm_cfg=norm_cfg),
137
- ConvModule(
138
- feat_channels,
139
- feat_channels,
140
- 3,
141
- padding=1,
142
- norm_cfg=norm_cfg)) for _ in range(num_stacks)
143
- ])
144
-
145
- self.out_convs = nn.ModuleList([
146
- ConvModule(
147
- cur_channels,
148
- out_channels,
149
- 1,
150
- padding=0,
151
- norm_cfg=None,
152
- act_cfg=None) for _ in range(num_stacks)
153
- ])
154
-
155
- self.remap_out_convs = nn.ModuleList([
156
- ConvModule(
157
- out_channels,
158
- feat_channels,
159
- 1,
160
- norm_cfg=norm_cfg,
161
- act_cfg=None) for _ in range(num_stacks - 1)
162
- ])
163
-
164
- self.remap_feature_convs = nn.ModuleList([
165
- ConvModule(
166
- feat_channels,
167
- feat_channels,
168
- 1,
169
- norm_cfg=norm_cfg,
170
- act_cfg=None) for _ in range(num_stacks - 1)
171
- ])
172
-
173
- self.relu = nn.ReLU(inplace=True)
174
-
175
- def init_weights(self, pretrained=None):
176
- """Initialize the weights in backbone.
177
-
178
- Args:
179
- pretrained (str, optional): Path to pre-trained weights.
180
- Defaults to None.
181
- """
182
- if isinstance(pretrained, str):
183
- logger = get_root_logger()
184
- load_checkpoint(self, pretrained, strict=False, logger=logger)
185
- elif pretrained is None:
186
- for m in self.modules():
187
- if isinstance(m, nn.Conv2d):
188
- normal_init(m, std=0.001)
189
- elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
190
- constant_init(m, 1)
191
- else:
192
- raise TypeError('pretrained must be a str or None')
193
-
194
- def forward(self, x):
195
- """Model forward function."""
196
- inter_feat = self.stem(x)
197
- out_feats = []
198
-
199
- for ind in range(self.num_stacks):
200
- single_hourglass = self.hourglass_modules[ind]
201
- out_conv = self.out_convs[ind]
202
-
203
- hourglass_feat = single_hourglass(inter_feat)
204
- out_feat = out_conv(hourglass_feat)
205
- out_feats.append(out_feat)
206
-
207
- if ind < self.num_stacks - 1:
208
- inter_feat = inter_feat + self.remap_out_convs[ind](
209
- out_feat) + self.remap_feature_convs[ind](
210
- hourglass_feat)
211
-
212
- return out_feats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/hrformer.py DELETED
@@ -1,746 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
-
3
- import math
4
-
5
- import torch
6
- import torch.nn as nn
7
- # from timm.models.layers import to_2tuple, trunc_normal_
8
- from mmcv.cnn import (build_activation_layer, build_conv_layer,
9
- build_norm_layer, trunc_normal_init)
10
- from mmcv.cnn.bricks.transformer import build_dropout
11
- from mmcv.runner import BaseModule
12
- from torch.nn.functional import pad
13
-
14
- from ..builder import BACKBONES
15
- from .hrnet import Bottleneck, HRModule, HRNet
16
-
17
-
18
- def nlc_to_nchw(x, hw_shape):
19
- """Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
20
-
21
- Args:
22
- x (Tensor): The input tensor of shape [N, L, C] before conversion.
23
- hw_shape (Sequence[int]): The height and width of output feature map.
24
-
25
- Returns:
26
- Tensor: The output tensor of shape [N, C, H, W] after conversion.
27
- """
28
- H, W = hw_shape
29
- assert len(x.shape) == 3
30
- B, L, C = x.shape
31
- assert L == H * W, 'The seq_len doesn\'t match H, W'
32
- return x.transpose(1, 2).reshape(B, C, H, W)
33
-
34
-
35
- def nchw_to_nlc(x):
36
- """Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
37
-
38
- Args:
39
- x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
40
-
41
- Returns:
42
- Tensor: The output tensor of shape [N, L, C] after conversion.
43
- """
44
- assert len(x.shape) == 4
45
- return x.flatten(2).transpose(1, 2).contiguous()
46
-
47
-
48
- def build_drop_path(drop_path_rate):
49
- """Build drop path layer."""
50
- return build_dropout(dict(type='DropPath', drop_prob=drop_path_rate))
51
-
52
-
53
- class WindowMSA(BaseModule):
54
- """Window based multi-head self-attention (W-MSA) module with relative
55
- position bias.
56
-
57
- Args:
58
- embed_dims (int): Number of input channels.
59
- num_heads (int): Number of attention heads.
60
- window_size (tuple[int]): The height and width of the window.
61
- qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
62
- Default: True.
63
- qk_scale (float | None, optional): Override default qk scale of
64
- head_dim ** -0.5 if set. Default: None.
65
- attn_drop_rate (float, optional): Dropout ratio of attention weight.
66
- Default: 0.0
67
- proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
68
- with_rpe (bool, optional): If True, use relative position bias.
69
- Default: True.
70
- init_cfg (dict | None, optional): The Config for initialization.
71
- Default: None.
72
- """
73
-
74
- def __init__(self,
75
- embed_dims,
76
- num_heads,
77
- window_size,
78
- qkv_bias=True,
79
- qk_scale=None,
80
- attn_drop_rate=0.,
81
- proj_drop_rate=0.,
82
- with_rpe=True,
83
- init_cfg=None):
84
-
85
- super().__init__(init_cfg=init_cfg)
86
- self.embed_dims = embed_dims
87
- self.window_size = window_size # Wh, Ww
88
- self.num_heads = num_heads
89
- head_embed_dims = embed_dims // num_heads
90
- self.scale = qk_scale or head_embed_dims**-0.5
91
-
92
- self.with_rpe = with_rpe
93
- if self.with_rpe:
94
- # define a parameter table of relative position bias
95
- self.relative_position_bias_table = nn.Parameter(
96
- torch.zeros(
97
- (2 * window_size[0] - 1) * (2 * window_size[1] - 1),
98
- num_heads)) # 2*Wh-1 * 2*Ww-1, nH
99
-
100
- Wh, Ww = self.window_size
101
- rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
102
- rel_position_index = rel_index_coords + rel_index_coords.T
103
- rel_position_index = rel_position_index.flip(1).contiguous()
104
- self.register_buffer('relative_position_index', rel_position_index)
105
-
106
- self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
107
- self.attn_drop = nn.Dropout(attn_drop_rate)
108
- self.proj = nn.Linear(embed_dims, embed_dims)
109
- self.proj_drop = nn.Dropout(proj_drop_rate)
110
-
111
- self.softmax = nn.Softmax(dim=-1)
112
-
113
- def init_weights(self):
114
- trunc_normal_init(self.relative_position_bias_table, std=0.02)
115
-
116
- def forward(self, x, mask=None):
117
- """
118
- Args:
119
-
120
- x (tensor): input features with shape of (B*num_windows, N, C)
121
- mask (tensor | None, Optional): mask with shape of (num_windows,
122
- Wh*Ww, Wh*Ww), value should be between (-inf, 0].
123
- """
124
- B, N, C = x.shape
125
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
126
- C // self.num_heads).permute(2, 0, 3, 1, 4)
127
- q, k, v = qkv[0], qkv[1], qkv[2]
128
-
129
- q = q * self.scale
130
- attn = (q @ k.transpose(-2, -1))
131
-
132
- if self.with_rpe:
133
- relative_position_bias = self.relative_position_bias_table[
134
- self.relative_position_index.view(-1)].view(
135
- self.window_size[0] * self.window_size[1],
136
- self.window_size[0] * self.window_size[1],
137
- -1) # Wh*Ww,Wh*Ww,nH
138
- relative_position_bias = relative_position_bias.permute(
139
- 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
140
- attn = attn + relative_position_bias.unsqueeze(0)
141
-
142
- if mask is not None:
143
- nW = mask.shape[0]
144
- attn = attn.view(B // nW, nW, self.num_heads, N,
145
- N) + mask.unsqueeze(1).unsqueeze(0)
146
- attn = attn.view(-1, self.num_heads, N, N)
147
- attn = self.softmax(attn)
148
-
149
- attn = self.attn_drop(attn)
150
-
151
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
152
- x = self.proj(x)
153
- x = self.proj_drop(x)
154
- return x
155
-
156
- @staticmethod
157
- def double_step_seq(step1, len1, step2, len2):
158
- seq1 = torch.arange(0, step1 * len1, step1)
159
- seq2 = torch.arange(0, step2 * len2, step2)
160
- return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
161
-
162
-
163
- class LocalWindowSelfAttention(BaseModule):
164
- r""" Local-window Self Attention (LSA) module with relative position bias.
165
-
166
- This module is the short-range self-attention module in the
167
- Interlaced Sparse Self-Attention <https://arxiv.org/abs/1907.12273>`_.
168
-
169
- Args:
170
- embed_dims (int): Number of input channels.
171
- num_heads (int): Number of attention heads.
172
- window_size (tuple[int] | int): The height and width of the window.
173
- qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
174
- Default: True.
175
- qk_scale (float | None, optional): Override default qk scale of
176
- head_dim ** -0.5 if set. Default: None.
177
- attn_drop_rate (float, optional): Dropout ratio of attention weight.
178
- Default: 0.0
179
- proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
180
- with_rpe (bool, optional): If True, use relative position bias.
181
- Default: True.
182
- with_pad_mask (bool, optional): If True, mask out the padded tokens in
183
- the attention process. Default: False.
184
- init_cfg (dict | None, optional): The Config for initialization.
185
- Default: None.
186
- """
187
-
188
- def __init__(self,
189
- embed_dims,
190
- num_heads,
191
- window_size,
192
- qkv_bias=True,
193
- qk_scale=None,
194
- attn_drop_rate=0.,
195
- proj_drop_rate=0.,
196
- with_rpe=True,
197
- with_pad_mask=False,
198
- init_cfg=None):
199
- super().__init__(init_cfg=init_cfg)
200
- if isinstance(window_size, int):
201
- window_size = (window_size, window_size)
202
- self.window_size = window_size
203
- self.with_pad_mask = with_pad_mask
204
- self.attn = WindowMSA(
205
- embed_dims=embed_dims,
206
- num_heads=num_heads,
207
- window_size=window_size,
208
- qkv_bias=qkv_bias,
209
- qk_scale=qk_scale,
210
- attn_drop_rate=attn_drop_rate,
211
- proj_drop_rate=proj_drop_rate,
212
- with_rpe=with_rpe,
213
- init_cfg=init_cfg)
214
-
215
- def forward(self, x, H, W, **kwargs):
216
- """Forward function."""
217
- B, N, C = x.shape
218
- x = x.view(B, H, W, C)
219
- Wh, Ww = self.window_size
220
-
221
- # center-pad the feature on H and W axes
222
- pad_h = math.ceil(H / Wh) * Wh - H
223
- pad_w = math.ceil(W / Ww) * Ww - W
224
- x = pad(x, (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
225
- pad_h - pad_h // 2))
226
-
227
- # permute
228
- x = x.view(B, math.ceil(H / Wh), Wh, math.ceil(W / Ww), Ww, C)
229
- x = x.permute(0, 1, 3, 2, 4, 5)
230
- x = x.reshape(-1, Wh * Ww, C) # (B*num_window, Wh*Ww, C)
231
-
232
- # attention
233
- if self.with_pad_mask and pad_h > 0 and pad_w > 0:
234
- pad_mask = x.new_zeros(1, H, W, 1)
235
- pad_mask = pad(
236
- pad_mask, [
237
- 0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
238
- pad_h - pad_h // 2
239
- ],
240
- value=-float('inf'))
241
- pad_mask = pad_mask.view(1, math.ceil(H / Wh), Wh,
242
- math.ceil(W / Ww), Ww, 1)
243
- pad_mask = pad_mask.permute(1, 3, 0, 2, 4, 5)
244
- pad_mask = pad_mask.reshape(-1, Wh * Ww)
245
- pad_mask = pad_mask[:, None, :].expand([-1, Wh * Ww, -1])
246
- out = self.attn(x, pad_mask, **kwargs)
247
- else:
248
- out = self.attn(x, **kwargs)
249
-
250
- # reverse permutation
251
- out = out.reshape(B, math.ceil(H / Wh), math.ceil(W / Ww), Wh, Ww, C)
252
- out = out.permute(0, 1, 3, 2, 4, 5)
253
- out = out.reshape(B, H + pad_h, W + pad_w, C)
254
-
255
- # de-pad
256
- out = out[:, pad_h // 2:H + pad_h // 2, pad_w // 2:W + pad_w // 2]
257
- return out.reshape(B, N, C)
258
-
259
-
260
- class CrossFFN(BaseModule):
261
- r"""FFN with Depthwise Conv of HRFormer.
262
-
263
- Args:
264
- in_features (int): The feature dimension.
265
- hidden_features (int, optional): The hidden dimension of FFNs.
266
- Defaults: The same as in_features.
267
- act_cfg (dict, optional): Config of activation layer.
268
- Default: dict(type='GELU').
269
- dw_act_cfg (dict, optional): Config of activation layer appended
270
- right after DW Conv. Default: dict(type='GELU').
271
- norm_cfg (dict, optional): Config of norm layer.
272
- Default: dict(type='SyncBN').
273
- init_cfg (dict | list | None, optional): The init config.
274
- Default: None.
275
- """
276
-
277
- def __init__(self,
278
- in_features,
279
- hidden_features=None,
280
- out_features=None,
281
- act_cfg=dict(type='GELU'),
282
- dw_act_cfg=dict(type='GELU'),
283
- norm_cfg=dict(type='SyncBN'),
284
- init_cfg=None):
285
- super().__init__(init_cfg=init_cfg)
286
- out_features = out_features or in_features
287
- hidden_features = hidden_features or in_features
288
- self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1)
289
- self.act1 = build_activation_layer(act_cfg)
290
- self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
291
- self.dw3x3 = nn.Conv2d(
292
- hidden_features,
293
- hidden_features,
294
- kernel_size=3,
295
- stride=1,
296
- groups=hidden_features,
297
- padding=1)
298
- self.act2 = build_activation_layer(dw_act_cfg)
299
- self.norm2 = build_norm_layer(norm_cfg, hidden_features)[1]
300
- self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
301
- self.act3 = build_activation_layer(act_cfg)
302
- self.norm3 = build_norm_layer(norm_cfg, out_features)[1]
303
-
304
- # put the modules togather
305
- self.layers = [
306
- self.fc1, self.norm1, self.act1, self.dw3x3, self.norm2, self.act2,
307
- self.fc2, self.norm3, self.act3
308
- ]
309
-
310
- def forward(self, x, H, W):
311
- """Forward function."""
312
- x = nlc_to_nchw(x, (H, W))
313
- for layer in self.layers:
314
- x = layer(x)
315
- x = nchw_to_nlc(x)
316
- return x
317
-
318
-
319
- class HRFormerBlock(BaseModule):
320
- """High-Resolution Block for HRFormer.
321
-
322
- Args:
323
- in_features (int): The input dimension.
324
- out_features (int): The output dimension.
325
- num_heads (int): The number of head within each LSA.
326
- window_size (int, optional): The window size for the LSA.
327
- Default: 7
328
- mlp_ratio (int, optional): The expansion ration of FFN.
329
- Default: 4
330
- act_cfg (dict, optional): Config of activation layer.
331
- Default: dict(type='GELU').
332
- norm_cfg (dict, optional): Config of norm layer.
333
- Default: dict(type='SyncBN').
334
- transformer_norm_cfg (dict, optional): Config of transformer norm
335
- layer. Default: dict(type='LN', eps=1e-6).
336
- init_cfg (dict | list | None, optional): The init config.
337
- Default: None.
338
- """
339
-
340
- expansion = 1
341
-
342
- def __init__(self,
343
- in_features,
344
- out_features,
345
- num_heads,
346
- window_size=7,
347
- mlp_ratio=4.0,
348
- drop_path=0.0,
349
- act_cfg=dict(type='GELU'),
350
- norm_cfg=dict(type='SyncBN'),
351
- transformer_norm_cfg=dict(type='LN', eps=1e-6),
352
- init_cfg=None,
353
- **kwargs):
354
- super(HRFormerBlock, self).__init__(init_cfg=init_cfg)
355
- self.num_heads = num_heads
356
- self.window_size = window_size
357
- self.mlp_ratio = mlp_ratio
358
-
359
- self.norm1 = build_norm_layer(transformer_norm_cfg, in_features)[1]
360
- self.attn = LocalWindowSelfAttention(
361
- in_features,
362
- num_heads=num_heads,
363
- window_size=window_size,
364
- init_cfg=None,
365
- **kwargs)
366
-
367
- self.norm2 = build_norm_layer(transformer_norm_cfg, out_features)[1]
368
- self.ffn = CrossFFN(
369
- in_features=in_features,
370
- hidden_features=int(in_features * mlp_ratio),
371
- out_features=out_features,
372
- norm_cfg=norm_cfg,
373
- act_cfg=act_cfg,
374
- dw_act_cfg=act_cfg,
375
- init_cfg=None)
376
-
377
- self.drop_path = build_drop_path(
378
- drop_path) if drop_path > 0.0 else nn.Identity()
379
-
380
- def forward(self, x):
381
- """Forward function."""
382
- B, C, H, W = x.size()
383
- # Attention
384
- x = x.view(B, C, -1).permute(0, 2, 1)
385
- x = x + self.drop_path(self.attn(self.norm1(x), H, W))
386
- # FFN
387
- x = x + self.drop_path(self.ffn(self.norm2(x), H, W))
388
- x = x.permute(0, 2, 1).view(B, C, H, W)
389
- return x
390
-
391
- def extra_repr(self):
392
- """(Optional) Set the extra information about this module."""
393
- return 'num_heads={}, window_size={}, mlp_ratio={}'.format(
394
- self.num_heads, self.window_size, self.mlp_ratio)
395
-
396
-
397
- class HRFomerModule(HRModule):
398
- """High-Resolution Module for HRFormer.
399
-
400
- Args:
401
- num_branches (int): The number of branches in the HRFormerModule.
402
- block (nn.Module): The building block of HRFormer.
403
- The block should be the HRFormerBlock.
404
- num_blocks (tuple): The number of blocks in each branch.
405
- The length must be equal to num_branches.
406
- num_inchannels (tuple): The number of input channels in each branch.
407
- The length must be equal to num_branches.
408
- num_channels (tuple): The number of channels in each branch.
409
- The length must be equal to num_branches.
410
- num_heads (tuple): The number of heads within the LSAs.
411
- num_window_sizes (tuple): The window size for the LSAs.
412
- num_mlp_ratios (tuple): The expansion ratio for the FFNs.
413
- drop_path (int, optional): The drop path rate of HRFomer.
414
- Default: 0.0
415
- multiscale_output (bool, optional): Whether to output multi-level
416
- features produced by multiple branches. If False, only the first
417
- level feature will be output. Default: True.
418
- conv_cfg (dict, optional): Config of the conv layers.
419
- Default: None.
420
- norm_cfg (dict, optional): Config of the norm layers appended
421
- right after conv. Default: dict(type='SyncBN', requires_grad=True)
422
- transformer_norm_cfg (dict, optional): Config of the norm layers.
423
- Default: dict(type='LN', eps=1e-6)
424
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
425
- memory while slowing down the training speed. Default: False
426
- upsample_cfg(dict, optional): The config of upsample layers in fuse
427
- layers. Default: dict(mode='bilinear', align_corners=False)
428
- """
429
-
430
- def __init__(self,
431
- num_branches,
432
- block,
433
- num_blocks,
434
- num_inchannels,
435
- num_channels,
436
- num_heads,
437
- num_window_sizes,
438
- num_mlp_ratios,
439
- multiscale_output=True,
440
- drop_paths=0.0,
441
- with_rpe=True,
442
- with_pad_mask=False,
443
- conv_cfg=None,
444
- norm_cfg=dict(type='SyncBN', requires_grad=True),
445
- transformer_norm_cfg=dict(type='LN', eps=1e-6),
446
- with_cp=False,
447
- upsample_cfg=dict(mode='bilinear', align_corners=False)):
448
-
449
- self.transformer_norm_cfg = transformer_norm_cfg
450
- self.drop_paths = drop_paths
451
- self.num_heads = num_heads
452
- self.num_window_sizes = num_window_sizes
453
- self.num_mlp_ratios = num_mlp_ratios
454
- self.with_rpe = with_rpe
455
- self.with_pad_mask = with_pad_mask
456
-
457
- super().__init__(num_branches, block, num_blocks, num_inchannels,
458
- num_channels, multiscale_output, with_cp, conv_cfg,
459
- norm_cfg, upsample_cfg)
460
-
461
- def _make_one_branch(self,
462
- branch_index,
463
- block,
464
- num_blocks,
465
- num_channels,
466
- stride=1):
467
- """Build one branch."""
468
- # HRFormerBlock does not support down sample layer yet.
469
- assert stride == 1 and self.in_channels[branch_index] == num_channels[
470
- branch_index]
471
- layers = []
472
- layers.append(
473
- block(
474
- self.in_channels[branch_index],
475
- num_channels[branch_index],
476
- num_heads=self.num_heads[branch_index],
477
- window_size=self.num_window_sizes[branch_index],
478
- mlp_ratio=self.num_mlp_ratios[branch_index],
479
- drop_path=self.drop_paths[0],
480
- norm_cfg=self.norm_cfg,
481
- transformer_norm_cfg=self.transformer_norm_cfg,
482
- init_cfg=None,
483
- with_rpe=self.with_rpe,
484
- with_pad_mask=self.with_pad_mask))
485
-
486
- self.in_channels[
487
- branch_index] = self.in_channels[branch_index] * block.expansion
488
- for i in range(1, num_blocks[branch_index]):
489
- layers.append(
490
- block(
491
- self.in_channels[branch_index],
492
- num_channels[branch_index],
493
- num_heads=self.num_heads[branch_index],
494
- window_size=self.num_window_sizes[branch_index],
495
- mlp_ratio=self.num_mlp_ratios[branch_index],
496
- drop_path=self.drop_paths[i],
497
- norm_cfg=self.norm_cfg,
498
- transformer_norm_cfg=self.transformer_norm_cfg,
499
- init_cfg=None,
500
- with_rpe=self.with_rpe,
501
- with_pad_mask=self.with_pad_mask))
502
- return nn.Sequential(*layers)
503
-
504
- def _make_fuse_layers(self):
505
- """Build fuse layers."""
506
- if self.num_branches == 1:
507
- return None
508
- num_branches = self.num_branches
509
- num_inchannels = self.in_channels
510
- fuse_layers = []
511
- for i in range(num_branches if self.multiscale_output else 1):
512
- fuse_layer = []
513
- for j in range(num_branches):
514
- if j > i:
515
- fuse_layer.append(
516
- nn.Sequential(
517
- build_conv_layer(
518
- self.conv_cfg,
519
- num_inchannels[j],
520
- num_inchannels[i],
521
- kernel_size=1,
522
- stride=1,
523
- bias=False),
524
- build_norm_layer(self.norm_cfg,
525
- num_inchannels[i])[1],
526
- nn.Upsample(
527
- scale_factor=2**(j - i),
528
- mode=self.upsample_cfg['mode'],
529
- align_corners=self.
530
- upsample_cfg['align_corners'])))
531
- elif j == i:
532
- fuse_layer.append(None)
533
- else:
534
- conv3x3s = []
535
- for k in range(i - j):
536
- if k == i - j - 1:
537
- num_outchannels_conv3x3 = num_inchannels[i]
538
- with_out_act = False
539
- else:
540
- num_outchannels_conv3x3 = num_inchannels[j]
541
- with_out_act = True
542
- sub_modules = [
543
- build_conv_layer(
544
- self.conv_cfg,
545
- num_inchannels[j],
546
- num_inchannels[j],
547
- kernel_size=3,
548
- stride=2,
549
- padding=1,
550
- groups=num_inchannels[j],
551
- bias=False,
552
- ),
553
- build_norm_layer(self.norm_cfg,
554
- num_inchannels[j])[1],
555
- build_conv_layer(
556
- self.conv_cfg,
557
- num_inchannels[j],
558
- num_outchannels_conv3x3,
559
- kernel_size=1,
560
- stride=1,
561
- bias=False,
562
- ),
563
- build_norm_layer(self.norm_cfg,
564
- num_outchannels_conv3x3)[1]
565
- ]
566
- if with_out_act:
567
- sub_modules.append(nn.ReLU(False))
568
- conv3x3s.append(nn.Sequential(*sub_modules))
569
- fuse_layer.append(nn.Sequential(*conv3x3s))
570
- fuse_layers.append(nn.ModuleList(fuse_layer))
571
-
572
- return nn.ModuleList(fuse_layers)
573
-
574
- def get_num_inchannels(self):
575
- """Return the number of input channels."""
576
- return self.in_channels
577
-
578
-
579
- @BACKBONES.register_module()
580
- class HRFormer(HRNet):
581
- """HRFormer backbone.
582
-
583
- This backbone is the implementation of `HRFormer: High-Resolution
584
- Transformer for Dense Prediction <https://arxiv.org/abs/2110.09408>`_.
585
-
586
- Args:
587
- extra (dict): Detailed configuration for each stage of HRNet.
588
- There must be 4 stages, the configuration for each stage must have
589
- 5 keys:
590
-
591
- - num_modules (int): The number of HRModule in this stage.
592
- - num_branches (int): The number of branches in the HRModule.
593
- - block (str): The type of block.
594
- - num_blocks (tuple): The number of blocks in each branch.
595
- The length must be equal to num_branches.
596
- - num_channels (tuple): The number of channels in each branch.
597
- The length must be equal to num_branches.
598
- in_channels (int): Number of input image channels. Normally 3.
599
- conv_cfg (dict): Dictionary to construct and config conv layer.
600
- Default: None.
601
- norm_cfg (dict): Config of norm layer.
602
- Use `SyncBN` by default.
603
- transformer_norm_cfg (dict): Config of transformer norm layer.
604
- Use `LN` by default.
605
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
606
- freeze running stats (mean and var). Note: Effect on Batch Norm
607
- and its variants only. Default: False.
608
- zero_init_residual (bool): Whether to use zero init for last norm layer
609
- in resblocks to let them behave as identity. Default: False.
610
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
611
- -1 means not freezing any parameters. Default: -1.
612
- Example:
613
- >>> from mmpose.models import HRFormer
614
- >>> import torch
615
- >>> extra = dict(
616
- >>> stage1=dict(
617
- >>> num_modules=1,
618
- >>> num_branches=1,
619
- >>> block='BOTTLENECK',
620
- >>> num_blocks=(2, ),
621
- >>> num_channels=(64, )),
622
- >>> stage2=dict(
623
- >>> num_modules=1,
624
- >>> num_branches=2,
625
- >>> block='HRFORMER',
626
- >>> window_sizes=(7, 7),
627
- >>> num_heads=(1, 2),
628
- >>> mlp_ratios=(4, 4),
629
- >>> num_blocks=(2, 2),
630
- >>> num_channels=(32, 64)),
631
- >>> stage3=dict(
632
- >>> num_modules=4,
633
- >>> num_branches=3,
634
- >>> block='HRFORMER',
635
- >>> window_sizes=(7, 7, 7),
636
- >>> num_heads=(1, 2, 4),
637
- >>> mlp_ratios=(4, 4, 4),
638
- >>> num_blocks=(2, 2, 2),
639
- >>> num_channels=(32, 64, 128)),
640
- >>> stage4=dict(
641
- >>> num_modules=2,
642
- >>> num_branches=4,
643
- >>> block='HRFORMER',
644
- >>> window_sizes=(7, 7, 7, 7),
645
- >>> num_heads=(1, 2, 4, 8),
646
- >>> mlp_ratios=(4, 4, 4, 4),
647
- >>> num_blocks=(2, 2, 2, 2),
648
- >>> num_channels=(32, 64, 128, 256)))
649
- >>> self = HRFormer(extra, in_channels=1)
650
- >>> self.eval()
651
- >>> inputs = torch.rand(1, 1, 32, 32)
652
- >>> level_outputs = self.forward(inputs)
653
- >>> for level_out in level_outputs:
654
- ... print(tuple(level_out.shape))
655
- (1, 32, 8, 8)
656
- (1, 64, 4, 4)
657
- (1, 128, 2, 2)
658
- (1, 256, 1, 1)
659
- """
660
-
661
- blocks_dict = {'BOTTLENECK': Bottleneck, 'HRFORMERBLOCK': HRFormerBlock}
662
-
663
- def __init__(self,
664
- extra,
665
- in_channels=3,
666
- conv_cfg=None,
667
- norm_cfg=dict(type='BN', requires_grad=True),
668
- transformer_norm_cfg=dict(type='LN', eps=1e-6),
669
- norm_eval=False,
670
- with_cp=False,
671
- zero_init_residual=False,
672
- frozen_stages=-1):
673
-
674
- # stochastic depth
675
- depths = [
676
- extra[stage]['num_blocks'][0] * extra[stage]['num_modules']
677
- for stage in ['stage2', 'stage3', 'stage4']
678
- ]
679
- depth_s2, depth_s3, _ = depths
680
- drop_path_rate = extra['drop_path_rate']
681
- dpr = [
682
- x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
683
- ]
684
- extra['stage2']['drop_path_rates'] = dpr[0:depth_s2]
685
- extra['stage3']['drop_path_rates'] = dpr[depth_s2:depth_s2 + depth_s3]
686
- extra['stage4']['drop_path_rates'] = dpr[depth_s2 + depth_s3:]
687
-
688
- # HRFormer use bilinear upsample as default
689
- upsample_cfg = extra.get('upsample', {
690
- 'mode': 'bilinear',
691
- 'align_corners': False
692
- })
693
- extra['upsample'] = upsample_cfg
694
- self.transformer_norm_cfg = transformer_norm_cfg
695
- self.with_rpe = extra.get('with_rpe', True)
696
- self.with_pad_mask = extra.get('with_pad_mask', False)
697
-
698
- super().__init__(extra, in_channels, conv_cfg, norm_cfg, norm_eval,
699
- with_cp, zero_init_residual, frozen_stages)
700
-
701
- def _make_stage(self,
702
- layer_config,
703
- num_inchannels,
704
- multiscale_output=True):
705
- """Make each stage."""
706
- num_modules = layer_config['num_modules']
707
- num_branches = layer_config['num_branches']
708
- num_blocks = layer_config['num_blocks']
709
- num_channels = layer_config['num_channels']
710
- block = self.blocks_dict[layer_config['block']]
711
- num_heads = layer_config['num_heads']
712
- num_window_sizes = layer_config['window_sizes']
713
- num_mlp_ratios = layer_config['mlp_ratios']
714
- drop_path_rates = layer_config['drop_path_rates']
715
-
716
- modules = []
717
- for i in range(num_modules):
718
- # multiscale_output is only used at the last module
719
- if not multiscale_output and i == num_modules - 1:
720
- reset_multiscale_output = False
721
- else:
722
- reset_multiscale_output = True
723
-
724
- modules.append(
725
- HRFomerModule(
726
- num_branches,
727
- block,
728
- num_blocks,
729
- num_inchannels,
730
- num_channels,
731
- num_heads,
732
- num_window_sizes,
733
- num_mlp_ratios,
734
- reset_multiscale_output,
735
- drop_paths=drop_path_rates[num_blocks[0] *
736
- i:num_blocks[0] * (i + 1)],
737
- with_rpe=self.with_rpe,
738
- with_pad_mask=self.with_pad_mask,
739
- conv_cfg=self.conv_cfg,
740
- norm_cfg=self.norm_cfg,
741
- transformer_norm_cfg=self.transformer_norm_cfg,
742
- with_cp=self.with_cp,
743
- upsample_cfg=self.upsample_cfg))
744
- num_inchannels = modules[-1].get_num_inchannels()
745
-
746
- return nn.Sequential(*modules), num_inchannels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/hrnet.py DELETED
@@ -1,604 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
-
4
- import torch.nn as nn
5
- from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
6
- normal_init)
7
- from torch.nn.modules.batchnorm import _BatchNorm
8
-
9
- from mmpose.utils import get_root_logger
10
- from ..builder import BACKBONES
11
- from .resnet import BasicBlock, Bottleneck, get_expansion
12
- from .utils import load_checkpoint
13
-
14
-
15
- class HRModule(nn.Module):
16
- """High-Resolution Module for HRNet.
17
-
18
- In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
19
- is in this module.
20
- """
21
-
22
- def __init__(self,
23
- num_branches,
24
- blocks,
25
- num_blocks,
26
- in_channels,
27
- num_channels,
28
- multiscale_output=False,
29
- with_cp=False,
30
- conv_cfg=None,
31
- norm_cfg=dict(type='BN'),
32
- upsample_cfg=dict(mode='nearest', align_corners=None)):
33
-
34
- # Protect mutable default arguments
35
- norm_cfg = copy.deepcopy(norm_cfg)
36
- super().__init__()
37
- self._check_branches(num_branches, num_blocks, in_channels,
38
- num_channels)
39
-
40
- self.in_channels = in_channels
41
- self.num_branches = num_branches
42
-
43
- self.multiscale_output = multiscale_output
44
- self.norm_cfg = norm_cfg
45
- self.conv_cfg = conv_cfg
46
- self.upsample_cfg = upsample_cfg
47
- self.with_cp = with_cp
48
- self.branches = self._make_branches(num_branches, blocks, num_blocks,
49
- num_channels)
50
- self.fuse_layers = self._make_fuse_layers()
51
- self.relu = nn.ReLU(inplace=True)
52
-
53
- @staticmethod
54
- def _check_branches(num_branches, num_blocks, in_channels, num_channels):
55
- """Check input to avoid ValueError."""
56
- if num_branches != len(num_blocks):
57
- error_msg = f'NUM_BRANCHES({num_branches}) ' \
58
- f'!= NUM_BLOCKS({len(num_blocks)})'
59
- raise ValueError(error_msg)
60
-
61
- if num_branches != len(num_channels):
62
- error_msg = f'NUM_BRANCHES({num_branches}) ' \
63
- f'!= NUM_CHANNELS({len(num_channels)})'
64
- raise ValueError(error_msg)
65
-
66
- if num_branches != len(in_channels):
67
- error_msg = f'NUM_BRANCHES({num_branches}) ' \
68
- f'!= NUM_INCHANNELS({len(in_channels)})'
69
- raise ValueError(error_msg)
70
-
71
- def _make_one_branch(self,
72
- branch_index,
73
- block,
74
- num_blocks,
75
- num_channels,
76
- stride=1):
77
- """Make one branch."""
78
- downsample = None
79
- if stride != 1 or \
80
- self.in_channels[branch_index] != \
81
- num_channels[branch_index] * get_expansion(block):
82
- downsample = nn.Sequential(
83
- build_conv_layer(
84
- self.conv_cfg,
85
- self.in_channels[branch_index],
86
- num_channels[branch_index] * get_expansion(block),
87
- kernel_size=1,
88
- stride=stride,
89
- bias=False),
90
- build_norm_layer(
91
- self.norm_cfg,
92
- num_channels[branch_index] * get_expansion(block))[1])
93
-
94
- layers = []
95
- layers.append(
96
- block(
97
- self.in_channels[branch_index],
98
- num_channels[branch_index] * get_expansion(block),
99
- stride=stride,
100
- downsample=downsample,
101
- with_cp=self.with_cp,
102
- norm_cfg=self.norm_cfg,
103
- conv_cfg=self.conv_cfg))
104
- self.in_channels[branch_index] = \
105
- num_channels[branch_index] * get_expansion(block)
106
- for _ in range(1, num_blocks[branch_index]):
107
- layers.append(
108
- block(
109
- self.in_channels[branch_index],
110
- num_channels[branch_index] * get_expansion(block),
111
- with_cp=self.with_cp,
112
- norm_cfg=self.norm_cfg,
113
- conv_cfg=self.conv_cfg))
114
-
115
- return nn.Sequential(*layers)
116
-
117
- def _make_branches(self, num_branches, block, num_blocks, num_channels):
118
- """Make branches."""
119
- branches = []
120
-
121
- for i in range(num_branches):
122
- branches.append(
123
- self._make_one_branch(i, block, num_blocks, num_channels))
124
-
125
- return nn.ModuleList(branches)
126
-
127
- def _make_fuse_layers(self):
128
- """Make fuse layer."""
129
- if self.num_branches == 1:
130
- return None
131
-
132
- num_branches = self.num_branches
133
- in_channels = self.in_channels
134
- fuse_layers = []
135
- num_out_branches = num_branches if self.multiscale_output else 1
136
-
137
- for i in range(num_out_branches):
138
- fuse_layer = []
139
- for j in range(num_branches):
140
- if j > i:
141
- fuse_layer.append(
142
- nn.Sequential(
143
- build_conv_layer(
144
- self.conv_cfg,
145
- in_channels[j],
146
- in_channels[i],
147
- kernel_size=1,
148
- stride=1,
149
- padding=0,
150
- bias=False),
151
- build_norm_layer(self.norm_cfg, in_channels[i])[1],
152
- nn.Upsample(
153
- scale_factor=2**(j - i),
154
- mode=self.upsample_cfg['mode'],
155
- align_corners=self.
156
- upsample_cfg['align_corners'])))
157
- elif j == i:
158
- fuse_layer.append(None)
159
- else:
160
- conv_downsamples = []
161
- for k in range(i - j):
162
- if k == i - j - 1:
163
- conv_downsamples.append(
164
- nn.Sequential(
165
- build_conv_layer(
166
- self.conv_cfg,
167
- in_channels[j],
168
- in_channels[i],
169
- kernel_size=3,
170
- stride=2,
171
- padding=1,
172
- bias=False),
173
- build_norm_layer(self.norm_cfg,
174
- in_channels[i])[1]))
175
- else:
176
- conv_downsamples.append(
177
- nn.Sequential(
178
- build_conv_layer(
179
- self.conv_cfg,
180
- in_channels[j],
181
- in_channels[j],
182
- kernel_size=3,
183
- stride=2,
184
- padding=1,
185
- bias=False),
186
- build_norm_layer(self.norm_cfg,
187
- in_channels[j])[1],
188
- nn.ReLU(inplace=True)))
189
- fuse_layer.append(nn.Sequential(*conv_downsamples))
190
- fuse_layers.append(nn.ModuleList(fuse_layer))
191
-
192
- return nn.ModuleList(fuse_layers)
193
-
194
- def forward(self, x):
195
- """Forward function."""
196
- if self.num_branches == 1:
197
- return [self.branches[0](x[0])]
198
-
199
- for i in range(self.num_branches):
200
- x[i] = self.branches[i](x[i])
201
-
202
- x_fuse = []
203
- for i in range(len(self.fuse_layers)):
204
- y = 0
205
- for j in range(self.num_branches):
206
- if i == j:
207
- y += x[j]
208
- else:
209
- y += self.fuse_layers[i][j](x[j])
210
- x_fuse.append(self.relu(y))
211
- return x_fuse
212
-
213
-
214
- @BACKBONES.register_module()
215
- class HRNet(nn.Module):
216
- """HRNet backbone.
217
-
218
- `High-Resolution Representations for Labeling Pixels and Regions
219
- <https://arxiv.org/abs/1904.04514>`__
220
-
221
- Args:
222
- extra (dict): detailed configuration for each stage of HRNet.
223
- in_channels (int): Number of input image channels. Default: 3.
224
- conv_cfg (dict): dictionary to construct and config conv layer.
225
- norm_cfg (dict): dictionary to construct and config norm layer.
226
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
227
- freeze running stats (mean and var). Note: Effect on Batch Norm
228
- and its variants only. Default: False
229
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
230
- memory while slowing down the training speed.
231
- zero_init_residual (bool): whether to use zero init for last norm layer
232
- in resblocks to let them behave as identity.
233
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
234
- -1 means not freezing any parameters. Default: -1.
235
-
236
- Example:
237
- >>> from mmpose.models import HRNet
238
- >>> import torch
239
- >>> extra = dict(
240
- >>> stage1=dict(
241
- >>> num_modules=1,
242
- >>> num_branches=1,
243
- >>> block='BOTTLENECK',
244
- >>> num_blocks=(4, ),
245
- >>> num_channels=(64, )),
246
- >>> stage2=dict(
247
- >>> num_modules=1,
248
- >>> num_branches=2,
249
- >>> block='BASIC',
250
- >>> num_blocks=(4, 4),
251
- >>> num_channels=(32, 64)),
252
- >>> stage3=dict(
253
- >>> num_modules=4,
254
- >>> num_branches=3,
255
- >>> block='BASIC',
256
- >>> num_blocks=(4, 4, 4),
257
- >>> num_channels=(32, 64, 128)),
258
- >>> stage4=dict(
259
- >>> num_modules=3,
260
- >>> num_branches=4,
261
- >>> block='BASIC',
262
- >>> num_blocks=(4, 4, 4, 4),
263
- >>> num_channels=(32, 64, 128, 256)))
264
- >>> self = HRNet(extra, in_channels=1)
265
- >>> self.eval()
266
- >>> inputs = torch.rand(1, 1, 32, 32)
267
- >>> level_outputs = self.forward(inputs)
268
- >>> for level_out in level_outputs:
269
- ... print(tuple(level_out.shape))
270
- (1, 32, 8, 8)
271
- """
272
-
273
- blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
274
-
275
- def __init__(self,
276
- extra,
277
- in_channels=3,
278
- conv_cfg=None,
279
- norm_cfg=dict(type='BN'),
280
- norm_eval=False,
281
- with_cp=False,
282
- zero_init_residual=False,
283
- frozen_stages=-1):
284
- # Protect mutable default arguments
285
- norm_cfg = copy.deepcopy(norm_cfg)
286
- super().__init__()
287
- self.extra = extra
288
- self.conv_cfg = conv_cfg
289
- self.norm_cfg = norm_cfg
290
- self.norm_eval = norm_eval
291
- self.with_cp = with_cp
292
- self.zero_init_residual = zero_init_residual
293
- self.frozen_stages = frozen_stages
294
-
295
- # stem net
296
- self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
297
- self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
298
-
299
- self.conv1 = build_conv_layer(
300
- self.conv_cfg,
301
- in_channels,
302
- 64,
303
- kernel_size=3,
304
- stride=2,
305
- padding=1,
306
- bias=False)
307
-
308
- self.add_module(self.norm1_name, norm1)
309
- self.conv2 = build_conv_layer(
310
- self.conv_cfg,
311
- 64,
312
- 64,
313
- kernel_size=3,
314
- stride=2,
315
- padding=1,
316
- bias=False)
317
-
318
- self.add_module(self.norm2_name, norm2)
319
- self.relu = nn.ReLU(inplace=True)
320
-
321
- self.upsample_cfg = self.extra.get('upsample', {
322
- 'mode': 'nearest',
323
- 'align_corners': None
324
- })
325
-
326
- # stage 1
327
- self.stage1_cfg = self.extra['stage1']
328
- num_channels = self.stage1_cfg['num_channels'][0]
329
- block_type = self.stage1_cfg['block']
330
- num_blocks = self.stage1_cfg['num_blocks'][0]
331
-
332
- block = self.blocks_dict[block_type]
333
- stage1_out_channels = num_channels * get_expansion(block)
334
- self.layer1 = self._make_layer(block, 64, stage1_out_channels,
335
- num_blocks)
336
-
337
- # stage 2
338
- self.stage2_cfg = self.extra['stage2']
339
- num_channels = self.stage2_cfg['num_channels']
340
- block_type = self.stage2_cfg['block']
341
-
342
- block = self.blocks_dict[block_type]
343
- num_channels = [
344
- channel * get_expansion(block) for channel in num_channels
345
- ]
346
- self.transition1 = self._make_transition_layer([stage1_out_channels],
347
- num_channels)
348
- self.stage2, pre_stage_channels = self._make_stage(
349
- self.stage2_cfg, num_channels)
350
-
351
- # stage 3
352
- self.stage3_cfg = self.extra['stage3']
353
- num_channels = self.stage3_cfg['num_channels']
354
- block_type = self.stage3_cfg['block']
355
-
356
- block = self.blocks_dict[block_type]
357
- num_channels = [
358
- channel * get_expansion(block) for channel in num_channels
359
- ]
360
- self.transition2 = self._make_transition_layer(pre_stage_channels,
361
- num_channels)
362
- self.stage3, pre_stage_channels = self._make_stage(
363
- self.stage3_cfg, num_channels)
364
-
365
- # stage 4
366
- self.stage4_cfg = self.extra['stage4']
367
- num_channels = self.stage4_cfg['num_channels']
368
- block_type = self.stage4_cfg['block']
369
-
370
- block = self.blocks_dict[block_type]
371
- num_channels = [
372
- channel * get_expansion(block) for channel in num_channels
373
- ]
374
- self.transition3 = self._make_transition_layer(pre_stage_channels,
375
- num_channels)
376
-
377
- self.stage4, pre_stage_channels = self._make_stage(
378
- self.stage4_cfg,
379
- num_channels,
380
- multiscale_output=self.stage4_cfg.get('multiscale_output', False))
381
-
382
- self._freeze_stages()
383
-
384
- @property
385
- def norm1(self):
386
- """nn.Module: the normalization layer named "norm1" """
387
- return getattr(self, self.norm1_name)
388
-
389
- @property
390
- def norm2(self):
391
- """nn.Module: the normalization layer named "norm2" """
392
- return getattr(self, self.norm2_name)
393
-
394
- def _make_transition_layer(self, num_channels_pre_layer,
395
- num_channels_cur_layer):
396
- """Make transition layer."""
397
- num_branches_cur = len(num_channels_cur_layer)
398
- num_branches_pre = len(num_channels_pre_layer)
399
-
400
- transition_layers = []
401
- for i in range(num_branches_cur):
402
- if i < num_branches_pre:
403
- if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
404
- transition_layers.append(
405
- nn.Sequential(
406
- build_conv_layer(
407
- self.conv_cfg,
408
- num_channels_pre_layer[i],
409
- num_channels_cur_layer[i],
410
- kernel_size=3,
411
- stride=1,
412
- padding=1,
413
- bias=False),
414
- build_norm_layer(self.norm_cfg,
415
- num_channels_cur_layer[i])[1],
416
- nn.ReLU(inplace=True)))
417
- else:
418
- transition_layers.append(None)
419
- else:
420
- conv_downsamples = []
421
- for j in range(i + 1 - num_branches_pre):
422
- in_channels = num_channels_pre_layer[-1]
423
- out_channels = num_channels_cur_layer[i] \
424
- if j == i - num_branches_pre else in_channels
425
- conv_downsamples.append(
426
- nn.Sequential(
427
- build_conv_layer(
428
- self.conv_cfg,
429
- in_channels,
430
- out_channels,
431
- kernel_size=3,
432
- stride=2,
433
- padding=1,
434
- bias=False),
435
- build_norm_layer(self.norm_cfg, out_channels)[1],
436
- nn.ReLU(inplace=True)))
437
- transition_layers.append(nn.Sequential(*conv_downsamples))
438
-
439
- return nn.ModuleList(transition_layers)
440
-
441
- def _make_layer(self, block, in_channels, out_channels, blocks, stride=1):
442
- """Make layer."""
443
- downsample = None
444
- if stride != 1 or in_channels != out_channels:
445
- downsample = nn.Sequential(
446
- build_conv_layer(
447
- self.conv_cfg,
448
- in_channels,
449
- out_channels,
450
- kernel_size=1,
451
- stride=stride,
452
- bias=False),
453
- build_norm_layer(self.norm_cfg, out_channels)[1])
454
-
455
- layers = []
456
- layers.append(
457
- block(
458
- in_channels,
459
- out_channels,
460
- stride=stride,
461
- downsample=downsample,
462
- with_cp=self.with_cp,
463
- norm_cfg=self.norm_cfg,
464
- conv_cfg=self.conv_cfg))
465
- for _ in range(1, blocks):
466
- layers.append(
467
- block(
468
- out_channels,
469
- out_channels,
470
- with_cp=self.with_cp,
471
- norm_cfg=self.norm_cfg,
472
- conv_cfg=self.conv_cfg))
473
-
474
- return nn.Sequential(*layers)
475
-
476
- def _make_stage(self, layer_config, in_channels, multiscale_output=True):
477
- """Make stage."""
478
- num_modules = layer_config['num_modules']
479
- num_branches = layer_config['num_branches']
480
- num_blocks = layer_config['num_blocks']
481
- num_channels = layer_config['num_channels']
482
- block = self.blocks_dict[layer_config['block']]
483
-
484
- hr_modules = []
485
- for i in range(num_modules):
486
- # multi_scale_output is only used for the last module
487
- if not multiscale_output and i == num_modules - 1:
488
- reset_multiscale_output = False
489
- else:
490
- reset_multiscale_output = True
491
-
492
- hr_modules.append(
493
- HRModule(
494
- num_branches,
495
- block,
496
- num_blocks,
497
- in_channels,
498
- num_channels,
499
- reset_multiscale_output,
500
- with_cp=self.with_cp,
501
- norm_cfg=self.norm_cfg,
502
- conv_cfg=self.conv_cfg,
503
- upsample_cfg=self.upsample_cfg))
504
-
505
- in_channels = hr_modules[-1].in_channels
506
-
507
- return nn.Sequential(*hr_modules), in_channels
508
-
509
- def _freeze_stages(self):
510
- """Freeze parameters."""
511
- if self.frozen_stages >= 0:
512
- self.norm1.eval()
513
- self.norm2.eval()
514
-
515
- for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
516
- for param in m.parameters():
517
- param.requires_grad = False
518
-
519
- for i in range(1, self.frozen_stages + 1):
520
- if i == 1:
521
- m = getattr(self, 'layer1')
522
- else:
523
- m = getattr(self, f'stage{i}')
524
-
525
- m.eval()
526
- for param in m.parameters():
527
- param.requires_grad = False
528
-
529
- if i < 4:
530
- m = getattr(self, f'transition{i}')
531
- m.eval()
532
- for param in m.parameters():
533
- param.requires_grad = False
534
-
535
- def init_weights(self, pretrained=None):
536
- """Initialize the weights in backbone.
537
-
538
- Args:
539
- pretrained (str, optional): Path to pre-trained weights.
540
- Defaults to None.
541
- """
542
- if isinstance(pretrained, str):
543
- logger = get_root_logger()
544
- load_checkpoint(self, pretrained, strict=False, logger=logger)
545
- elif pretrained is None:
546
- for m in self.modules():
547
- if isinstance(m, nn.Conv2d):
548
- normal_init(m, std=0.001)
549
- elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
550
- constant_init(m, 1)
551
-
552
- if self.zero_init_residual:
553
- for m in self.modules():
554
- if isinstance(m, Bottleneck):
555
- constant_init(m.norm3, 0)
556
- elif isinstance(m, BasicBlock):
557
- constant_init(m.norm2, 0)
558
- else:
559
- raise TypeError('pretrained must be a str or None')
560
-
561
- def forward(self, x):
562
- """Forward function."""
563
- x = self.conv1(x)
564
- x = self.norm1(x)
565
- x = self.relu(x)
566
- x = self.conv2(x)
567
- x = self.norm2(x)
568
- x = self.relu(x)
569
- x = self.layer1(x)
570
-
571
- x_list = []
572
- for i in range(self.stage2_cfg['num_branches']):
573
- if self.transition1[i] is not None:
574
- x_list.append(self.transition1[i](x))
575
- else:
576
- x_list.append(x)
577
- y_list = self.stage2(x_list)
578
-
579
- x_list = []
580
- for i in range(self.stage3_cfg['num_branches']):
581
- if self.transition2[i] is not None:
582
- x_list.append(self.transition2[i](y_list[-1]))
583
- else:
584
- x_list.append(y_list[i])
585
- y_list = self.stage3(x_list)
586
-
587
- x_list = []
588
- for i in range(self.stage4_cfg['num_branches']):
589
- if self.transition3[i] is not None:
590
- x_list.append(self.transition3[i](y_list[-1]))
591
- else:
592
- x_list.append(y_list[i])
593
- y_list = self.stage4(x_list)
594
-
595
- return y_list
596
-
597
- def train(self, mode=True):
598
- """Convert the model into training mode."""
599
- super().train(mode)
600
- self._freeze_stages()
601
- if mode and self.norm_eval:
602
- for m in self.modules():
603
- if isinstance(m, _BatchNorm):
604
- m.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/hrt.py DELETED
@@ -1,676 +0,0 @@
1
- # --------------------------------------------------------
2
- # High Resolution Transformer
3
- # Copyright (c) 2021 Microsoft
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Written by Rao Fu, RainbowSecret
6
- # --------------------------------------------------------
7
-
8
- import pdb
9
- import torch
10
- import torch.nn as nn
11
- from mmcv.cnn import (
12
- build_conv_layer,
13
- build_norm_layer,
14
- constant_init,
15
- kaiming_init,
16
- normal_init,
17
- )
18
- # from mmcv.runner import load_checkpoint
19
- from .hrt_checkpoint import load_checkpoint
20
- from mmcv.runner.checkpoint import load_state_dict
21
- from mmcv.utils.parrots_wrapper import _BatchNorm
22
-
23
- from mmpose.models.utils.ops import resize
24
- from mmpose.utils import get_root_logger
25
- from ..builder import BACKBONES
26
- from .modules.bottleneck_block import Bottleneck
27
- from .modules.transformer_block import GeneralTransformerBlock
28
-
29
-
30
- class HighResolutionTransformerModule(nn.Module):
31
- def __init__(
32
- self,
33
- num_branches,
34
- blocks,
35
- num_blocks,
36
- in_channels,
37
- num_channels,
38
- multiscale_output,
39
- with_cp=False,
40
- conv_cfg=None,
41
- norm_cfg=dict(type="BN", requires_grad=True),
42
- num_heads=None,
43
- num_window_sizes=None,
44
- num_mlp_ratios=None,
45
- drop_paths=0.0,
46
- ):
47
- super(HighResolutionTransformerModule, self).__init__()
48
- self._check_branches(num_branches, num_blocks, in_channels, num_channels)
49
-
50
- self.in_channels = in_channels
51
- self.num_branches = num_branches
52
-
53
- self.multiscale_output = multiscale_output
54
- self.norm_cfg = norm_cfg
55
- self.conv_cfg = conv_cfg
56
- self.with_cp = with_cp
57
- self.branches = self._make_branches(
58
- num_branches,
59
- blocks,
60
- num_blocks,
61
- num_channels,
62
- num_heads,
63
- num_window_sizes,
64
- num_mlp_ratios,
65
- drop_paths,
66
- )
67
- self.fuse_layers = self._make_fuse_layers()
68
- self.relu = nn.ReLU(inplace=True)
69
-
70
- # MHSA parameters
71
- self.num_heads = num_heads
72
- self.num_window_sizes = num_window_sizes
73
- self.num_mlp_ratios = num_mlp_ratios
74
-
75
- def _check_branches(self, num_branches, num_blocks, in_channels, num_channels):
76
- logger = get_root_logger()
77
- if num_branches != len(num_blocks):
78
- error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
79
- num_branches, len(num_blocks)
80
- )
81
- logger.error(error_msg)
82
- raise ValueError(error_msg)
83
-
84
- if num_branches != len(num_channels):
85
- error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
86
- num_branches, len(num_channels)
87
- )
88
- logger.error(error_msg)
89
- raise ValueError(error_msg)
90
-
91
- if num_branches != len(in_channels):
92
- error_msg = "NUM_BRANCHES({}) <> IN_CHANNELS({})".format(
93
- num_branches, len(in_channels)
94
- )
95
- logger.error(error_msg)
96
- raise ValueError(error_msg)
97
-
98
- def _make_one_branch(
99
- self,
100
- branch_index,
101
- block,
102
- num_blocks,
103
- num_channels,
104
- num_heads,
105
- num_window_sizes,
106
- num_mlp_ratios,
107
- drop_paths,
108
- stride=1,
109
- ):
110
- """Make one branch."""
111
- downsample = None
112
- if (
113
- stride != 1
114
- or self.in_channels[branch_index]
115
- != num_channels[branch_index] * block.expansion
116
- ):
117
- downsample = nn.Sequential(
118
- build_conv_layer(
119
- self.conv_cfg,
120
- self.in_channels[branch_index],
121
- num_channels[branch_index] * block.expansion,
122
- kernel_size=1,
123
- stride=stride,
124
- bias=False,
125
- ),
126
- build_norm_layer(
127
- self.norm_cfg, num_channels[branch_index] * block.expansion
128
- )[1],
129
- )
130
-
131
- layers = []
132
-
133
- layers.append(
134
- block(
135
- self.in_channels[branch_index],
136
- num_channels[branch_index],
137
- num_heads=num_heads[branch_index],
138
- window_size=num_window_sizes[branch_index],
139
- mlp_ratio=num_mlp_ratios[branch_index],
140
- drop_path=drop_paths[0],
141
- norm_cfg=self.norm_cfg,
142
- conv_cfg=self.conv_cfg,
143
- )
144
- )
145
- self.in_channels[branch_index] = num_channels[branch_index] * block.expansion
146
- for i in range(1, num_blocks[branch_index]):
147
- layers.append(
148
- block(
149
- self.in_channels[branch_index],
150
- num_channels[branch_index],
151
- num_heads=num_heads[branch_index],
152
- window_size=num_window_sizes[branch_index],
153
- mlp_ratio=num_mlp_ratios[branch_index],
154
- drop_path=drop_paths[i],
155
- norm_cfg=self.norm_cfg,
156
- conv_cfg=self.conv_cfg,
157
- )
158
- )
159
-
160
- return nn.Sequential(*layers)
161
-
162
- def _make_branches(
163
- self,
164
- num_branches,
165
- block,
166
- num_blocks,
167
- num_channels,
168
- num_heads,
169
- num_window_sizes,
170
- num_mlp_ratios,
171
- drop_paths,
172
- ):
173
- """Make branches."""
174
- branches = []
175
-
176
- for i in range(num_branches):
177
- branches.append(
178
- self._make_one_branch(
179
- i,
180
- block,
181
- num_blocks,
182
- num_channels,
183
- num_heads,
184
- num_window_sizes,
185
- num_mlp_ratios,
186
- drop_paths,
187
- )
188
- )
189
-
190
- return nn.ModuleList(branches)
191
-
192
- def _make_fuse_layers(self):
193
- """Build fuse layer."""
194
- if self.num_branches == 1:
195
- return None
196
-
197
- num_branches = self.num_branches
198
- in_channels = self.in_channels
199
- fuse_layers = []
200
- num_out_branches = num_branches if self.multiscale_output else 1
201
- for i in range(num_out_branches):
202
- fuse_layer = []
203
- for j in range(num_branches):
204
- if j > i:
205
- fuse_layer.append(
206
- nn.Sequential(
207
- build_conv_layer(
208
- self.conv_cfg,
209
- in_channels[j],
210
- in_channels[i],
211
- kernel_size=1,
212
- stride=1,
213
- padding=0,
214
- bias=False,
215
- ),
216
- build_norm_layer(self.norm_cfg, in_channels[i])[1],
217
- nn.Upsample(
218
- scale_factor=2 ** (j - i),
219
- mode="bilinear",
220
- align_corners=False,
221
- ),
222
- )
223
- )
224
- elif j == i:
225
- fuse_layer.append(None)
226
- else:
227
- conv_downsamples = []
228
- for k in range(i - j):
229
- if k == i - j - 1:
230
- conv_downsamples.append(
231
- nn.Sequential(
232
- build_conv_layer(
233
- self.conv_cfg,
234
- in_channels[j],
235
- in_channels[j],
236
- kernel_size=3,
237
- stride=2,
238
- padding=1,
239
- groups=in_channels[j],
240
- bias=False,
241
- ),
242
- build_norm_layer(self.norm_cfg, in_channels[j])[1],
243
- build_conv_layer(
244
- self.conv_cfg,
245
- in_channels[j],
246
- in_channels[i],
247
- kernel_size=1,
248
- stride=1,
249
- bias=False,
250
- ),
251
- build_norm_layer(self.norm_cfg, in_channels[i])[1],
252
- )
253
- )
254
- else:
255
- conv_downsamples.append(
256
- nn.Sequential(
257
- build_conv_layer(
258
- self.conv_cfg,
259
- in_channels[j],
260
- in_channels[j],
261
- kernel_size=3,
262
- stride=2,
263
- padding=1,
264
- groups=in_channels[j],
265
- bias=False,
266
- ),
267
- build_norm_layer(self.norm_cfg, in_channels[j])[1],
268
- build_conv_layer(
269
- self.conv_cfg,
270
- in_channels[j],
271
- in_channels[j],
272
- kernel_size=1,
273
- stride=1,
274
- bias=False,
275
- ),
276
- build_norm_layer(self.norm_cfg, in_channels[j])[1],
277
- nn.ReLU(inplace=True),
278
- )
279
- )
280
- fuse_layer.append(nn.Sequential(*conv_downsamples))
281
- fuse_layers.append(nn.ModuleList(fuse_layer))
282
- return nn.ModuleList(fuse_layers)
283
-
284
- def forward(self, x):
285
- """Forward function."""
286
- if self.num_branches == 1:
287
- return [self.branches[0](x[0])]
288
-
289
- for i in range(self.num_branches):
290
- x[i] = self.branches[i](x[i])
291
-
292
- x_fuse = []
293
- for i in range(len(self.fuse_layers)):
294
- y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
295
- for j in range(1, self.num_branches):
296
- if i == j:
297
- y += x[j]
298
- elif j > i:
299
- y = y + resize(
300
- self.fuse_layers[i][j](x[j]),
301
- size=x[i].shape[2:],
302
- mode="bilinear",
303
- align_corners=False,
304
- )
305
- else:
306
- y += self.fuse_layers[i][j](x[j])
307
- x_fuse.append(self.relu(y))
308
- return x_fuse
309
-
310
-
311
- @BACKBONES.register_module()
312
- class HRT(nn.Module):
313
- """HRT backbone.
314
- High Resolution Transformer Backbone
315
- """
316
-
317
- blocks_dict = {
318
- "BOTTLENECK": Bottleneck,
319
- "TRANSFORMER_BLOCK": GeneralTransformerBlock,
320
- }
321
-
322
- def __init__(
323
- self,
324
- extra,
325
- in_channels=3,
326
- conv_cfg=None,
327
- norm_cfg=dict(type="BN", requires_grad=True),
328
- norm_eval=False,
329
- with_cp=False,
330
- zero_init_residual=False,
331
- ):
332
- super(HRT, self).__init__()
333
- self.extra = extra
334
- self.conv_cfg = conv_cfg
335
- self.norm_cfg = norm_cfg
336
- self.norm_eval = norm_eval
337
- self.with_cp = with_cp
338
- self.zero_init_residual = zero_init_residual
339
-
340
- # stem net
341
- self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
342
- self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
343
-
344
- self.conv1 = build_conv_layer(
345
- self.conv_cfg,
346
- in_channels,
347
- 64,
348
- kernel_size=3,
349
- stride=2,
350
- padding=1,
351
- bias=False,
352
- )
353
- self.add_module(self.norm1_name, norm1)
354
-
355
- self.conv2 = build_conv_layer(
356
- self.conv_cfg, 64, 64, kernel_size=3, stride=2, padding=1, bias=False
357
- )
358
- self.add_module(self.norm2_name, norm2)
359
- self.relu = nn.ReLU(inplace=True)
360
-
361
- # generat drop path rate list
362
- depth_s2 = (
363
- self.extra["stage2"]["num_blocks"][0] * self.extra["stage2"]["num_modules"]
364
- )
365
- depth_s3 = (
366
- self.extra["stage3"]["num_blocks"][0] * self.extra["stage3"]["num_modules"]
367
- )
368
- depth_s4 = (
369
- self.extra["stage4"]["num_blocks"][0] * self.extra["stage4"]["num_modules"]
370
- )
371
- depths = [depth_s2, depth_s3, depth_s4]
372
- drop_path_rate = self.extra["drop_path_rate"]
373
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
374
-
375
- logger = get_root_logger()
376
- logger.info(dpr)
377
-
378
- # stage 1
379
- self.stage1_cfg = self.extra["stage1"]
380
- num_channels = self.stage1_cfg["num_channels"][0]
381
- block_type = self.stage1_cfg["block"]
382
- num_blocks = self.stage1_cfg["num_blocks"][0]
383
-
384
- block = self.blocks_dict[block_type]
385
- stage1_out_channels = num_channels * block.expansion
386
- self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
387
-
388
- # stage 2
389
- self.stage2_cfg = self.extra["stage2"]
390
- num_channels = self.stage2_cfg["num_channels"]
391
- block_type = self.stage2_cfg["block"]
392
-
393
- block = self.blocks_dict[block_type]
394
- num_channels = [channel * block.expansion for channel in num_channels]
395
- self.transition1 = self._make_transition_layer(
396
- [stage1_out_channels], num_channels
397
- )
398
- self.stage2, pre_stage_channels = self._make_stage(
399
- self.stage2_cfg, num_channels, drop_paths=dpr[0:depth_s2]
400
- )
401
-
402
- # stage 3
403
- self.stage3_cfg = self.extra["stage3"]
404
- num_channels = self.stage3_cfg["num_channels"]
405
- block_type = self.stage3_cfg["block"]
406
-
407
- block = self.blocks_dict[block_type]
408
- num_channels = [channel * block.expansion for channel in num_channels]
409
- self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
410
- self.stage3, pre_stage_channels = self._make_stage(
411
- self.stage3_cfg,
412
- num_channels,
413
- drop_paths=dpr[depth_s2 : depth_s2 + depth_s3],
414
- )
415
-
416
- # stage 4
417
- self.stage4_cfg = self.extra["stage4"]
418
- num_channels = self.stage4_cfg["num_channels"]
419
- block_type = self.stage4_cfg["block"]
420
-
421
- block = self.blocks_dict[block_type]
422
- num_channels = [channel * block.expansion for channel in num_channels]
423
- self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
424
- self.stage4, pre_stage_channels = self._make_stage(
425
- self.stage4_cfg,
426
- num_channels,
427
- multiscale_output=self.stage4_cfg.get("multiscale_output", True),
428
- drop_paths=dpr[depth_s2 + depth_s3 :],
429
- )
430
-
431
- @property
432
- def norm1(self):
433
- """nn.Module: the normalization layer named "norm1" """
434
- return getattr(self, self.norm1_name)
435
-
436
- @property
437
- def norm2(self):
438
- """nn.Module: the normalization layer named "norm2" """
439
- return getattr(self, self.norm2_name)
440
-
441
- def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
442
- """Make transition layer."""
443
- num_branches_cur = len(num_channels_cur_layer)
444
- num_branches_pre = len(num_channels_pre_layer)
445
-
446
- transition_layers = []
447
- for i in range(num_branches_cur):
448
- if i < num_branches_pre:
449
- if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
450
- transition_layers.append(
451
- nn.Sequential(
452
- build_conv_layer(
453
- self.conv_cfg,
454
- num_channels_pre_layer[i],
455
- num_channels_cur_layer[i],
456
- kernel_size=3,
457
- stride=1,
458
- padding=1,
459
- bias=False,
460
- ),
461
- build_norm_layer(self.norm_cfg, num_channels_cur_layer[i])[
462
- 1
463
- ],
464
- nn.ReLU(inplace=True),
465
- )
466
- )
467
- else:
468
- transition_layers.append(None)
469
- else:
470
- conv_downsamples = []
471
- for j in range(i + 1 - num_branches_pre):
472
- in_channels = num_channels_pre_layer[-1]
473
- out_channels = (
474
- num_channels_cur_layer[i]
475
- if j == i - num_branches_pre
476
- else in_channels
477
- )
478
- conv_downsamples.append(
479
- nn.Sequential(
480
- build_conv_layer(
481
- self.conv_cfg,
482
- in_channels,
483
- out_channels,
484
- kernel_size=3,
485
- stride=2,
486
- padding=1,
487
- bias=False,
488
- ),
489
- build_norm_layer(self.norm_cfg, out_channels)[1],
490
- nn.ReLU(inplace=True),
491
- )
492
- )
493
- transition_layers.append(nn.Sequential(*conv_downsamples))
494
-
495
- return nn.ModuleList(transition_layers)
496
-
497
- def _make_layer(
498
- self,
499
- block,
500
- inplanes,
501
- planes,
502
- blocks,
503
- stride=1,
504
- num_heads=1,
505
- window_size=7,
506
- mlp_ratio=4.0,
507
- ):
508
- """Make each layer."""
509
- downsample = None
510
- if stride != 1 or inplanes != planes * block.expansion:
511
- downsample = nn.Sequential(
512
- build_conv_layer(
513
- self.conv_cfg,
514
- inplanes,
515
- planes * block.expansion,
516
- kernel_size=1,
517
- stride=stride,
518
- bias=False,
519
- ),
520
- build_norm_layer(self.norm_cfg, planes * block.expansion)[1],
521
- )
522
-
523
- layers = []
524
- if isinstance(block, GeneralTransformerBlock):
525
- layers.append(
526
- block(
527
- inplanes,
528
- planes,
529
- num_heads=num_heads,
530
- window_size=window_size,
531
- mlp_ratio=mlp_ratio,
532
- norm_cfg=self.norm_cfg,
533
- conv_cfg=self.conv_cfg,
534
- )
535
- )
536
- else:
537
- layers.append(
538
- block(
539
- inplanes,
540
- planes,
541
- stride,
542
- downsample=downsample,
543
- with_cp=self.with_cp,
544
- norm_cfg=self.norm_cfg,
545
- conv_cfg=self.conv_cfg,
546
- )
547
- )
548
- inplanes = planes * block.expansion
549
- for i in range(1, blocks):
550
- layers.append(
551
- block(
552
- inplanes,
553
- planes,
554
- with_cp=self.with_cp,
555
- norm_cfg=self.norm_cfg,
556
- conv_cfg=self.conv_cfg,
557
- )
558
- )
559
-
560
- return nn.Sequential(*layers)
561
-
562
- def _make_stage(
563
- self, layer_config, in_channels, multiscale_output=True, drop_paths=0.0
564
- ):
565
- """Make each stage."""
566
- num_modules = layer_config["num_modules"]
567
- num_branches = layer_config["num_branches"]
568
- num_blocks = layer_config["num_blocks"]
569
- num_channels = layer_config["num_channels"]
570
- block = self.blocks_dict[layer_config["block"]]
571
-
572
- num_heads = layer_config["num_heads"]
573
- num_window_sizes = layer_config["num_window_sizes"]
574
- num_mlp_ratios = layer_config["num_mlp_ratios"]
575
-
576
- hr_modules = []
577
- for i in range(num_modules):
578
- # multi_scale_output is only used for the last module
579
- if not multiscale_output and i == num_modules - 1:
580
- reset_multiscale_output = False
581
- else:
582
- reset_multiscale_output = True
583
-
584
- hr_modules.append(
585
- HighResolutionTransformerModule(
586
- num_branches,
587
- block,
588
- num_blocks,
589
- in_channels,
590
- num_channels,
591
- reset_multiscale_output,
592
- with_cp=self.with_cp,
593
- norm_cfg=self.norm_cfg,
594
- conv_cfg=self.conv_cfg,
595
- num_heads=num_heads,
596
- num_window_sizes=num_window_sizes,
597
- num_mlp_ratios=num_mlp_ratios,
598
- drop_paths=drop_paths[num_blocks[0] * i : num_blocks[0] * (i + 1)],
599
- )
600
- )
601
-
602
- return nn.Sequential(*hr_modules), in_channels
603
-
604
- def init_weights(self, pretrained=None):
605
- """Initialize the weights in backbone.
606
-
607
- Args:
608
- pretrained (str, optional): Path to pre-trained weights.
609
- Defaults to None.
610
- """
611
- if isinstance(pretrained, str):
612
- logger = get_root_logger()
613
- ckpt = load_checkpoint(self, pretrained, strict=False)
614
- if "model" in ckpt:
615
- msg = self.load_state_dict(ckpt["model"], strict=False)
616
- logger.info(msg)
617
- elif pretrained is None:
618
- for m in self.modules():
619
- if isinstance(m, nn.Conv2d):
620
- """mmseg: kaiming_init(m)"""
621
- normal_init(m, std=0.001)
622
- elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
623
- constant_init(m, 1)
624
-
625
- if self.zero_init_residual:
626
- for m in self.modules():
627
- if isinstance(m, Bottleneck):
628
- constant_init(m.norm3, 0)
629
- elif isinstance(m, BasicBlock):
630
- constant_init(m.norm2, 0)
631
- else:
632
- raise TypeError("pretrained must be a str or None")
633
-
634
- def forward(self, x):
635
- """Forward function."""
636
- x = self.conv1(x)
637
- x = self.norm1(x)
638
- x = self.relu(x)
639
- x = self.conv2(x)
640
- x = self.norm2(x)
641
- x = self.relu(x)
642
- x = self.layer1(x)
643
-
644
- x_list = []
645
- for i in range(self.stage2_cfg["num_branches"]):
646
- if self.transition1[i] is not None:
647
- x_list.append(self.transition1[i](x))
648
- else:
649
- x_list.append(x)
650
- y_list = self.stage2(x_list)
651
-
652
- x_list = []
653
- for i in range(self.stage3_cfg["num_branches"]):
654
- if self.transition2[i] is not None:
655
- x_list.append(self.transition2[i](y_list[-1]))
656
- else:
657
- x_list.append(y_list[i])
658
- y_list = self.stage3(x_list)
659
-
660
- x_list = []
661
- for i in range(self.stage4_cfg["num_branches"]):
662
- if self.transition3[i] is not None:
663
- x_list.append(self.transition3[i](y_list[-1]))
664
- else:
665
- x_list.append(y_list[i])
666
- y_list = self.stage4(x_list)
667
-
668
- return y_list
669
-
670
- def train(self, mode=True):
671
- """Convert the model into training mode."""
672
- super(HRT, self).train(mode)
673
- if mode and self.norm_eval:
674
- for m in self.modules():
675
- if isinstance(m, _BatchNorm):
676
- m.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/hrt_checkpoint.py DELETED
@@ -1,500 +0,0 @@
1
- # Copyright (c) Open-MMLab. All rights reserved.
2
- import io
3
- import os
4
- import os.path as osp
5
- import pkgutil
6
- import time
7
- import warnings
8
- from collections import OrderedDict
9
- from importlib import import_module
10
- from tempfile import TemporaryDirectory
11
-
12
- import torch
13
- import torchvision
14
- from torch.optim import Optimizer
15
- from torch.utils import model_zoo
16
- from torch.nn import functional as F
17
-
18
- import mmcv
19
- from mmcv.fileio import FileClient
20
- from mmcv.fileio import load as load_file
21
- from mmcv.parallel import is_module_wrapper
22
- from mmcv.utils import mkdir_or_exist
23
- from mmcv.runner import get_dist_info
24
-
25
- ENV_MMCV_HOME = 'MMCV_HOME'
26
- ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
27
- DEFAULT_CACHE_DIR = '~/.cache'
28
-
29
-
30
- def _get_mmcv_home():
31
- mmcv_home = os.path.expanduser(
32
- os.getenv(
33
- ENV_MMCV_HOME,
34
- os.path.join(
35
- os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
36
-
37
- mkdir_or_exist(mmcv_home)
38
- return mmcv_home
39
-
40
-
41
- def load_state_dict(module, state_dict, strict=False, logger=None):
42
- """Load state_dict to a module.
43
-
44
- This method is modified from :meth:`torch.nn.Module.load_state_dict`.
45
- Default value for ``strict`` is set to ``False`` and the message for
46
- param mismatch will be shown even if strict is False.
47
-
48
- Args:
49
- module (Module): Module that receives the state_dict.
50
- state_dict (OrderedDict): Weights.
51
- strict (bool): whether to strictly enforce that the keys
52
- in :attr:`state_dict` match the keys returned by this module's
53
- :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
54
- logger (:obj:`logging.Logger`, optional): Logger to log the error
55
- message. If not specified, print function will be used.
56
- """
57
- unexpected_keys = []
58
- all_missing_keys = []
59
- err_msg = []
60
-
61
- metadata = getattr(state_dict, '_metadata', None)
62
- state_dict = state_dict.copy()
63
- if metadata is not None:
64
- state_dict._metadata = metadata
65
-
66
- # use _load_from_state_dict to enable checkpoint version control
67
- def load(module, prefix=''):
68
- # recursively check parallel module in case that the model has a
69
- # complicated structure, e.g., nn.Module(nn.Module(DDP))
70
- if is_module_wrapper(module):
71
- module = module.module
72
- local_metadata = {} if metadata is None else metadata.get(
73
- prefix[:-1], {})
74
- module._load_from_state_dict(state_dict, prefix, local_metadata, True,
75
- all_missing_keys, unexpected_keys,
76
- err_msg)
77
- for name, child in module._modules.items():
78
- if child is not None:
79
- load(child, prefix + name + '.')
80
-
81
- load(module)
82
- load = None # break load->load reference cycle
83
-
84
- # ignore "num_batches_tracked" of BN layers
85
- missing_keys = [
86
- key for key in all_missing_keys if 'num_batches_tracked' not in key
87
- ]
88
-
89
- if unexpected_keys:
90
- err_msg.append('unexpected key in source '
91
- f'state_dict: {", ".join(unexpected_keys)}\n')
92
- if missing_keys:
93
- err_msg.append(
94
- f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
95
-
96
- rank, _ = get_dist_info()
97
- if len(err_msg) > 0 and rank == 0:
98
- err_msg.insert(
99
- 0, 'The model and loaded state dict do not match exactly\n')
100
- err_msg = '\n'.join(err_msg)
101
- if strict:
102
- raise RuntimeError(err_msg)
103
- elif logger is not None:
104
- logger.warning(err_msg)
105
- else:
106
- print(err_msg)
107
-
108
-
109
- def load_url_dist(url, model_dir=None):
110
- """In distributed setting, this function only download checkpoint at local
111
- rank 0."""
112
- rank, world_size = get_dist_info()
113
- rank = int(os.environ.get('LOCAL_RANK', rank))
114
- if rank == 0:
115
- checkpoint = model_zoo.load_url(url, model_dir=model_dir)
116
- if world_size > 1:
117
- torch.distributed.barrier()
118
- if rank > 0:
119
- checkpoint = model_zoo.load_url(url, model_dir=model_dir)
120
- return checkpoint
121
-
122
-
123
- def load_pavimodel_dist(model_path, map_location=None):
124
- """In distributed setting, this function only download checkpoint at local
125
- rank 0."""
126
- try:
127
- from pavi import modelcloud
128
- except ImportError:
129
- raise ImportError(
130
- 'Please install pavi to load checkpoint from modelcloud.')
131
- rank, world_size = get_dist_info()
132
- rank = int(os.environ.get('LOCAL_RANK', rank))
133
- if rank == 0:
134
- model = modelcloud.get(model_path)
135
- with TemporaryDirectory() as tmp_dir:
136
- downloaded_file = osp.join(tmp_dir, model.name)
137
- model.download(downloaded_file)
138
- checkpoint = torch.load(downloaded_file, map_location=map_location)
139
- if world_size > 1:
140
- torch.distributed.barrier()
141
- if rank > 0:
142
- model = modelcloud.get(model_path)
143
- with TemporaryDirectory() as tmp_dir:
144
- downloaded_file = osp.join(tmp_dir, model.name)
145
- model.download(downloaded_file)
146
- checkpoint = torch.load(
147
- downloaded_file, map_location=map_location)
148
- return checkpoint
149
-
150
-
151
- def load_fileclient_dist(filename, backend, map_location):
152
- """In distributed setting, this function only download checkpoint at local
153
- rank 0."""
154
- rank, world_size = get_dist_info()
155
- rank = int(os.environ.get('LOCAL_RANK', rank))
156
- allowed_backends = ['ceph']
157
- if backend not in allowed_backends:
158
- raise ValueError(f'Load from Backend {backend} is not supported.')
159
- if rank == 0:
160
- fileclient = FileClient(backend=backend)
161
- buffer = io.BytesIO(fileclient.get(filename))
162
- checkpoint = torch.load(buffer, map_location=map_location)
163
- if world_size > 1:
164
- torch.distributed.barrier()
165
- if rank > 0:
166
- fileclient = FileClient(backend=backend)
167
- buffer = io.BytesIO(fileclient.get(filename))
168
- checkpoint = torch.load(buffer, map_location=map_location)
169
- return checkpoint
170
-
171
-
172
- def get_torchvision_models():
173
- model_urls = dict()
174
- for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
175
- if ispkg:
176
- continue
177
- _zoo = import_module(f'torchvision.models.{name}')
178
- if hasattr(_zoo, 'model_urls'):
179
- _urls = getattr(_zoo, 'model_urls')
180
- model_urls.update(_urls)
181
- return model_urls
182
-
183
-
184
- def get_external_models():
185
- mmcv_home = _get_mmcv_home()
186
- default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
187
- default_urls = load_file(default_json_path)
188
- assert isinstance(default_urls, dict)
189
- external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
190
- if osp.exists(external_json_path):
191
- external_urls = load_file(external_json_path)
192
- assert isinstance(external_urls, dict)
193
- default_urls.update(external_urls)
194
-
195
- return default_urls
196
-
197
-
198
- def get_mmcls_models():
199
- mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
200
- mmcls_urls = load_file(mmcls_json_path)
201
-
202
- return mmcls_urls
203
-
204
-
205
- def get_deprecated_model_names():
206
- deprecate_json_path = osp.join(mmcv.__path__[0],
207
- 'model_zoo/deprecated.json')
208
- deprecate_urls = load_file(deprecate_json_path)
209
- assert isinstance(deprecate_urls, dict)
210
-
211
- return deprecate_urls
212
-
213
-
214
- def _process_mmcls_checkpoint(checkpoint):
215
- state_dict = checkpoint['state_dict']
216
- new_state_dict = OrderedDict()
217
- for k, v in state_dict.items():
218
- if k.startswith('backbone.'):
219
- new_state_dict[k[9:]] = v
220
- new_checkpoint = dict(state_dict=new_state_dict)
221
-
222
- return new_checkpoint
223
-
224
-
225
- def _load_checkpoint(filename, map_location=None):
226
- """Load checkpoint from somewhere (modelzoo, file, url).
227
-
228
- Args:
229
- filename (str): Accept local filepath, URL, ``torchvision://xxx``,
230
- ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
231
- details.
232
- map_location (str | None): Same as :func:`torch.load`. Default: None.
233
-
234
- Returns:
235
- dict | OrderedDict: The loaded checkpoint. It can be either an
236
- OrderedDict storing model weights or a dict containing other
237
- information, which depends on the checkpoint.
238
- """
239
- if filename.startswith('modelzoo://'):
240
- warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
241
- 'use "torchvision://" instead')
242
- model_urls = get_torchvision_models()
243
- model_name = filename[11:]
244
- checkpoint = load_url_dist(model_urls[model_name])
245
- elif filename.startswith('torchvision://'):
246
- model_urls = get_torchvision_models()
247
- model_name = filename[14:]
248
- checkpoint = load_url_dist(model_urls[model_name])
249
- elif filename.startswith('open-mmlab://'):
250
- model_urls = get_external_models()
251
- model_name = filename[13:]
252
- deprecated_urls = get_deprecated_model_names()
253
- if model_name in deprecated_urls:
254
- warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
255
- f'of open-mmlab://{deprecated_urls[model_name]}')
256
- model_name = deprecated_urls[model_name]
257
- model_url = model_urls[model_name]
258
- # check if is url
259
- if model_url.startswith(('http://', 'https://')):
260
- checkpoint = load_url_dist(model_url)
261
- else:
262
- filename = osp.join(_get_mmcv_home(), model_url)
263
- if not osp.isfile(filename):
264
- raise IOError(f'{filename} is not a checkpoint file')
265
- checkpoint = torch.load(filename, map_location=map_location)
266
- elif filename.startswith('mmcls://'):
267
- model_urls = get_mmcls_models()
268
- model_name = filename[8:]
269
- checkpoint = load_url_dist(model_urls[model_name])
270
- checkpoint = _process_mmcls_checkpoint(checkpoint)
271
- elif filename.startswith(('http://', 'https://')):
272
- checkpoint = load_url_dist(filename)
273
- elif filename.startswith('pavi://'):
274
- model_path = filename[7:]
275
- checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
276
- elif filename.startswith('s3://'):
277
- checkpoint = load_fileclient_dist(
278
- filename, backend='ceph', map_location=map_location)
279
- else:
280
- if not osp.isfile(filename):
281
- raise IOError(f'{filename} is not a checkpoint file')
282
- checkpoint = torch.load(filename, map_location=map_location)
283
- return checkpoint
284
-
285
-
286
- def load_checkpoint(model,
287
- filename,
288
- map_location='cpu',
289
- strict=False,
290
- logger=None):
291
- """Load checkpoint from a file or URI.
292
-
293
- Args:
294
- model (Module): Module to load checkpoint.
295
- filename (str): Accept local filepath, URL, ``torchvision://xxx``,
296
- ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
297
- details.
298
- map_location (str): Same as :func:`torch.load`.
299
- strict (bool): Whether to allow different params for the model and
300
- checkpoint.
301
- logger (:mod:`logging.Logger` or None): The logger for error message.
302
-
303
- Returns:
304
- dict or OrderedDict: The loaded checkpoint.
305
- """
306
- checkpoint = _load_checkpoint(filename, map_location)
307
- # OrderedDict is a subclass of dict
308
- if not isinstance(checkpoint, dict):
309
- raise RuntimeError(
310
- f'No state_dict found in checkpoint file {filename}')
311
- # get state_dict from checkpoint
312
- if 'state_dict' in checkpoint:
313
- state_dict = checkpoint['state_dict']
314
- elif 'model' in checkpoint:
315
- state_dict = checkpoint['model']
316
- else:
317
- state_dict = checkpoint
318
- # strip prefix of state_dict
319
- if list(state_dict.keys())[0].startswith('module.'):
320
- state_dict = {k[7:]: v for k, v in state_dict.items()}
321
-
322
- # for MoBY, load model of online branch
323
- if sorted(list(state_dict.keys()))[0].startswith('encoder'):
324
- state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
325
-
326
- # reshape absolute position embedding
327
- if state_dict.get('absolute_pos_embed') is not None:
328
- absolute_pos_embed = state_dict['absolute_pos_embed']
329
- N1, L, C1 = absolute_pos_embed.size()
330
- N2, C2, H, W = model.absolute_pos_embed.size()
331
- if N1 != N2 or C1 != C2 or L != H*W:
332
- logger.warning("Error in loading absolute_pos_embed, pass")
333
- else:
334
- state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
335
-
336
- # interpolate position bias table if needed
337
- # relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
338
- # for table_key in relative_position_bias_table_keys:
339
- # table_pretrained = state_dict[table_key]
340
- # table_current = model.state_dict()[table_key]
341
- # L1, nH1 = table_pretrained.size()
342
- # L2, nH2 = table_current.size()
343
- # if nH1 != nH2:
344
- # logger.warning(f"Error in loading {table_key}, pass")
345
- # else:
346
- # if L1 != L2:
347
- # S1 = int(L1 ** 0.5)
348
- # S2 = int(L2 ** 0.5)
349
- # table_pretrained_resized = F.interpolate(
350
- # table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
351
- # size=(S2, S2), mode='bicubic')
352
- # state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
353
-
354
- # load state_dict
355
- load_state_dict(model, state_dict, strict, logger)
356
- return checkpoint
357
-
358
-
359
- def weights_to_cpu(state_dict):
360
- """Copy a model state_dict to cpu.
361
-
362
- Args:
363
- state_dict (OrderedDict): Model weights on GPU.
364
-
365
- Returns:
366
- OrderedDict: Model weights on GPU.
367
- """
368
- state_dict_cpu = OrderedDict()
369
- for key, val in state_dict.items():
370
- state_dict_cpu[key] = val.cpu()
371
- return state_dict_cpu
372
-
373
-
374
- def _save_to_state_dict(module, destination, prefix, keep_vars):
375
- """Saves module state to `destination` dictionary.
376
-
377
- This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
378
-
379
- Args:
380
- module (nn.Module): The module to generate state_dict.
381
- destination (dict): A dict where state will be stored.
382
- prefix (str): The prefix for parameters and buffers used in this
383
- module.
384
- """
385
- for name, param in module._parameters.items():
386
- if param is not None:
387
- destination[prefix + name] = param if keep_vars else param.detach()
388
- for name, buf in module._buffers.items():
389
- # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
390
- if buf is not None:
391
- destination[prefix + name] = buf if keep_vars else buf.detach()
392
-
393
-
394
- def get_state_dict(module, destination=None, prefix='', keep_vars=False):
395
- """Returns a dictionary containing a whole state of the module.
396
-
397
- Both parameters and persistent buffers (e.g. running averages) are
398
- included. Keys are corresponding parameter and buffer names.
399
-
400
- This method is modified from :meth:`torch.nn.Module.state_dict` to
401
- recursively check parallel module in case that the model has a complicated
402
- structure, e.g., nn.Module(nn.Module(DDP)).
403
-
404
- Args:
405
- module (nn.Module): The module to generate state_dict.
406
- destination (OrderedDict): Returned dict for the state of the
407
- module.
408
- prefix (str): Prefix of the key.
409
- keep_vars (bool): Whether to keep the variable property of the
410
- parameters. Default: False.
411
-
412
- Returns:
413
- dict: A dictionary containing a whole state of the module.
414
- """
415
- # recursively check parallel module in case that the model has a
416
- # complicated structure, e.g., nn.Module(nn.Module(DDP))
417
- if is_module_wrapper(module):
418
- module = module.module
419
-
420
- # below is the same as torch.nn.Module.state_dict()
421
- if destination is None:
422
- destination = OrderedDict()
423
- destination._metadata = OrderedDict()
424
- destination._metadata[prefix[:-1]] = local_metadata = dict(
425
- version=module._version)
426
- _save_to_state_dict(module, destination, prefix, keep_vars)
427
- for name, child in module._modules.items():
428
- if child is not None:
429
- get_state_dict(
430
- child, destination, prefix + name + '.', keep_vars=keep_vars)
431
- for hook in module._state_dict_hooks.values():
432
- hook_result = hook(module, destination, prefix, local_metadata)
433
- if hook_result is not None:
434
- destination = hook_result
435
- return destination
436
-
437
-
438
- def save_checkpoint(model, filename, optimizer=None, meta=None):
439
- """Save checkpoint to file.
440
-
441
- The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
442
- ``optimizer``. By default ``meta`` will contain version and time info.
443
-
444
- Args:
445
- model (Module): Module whose params are to be saved.
446
- filename (str): Checkpoint filename.
447
- optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
448
- meta (dict, optional): Metadata to be saved in checkpoint.
449
- """
450
- if meta is None:
451
- meta = {}
452
- elif not isinstance(meta, dict):
453
- raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
454
- meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
455
-
456
- if is_module_wrapper(model):
457
- model = model.module
458
-
459
- if hasattr(model, 'CLASSES') and model.CLASSES is not None:
460
- # save class name to the meta
461
- meta.update(CLASSES=model.CLASSES)
462
-
463
- checkpoint = {
464
- 'meta': meta,
465
- 'state_dict': weights_to_cpu(get_state_dict(model))
466
- }
467
- # save optimizer state dict in the checkpoint
468
- if isinstance(optimizer, Optimizer):
469
- checkpoint['optimizer'] = optimizer.state_dict()
470
- elif isinstance(optimizer, dict):
471
- checkpoint['optimizer'] = {}
472
- for name, optim in optimizer.items():
473
- checkpoint['optimizer'][name] = optim.state_dict()
474
-
475
- if filename.startswith('pavi://'):
476
- try:
477
- from pavi import modelcloud
478
- from pavi.exception import NodeNotFoundError
479
- except ImportError:
480
- raise ImportError(
481
- 'Please install pavi to load checkpoint from modelcloud.')
482
- model_path = filename[7:]
483
- root = modelcloud.Folder()
484
- model_dir, model_name = osp.split(model_path)
485
- try:
486
- model = modelcloud.get(model_dir)
487
- except NodeNotFoundError:
488
- model = root.create_training_model(model_dir)
489
- with TemporaryDirectory() as tmp_dir:
490
- checkpoint_file = osp.join(tmp_dir, model_name)
491
- with open(checkpoint_file, 'wb') as f:
492
- torch.save(checkpoint, f)
493
- f.flush()
494
- model.create_file(checkpoint_file, name=model_name)
495
- else:
496
- mmcv.mkdir_or_exist(osp.dirname(filename))
497
- # immediately flush buffer
498
- with open(filename, 'wb') as f:
499
- torch.save(checkpoint, f)
500
- f.flush()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/i3d.py DELETED
@@ -1,215 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- # Code is modified from `Third-party pytorch implementation of i3d
3
- # <https://github.com/hassony2/kinetics_i3d_pytorch>`.
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ..builder import BACKBONES
9
- from .base_backbone import BaseBackbone
10
-
11
-
12
- class Conv3dBlock(nn.Module):
13
- """Basic 3d convolution block for I3D.
14
-
15
- Args:
16
- in_channels (int): Input channels of this block.
17
- out_channels (int): Output channels of this block.
18
- expansion (float): The multiplier of in_channels and out_channels.
19
- Default: 1.
20
- kernel_size (tuple[int]): kernel size of the 3d convolution layer.
21
- Default: (1, 1, 1).
22
- stride (tuple[int]): stride of the block. Default: (1, 1, 1)
23
- padding (tuple[int]): padding of the input tensor. Default: (0, 0, 0)
24
- use_bias (bool): whether to enable bias in 3d convolution layer.
25
- Default: False
26
- use_bn (bool): whether to use Batch Normalization after 3d convolution
27
- layer. Default: True
28
- use_relu (bool): whether to use ReLU after Batch Normalization layer.
29
- Default: True
30
- """
31
-
32
- def __init__(self,
33
- in_channels,
34
- out_channels,
35
- expansion=1.0,
36
- kernel_size=(1, 1, 1),
37
- stride=(1, 1, 1),
38
- padding=(0, 0, 0),
39
- use_bias=False,
40
- use_bn=True,
41
- use_relu=True):
42
- super().__init__()
43
-
44
- in_channels = int(in_channels * expansion)
45
- out_channels = int(out_channels * expansion)
46
-
47
- self.conv3d = nn.Conv3d(
48
- in_channels,
49
- out_channels,
50
- kernel_size,
51
- padding=padding,
52
- stride=stride,
53
- bias=use_bias)
54
-
55
- self.use_bn = use_bn
56
- self.use_relu = use_relu
57
-
58
- if self.use_bn:
59
- self.batch3d = nn.BatchNorm3d(out_channels)
60
-
61
- if self.use_relu:
62
- self.activation = nn.ReLU(inplace=True)
63
-
64
- def forward(self, x):
65
- """Forward function."""
66
- out = self.conv3d(x)
67
- if self.use_bn:
68
- out = self.batch3d(out)
69
- if self.use_relu:
70
- out = self.activation(out)
71
- return out
72
-
73
-
74
- class Mixed(nn.Module):
75
- """Inception block for I3D.
76
-
77
- Args:
78
- in_channels (int): Input channels of this block.
79
- out_channels (int): Output channels of this block.
80
- expansion (float): The multiplier of in_channels and out_channels.
81
- Default: 1.
82
- """
83
-
84
- def __init__(self, in_channels, out_channels, expansion=1.0):
85
- super(Mixed, self).__init__()
86
- # Branch 0
87
- self.branch_0 = Conv3dBlock(
88
- in_channels, out_channels[0], expansion, kernel_size=(1, 1, 1))
89
-
90
- # Branch 1
91
- branch_1_conv1 = Conv3dBlock(
92
- in_channels, out_channels[1], expansion, kernel_size=(1, 1, 1))
93
- branch_1_conv2 = Conv3dBlock(
94
- out_channels[1],
95
- out_channels[2],
96
- expansion,
97
- kernel_size=(3, 3, 3),
98
- padding=(1, 1, 1))
99
- self.branch_1 = nn.Sequential(branch_1_conv1, branch_1_conv2)
100
-
101
- # Branch 2
102
- branch_2_conv1 = Conv3dBlock(
103
- in_channels, out_channels[3], expansion, kernel_size=(1, 1, 1))
104
- branch_2_conv2 = Conv3dBlock(
105
- out_channels[3],
106
- out_channels[4],
107
- expansion,
108
- kernel_size=(3, 3, 3),
109
- padding=(1, 1, 1))
110
- self.branch_2 = nn.Sequential(branch_2_conv1, branch_2_conv2)
111
-
112
- # Branch3
113
- branch_3_pool = nn.MaxPool3d(
114
- kernel_size=(3, 3, 3),
115
- stride=(1, 1, 1),
116
- padding=(1, 1, 1),
117
- ceil_mode=True)
118
- branch_3_conv2 = Conv3dBlock(
119
- in_channels, out_channels[5], expansion, kernel_size=(1, 1, 1))
120
- self.branch_3 = nn.Sequential(branch_3_pool, branch_3_conv2)
121
-
122
- def forward(self, x):
123
- """Forward function."""
124
- out_0 = self.branch_0(x)
125
- out_1 = self.branch_1(x)
126
- out_2 = self.branch_2(x)
127
- out_3 = self.branch_3(x)
128
- out = torch.cat((out_0, out_1, out_2, out_3), 1)
129
- return out
130
-
131
-
132
- @BACKBONES.register_module()
133
- class I3D(BaseBackbone):
134
- """I3D backbone.
135
-
136
- Please refer to the `paper <https://arxiv.org/abs/1705.07750>`__ for
137
- details.
138
-
139
- Args:
140
- in_channels (int): Input channels of the backbone, which is decided
141
- on the input modality.
142
- expansion (float): The multiplier of in_channels and out_channels.
143
- Default: 1.
144
- """
145
-
146
- def __init__(self, in_channels=3, expansion=1.0):
147
- super(I3D, self).__init__()
148
-
149
- # expansion must be an integer multiple of 1/8
150
- expansion = round(8 * expansion) / 8.0
151
-
152
- # xut Layer
153
- self.conv3d_1a_7x7 = Conv3dBlock(
154
- out_channels=64,
155
- in_channels=in_channels / expansion,
156
- expansion=expansion,
157
- kernel_size=(7, 7, 7),
158
- stride=(2, 2, 2),
159
- padding=(2, 3, 3))
160
- self.maxPool3d_2a_3x3 = nn.MaxPool3d(
161
- kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
162
-
163
- # Layer 2
164
- self.conv3d_2b_1x1 = Conv3dBlock(
165
- out_channels=64,
166
- in_channels=64,
167
- expansion=expansion,
168
- kernel_size=(1, 1, 1))
169
- self.conv3d_2c_3x3 = Conv3dBlock(
170
- out_channels=192,
171
- in_channels=64,
172
- expansion=expansion,
173
- kernel_size=(3, 3, 3),
174
- padding=(1, 1, 1))
175
- self.maxPool3d_3a_3x3 = nn.MaxPool3d(
176
- kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
177
-
178
- # Mixed_3b
179
- self.mixed_3b = Mixed(192, [64, 96, 128, 16, 32, 32], expansion)
180
- self.mixed_3c = Mixed(256, [128, 128, 192, 32, 96, 64], expansion)
181
- self.maxPool3d_4a_3x3 = nn.MaxPool3d(
182
- kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
183
-
184
- # Mixed 4
185
- self.mixed_4b = Mixed(480, [192, 96, 208, 16, 48, 64], expansion)
186
- self.mixed_4c = Mixed(512, [160, 112, 224, 24, 64, 64], expansion)
187
- self.mixed_4d = Mixed(512, [128, 128, 256, 24, 64, 64], expansion)
188
- self.mixed_4e = Mixed(512, [112, 144, 288, 32, 64, 64], expansion)
189
- self.mixed_4f = Mixed(528, [256, 160, 320, 32, 128, 128], expansion)
190
-
191
- self.maxPool3d_5a_2x2 = nn.MaxPool3d(
192
- kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0))
193
-
194
- # Mixed 5
195
- self.mixed_5b = Mixed(832, [256, 160, 320, 32, 128, 128], expansion)
196
- self.mixed_5c = Mixed(832, [384, 192, 384, 48, 128, 128], expansion)
197
-
198
- def forward(self, x):
199
- out = self.conv3d_1a_7x7(x)
200
- out = self.maxPool3d_2a_3x3(out)
201
- out = self.conv3d_2b_1x1(out)
202
- out = self.conv3d_2c_3x3(out)
203
- out = self.maxPool3d_3a_3x3(out)
204
- out = self.mixed_3b(out)
205
- out = self.mixed_3c(out)
206
- out = self.maxPool3d_4a_3x3(out)
207
- out = self.mixed_4b(out)
208
- out = self.mixed_4c(out)
209
- out = self.mixed_4d(out)
210
- out = self.mixed_4e(out)
211
- out = self.mixed_4f(out)
212
- out = self.maxPool3d_5a_2x2(out)
213
- out = self.mixed_5b(out)
214
- out = self.mixed_5c(out)
215
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/litehrnet.py DELETED
@@ -1,984 +0,0 @@
1
- # ------------------------------------------------------------------------------
2
- # Adapted from https://github.com/HRNet/Lite-HRNet
3
- # Original licence: Apache License 2.0.
4
- # ------------------------------------------------------------------------------
5
-
6
- import mmcv
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- import torch.utils.checkpoint as cp
11
- from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
12
- build_conv_layer, build_norm_layer, constant_init,
13
- normal_init)
14
- from torch.nn.modules.batchnorm import _BatchNorm
15
-
16
- from mmpose.utils import get_root_logger
17
- from ..builder import BACKBONES
18
- from .utils import channel_shuffle, load_checkpoint
19
-
20
-
21
- class SpatialWeighting(nn.Module):
22
- """Spatial weighting module.
23
-
24
- Args:
25
- channels (int): The channels of the module.
26
- ratio (int): channel reduction ratio.
27
- conv_cfg (dict): Config dict for convolution layer.
28
- Default: None, which means using conv2d.
29
- norm_cfg (dict): Config dict for normalization layer.
30
- Default: None.
31
- act_cfg (dict): Config dict for activation layer.
32
- Default: (dict(type='ReLU'), dict(type='Sigmoid')).
33
- The last ConvModule uses Sigmoid by default.
34
- """
35
-
36
- def __init__(self,
37
- channels,
38
- ratio=16,
39
- conv_cfg=None,
40
- norm_cfg=None,
41
- act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
42
- super().__init__()
43
- if isinstance(act_cfg, dict):
44
- act_cfg = (act_cfg, act_cfg)
45
- assert len(act_cfg) == 2
46
- assert mmcv.is_tuple_of(act_cfg, dict)
47
- self.global_avgpool = nn.AdaptiveAvgPool2d(1)
48
- self.conv1 = ConvModule(
49
- in_channels=channels,
50
- out_channels=int(channels / ratio),
51
- kernel_size=1,
52
- stride=1,
53
- conv_cfg=conv_cfg,
54
- norm_cfg=norm_cfg,
55
- act_cfg=act_cfg[0])
56
- self.conv2 = ConvModule(
57
- in_channels=int(channels / ratio),
58
- out_channels=channels,
59
- kernel_size=1,
60
- stride=1,
61
- conv_cfg=conv_cfg,
62
- norm_cfg=norm_cfg,
63
- act_cfg=act_cfg[1])
64
-
65
- def forward(self, x):
66
- out = self.global_avgpool(x)
67
- out = self.conv1(out)
68
- out = self.conv2(out)
69
- return x * out
70
-
71
-
72
- class CrossResolutionWeighting(nn.Module):
73
- """Cross-resolution channel weighting module.
74
-
75
- Args:
76
- channels (int): The channels of the module.
77
- ratio (int): channel reduction ratio.
78
- conv_cfg (dict): Config dict for convolution layer.
79
- Default: None, which means using conv2d.
80
- norm_cfg (dict): Config dict for normalization layer.
81
- Default: None.
82
- act_cfg (dict): Config dict for activation layer.
83
- Default: (dict(type='ReLU'), dict(type='Sigmoid')).
84
- The last ConvModule uses Sigmoid by default.
85
- """
86
-
87
- def __init__(self,
88
- channels,
89
- ratio=16,
90
- conv_cfg=None,
91
- norm_cfg=None,
92
- act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
93
- super().__init__()
94
- if isinstance(act_cfg, dict):
95
- act_cfg = (act_cfg, act_cfg)
96
- assert len(act_cfg) == 2
97
- assert mmcv.is_tuple_of(act_cfg, dict)
98
- self.channels = channels
99
- total_channel = sum(channels)
100
- self.conv1 = ConvModule(
101
- in_channels=total_channel,
102
- out_channels=int(total_channel / ratio),
103
- kernel_size=1,
104
- stride=1,
105
- conv_cfg=conv_cfg,
106
- norm_cfg=norm_cfg,
107
- act_cfg=act_cfg[0])
108
- self.conv2 = ConvModule(
109
- in_channels=int(total_channel / ratio),
110
- out_channels=total_channel,
111
- kernel_size=1,
112
- stride=1,
113
- conv_cfg=conv_cfg,
114
- norm_cfg=norm_cfg,
115
- act_cfg=act_cfg[1])
116
-
117
- def forward(self, x):
118
- mini_size = x[-1].size()[-2:]
119
- out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]]
120
- out = torch.cat(out, dim=1)
121
- out = self.conv1(out)
122
- out = self.conv2(out)
123
- out = torch.split(out, self.channels, dim=1)
124
- out = [
125
- s * F.interpolate(a, size=s.size()[-2:], mode='nearest')
126
- for s, a in zip(x, out)
127
- ]
128
- return out
129
-
130
-
131
- class ConditionalChannelWeighting(nn.Module):
132
- """Conditional channel weighting block.
133
-
134
- Args:
135
- in_channels (int): The input channels of the block.
136
- stride (int): Stride of the 3x3 convolution layer.
137
- reduce_ratio (int): channel reduction ratio.
138
- conv_cfg (dict): Config dict for convolution layer.
139
- Default: None, which means using conv2d.
140
- norm_cfg (dict): Config dict for normalization layer.
141
- Default: dict(type='BN').
142
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
143
- memory while slowing down the training speed. Default: False.
144
- """
145
-
146
- def __init__(self,
147
- in_channels,
148
- stride,
149
- reduce_ratio,
150
- conv_cfg=None,
151
- norm_cfg=dict(type='BN'),
152
- with_cp=False):
153
- super().__init__()
154
- self.with_cp = with_cp
155
- self.stride = stride
156
- assert stride in [1, 2]
157
-
158
- branch_channels = [channel // 2 for channel in in_channels]
159
-
160
- self.cross_resolution_weighting = CrossResolutionWeighting(
161
- branch_channels,
162
- ratio=reduce_ratio,
163
- conv_cfg=conv_cfg,
164
- norm_cfg=norm_cfg)
165
-
166
- self.depthwise_convs = nn.ModuleList([
167
- ConvModule(
168
- channel,
169
- channel,
170
- kernel_size=3,
171
- stride=self.stride,
172
- padding=1,
173
- groups=channel,
174
- conv_cfg=conv_cfg,
175
- norm_cfg=norm_cfg,
176
- act_cfg=None) for channel in branch_channels
177
- ])
178
-
179
- self.spatial_weighting = nn.ModuleList([
180
- SpatialWeighting(channels=channel, ratio=4)
181
- for channel in branch_channels
182
- ])
183
-
184
- def forward(self, x):
185
-
186
- def _inner_forward(x):
187
- x = [s.chunk(2, dim=1) for s in x]
188
- x1 = [s[0] for s in x]
189
- x2 = [s[1] for s in x]
190
-
191
- x2 = self.cross_resolution_weighting(x2)
192
- x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)]
193
- x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)]
194
-
195
- out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)]
196
- out = [channel_shuffle(s, 2) for s in out]
197
-
198
- return out
199
-
200
- if self.with_cp and x.requires_grad:
201
- out = cp.checkpoint(_inner_forward, x)
202
- else:
203
- out = _inner_forward(x)
204
-
205
- return out
206
-
207
-
208
- class Stem(nn.Module):
209
- """Stem network block.
210
-
211
- Args:
212
- in_channels (int): The input channels of the block.
213
- stem_channels (int): Output channels of the stem layer.
214
- out_channels (int): The output channels of the block.
215
- expand_ratio (int): adjusts number of channels of the hidden layer
216
- in InvertedResidual by this amount.
217
- conv_cfg (dict): Config dict for convolution layer.
218
- Default: None, which means using conv2d.
219
- norm_cfg (dict): Config dict for normalization layer.
220
- Default: dict(type='BN').
221
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
222
- memory while slowing down the training speed. Default: False.
223
- """
224
-
225
- def __init__(self,
226
- in_channels,
227
- stem_channels,
228
- out_channels,
229
- expand_ratio,
230
- conv_cfg=None,
231
- norm_cfg=dict(type='BN'),
232
- with_cp=False):
233
- super().__init__()
234
- self.in_channels = in_channels
235
- self.out_channels = out_channels
236
- self.conv_cfg = conv_cfg
237
- self.norm_cfg = norm_cfg
238
- self.with_cp = with_cp
239
-
240
- self.conv1 = ConvModule(
241
- in_channels=in_channels,
242
- out_channels=stem_channels,
243
- kernel_size=3,
244
- stride=2,
245
- padding=1,
246
- conv_cfg=self.conv_cfg,
247
- norm_cfg=self.norm_cfg,
248
- act_cfg=dict(type='ReLU'))
249
-
250
- mid_channels = int(round(stem_channels * expand_ratio))
251
- branch_channels = stem_channels // 2
252
- if stem_channels == self.out_channels:
253
- inc_channels = self.out_channels - branch_channels
254
- else:
255
- inc_channels = self.out_channels - stem_channels
256
-
257
- self.branch1 = nn.Sequential(
258
- ConvModule(
259
- branch_channels,
260
- branch_channels,
261
- kernel_size=3,
262
- stride=2,
263
- padding=1,
264
- groups=branch_channels,
265
- conv_cfg=conv_cfg,
266
- norm_cfg=norm_cfg,
267
- act_cfg=None),
268
- ConvModule(
269
- branch_channels,
270
- inc_channels,
271
- kernel_size=1,
272
- stride=1,
273
- padding=0,
274
- conv_cfg=conv_cfg,
275
- norm_cfg=norm_cfg,
276
- act_cfg=dict(type='ReLU')),
277
- )
278
-
279
- self.expand_conv = ConvModule(
280
- branch_channels,
281
- mid_channels,
282
- kernel_size=1,
283
- stride=1,
284
- padding=0,
285
- conv_cfg=conv_cfg,
286
- norm_cfg=norm_cfg,
287
- act_cfg=dict(type='ReLU'))
288
- self.depthwise_conv = ConvModule(
289
- mid_channels,
290
- mid_channels,
291
- kernel_size=3,
292
- stride=2,
293
- padding=1,
294
- groups=mid_channels,
295
- conv_cfg=conv_cfg,
296
- norm_cfg=norm_cfg,
297
- act_cfg=None)
298
- self.linear_conv = ConvModule(
299
- mid_channels,
300
- branch_channels
301
- if stem_channels == self.out_channels else stem_channels,
302
- kernel_size=1,
303
- stride=1,
304
- padding=0,
305
- conv_cfg=conv_cfg,
306
- norm_cfg=norm_cfg,
307
- act_cfg=dict(type='ReLU'))
308
-
309
- def forward(self, x):
310
-
311
- def _inner_forward(x):
312
- x = self.conv1(x)
313
- x1, x2 = x.chunk(2, dim=1)
314
-
315
- x2 = self.expand_conv(x2)
316
- x2 = self.depthwise_conv(x2)
317
- x2 = self.linear_conv(x2)
318
-
319
- out = torch.cat((self.branch1(x1), x2), dim=1)
320
-
321
- out = channel_shuffle(out, 2)
322
-
323
- return out
324
-
325
- if self.with_cp and x.requires_grad:
326
- out = cp.checkpoint(_inner_forward, x)
327
- else:
328
- out = _inner_forward(x)
329
-
330
- return out
331
-
332
-
333
- class IterativeHead(nn.Module):
334
- """Extra iterative head for feature learning.
335
-
336
- Args:
337
- in_channels (int): The input channels of the block.
338
- norm_cfg (dict): Config dict for normalization layer.
339
- Default: dict(type='BN').
340
- """
341
-
342
- def __init__(self, in_channels, norm_cfg=dict(type='BN')):
343
- super().__init__()
344
- projects = []
345
- num_branchs = len(in_channels)
346
- self.in_channels = in_channels[::-1]
347
-
348
- for i in range(num_branchs):
349
- if i != num_branchs - 1:
350
- projects.append(
351
- DepthwiseSeparableConvModule(
352
- in_channels=self.in_channels[i],
353
- out_channels=self.in_channels[i + 1],
354
- kernel_size=3,
355
- stride=1,
356
- padding=1,
357
- norm_cfg=norm_cfg,
358
- act_cfg=dict(type='ReLU'),
359
- dw_act_cfg=None,
360
- pw_act_cfg=dict(type='ReLU')))
361
- else:
362
- projects.append(
363
- DepthwiseSeparableConvModule(
364
- in_channels=self.in_channels[i],
365
- out_channels=self.in_channels[i],
366
- kernel_size=3,
367
- stride=1,
368
- padding=1,
369
- norm_cfg=norm_cfg,
370
- act_cfg=dict(type='ReLU'),
371
- dw_act_cfg=None,
372
- pw_act_cfg=dict(type='ReLU')))
373
- self.projects = nn.ModuleList(projects)
374
-
375
- def forward(self, x):
376
- x = x[::-1]
377
-
378
- y = []
379
- last_x = None
380
- for i, s in enumerate(x):
381
- if last_x is not None:
382
- last_x = F.interpolate(
383
- last_x,
384
- size=s.size()[-2:],
385
- mode='bilinear',
386
- align_corners=True)
387
- s = s + last_x
388
- s = self.projects[i](s)
389
- y.append(s)
390
- last_x = s
391
-
392
- return y[::-1]
393
-
394
-
395
- class ShuffleUnit(nn.Module):
396
- """InvertedResidual block for ShuffleNetV2 backbone.
397
-
398
- Args:
399
- in_channels (int): The input channels of the block.
400
- out_channels (int): The output channels of the block.
401
- stride (int): Stride of the 3x3 convolution layer. Default: 1
402
- conv_cfg (dict): Config dict for convolution layer.
403
- Default: None, which means using conv2d.
404
- norm_cfg (dict): Config dict for normalization layer.
405
- Default: dict(type='BN').
406
- act_cfg (dict): Config dict for activation layer.
407
- Default: dict(type='ReLU').
408
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
409
- memory while slowing down the training speed. Default: False.
410
- """
411
-
412
- def __init__(self,
413
- in_channels,
414
- out_channels,
415
- stride=1,
416
- conv_cfg=None,
417
- norm_cfg=dict(type='BN'),
418
- act_cfg=dict(type='ReLU'),
419
- with_cp=False):
420
- super().__init__()
421
- self.stride = stride
422
- self.with_cp = with_cp
423
-
424
- branch_features = out_channels // 2
425
- if self.stride == 1:
426
- assert in_channels == branch_features * 2, (
427
- f'in_channels ({in_channels}) should equal to '
428
- f'branch_features * 2 ({branch_features * 2}) '
429
- 'when stride is 1')
430
-
431
- if in_channels != branch_features * 2:
432
- assert self.stride != 1, (
433
- f'stride ({self.stride}) should not equal 1 when '
434
- f'in_channels != branch_features * 2')
435
-
436
- if self.stride > 1:
437
- self.branch1 = nn.Sequential(
438
- ConvModule(
439
- in_channels,
440
- in_channels,
441
- kernel_size=3,
442
- stride=self.stride,
443
- padding=1,
444
- groups=in_channels,
445
- conv_cfg=conv_cfg,
446
- norm_cfg=norm_cfg,
447
- act_cfg=None),
448
- ConvModule(
449
- in_channels,
450
- branch_features,
451
- kernel_size=1,
452
- stride=1,
453
- padding=0,
454
- conv_cfg=conv_cfg,
455
- norm_cfg=norm_cfg,
456
- act_cfg=act_cfg),
457
- )
458
-
459
- self.branch2 = nn.Sequential(
460
- ConvModule(
461
- in_channels if (self.stride > 1) else branch_features,
462
- branch_features,
463
- kernel_size=1,
464
- stride=1,
465
- padding=0,
466
- conv_cfg=conv_cfg,
467
- norm_cfg=norm_cfg,
468
- act_cfg=act_cfg),
469
- ConvModule(
470
- branch_features,
471
- branch_features,
472
- kernel_size=3,
473
- stride=self.stride,
474
- padding=1,
475
- groups=branch_features,
476
- conv_cfg=conv_cfg,
477
- norm_cfg=norm_cfg,
478
- act_cfg=None),
479
- ConvModule(
480
- branch_features,
481
- branch_features,
482
- kernel_size=1,
483
- stride=1,
484
- padding=0,
485
- conv_cfg=conv_cfg,
486
- norm_cfg=norm_cfg,
487
- act_cfg=act_cfg))
488
-
489
- def forward(self, x):
490
-
491
- def _inner_forward(x):
492
- if self.stride > 1:
493
- out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
494
- else:
495
- x1, x2 = x.chunk(2, dim=1)
496
- out = torch.cat((x1, self.branch2(x2)), dim=1)
497
-
498
- out = channel_shuffle(out, 2)
499
-
500
- return out
501
-
502
- if self.with_cp and x.requires_grad:
503
- out = cp.checkpoint(_inner_forward, x)
504
- else:
505
- out = _inner_forward(x)
506
-
507
- return out
508
-
509
-
510
- class LiteHRModule(nn.Module):
511
- """High-Resolution Module for LiteHRNet.
512
-
513
- It contains conditional channel weighting blocks and
514
- shuffle blocks.
515
-
516
-
517
- Args:
518
- num_branches (int): Number of branches in the module.
519
- num_blocks (int): Number of blocks in the module.
520
- in_channels (list(int)): Number of input image channels.
521
- reduce_ratio (int): Channel reduction ratio.
522
- module_type (str): 'LITE' or 'NAIVE'
523
- multiscale_output (bool): Whether to output multi-scale features.
524
- with_fuse (bool): Whether to use fuse layers.
525
- conv_cfg (dict): dictionary to construct and config conv layer.
526
- norm_cfg (dict): dictionary to construct and config norm layer.
527
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
528
- memory while slowing down the training speed.
529
- """
530
-
531
- def __init__(
532
- self,
533
- num_branches,
534
- num_blocks,
535
- in_channels,
536
- reduce_ratio,
537
- module_type,
538
- multiscale_output=False,
539
- with_fuse=True,
540
- conv_cfg=None,
541
- norm_cfg=dict(type='BN'),
542
- with_cp=False,
543
- ):
544
- super().__init__()
545
- self._check_branches(num_branches, in_channels)
546
-
547
- self.in_channels = in_channels
548
- self.num_branches = num_branches
549
-
550
- self.module_type = module_type
551
- self.multiscale_output = multiscale_output
552
- self.with_fuse = with_fuse
553
- self.norm_cfg = norm_cfg
554
- self.conv_cfg = conv_cfg
555
- self.with_cp = with_cp
556
-
557
- if self.module_type.upper() == 'LITE':
558
- self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio)
559
- elif self.module_type.upper() == 'NAIVE':
560
- self.layers = self._make_naive_branches(num_branches, num_blocks)
561
- else:
562
- raise ValueError("module_type should be either 'LITE' or 'NAIVE'.")
563
- if self.with_fuse:
564
- self.fuse_layers = self._make_fuse_layers()
565
- self.relu = nn.ReLU()
566
-
567
- def _check_branches(self, num_branches, in_channels):
568
- """Check input to avoid ValueError."""
569
- if num_branches != len(in_channels):
570
- error_msg = f'NUM_BRANCHES({num_branches}) ' \
571
- f'!= NUM_INCHANNELS({len(in_channels)})'
572
- raise ValueError(error_msg)
573
-
574
- def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1):
575
- """Make channel weighting blocks."""
576
- layers = []
577
- for i in range(num_blocks):
578
- layers.append(
579
- ConditionalChannelWeighting(
580
- self.in_channels,
581
- stride=stride,
582
- reduce_ratio=reduce_ratio,
583
- conv_cfg=self.conv_cfg,
584
- norm_cfg=self.norm_cfg,
585
- with_cp=self.with_cp))
586
-
587
- return nn.Sequential(*layers)
588
-
589
- def _make_one_branch(self, branch_index, num_blocks, stride=1):
590
- """Make one branch."""
591
- layers = []
592
- layers.append(
593
- ShuffleUnit(
594
- self.in_channels[branch_index],
595
- self.in_channels[branch_index],
596
- stride=stride,
597
- conv_cfg=self.conv_cfg,
598
- norm_cfg=self.norm_cfg,
599
- act_cfg=dict(type='ReLU'),
600
- with_cp=self.with_cp))
601
- for i in range(1, num_blocks):
602
- layers.append(
603
- ShuffleUnit(
604
- self.in_channels[branch_index],
605
- self.in_channels[branch_index],
606
- stride=1,
607
- conv_cfg=self.conv_cfg,
608
- norm_cfg=self.norm_cfg,
609
- act_cfg=dict(type='ReLU'),
610
- with_cp=self.with_cp))
611
-
612
- return nn.Sequential(*layers)
613
-
614
- def _make_naive_branches(self, num_branches, num_blocks):
615
- """Make branches."""
616
- branches = []
617
-
618
- for i in range(num_branches):
619
- branches.append(self._make_one_branch(i, num_blocks))
620
-
621
- return nn.ModuleList(branches)
622
-
623
- def _make_fuse_layers(self):
624
- """Make fuse layer."""
625
- if self.num_branches == 1:
626
- return None
627
-
628
- num_branches = self.num_branches
629
- in_channels = self.in_channels
630
- fuse_layers = []
631
- num_out_branches = num_branches if self.multiscale_output else 1
632
- for i in range(num_out_branches):
633
- fuse_layer = []
634
- for j in range(num_branches):
635
- if j > i:
636
- fuse_layer.append(
637
- nn.Sequential(
638
- build_conv_layer(
639
- self.conv_cfg,
640
- in_channels[j],
641
- in_channels[i],
642
- kernel_size=1,
643
- stride=1,
644
- padding=0,
645
- bias=False),
646
- build_norm_layer(self.norm_cfg, in_channels[i])[1],
647
- nn.Upsample(
648
- scale_factor=2**(j - i), mode='nearest')))
649
- elif j == i:
650
- fuse_layer.append(None)
651
- else:
652
- conv_downsamples = []
653
- for k in range(i - j):
654
- if k == i - j - 1:
655
- conv_downsamples.append(
656
- nn.Sequential(
657
- build_conv_layer(
658
- self.conv_cfg,
659
- in_channels[j],
660
- in_channels[j],
661
- kernel_size=3,
662
- stride=2,
663
- padding=1,
664
- groups=in_channels[j],
665
- bias=False),
666
- build_norm_layer(self.norm_cfg,
667
- in_channels[j])[1],
668
- build_conv_layer(
669
- self.conv_cfg,
670
- in_channels[j],
671
- in_channels[i],
672
- kernel_size=1,
673
- stride=1,
674
- padding=0,
675
- bias=False),
676
- build_norm_layer(self.norm_cfg,
677
- in_channels[i])[1]))
678
- else:
679
- conv_downsamples.append(
680
- nn.Sequential(
681
- build_conv_layer(
682
- self.conv_cfg,
683
- in_channels[j],
684
- in_channels[j],
685
- kernel_size=3,
686
- stride=2,
687
- padding=1,
688
- groups=in_channels[j],
689
- bias=False),
690
- build_norm_layer(self.norm_cfg,
691
- in_channels[j])[1],
692
- build_conv_layer(
693
- self.conv_cfg,
694
- in_channels[j],
695
- in_channels[j],
696
- kernel_size=1,
697
- stride=1,
698
- padding=0,
699
- bias=False),
700
- build_norm_layer(self.norm_cfg,
701
- in_channels[j])[1],
702
- nn.ReLU(inplace=True)))
703
- fuse_layer.append(nn.Sequential(*conv_downsamples))
704
- fuse_layers.append(nn.ModuleList(fuse_layer))
705
-
706
- return nn.ModuleList(fuse_layers)
707
-
708
- def forward(self, x):
709
- """Forward function."""
710
- if self.num_branches == 1:
711
- return [self.layers[0](x[0])]
712
-
713
- if self.module_type.upper() == 'LITE':
714
- out = self.layers(x)
715
- elif self.module_type.upper() == 'NAIVE':
716
- for i in range(self.num_branches):
717
- x[i] = self.layers[i](x[i])
718
- out = x
719
-
720
- if self.with_fuse:
721
- out_fuse = []
722
- for i in range(len(self.fuse_layers)):
723
- # `y = 0` will lead to decreased accuracy (0.5~1 mAP)
724
- y = out[0] if i == 0 else self.fuse_layers[i][0](out[0])
725
- for j in range(self.num_branches):
726
- if i == j:
727
- y += out[j]
728
- else:
729
- y += self.fuse_layers[i][j](out[j])
730
- out_fuse.append(self.relu(y))
731
- out = out_fuse
732
- if not self.multiscale_output:
733
- out = [out[0]]
734
- return out
735
-
736
-
737
- @BACKBONES.register_module()
738
- class LiteHRNet(nn.Module):
739
- """Lite-HRNet backbone.
740
-
741
- `Lite-HRNet: A Lightweight High-Resolution Network
742
- <https://arxiv.org/abs/2104.06403>`_.
743
-
744
- Code adapted from 'https://github.com/HRNet/Lite-HRNet'.
745
-
746
- Args:
747
- extra (dict): detailed configuration for each stage of HRNet.
748
- in_channels (int): Number of input image channels. Default: 3.
749
- conv_cfg (dict): dictionary to construct and config conv layer.
750
- norm_cfg (dict): dictionary to construct and config norm layer.
751
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
752
- freeze running stats (mean and var). Note: Effect on Batch Norm
753
- and its variants only. Default: False
754
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
755
- memory while slowing down the training speed.
756
-
757
- Example:
758
- >>> from mmpose.models import LiteHRNet
759
- >>> import torch
760
- >>> extra=dict(
761
- >>> stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
762
- >>> num_stages=3,
763
- >>> stages_spec=dict(
764
- >>> num_modules=(2, 4, 2),
765
- >>> num_branches=(2, 3, 4),
766
- >>> num_blocks=(2, 2, 2),
767
- >>> module_type=('LITE', 'LITE', 'LITE'),
768
- >>> with_fuse=(True, True, True),
769
- >>> reduce_ratios=(8, 8, 8),
770
- >>> num_channels=(
771
- >>> (40, 80),
772
- >>> (40, 80, 160),
773
- >>> (40, 80, 160, 320),
774
- >>> )),
775
- >>> with_head=False)
776
- >>> self = LiteHRNet(extra, in_channels=1)
777
- >>> self.eval()
778
- >>> inputs = torch.rand(1, 1, 32, 32)
779
- >>> level_outputs = self.forward(inputs)
780
- >>> for level_out in level_outputs:
781
- ... print(tuple(level_out.shape))
782
- (1, 40, 8, 8)
783
- """
784
-
785
- def __init__(self,
786
- extra,
787
- in_channels=3,
788
- conv_cfg=None,
789
- norm_cfg=dict(type='BN'),
790
- norm_eval=False,
791
- with_cp=False):
792
- super().__init__()
793
- self.extra = extra
794
- self.conv_cfg = conv_cfg
795
- self.norm_cfg = norm_cfg
796
- self.norm_eval = norm_eval
797
- self.with_cp = with_cp
798
-
799
- self.stem = Stem(
800
- in_channels,
801
- stem_channels=self.extra['stem']['stem_channels'],
802
- out_channels=self.extra['stem']['out_channels'],
803
- expand_ratio=self.extra['stem']['expand_ratio'],
804
- conv_cfg=self.conv_cfg,
805
- norm_cfg=self.norm_cfg)
806
-
807
- self.num_stages = self.extra['num_stages']
808
- self.stages_spec = self.extra['stages_spec']
809
-
810
- num_channels_last = [
811
- self.stem.out_channels,
812
- ]
813
- for i in range(self.num_stages):
814
- num_channels = self.stages_spec['num_channels'][i]
815
- num_channels = [num_channels[i] for i in range(len(num_channels))]
816
- setattr(
817
- self, f'transition{i}',
818
- self._make_transition_layer(num_channels_last, num_channels))
819
-
820
- stage, num_channels_last = self._make_stage(
821
- self.stages_spec, i, num_channels, multiscale_output=True)
822
- setattr(self, f'stage{i}', stage)
823
-
824
- self.with_head = self.extra['with_head']
825
- if self.with_head:
826
- self.head_layer = IterativeHead(
827
- in_channels=num_channels_last,
828
- norm_cfg=self.norm_cfg,
829
- )
830
-
831
- def _make_transition_layer(self, num_channels_pre_layer,
832
- num_channels_cur_layer):
833
- """Make transition layer."""
834
- num_branches_cur = len(num_channels_cur_layer)
835
- num_branches_pre = len(num_channels_pre_layer)
836
-
837
- transition_layers = []
838
- for i in range(num_branches_cur):
839
- if i < num_branches_pre:
840
- if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
841
- transition_layers.append(
842
- nn.Sequential(
843
- build_conv_layer(
844
- self.conv_cfg,
845
- num_channels_pre_layer[i],
846
- num_channels_pre_layer[i],
847
- kernel_size=3,
848
- stride=1,
849
- padding=1,
850
- groups=num_channels_pre_layer[i],
851
- bias=False),
852
- build_norm_layer(self.norm_cfg,
853
- num_channels_pre_layer[i])[1],
854
- build_conv_layer(
855
- self.conv_cfg,
856
- num_channels_pre_layer[i],
857
- num_channels_cur_layer[i],
858
- kernel_size=1,
859
- stride=1,
860
- padding=0,
861
- bias=False),
862
- build_norm_layer(self.norm_cfg,
863
- num_channels_cur_layer[i])[1],
864
- nn.ReLU()))
865
- else:
866
- transition_layers.append(None)
867
- else:
868
- conv_downsamples = []
869
- for j in range(i + 1 - num_branches_pre):
870
- in_channels = num_channels_pre_layer[-1]
871
- out_channels = num_channels_cur_layer[i] \
872
- if j == i - num_branches_pre else in_channels
873
- conv_downsamples.append(
874
- nn.Sequential(
875
- build_conv_layer(
876
- self.conv_cfg,
877
- in_channels,
878
- in_channels,
879
- kernel_size=3,
880
- stride=2,
881
- padding=1,
882
- groups=in_channels,
883
- bias=False),
884
- build_norm_layer(self.norm_cfg, in_channels)[1],
885
- build_conv_layer(
886
- self.conv_cfg,
887
- in_channels,
888
- out_channels,
889
- kernel_size=1,
890
- stride=1,
891
- padding=0,
892
- bias=False),
893
- build_norm_layer(self.norm_cfg, out_channels)[1],
894
- nn.ReLU()))
895
- transition_layers.append(nn.Sequential(*conv_downsamples))
896
-
897
- return nn.ModuleList(transition_layers)
898
-
899
- def _make_stage(self,
900
- stages_spec,
901
- stage_index,
902
- in_channels,
903
- multiscale_output=True):
904
- num_modules = stages_spec['num_modules'][stage_index]
905
- num_branches = stages_spec['num_branches'][stage_index]
906
- num_blocks = stages_spec['num_blocks'][stage_index]
907
- reduce_ratio = stages_spec['reduce_ratios'][stage_index]
908
- with_fuse = stages_spec['with_fuse'][stage_index]
909
- module_type = stages_spec['module_type'][stage_index]
910
-
911
- modules = []
912
- for i in range(num_modules):
913
- # multi_scale_output is only used last module
914
- if not multiscale_output and i == num_modules - 1:
915
- reset_multiscale_output = False
916
- else:
917
- reset_multiscale_output = True
918
-
919
- modules.append(
920
- LiteHRModule(
921
- num_branches,
922
- num_blocks,
923
- in_channels,
924
- reduce_ratio,
925
- module_type,
926
- multiscale_output=reset_multiscale_output,
927
- with_fuse=with_fuse,
928
- conv_cfg=self.conv_cfg,
929
- norm_cfg=self.norm_cfg,
930
- with_cp=self.with_cp))
931
- in_channels = modules[-1].in_channels
932
-
933
- return nn.Sequential(*modules), in_channels
934
-
935
- def init_weights(self, pretrained=None):
936
- """Initialize the weights in backbone.
937
-
938
- Args:
939
- pretrained (str, optional): Path to pre-trained weights.
940
- Defaults to None.
941
- """
942
- if isinstance(pretrained, str):
943
- logger = get_root_logger()
944
- load_checkpoint(self, pretrained, strict=False, logger=logger)
945
- elif pretrained is None:
946
- for m in self.modules():
947
- if isinstance(m, nn.Conv2d):
948
- normal_init(m, std=0.001)
949
- elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
950
- constant_init(m, 1)
951
- else:
952
- raise TypeError('pretrained must be a str or None')
953
-
954
- def forward(self, x):
955
- """Forward function."""
956
- x = self.stem(x)
957
-
958
- y_list = [x]
959
- for i in range(self.num_stages):
960
- x_list = []
961
- transition = getattr(self, f'transition{i}')
962
- for j in range(self.stages_spec['num_branches'][i]):
963
- if transition[j]:
964
- if j >= len(y_list):
965
- x_list.append(transition[j](y_list[-1]))
966
- else:
967
- x_list.append(transition[j](y_list[j]))
968
- else:
969
- x_list.append(y_list[j])
970
- y_list = getattr(self, f'stage{i}')(x_list)
971
-
972
- x = y_list
973
- if self.with_head:
974
- x = self.head_layer(x)
975
-
976
- return [x[0]]
977
-
978
- def train(self, mode=True):
979
- """Convert the model into training mode."""
980
- super().train(mode)
981
- if mode and self.norm_eval:
982
- for m in self.modules():
983
- if isinstance(m, _BatchNorm):
984
- m.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/mobilenet_v2.py DELETED
@@ -1,275 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
- import logging
4
-
5
- import torch.nn as nn
6
- import torch.utils.checkpoint as cp
7
- from mmcv.cnn import ConvModule, constant_init, kaiming_init
8
- from torch.nn.modules.batchnorm import _BatchNorm
9
-
10
- from ..builder import BACKBONES
11
- from .base_backbone import BaseBackbone
12
- from .utils import load_checkpoint, make_divisible
13
-
14
-
15
- class InvertedResidual(nn.Module):
16
- """InvertedResidual block for MobileNetV2.
17
-
18
- Args:
19
- in_channels (int): The input channels of the InvertedResidual block.
20
- out_channels (int): The output channels of the InvertedResidual block.
21
- stride (int): Stride of the middle (first) 3x3 convolution.
22
- expand_ratio (int): adjusts number of channels of the hidden layer
23
- in InvertedResidual by this amount.
24
- conv_cfg (dict): Config dict for convolution layer.
25
- Default: None, which means using conv2d.
26
- norm_cfg (dict): Config dict for normalization layer.
27
- Default: dict(type='BN').
28
- act_cfg (dict): Config dict for activation layer.
29
- Default: dict(type='ReLU6').
30
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
31
- memory while slowing down the training speed. Default: False.
32
- """
33
-
34
- def __init__(self,
35
- in_channels,
36
- out_channels,
37
- stride,
38
- expand_ratio,
39
- conv_cfg=None,
40
- norm_cfg=dict(type='BN'),
41
- act_cfg=dict(type='ReLU6'),
42
- with_cp=False):
43
- # Protect mutable default arguments
44
- norm_cfg = copy.deepcopy(norm_cfg)
45
- act_cfg = copy.deepcopy(act_cfg)
46
- super().__init__()
47
- self.stride = stride
48
- assert stride in [1, 2], f'stride must in [1, 2]. ' \
49
- f'But received {stride}.'
50
- self.with_cp = with_cp
51
- self.use_res_connect = self.stride == 1 and in_channels == out_channels
52
- hidden_dim = int(round(in_channels * expand_ratio))
53
-
54
- layers = []
55
- if expand_ratio != 1:
56
- layers.append(
57
- ConvModule(
58
- in_channels=in_channels,
59
- out_channels=hidden_dim,
60
- kernel_size=1,
61
- conv_cfg=conv_cfg,
62
- norm_cfg=norm_cfg,
63
- act_cfg=act_cfg))
64
- layers.extend([
65
- ConvModule(
66
- in_channels=hidden_dim,
67
- out_channels=hidden_dim,
68
- kernel_size=3,
69
- stride=stride,
70
- padding=1,
71
- groups=hidden_dim,
72
- conv_cfg=conv_cfg,
73
- norm_cfg=norm_cfg,
74
- act_cfg=act_cfg),
75
- ConvModule(
76
- in_channels=hidden_dim,
77
- out_channels=out_channels,
78
- kernel_size=1,
79
- conv_cfg=conv_cfg,
80
- norm_cfg=norm_cfg,
81
- act_cfg=None)
82
- ])
83
- self.conv = nn.Sequential(*layers)
84
-
85
- def forward(self, x):
86
-
87
- def _inner_forward(x):
88
- if self.use_res_connect:
89
- return x + self.conv(x)
90
- return self.conv(x)
91
-
92
- if self.with_cp and x.requires_grad:
93
- out = cp.checkpoint(_inner_forward, x)
94
- else:
95
- out = _inner_forward(x)
96
-
97
- return out
98
-
99
-
100
- @BACKBONES.register_module()
101
- class MobileNetV2(BaseBackbone):
102
- """MobileNetV2 backbone.
103
-
104
- Args:
105
- widen_factor (float): Width multiplier, multiply number of
106
- channels in each layer by this amount. Default: 1.0.
107
- out_indices (None or Sequence[int]): Output from which stages.
108
- Default: (7, ).
109
- frozen_stages (int): Stages to be frozen (all param fixed).
110
- Default: -1, which means not freezing any parameters.
111
- conv_cfg (dict): Config dict for convolution layer.
112
- Default: None, which means using conv2d.
113
- norm_cfg (dict): Config dict for normalization layer.
114
- Default: dict(type='BN').
115
- act_cfg (dict): Config dict for activation layer.
116
- Default: dict(type='ReLU6').
117
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
118
- freeze running stats (mean and var). Note: Effect on Batch Norm
119
- and its variants only. Default: False.
120
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
121
- memory while slowing down the training speed. Default: False.
122
- """
123
-
124
- # Parameters to build layers. 4 parameters are needed to construct a
125
- # layer, from left to right: expand_ratio, channel, num_blocks, stride.
126
- arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],
127
- [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2],
128
- [6, 320, 1, 1]]
129
-
130
- def __init__(self,
131
- widen_factor=1.,
132
- out_indices=(7, ),
133
- frozen_stages=-1,
134
- conv_cfg=None,
135
- norm_cfg=dict(type='BN'),
136
- act_cfg=dict(type='ReLU6'),
137
- norm_eval=False,
138
- with_cp=False):
139
- # Protect mutable default arguments
140
- norm_cfg = copy.deepcopy(norm_cfg)
141
- act_cfg = copy.deepcopy(act_cfg)
142
- super().__init__()
143
- self.widen_factor = widen_factor
144
- self.out_indices = out_indices
145
- for index in out_indices:
146
- if index not in range(0, 8):
147
- raise ValueError('the item in out_indices must in '
148
- f'range(0, 8). But received {index}')
149
-
150
- if frozen_stages not in range(-1, 8):
151
- raise ValueError('frozen_stages must be in range(-1, 8). '
152
- f'But received {frozen_stages}')
153
- self.out_indices = out_indices
154
- self.frozen_stages = frozen_stages
155
- self.conv_cfg = conv_cfg
156
- self.norm_cfg = norm_cfg
157
- self.act_cfg = act_cfg
158
- self.norm_eval = norm_eval
159
- self.with_cp = with_cp
160
-
161
- self.in_channels = make_divisible(32 * widen_factor, 8)
162
-
163
- self.conv1 = ConvModule(
164
- in_channels=3,
165
- out_channels=self.in_channels,
166
- kernel_size=3,
167
- stride=2,
168
- padding=1,
169
- conv_cfg=self.conv_cfg,
170
- norm_cfg=self.norm_cfg,
171
- act_cfg=self.act_cfg)
172
-
173
- self.layers = []
174
-
175
- for i, layer_cfg in enumerate(self.arch_settings):
176
- expand_ratio, channel, num_blocks, stride = layer_cfg
177
- out_channels = make_divisible(channel * widen_factor, 8)
178
- inverted_res_layer = self.make_layer(
179
- out_channels=out_channels,
180
- num_blocks=num_blocks,
181
- stride=stride,
182
- expand_ratio=expand_ratio)
183
- layer_name = f'layer{i + 1}'
184
- self.add_module(layer_name, inverted_res_layer)
185
- self.layers.append(layer_name)
186
-
187
- if widen_factor > 1.0:
188
- self.out_channel = int(1280 * widen_factor)
189
- else:
190
- self.out_channel = 1280
191
-
192
- layer = ConvModule(
193
- in_channels=self.in_channels,
194
- out_channels=self.out_channel,
195
- kernel_size=1,
196
- stride=1,
197
- padding=0,
198
- conv_cfg=self.conv_cfg,
199
- norm_cfg=self.norm_cfg,
200
- act_cfg=self.act_cfg)
201
- self.add_module('conv2', layer)
202
- self.layers.append('conv2')
203
-
204
- def make_layer(self, out_channels, num_blocks, stride, expand_ratio):
205
- """Stack InvertedResidual blocks to build a layer for MobileNetV2.
206
-
207
- Args:
208
- out_channels (int): out_channels of block.
209
- num_blocks (int): number of blocks.
210
- stride (int): stride of the first block. Default: 1
211
- expand_ratio (int): Expand the number of channels of the
212
- hidden layer in InvertedResidual by this ratio. Default: 6.
213
- """
214
- layers = []
215
- for i in range(num_blocks):
216
- if i >= 1:
217
- stride = 1
218
- layers.append(
219
- InvertedResidual(
220
- self.in_channels,
221
- out_channels,
222
- stride,
223
- expand_ratio=expand_ratio,
224
- conv_cfg=self.conv_cfg,
225
- norm_cfg=self.norm_cfg,
226
- act_cfg=self.act_cfg,
227
- with_cp=self.with_cp))
228
- self.in_channels = out_channels
229
-
230
- return nn.Sequential(*layers)
231
-
232
- def init_weights(self, pretrained=None):
233
- if isinstance(pretrained, str):
234
- logger = logging.getLogger()
235
- load_checkpoint(self, pretrained, strict=False, logger=logger)
236
- elif pretrained is None:
237
- for m in self.modules():
238
- if isinstance(m, nn.Conv2d):
239
- kaiming_init(m)
240
- elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
241
- constant_init(m, 1)
242
- else:
243
- raise TypeError('pretrained must be a str or None')
244
-
245
- def forward(self, x):
246
- x = self.conv1(x)
247
-
248
- outs = []
249
- for i, layer_name in enumerate(self.layers):
250
- layer = getattr(self, layer_name)
251
- x = layer(x)
252
- if i in self.out_indices:
253
- outs.append(x)
254
-
255
- if len(outs) == 1:
256
- return outs[0]
257
- return tuple(outs)
258
-
259
- def _freeze_stages(self):
260
- if self.frozen_stages >= 0:
261
- for param in self.conv1.parameters():
262
- param.requires_grad = False
263
- for i in range(1, self.frozen_stages + 1):
264
- layer = getattr(self, f'layer{i}')
265
- layer.eval()
266
- for param in layer.parameters():
267
- param.requires_grad = False
268
-
269
- def train(self, mode=True):
270
- super().train(mode)
271
- self._freeze_stages()
272
- if mode and self.norm_eval:
273
- for m in self.modules():
274
- if isinstance(m, _BatchNorm):
275
- m.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/mobilenet_v3.py DELETED
@@ -1,188 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
- import logging
4
-
5
- import torch.nn as nn
6
- from mmcv.cnn import ConvModule, constant_init, kaiming_init
7
- from torch.nn.modules.batchnorm import _BatchNorm
8
-
9
- from ..builder import BACKBONES
10
- from .base_backbone import BaseBackbone
11
- from .utils import InvertedResidual, load_checkpoint
12
-
13
-
14
- @BACKBONES.register_module()
15
- class MobileNetV3(BaseBackbone):
16
- """MobileNetV3 backbone.
17
-
18
- Args:
19
- arch (str): Architecture of mobilnetv3, from {small, big}.
20
- Default: small.
21
- conv_cfg (dict): Config dict for convolution layer.
22
- Default: None, which means using conv2d.
23
- norm_cfg (dict): Config dict for normalization layer.
24
- Default: dict(type='BN').
25
- out_indices (None or Sequence[int]): Output from which stages.
26
- Default: (-1, ), which means output tensors from final stage.
27
- frozen_stages (int): Stages to be frozen (all param fixed).
28
- Default: -1, which means not freezing any parameters.
29
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
30
- freeze running stats (mean and var). Note: Effect on Batch Norm
31
- and its variants only. Default: False.
32
- with_cp (bool): Use checkpoint or not. Using checkpoint will save
33
- some memory while slowing down the training speed.
34
- Default: False.
35
- """
36
- # Parameters to build each block:
37
- # [kernel size, mid channels, out channels, with_se, act type, stride]
38
- arch_settings = {
39
- 'small': [[3, 16, 16, True, 'ReLU', 2],
40
- [3, 72, 24, False, 'ReLU', 2],
41
- [3, 88, 24, False, 'ReLU', 1],
42
- [5, 96, 40, True, 'HSwish', 2],
43
- [5, 240, 40, True, 'HSwish', 1],
44
- [5, 240, 40, True, 'HSwish', 1],
45
- [5, 120, 48, True, 'HSwish', 1],
46
- [5, 144, 48, True, 'HSwish', 1],
47
- [5, 288, 96, True, 'HSwish', 2],
48
- [5, 576, 96, True, 'HSwish', 1],
49
- [5, 576, 96, True, 'HSwish', 1]],
50
- 'big': [[3, 16, 16, False, 'ReLU', 1],
51
- [3, 64, 24, False, 'ReLU', 2],
52
- [3, 72, 24, False, 'ReLU', 1],
53
- [5, 72, 40, True, 'ReLU', 2],
54
- [5, 120, 40, True, 'ReLU', 1],
55
- [5, 120, 40, True, 'ReLU', 1],
56
- [3, 240, 80, False, 'HSwish', 2],
57
- [3, 200, 80, False, 'HSwish', 1],
58
- [3, 184, 80, False, 'HSwish', 1],
59
- [3, 184, 80, False, 'HSwish', 1],
60
- [3, 480, 112, True, 'HSwish', 1],
61
- [3, 672, 112, True, 'HSwish', 1],
62
- [5, 672, 160, True, 'HSwish', 1],
63
- [5, 672, 160, True, 'HSwish', 2],
64
- [5, 960, 160, True, 'HSwish', 1]]
65
- } # yapf: disable
66
-
67
- def __init__(self,
68
- arch='small',
69
- conv_cfg=None,
70
- norm_cfg=dict(type='BN'),
71
- out_indices=(-1, ),
72
- frozen_stages=-1,
73
- norm_eval=False,
74
- with_cp=False):
75
- # Protect mutable default arguments
76
- norm_cfg = copy.deepcopy(norm_cfg)
77
- super().__init__()
78
- assert arch in self.arch_settings
79
- for index in out_indices:
80
- if index not in range(-len(self.arch_settings[arch]),
81
- len(self.arch_settings[arch])):
82
- raise ValueError('the item in out_indices must in '
83
- f'range(0, {len(self.arch_settings[arch])}). '
84
- f'But received {index}')
85
-
86
- if frozen_stages not in range(-1, len(self.arch_settings[arch])):
87
- raise ValueError('frozen_stages must be in range(-1, '
88
- f'{len(self.arch_settings[arch])}). '
89
- f'But received {frozen_stages}')
90
- self.arch = arch
91
- self.conv_cfg = conv_cfg
92
- self.norm_cfg = norm_cfg
93
- self.out_indices = out_indices
94
- self.frozen_stages = frozen_stages
95
- self.norm_eval = norm_eval
96
- self.with_cp = with_cp
97
-
98
- self.in_channels = 16
99
- self.conv1 = ConvModule(
100
- in_channels=3,
101
- out_channels=self.in_channels,
102
- kernel_size=3,
103
- stride=2,
104
- padding=1,
105
- conv_cfg=conv_cfg,
106
- norm_cfg=norm_cfg,
107
- act_cfg=dict(type='HSwish'))
108
-
109
- self.layers = self._make_layer()
110
- self.feat_dim = self.arch_settings[arch][-1][2]
111
-
112
- def _make_layer(self):
113
- layers = []
114
- layer_setting = self.arch_settings[self.arch]
115
- for i, params in enumerate(layer_setting):
116
- (kernel_size, mid_channels, out_channels, with_se, act,
117
- stride) = params
118
- if with_se:
119
- se_cfg = dict(
120
- channels=mid_channels,
121
- ratio=4,
122
- act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')))
123
- else:
124
- se_cfg = None
125
-
126
- layer = InvertedResidual(
127
- in_channels=self.in_channels,
128
- out_channels=out_channels,
129
- mid_channels=mid_channels,
130
- kernel_size=kernel_size,
131
- stride=stride,
132
- se_cfg=se_cfg,
133
- with_expand_conv=True,
134
- conv_cfg=self.conv_cfg,
135
- norm_cfg=self.norm_cfg,
136
- act_cfg=dict(type=act),
137
- with_cp=self.with_cp)
138
- self.in_channels = out_channels
139
- layer_name = f'layer{i + 1}'
140
- self.add_module(layer_name, layer)
141
- layers.append(layer_name)
142
- return layers
143
-
144
- def init_weights(self, pretrained=None):
145
- if isinstance(pretrained, str):
146
- logger = logging.getLogger()
147
- load_checkpoint(self, pretrained, strict=False, logger=logger)
148
- elif pretrained is None:
149
- for m in self.modules():
150
- if isinstance(m, nn.Conv2d):
151
- kaiming_init(m)
152
- elif isinstance(m, nn.BatchNorm2d):
153
- constant_init(m, 1)
154
- else:
155
- raise TypeError('pretrained must be a str or None')
156
-
157
- def forward(self, x):
158
- x = self.conv1(x)
159
-
160
- outs = []
161
- for i, layer_name in enumerate(self.layers):
162
- layer = getattr(self, layer_name)
163
- x = layer(x)
164
- if i in self.out_indices or \
165
- i - len(self.layers) in self.out_indices:
166
- outs.append(x)
167
-
168
- if len(outs) == 1:
169
- return outs[0]
170
- return tuple(outs)
171
-
172
- def _freeze_stages(self):
173
- if self.frozen_stages >= 0:
174
- for param in self.conv1.parameters():
175
- param.requires_grad = False
176
- for i in range(1, self.frozen_stages + 1):
177
- layer = getattr(self, f'layer{i}')
178
- layer.eval()
179
- for param in layer.parameters():
180
- param.requires_grad = False
181
-
182
- def train(self, mode=True):
183
- super().train(mode)
184
- self._freeze_stages()
185
- if mode and self.norm_eval:
186
- for m in self.modules():
187
- if isinstance(m, _BatchNorm):
188
- m.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/modules/basic_block.py CHANGED
@@ -12,13 +12,11 @@ import torch.nn as nn
12
  import torch.nn.functional as F
13
  import torch.utils.checkpoint as cp
14
  from .transformer_block import TransformerBlock
15
-
16
  from mmcv.cnn import (
17
  build_conv_layer,
18
  build_norm_layer,
19
  build_plugin_layer,
20
- constant_init,
21
- kaiming_init,
22
  )
23
 
24
 
 
12
  import torch.nn.functional as F
13
  import torch.utils.checkpoint as cp
14
  from .transformer_block import TransformerBlock
15
+ from mmengine.model import constant_init, kaiming_init
16
  from mmcv.cnn import (
17
  build_conv_layer,
18
  build_norm_layer,
19
  build_plugin_layer,
 
 
20
  )
21
 
22
 
main/transformer_utils/mmpose/models/backbones/mspn.py DELETED
@@ -1,513 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy as cp
3
- from collections import OrderedDict
4
-
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from mmcv.cnn import (ConvModule, MaxPool2d, constant_init, kaiming_init,
8
- normal_init)
9
- from mmcv.runner.checkpoint import load_state_dict
10
-
11
- from mmpose.utils import get_root_logger
12
- from ..builder import BACKBONES
13
- from .base_backbone import BaseBackbone
14
- from .resnet import Bottleneck as _Bottleneck
15
- from .utils.utils import get_state_dict
16
-
17
-
18
- class Bottleneck(_Bottleneck):
19
- expansion = 4
20
- """Bottleneck block for MSPN.
21
-
22
- Args:
23
- in_channels (int): Input channels of this block.
24
- out_channels (int): Output channels of this block.
25
- stride (int): stride of the block. Default: 1
26
- downsample (nn.Module): downsample operation on identity branch.
27
- Default: None
28
- norm_cfg (dict): dictionary to construct and config norm layer.
29
- Default: dict(type='BN')
30
- """
31
-
32
- def __init__(self, in_channels, out_channels, **kwargs):
33
- super().__init__(in_channels, out_channels * 4, **kwargs)
34
-
35
-
36
- class DownsampleModule(nn.Module):
37
- """Downsample module for MSPN.
38
-
39
- Args:
40
- block (nn.Module): Downsample block.
41
- num_blocks (list): Number of blocks in each downsample unit.
42
- num_units (int): Numbers of downsample units. Default: 4
43
- has_skip (bool): Have skip connections from prior upsample
44
- module or not. Default:False
45
- norm_cfg (dict): dictionary to construct and config norm layer.
46
- Default: dict(type='BN')
47
- in_channels (int): Number of channels of the input feature to
48
- downsample module. Default: 64
49
- """
50
-
51
- def __init__(self,
52
- block,
53
- num_blocks,
54
- num_units=4,
55
- has_skip=False,
56
- norm_cfg=dict(type='BN'),
57
- in_channels=64):
58
- # Protect mutable default arguments
59
- norm_cfg = cp.deepcopy(norm_cfg)
60
- super().__init__()
61
- self.has_skip = has_skip
62
- self.in_channels = in_channels
63
- assert len(num_blocks) == num_units
64
- self.num_blocks = num_blocks
65
- self.num_units = num_units
66
- self.norm_cfg = norm_cfg
67
- self.layer1 = self._make_layer(block, in_channels, num_blocks[0])
68
- for i in range(1, num_units):
69
- module_name = f'layer{i + 1}'
70
- self.add_module(
71
- module_name,
72
- self._make_layer(
73
- block, in_channels * pow(2, i), num_blocks[i], stride=2))
74
-
75
- def _make_layer(self, block, out_channels, blocks, stride=1):
76
- downsample = None
77
- if stride != 1 or self.in_channels != out_channels * block.expansion:
78
- downsample = ConvModule(
79
- self.in_channels,
80
- out_channels * block.expansion,
81
- kernel_size=1,
82
- stride=stride,
83
- padding=0,
84
- norm_cfg=self.norm_cfg,
85
- act_cfg=None,
86
- inplace=True)
87
-
88
- units = list()
89
- units.append(
90
- block(
91
- self.in_channels,
92
- out_channels,
93
- stride=stride,
94
- downsample=downsample,
95
- norm_cfg=self.norm_cfg))
96
- self.in_channels = out_channels * block.expansion
97
- for _ in range(1, blocks):
98
- units.append(block(self.in_channels, out_channels))
99
-
100
- return nn.Sequential(*units)
101
-
102
- def forward(self, x, skip1, skip2):
103
- out = list()
104
- for i in range(self.num_units):
105
- module_name = f'layer{i + 1}'
106
- module_i = getattr(self, module_name)
107
- x = module_i(x)
108
- if self.has_skip:
109
- x = x + skip1[i] + skip2[i]
110
- out.append(x)
111
- out.reverse()
112
-
113
- return tuple(out)
114
-
115
-
116
- class UpsampleUnit(nn.Module):
117
- """Upsample unit for upsample module.
118
-
119
- Args:
120
- ind (int): Indicates whether to interpolate (>0) and whether to
121
- generate feature map for the next hourglass-like module.
122
- num_units (int): Number of units that form a upsample module. Along
123
- with ind and gen_cross_conv, nm_units is used to decide whether
124
- to generate feature map for the next hourglass-like module.
125
- in_channels (int): Channel number of the skip-in feature maps from
126
- the corresponding downsample unit.
127
- unit_channels (int): Channel number in this unit. Default:256.
128
- gen_skip: (bool): Whether or not to generate skips for the posterior
129
- downsample module. Default:False
130
- gen_cross_conv (bool): Whether to generate feature map for the next
131
- hourglass-like module. Default:False
132
- norm_cfg (dict): dictionary to construct and config norm layer.
133
- Default: dict(type='BN')
134
- out_channels (int): Number of channels of feature output by upsample
135
- module. Must equal to in_channels of downsample module. Default:64
136
- """
137
-
138
- def __init__(self,
139
- ind,
140
- num_units,
141
- in_channels,
142
- unit_channels=256,
143
- gen_skip=False,
144
- gen_cross_conv=False,
145
- norm_cfg=dict(type='BN'),
146
- out_channels=64):
147
- # Protect mutable default arguments
148
- norm_cfg = cp.deepcopy(norm_cfg)
149
- super().__init__()
150
- self.num_units = num_units
151
- self.norm_cfg = norm_cfg
152
- self.in_skip = ConvModule(
153
- in_channels,
154
- unit_channels,
155
- kernel_size=1,
156
- stride=1,
157
- padding=0,
158
- norm_cfg=self.norm_cfg,
159
- act_cfg=None,
160
- inplace=True)
161
- self.relu = nn.ReLU(inplace=True)
162
-
163
- self.ind = ind
164
- if self.ind > 0:
165
- self.up_conv = ConvModule(
166
- unit_channels,
167
- unit_channels,
168
- kernel_size=1,
169
- stride=1,
170
- padding=0,
171
- norm_cfg=self.norm_cfg,
172
- act_cfg=None,
173
- inplace=True)
174
-
175
- self.gen_skip = gen_skip
176
- if self.gen_skip:
177
- self.out_skip1 = ConvModule(
178
- in_channels,
179
- in_channels,
180
- kernel_size=1,
181
- stride=1,
182
- padding=0,
183
- norm_cfg=self.norm_cfg,
184
- inplace=True)
185
-
186
- self.out_skip2 = ConvModule(
187
- unit_channels,
188
- in_channels,
189
- kernel_size=1,
190
- stride=1,
191
- padding=0,
192
- norm_cfg=self.norm_cfg,
193
- inplace=True)
194
-
195
- self.gen_cross_conv = gen_cross_conv
196
- if self.ind == num_units - 1 and self.gen_cross_conv:
197
- self.cross_conv = ConvModule(
198
- unit_channels,
199
- out_channels,
200
- kernel_size=1,
201
- stride=1,
202
- padding=0,
203
- norm_cfg=self.norm_cfg,
204
- inplace=True)
205
-
206
- def forward(self, x, up_x):
207
- out = self.in_skip(x)
208
-
209
- if self.ind > 0:
210
- up_x = F.interpolate(
211
- up_x,
212
- size=(x.size(2), x.size(3)),
213
- mode='bilinear',
214
- align_corners=True)
215
- up_x = self.up_conv(up_x)
216
- out = out + up_x
217
- out = self.relu(out)
218
-
219
- skip1 = None
220
- skip2 = None
221
- if self.gen_skip:
222
- skip1 = self.out_skip1(x)
223
- skip2 = self.out_skip2(out)
224
-
225
- cross_conv = None
226
- if self.ind == self.num_units - 1 and self.gen_cross_conv:
227
- cross_conv = self.cross_conv(out)
228
-
229
- return out, skip1, skip2, cross_conv
230
-
231
-
232
- class UpsampleModule(nn.Module):
233
- """Upsample module for MSPN.
234
-
235
- Args:
236
- unit_channels (int): Channel number in the upsample units.
237
- Default:256.
238
- num_units (int): Numbers of upsample units. Default: 4
239
- gen_skip (bool): Whether to generate skip for posterior downsample
240
- module or not. Default:False
241
- gen_cross_conv (bool): Whether to generate feature map for the next
242
- hourglass-like module. Default:False
243
- norm_cfg (dict): dictionary to construct and config norm layer.
244
- Default: dict(type='BN')
245
- out_channels (int): Number of channels of feature output by upsample
246
- module. Must equal to in_channels of downsample module. Default:64
247
- """
248
-
249
- def __init__(self,
250
- unit_channels=256,
251
- num_units=4,
252
- gen_skip=False,
253
- gen_cross_conv=False,
254
- norm_cfg=dict(type='BN'),
255
- out_channels=64):
256
- # Protect mutable default arguments
257
- norm_cfg = cp.deepcopy(norm_cfg)
258
- super().__init__()
259
- self.in_channels = list()
260
- for i in range(num_units):
261
- self.in_channels.append(Bottleneck.expansion * out_channels *
262
- pow(2, i))
263
- self.in_channels.reverse()
264
- self.num_units = num_units
265
- self.gen_skip = gen_skip
266
- self.gen_cross_conv = gen_cross_conv
267
- self.norm_cfg = norm_cfg
268
- for i in range(num_units):
269
- module_name = f'up{i + 1}'
270
- self.add_module(
271
- module_name,
272
- UpsampleUnit(
273
- i,
274
- self.num_units,
275
- self.in_channels[i],
276
- unit_channels,
277
- self.gen_skip,
278
- self.gen_cross_conv,
279
- norm_cfg=self.norm_cfg,
280
- out_channels=64))
281
-
282
- def forward(self, x):
283
- out = list()
284
- skip1 = list()
285
- skip2 = list()
286
- cross_conv = None
287
- for i in range(self.num_units):
288
- module_i = getattr(self, f'up{i + 1}')
289
- if i == 0:
290
- outi, skip1_i, skip2_i, _ = module_i(x[i], None)
291
- elif i == self.num_units - 1:
292
- outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1])
293
- else:
294
- outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1])
295
- out.append(outi)
296
- skip1.append(skip1_i)
297
- skip2.append(skip2_i)
298
- skip1.reverse()
299
- skip2.reverse()
300
-
301
- return out, skip1, skip2, cross_conv
302
-
303
-
304
- class SingleStageNetwork(nn.Module):
305
- """Single_stage Network.
306
-
307
- Args:
308
- unit_channels (int): Channel number in the upsample units. Default:256.
309
- num_units (int): Numbers of downsample/upsample units. Default: 4
310
- gen_skip (bool): Whether to generate skip for posterior downsample
311
- module or not. Default:False
312
- gen_cross_conv (bool): Whether to generate feature map for the next
313
- hourglass-like module. Default:False
314
- has_skip (bool): Have skip connections from prior upsample
315
- module or not. Default:False
316
- num_blocks (list): Number of blocks in each downsample unit.
317
- Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks)
318
- norm_cfg (dict): dictionary to construct and config norm layer.
319
- Default: dict(type='BN')
320
- in_channels (int): Number of channels of the feature from ResNetTop.
321
- Default: 64.
322
- """
323
-
324
- def __init__(self,
325
- has_skip=False,
326
- gen_skip=False,
327
- gen_cross_conv=False,
328
- unit_channels=256,
329
- num_units=4,
330
- num_blocks=[2, 2, 2, 2],
331
- norm_cfg=dict(type='BN'),
332
- in_channels=64):
333
- # Protect mutable default arguments
334
- norm_cfg = cp.deepcopy(norm_cfg)
335
- num_blocks = cp.deepcopy(num_blocks)
336
- super().__init__()
337
- assert len(num_blocks) == num_units
338
- self.has_skip = has_skip
339
- self.gen_skip = gen_skip
340
- self.gen_cross_conv = gen_cross_conv
341
- self.num_units = num_units
342
- self.unit_channels = unit_channels
343
- self.num_blocks = num_blocks
344
- self.norm_cfg = norm_cfg
345
-
346
- self.downsample = DownsampleModule(Bottleneck, num_blocks, num_units,
347
- has_skip, norm_cfg, in_channels)
348
- self.upsample = UpsampleModule(unit_channels, num_units, gen_skip,
349
- gen_cross_conv, norm_cfg, in_channels)
350
-
351
- def forward(self, x, skip1, skip2):
352
- mid = self.downsample(x, skip1, skip2)
353
- out, skip1, skip2, cross_conv = self.upsample(mid)
354
-
355
- return out, skip1, skip2, cross_conv
356
-
357
-
358
- class ResNetTop(nn.Module):
359
- """ResNet top for MSPN.
360
-
361
- Args:
362
- norm_cfg (dict): dictionary to construct and config norm layer.
363
- Default: dict(type='BN')
364
- channels (int): Number of channels of the feature output by ResNetTop.
365
- """
366
-
367
- def __init__(self, norm_cfg=dict(type='BN'), channels=64):
368
- # Protect mutable default arguments
369
- norm_cfg = cp.deepcopy(norm_cfg)
370
- super().__init__()
371
- self.top = nn.Sequential(
372
- ConvModule(
373
- 3,
374
- channels,
375
- kernel_size=7,
376
- stride=2,
377
- padding=3,
378
- norm_cfg=norm_cfg,
379
- inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1))
380
-
381
- def forward(self, img):
382
- return self.top(img)
383
-
384
-
385
- @BACKBONES.register_module()
386
- class MSPN(BaseBackbone):
387
- """MSPN backbone. Paper ref: Li et al. "Rethinking on Multi-Stage Networks
388
- for Human Pose Estimation" (CVPR 2020).
389
-
390
- Args:
391
- unit_channels (int): Number of Channels in an upsample unit.
392
- Default: 256
393
- num_stages (int): Number of stages in a multi-stage MSPN. Default: 4
394
- num_units (int): Number of downsample/upsample units in a single-stage
395
- network. Default: 4
396
- Note: Make sure num_units == len(self.num_blocks)
397
- num_blocks (list): Number of bottlenecks in each
398
- downsample unit. Default: [2, 2, 2, 2]
399
- norm_cfg (dict): dictionary to construct and config norm layer.
400
- Default: dict(type='BN')
401
- res_top_channels (int): Number of channels of feature from ResNetTop.
402
- Default: 64.
403
-
404
- Example:
405
- >>> from mmpose.models import MSPN
406
- >>> import torch
407
- >>> self = MSPN(num_stages=2,num_units=2,num_blocks=[2,2])
408
- >>> self.eval()
409
- >>> inputs = torch.rand(1, 3, 511, 511)
410
- >>> level_outputs = self.forward(inputs)
411
- >>> for level_output in level_outputs:
412
- ... for feature in level_output:
413
- ... print(tuple(feature.shape))
414
- ...
415
- (1, 256, 64, 64)
416
- (1, 256, 128, 128)
417
- (1, 256, 64, 64)
418
- (1, 256, 128, 128)
419
- """
420
-
421
- def __init__(self,
422
- unit_channels=256,
423
- num_stages=4,
424
- num_units=4,
425
- num_blocks=[2, 2, 2, 2],
426
- norm_cfg=dict(type='BN'),
427
- res_top_channels=64):
428
- # Protect mutable default arguments
429
- norm_cfg = cp.deepcopy(norm_cfg)
430
- num_blocks = cp.deepcopy(num_blocks)
431
- super().__init__()
432
- self.unit_channels = unit_channels
433
- self.num_stages = num_stages
434
- self.num_units = num_units
435
- self.num_blocks = num_blocks
436
- self.norm_cfg = norm_cfg
437
-
438
- assert self.num_stages > 0
439
- assert self.num_units > 1
440
- assert self.num_units == len(self.num_blocks)
441
- self.top = ResNetTop(norm_cfg=norm_cfg)
442
- self.multi_stage_mspn = nn.ModuleList([])
443
- for i in range(self.num_stages):
444
- if i == 0:
445
- has_skip = False
446
- else:
447
- has_skip = True
448
- if i != self.num_stages - 1:
449
- gen_skip = True
450
- gen_cross_conv = True
451
- else:
452
- gen_skip = False
453
- gen_cross_conv = False
454
- self.multi_stage_mspn.append(
455
- SingleStageNetwork(has_skip, gen_skip, gen_cross_conv,
456
- unit_channels, num_units, num_blocks,
457
- norm_cfg, res_top_channels))
458
-
459
- def forward(self, x):
460
- """Model forward function."""
461
- out_feats = []
462
- skip1 = None
463
- skip2 = None
464
- x = self.top(x)
465
- for i in range(self.num_stages):
466
- out, skip1, skip2, x = self.multi_stage_mspn[i](x, skip1, skip2)
467
- out_feats.append(out)
468
-
469
- return out_feats
470
-
471
- def init_weights(self, pretrained=None):
472
- """Initialize model weights."""
473
- if isinstance(pretrained, str):
474
- logger = get_root_logger()
475
- state_dict_tmp = get_state_dict(pretrained)
476
- state_dict = OrderedDict()
477
- state_dict['top'] = OrderedDict()
478
- state_dict['bottlenecks'] = OrderedDict()
479
- for k, v in state_dict_tmp.items():
480
- if k.startswith('layer'):
481
- if 'downsample.0' in k:
482
- state_dict['bottlenecks'][k.replace(
483
- 'downsample.0', 'downsample.conv')] = v
484
- elif 'downsample.1' in k:
485
- state_dict['bottlenecks'][k.replace(
486
- 'downsample.1', 'downsample.bn')] = v
487
- else:
488
- state_dict['bottlenecks'][k] = v
489
- elif k.startswith('conv1'):
490
- state_dict['top'][k.replace('conv1', 'top.0.conv')] = v
491
- elif k.startswith('bn1'):
492
- state_dict['top'][k.replace('bn1', 'top.0.bn')] = v
493
-
494
- load_state_dict(
495
- self.top, state_dict['top'], strict=False, logger=logger)
496
- for i in range(self.num_stages):
497
- load_state_dict(
498
- self.multi_stage_mspn[i].downsample,
499
- state_dict['bottlenecks'],
500
- strict=False,
501
- logger=logger)
502
- else:
503
- for m in self.multi_stage_mspn.modules():
504
- if isinstance(m, nn.Conv2d):
505
- kaiming_init(m)
506
- elif isinstance(m, nn.BatchNorm2d):
507
- constant_init(m, 1)
508
- elif isinstance(m, nn.Linear):
509
- normal_init(m, std=0.01)
510
-
511
- for m in self.top.modules():
512
- if isinstance(m, nn.Conv2d):
513
- kaiming_init(m)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/pvt.py DELETED
@@ -1,592 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import math
3
- import warnings
4
-
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
10
- constant_init, normal_init, trunc_normal_init)
11
- from mmcv.cnn.bricks.drop import build_dropout
12
- from mmcv.cnn.bricks.transformer import MultiheadAttention
13
- from mmcv.cnn.utils.weight_init import trunc_normal_
14
- from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint,
15
- load_state_dict)
16
- from torch.nn.modules.utils import _pair as to_2tuple
17
-
18
- from ...utils import get_root_logger
19
- from ..builder import BACKBONES
20
- from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw, pvt_convert
21
-
22
-
23
- class MixFFN(BaseModule):
24
- """An implementation of MixFFN of PVT.
25
-
26
- The differences between MixFFN & FFN:
27
- 1. Use 1X1 Conv to replace Linear layer.
28
- 2. Introduce 3X3 Depth-wise Conv to encode positional information.
29
-
30
- Args:
31
- embed_dims (int): The feature dimension. Same as
32
- `MultiheadAttention`.
33
- feedforward_channels (int): The hidden dimension of FFNs.
34
- act_cfg (dict, optional): The activation config for FFNs.
35
- Default: dict(type='GELU').
36
- ffn_drop (float, optional): Probability of an element to be
37
- zeroed in FFN. Default 0.0.
38
- dropout_layer (obj:`ConfigDict`): The dropout_layer used
39
- when adding the shortcut.
40
- Default: None.
41
- use_conv (bool): If True, add 3x3 DWConv between two Linear layers.
42
- Defaults: False.
43
- init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
44
- Default: None.
45
- """
46
-
47
- def __init__(self,
48
- embed_dims,
49
- feedforward_channels,
50
- act_cfg=dict(type='GELU'),
51
- ffn_drop=0.,
52
- dropout_layer=None,
53
- use_conv=False,
54
- init_cfg=None):
55
- super(MixFFN, self).__init__(init_cfg=init_cfg)
56
-
57
- self.embed_dims = embed_dims
58
- self.feedforward_channels = feedforward_channels
59
- self.act_cfg = act_cfg
60
- activate = build_activation_layer(act_cfg)
61
-
62
- in_channels = embed_dims
63
- fc1 = Conv2d(
64
- in_channels=in_channels,
65
- out_channels=feedforward_channels,
66
- kernel_size=1,
67
- stride=1,
68
- bias=True)
69
- if use_conv:
70
- # 3x3 depth wise conv to provide positional encode information
71
- dw_conv = Conv2d(
72
- in_channels=feedforward_channels,
73
- out_channels=feedforward_channels,
74
- kernel_size=3,
75
- stride=1,
76
- padding=(3 - 1) // 2,
77
- bias=True,
78
- groups=feedforward_channels)
79
- fc2 = Conv2d(
80
- in_channels=feedforward_channels,
81
- out_channels=in_channels,
82
- kernel_size=1,
83
- stride=1,
84
- bias=True)
85
- drop = nn.Dropout(ffn_drop)
86
- layers = [fc1, activate, drop, fc2, drop]
87
- if use_conv:
88
- layers.insert(1, dw_conv)
89
- self.layers = Sequential(*layers)
90
- self.dropout_layer = build_dropout(
91
- dropout_layer) if dropout_layer else torch.nn.Identity()
92
-
93
- def forward(self, x, hw_shape, identity=None):
94
- out = nlc_to_nchw(x, hw_shape)
95
- out = self.layers(out)
96
- out = nchw_to_nlc(out)
97
- if identity is None:
98
- identity = x
99
- return identity + self.dropout_layer(out)
100
-
101
-
102
- class SpatialReductionAttention(MultiheadAttention):
103
- """An implementation of Spatial Reduction Attention of PVT.
104
-
105
- This module is modified from MultiheadAttention which is a module from
106
- mmcv.cnn.bricks.transformer.
107
-
108
- Args:
109
- embed_dims (int): The embedding dimension.
110
- num_heads (int): Parallel attention heads.
111
- attn_drop (float): A Dropout layer on attn_output_weights.
112
- Default: 0.0.
113
- proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
114
- Default: 0.0.
115
- dropout_layer (obj:`ConfigDict`): The dropout_layer used
116
- when adding the shortcut. Default: None.
117
- batch_first (bool): Key, Query and Value are shape of
118
- (batch, n, embed_dim)
119
- or (n, batch, embed_dim). Default: False.
120
- qkv_bias (bool): enable bias for qkv if True. Default: True.
121
- norm_cfg (dict): Config dict for normalization layer.
122
- Default: dict(type='LN').
123
- sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
124
- Attention of PVT. Default: 1.
125
- init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
126
- Default: None.
127
- """
128
-
129
- def __init__(self,
130
- embed_dims,
131
- num_heads,
132
- attn_drop=0.,
133
- proj_drop=0.,
134
- dropout_layer=None,
135
- batch_first=True,
136
- qkv_bias=True,
137
- norm_cfg=dict(type='LN'),
138
- sr_ratio=1,
139
- init_cfg=None):
140
- super().__init__(
141
- embed_dims,
142
- num_heads,
143
- attn_drop,
144
- proj_drop,
145
- batch_first=batch_first,
146
- dropout_layer=dropout_layer,
147
- bias=qkv_bias,
148
- init_cfg=init_cfg)
149
-
150
- self.sr_ratio = sr_ratio
151
- if sr_ratio > 1:
152
- self.sr = Conv2d(
153
- in_channels=embed_dims,
154
- out_channels=embed_dims,
155
- kernel_size=sr_ratio,
156
- stride=sr_ratio)
157
- # The ret[0] of build_norm_layer is norm name.
158
- self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
159
-
160
- # handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
161
- from mmpose import digit_version, mmcv_version
162
- if mmcv_version < digit_version('1.3.17'):
163
- warnings.warn('The legacy version of forward function in'
164
- 'SpatialReductionAttention is deprecated in'
165
- 'mmcv>=1.3.17 and will no longer support in the'
166
- 'future. Please upgrade your mmcv.')
167
- self.forward = self.legacy_forward
168
-
169
- def forward(self, x, hw_shape, identity=None):
170
-
171
- x_q = x
172
- if self.sr_ratio > 1:
173
- x_kv = nlc_to_nchw(x, hw_shape)
174
- x_kv = self.sr(x_kv)
175
- x_kv = nchw_to_nlc(x_kv)
176
- x_kv = self.norm(x_kv)
177
- else:
178
- x_kv = x
179
-
180
- if identity is None:
181
- identity = x_q
182
-
183
- # Because the dataflow('key', 'query', 'value') of
184
- # ``torch.nn.MultiheadAttention`` is (num_query, batch,
185
- # embed_dims), We should adjust the shape of dataflow from
186
- # batch_first (batch, num_query, embed_dims) to num_query_first
187
- # (num_query ,batch, embed_dims), and recover ``attn_output``
188
- # from num_query_first to batch_first.
189
- if self.batch_first:
190
- x_q = x_q.transpose(0, 1)
191
- x_kv = x_kv.transpose(0, 1)
192
-
193
- out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
194
-
195
- if self.batch_first:
196
- out = out.transpose(0, 1)
197
-
198
- return identity + self.dropout_layer(self.proj_drop(out))
199
-
200
- def legacy_forward(self, x, hw_shape, identity=None):
201
- """multi head attention forward in mmcv version < 1.3.17."""
202
- x_q = x
203
- if self.sr_ratio > 1:
204
- x_kv = nlc_to_nchw(x, hw_shape)
205
- x_kv = self.sr(x_kv)
206
- x_kv = nchw_to_nlc(x_kv)
207
- x_kv = self.norm(x_kv)
208
- else:
209
- x_kv = x
210
-
211
- if identity is None:
212
- identity = x_q
213
-
214
- out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
215
-
216
- return identity + self.dropout_layer(self.proj_drop(out))
217
-
218
-
219
- class PVTEncoderLayer(BaseModule):
220
- """Implements one encoder layer in PVT.
221
-
222
- Args:
223
- embed_dims (int): The feature dimension.
224
- num_heads (int): Parallel attention heads.
225
- feedforward_channels (int): The hidden dimension for FFNs.
226
- drop_rate (float): Probability of an element to be zeroed.
227
- after the feed forward layer. Default: 0.0.
228
- attn_drop_rate (float): The drop out rate for attention layer.
229
- Default: 0.0.
230
- drop_path_rate (float): stochastic depth rate. Default: 0.0.
231
- qkv_bias (bool): enable bias for qkv if True.
232
- Default: True.
233
- act_cfg (dict): The activation config for FFNs.
234
- Default: dict(type='GELU').
235
- norm_cfg (dict): Config dict for normalization layer.
236
- Default: dict(type='LN').
237
- sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
238
- Attention of PVT. Default: 1.
239
- use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
240
- Default: False.
241
- init_cfg (dict, optional): Initialization config dict.
242
- Default: None.
243
- """
244
-
245
- def __init__(self,
246
- embed_dims,
247
- num_heads,
248
- feedforward_channels,
249
- drop_rate=0.,
250
- attn_drop_rate=0.,
251
- drop_path_rate=0.,
252
- qkv_bias=True,
253
- act_cfg=dict(type='GELU'),
254
- norm_cfg=dict(type='LN'),
255
- sr_ratio=1,
256
- use_conv_ffn=False,
257
- init_cfg=None):
258
- super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg)
259
-
260
- # The ret[0] of build_norm_layer is norm name.
261
- self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
262
-
263
- self.attn = SpatialReductionAttention(
264
- embed_dims=embed_dims,
265
- num_heads=num_heads,
266
- attn_drop=attn_drop_rate,
267
- proj_drop=drop_rate,
268
- dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
269
- qkv_bias=qkv_bias,
270
- norm_cfg=norm_cfg,
271
- sr_ratio=sr_ratio)
272
-
273
- # The ret[0] of build_norm_layer is norm name.
274
- self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
275
-
276
- self.ffn = MixFFN(
277
- embed_dims=embed_dims,
278
- feedforward_channels=feedforward_channels,
279
- ffn_drop=drop_rate,
280
- dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
281
- use_conv=use_conv_ffn,
282
- act_cfg=act_cfg)
283
-
284
- def forward(self, x, hw_shape):
285
- x = self.attn(self.norm1(x), hw_shape, identity=x)
286
- x = self.ffn(self.norm2(x), hw_shape, identity=x)
287
-
288
- return x
289
-
290
-
291
- class AbsolutePositionEmbedding(BaseModule):
292
- """An implementation of the absolute position embedding in PVT.
293
-
294
- Args:
295
- pos_shape (int): The shape of the absolute position embedding.
296
- pos_dim (int): The dimension of the absolute position embedding.
297
- drop_rate (float): Probability of an element to be zeroed.
298
- Default: 0.0.
299
- """
300
-
301
- def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None):
302
- super().__init__(init_cfg=init_cfg)
303
-
304
- if isinstance(pos_shape, int):
305
- pos_shape = to_2tuple(pos_shape)
306
- elif isinstance(pos_shape, tuple):
307
- if len(pos_shape) == 1:
308
- pos_shape = to_2tuple(pos_shape[0])
309
- assert len(pos_shape) == 2, \
310
- f'The size of image should have length 1 or 2, ' \
311
- f'but got {len(pos_shape)}'
312
- self.pos_shape = pos_shape
313
- self.pos_dim = pos_dim
314
-
315
- self.pos_embed = nn.Parameter(
316
- torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim))
317
- self.drop = nn.Dropout(p=drop_rate)
318
-
319
- def init_weights(self):
320
- trunc_normal_(self.pos_embed, std=0.02)
321
-
322
- def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'):
323
- """Resize pos_embed weights.
324
-
325
- Resize pos_embed using bilinear interpolate method.
326
-
327
- Args:
328
- pos_embed (torch.Tensor): Position embedding weights.
329
- input_shape (tuple): Tuple for (downsampled input image height,
330
- downsampled input image width).
331
- mode (str): Algorithm used for upsampling:
332
- ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
333
- ``'trilinear'``. Default: ``'bilinear'``.
334
-
335
- Return:
336
- torch.Tensor: The resized pos_embed of shape [B, L_new, C].
337
- """
338
- assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
339
- pos_h, pos_w = self.pos_shape
340
- pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
341
- pos_embed_weight = pos_embed_weight.reshape(
342
- 1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous()
343
- pos_embed_weight = F.interpolate(
344
- pos_embed_weight, size=input_shape, mode=mode)
345
- pos_embed_weight = torch.flatten(pos_embed_weight,
346
- 2).transpose(1, 2).contiguous()
347
- pos_embed = pos_embed_weight
348
-
349
- return pos_embed
350
-
351
- def forward(self, x, hw_shape, mode='bilinear'):
352
- pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode)
353
- return self.drop(x + pos_embed)
354
-
355
-
356
- @BACKBONES.register_module()
357
- class PyramidVisionTransformer(BaseModule):
358
- """Pyramid Vision Transformer (PVT)
359
-
360
- Implementation of `Pyramid Vision Transformer: A Versatile Backbone for
361
- Dense Prediction without Convolutions
362
- <https://arxiv.org/pdf/2102.12122.pdf>`_.
363
-
364
- Args:
365
- pretrain_img_size (int | tuple[int]): The size of input image when
366
- pretrain. Defaults: 224.
367
- in_channels (int): Number of input channels. Default: 3.
368
- embed_dims (int): Embedding dimension. Default: 64.
369
- num_stags (int): The num of stages. Default: 4.
370
- num_layers (Sequence[int]): The layer number of each transformer encode
371
- layer. Default: [3, 4, 6, 3].
372
- num_heads (Sequence[int]): The attention heads of each transformer
373
- encode layer. Default: [1, 2, 5, 8].
374
- patch_sizes (Sequence[int]): The patch_size of each patch embedding.
375
- Default: [4, 2, 2, 2].
376
- strides (Sequence[int]): The stride of each patch embedding.
377
- Default: [4, 2, 2, 2].
378
- paddings (Sequence[int]): The padding of each patch embedding.
379
- Default: [0, 0, 0, 0].
380
- sr_ratios (Sequence[int]): The spatial reduction rate of each
381
- transformer encode layer. Default: [8, 4, 2, 1].
382
- out_indices (Sequence[int] | int): Output from which stages.
383
- Default: (0, 1, 2, 3).
384
- mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
385
- embedding dim of each transformer encode layer.
386
- Default: [8, 8, 4, 4].
387
- qkv_bias (bool): Enable bias for qkv if True. Default: True.
388
- drop_rate (float): Probability of an element to be zeroed.
389
- Default 0.0.
390
- attn_drop_rate (float): The drop out rate for attention layer.
391
- Default 0.0.
392
- drop_path_rate (float): stochastic depth rate. Default 0.1.
393
- use_abs_pos_embed (bool): If True, add absolute position embedding to
394
- the patch embedding. Defaults: True.
395
- use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
396
- Default: False.
397
- act_cfg (dict): The activation config for FFNs.
398
- Default: dict(type='GELU').
399
- norm_cfg (dict): Config dict for normalization layer.
400
- Default: dict(type='LN').
401
- pretrained (str, optional): model pretrained path. Default: None.
402
- convert_weights (bool): The flag indicates whether the
403
- pre-trained model is from the original repo. We may need
404
- to convert some keys to make it compatible.
405
- Default: True.
406
- init_cfg (dict or list[dict], optional): Initialization config dict.
407
- Default: None.
408
- """
409
-
410
- def __init__(self,
411
- pretrain_img_size=224,
412
- in_channels=3,
413
- embed_dims=64,
414
- num_stages=4,
415
- num_layers=[3, 4, 6, 3],
416
- num_heads=[1, 2, 5, 8],
417
- patch_sizes=[4, 2, 2, 2],
418
- strides=[4, 2, 2, 2],
419
- paddings=[0, 0, 0, 0],
420
- sr_ratios=[8, 4, 2, 1],
421
- out_indices=(0, 1, 2, 3),
422
- mlp_ratios=[8, 8, 4, 4],
423
- qkv_bias=True,
424
- drop_rate=0.,
425
- attn_drop_rate=0.,
426
- drop_path_rate=0.1,
427
- use_abs_pos_embed=True,
428
- norm_after_stage=False,
429
- use_conv_ffn=False,
430
- act_cfg=dict(type='GELU'),
431
- norm_cfg=dict(type='LN', eps=1e-6),
432
- pretrained=None,
433
- convert_weights=True,
434
- init_cfg=None):
435
- super().__init__(init_cfg=init_cfg)
436
-
437
- self.convert_weights = convert_weights
438
- if isinstance(pretrain_img_size, int):
439
- pretrain_img_size = to_2tuple(pretrain_img_size)
440
- elif isinstance(pretrain_img_size, tuple):
441
- if len(pretrain_img_size) == 1:
442
- pretrain_img_size = to_2tuple(pretrain_img_size[0])
443
- assert len(pretrain_img_size) == 2, \
444
- f'The size of image should have length 1 or 2, ' \
445
- f'but got {len(pretrain_img_size)}'
446
-
447
- assert not (init_cfg and pretrained), \
448
- 'init_cfg and pretrained cannot be setting at the same time'
449
- if isinstance(pretrained, str):
450
- self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
451
- elif pretrained is None:
452
- self.init_cfg = init_cfg
453
- else:
454
- raise TypeError('pretrained must be a str or None')
455
-
456
- self.embed_dims = embed_dims
457
-
458
- self.num_stages = num_stages
459
- self.num_layers = num_layers
460
- self.num_heads = num_heads
461
- self.patch_sizes = patch_sizes
462
- self.strides = strides
463
- self.sr_ratios = sr_ratios
464
- assert num_stages == len(num_layers) == len(num_heads) \
465
- == len(patch_sizes) == len(strides) == len(sr_ratios)
466
-
467
- self.out_indices = out_indices
468
- assert max(out_indices) < self.num_stages
469
- self.pretrained = pretrained
470
-
471
- # transformer encoder
472
- dpr = [
473
- x.item()
474
- for x in torch.linspace(0, drop_path_rate, sum(num_layers))
475
- ] # stochastic num_layer decay rule
476
-
477
- cur = 0
478
- self.layers = ModuleList()
479
- for i, num_layer in enumerate(num_layers):
480
- embed_dims_i = embed_dims * num_heads[i]
481
- patch_embed = PatchEmbed(
482
- in_channels=in_channels,
483
- embed_dims=embed_dims_i,
484
- kernel_size=patch_sizes[i],
485
- stride=strides[i],
486
- padding=paddings[i],
487
- bias=True,
488
- norm_cfg=norm_cfg)
489
-
490
- layers = ModuleList()
491
- if use_abs_pos_embed:
492
- pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1])
493
- pos_embed = AbsolutePositionEmbedding(
494
- pos_shape=pos_shape,
495
- pos_dim=embed_dims_i,
496
- drop_rate=drop_rate)
497
- layers.append(pos_embed)
498
- layers.extend([
499
- PVTEncoderLayer(
500
- embed_dims=embed_dims_i,
501
- num_heads=num_heads[i],
502
- feedforward_channels=mlp_ratios[i] * embed_dims_i,
503
- drop_rate=drop_rate,
504
- attn_drop_rate=attn_drop_rate,
505
- drop_path_rate=dpr[cur + idx],
506
- qkv_bias=qkv_bias,
507
- act_cfg=act_cfg,
508
- norm_cfg=norm_cfg,
509
- sr_ratio=sr_ratios[i],
510
- use_conv_ffn=use_conv_ffn) for idx in range(num_layer)
511
- ])
512
- in_channels = embed_dims_i
513
- # The ret[0] of build_norm_layer is norm name.
514
- if norm_after_stage:
515
- norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
516
- else:
517
- norm = nn.Identity()
518
- self.layers.append(ModuleList([patch_embed, layers, norm]))
519
- cur += num_layer
520
-
521
- def init_weights(self, pretrained=None):
522
- if isinstance(pretrained, str):
523
- self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
524
-
525
- logger = get_root_logger()
526
- if self.init_cfg is None:
527
- logger.warn(f'No pre-trained weights for '
528
- f'{self.__class__.__name__}, '
529
- f'training start from scratch')
530
- for m in self.modules():
531
- if isinstance(m, nn.Linear):
532
- trunc_normal_init(m, std=.02, bias=0.)
533
- elif isinstance(m, nn.LayerNorm):
534
- constant_init(m, 1.0)
535
- elif isinstance(m, nn.Conv2d):
536
- fan_out = m.kernel_size[0] * m.kernel_size[
537
- 1] * m.out_channels
538
- fan_out //= m.groups
539
- normal_init(m, 0, math.sqrt(2.0 / fan_out))
540
- elif isinstance(m, AbsolutePositionEmbedding):
541
- m.init_weights()
542
- else:
543
- assert 'checkpoint' in self.init_cfg, f'Only support ' \
544
- f'specify `Pretrained` in ' \
545
- f'`init_cfg` in ' \
546
- f'{self.__class__.__name__} '
547
- checkpoint = _load_checkpoint(
548
- self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
549
- logger.warn(f'Load pre-trained model for '
550
- f'{self.__class__.__name__} from original repo')
551
- if 'state_dict' in checkpoint:
552
- state_dict = checkpoint['state_dict']
553
- elif 'model' in checkpoint:
554
- state_dict = checkpoint['model']
555
- else:
556
- state_dict = checkpoint
557
- if self.convert_weights:
558
- # Because pvt backbones are not supported by mmcls,
559
- # so we need to convert pre-trained weights to match this
560
- # implementation.
561
- state_dict = pvt_convert(state_dict)
562
- load_state_dict(self, state_dict, strict=False, logger=logger)
563
-
564
- def forward(self, x):
565
- outs = []
566
-
567
- for i, layer in enumerate(self.layers):
568
- x, hw_shape = layer[0](x)
569
-
570
- for block in layer[1]:
571
- x = block(x, hw_shape)
572
- x = layer[2](x)
573
- x = nlc_to_nchw(x, hw_shape)
574
- if i in self.out_indices:
575
- outs.append(x)
576
-
577
- return outs
578
-
579
-
580
- @BACKBONES.register_module()
581
- class PyramidVisionTransformerV2(PyramidVisionTransformer):
582
- """Implementation of `PVTv2: Improved Baselines with Pyramid Vision
583
- Transformer <https://arxiv.org/pdf/2106.13797.pdf>`_."""
584
-
585
- def __init__(self, **kwargs):
586
- super(PyramidVisionTransformerV2, self).__init__(
587
- patch_sizes=[7, 3, 3, 3],
588
- paddings=[3, 1, 1, 1],
589
- use_abs_pos_embed=False,
590
- norm_after_stage=True,
591
- use_conv_ffn=True,
592
- **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/regnet.py DELETED
@@ -1,317 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
-
4
- import numpy as np
5
- import torch.nn as nn
6
- from mmcv.cnn import build_conv_layer, build_norm_layer
7
-
8
- from ..builder import BACKBONES
9
- from .resnet import ResNet
10
- from .resnext import Bottleneck
11
-
12
-
13
- @BACKBONES.register_module()
14
- class RegNet(ResNet):
15
- """RegNet backbone.
16
-
17
- More details can be found in `paper <https://arxiv.org/abs/2003.13678>`__ .
18
-
19
- Args:
20
- arch (dict): The parameter of RegNets.
21
- - w0 (int): initial width
22
- - wa (float): slope of width
23
- - wm (float): quantization parameter to quantize the width
24
- - depth (int): depth of the backbone
25
- - group_w (int): width of group
26
- - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck.
27
- strides (Sequence[int]): Strides of the first block of each stage.
28
- base_channels (int): Base channels after stem layer.
29
- in_channels (int): Number of input image channels. Default: 3.
30
- dilations (Sequence[int]): Dilation of each stage.
31
- out_indices (Sequence[int]): Output from which stages.
32
- style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
33
- layer is the 3x3 conv layer, otherwise the stride-two layer is
34
- the first 1x1 conv layer. Default: "pytorch".
35
- frozen_stages (int): Stages to be frozen (all param fixed). -1 means
36
- not freezing any parameters. Default: -1.
37
- norm_cfg (dict): dictionary to construct and config norm layer.
38
- Default: dict(type='BN', requires_grad=True).
39
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
40
- freeze running stats (mean and var). Note: Effect on Batch Norm
41
- and its variants only. Default: False.
42
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
43
- memory while slowing down the training speed. Default: False.
44
- zero_init_residual (bool): whether to use zero init for last norm layer
45
- in resblocks to let them behave as identity. Default: True.
46
-
47
- Example:
48
- >>> from mmpose.models import RegNet
49
- >>> import torch
50
- >>> self = RegNet(
51
- arch=dict(
52
- w0=88,
53
- wa=26.31,
54
- wm=2.25,
55
- group_w=48,
56
- depth=25,
57
- bot_mul=1.0),
58
- out_indices=(0, 1, 2, 3))
59
- >>> self.eval()
60
- >>> inputs = torch.rand(1, 3, 32, 32)
61
- >>> level_outputs = self.forward(inputs)
62
- >>> for level_out in level_outputs:
63
- ... print(tuple(level_out.shape))
64
- (1, 96, 8, 8)
65
- (1, 192, 4, 4)
66
- (1, 432, 2, 2)
67
- (1, 1008, 1, 1)
68
- """
69
- arch_settings = {
70
- 'regnetx_400mf':
71
- dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
72
- 'regnetx_800mf':
73
- dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
74
- 'regnetx_1.6gf':
75
- dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
76
- 'regnetx_3.2gf':
77
- dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
78
- 'regnetx_4.0gf':
79
- dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
80
- 'regnetx_6.4gf':
81
- dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
82
- 'regnetx_8.0gf':
83
- dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
84
- 'regnetx_12gf':
85
- dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
86
- }
87
-
88
- def __init__(self,
89
- arch,
90
- in_channels=3,
91
- stem_channels=32,
92
- base_channels=32,
93
- strides=(2, 2, 2, 2),
94
- dilations=(1, 1, 1, 1),
95
- out_indices=(3, ),
96
- style='pytorch',
97
- deep_stem=False,
98
- avg_down=False,
99
- frozen_stages=-1,
100
- conv_cfg=None,
101
- norm_cfg=dict(type='BN', requires_grad=True),
102
- norm_eval=False,
103
- with_cp=False,
104
- zero_init_residual=True):
105
- # Protect mutable default arguments
106
- norm_cfg = copy.deepcopy(norm_cfg)
107
- super(ResNet, self).__init__()
108
-
109
- # Generate RegNet parameters first
110
- if isinstance(arch, str):
111
- assert arch in self.arch_settings, \
112
- f'"arch": "{arch}" is not one of the' \
113
- ' arch_settings'
114
- arch = self.arch_settings[arch]
115
- elif not isinstance(arch, dict):
116
- raise TypeError('Expect "arch" to be either a string '
117
- f'or a dict, got {type(arch)}')
118
-
119
- widths, num_stages = self.generate_regnet(
120
- arch['w0'],
121
- arch['wa'],
122
- arch['wm'],
123
- arch['depth'],
124
- )
125
- # Convert to per stage format
126
- stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
127
- # Generate group widths and bot muls
128
- group_widths = [arch['group_w'] for _ in range(num_stages)]
129
- self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
130
- # Adjust the compatibility of stage_widths and group_widths
131
- stage_widths, group_widths = self.adjust_width_group(
132
- stage_widths, self.bottleneck_ratio, group_widths)
133
-
134
- # Group params by stage
135
- self.stage_widths = stage_widths
136
- self.group_widths = group_widths
137
- self.depth = sum(stage_blocks)
138
- self.stem_channels = stem_channels
139
- self.base_channels = base_channels
140
- self.num_stages = num_stages
141
- assert 1 <= num_stages <= 4
142
- self.strides = strides
143
- self.dilations = dilations
144
- assert len(strides) == len(dilations) == num_stages
145
- self.out_indices = out_indices
146
- assert max(out_indices) < num_stages
147
- self.style = style
148
- self.deep_stem = deep_stem
149
- if self.deep_stem:
150
- raise NotImplementedError(
151
- 'deep_stem has not been implemented for RegNet')
152
- self.avg_down = avg_down
153
- self.frozen_stages = frozen_stages
154
- self.conv_cfg = conv_cfg
155
- self.norm_cfg = norm_cfg
156
- self.with_cp = with_cp
157
- self.norm_eval = norm_eval
158
- self.zero_init_residual = zero_init_residual
159
- self.stage_blocks = stage_blocks[:num_stages]
160
-
161
- self._make_stem_layer(in_channels, stem_channels)
162
-
163
- _in_channels = stem_channels
164
- self.res_layers = []
165
- for i, num_blocks in enumerate(self.stage_blocks):
166
- stride = self.strides[i]
167
- dilation = self.dilations[i]
168
- group_width = self.group_widths[i]
169
- width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
170
- stage_groups = width // group_width
171
-
172
- res_layer = self.make_res_layer(
173
- block=Bottleneck,
174
- num_blocks=num_blocks,
175
- in_channels=_in_channels,
176
- out_channels=self.stage_widths[i],
177
- expansion=1,
178
- stride=stride,
179
- dilation=dilation,
180
- style=self.style,
181
- avg_down=self.avg_down,
182
- with_cp=self.with_cp,
183
- conv_cfg=self.conv_cfg,
184
- norm_cfg=self.norm_cfg,
185
- base_channels=self.stage_widths[i],
186
- groups=stage_groups,
187
- width_per_group=group_width)
188
- _in_channels = self.stage_widths[i]
189
- layer_name = f'layer{i + 1}'
190
- self.add_module(layer_name, res_layer)
191
- self.res_layers.append(layer_name)
192
-
193
- self._freeze_stages()
194
-
195
- self.feat_dim = stage_widths[-1]
196
-
197
- def _make_stem_layer(self, in_channels, base_channels):
198
- self.conv1 = build_conv_layer(
199
- self.conv_cfg,
200
- in_channels,
201
- base_channels,
202
- kernel_size=3,
203
- stride=2,
204
- padding=1,
205
- bias=False)
206
- self.norm1_name, norm1 = build_norm_layer(
207
- self.norm_cfg, base_channels, postfix=1)
208
- self.add_module(self.norm1_name, norm1)
209
- self.relu = nn.ReLU(inplace=True)
210
-
211
- @staticmethod
212
- def generate_regnet(initial_width,
213
- width_slope,
214
- width_parameter,
215
- depth,
216
- divisor=8):
217
- """Generates per block width from RegNet parameters.
218
-
219
- Args:
220
- initial_width ([int]): Initial width of the backbone
221
- width_slope ([float]): Slope of the quantized linear function
222
- width_parameter ([int]): Parameter used to quantize the width.
223
- depth ([int]): Depth of the backbone.
224
- divisor (int, optional): The divisor of channels. Defaults to 8.
225
-
226
- Returns:
227
- list, int: return a list of widths of each stage and the number of
228
- stages
229
- """
230
- assert width_slope >= 0
231
- assert initial_width > 0
232
- assert width_parameter > 1
233
- assert initial_width % divisor == 0
234
- widths_cont = np.arange(depth) * width_slope + initial_width
235
- ks = np.round(
236
- np.log(widths_cont / initial_width) / np.log(width_parameter))
237
- widths = initial_width * np.power(width_parameter, ks)
238
- widths = np.round(np.divide(widths, divisor)) * divisor
239
- num_stages = len(np.unique(widths))
240
- widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
241
- return widths, num_stages
242
-
243
- @staticmethod
244
- def quantize_float(number, divisor):
245
- """Converts a float to closest non-zero int divisible by divior.
246
-
247
- Args:
248
- number (int): Original number to be quantized.
249
- divisor (int): Divisor used to quantize the number.
250
-
251
- Returns:
252
- int: quantized number that is divisible by devisor.
253
- """
254
- return int(round(number / divisor) * divisor)
255
-
256
- def adjust_width_group(self, widths, bottleneck_ratio, groups):
257
- """Adjusts the compatibility of widths and groups.
258
-
259
- Args:
260
- widths (list[int]): Width of each stage.
261
- bottleneck_ratio (float): Bottleneck ratio.
262
- groups (int): number of groups in each stage
263
-
264
- Returns:
265
- tuple(list): The adjusted widths and groups of each stage.
266
- """
267
- bottleneck_width = [
268
- int(w * b) for w, b in zip(widths, bottleneck_ratio)
269
- ]
270
- groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
271
- bottleneck_width = [
272
- self.quantize_float(w_bot, g)
273
- for w_bot, g in zip(bottleneck_width, groups)
274
- ]
275
- widths = [
276
- int(w_bot / b)
277
- for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
278
- ]
279
- return widths, groups
280
-
281
- def get_stages_from_blocks(self, widths):
282
- """Gets widths/stage_blocks of network at each stage.
283
-
284
- Args:
285
- widths (list[int]): Width in each stage.
286
-
287
- Returns:
288
- tuple(list): width and depth of each stage
289
- """
290
- width_diff = [
291
- width != width_prev
292
- for width, width_prev in zip(widths + [0], [0] + widths)
293
- ]
294
- stage_widths = [
295
- width for width, diff in zip(widths, width_diff[:-1]) if diff
296
- ]
297
- stage_blocks = np.diff([
298
- depth for depth, diff in zip(range(len(width_diff)), width_diff)
299
- if diff
300
- ]).tolist()
301
- return stage_widths, stage_blocks
302
-
303
- def forward(self, x):
304
- x = self.conv1(x)
305
- x = self.norm1(x)
306
- x = self.relu(x)
307
-
308
- outs = []
309
- for i, layer_name in enumerate(self.res_layers):
310
- res_layer = getattr(self, layer_name)
311
- x = res_layer(x)
312
- if i in self.out_indices:
313
- outs.append(x)
314
-
315
- if len(outs) == 1:
316
- return outs[0]
317
- return tuple(outs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/resnest.py DELETED
@@ -1,338 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import torch.utils.checkpoint as cp
6
- from mmcv.cnn import build_conv_layer, build_norm_layer
7
-
8
- from ..builder import BACKBONES
9
- from .resnet import Bottleneck as _Bottleneck
10
- from .resnet import ResLayer, ResNetV1d
11
-
12
-
13
- class RSoftmax(nn.Module):
14
- """Radix Softmax module in ``SplitAttentionConv2d``.
15
-
16
- Args:
17
- radix (int): Radix of input.
18
- groups (int): Groups of input.
19
- """
20
-
21
- def __init__(self, radix, groups):
22
- super().__init__()
23
- self.radix = radix
24
- self.groups = groups
25
-
26
- def forward(self, x):
27
- batch = x.size(0)
28
- if self.radix > 1:
29
- x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
30
- x = F.softmax(x, dim=1)
31
- x = x.reshape(batch, -1)
32
- else:
33
- x = torch.sigmoid(x)
34
- return x
35
-
36
-
37
- class SplitAttentionConv2d(nn.Module):
38
- """Split-Attention Conv2d.
39
-
40
- Args:
41
- in_channels (int): Same as nn.Conv2d.
42
- out_channels (int): Same as nn.Conv2d.
43
- kernel_size (int | tuple[int]): Same as nn.Conv2d.
44
- stride (int | tuple[int]): Same as nn.Conv2d.
45
- padding (int | tuple[int]): Same as nn.Conv2d.
46
- dilation (int | tuple[int]): Same as nn.Conv2d.
47
- groups (int): Same as nn.Conv2d.
48
- radix (int): Radix of SpltAtConv2d. Default: 2
49
- reduction_factor (int): Reduction factor of SplitAttentionConv2d.
50
- Default: 4.
51
- conv_cfg (dict): Config dict for convolution layer. Default: None,
52
- which means using conv2d.
53
- norm_cfg (dict): Config dict for normalization layer. Default: None.
54
- """
55
-
56
- def __init__(self,
57
- in_channels,
58
- channels,
59
- kernel_size,
60
- stride=1,
61
- padding=0,
62
- dilation=1,
63
- groups=1,
64
- radix=2,
65
- reduction_factor=4,
66
- conv_cfg=None,
67
- norm_cfg=dict(type='BN')):
68
- super().__init__()
69
- inter_channels = max(in_channels * radix // reduction_factor, 32)
70
- self.radix = radix
71
- self.groups = groups
72
- self.channels = channels
73
- self.conv = build_conv_layer(
74
- conv_cfg,
75
- in_channels,
76
- channels * radix,
77
- kernel_size,
78
- stride=stride,
79
- padding=padding,
80
- dilation=dilation,
81
- groups=groups * radix,
82
- bias=False)
83
- self.norm0_name, norm0 = build_norm_layer(
84
- norm_cfg, channels * radix, postfix=0)
85
- self.add_module(self.norm0_name, norm0)
86
- self.relu = nn.ReLU(inplace=True)
87
- self.fc1 = build_conv_layer(
88
- None, channels, inter_channels, 1, groups=self.groups)
89
- self.norm1_name, norm1 = build_norm_layer(
90
- norm_cfg, inter_channels, postfix=1)
91
- self.add_module(self.norm1_name, norm1)
92
- self.fc2 = build_conv_layer(
93
- None, inter_channels, channels * radix, 1, groups=self.groups)
94
- self.rsoftmax = RSoftmax(radix, groups)
95
-
96
- @property
97
- def norm0(self):
98
- return getattr(self, self.norm0_name)
99
-
100
- @property
101
- def norm1(self):
102
- return getattr(self, self.norm1_name)
103
-
104
- def forward(self, x):
105
- x = self.conv(x)
106
- x = self.norm0(x)
107
- x = self.relu(x)
108
-
109
- batch, rchannel = x.shape[:2]
110
- if self.radix > 1:
111
- splits = x.view(batch, self.radix, -1, *x.shape[2:])
112
- gap = splits.sum(dim=1)
113
- else:
114
- gap = x
115
- gap = F.adaptive_avg_pool2d(gap, 1)
116
- gap = self.fc1(gap)
117
-
118
- gap = self.norm1(gap)
119
- gap = self.relu(gap)
120
-
121
- atten = self.fc2(gap)
122
- atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
123
-
124
- if self.radix > 1:
125
- attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
126
- out = torch.sum(attens * splits, dim=1)
127
- else:
128
- out = atten * x
129
- return out.contiguous()
130
-
131
-
132
- class Bottleneck(_Bottleneck):
133
- """Bottleneck block for ResNeSt.
134
-
135
- Args:
136
- in_channels (int): Input channels of this block.
137
- out_channels (int): Output channels of this block.
138
- groups (int): Groups of conv2.
139
- width_per_group (int): Width per group of conv2. 64x4d indicates
140
- ``groups=64, width_per_group=4`` and 32x8d indicates
141
- ``groups=32, width_per_group=8``.
142
- radix (int): Radix of SpltAtConv2d. Default: 2
143
- reduction_factor (int): Reduction factor of SplitAttentionConv2d.
144
- Default: 4.
145
- avg_down_stride (bool): Whether to use average pool for stride in
146
- Bottleneck. Default: True.
147
- stride (int): stride of the block. Default: 1
148
- dilation (int): dilation of convolution. Default: 1
149
- downsample (nn.Module): downsample operation on identity branch.
150
- Default: None
151
- style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
152
- layer is the 3x3 conv layer, otherwise the stride-two layer is
153
- the first 1x1 conv layer.
154
- conv_cfg (dict): dictionary to construct and config conv layer.
155
- Default: None
156
- norm_cfg (dict): dictionary to construct and config norm layer.
157
- Default: dict(type='BN')
158
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
159
- memory while slowing down the training speed.
160
- """
161
-
162
- def __init__(self,
163
- in_channels,
164
- out_channels,
165
- groups=1,
166
- width_per_group=4,
167
- base_channels=64,
168
- radix=2,
169
- reduction_factor=4,
170
- avg_down_stride=True,
171
- **kwargs):
172
- super().__init__(in_channels, out_channels, **kwargs)
173
-
174
- self.groups = groups
175
- self.width_per_group = width_per_group
176
-
177
- # For ResNet bottleneck, middle channels are determined by expansion
178
- # and out_channels, but for ResNeXt bottleneck, it is determined by
179
- # groups and width_per_group and the stage it is located in.
180
- if groups != 1:
181
- assert self.mid_channels % base_channels == 0
182
- self.mid_channels = (
183
- groups * width_per_group * self.mid_channels // base_channels)
184
-
185
- self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
186
-
187
- self.norm1_name, norm1 = build_norm_layer(
188
- self.norm_cfg, self.mid_channels, postfix=1)
189
- self.norm3_name, norm3 = build_norm_layer(
190
- self.norm_cfg, self.out_channels, postfix=3)
191
-
192
- self.conv1 = build_conv_layer(
193
- self.conv_cfg,
194
- self.in_channels,
195
- self.mid_channels,
196
- kernel_size=1,
197
- stride=self.conv1_stride,
198
- bias=False)
199
- self.add_module(self.norm1_name, norm1)
200
- self.conv2 = SplitAttentionConv2d(
201
- self.mid_channels,
202
- self.mid_channels,
203
- kernel_size=3,
204
- stride=1 if self.avg_down_stride else self.conv2_stride,
205
- padding=self.dilation,
206
- dilation=self.dilation,
207
- groups=groups,
208
- radix=radix,
209
- reduction_factor=reduction_factor,
210
- conv_cfg=self.conv_cfg,
211
- norm_cfg=self.norm_cfg)
212
- delattr(self, self.norm2_name)
213
-
214
- if self.avg_down_stride:
215
- self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
216
-
217
- self.conv3 = build_conv_layer(
218
- self.conv_cfg,
219
- self.mid_channels,
220
- self.out_channels,
221
- kernel_size=1,
222
- bias=False)
223
- self.add_module(self.norm3_name, norm3)
224
-
225
- def forward(self, x):
226
-
227
- def _inner_forward(x):
228
- identity = x
229
-
230
- out = self.conv1(x)
231
- out = self.norm1(out)
232
- out = self.relu(out)
233
-
234
- out = self.conv2(out)
235
-
236
- if self.avg_down_stride:
237
- out = self.avd_layer(out)
238
-
239
- out = self.conv3(out)
240
- out = self.norm3(out)
241
-
242
- if self.downsample is not None:
243
- identity = self.downsample(x)
244
-
245
- out += identity
246
-
247
- return out
248
-
249
- if self.with_cp and x.requires_grad:
250
- out = cp.checkpoint(_inner_forward, x)
251
- else:
252
- out = _inner_forward(x)
253
-
254
- out = self.relu(out)
255
-
256
- return out
257
-
258
-
259
- @BACKBONES.register_module()
260
- class ResNeSt(ResNetV1d):
261
- """ResNeSt backbone.
262
-
263
- Please refer to the `paper <https://arxiv.org/pdf/2004.08955.pdf>`__
264
- for details.
265
-
266
- Args:
267
- depth (int): Network depth, from {50, 101, 152, 200}.
268
- groups (int): Groups of conv2 in Bottleneck. Default: 32.
269
- width_per_group (int): Width per group of conv2 in Bottleneck.
270
- Default: 4.
271
- radix (int): Radix of SpltAtConv2d. Default: 2
272
- reduction_factor (int): Reduction factor of SplitAttentionConv2d.
273
- Default: 4.
274
- avg_down_stride (bool): Whether to use average pool for stride in
275
- Bottleneck. Default: True.
276
- in_channels (int): Number of input image channels. Default: 3.
277
- stem_channels (int): Output channels of the stem layer. Default: 64.
278
- num_stages (int): Stages of the network. Default: 4.
279
- strides (Sequence[int]): Strides of the first block of each stage.
280
- Default: ``(1, 2, 2, 2)``.
281
- dilations (Sequence[int]): Dilation of each stage.
282
- Default: ``(1, 1, 1, 1)``.
283
- out_indices (Sequence[int]): Output from which stages. If only one
284
- stage is specified, a single tensor (feature map) is returned,
285
- otherwise multiple stages are specified, a tuple of tensors will
286
- be returned. Default: ``(3, )``.
287
- style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
288
- layer is the 3x3 conv layer, otherwise the stride-two layer is
289
- the first 1x1 conv layer.
290
- deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
291
- Default: False.
292
- avg_down (bool): Use AvgPool instead of stride conv when
293
- downsampling in the bottleneck. Default: False.
294
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
295
- -1 means not freezing any parameters. Default: -1.
296
- conv_cfg (dict | None): The config dict for conv layers. Default: None.
297
- norm_cfg (dict): The config dict for norm layers.
298
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
299
- freeze running stats (mean and var). Note: Effect on Batch Norm
300
- and its variants only. Default: False.
301
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
302
- memory while slowing down the training speed. Default: False.
303
- zero_init_residual (bool): Whether to use zero init for last norm layer
304
- in resblocks to let them behave as identity. Default: True.
305
- """
306
-
307
- arch_settings = {
308
- 50: (Bottleneck, (3, 4, 6, 3)),
309
- 101: (Bottleneck, (3, 4, 23, 3)),
310
- 152: (Bottleneck, (3, 8, 36, 3)),
311
- 200: (Bottleneck, (3, 24, 36, 3)),
312
- 269: (Bottleneck, (3, 30, 48, 8))
313
- }
314
-
315
- def __init__(self,
316
- depth,
317
- groups=1,
318
- width_per_group=4,
319
- radix=2,
320
- reduction_factor=4,
321
- avg_down_stride=True,
322
- **kwargs):
323
- self.groups = groups
324
- self.width_per_group = width_per_group
325
- self.radix = radix
326
- self.reduction_factor = reduction_factor
327
- self.avg_down_stride = avg_down_stride
328
- super().__init__(depth=depth, **kwargs)
329
-
330
- def make_res_layer(self, **kwargs):
331
- return ResLayer(
332
- groups=self.groups,
333
- width_per_group=self.width_per_group,
334
- base_channels=self.base_channels,
335
- radix=self.radix,
336
- reduction_factor=self.reduction_factor,
337
- avg_down_stride=self.avg_down_stride,
338
- **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/resnet.py CHANGED
@@ -3,9 +3,9 @@ import copy
3
 
4
  import torch.nn as nn
5
  import torch.utils.checkpoint as cp
6
- from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
7
- constant_init, kaiming_init)
8
- from mmcv.utils.parrots_wrapper import _BatchNorm
9
 
10
  from ..builder import BACKBONES
11
  from .base_backbone import BaseBackbone
 
3
 
4
  import torch.nn as nn
5
  import torch.utils.checkpoint as cp
6
+ from mmengine.model import constant_init, kaiming_init
7
+ from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer)
8
+ from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
9
 
10
  from ..builder import BACKBONES
11
  from .base_backbone import BaseBackbone
main/transformer_utils/mmpose/models/backbones/resnext.py DELETED
@@ -1,162 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from mmcv.cnn import build_conv_layer, build_norm_layer
3
-
4
- from ..builder import BACKBONES
5
- from .resnet import Bottleneck as _Bottleneck
6
- from .resnet import ResLayer, ResNet
7
-
8
-
9
- class Bottleneck(_Bottleneck):
10
- """Bottleneck block for ResNeXt.
11
-
12
- Args:
13
- in_channels (int): Input channels of this block.
14
- out_channels (int): Output channels of this block.
15
- groups (int): Groups of conv2.
16
- width_per_group (int): Width per group of conv2. 64x4d indicates
17
- ``groups=64, width_per_group=4`` and 32x8d indicates
18
- ``groups=32, width_per_group=8``.
19
- stride (int): stride of the block. Default: 1
20
- dilation (int): dilation of convolution. Default: 1
21
- downsample (nn.Module): downsample operation on identity branch.
22
- Default: None
23
- style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
24
- layer is the 3x3 conv layer, otherwise the stride-two layer is
25
- the first 1x1 conv layer.
26
- conv_cfg (dict): dictionary to construct and config conv layer.
27
- Default: None
28
- norm_cfg (dict): dictionary to construct and config norm layer.
29
- Default: dict(type='BN')
30
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
31
- memory while slowing down the training speed.
32
- """
33
-
34
- def __init__(self,
35
- in_channels,
36
- out_channels,
37
- base_channels=64,
38
- groups=32,
39
- width_per_group=4,
40
- **kwargs):
41
- super().__init__(in_channels, out_channels, **kwargs)
42
- self.groups = groups
43
- self.width_per_group = width_per_group
44
-
45
- # For ResNet bottleneck, middle channels are determined by expansion
46
- # and out_channels, but for ResNeXt bottleneck, it is determined by
47
- # groups and width_per_group and the stage it is located in.
48
- if groups != 1:
49
- assert self.mid_channels % base_channels == 0
50
- self.mid_channels = (
51
- groups * width_per_group * self.mid_channels // base_channels)
52
-
53
- self.norm1_name, norm1 = build_norm_layer(
54
- self.norm_cfg, self.mid_channels, postfix=1)
55
- self.norm2_name, norm2 = build_norm_layer(
56
- self.norm_cfg, self.mid_channels, postfix=2)
57
- self.norm3_name, norm3 = build_norm_layer(
58
- self.norm_cfg, self.out_channels, postfix=3)
59
-
60
- self.conv1 = build_conv_layer(
61
- self.conv_cfg,
62
- self.in_channels,
63
- self.mid_channels,
64
- kernel_size=1,
65
- stride=self.conv1_stride,
66
- bias=False)
67
- self.add_module(self.norm1_name, norm1)
68
- self.conv2 = build_conv_layer(
69
- self.conv_cfg,
70
- self.mid_channels,
71
- self.mid_channels,
72
- kernel_size=3,
73
- stride=self.conv2_stride,
74
- padding=self.dilation,
75
- dilation=self.dilation,
76
- groups=groups,
77
- bias=False)
78
-
79
- self.add_module(self.norm2_name, norm2)
80
- self.conv3 = build_conv_layer(
81
- self.conv_cfg,
82
- self.mid_channels,
83
- self.out_channels,
84
- kernel_size=1,
85
- bias=False)
86
- self.add_module(self.norm3_name, norm3)
87
-
88
-
89
- @BACKBONES.register_module()
90
- class ResNeXt(ResNet):
91
- """ResNeXt backbone.
92
-
93
- Please refer to the `paper <https://arxiv.org/abs/1611.05431>`__ for
94
- details.
95
-
96
- Args:
97
- depth (int): Network depth, from {50, 101, 152}.
98
- groups (int): Groups of conv2 in Bottleneck. Default: 32.
99
- width_per_group (int): Width per group of conv2 in Bottleneck.
100
- Default: 4.
101
- in_channels (int): Number of input image channels. Default: 3.
102
- stem_channels (int): Output channels of the stem layer. Default: 64.
103
- num_stages (int): Stages of the network. Default: 4.
104
- strides (Sequence[int]): Strides of the first block of each stage.
105
- Default: ``(1, 2, 2, 2)``.
106
- dilations (Sequence[int]): Dilation of each stage.
107
- Default: ``(1, 1, 1, 1)``.
108
- out_indices (Sequence[int]): Output from which stages. If only one
109
- stage is specified, a single tensor (feature map) is returned,
110
- otherwise multiple stages are specified, a tuple of tensors will
111
- be returned. Default: ``(3, )``.
112
- style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
113
- layer is the 3x3 conv layer, otherwise the stride-two layer is
114
- the first 1x1 conv layer.
115
- deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
116
- Default: False.
117
- avg_down (bool): Use AvgPool instead of stride conv when
118
- downsampling in the bottleneck. Default: False.
119
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
120
- -1 means not freezing any parameters. Default: -1.
121
- conv_cfg (dict | None): The config dict for conv layers. Default: None.
122
- norm_cfg (dict): The config dict for norm layers.
123
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
124
- freeze running stats (mean and var). Note: Effect on Batch Norm
125
- and its variants only. Default: False.
126
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
127
- memory while slowing down the training speed. Default: False.
128
- zero_init_residual (bool): Whether to use zero init for last norm layer
129
- in resblocks to let them behave as identity. Default: True.
130
-
131
- Example:
132
- >>> from mmpose.models import ResNeXt
133
- >>> import torch
134
- >>> self = ResNeXt(depth=50, out_indices=(0, 1, 2, 3))
135
- >>> self.eval()
136
- >>> inputs = torch.rand(1, 3, 32, 32)
137
- >>> level_outputs = self.forward(inputs)
138
- >>> for level_out in level_outputs:
139
- ... print(tuple(level_out.shape))
140
- (1, 256, 8, 8)
141
- (1, 512, 4, 4)
142
- (1, 1024, 2, 2)
143
- (1, 2048, 1, 1)
144
- """
145
-
146
- arch_settings = {
147
- 50: (Bottleneck, (3, 4, 6, 3)),
148
- 101: (Bottleneck, (3, 4, 23, 3)),
149
- 152: (Bottleneck, (3, 8, 36, 3))
150
- }
151
-
152
- def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
153
- self.groups = groups
154
- self.width_per_group = width_per_group
155
- super().__init__(depth, **kwargs)
156
-
157
- def make_res_layer(self, **kwargs):
158
- return ResLayer(
159
- groups=self.groups,
160
- width_per_group=self.width_per_group,
161
- base_channels=self.base_channels,
162
- **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/rsn.py DELETED
@@ -1,616 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy as cp
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from mmcv.cnn import (ConvModule, MaxPool2d, constant_init, kaiming_init,
8
- normal_init)
9
-
10
- from ..builder import BACKBONES
11
- from .base_backbone import BaseBackbone
12
-
13
-
14
- class RSB(nn.Module):
15
- """Residual Steps block for RSN. Paper ref: Cai et al. "Learning Delicate
16
- Local Representations for Multi-Person Pose Estimation" (ECCV 2020).
17
-
18
- Args:
19
- in_channels (int): Input channels of this block.
20
- out_channels (int): Output channels of this block.
21
- num_steps (int): Numbers of steps in RSB
22
- stride (int): stride of the block. Default: 1
23
- downsample (nn.Module): downsample operation on identity branch.
24
- Default: None.
25
- norm_cfg (dict): dictionary to construct and config norm layer.
26
- Default: dict(type='BN')
27
- expand_times (int): Times by which the in_channels are expanded.
28
- Default:26.
29
- res_top_channels (int): Number of channels of feature output by
30
- ResNet_top. Default:64.
31
- """
32
-
33
- expansion = 1
34
-
35
- def __init__(self,
36
- in_channels,
37
- out_channels,
38
- num_steps=4,
39
- stride=1,
40
- downsample=None,
41
- with_cp=False,
42
- norm_cfg=dict(type='BN'),
43
- expand_times=26,
44
- res_top_channels=64):
45
- # Protect mutable default arguments
46
- norm_cfg = cp.deepcopy(norm_cfg)
47
- super().__init__()
48
- assert num_steps > 1
49
- self.in_channels = in_channels
50
- self.branch_channels = self.in_channels * expand_times
51
- self.branch_channels //= res_top_channels
52
- self.out_channels = out_channels
53
- self.stride = stride
54
- self.downsample = downsample
55
- self.with_cp = with_cp
56
- self.norm_cfg = norm_cfg
57
- self.num_steps = num_steps
58
- self.conv_bn_relu1 = ConvModule(
59
- self.in_channels,
60
- self.num_steps * self.branch_channels,
61
- kernel_size=1,
62
- stride=self.stride,
63
- padding=0,
64
- norm_cfg=self.norm_cfg,
65
- inplace=False)
66
- for i in range(self.num_steps):
67
- for j in range(i + 1):
68
- module_name = f'conv_bn_relu2_{i + 1}_{j + 1}'
69
- self.add_module(
70
- module_name,
71
- ConvModule(
72
- self.branch_channels,
73
- self.branch_channels,
74
- kernel_size=3,
75
- stride=1,
76
- padding=1,
77
- norm_cfg=self.norm_cfg,
78
- inplace=False))
79
- self.conv_bn3 = ConvModule(
80
- self.num_steps * self.branch_channels,
81
- self.out_channels * self.expansion,
82
- kernel_size=1,
83
- stride=1,
84
- padding=0,
85
- act_cfg=None,
86
- norm_cfg=self.norm_cfg,
87
- inplace=False)
88
- self.relu = nn.ReLU(inplace=False)
89
-
90
- def forward(self, x):
91
- """Forward function."""
92
-
93
- identity = x
94
- x = self.conv_bn_relu1(x)
95
- spx = torch.split(x, self.branch_channels, 1)
96
- outputs = list()
97
- outs = list()
98
- for i in range(self.num_steps):
99
- outputs_i = list()
100
- outputs.append(outputs_i)
101
- for j in range(i + 1):
102
- if j == 0:
103
- inputs = spx[i]
104
- else:
105
- inputs = outputs[i][j - 1]
106
- if i > j:
107
- inputs = inputs + outputs[i - 1][j]
108
- module_name = f'conv_bn_relu2_{i + 1}_{j + 1}'
109
- module_i_j = getattr(self, module_name)
110
- outputs[i].append(module_i_j(inputs))
111
-
112
- outs.append(outputs[i][i])
113
- out = torch.cat(tuple(outs), 1)
114
- out = self.conv_bn3(out)
115
-
116
- if self.downsample is not None:
117
- identity = self.downsample(identity)
118
- out = out + identity
119
-
120
- out = self.relu(out)
121
-
122
- return out
123
-
124
-
125
- class Downsample_module(nn.Module):
126
- """Downsample module for RSN.
127
-
128
- Args:
129
- block (nn.Module): Downsample block.
130
- num_blocks (list): Number of blocks in each downsample unit.
131
- num_units (int): Numbers of downsample units. Default: 4
132
- has_skip (bool): Have skip connections from prior upsample
133
- module or not. Default:False
134
- num_steps (int): Number of steps in a block. Default:4
135
- norm_cfg (dict): dictionary to construct and config norm layer.
136
- Default: dict(type='BN')
137
- in_channels (int): Number of channels of the input feature to
138
- downsample module. Default: 64
139
- expand_times (int): Times by which the in_channels are expanded.
140
- Default:26.
141
- """
142
-
143
- def __init__(self,
144
- block,
145
- num_blocks,
146
- num_steps=4,
147
- num_units=4,
148
- has_skip=False,
149
- norm_cfg=dict(type='BN'),
150
- in_channels=64,
151
- expand_times=26):
152
- # Protect mutable default arguments
153
- norm_cfg = cp.deepcopy(norm_cfg)
154
- super().__init__()
155
- self.has_skip = has_skip
156
- self.in_channels = in_channels
157
- assert len(num_blocks) == num_units
158
- self.num_blocks = num_blocks
159
- self.num_units = num_units
160
- self.num_steps = num_steps
161
- self.norm_cfg = norm_cfg
162
- self.layer1 = self._make_layer(
163
- block,
164
- in_channels,
165
- num_blocks[0],
166
- expand_times=expand_times,
167
- res_top_channels=in_channels)
168
- for i in range(1, num_units):
169
- module_name = f'layer{i + 1}'
170
- self.add_module(
171
- module_name,
172
- self._make_layer(
173
- block,
174
- in_channels * pow(2, i),
175
- num_blocks[i],
176
- stride=2,
177
- expand_times=expand_times,
178
- res_top_channels=in_channels))
179
-
180
- def _make_layer(self,
181
- block,
182
- out_channels,
183
- blocks,
184
- stride=1,
185
- expand_times=26,
186
- res_top_channels=64):
187
- downsample = None
188
- if stride != 1 or self.in_channels != out_channels * block.expansion:
189
- downsample = ConvModule(
190
- self.in_channels,
191
- out_channels * block.expansion,
192
- kernel_size=1,
193
- stride=stride,
194
- padding=0,
195
- norm_cfg=self.norm_cfg,
196
- act_cfg=None,
197
- inplace=True)
198
-
199
- units = list()
200
- units.append(
201
- block(
202
- self.in_channels,
203
- out_channels,
204
- num_steps=self.num_steps,
205
- stride=stride,
206
- downsample=downsample,
207
- norm_cfg=self.norm_cfg,
208
- expand_times=expand_times,
209
- res_top_channels=res_top_channels))
210
- self.in_channels = out_channels * block.expansion
211
- for _ in range(1, blocks):
212
- units.append(
213
- block(
214
- self.in_channels,
215
- out_channels,
216
- num_steps=self.num_steps,
217
- expand_times=expand_times,
218
- res_top_channels=res_top_channels))
219
-
220
- return nn.Sequential(*units)
221
-
222
- def forward(self, x, skip1, skip2):
223
- out = list()
224
- for i in range(self.num_units):
225
- module_name = f'layer{i + 1}'
226
- module_i = getattr(self, module_name)
227
- x = module_i(x)
228
- if self.has_skip:
229
- x = x + skip1[i] + skip2[i]
230
- out.append(x)
231
- out.reverse()
232
-
233
- return tuple(out)
234
-
235
-
236
- class Upsample_unit(nn.Module):
237
- """Upsample unit for upsample module.
238
-
239
- Args:
240
- ind (int): Indicates whether to interpolate (>0) and whether to
241
- generate feature map for the next hourglass-like module.
242
- num_units (int): Number of units that form a upsample module. Along
243
- with ind and gen_cross_conv, nm_units is used to decide whether
244
- to generate feature map for the next hourglass-like module.
245
- in_channels (int): Channel number of the skip-in feature maps from
246
- the corresponding downsample unit.
247
- unit_channels (int): Channel number in this unit. Default:256.
248
- gen_skip: (bool): Whether or not to generate skips for the posterior
249
- downsample module. Default:False
250
- gen_cross_conv (bool): Whether to generate feature map for the next
251
- hourglass-like module. Default:False
252
- norm_cfg (dict): dictionary to construct and config norm layer.
253
- Default: dict(type='BN')
254
- out_channels (in): Number of channels of feature output by upsample
255
- module. Must equal to in_channels of downsample module. Default:64
256
- """
257
-
258
- def __init__(self,
259
- ind,
260
- num_units,
261
- in_channels,
262
- unit_channels=256,
263
- gen_skip=False,
264
- gen_cross_conv=False,
265
- norm_cfg=dict(type='BN'),
266
- out_channels=64):
267
- # Protect mutable default arguments
268
- norm_cfg = cp.deepcopy(norm_cfg)
269
- super().__init__()
270
- self.num_units = num_units
271
- self.norm_cfg = norm_cfg
272
- self.in_skip = ConvModule(
273
- in_channels,
274
- unit_channels,
275
- kernel_size=1,
276
- stride=1,
277
- padding=0,
278
- norm_cfg=self.norm_cfg,
279
- act_cfg=None,
280
- inplace=True)
281
- self.relu = nn.ReLU(inplace=True)
282
-
283
- self.ind = ind
284
- if self.ind > 0:
285
- self.up_conv = ConvModule(
286
- unit_channels,
287
- unit_channels,
288
- kernel_size=1,
289
- stride=1,
290
- padding=0,
291
- norm_cfg=self.norm_cfg,
292
- act_cfg=None,
293
- inplace=True)
294
-
295
- self.gen_skip = gen_skip
296
- if self.gen_skip:
297
- self.out_skip1 = ConvModule(
298
- in_channels,
299
- in_channels,
300
- kernel_size=1,
301
- stride=1,
302
- padding=0,
303
- norm_cfg=self.norm_cfg,
304
- inplace=True)
305
-
306
- self.out_skip2 = ConvModule(
307
- unit_channels,
308
- in_channels,
309
- kernel_size=1,
310
- stride=1,
311
- padding=0,
312
- norm_cfg=self.norm_cfg,
313
- inplace=True)
314
-
315
- self.gen_cross_conv = gen_cross_conv
316
- if self.ind == num_units - 1 and self.gen_cross_conv:
317
- self.cross_conv = ConvModule(
318
- unit_channels,
319
- out_channels,
320
- kernel_size=1,
321
- stride=1,
322
- padding=0,
323
- norm_cfg=self.norm_cfg,
324
- inplace=True)
325
-
326
- def forward(self, x, up_x):
327
- out = self.in_skip(x)
328
-
329
- if self.ind > 0:
330
- up_x = F.interpolate(
331
- up_x,
332
- size=(x.size(2), x.size(3)),
333
- mode='bilinear',
334
- align_corners=True)
335
- up_x = self.up_conv(up_x)
336
- out = out + up_x
337
- out = self.relu(out)
338
-
339
- skip1 = None
340
- skip2 = None
341
- if self.gen_skip:
342
- skip1 = self.out_skip1(x)
343
- skip2 = self.out_skip2(out)
344
-
345
- cross_conv = None
346
- if self.ind == self.num_units - 1 and self.gen_cross_conv:
347
- cross_conv = self.cross_conv(out)
348
-
349
- return out, skip1, skip2, cross_conv
350
-
351
-
352
- class Upsample_module(nn.Module):
353
- """Upsample module for RSN.
354
-
355
- Args:
356
- unit_channels (int): Channel number in the upsample units.
357
- Default:256.
358
- num_units (int): Numbers of upsample units. Default: 4
359
- gen_skip (bool): Whether to generate skip for posterior downsample
360
- module or not. Default:False
361
- gen_cross_conv (bool): Whether to generate feature map for the next
362
- hourglass-like module. Default:False
363
- norm_cfg (dict): dictionary to construct and config norm layer.
364
- Default: dict(type='BN')
365
- out_channels (int): Number of channels of feature output by upsample
366
- module. Must equal to in_channels of downsample module. Default:64
367
- """
368
-
369
- def __init__(self,
370
- unit_channels=256,
371
- num_units=4,
372
- gen_skip=False,
373
- gen_cross_conv=False,
374
- norm_cfg=dict(type='BN'),
375
- out_channels=64):
376
- # Protect mutable default arguments
377
- norm_cfg = cp.deepcopy(norm_cfg)
378
- super().__init__()
379
- self.in_channels = list()
380
- for i in range(num_units):
381
- self.in_channels.append(RSB.expansion * out_channels * pow(2, i))
382
- self.in_channels.reverse()
383
- self.num_units = num_units
384
- self.gen_skip = gen_skip
385
- self.gen_cross_conv = gen_cross_conv
386
- self.norm_cfg = norm_cfg
387
- for i in range(num_units):
388
- module_name = f'up{i + 1}'
389
- self.add_module(
390
- module_name,
391
- Upsample_unit(
392
- i,
393
- self.num_units,
394
- self.in_channels[i],
395
- unit_channels,
396
- self.gen_skip,
397
- self.gen_cross_conv,
398
- norm_cfg=self.norm_cfg,
399
- out_channels=64))
400
-
401
- def forward(self, x):
402
- out = list()
403
- skip1 = list()
404
- skip2 = list()
405
- cross_conv = None
406
- for i in range(self.num_units):
407
- module_i = getattr(self, f'up{i + 1}')
408
- if i == 0:
409
- outi, skip1_i, skip2_i, _ = module_i(x[i], None)
410
- elif i == self.num_units - 1:
411
- outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1])
412
- else:
413
- outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1])
414
- out.append(outi)
415
- skip1.append(skip1_i)
416
- skip2.append(skip2_i)
417
- skip1.reverse()
418
- skip2.reverse()
419
-
420
- return out, skip1, skip2, cross_conv
421
-
422
-
423
- class Single_stage_RSN(nn.Module):
424
- """Single_stage Residual Steps Network.
425
-
426
- Args:
427
- unit_channels (int): Channel number in the upsample units. Default:256.
428
- num_units (int): Numbers of downsample/upsample units. Default: 4
429
- gen_skip (bool): Whether to generate skip for posterior downsample
430
- module or not. Default:False
431
- gen_cross_conv (bool): Whether to generate feature map for the next
432
- hourglass-like module. Default:False
433
- has_skip (bool): Have skip connections from prior upsample
434
- module or not. Default:False
435
- num_steps (int): Number of steps in RSB. Default: 4
436
- num_blocks (list): Number of blocks in each downsample unit.
437
- Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks)
438
- norm_cfg (dict): dictionary to construct and config norm layer.
439
- Default: dict(type='BN')
440
- in_channels (int): Number of channels of the feature from ResNet_Top.
441
- Default: 64.
442
- expand_times (int): Times by which the in_channels are expanded in RSB.
443
- Default:26.
444
- """
445
-
446
- def __init__(self,
447
- has_skip=False,
448
- gen_skip=False,
449
- gen_cross_conv=False,
450
- unit_channels=256,
451
- num_units=4,
452
- num_steps=4,
453
- num_blocks=[2, 2, 2, 2],
454
- norm_cfg=dict(type='BN'),
455
- in_channels=64,
456
- expand_times=26):
457
- # Protect mutable default arguments
458
- norm_cfg = cp.deepcopy(norm_cfg)
459
- num_blocks = cp.deepcopy(num_blocks)
460
- super().__init__()
461
- assert len(num_blocks) == num_units
462
- self.has_skip = has_skip
463
- self.gen_skip = gen_skip
464
- self.gen_cross_conv = gen_cross_conv
465
- self.num_units = num_units
466
- self.num_steps = num_steps
467
- self.unit_channels = unit_channels
468
- self.num_blocks = num_blocks
469
- self.norm_cfg = norm_cfg
470
-
471
- self.downsample = Downsample_module(RSB, num_blocks, num_steps,
472
- num_units, has_skip, norm_cfg,
473
- in_channels, expand_times)
474
- self.upsample = Upsample_module(unit_channels, num_units, gen_skip,
475
- gen_cross_conv, norm_cfg, in_channels)
476
-
477
- def forward(self, x, skip1, skip2):
478
- mid = self.downsample(x, skip1, skip2)
479
- out, skip1, skip2, cross_conv = self.upsample(mid)
480
-
481
- return out, skip1, skip2, cross_conv
482
-
483
-
484
- class ResNet_top(nn.Module):
485
- """ResNet top for RSN.
486
-
487
- Args:
488
- norm_cfg (dict): dictionary to construct and config norm layer.
489
- Default: dict(type='BN')
490
- channels (int): Number of channels of the feature output by ResNet_top.
491
- """
492
-
493
- def __init__(self, norm_cfg=dict(type='BN'), channels=64):
494
- # Protect mutable default arguments
495
- norm_cfg = cp.deepcopy(norm_cfg)
496
- super().__init__()
497
- self.top = nn.Sequential(
498
- ConvModule(
499
- 3,
500
- channels,
501
- kernel_size=7,
502
- stride=2,
503
- padding=3,
504
- norm_cfg=norm_cfg,
505
- inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1))
506
-
507
- def forward(self, img):
508
- return self.top(img)
509
-
510
-
511
- @BACKBONES.register_module()
512
- class RSN(BaseBackbone):
513
- """Residual Steps Network backbone. Paper ref: Cai et al. "Learning
514
- Delicate Local Representations for Multi-Person Pose Estimation" (ECCV
515
- 2020).
516
-
517
- Args:
518
- unit_channels (int): Number of Channels in an upsample unit.
519
- Default: 256
520
- num_stages (int): Number of stages in a multi-stage RSN. Default: 4
521
- num_units (int): NUmber of downsample/upsample units in a single-stage
522
- RSN. Default: 4 Note: Make sure num_units == len(self.num_blocks)
523
- num_blocks (list): Number of RSBs (Residual Steps Block) in each
524
- downsample unit. Default: [2, 2, 2, 2]
525
- num_steps (int): Number of steps in a RSB. Default:4
526
- norm_cfg (dict): dictionary to construct and config norm layer.
527
- Default: dict(type='BN')
528
- res_top_channels (int): Number of channels of feature from ResNet_top.
529
- Default: 64.
530
- expand_times (int): Times by which the in_channels are expanded in RSB.
531
- Default:26.
532
- Example:
533
- >>> from mmpose.models import RSN
534
- >>> import torch
535
- >>> self = RSN(num_stages=2,num_units=2,num_blocks=[2,2])
536
- >>> self.eval()
537
- >>> inputs = torch.rand(1, 3, 511, 511)
538
- >>> level_outputs = self.forward(inputs)
539
- >>> for level_output in level_outputs:
540
- ... for feature in level_output:
541
- ... print(tuple(feature.shape))
542
- ...
543
- (1, 256, 64, 64)
544
- (1, 256, 128, 128)
545
- (1, 256, 64, 64)
546
- (1, 256, 128, 128)
547
- """
548
-
549
- def __init__(self,
550
- unit_channels=256,
551
- num_stages=4,
552
- num_units=4,
553
- num_blocks=[2, 2, 2, 2],
554
- num_steps=4,
555
- norm_cfg=dict(type='BN'),
556
- res_top_channels=64,
557
- expand_times=26):
558
- # Protect mutable default arguments
559
- norm_cfg = cp.deepcopy(norm_cfg)
560
- num_blocks = cp.deepcopy(num_blocks)
561
- super().__init__()
562
- self.unit_channels = unit_channels
563
- self.num_stages = num_stages
564
- self.num_units = num_units
565
- self.num_blocks = num_blocks
566
- self.num_steps = num_steps
567
- self.norm_cfg = norm_cfg
568
-
569
- assert self.num_stages > 0
570
- assert self.num_steps > 1
571
- assert self.num_units > 1
572
- assert self.num_units == len(self.num_blocks)
573
- self.top = ResNet_top(norm_cfg=norm_cfg)
574
- self.multi_stage_rsn = nn.ModuleList([])
575
- for i in range(self.num_stages):
576
- if i == 0:
577
- has_skip = False
578
- else:
579
- has_skip = True
580
- if i != self.num_stages - 1:
581
- gen_skip = True
582
- gen_cross_conv = True
583
- else:
584
- gen_skip = False
585
- gen_cross_conv = False
586
- self.multi_stage_rsn.append(
587
- Single_stage_RSN(has_skip, gen_skip, gen_cross_conv,
588
- unit_channels, num_units, num_steps,
589
- num_blocks, norm_cfg, res_top_channels,
590
- expand_times))
591
-
592
- def forward(self, x):
593
- """Model forward function."""
594
- out_feats = []
595
- skip1 = None
596
- skip2 = None
597
- x = self.top(x)
598
- for i in range(self.num_stages):
599
- out, skip1, skip2, x = self.multi_stage_rsn[i](x, skip1, skip2)
600
- out_feats.append(out)
601
-
602
- return out_feats
603
-
604
- def init_weights(self, pretrained=None):
605
- """Initialize model weights."""
606
- for m in self.multi_stage_rsn.modules():
607
- if isinstance(m, nn.Conv2d):
608
- kaiming_init(m)
609
- elif isinstance(m, nn.BatchNorm2d):
610
- constant_init(m, 1)
611
- elif isinstance(m, nn.Linear):
612
- normal_init(m, std=0.01)
613
-
614
- for m in self.top.modules():
615
- if isinstance(m, nn.Conv2d):
616
- kaiming_init(m)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/scnet.py DELETED
@@ -1,248 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- import torch.utils.checkpoint as cp
8
- from mmcv.cnn import build_conv_layer, build_norm_layer
9
-
10
- from ..builder import BACKBONES
11
- from .resnet import Bottleneck, ResNet
12
-
13
-
14
- class SCConv(nn.Module):
15
- """SCConv (Self-calibrated Convolution)
16
-
17
- Args:
18
- in_channels (int): The input channels of the SCConv.
19
- out_channels (int): The output channel of the SCConv.
20
- stride (int): stride of SCConv.
21
- pooling_r (int): size of pooling for scconv.
22
- conv_cfg (dict): dictionary to construct and config conv layer.
23
- Default: None
24
- norm_cfg (dict): dictionary to construct and config norm layer.
25
- Default: dict(type='BN')
26
- """
27
-
28
- def __init__(self,
29
- in_channels,
30
- out_channels,
31
- stride,
32
- pooling_r,
33
- conv_cfg=None,
34
- norm_cfg=dict(type='BN', momentum=0.1)):
35
- # Protect mutable default arguments
36
- norm_cfg = copy.deepcopy(norm_cfg)
37
- super().__init__()
38
-
39
- assert in_channels == out_channels
40
-
41
- self.k2 = nn.Sequential(
42
- nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r),
43
- build_conv_layer(
44
- conv_cfg,
45
- in_channels,
46
- in_channels,
47
- kernel_size=3,
48
- stride=1,
49
- padding=1,
50
- bias=False),
51
- build_norm_layer(norm_cfg, in_channels)[1],
52
- )
53
- self.k3 = nn.Sequential(
54
- build_conv_layer(
55
- conv_cfg,
56
- in_channels,
57
- in_channels,
58
- kernel_size=3,
59
- stride=1,
60
- padding=1,
61
- bias=False),
62
- build_norm_layer(norm_cfg, in_channels)[1],
63
- )
64
- self.k4 = nn.Sequential(
65
- build_conv_layer(
66
- conv_cfg,
67
- in_channels,
68
- in_channels,
69
- kernel_size=3,
70
- stride=stride,
71
- padding=1,
72
- bias=False),
73
- build_norm_layer(norm_cfg, out_channels)[1],
74
- nn.ReLU(inplace=True),
75
- )
76
-
77
- def forward(self, x):
78
- """Forward function."""
79
- identity = x
80
-
81
- out = torch.sigmoid(
82
- torch.add(identity, F.interpolate(self.k2(x),
83
- identity.size()[2:])))
84
- out = torch.mul(self.k3(x), out)
85
- out = self.k4(out)
86
-
87
- return out
88
-
89
-
90
- class SCBottleneck(Bottleneck):
91
- """SC(Self-calibrated) Bottleneck.
92
-
93
- Args:
94
- in_channels (int): The input channels of the SCBottleneck block.
95
- out_channels (int): The output channel of the SCBottleneck block.
96
- """
97
-
98
- pooling_r = 4
99
-
100
- def __init__(self, in_channels, out_channels, **kwargs):
101
- super().__init__(in_channels, out_channels, **kwargs)
102
- self.mid_channels = out_channels // self.expansion // 2
103
-
104
- self.norm1_name, norm1 = build_norm_layer(
105
- self.norm_cfg, self.mid_channels, postfix=1)
106
- self.norm2_name, norm2 = build_norm_layer(
107
- self.norm_cfg, self.mid_channels, postfix=2)
108
- self.norm3_name, norm3 = build_norm_layer(
109
- self.norm_cfg, out_channels, postfix=3)
110
-
111
- self.conv1 = build_conv_layer(
112
- self.conv_cfg,
113
- in_channels,
114
- self.mid_channels,
115
- kernel_size=1,
116
- stride=1,
117
- bias=False)
118
- self.add_module(self.norm1_name, norm1)
119
-
120
- self.k1 = nn.Sequential(
121
- build_conv_layer(
122
- self.conv_cfg,
123
- self.mid_channels,
124
- self.mid_channels,
125
- kernel_size=3,
126
- stride=self.stride,
127
- padding=1,
128
- bias=False),
129
- build_norm_layer(self.norm_cfg, self.mid_channels)[1],
130
- nn.ReLU(inplace=True))
131
-
132
- self.conv2 = build_conv_layer(
133
- self.conv_cfg,
134
- in_channels,
135
- self.mid_channels,
136
- kernel_size=1,
137
- stride=1,
138
- bias=False)
139
- self.add_module(self.norm2_name, norm2)
140
-
141
- self.scconv = SCConv(self.mid_channels, self.mid_channels, self.stride,
142
- self.pooling_r, self.conv_cfg, self.norm_cfg)
143
-
144
- self.conv3 = build_conv_layer(
145
- self.conv_cfg,
146
- self.mid_channels * 2,
147
- out_channels,
148
- kernel_size=1,
149
- stride=1,
150
- bias=False)
151
- self.add_module(self.norm3_name, norm3)
152
-
153
- def forward(self, x):
154
- """Forward function."""
155
-
156
- def _inner_forward(x):
157
- identity = x
158
-
159
- out_a = self.conv1(x)
160
- out_a = self.norm1(out_a)
161
- out_a = self.relu(out_a)
162
-
163
- out_a = self.k1(out_a)
164
-
165
- out_b = self.conv2(x)
166
- out_b = self.norm2(out_b)
167
- out_b = self.relu(out_b)
168
-
169
- out_b = self.scconv(out_b)
170
-
171
- out = self.conv3(torch.cat([out_a, out_b], dim=1))
172
- out = self.norm3(out)
173
-
174
- if self.downsample is not None:
175
- identity = self.downsample(x)
176
-
177
- out += identity
178
-
179
- return out
180
-
181
- if self.with_cp and x.requires_grad:
182
- out = cp.checkpoint(_inner_forward, x)
183
- else:
184
- out = _inner_forward(x)
185
-
186
- out = self.relu(out)
187
-
188
- return out
189
-
190
-
191
- @BACKBONES.register_module()
192
- class SCNet(ResNet):
193
- """SCNet backbone.
194
-
195
- Improving Convolutional Networks with Self-Calibrated Convolutions,
196
- Jiang-Jiang Liu, Qibin Hou, Ming-Ming Cheng, Changhu Wang, Jiashi Feng,
197
- IEEE CVPR, 2020.
198
- http://mftp.mmcheng.net/Papers/20cvprSCNet.pdf
199
-
200
- Args:
201
- depth (int): Depth of scnet, from {50, 101}.
202
- in_channels (int): Number of input image channels. Normally 3.
203
- base_channels (int): Number of base channels of hidden layer.
204
- num_stages (int): SCNet stages, normally 4.
205
- strides (Sequence[int]): Strides of the first block of each stage.
206
- dilations (Sequence[int]): Dilation of each stage.
207
- out_indices (Sequence[int]): Output from which stages.
208
- style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
209
- layer is the 3x3 conv layer, otherwise the stride-two layer is
210
- the first 1x1 conv layer.
211
- deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
212
- avg_down (bool): Use AvgPool instead of stride conv when
213
- downsampling in the bottleneck.
214
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
215
- -1 means not freezing any parameters.
216
- norm_cfg (dict): Dictionary to construct and config norm layer.
217
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
218
- freeze running stats (mean and var). Note: Effect on Batch Norm
219
- and its variants only.
220
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
221
- memory while slowing down the training speed.
222
- zero_init_residual (bool): Whether to use zero init for last norm layer
223
- in resblocks to let them behave as identity.
224
-
225
- Example:
226
- >>> from mmpose.models import SCNet
227
- >>> import torch
228
- >>> self = SCNet(depth=50, out_indices=(0, 1, 2, 3))
229
- >>> self.eval()
230
- >>> inputs = torch.rand(1, 3, 224, 224)
231
- >>> level_outputs = self.forward(inputs)
232
- >>> for level_out in level_outputs:
233
- ... print(tuple(level_out.shape))
234
- (1, 256, 56, 56)
235
- (1, 512, 28, 28)
236
- (1, 1024, 14, 14)
237
- (1, 2048, 7, 7)
238
- """
239
-
240
- arch_settings = {
241
- 50: (SCBottleneck, [3, 4, 6, 3]),
242
- 101: (SCBottleneck, [3, 4, 23, 3])
243
- }
244
-
245
- def __init__(self, depth, **kwargs):
246
- if depth not in self.arch_settings:
247
- raise KeyError(f'invalid depth {depth} for SCNet')
248
- super().__init__(depth, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/seresnet.py DELETED
@@ -1,125 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import torch.utils.checkpoint as cp
3
-
4
- from ..builder import BACKBONES
5
- from .resnet import Bottleneck, ResLayer, ResNet
6
- from .utils.se_layer import SELayer
7
-
8
-
9
- class SEBottleneck(Bottleneck):
10
- """SEBottleneck block for SEResNet.
11
-
12
- Args:
13
- in_channels (int): The input channels of the SEBottleneck block.
14
- out_channels (int): The output channel of the SEBottleneck block.
15
- se_ratio (int): Squeeze ratio in SELayer. Default: 16
16
- """
17
-
18
- def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs):
19
- super().__init__(in_channels, out_channels, **kwargs)
20
- self.se_layer = SELayer(out_channels, ratio=se_ratio)
21
-
22
- def forward(self, x):
23
-
24
- def _inner_forward(x):
25
- identity = x
26
-
27
- out = self.conv1(x)
28
- out = self.norm1(out)
29
- out = self.relu(out)
30
-
31
- out = self.conv2(out)
32
- out = self.norm2(out)
33
- out = self.relu(out)
34
-
35
- out = self.conv3(out)
36
- out = self.norm3(out)
37
-
38
- out = self.se_layer(out)
39
-
40
- if self.downsample is not None:
41
- identity = self.downsample(x)
42
-
43
- out += identity
44
-
45
- return out
46
-
47
- if self.with_cp and x.requires_grad:
48
- out = cp.checkpoint(_inner_forward, x)
49
- else:
50
- out = _inner_forward(x)
51
-
52
- out = self.relu(out)
53
-
54
- return out
55
-
56
-
57
- @BACKBONES.register_module()
58
- class SEResNet(ResNet):
59
- """SEResNet backbone.
60
-
61
- Please refer to the `paper <https://arxiv.org/abs/1709.01507>`__ for
62
- details.
63
-
64
- Args:
65
- depth (int): Network depth, from {50, 101, 152}.
66
- se_ratio (int): Squeeze ratio in SELayer. Default: 16.
67
- in_channels (int): Number of input image channels. Default: 3.
68
- stem_channels (int): Output channels of the stem layer. Default: 64.
69
- num_stages (int): Stages of the network. Default: 4.
70
- strides (Sequence[int]): Strides of the first block of each stage.
71
- Default: ``(1, 2, 2, 2)``.
72
- dilations (Sequence[int]): Dilation of each stage.
73
- Default: ``(1, 1, 1, 1)``.
74
- out_indices (Sequence[int]): Output from which stages. If only one
75
- stage is specified, a single tensor (feature map) is returned,
76
- otherwise multiple stages are specified, a tuple of tensors will
77
- be returned. Default: ``(3, )``.
78
- style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
79
- layer is the 3x3 conv layer, otherwise the stride-two layer is
80
- the first 1x1 conv layer.
81
- deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
82
- Default: False.
83
- avg_down (bool): Use AvgPool instead of stride conv when
84
- downsampling in the bottleneck. Default: False.
85
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
86
- -1 means not freezing any parameters. Default: -1.
87
- conv_cfg (dict | None): The config dict for conv layers. Default: None.
88
- norm_cfg (dict): The config dict for norm layers.
89
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
90
- freeze running stats (mean and var). Note: Effect on Batch Norm
91
- and its variants only. Default: False.
92
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
93
- memory while slowing down the training speed. Default: False.
94
- zero_init_residual (bool): Whether to use zero init for last norm layer
95
- in resblocks to let them behave as identity. Default: True.
96
-
97
- Example:
98
- >>> from mmpose.models import SEResNet
99
- >>> import torch
100
- >>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3))
101
- >>> self.eval()
102
- >>> inputs = torch.rand(1, 3, 224, 224)
103
- >>> level_outputs = self.forward(inputs)
104
- >>> for level_out in level_outputs:
105
- ... print(tuple(level_out.shape))
106
- (1, 256, 56, 56)
107
- (1, 512, 28, 28)
108
- (1, 1024, 14, 14)
109
- (1, 2048, 7, 7)
110
- """
111
-
112
- arch_settings = {
113
- 50: (SEBottleneck, (3, 4, 6, 3)),
114
- 101: (SEBottleneck, (3, 4, 23, 3)),
115
- 152: (SEBottleneck, (3, 8, 36, 3))
116
- }
117
-
118
- def __init__(self, depth, se_ratio=16, **kwargs):
119
- if depth not in self.arch_settings:
120
- raise KeyError(f'invalid depth {depth} for SEResNet')
121
- self.se_ratio = se_ratio
122
- super().__init__(depth, **kwargs)
123
-
124
- def make_res_layer(self, **kwargs):
125
- return ResLayer(se_ratio=self.se_ratio, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/seresnext.py DELETED
@@ -1,168 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from mmcv.cnn import build_conv_layer, build_norm_layer
3
-
4
- from ..builder import BACKBONES
5
- from .resnet import ResLayer
6
- from .seresnet import SEBottleneck as _SEBottleneck
7
- from .seresnet import SEResNet
8
-
9
-
10
- class SEBottleneck(_SEBottleneck):
11
- """SEBottleneck block for SEResNeXt.
12
-
13
- Args:
14
- in_channels (int): Input channels of this block.
15
- out_channels (int): Output channels of this block.
16
- base_channels (int): Middle channels of the first stage. Default: 64.
17
- groups (int): Groups of conv2.
18
- width_per_group (int): Width per group of conv2. 64x4d indicates
19
- ``groups=64, width_per_group=4`` and 32x8d indicates
20
- ``groups=32, width_per_group=8``.
21
- stride (int): stride of the block. Default: 1
22
- dilation (int): dilation of convolution. Default: 1
23
- downsample (nn.Module): downsample operation on identity branch.
24
- Default: None
25
- se_ratio (int): Squeeze ratio in SELayer. Default: 16
26
- style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
27
- layer is the 3x3 conv layer, otherwise the stride-two layer is
28
- the first 1x1 conv layer.
29
- conv_cfg (dict): dictionary to construct and config conv layer.
30
- Default: None
31
- norm_cfg (dict): dictionary to construct and config norm layer.
32
- Default: dict(type='BN')
33
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
34
- memory while slowing down the training speed.
35
- """
36
-
37
- def __init__(self,
38
- in_channels,
39
- out_channels,
40
- base_channels=64,
41
- groups=32,
42
- width_per_group=4,
43
- se_ratio=16,
44
- **kwargs):
45
- super().__init__(in_channels, out_channels, se_ratio, **kwargs)
46
- self.groups = groups
47
- self.width_per_group = width_per_group
48
-
49
- # We follow the same rational of ResNext to compute mid_channels.
50
- # For SEResNet bottleneck, middle channels are determined by expansion
51
- # and out_channels, but for SEResNeXt bottleneck, it is determined by
52
- # groups and width_per_group and the stage it is located in.
53
- if groups != 1:
54
- assert self.mid_channels % base_channels == 0
55
- self.mid_channels = (
56
- groups * width_per_group * self.mid_channels // base_channels)
57
-
58
- self.norm1_name, norm1 = build_norm_layer(
59
- self.norm_cfg, self.mid_channels, postfix=1)
60
- self.norm2_name, norm2 = build_norm_layer(
61
- self.norm_cfg, self.mid_channels, postfix=2)
62
- self.norm3_name, norm3 = build_norm_layer(
63
- self.norm_cfg, self.out_channels, postfix=3)
64
-
65
- self.conv1 = build_conv_layer(
66
- self.conv_cfg,
67
- self.in_channels,
68
- self.mid_channels,
69
- kernel_size=1,
70
- stride=self.conv1_stride,
71
- bias=False)
72
- self.add_module(self.norm1_name, norm1)
73
- self.conv2 = build_conv_layer(
74
- self.conv_cfg,
75
- self.mid_channels,
76
- self.mid_channels,
77
- kernel_size=3,
78
- stride=self.conv2_stride,
79
- padding=self.dilation,
80
- dilation=self.dilation,
81
- groups=groups,
82
- bias=False)
83
-
84
- self.add_module(self.norm2_name, norm2)
85
- self.conv3 = build_conv_layer(
86
- self.conv_cfg,
87
- self.mid_channels,
88
- self.out_channels,
89
- kernel_size=1,
90
- bias=False)
91
- self.add_module(self.norm3_name, norm3)
92
-
93
-
94
- @BACKBONES.register_module()
95
- class SEResNeXt(SEResNet):
96
- """SEResNeXt backbone.
97
-
98
- Please refer to the `paper <https://arxiv.org/abs/1709.01507>`__ for
99
- details.
100
-
101
- Args:
102
- depth (int): Network depth, from {50, 101, 152}.
103
- groups (int): Groups of conv2 in Bottleneck. Default: 32.
104
- width_per_group (int): Width per group of conv2 in Bottleneck.
105
- Default: 4.
106
- se_ratio (int): Squeeze ratio in SELayer. Default: 16.
107
- in_channels (int): Number of input image channels. Default: 3.
108
- stem_channels (int): Output channels of the stem layer. Default: 64.
109
- num_stages (int): Stages of the network. Default: 4.
110
- strides (Sequence[int]): Strides of the first block of each stage.
111
- Default: ``(1, 2, 2, 2)``.
112
- dilations (Sequence[int]): Dilation of each stage.
113
- Default: ``(1, 1, 1, 1)``.
114
- out_indices (Sequence[int]): Output from which stages. If only one
115
- stage is specified, a single tensor (feature map) is returned,
116
- otherwise multiple stages are specified, a tuple of tensors will
117
- be returned. Default: ``(3, )``.
118
- style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
119
- layer is the 3x3 conv layer, otherwise the stride-two layer is
120
- the first 1x1 conv layer.
121
- deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
122
- Default: False.
123
- avg_down (bool): Use AvgPool instead of stride conv when
124
- downsampling in the bottleneck. Default: False.
125
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
126
- -1 means not freezing any parameters. Default: -1.
127
- conv_cfg (dict | None): The config dict for conv layers. Default: None.
128
- norm_cfg (dict): The config dict for norm layers.
129
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
130
- freeze running stats (mean and var). Note: Effect on Batch Norm
131
- and its variants only. Default: False.
132
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
133
- memory while slowing down the training speed. Default: False.
134
- zero_init_residual (bool): Whether to use zero init for last norm layer
135
- in resblocks to let them behave as identity. Default: True.
136
-
137
- Example:
138
- >>> from mmpose.models import SEResNeXt
139
- >>> import torch
140
- >>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3))
141
- >>> self.eval()
142
- >>> inputs = torch.rand(1, 3, 224, 224)
143
- >>> level_outputs = self.forward(inputs)
144
- >>> for level_out in level_outputs:
145
- ... print(tuple(level_out.shape))
146
- (1, 256, 56, 56)
147
- (1, 512, 28, 28)
148
- (1, 1024, 14, 14)
149
- (1, 2048, 7, 7)
150
- """
151
-
152
- arch_settings = {
153
- 50: (SEBottleneck, (3, 4, 6, 3)),
154
- 101: (SEBottleneck, (3, 4, 23, 3)),
155
- 152: (SEBottleneck, (3, 8, 36, 3))
156
- }
157
-
158
- def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
159
- self.groups = groups
160
- self.width_per_group = width_per_group
161
- super().__init__(depth, **kwargs)
162
-
163
- def make_res_layer(self, **kwargs):
164
- return ResLayer(
165
- groups=self.groups,
166
- width_per_group=self.width_per_group,
167
- base_channels=self.base_channels,
168
- **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/shufflenet_v1.py DELETED
@@ -1,329 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
- import logging
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.utils.checkpoint as cp
8
- from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
9
- normal_init)
10
- from torch.nn.modules.batchnorm import _BatchNorm
11
-
12
- from ..builder import BACKBONES
13
- from .base_backbone import BaseBackbone
14
- from .utils import channel_shuffle, load_checkpoint, make_divisible
15
-
16
-
17
- class ShuffleUnit(nn.Module):
18
- """ShuffleUnit block.
19
-
20
- ShuffleNet unit with pointwise group convolution (GConv) and channel
21
- shuffle.
22
-
23
- Args:
24
- in_channels (int): The input channels of the ShuffleUnit.
25
- out_channels (int): The output channels of the ShuffleUnit.
26
- groups (int, optional): The number of groups to be used in grouped 1x1
27
- convolutions in each ShuffleUnit. Default: 3
28
- first_block (bool, optional): Whether it is the first ShuffleUnit of a
29
- sequential ShuffleUnits. Default: True, which means not using the
30
- grouped 1x1 convolution.
31
- combine (str, optional): The ways to combine the input and output
32
- branches. Default: 'add'.
33
- conv_cfg (dict): Config dict for convolution layer. Default: None,
34
- which means using conv2d.
35
- norm_cfg (dict): Config dict for normalization layer.
36
- Default: dict(type='BN').
37
- act_cfg (dict): Config dict for activation layer.
38
- Default: dict(type='ReLU').
39
- with_cp (bool, optional): Use checkpoint or not. Using checkpoint
40
- will save some memory while slowing down the training speed.
41
- Default: False.
42
-
43
- Returns:
44
- Tensor: The output tensor.
45
- """
46
-
47
- def __init__(self,
48
- in_channels,
49
- out_channels,
50
- groups=3,
51
- first_block=True,
52
- combine='add',
53
- conv_cfg=None,
54
- norm_cfg=dict(type='BN'),
55
- act_cfg=dict(type='ReLU'),
56
- with_cp=False):
57
- # Protect mutable default arguments
58
- norm_cfg = copy.deepcopy(norm_cfg)
59
- act_cfg = copy.deepcopy(act_cfg)
60
- super().__init__()
61
- self.in_channels = in_channels
62
- self.out_channels = out_channels
63
- self.first_block = first_block
64
- self.combine = combine
65
- self.groups = groups
66
- self.bottleneck_channels = self.out_channels // 4
67
- self.with_cp = with_cp
68
-
69
- if self.combine == 'add':
70
- self.depthwise_stride = 1
71
- self._combine_func = self._add
72
- assert in_channels == out_channels, (
73
- 'in_channels must be equal to out_channels when combine '
74
- 'is add')
75
- elif self.combine == 'concat':
76
- self.depthwise_stride = 2
77
- self._combine_func = self._concat
78
- self.out_channels -= self.in_channels
79
- self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
80
- else:
81
- raise ValueError(f'Cannot combine tensors with {self.combine}. '
82
- 'Only "add" and "concat" are supported')
83
-
84
- self.first_1x1_groups = 1 if first_block else self.groups
85
- self.g_conv_1x1_compress = ConvModule(
86
- in_channels=self.in_channels,
87
- out_channels=self.bottleneck_channels,
88
- kernel_size=1,
89
- groups=self.first_1x1_groups,
90
- conv_cfg=conv_cfg,
91
- norm_cfg=norm_cfg,
92
- act_cfg=act_cfg)
93
-
94
- self.depthwise_conv3x3_bn = ConvModule(
95
- in_channels=self.bottleneck_channels,
96
- out_channels=self.bottleneck_channels,
97
- kernel_size=3,
98
- stride=self.depthwise_stride,
99
- padding=1,
100
- groups=self.bottleneck_channels,
101
- conv_cfg=conv_cfg,
102
- norm_cfg=norm_cfg,
103
- act_cfg=None)
104
-
105
- self.g_conv_1x1_expand = ConvModule(
106
- in_channels=self.bottleneck_channels,
107
- out_channels=self.out_channels,
108
- kernel_size=1,
109
- groups=self.groups,
110
- conv_cfg=conv_cfg,
111
- norm_cfg=norm_cfg,
112
- act_cfg=None)
113
-
114
- self.act = build_activation_layer(act_cfg)
115
-
116
- @staticmethod
117
- def _add(x, out):
118
- # residual connection
119
- return x + out
120
-
121
- @staticmethod
122
- def _concat(x, out):
123
- # concatenate along channel axis
124
- return torch.cat((x, out), 1)
125
-
126
- def forward(self, x):
127
-
128
- def _inner_forward(x):
129
- residual = x
130
-
131
- out = self.g_conv_1x1_compress(x)
132
- out = self.depthwise_conv3x3_bn(out)
133
-
134
- if self.groups > 1:
135
- out = channel_shuffle(out, self.groups)
136
-
137
- out = self.g_conv_1x1_expand(out)
138
-
139
- if self.combine == 'concat':
140
- residual = self.avgpool(residual)
141
- out = self.act(out)
142
- out = self._combine_func(residual, out)
143
- else:
144
- out = self._combine_func(residual, out)
145
- out = self.act(out)
146
- return out
147
-
148
- if self.with_cp and x.requires_grad:
149
- out = cp.checkpoint(_inner_forward, x)
150
- else:
151
- out = _inner_forward(x)
152
-
153
- return out
154
-
155
-
156
- @BACKBONES.register_module()
157
- class ShuffleNetV1(BaseBackbone):
158
- """ShuffleNetV1 backbone.
159
-
160
- Args:
161
- groups (int, optional): The number of groups to be used in grouped 1x1
162
- convolutions in each ShuffleUnit. Default: 3.
163
- widen_factor (float, optional): Width multiplier - adjusts the number
164
- of channels in each layer by this amount. Default: 1.0.
165
- out_indices (Sequence[int]): Output from which stages.
166
- Default: (2, )
167
- frozen_stages (int): Stages to be frozen (all param fixed).
168
- Default: -1, which means not freezing any parameters.
169
- conv_cfg (dict): Config dict for convolution layer. Default: None,
170
- which means using conv2d.
171
- norm_cfg (dict): Config dict for normalization layer.
172
- Default: dict(type='BN').
173
- act_cfg (dict): Config dict for activation layer.
174
- Default: dict(type='ReLU').
175
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
176
- freeze running stats (mean and var). Note: Effect on Batch Norm
177
- and its variants only. Default: False.
178
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
179
- memory while slowing down the training speed. Default: False.
180
- """
181
-
182
- def __init__(self,
183
- groups=3,
184
- widen_factor=1.0,
185
- out_indices=(2, ),
186
- frozen_stages=-1,
187
- conv_cfg=None,
188
- norm_cfg=dict(type='BN'),
189
- act_cfg=dict(type='ReLU'),
190
- norm_eval=False,
191
- with_cp=False):
192
- # Protect mutable default arguments
193
- norm_cfg = copy.deepcopy(norm_cfg)
194
- act_cfg = copy.deepcopy(act_cfg)
195
- super().__init__()
196
- self.stage_blocks = [4, 8, 4]
197
- self.groups = groups
198
-
199
- for index in out_indices:
200
- if index not in range(0, 3):
201
- raise ValueError('the item in out_indices must in '
202
- f'range(0, 3). But received {index}')
203
-
204
- if frozen_stages not in range(-1, 3):
205
- raise ValueError('frozen_stages must be in range(-1, 3). '
206
- f'But received {frozen_stages}')
207
- self.out_indices = out_indices
208
- self.frozen_stages = frozen_stages
209
- self.conv_cfg = conv_cfg
210
- self.norm_cfg = norm_cfg
211
- self.act_cfg = act_cfg
212
- self.norm_eval = norm_eval
213
- self.with_cp = with_cp
214
-
215
- if groups == 1:
216
- channels = (144, 288, 576)
217
- elif groups == 2:
218
- channels = (200, 400, 800)
219
- elif groups == 3:
220
- channels = (240, 480, 960)
221
- elif groups == 4:
222
- channels = (272, 544, 1088)
223
- elif groups == 8:
224
- channels = (384, 768, 1536)
225
- else:
226
- raise ValueError(f'{groups} groups is not supported for 1x1 '
227
- 'Grouped Convolutions')
228
-
229
- channels = [make_divisible(ch * widen_factor, 8) for ch in channels]
230
-
231
- self.in_channels = int(24 * widen_factor)
232
-
233
- self.conv1 = ConvModule(
234
- in_channels=3,
235
- out_channels=self.in_channels,
236
- kernel_size=3,
237
- stride=2,
238
- padding=1,
239
- conv_cfg=conv_cfg,
240
- norm_cfg=norm_cfg,
241
- act_cfg=act_cfg)
242
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
243
-
244
- self.layers = nn.ModuleList()
245
- for i, num_blocks in enumerate(self.stage_blocks):
246
- first_block = (i == 0)
247
- layer = self.make_layer(channels[i], num_blocks, first_block)
248
- self.layers.append(layer)
249
-
250
- def _freeze_stages(self):
251
- if self.frozen_stages >= 0:
252
- for param in self.conv1.parameters():
253
- param.requires_grad = False
254
- for i in range(self.frozen_stages):
255
- layer = self.layers[i]
256
- layer.eval()
257
- for param in layer.parameters():
258
- param.requires_grad = False
259
-
260
- def init_weights(self, pretrained=None):
261
- if isinstance(pretrained, str):
262
- logger = logging.getLogger()
263
- load_checkpoint(self, pretrained, strict=False, logger=logger)
264
- elif pretrained is None:
265
- for name, m in self.named_modules():
266
- if isinstance(m, nn.Conv2d):
267
- if 'conv1' in name:
268
- normal_init(m, mean=0, std=0.01)
269
- else:
270
- normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
271
- elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
272
- constant_init(m, val=1, bias=0.0001)
273
- if isinstance(m, _BatchNorm):
274
- if m.running_mean is not None:
275
- nn.init.constant_(m.running_mean, 0)
276
- else:
277
- raise TypeError('pretrained must be a str or None. But received '
278
- f'{type(pretrained)}')
279
-
280
- def make_layer(self, out_channels, num_blocks, first_block=False):
281
- """Stack ShuffleUnit blocks to make a layer.
282
-
283
- Args:
284
- out_channels (int): out_channels of the block.
285
- num_blocks (int): Number of blocks.
286
- first_block (bool, optional): Whether is the first ShuffleUnit of a
287
- sequential ShuffleUnits. Default: False, which means using
288
- the grouped 1x1 convolution.
289
- """
290
- layers = []
291
- for i in range(num_blocks):
292
- first_block = first_block if i == 0 else False
293
- combine_mode = 'concat' if i == 0 else 'add'
294
- layers.append(
295
- ShuffleUnit(
296
- self.in_channels,
297
- out_channels,
298
- groups=self.groups,
299
- first_block=first_block,
300
- combine=combine_mode,
301
- conv_cfg=self.conv_cfg,
302
- norm_cfg=self.norm_cfg,
303
- act_cfg=self.act_cfg,
304
- with_cp=self.with_cp))
305
- self.in_channels = out_channels
306
-
307
- return nn.Sequential(*layers)
308
-
309
- def forward(self, x):
310
- x = self.conv1(x)
311
- x = self.maxpool(x)
312
-
313
- outs = []
314
- for i, layer in enumerate(self.layers):
315
- x = layer(x)
316
- if i in self.out_indices:
317
- outs.append(x)
318
-
319
- if len(outs) == 1:
320
- return outs[0]
321
- return tuple(outs)
322
-
323
- def train(self, mode=True):
324
- super().train(mode)
325
- self._freeze_stages()
326
- if mode and self.norm_eval:
327
- for m in self.modules():
328
- if isinstance(m, _BatchNorm):
329
- m.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/shufflenet_v2.py DELETED
@@ -1,302 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
- import logging
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.utils.checkpoint as cp
8
- from mmcv.cnn import ConvModule, constant_init, normal_init
9
- from torch.nn.modules.batchnorm import _BatchNorm
10
-
11
- from ..builder import BACKBONES
12
- from .base_backbone import BaseBackbone
13
- from .utils import channel_shuffle, load_checkpoint
14
-
15
-
16
- class InvertedResidual(nn.Module):
17
- """InvertedResidual block for ShuffleNetV2 backbone.
18
-
19
- Args:
20
- in_channels (int): The input channels of the block.
21
- out_channels (int): The output channels of the block.
22
- stride (int): Stride of the 3x3 convolution layer. Default: 1
23
- conv_cfg (dict): Config dict for convolution layer.
24
- Default: None, which means using conv2d.
25
- norm_cfg (dict): Config dict for normalization layer.
26
- Default: dict(type='BN').
27
- act_cfg (dict): Config dict for activation layer.
28
- Default: dict(type='ReLU').
29
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
30
- memory while slowing down the training speed. Default: False.
31
- """
32
-
33
- def __init__(self,
34
- in_channels,
35
- out_channels,
36
- stride=1,
37
- conv_cfg=None,
38
- norm_cfg=dict(type='BN'),
39
- act_cfg=dict(type='ReLU'),
40
- with_cp=False):
41
- # Protect mutable default arguments
42
- norm_cfg = copy.deepcopy(norm_cfg)
43
- act_cfg = copy.deepcopy(act_cfg)
44
- super().__init__()
45
- self.stride = stride
46
- self.with_cp = with_cp
47
-
48
- branch_features = out_channels // 2
49
- if self.stride == 1:
50
- assert in_channels == branch_features * 2, (
51
- f'in_channels ({in_channels}) should equal to '
52
- f'branch_features * 2 ({branch_features * 2}) '
53
- 'when stride is 1')
54
-
55
- if in_channels != branch_features * 2:
56
- assert self.stride != 1, (
57
- f'stride ({self.stride}) should not equal 1 when '
58
- f'in_channels != branch_features * 2')
59
-
60
- if self.stride > 1:
61
- self.branch1 = nn.Sequential(
62
- ConvModule(
63
- in_channels,
64
- in_channels,
65
- kernel_size=3,
66
- stride=self.stride,
67
- padding=1,
68
- groups=in_channels,
69
- conv_cfg=conv_cfg,
70
- norm_cfg=norm_cfg,
71
- act_cfg=None),
72
- ConvModule(
73
- in_channels,
74
- branch_features,
75
- kernel_size=1,
76
- stride=1,
77
- padding=0,
78
- conv_cfg=conv_cfg,
79
- norm_cfg=norm_cfg,
80
- act_cfg=act_cfg),
81
- )
82
-
83
- self.branch2 = nn.Sequential(
84
- ConvModule(
85
- in_channels if (self.stride > 1) else branch_features,
86
- branch_features,
87
- kernel_size=1,
88
- stride=1,
89
- padding=0,
90
- conv_cfg=conv_cfg,
91
- norm_cfg=norm_cfg,
92
- act_cfg=act_cfg),
93
- ConvModule(
94
- branch_features,
95
- branch_features,
96
- kernel_size=3,
97
- stride=self.stride,
98
- padding=1,
99
- groups=branch_features,
100
- conv_cfg=conv_cfg,
101
- norm_cfg=norm_cfg,
102
- act_cfg=None),
103
- ConvModule(
104
- branch_features,
105
- branch_features,
106
- kernel_size=1,
107
- stride=1,
108
- padding=0,
109
- conv_cfg=conv_cfg,
110
- norm_cfg=norm_cfg,
111
- act_cfg=act_cfg))
112
-
113
- def forward(self, x):
114
-
115
- def _inner_forward(x):
116
- if self.stride > 1:
117
- out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
118
- else:
119
- x1, x2 = x.chunk(2, dim=1)
120
- out = torch.cat((x1, self.branch2(x2)), dim=1)
121
-
122
- out = channel_shuffle(out, 2)
123
-
124
- return out
125
-
126
- if self.with_cp and x.requires_grad:
127
- out = cp.checkpoint(_inner_forward, x)
128
- else:
129
- out = _inner_forward(x)
130
-
131
- return out
132
-
133
-
134
- @BACKBONES.register_module()
135
- class ShuffleNetV2(BaseBackbone):
136
- """ShuffleNetV2 backbone.
137
-
138
- Args:
139
- widen_factor (float): Width multiplier - adjusts the number of
140
- channels in each layer by this amount. Default: 1.0.
141
- out_indices (Sequence[int]): Output from which stages.
142
- Default: (0, 1, 2, 3).
143
- frozen_stages (int): Stages to be frozen (all param fixed).
144
- Default: -1, which means not freezing any parameters.
145
- conv_cfg (dict): Config dict for convolution layer.
146
- Default: None, which means using conv2d.
147
- norm_cfg (dict): Config dict for normalization layer.
148
- Default: dict(type='BN').
149
- act_cfg (dict): Config dict for activation layer.
150
- Default: dict(type='ReLU').
151
- norm_eval (bool): Whether to set norm layers to eval mode, namely,
152
- freeze running stats (mean and var). Note: Effect on Batch Norm
153
- and its variants only. Default: False.
154
- with_cp (bool): Use checkpoint or not. Using checkpoint will save some
155
- memory while slowing down the training speed. Default: False.
156
- """
157
-
158
- def __init__(self,
159
- widen_factor=1.0,
160
- out_indices=(3, ),
161
- frozen_stages=-1,
162
- conv_cfg=None,
163
- norm_cfg=dict(type='BN'),
164
- act_cfg=dict(type='ReLU'),
165
- norm_eval=False,
166
- with_cp=False):
167
- # Protect mutable default arguments
168
- norm_cfg = copy.deepcopy(norm_cfg)
169
- act_cfg = copy.deepcopy(act_cfg)
170
- super().__init__()
171
- self.stage_blocks = [4, 8, 4]
172
- for index in out_indices:
173
- if index not in range(0, 4):
174
- raise ValueError('the item in out_indices must in '
175
- f'range(0, 4). But received {index}')
176
-
177
- if frozen_stages not in range(-1, 4):
178
- raise ValueError('frozen_stages must be in range(-1, 4). '
179
- f'But received {frozen_stages}')
180
- self.out_indices = out_indices
181
- self.frozen_stages = frozen_stages
182
- self.conv_cfg = conv_cfg
183
- self.norm_cfg = norm_cfg
184
- self.act_cfg = act_cfg
185
- self.norm_eval = norm_eval
186
- self.with_cp = with_cp
187
-
188
- if widen_factor == 0.5:
189
- channels = [48, 96, 192, 1024]
190
- elif widen_factor == 1.0:
191
- channels = [116, 232, 464, 1024]
192
- elif widen_factor == 1.5:
193
- channels = [176, 352, 704, 1024]
194
- elif widen_factor == 2.0:
195
- channels = [244, 488, 976, 2048]
196
- else:
197
- raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. '
198
- f'But received {widen_factor}')
199
-
200
- self.in_channels = 24
201
- self.conv1 = ConvModule(
202
- in_channels=3,
203
- out_channels=self.in_channels,
204
- kernel_size=3,
205
- stride=2,
206
- padding=1,
207
- conv_cfg=conv_cfg,
208
- norm_cfg=norm_cfg,
209
- act_cfg=act_cfg)
210
-
211
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
212
-
213
- self.layers = nn.ModuleList()
214
- for i, num_blocks in enumerate(self.stage_blocks):
215
- layer = self._make_layer(channels[i], num_blocks)
216
- self.layers.append(layer)
217
-
218
- output_channels = channels[-1]
219
- self.layers.append(
220
- ConvModule(
221
- in_channels=self.in_channels,
222
- out_channels=output_channels,
223
- kernel_size=1,
224
- conv_cfg=conv_cfg,
225
- norm_cfg=norm_cfg,
226
- act_cfg=act_cfg))
227
-
228
- def _make_layer(self, out_channels, num_blocks):
229
- """Stack blocks to make a layer.
230
-
231
- Args:
232
- out_channels (int): out_channels of the block.
233
- num_blocks (int): number of blocks.
234
- """
235
- layers = []
236
- for i in range(num_blocks):
237
- stride = 2 if i == 0 else 1
238
- layers.append(
239
- InvertedResidual(
240
- in_channels=self.in_channels,
241
- out_channels=out_channels,
242
- stride=stride,
243
- conv_cfg=self.conv_cfg,
244
- norm_cfg=self.norm_cfg,
245
- act_cfg=self.act_cfg,
246
- with_cp=self.with_cp))
247
- self.in_channels = out_channels
248
-
249
- return nn.Sequential(*layers)
250
-
251
- def _freeze_stages(self):
252
- if self.frozen_stages >= 0:
253
- for param in self.conv1.parameters():
254
- param.requires_grad = False
255
-
256
- for i in range(self.frozen_stages):
257
- m = self.layers[i]
258
- m.eval()
259
- for param in m.parameters():
260
- param.requires_grad = False
261
-
262
- def init_weights(self, pretrained=None):
263
- if isinstance(pretrained, str):
264
- logger = logging.getLogger()
265
- load_checkpoint(self, pretrained, strict=False, logger=logger)
266
- elif pretrained is None:
267
- for name, m in self.named_modules():
268
- if isinstance(m, nn.Conv2d):
269
- if 'conv1' in name:
270
- normal_init(m, mean=0, std=0.01)
271
- else:
272
- normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
273
- elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
274
- constant_init(m.weight, val=1, bias=0.0001)
275
- if isinstance(m, _BatchNorm):
276
- if m.running_mean is not None:
277
- nn.init.constant_(m.running_mean, 0)
278
- else:
279
- raise TypeError('pretrained must be a str or None. But received '
280
- f'{type(pretrained)}')
281
-
282
- def forward(self, x):
283
- x = self.conv1(x)
284
- x = self.maxpool(x)
285
-
286
- outs = []
287
- for i, layer in enumerate(self.layers):
288
- x = layer(x)
289
- if i in self.out_indices:
290
- outs.append(x)
291
-
292
- if len(outs) == 1:
293
- return outs[0]
294
- return tuple(outs)
295
-
296
- def train(self, mode=True):
297
- super().train(mode)
298
- self._freeze_stages()
299
- if mode and self.norm_eval:
300
- for m in self.modules():
301
- if isinstance(m, nn.BatchNorm2d):
302
- m.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/swin.py DELETED
@@ -1,733 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- from collections import OrderedDict
3
- from copy import deepcopy
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import torch.utils.checkpoint as cp
9
- from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
10
- from mmcv.cnn.bricks.transformer import FFN, build_dropout
11
- from mmcv.cnn.utils.weight_init import trunc_normal_
12
- from mmcv.runner import _load_checkpoint
13
- from mmcv.utils import to_2tuple
14
-
15
- from ...utils import get_root_logger
16
- from ..builder import BACKBONES
17
- from ..utils.transformer import PatchEmbed, PatchMerging
18
- from .base_backbone import BaseBackbone
19
- from .utils.ckpt_convert import swin_converter
20
-
21
-
22
- class WindowMSA(nn.Module):
23
- """Window based multi-head self-attention (W-MSA) module with relative
24
- position bias.
25
-
26
- Args:
27
- embed_dims (int): Number of input channels.
28
- num_heads (int): Number of attention heads.
29
- window_size (tuple[int]): The height and width of the window.
30
- qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
31
- Default: True.
32
- qk_scale (float | None, optional): Override default qk scale of
33
- head_dim ** -0.5 if set. Default: None.
34
- attn_drop_rate (float, optional): Dropout ratio of attention weight.
35
- Default: 0.0
36
- proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
37
- """
38
-
39
- def __init__(self,
40
- embed_dims,
41
- num_heads,
42
- window_size,
43
- qkv_bias=True,
44
- qk_scale=None,
45
- attn_drop_rate=0.,
46
- proj_drop_rate=0.):
47
-
48
- super().__init__()
49
- self.embed_dims = embed_dims
50
- self.window_size = window_size # Wh, Ww
51
- self.num_heads = num_heads
52
- head_embed_dims = embed_dims // num_heads
53
- self.scale = qk_scale or head_embed_dims**-0.5
54
-
55
- # define a parameter table of relative position bias
56
- self.relative_position_bias_table = nn.Parameter(
57
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
58
- num_heads)) # 2*Wh-1 * 2*Ww-1, nH
59
-
60
- # About 2x faster than original impl
61
- Wh, Ww = self.window_size
62
- rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
63
- rel_position_index = rel_index_coords + rel_index_coords.T
64
- rel_position_index = rel_position_index.flip(1).contiguous()
65
- self.register_buffer('relative_position_index', rel_position_index)
66
-
67
- self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
68
- self.attn_drop = nn.Dropout(attn_drop_rate)
69
- self.proj = nn.Linear(embed_dims, embed_dims)
70
- self.proj_drop = nn.Dropout(proj_drop_rate)
71
-
72
- self.softmax = nn.Softmax(dim=-1)
73
-
74
- def init_weights(self):
75
- trunc_normal_(self.relative_position_bias_table, std=0.02)
76
-
77
- def forward(self, x, mask=None):
78
- """
79
- Args:
80
-
81
- x (tensor): input features with shape of (num_windows*B, N, C)
82
- mask (tensor | None, Optional): mask with shape of (num_windows,
83
- Wh*Ww, Wh*Ww), value should be between (-inf, 0].
84
- """
85
- B, N, C = x.shape
86
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
87
- C // self.num_heads).permute(2, 0, 3, 1, 4)
88
- # make torchscript happy (cannot use tensor as tuple)
89
- q, k, v = qkv[0], qkv[1], qkv[2]
90
-
91
- q = q * self.scale
92
- attn = (q @ k.transpose(-2, -1))
93
-
94
- relative_position_bias = self.relative_position_bias_table[
95
- self.relative_position_index.view(-1)].view(
96
- self.window_size[0] * self.window_size[1],
97
- self.window_size[0] * self.window_size[1],
98
- -1) # Wh*Ww,Wh*Ww,nH
99
- relative_position_bias = relative_position_bias.permute(
100
- 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
101
- attn = attn + relative_position_bias.unsqueeze(0)
102
-
103
- if mask is not None:
104
- nW = mask.shape[0]
105
- attn = attn.view(B // nW, nW, self.num_heads, N,
106
- N) + mask.unsqueeze(1).unsqueeze(0)
107
- attn = attn.view(-1, self.num_heads, N, N)
108
- attn = self.softmax(attn)
109
-
110
- attn = self.attn_drop(attn)
111
-
112
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
113
- x = self.proj(x)
114
- x = self.proj_drop(x)
115
- return x
116
-
117
- @staticmethod
118
- def double_step_seq(step1, len1, step2, len2):
119
- seq1 = torch.arange(0, step1 * len1, step1)
120
- seq2 = torch.arange(0, step2 * len2, step2)
121
- return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
122
-
123
-
124
- class ShiftWindowMSA(nn.Module):
125
- """Shifted Window Multihead Self-Attention Module.
126
-
127
- Args:
128
- embed_dims (int): Number of input channels.
129
- num_heads (int): Number of attention heads.
130
- window_size (int): The height and width of the window.
131
- shift_size (int, optional): The shift step of each window towards
132
- right-bottom. If zero, act as regular window-msa. Defaults to 0.
133
- qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
134
- Default: True
135
- qk_scale (float | None, optional): Override default qk scale of
136
- head_dim ** -0.5 if set. Defaults: None.
137
- attn_drop_rate (float, optional): Dropout ratio of attention weight.
138
- Defaults: 0.
139
- proj_drop_rate (float, optional): Dropout ratio of output.
140
- Defaults: 0.
141
- dropout_layer (dict, optional): The dropout_layer used before output.
142
- Defaults: dict(type='DropPath', drop_prob=0.).
143
- """
144
-
145
- def __init__(self,
146
- embed_dims,
147
- num_heads,
148
- window_size,
149
- shift_size=0,
150
- qkv_bias=True,
151
- qk_scale=None,
152
- attn_drop_rate=0,
153
- proj_drop_rate=0,
154
- dropout_layer=dict(type='DropPath', drop_prob=0.)):
155
- super().__init__()
156
-
157
- self.window_size = window_size
158
- self.shift_size = shift_size
159
- assert 0 <= self.shift_size < self.window_size
160
-
161
- self.w_msa = WindowMSA(
162
- embed_dims=embed_dims,
163
- num_heads=num_heads,
164
- window_size=to_2tuple(window_size),
165
- qkv_bias=qkv_bias,
166
- qk_scale=qk_scale,
167
- attn_drop_rate=attn_drop_rate,
168
- proj_drop_rate=proj_drop_rate)
169
-
170
- self.drop = build_dropout(dropout_layer)
171
-
172
- def forward(self, query, hw_shape):
173
- B, L, C = query.shape
174
- H, W = hw_shape
175
- assert L == H * W, 'input feature has wrong size'
176
- query = query.view(B, H, W, C)
177
-
178
- # pad feature maps to multiples of window size
179
- pad_r = (self.window_size - W % self.window_size) % self.window_size
180
- pad_b = (self.window_size - H % self.window_size) % self.window_size
181
- query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
182
- H_pad, W_pad = query.shape[1], query.shape[2]
183
-
184
- # cyclic shift
185
- if self.shift_size > 0:
186
- shifted_query = torch.roll(
187
- query,
188
- shifts=(-self.shift_size, -self.shift_size),
189
- dims=(1, 2))
190
-
191
- # calculate attention mask for SW-MSA
192
- img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
193
- h_slices = (slice(0, -self.window_size),
194
- slice(-self.window_size,
195
- -self.shift_size), slice(-self.shift_size, None))
196
- w_slices = (slice(0, -self.window_size),
197
- slice(-self.window_size,
198
- -self.shift_size), slice(-self.shift_size, None))
199
- cnt = 0
200
- for h in h_slices:
201
- for w in w_slices:
202
- img_mask[:, h, w, :] = cnt
203
- cnt += 1
204
-
205
- # nW, window_size, window_size, 1
206
- mask_windows = self.window_partition(img_mask)
207
- mask_windows = mask_windows.view(
208
- -1, self.window_size * self.window_size)
209
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
210
- attn_mask = attn_mask.masked_fill(attn_mask != 0,
211
- float(-100.0)).masked_fill(
212
- attn_mask == 0, float(0.0))
213
- else:
214
- shifted_query = query
215
- attn_mask = None
216
-
217
- # nW*B, window_size, window_size, C
218
- query_windows = self.window_partition(shifted_query)
219
- # nW*B, window_size*window_size, C
220
- query_windows = query_windows.view(-1, self.window_size**2, C)
221
-
222
- # W-MSA/SW-MSA (nW*B, window_size*window_size, C)
223
- attn_windows = self.w_msa(query_windows, mask=attn_mask)
224
-
225
- # merge windows
226
- attn_windows = attn_windows.view(-1, self.window_size,
227
- self.window_size, C)
228
-
229
- # B H' W' C
230
- shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
231
- # reverse cyclic shift
232
- if self.shift_size > 0:
233
- x = torch.roll(
234
- shifted_x,
235
- shifts=(self.shift_size, self.shift_size),
236
- dims=(1, 2))
237
- else:
238
- x = shifted_x
239
-
240
- if pad_r > 0 or pad_b:
241
- x = x[:, :H, :W, :].contiguous()
242
-
243
- x = x.view(B, H * W, C)
244
-
245
- x = self.drop(x)
246
- return x
247
-
248
- def window_reverse(self, windows, H, W):
249
- """
250
- Args:
251
- windows: (num_windows*B, window_size, window_size, C)
252
- H (int): Height of image
253
- W (int): Width of image
254
- Returns:
255
- x: (B, H, W, C)
256
- """
257
- window_size = self.window_size
258
- B = int(windows.shape[0] / (H * W / window_size / window_size))
259
- x = windows.view(B, H // window_size, W // window_size, window_size,
260
- window_size, -1)
261
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
262
- return x
263
-
264
- def window_partition(self, x):
265
- """
266
- Args:
267
- x: (B, H, W, C)
268
- Returns:
269
- windows: (num_windows*B, window_size, window_size, C)
270
- """
271
- B, H, W, C = x.shape
272
- window_size = self.window_size
273
- x = x.view(B, H // window_size, window_size, W // window_size,
274
- window_size, C)
275
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
276
- windows = windows.view(-1, window_size, window_size, C)
277
- return windows
278
-
279
-
280
- class SwinBlock(nn.Module):
281
- """"
282
- Args:
283
- embed_dims (int): The feature dimension.
284
- num_heads (int): Parallel attention heads.
285
- feedforward_channels (int): The hidden dimension for FFNs.
286
- window_size (int, optional): The local window scale. Default: 7.
287
- shift (bool, optional): whether to shift window or not. Default False.
288
- qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
289
- qk_scale (float | None, optional): Override default qk scale of
290
- head_dim ** -0.5 if set. Default: None.
291
- drop_rate (float, optional): Dropout rate. Default: 0.
292
- attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
293
- drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
294
- act_cfg (dict, optional): The config dict of activation function.
295
- Default: dict(type='GELU').
296
- norm_cfg (dict, optional): The config dict of normalization.
297
- Default: dict(type='LN').
298
- with_cp (bool, optional): Use checkpoint or not. Using checkpoint
299
- will save some memory while slowing down the training speed.
300
- Default: False.
301
- """
302
-
303
- def __init__(self,
304
- embed_dims,
305
- num_heads,
306
- feedforward_channels,
307
- window_size=7,
308
- shift=False,
309
- qkv_bias=True,
310
- qk_scale=None,
311
- drop_rate=0.,
312
- attn_drop_rate=0.,
313
- drop_path_rate=0.,
314
- act_cfg=dict(type='GELU'),
315
- norm_cfg=dict(type='LN'),
316
- with_cp=False):
317
-
318
- super(SwinBlock, self).__init__()
319
-
320
- self.with_cp = with_cp
321
-
322
- self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
323
- self.attn = ShiftWindowMSA(
324
- embed_dims=embed_dims,
325
- num_heads=num_heads,
326
- window_size=window_size,
327
- shift_size=window_size // 2 if shift else 0,
328
- qkv_bias=qkv_bias,
329
- qk_scale=qk_scale,
330
- attn_drop_rate=attn_drop_rate,
331
- proj_drop_rate=drop_rate,
332
- dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate))
333
-
334
- self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
335
- self.ffn = FFN(
336
- embed_dims=embed_dims,
337
- feedforward_channels=feedforward_channels,
338
- num_fcs=2,
339
- ffn_drop=drop_rate,
340
- dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
341
- act_cfg=act_cfg,
342
- add_identity=True,
343
- init_cfg=None)
344
-
345
- def forward(self, x, hw_shape):
346
-
347
- def _inner_forward(x):
348
- identity = x
349
- x = self.norm1(x)
350
- x = self.attn(x, hw_shape)
351
-
352
- x = x + identity
353
-
354
- identity = x
355
- x = self.norm2(x)
356
- x = self.ffn(x, identity=identity)
357
-
358
- return x
359
-
360
- if self.with_cp and x.requires_grad:
361
- x = cp.checkpoint(_inner_forward, x)
362
- else:
363
- x = _inner_forward(x)
364
-
365
- return x
366
-
367
-
368
- class SwinBlockSequence(nn.Module):
369
- """Implements one stage in Swin Transformer.
370
-
371
- Args:
372
- embed_dims (int): The feature dimension.
373
- num_heads (int): Parallel attention heads.
374
- feedforward_channels (int): The hidden dimension for FFNs.
375
- depth (int): The number of blocks in this stage.
376
- window_size (int, optional): The local window scale. Default: 7.
377
- qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
378
- qk_scale (float | None, optional): Override default qk scale of
379
- head_dim ** -0.5 if set. Default: None.
380
- drop_rate (float, optional): Dropout rate. Default: 0.
381
- attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
382
- drop_path_rate (float | list[float], optional): Stochastic depth
383
- rate. Default: 0.
384
- downsample (nn.Module | None, optional): The downsample operation
385
- module. Default: None.
386
- act_cfg (dict, optional): The config dict of activation function.
387
- Default: dict(type='GELU').
388
- norm_cfg (dict, optional): The config dict of normalization.
389
- Default: dict(type='LN').
390
- with_cp (bool, optional): Use checkpoint or not. Using checkpoint
391
- will save some memory while slowing down the training speed.
392
- Default: False.
393
- """
394
-
395
- def __init__(self,
396
- embed_dims,
397
- num_heads,
398
- feedforward_channels,
399
- depth,
400
- window_size=7,
401
- qkv_bias=True,
402
- qk_scale=None,
403
- drop_rate=0.,
404
- attn_drop_rate=0.,
405
- drop_path_rate=0.,
406
- downsample=None,
407
- act_cfg=dict(type='GELU'),
408
- norm_cfg=dict(type='LN'),
409
- with_cp=False):
410
- super().__init__()
411
-
412
- if isinstance(drop_path_rate, list):
413
- drop_path_rates = drop_path_rate
414
- assert len(drop_path_rates) == depth
415
- else:
416
- drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
417
-
418
- self.blocks = nn.ModuleList()
419
- for i in range(depth):
420
- block = SwinBlock(
421
- embed_dims=embed_dims,
422
- num_heads=num_heads,
423
- feedforward_channels=feedforward_channels,
424
- window_size=window_size,
425
- shift=False if i % 2 == 0 else True,
426
- qkv_bias=qkv_bias,
427
- qk_scale=qk_scale,
428
- drop_rate=drop_rate,
429
- attn_drop_rate=attn_drop_rate,
430
- drop_path_rate=drop_path_rates[i],
431
- act_cfg=act_cfg,
432
- norm_cfg=norm_cfg,
433
- with_cp=with_cp)
434
- self.blocks.append(block)
435
-
436
- self.downsample = downsample
437
-
438
- def forward(self, x, hw_shape):
439
- for block in self.blocks:
440
- x = block(x, hw_shape)
441
-
442
- if self.downsample:
443
- x_down, down_hw_shape = self.downsample(x, hw_shape)
444
- return x_down, down_hw_shape, x, hw_shape
445
- else:
446
- return x, hw_shape, x, hw_shape
447
-
448
-
449
- @BACKBONES.register_module()
450
- class SwinTransformer(BaseBackbone):
451
- """ Swin Transformer
452
- A PyTorch implement of : `Swin Transformer:
453
- Hierarchical Vision Transformer using Shifted Windows` -
454
- https://arxiv.org/abs/2103.14030
455
-
456
- Inspiration from
457
- https://github.com/microsoft/Swin-Transformer
458
-
459
- Args:
460
- pretrain_img_size (int | tuple[int]): The size of input image when
461
- pretrain. Defaults: 224.
462
- in_channels (int): The num of input channels.
463
- Defaults: 3.
464
- embed_dims (int): The feature dimension. Default: 96.
465
- patch_size (int | tuple[int]): Patch size. Default: 4.
466
- window_size (int): Window size. Default: 7.
467
- mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
468
- Default: 4.
469
- depths (tuple[int]): Depths of each Swin Transformer stage.
470
- Default: (2, 2, 6, 2).
471
- num_heads (tuple[int]): Parallel attention heads of each Swin
472
- Transformer stage. Default: (3, 6, 12, 24).
473
- strides (tuple[int]): The patch merging or patch embedding stride of
474
- each Swin Transformer stage. (In swin, we set kernel size equal to
475
- stride.) Default: (4, 2, 2, 2).
476
- out_indices (tuple[int]): Output from which stages.
477
- Default: (0, 1, 2, 3).
478
- qkv_bias (bool, optional): If True, add a learnable bias to query, key,
479
- value. Default: True
480
- qk_scale (float | None, optional): Override default qk scale of
481
- head_dim ** -0.5 if set. Default: None.
482
- patch_norm (bool): If add a norm layer for patch embed and patch
483
- merging. Default: True.
484
- drop_rate (float): Dropout rate. Defaults: 0.
485
- attn_drop_rate (float): Attention dropout rate. Default: 0.
486
- drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
487
- use_abs_pos_embed (bool): If True, add absolute position embedding to
488
- the patch embedding. Defaults: False.
489
- act_cfg (dict): Config dict for activation layer.
490
- Default: dict(type='LN').
491
- norm_cfg (dict): Config dict for normalization layer at
492
- output of backone. Defaults: dict(type='LN').
493
- with_cp (bool, optional): Use checkpoint or not. Using checkpoint
494
- will save some memory while slowing down the training speed.
495
- Default: False.
496
- pretrained (str, optional): model pretrained path. Default: None.
497
- convert_weights (bool): The flag indicates whether the
498
- pre-trained model is from the original repo. We may need
499
- to convert some keys to make it compatible.
500
- Default: False.
501
- frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
502
- Default: -1 (-1 means not freezing any parameters).
503
- """
504
-
505
- def __init__(
506
- self,
507
- pretrain_img_size=224,
508
- in_channels=3,
509
- embed_dims=96,
510
- patch_size=4,
511
- window_size=7,
512
- mlp_ratio=4,
513
- depths=(2, 2, 6, 2),
514
- num_heads=(3, 6, 12, 24),
515
- strides=(4, 2, 2, 2),
516
- out_indices=(0, 1, 2, 3),
517
- qkv_bias=True,
518
- qk_scale=None,
519
- patch_norm=True,
520
- drop_rate=0.,
521
- attn_drop_rate=0.,
522
- drop_path_rate=0.1,
523
- use_abs_pos_embed=False,
524
- act_cfg=dict(type='GELU'),
525
- norm_cfg=dict(type='LN'),
526
- with_cp=False,
527
- convert_weights=False,
528
- frozen_stages=-1,
529
- ):
530
- self.convert_weights = convert_weights
531
- self.frozen_stages = frozen_stages
532
- if isinstance(pretrain_img_size, int):
533
- pretrain_img_size = to_2tuple(pretrain_img_size)
534
- elif isinstance(pretrain_img_size, tuple):
535
- if len(pretrain_img_size) == 1:
536
- pretrain_img_size = to_2tuple(pretrain_img_size[0])
537
- assert len(pretrain_img_size) == 2, \
538
- f'The size of image should have length 1 or 2, ' \
539
- f'but got {len(pretrain_img_size)}'
540
-
541
- super(SwinTransformer, self).__init__()
542
-
543
- num_layers = len(depths)
544
- self.out_indices = out_indices
545
- self.use_abs_pos_embed = use_abs_pos_embed
546
-
547
- assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
548
-
549
- self.patch_embed = PatchEmbed(
550
- in_channels=in_channels,
551
- embed_dims=embed_dims,
552
- conv_type='Conv2d',
553
- kernel_size=patch_size,
554
- stride=strides[0],
555
- norm_cfg=norm_cfg if patch_norm else None,
556
- init_cfg=None)
557
-
558
- if self.use_abs_pos_embed:
559
- patch_row = pretrain_img_size[0] // patch_size
560
- patch_col = pretrain_img_size[1] // patch_size
561
- num_patches = patch_row * patch_col
562
- self.absolute_pos_embed = nn.Parameter(
563
- torch.zeros((1, num_patches, embed_dims)))
564
-
565
- self.drop_after_pos = nn.Dropout(p=drop_rate)
566
-
567
- # set stochastic depth decay rule
568
- total_depth = sum(depths)
569
- dpr = [
570
- x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
571
- ]
572
-
573
- self.stages = nn.ModuleList()
574
- in_channels = embed_dims
575
- for i in range(num_layers):
576
- if i < num_layers - 1:
577
- downsample = PatchMerging(
578
- in_channels=in_channels,
579
- out_channels=2 * in_channels,
580
- stride=strides[i + 1],
581
- norm_cfg=norm_cfg if patch_norm else None,
582
- init_cfg=None)
583
- else:
584
- downsample = None
585
-
586
- stage = SwinBlockSequence(
587
- embed_dims=in_channels,
588
- num_heads=num_heads[i],
589
- feedforward_channels=mlp_ratio * in_channels,
590
- depth=depths[i],
591
- window_size=window_size,
592
- qkv_bias=qkv_bias,
593
- qk_scale=qk_scale,
594
- drop_rate=drop_rate,
595
- attn_drop_rate=attn_drop_rate,
596
- drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
597
- downsample=downsample,
598
- act_cfg=act_cfg,
599
- norm_cfg=norm_cfg,
600
- with_cp=with_cp)
601
- self.stages.append(stage)
602
- if downsample:
603
- in_channels = downsample.out_channels
604
-
605
- self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
606
- # Add a norm layer for each output
607
- for i in out_indices:
608
- layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
609
- layer_name = f'norm{i}'
610
- self.add_module(layer_name, layer)
611
-
612
- def train(self, mode=True):
613
- """Convert the model into training mode while keep layers freezed."""
614
- super(SwinTransformer, self).train(mode)
615
- self._freeze_stages()
616
-
617
- def _freeze_stages(self):
618
- if self.frozen_stages >= 0:
619
- self.patch_embed.eval()
620
- for param in self.patch_embed.parameters():
621
- param.requires_grad = False
622
- if self.use_abs_pos_embed:
623
- self.absolute_pos_embed.requires_grad = False
624
- self.drop_after_pos.eval()
625
-
626
- for i in range(1, self.frozen_stages + 1):
627
-
628
- if (i - 1) in self.out_indices:
629
- norm_layer = getattr(self, f'norm{i-1}')
630
- norm_layer.eval()
631
- for param in norm_layer.parameters():
632
- param.requires_grad = False
633
-
634
- m = self.stages[i - 1]
635
- m.eval()
636
- for param in m.parameters():
637
- param.requires_grad = False
638
-
639
- def init_weights(self, pretrained=None):
640
- """Initialize the weights in backbone.
641
-
642
- Args:
643
- pretrained (str, optional): Path to pre-trained weights.
644
- Defaults to None.
645
- """
646
- if isinstance(pretrained, str):
647
- logger = get_root_logger()
648
- ckpt = _load_checkpoint(
649
- pretrained, logger=logger, map_location='cpu')
650
- if 'state_dict' in ckpt:
651
- _state_dict = ckpt['state_dict']
652
- elif 'model' in ckpt:
653
- _state_dict = ckpt['model']
654
- else:
655
- _state_dict = ckpt
656
- if self.convert_weights:
657
- # supported loading weight from original repo,
658
- _state_dict = swin_converter(_state_dict)
659
-
660
- state_dict = OrderedDict()
661
- for k, v in _state_dict.items():
662
- if k.startswith('backbone.'):
663
- state_dict[k[9:]] = v
664
-
665
- # strip prefix of state_dict
666
- if list(state_dict.keys())[0].startswith('module.'):
667
- state_dict = {k[7:]: v for k, v in state_dict.items()}
668
-
669
- # reshape absolute position embedding
670
- if state_dict.get('absolute_pos_embed') is not None:
671
- absolute_pos_embed = state_dict['absolute_pos_embed']
672
- N1, L, C1 = absolute_pos_embed.size()
673
- N2, C2, H, W = self.absolute_pos_embed.size()
674
- if N1 != N2 or C1 != C2 or L != H * W:
675
- logger.warning('Error in loading absolute_pos_embed, pass')
676
- else:
677
- state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
678
- N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
679
-
680
- # interpolate position bias table if needed
681
- relative_position_bias_table_keys = [
682
- k for k in state_dict.keys()
683
- if 'relative_position_bias_table' in k
684
- ]
685
- for table_key in relative_position_bias_table_keys:
686
- table_pretrained = state_dict[table_key]
687
- table_current = self.state_dict()[table_key]
688
- L1, nH1 = table_pretrained.size()
689
- L2, nH2 = table_current.size()
690
- if nH1 != nH2:
691
- logger.warning(f'Error in loading {table_key}, pass')
692
- elif L1 != L2:
693
- S1 = int(L1**0.5)
694
- S2 = int(L2**0.5)
695
- table_pretrained_resized = F.interpolate(
696
- table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
697
- size=(S2, S2),
698
- mode='bicubic')
699
- state_dict[table_key] = table_pretrained_resized.view(
700
- nH2, L2).permute(1, 0).contiguous()
701
-
702
- # load state_dict
703
- self.load_state_dict(state_dict, False)
704
- elif pretrained is None:
705
- if self.use_abs_pos_embed:
706
- trunc_normal_(self.absolute_pos_embed, std=0.02)
707
- for m in self.modules():
708
- if isinstance(m, nn.Linear):
709
- trunc_normal_init(m, std=.02, bias=0.)
710
- elif isinstance(m, nn.LayerNorm):
711
- constant_init(m, 1.0)
712
- else:
713
- raise TypeError('pretrained must be a str or None')
714
-
715
- def forward(self, x):
716
- x, hw_shape = self.patch_embed(x)
717
-
718
- if self.use_abs_pos_embed:
719
- x = x + self.absolute_pos_embed
720
- x = self.drop_after_pos(x)
721
-
722
- outs = []
723
- for i, stage in enumerate(self.stages):
724
- x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
725
- if i in self.out_indices:
726
- norm_layer = getattr(self, f'norm{i}')
727
- out = norm_layer(out)
728
- out = out.view(-1, *out_hw_shape,
729
- self.num_features[i]).permute(0, 3, 1,
730
- 2).contiguous()
731
- outs.append(out)
732
-
733
- return outs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/tcformer.py DELETED
@@ -1,283 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import math
3
-
4
- import torch
5
- import torch.nn as nn
6
- from mmcv.cnn import (build_norm_layer, constant_init, normal_init,
7
- trunc_normal_init)
8
- from mmcv.runner import _load_checkpoint, load_state_dict
9
-
10
- from ...utils import get_root_logger
11
- from ..builder import BACKBONES
12
- from ..utils import (PatchEmbed, TCFormerDynamicBlock, TCFormerRegularBlock,
13
- TokenConv, cluster_dpc_knn, merge_tokens,
14
- tcformer_convert, token2map)
15
-
16
-
17
- class CTM(nn.Module):
18
- """Clustering-based Token Merging module in TCFormer.
19
-
20
- Args:
21
- sample_ratio (float): The sample ratio of tokens.
22
- embed_dim (int): Input token feature dimension.
23
- dim_out (int): Output token feature dimension.
24
- k (int): number of the nearest neighbor used i DPC-knn algorithm.
25
- """
26
-
27
- def __init__(self, sample_ratio, embed_dim, dim_out, k=5):
28
- super().__init__()
29
- self.sample_ratio = sample_ratio
30
- self.dim_out = dim_out
31
- self.conv = TokenConv(
32
- in_channels=embed_dim,
33
- out_channels=dim_out,
34
- kernel_size=3,
35
- stride=2,
36
- padding=1)
37
- self.norm = nn.LayerNorm(self.dim_out)
38
- self.score = nn.Linear(self.dim_out, 1)
39
- self.k = k
40
-
41
- def forward(self, token_dict):
42
- token_dict = token_dict.copy()
43
- x = self.conv(token_dict)
44
- x = self.norm(x)
45
- token_score = self.score(x)
46
- token_weight = token_score.exp()
47
-
48
- token_dict['x'] = x
49
- B, N, C = x.shape
50
- token_dict['token_score'] = token_score
51
-
52
- cluster_num = max(math.ceil(N * self.sample_ratio), 1)
53
- idx_cluster, cluster_num = cluster_dpc_knn(token_dict, cluster_num,
54
- self.k)
55
- down_dict = merge_tokens(token_dict, idx_cluster, cluster_num,
56
- token_weight)
57
-
58
- H, W = token_dict['map_size']
59
- H = math.floor((H - 1) / 2 + 1)
60
- W = math.floor((W - 1) / 2 + 1)
61
- down_dict['map_size'] = [H, W]
62
-
63
- return down_dict, token_dict
64
-
65
-
66
- @BACKBONES.register_module()
67
- class TCFormer(nn.Module):
68
- """Token Clustering Transformer (TCFormer)
69
-
70
- Implementation of `Not All Tokens Are Equal: Human-centric Visual
71
- Analysis via Token Clustering Transformer
72
- <https://arxiv.org/abs/2204.08680>`
73
-
74
- Args:
75
- in_channels (int): Number of input channels. Default: 3.
76
- embed_dims (list[int]): Embedding dimension. Default:
77
- [64, 128, 256, 512].
78
- num_heads (Sequence[int]): The attention heads of each transformer
79
- encode layer. Default: [1, 2, 5, 8].
80
- mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
81
- embedding dim of each transformer block.
82
- qkv_bias (bool): Enable bias for qkv if True. Default: True.
83
- qk_scale (float | None, optional): Override default qk scale of
84
- head_dim ** -0.5 if set. Default: None.
85
- drop_rate (float): Probability of an element to be zeroed.
86
- Default 0.0.
87
- attn_drop_rate (float): The drop out rate for attention layer.
88
- Default 0.0.
89
- drop_path_rate (float): stochastic depth rate. Default 0.
90
- norm_cfg (dict): Config dict for normalization layer.
91
- Default: dict(type='LN', eps=1e-6).
92
- num_layers (Sequence[int]): The layer number of each transformer encode
93
- layer. Default: [3, 4, 6, 3].
94
- sr_ratios (Sequence[int]): The spatial reduction rate of each
95
- transformer block. Default: [8, 4, 2, 1].
96
- num_stages (int): The num of stages. Default: 4.
97
- pretrained (str, optional): model pretrained path. Default: None.
98
- k (int): number of the nearest neighbor used for local density.
99
- sample_ratios (list[float]): The sample ratios of CTM modules.
100
- Default: [0.25, 0.25, 0.25]
101
- return_map (bool): If True, transfer dynamic tokens to feature map at
102
- last. Default: False
103
- convert_weights (bool): The flag indicates whether the
104
- pre-trained model is from the original repo. We may need
105
- to convert some keys to make it compatible.
106
- Default: True.
107
- """
108
-
109
- def __init__(self,
110
- in_channels=3,
111
- embed_dims=[64, 128, 256, 512],
112
- num_heads=[1, 2, 4, 8],
113
- mlp_ratios=[4, 4, 4, 4],
114
- qkv_bias=True,
115
- qk_scale=None,
116
- drop_rate=0.,
117
- attn_drop_rate=0.,
118
- drop_path_rate=0.,
119
- norm_cfg=dict(type='LN', eps=1e-6),
120
- num_layers=[3, 4, 6, 3],
121
- sr_ratios=[8, 4, 2, 1],
122
- num_stages=4,
123
- pretrained=None,
124
- k=5,
125
- sample_ratios=[0.25, 0.25, 0.25],
126
- return_map=False,
127
- convert_weights=True):
128
- super().__init__()
129
-
130
- self.num_layers = num_layers
131
- self.num_stages = num_stages
132
- self.grid_stride = sr_ratios[0]
133
- self.embed_dims = embed_dims
134
- self.sr_ratios = sr_ratios
135
- self.mlp_ratios = mlp_ratios
136
- self.sample_ratios = sample_ratios
137
- self.return_map = return_map
138
- self.convert_weights = convert_weights
139
-
140
- # stochastic depth decay rule
141
- dpr = [
142
- x.item()
143
- for x in torch.linspace(0, drop_path_rate, sum(num_layers))
144
- ]
145
- cur = 0
146
-
147
- # In stage 1, use the standard transformer blocks
148
- for i in range(1):
149
- patch_embed = PatchEmbed(
150
- in_channels=in_channels if i == 0 else embed_dims[i - 1],
151
- embed_dims=embed_dims[i],
152
- kernel_size=7,
153
- stride=4,
154
- padding=3,
155
- bias=True,
156
- norm_cfg=dict(type='LN', eps=1e-6))
157
-
158
- block = nn.ModuleList([
159
- TCFormerRegularBlock(
160
- dim=embed_dims[i],
161
- num_heads=num_heads[i],
162
- mlp_ratio=mlp_ratios[i],
163
- qkv_bias=qkv_bias,
164
- qk_scale=qk_scale,
165
- drop=drop_rate,
166
- attn_drop=attn_drop_rate,
167
- drop_path=dpr[cur + j],
168
- norm_cfg=norm_cfg,
169
- sr_ratio=sr_ratios[i]) for j in range(num_layers[i])
170
- ])
171
- norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
172
-
173
- cur += num_layers[i]
174
-
175
- setattr(self, f'patch_embed{i + 1}', patch_embed)
176
- setattr(self, f'block{i + 1}', block)
177
- setattr(self, f'norm{i + 1}', norm)
178
-
179
- # In stage 2~4, use TCFormerDynamicBlock for dynamic tokens
180
- for i in range(1, num_stages):
181
- ctm = CTM(sample_ratios[i - 1], embed_dims[i - 1], embed_dims[i],
182
- k)
183
-
184
- block = nn.ModuleList([
185
- TCFormerDynamicBlock(
186
- dim=embed_dims[i],
187
- num_heads=num_heads[i],
188
- mlp_ratio=mlp_ratios[i],
189
- qkv_bias=qkv_bias,
190
- qk_scale=qk_scale,
191
- drop=drop_rate,
192
- attn_drop=attn_drop_rate,
193
- drop_path=dpr[cur + j],
194
- norm_cfg=norm_cfg,
195
- sr_ratio=sr_ratios[i]) for j in range(num_layers[i])
196
- ])
197
- norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
198
- cur += num_layers[i]
199
-
200
- setattr(self, f'ctm{i}', ctm)
201
- setattr(self, f'block{i + 1}', block)
202
- setattr(self, f'norm{i + 1}', norm)
203
-
204
- self.init_weights(pretrained)
205
-
206
- def init_weights(self, pretrained=None):
207
- if isinstance(pretrained, str):
208
- logger = get_root_logger()
209
-
210
- checkpoint = _load_checkpoint(
211
- pretrained, logger=logger, map_location='cpu')
212
- logger.warning(f'Load pre-trained model for '
213
- f'{self.__class__.__name__} from original repo')
214
- if 'state_dict' in checkpoint:
215
- state_dict = checkpoint['state_dict']
216
- elif 'model' in checkpoint:
217
- state_dict = checkpoint['model']
218
- else:
219
- state_dict = checkpoint
220
-
221
- if self.convert_weights:
222
- # We need to convert pre-trained weights to match this
223
- # implementation.
224
- state_dict = tcformer_convert(state_dict)
225
- load_state_dict(self, state_dict, strict=False, logger=logger)
226
-
227
- elif pretrained is None:
228
- for m in self.modules():
229
- if isinstance(m, nn.Linear):
230
- trunc_normal_init(m, std=.02, bias=0.)
231
- elif isinstance(m, nn.LayerNorm):
232
- constant_init(m, 1.0)
233
- elif isinstance(m, nn.Conv2d):
234
- fan_out = m.kernel_size[0] * m.kernel_size[
235
- 1] * m.out_channels
236
- fan_out //= m.groups
237
- normal_init(m, 0, math.sqrt(2.0 / fan_out))
238
- else:
239
- raise TypeError('pretrained must be a str or None')
240
-
241
- def forward(self, x):
242
- outs = []
243
-
244
- i = 0
245
- patch_embed = getattr(self, f'patch_embed{i + 1}')
246
- block = getattr(self, f'block{i + 1}')
247
- norm = getattr(self, f'norm{i + 1}')
248
- x, (H, W) = patch_embed(x)
249
- for blk in block:
250
- x = blk(x, H, W)
251
- x = norm(x)
252
-
253
- # init token dict
254
- B, N, _ = x.shape
255
- device = x.device
256
- idx_token = torch.arange(N)[None, :].repeat(B, 1).to(device)
257
- agg_weight = x.new_ones(B, N, 1)
258
- token_dict = {
259
- 'x': x,
260
- 'token_num': N,
261
- 'map_size': [H, W],
262
- 'init_grid_size': [H, W],
263
- 'idx_token': idx_token,
264
- 'agg_weight': agg_weight
265
- }
266
- outs.append(token_dict.copy())
267
-
268
- # stage 2~4
269
- for i in range(1, self.num_stages):
270
- ctm = getattr(self, f'ctm{i}')
271
- block = getattr(self, f'block{i + 1}')
272
- norm = getattr(self, f'norm{i + 1}')
273
-
274
- token_dict = ctm(token_dict) # down sample
275
- for j, blk in enumerate(block):
276
- token_dict = blk(token_dict)
277
-
278
- token_dict['x'] = norm(token_dict['x'])
279
- outs.append(token_dict)
280
-
281
- if self.return_map:
282
- outs = [token2map(token_dict) for token_dict in outs]
283
- return outs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/tcn.py DELETED
@@ -1,267 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import copy
3
-
4
- import torch.nn as nn
5
- from mmcv.cnn import ConvModule, build_conv_layer, constant_init, kaiming_init
6
- from mmcv.utils.parrots_wrapper import _BatchNorm
7
-
8
- from mmpose.core import WeightNormClipHook
9
- from ..builder import BACKBONES
10
- from .base_backbone import BaseBackbone
11
-
12
-
13
- class BasicTemporalBlock(nn.Module):
14
- """Basic block for VideoPose3D.
15
-
16
- Args:
17
- in_channels (int): Input channels of this block.
18
- out_channels (int): Output channels of this block.
19
- mid_channels (int): The output channels of conv1. Default: 1024.
20
- kernel_size (int): Size of the convolving kernel. Default: 3.
21
- dilation (int): Spacing between kernel elements. Default: 3.
22
- dropout (float): Dropout rate. Default: 0.25.
23
- causal (bool): Use causal convolutions instead of symmetric
24
- convolutions (for real-time applications). Default: False.
25
- residual (bool): Use residual connection. Default: True.
26
- use_stride_conv (bool): Use optimized TCN that designed
27
- specifically for single-frame batching, i.e. where batches have
28
- input length = receptive field, and output length = 1. This
29
- implementation replaces dilated convolutions with strided
30
- convolutions to avoid generating unused intermediate results.
31
- Default: False.
32
- conv_cfg (dict): dictionary to construct and config conv layer.
33
- Default: dict(type='Conv1d').
34
- norm_cfg (dict): dictionary to construct and config norm layer.
35
- Default: dict(type='BN1d').
36
- """
37
-
38
- def __init__(self,
39
- in_channels,
40
- out_channels,
41
- mid_channels=1024,
42
- kernel_size=3,
43
- dilation=3,
44
- dropout=0.25,
45
- causal=False,
46
- residual=True,
47
- use_stride_conv=False,
48
- conv_cfg=dict(type='Conv1d'),
49
- norm_cfg=dict(type='BN1d')):
50
- # Protect mutable default arguments
51
- conv_cfg = copy.deepcopy(conv_cfg)
52
- norm_cfg = copy.deepcopy(norm_cfg)
53
- super().__init__()
54
- self.in_channels = in_channels
55
- self.out_channels = out_channels
56
- self.mid_channels = mid_channels
57
- self.kernel_size = kernel_size
58
- self.dilation = dilation
59
- self.dropout = dropout
60
- self.causal = causal
61
- self.residual = residual
62
- self.use_stride_conv = use_stride_conv
63
-
64
- self.pad = (kernel_size - 1) * dilation // 2
65
- if use_stride_conv:
66
- self.stride = kernel_size
67
- self.causal_shift = kernel_size // 2 if causal else 0
68
- self.dilation = 1
69
- else:
70
- self.stride = 1
71
- self.causal_shift = kernel_size // 2 * dilation if causal else 0
72
-
73
- self.conv1 = nn.Sequential(
74
- ConvModule(
75
- in_channels,
76
- mid_channels,
77
- kernel_size=kernel_size,
78
- stride=self.stride,
79
- dilation=self.dilation,
80
- bias='auto',
81
- conv_cfg=conv_cfg,
82
- norm_cfg=norm_cfg))
83
- self.conv2 = nn.Sequential(
84
- ConvModule(
85
- mid_channels,
86
- out_channels,
87
- kernel_size=1,
88
- bias='auto',
89
- conv_cfg=conv_cfg,
90
- norm_cfg=norm_cfg))
91
-
92
- if residual and in_channels != out_channels:
93
- self.short_cut = build_conv_layer(conv_cfg, in_channels,
94
- out_channels, 1)
95
- else:
96
- self.short_cut = None
97
-
98
- self.dropout = nn.Dropout(dropout) if dropout > 0 else None
99
-
100
- def forward(self, x):
101
- """Forward function."""
102
- if self.use_stride_conv:
103
- assert self.causal_shift + self.kernel_size // 2 < x.shape[2]
104
- else:
105
- assert 0 <= self.pad + self.causal_shift < x.shape[2] - \
106
- self.pad + self.causal_shift <= x.shape[2]
107
-
108
- out = self.conv1(x)
109
- if self.dropout is not None:
110
- out = self.dropout(out)
111
-
112
- out = self.conv2(out)
113
- if self.dropout is not None:
114
- out = self.dropout(out)
115
-
116
- if self.residual:
117
- if self.use_stride_conv:
118
- res = x[:, :, self.causal_shift +
119
- self.kernel_size // 2::self.kernel_size]
120
- else:
121
- res = x[:, :,
122
- (self.pad + self.causal_shift):(x.shape[2] - self.pad +
123
- self.causal_shift)]
124
-
125
- if self.short_cut is not None:
126
- res = self.short_cut(res)
127
- out = out + res
128
-
129
- return out
130
-
131
-
132
- @BACKBONES.register_module()
133
- class TCN(BaseBackbone):
134
- """TCN backbone.
135
-
136
- Temporal Convolutional Networks.
137
- More details can be found in the
138
- `paper <https://arxiv.org/abs/1811.11742>`__ .
139
-
140
- Args:
141
- in_channels (int): Number of input channels, which equals to
142
- num_keypoints * num_features.
143
- stem_channels (int): Number of feature channels. Default: 1024.
144
- num_blocks (int): NUmber of basic temporal convolutional blocks.
145
- Default: 2.
146
- kernel_sizes (Sequence[int]): Sizes of the convolving kernel of
147
- each basic block. Default: ``(3, 3, 3)``.
148
- dropout (float): Dropout rate. Default: 0.25.
149
- causal (bool): Use causal convolutions instead of symmetric
150
- convolutions (for real-time applications).
151
- Default: False.
152
- residual (bool): Use residual connection. Default: True.
153
- use_stride_conv (bool): Use TCN backbone optimized for
154
- single-frame batching, i.e. where batches have input length =
155
- receptive field, and output length = 1. This implementation
156
- replaces dilated convolutions with strided convolutions to avoid
157
- generating unused intermediate results. The weights are
158
- interchangeable with the reference implementation. Default: False
159
- conv_cfg (dict): dictionary to construct and config conv layer.
160
- Default: dict(type='Conv1d').
161
- norm_cfg (dict): dictionary to construct and config norm layer.
162
- Default: dict(type='BN1d').
163
- max_norm (float|None): if not None, the weight of convolution layers
164
- will be clipped to have a maximum norm of max_norm.
165
-
166
- Example:
167
- >>> from mmpose.models import TCN
168
- >>> import torch
169
- >>> self = TCN(in_channels=34)
170
- >>> self.eval()
171
- >>> inputs = torch.rand(1, 34, 243)
172
- >>> level_outputs = self.forward(inputs)
173
- >>> for level_out in level_outputs:
174
- ... print(tuple(level_out.shape))
175
- (1, 1024, 235)
176
- (1, 1024, 217)
177
- """
178
-
179
- def __init__(self,
180
- in_channels,
181
- stem_channels=1024,
182
- num_blocks=2,
183
- kernel_sizes=(3, 3, 3),
184
- dropout=0.25,
185
- causal=False,
186
- residual=True,
187
- use_stride_conv=False,
188
- conv_cfg=dict(type='Conv1d'),
189
- norm_cfg=dict(type='BN1d'),
190
- max_norm=None):
191
- # Protect mutable default arguments
192
- conv_cfg = copy.deepcopy(conv_cfg)
193
- norm_cfg = copy.deepcopy(norm_cfg)
194
- super().__init__()
195
- self.in_channels = in_channels
196
- self.stem_channels = stem_channels
197
- self.num_blocks = num_blocks
198
- self.kernel_sizes = kernel_sizes
199
- self.dropout = dropout
200
- self.causal = causal
201
- self.residual = residual
202
- self.use_stride_conv = use_stride_conv
203
- self.max_norm = max_norm
204
-
205
- assert num_blocks == len(kernel_sizes) - 1
206
- for ks in kernel_sizes:
207
- assert ks % 2 == 1, 'Only odd filter widths are supported.'
208
-
209
- self.expand_conv = ConvModule(
210
- in_channels,
211
- stem_channels,
212
- kernel_size=kernel_sizes[0],
213
- stride=kernel_sizes[0] if use_stride_conv else 1,
214
- bias='auto',
215
- conv_cfg=conv_cfg,
216
- norm_cfg=norm_cfg)
217
-
218
- dilation = kernel_sizes[0]
219
- self.tcn_blocks = nn.ModuleList()
220
- for i in range(1, num_blocks + 1):
221
- self.tcn_blocks.append(
222
- BasicTemporalBlock(
223
- in_channels=stem_channels,
224
- out_channels=stem_channels,
225
- mid_channels=stem_channels,
226
- kernel_size=kernel_sizes[i],
227
- dilation=dilation,
228
- dropout=dropout,
229
- causal=causal,
230
- residual=residual,
231
- use_stride_conv=use_stride_conv,
232
- conv_cfg=conv_cfg,
233
- norm_cfg=norm_cfg))
234
- dilation *= kernel_sizes[i]
235
-
236
- if self.max_norm is not None:
237
- # Apply weight norm clip to conv layers
238
- weight_clip = WeightNormClipHook(self.max_norm)
239
- for module in self.modules():
240
- if isinstance(module, nn.modules.conv._ConvNd):
241
- weight_clip.register(module)
242
-
243
- self.dropout = nn.Dropout(dropout) if dropout > 0 else None
244
-
245
- def forward(self, x):
246
- """Forward function."""
247
- x = self.expand_conv(x)
248
-
249
- if self.dropout is not None:
250
- x = self.dropout(x)
251
-
252
- outs = []
253
- for i in range(self.num_blocks):
254
- x = self.tcn_blocks[i](x)
255
- outs.append(x)
256
-
257
- return tuple(outs)
258
-
259
- def init_weights(self, pretrained=None):
260
- """Initialize the weights."""
261
- super().init_weights(pretrained)
262
- if pretrained is None:
263
- for m in self.modules():
264
- if isinstance(m, nn.modules.conv._ConvNd):
265
- kaiming_init(m, mode='fan_in', nonlinearity='relu')
266
- elif isinstance(m, _BatchNorm):
267
- constant_init(m, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/transformer_utils/mmpose/models/backbones/utils/utils.py CHANGED
@@ -1,7 +1,7 @@
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
  from collections import OrderedDict
3
 
4
- from mmcv.runner.checkpoint import _load_checkpoint, load_state_dict
5
 
6
 
7
  # Copyright (c) Open-MMLab. All rights reserved.
@@ -22,11 +22,11 @@ from torch.utils import model_zoo
22
  from torch.nn import functional as F
23
 
24
  import mmcv
25
- from mmcv.fileio import FileClient
26
- from mmcv.fileio import load as load_file
27
- from mmcv.parallel import is_module_wrapper
28
- from mmcv.utils import mkdir_or_exist
29
- from mmcv.runner import get_dist_info
30
 
31
  from scipy import interpolate
32
  import numpy as np
@@ -75,8 +75,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
75
  def load(module, prefix=''):
76
  # recursively check parallel module in case that the model has a
77
  # complicated structure, e.g., nn.Module(nn.Module(DDP))
78
- if is_module_wrapper(module):
79
- module = module.module
80
  local_metadata = {} if metadata is None else metadata.get(
81
  prefix[:-1], {})
82
  module._load_from_state_dict(state_dict, prefix, local_metadata, True,
@@ -445,8 +445,8 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
445
  """
446
  # recursively check parallel module in case that the model has a
447
  # complicated structure, e.g., nn.Module(nn.Module(DDP))
448
- if is_module_wrapper(module):
449
- module = module.module
450
 
451
  # below is the same as torch.nn.Module.state_dict()
452
  if destination is None:
@@ -482,8 +482,8 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
482
  raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
483
  meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
484
 
485
- if is_module_wrapper(model):
486
- model = model.module
487
 
488
  if hasattr(model, 'CLASSES') and model.CLASSES is not None:
489
  # save class name to the meta
 
1
  # Copyright (c) OpenMMLab. All rights reserved.
2
  from collections import OrderedDict
3
 
4
+ from mmengine.runner import load_state_dict
5
 
6
 
7
  # Copyright (c) Open-MMLab. All rights reserved.
 
22
  from torch.nn import functional as F
23
 
24
  import mmcv
25
+ from mmengine.fileio import FileClient
26
+ from mmengine.fileio import load as load_file
27
+ # from mmengine.model.wrappers.utils import is_module_wrapper
28
+ from mmengine.utils import mkdir_or_exist
29
+ from mmengine.dist import get_dist_info
30
 
31
  from scipy import interpolate
32
  import numpy as np
 
75
  def load(module, prefix=''):
76
  # recursively check parallel module in case that the model has a
77
  # complicated structure, e.g., nn.Module(nn.Module(DDP))
78
+ # if is_module_wrapper(module):
79
+ # module = module.module
80
  local_metadata = {} if metadata is None else metadata.get(
81
  prefix[:-1], {})
82
  module._load_from_state_dict(state_dict, prefix, local_metadata, True,
 
445
  """
446
  # recursively check parallel module in case that the model has a
447
  # complicated structure, e.g., nn.Module(nn.Module(DDP))
448
+ # if is_module_wrapper(module):
449
+ # module = module.module
450
 
451
  # below is the same as torch.nn.Module.state_dict()
452
  if destination is None:
 
482
  raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
483
  meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
484
 
485
+ # if is_module_wrapper(model):
486
+ # model = model.module
487
 
488
  if hasattr(model, 'CLASSES') and model.CLASSES is not None:
489
  # save class name to the meta