Spaces:
Build error
Build error
onescotch
commited on
Commit
•
010a8bc
1
Parent(s):
d200058
clean up for zero gpus
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +8 -15
- common/base.py +1 -1
- common/utils/distribute_utils.py +1 -1
- main/SMPLer_X.py +1 -1
- main/config.py +2 -1
- main/inference.py +7 -1
- main/transformer_utils/mmpose/__init__.py +1 -1
- main/transformer_utils/mmpose/core/camera/camera_base.py +1 -1
- main/transformer_utils/mmpose/core/distributed_wrapper.py +1 -1
- main/transformer_utils/mmpose/core/evaluation/eval_hooks.py +511 -2
- main/transformer_utils/mmpose/core/fp16/hooks.py +77 -2
- main/transformer_utils/mmpose/core/optimizers/builder.py +20 -7
- main/transformer_utils/mmpose/core/optimizers/layer_decay_optimizer_constructor.py +2 -2
- main/transformer_utils/mmpose/core/post_processing/smoother.py +2 -2
- main/transformer_utils/mmpose/core/post_processing/temporal_filters/builder.py +1 -1
- main/transformer_utils/mmpose/core/post_processing/temporal_filters/smoothnet_filter.py +1 -1
- main/transformer_utils/mmpose/core/utils/dist_utils.py +1 -1
- main/transformer_utils/mmpose/core/utils/model_util_hooks.py +2 -2
- main/transformer_utils/mmpose/core/visualization/image.py +1 -1
- main/transformer_utils/mmpose/models/__init__.py +1 -1
- main/transformer_utils/mmpose/models/backbones/__init__.py +39 -38
- main/transformer_utils/mmpose/models/backbones/alexnet.py +0 -56
- main/transformer_utils/mmpose/models/backbones/cpm.py +0 -186
- main/transformer_utils/mmpose/models/backbones/hourglass.py +0 -212
- main/transformer_utils/mmpose/models/backbones/hourglass_ae.py +0 -212
- main/transformer_utils/mmpose/models/backbones/hrformer.py +0 -746
- main/transformer_utils/mmpose/models/backbones/hrnet.py +0 -604
- main/transformer_utils/mmpose/models/backbones/hrt.py +0 -676
- main/transformer_utils/mmpose/models/backbones/hrt_checkpoint.py +0 -500
- main/transformer_utils/mmpose/models/backbones/i3d.py +0 -215
- main/transformer_utils/mmpose/models/backbones/litehrnet.py +0 -984
- main/transformer_utils/mmpose/models/backbones/mobilenet_v2.py +0 -275
- main/transformer_utils/mmpose/models/backbones/mobilenet_v3.py +0 -188
- main/transformer_utils/mmpose/models/backbones/modules/basic_block.py +1 -3
- main/transformer_utils/mmpose/models/backbones/mspn.py +0 -513
- main/transformer_utils/mmpose/models/backbones/pvt.py +0 -592
- main/transformer_utils/mmpose/models/backbones/regnet.py +0 -317
- main/transformer_utils/mmpose/models/backbones/resnest.py +0 -338
- main/transformer_utils/mmpose/models/backbones/resnet.py +3 -3
- main/transformer_utils/mmpose/models/backbones/resnext.py +0 -162
- main/transformer_utils/mmpose/models/backbones/rsn.py +0 -616
- main/transformer_utils/mmpose/models/backbones/scnet.py +0 -248
- main/transformer_utils/mmpose/models/backbones/seresnet.py +0 -125
- main/transformer_utils/mmpose/models/backbones/seresnext.py +0 -168
- main/transformer_utils/mmpose/models/backbones/shufflenet_v1.py +0 -329
- main/transformer_utils/mmpose/models/backbones/shufflenet_v2.py +0 -302
- main/transformer_utils/mmpose/models/backbones/swin.py +0 -733
- main/transformer_utils/mmpose/models/backbones/tcformer.py +0 -283
- main/transformer_utils/mmpose/models/backbones/tcn.py +0 -267
- main/transformer_utils/mmpose/models/backbones/utils/utils.py +12 -12
app.py
CHANGED
@@ -13,29 +13,22 @@ try:
|
|
13 |
except:
|
14 |
os.system('pip install /home/user/app/main/transformer_utils')
|
15 |
hf_hub_download(repo_id="caizhongang/SMPLer-X", filename="smpler_x_h32.pth.tar", local_dir="/home/user/app/pretrained_models")
|
16 |
-
os.system('cp -rf /home/user/app/assets/conversions.py /
|
17 |
DEFAULT_MODEL='smpler_x_h32'
|
18 |
OUT_FOLDER = '/home/user/app/demo_out'
|
19 |
os.makedirs(OUT_FOLDER, exist_ok=True)
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
# from main.inference import Inferer
|
28 |
# inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
|
29 |
|
30 |
@spaces.GPU(enable_queue=True)
|
31 |
def infer(video_input, in_threshold=0.5, num_people="Single person", render_mesh=False):
|
32 |
-
num_gpus = 1 if torch.cuda.is_available() else -1
|
33 |
-
print("!!!", torch.cuda.is_available())
|
34 |
-
print(torch.cuda.device_count())
|
35 |
-
print(torch.version.cuda)
|
36 |
-
index = torch.cuda.current_device()
|
37 |
-
print(index)
|
38 |
-
print(torch.cuda.get_device_name(index))
|
39 |
from main.inference import Inferer
|
40 |
inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
|
41 |
os.system(f'rm -rf {OUT_FOLDER}/*')
|
|
|
13 |
except:
|
14 |
os.system('pip install /home/user/app/main/transformer_utils')
|
15 |
hf_hub_download(repo_id="caizhongang/SMPLer-X", filename="smpler_x_h32.pth.tar", local_dir="/home/user/app/pretrained_models")
|
16 |
+
os.system('cp -rf /home/user/app/assets/conversions.py /usr/local/lib/python3.10/site-packages/torchgeometry/core/conversions.py')
|
17 |
DEFAULT_MODEL='smpler_x_h32'
|
18 |
OUT_FOLDER = '/home/user/app/demo_out'
|
19 |
os.makedirs(OUT_FOLDER, exist_ok=True)
|
20 |
+
num_gpus = 1 if torch.cuda.is_available() else -1
|
21 |
+
print("!!!", torch.cuda.is_available())
|
22 |
+
print(torch.cuda.device_count())
|
23 |
+
print(torch.version.cuda)
|
24 |
+
index = torch.cuda.current_device()
|
25 |
+
print(index)
|
26 |
+
print(torch.cuda.get_device_name(index))
|
27 |
# from main.inference import Inferer
|
28 |
# inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
|
29 |
|
30 |
@spaces.GPU(enable_queue=True)
|
31 |
def infer(video_input, in_threshold=0.5, num_people="Single person", render_mesh=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
from main.inference import Inferer
|
33 |
inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
|
34 |
os.system(f'rm -rf {OUT_FOLDER}/*')
|
common/base.py
CHANGED
@@ -17,7 +17,7 @@ import torch.utils.data.distributed
|
|
17 |
from utils.distribute_utils import (
|
18 |
get_rank, is_main_process, time_synchronized, get_group_idx, get_process_groups
|
19 |
)
|
20 |
-
|
21 |
|
22 |
class Base(object):
|
23 |
__metaclass__ = abc.ABCMeta
|
|
|
17 |
from utils.distribute_utils import (
|
18 |
get_rank, is_main_process, time_synchronized, get_group_idx, get_process_groups
|
19 |
)
|
20 |
+
|
21 |
|
22 |
class Base(object):
|
23 |
__metaclass__ = abc.ABCMeta
|
common/utils/distribute_utils.py
CHANGED
@@ -7,7 +7,7 @@ import tempfile
|
|
7 |
import time
|
8 |
import torch
|
9 |
import torch.distributed as dist
|
10 |
-
from
|
11 |
import random
|
12 |
import numpy as np
|
13 |
import subprocess
|
|
|
7 |
import time
|
8 |
import torch
|
9 |
import torch.distributed as dist
|
10 |
+
from mmengine.dist import get_dist_info
|
11 |
import random
|
12 |
import numpy as np
|
13 |
import subprocess
|
main/SMPLer_X.py
CHANGED
@@ -9,7 +9,7 @@ from config import cfg
|
|
9 |
import math
|
10 |
import copy
|
11 |
from mmpose.models import build_posenet
|
12 |
-
from
|
13 |
|
14 |
class Model(nn.Module):
|
15 |
def __init__(self, encoder, body_position_net, body_rotation_net, box_net, hand_position_net, hand_roi_net,
|
|
|
9 |
import math
|
10 |
import copy
|
11 |
from mmpose.models import build_posenet
|
12 |
+
from mmengine.config import Config
|
13 |
|
14 |
class Model(nn.Module):
|
15 |
def __init__(self, encoder, body_position_net, body_rotation_net, box_net, hand_position_net, hand_roi_net,
|
main/config.py
CHANGED
@@ -2,7 +2,8 @@ import os
|
|
2 |
import os.path as osp
|
3 |
import sys
|
4 |
import datetime
|
5 |
-
from
|
|
|
6 |
|
7 |
class Config:
|
8 |
def get_config_fromfile(self, config_path):
|
|
|
2 |
import os.path as osp
|
3 |
import sys
|
4 |
import datetime
|
5 |
+
from mmengine.config import Config as MMConfig
|
6 |
+
|
7 |
|
8 |
class Config:
|
9 |
def get_config_fromfile(self, config_path):
|
main/inference.py
CHANGED
@@ -53,8 +53,14 @@ class Inferer:
|
|
53 |
|
54 |
## mmdet inference
|
55 |
mmdet_results = inference_detector(self.model, original_img)
|
56 |
-
mmdet_box = process_mmdet_results(mmdet_results, cat_id=0, multi_person=True)
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
# save original image if no bbox
|
59 |
if len(mmdet_box[0])<1:
|
60 |
return original_img, [], []
|
|
|
53 |
|
54 |
## mmdet inference
|
55 |
mmdet_results = inference_detector(self.model, original_img)
|
|
|
56 |
|
57 |
+
pred_instance = mmdet_results.pred_instances.cpu().numpy()
|
58 |
+
bboxes = np.concatenate(
|
59 |
+
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
|
60 |
+
bboxes = bboxes[pred_instance.labels == 0]
|
61 |
+
bboxes = np.expand_dims(bboxes, axis=0)
|
62 |
+
mmdet_box = process_mmdet_results(bboxes, cat_id=0, multi_person=True)
|
63 |
+
|
64 |
# save original image if no bbox
|
65 |
if len(mmdet_box[0])<1:
|
66 |
return original_img, [], []
|
main/transformer_utils/mmpose/__init__.py
CHANGED
@@ -17,7 +17,7 @@ def digit_version(version_str):
|
|
17 |
|
18 |
|
19 |
mmcv_minimum_version = '1.3.8'
|
20 |
-
mmcv_maximum_version = '
|
21 |
mmcv_version = digit_version(mmcv.__version__)
|
22 |
|
23 |
|
|
|
17 |
|
18 |
|
19 |
mmcv_minimum_version = '1.3.8'
|
20 |
+
mmcv_maximum_version = '2.3.0'
|
21 |
mmcv_version = digit_version(mmcv.__version__)
|
22 |
|
23 |
|
main/transformer_utils/mmpose/core/camera/camera_base.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
from abc import ABCMeta, abstractmethod
|
3 |
|
4 |
-
from
|
5 |
|
6 |
CAMERAS = Registry('camera')
|
7 |
|
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
from abc import ABCMeta, abstractmethod
|
3 |
|
4 |
+
from mmengine import Registry
|
5 |
|
6 |
CAMERAS = Registry('camera')
|
7 |
|
main/transformer_utils/mmpose/core/distributed_wrapper.py
CHANGED
@@ -4,7 +4,7 @@ import torch.nn as nn
|
|
4 |
from mmcv.parallel import MODULE_WRAPPERS as MMCV_MODULE_WRAPPERS
|
5 |
from mmcv.parallel import MMDistributedDataParallel
|
6 |
from mmcv.parallel.scatter_gather import scatter_kwargs
|
7 |
-
from
|
8 |
from torch.cuda._utils import _get_device_index
|
9 |
|
10 |
MODULE_WRAPPERS = Registry('module wrapper', parent=MMCV_MODULE_WRAPPERS)
|
|
|
4 |
from mmcv.parallel import MODULE_WRAPPERS as MMCV_MODULE_WRAPPERS
|
5 |
from mmcv.parallel import MMDistributedDataParallel
|
6 |
from mmcv.parallel.scatter_gather import scatter_kwargs
|
7 |
+
from mmengine import Registry
|
8 |
from torch.cuda._utils import _get_device_index
|
9 |
|
10 |
MODULE_WRAPPERS = Registry('module wrapper', parent=MMCV_MODULE_WRAPPERS)
|
main/transformer_utils/mmpose/core/evaluation/eval_hooks.py
CHANGED
@@ -1,8 +1,18 @@
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
import warnings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
from mmcv.runner import DistEvalHook as _DistEvalHook
|
5 |
-
from mmcv.runner import EvalHook as _EvalHook
|
6 |
|
7 |
MMPOSE_GREATER_KEYS = [
|
8 |
'acc', 'ap', 'ar', 'pck', 'auc', '3dpck', 'p-3dpck', '3dauc', 'p-3dauc',
|
@@ -10,6 +20,505 @@ MMPOSE_GREATER_KEYS = [
|
|
10 |
]
|
11 |
MMPOSE_LESS_KEYS = ['loss', 'epe', 'nme', 'mpjpe', 'p-mpjpe', 'n-mpjpe']
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
class EvalHook(_EvalHook):
|
15 |
|
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
import warnings
|
3 |
+
import os.path as osp
|
4 |
+
import warnings
|
5 |
+
from math import inf
|
6 |
+
from typing import Callable, List, Optional
|
7 |
+
|
8 |
+
import torch.distributed as dist
|
9 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
|
12 |
+
from mmengine.fileio import FileClient
|
13 |
+
from mmengine.utils import is_seq_of
|
14 |
+
from mmengine.hooks import Hook, LoggerHook
|
15 |
|
|
|
|
|
16 |
|
17 |
MMPOSE_GREATER_KEYS = [
|
18 |
'acc', 'ap', 'ar', 'pck', 'auc', '3dpck', 'p-3dpck', '3dauc', 'p-3dauc',
|
|
|
20 |
]
|
21 |
MMPOSE_LESS_KEYS = ['loss', 'epe', 'nme', 'mpjpe', 'p-mpjpe', 'n-mpjpe']
|
22 |
|
23 |
+
class _EvalHook(Hook):
|
24 |
+
"""Non-Distributed evaluation hook.
|
25 |
+
|
26 |
+
This hook will regularly perform evaluation in a given interval when
|
27 |
+
performing in non-distributed environment.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
dataloader (DataLoader): A PyTorch dataloader, whose dataset has
|
31 |
+
implemented ``evaluate`` function.
|
32 |
+
start (int | None, optional): Evaluation starting epoch or iteration.
|
33 |
+
It enables evaluation before the training starts if ``start`` <=
|
34 |
+
the resuming epoch or iteration. If None, whether to evaluate is
|
35 |
+
merely decided by ``interval``. Default: None.
|
36 |
+
interval (int): Evaluation interval. Default: 1.
|
37 |
+
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
|
38 |
+
If set to True, it will perform by epoch. Otherwise, by iteration.
|
39 |
+
Default: True.
|
40 |
+
save_best (str, optional): If a metric is specified, it would measure
|
41 |
+
the best checkpoint during evaluation. The information about best
|
42 |
+
checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
|
43 |
+
best score value and best checkpoint path, which will be also
|
44 |
+
loaded when resume checkpoint. Options are the evaluation metrics
|
45 |
+
on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
|
46 |
+
detection and instance segmentation. ``AR@100`` for proposal
|
47 |
+
recall. If ``save_best`` is ``auto``, the first key of the returned
|
48 |
+
``OrderedDict`` result will be used. Default: None.
|
49 |
+
rule (str | None, optional): Comparison rule for best score. If set to
|
50 |
+
None, it will infer a reasonable rule. Keys such as 'acc', 'top'
|
51 |
+
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
|
52 |
+
be inferred by 'less' rule. Options are 'greater', 'less', None.
|
53 |
+
Default: None.
|
54 |
+
test_fn (callable, optional): test a model with samples from a
|
55 |
+
dataloader, and return the test results. If ``None``, the default
|
56 |
+
test function ``mmcv.engine.single_gpu_test`` will be used.
|
57 |
+
(default: ``None``)
|
58 |
+
greater_keys (List[str] | None, optional): Metric keys that will be
|
59 |
+
inferred by 'greater' comparison rule. If ``None``,
|
60 |
+
_default_greater_keys will be used. (default: ``None``)
|
61 |
+
less_keys (List[str] | None, optional): Metric keys that will be
|
62 |
+
inferred by 'less' comparison rule. If ``None``, _default_less_keys
|
63 |
+
will be used. (default: ``None``)
|
64 |
+
out_dir (str, optional): The root directory to save checkpoints. If not
|
65 |
+
specified, `runner.work_dir` will be used by default. If specified,
|
66 |
+
the `out_dir` will be the concatenation of `out_dir` and the last
|
67 |
+
level directory of `runner.work_dir`.
|
68 |
+
`New in version 1.3.16.`
|
69 |
+
file_client_args (dict): Arguments to instantiate a FileClient.
|
70 |
+
See :class:`mmcv.fileio.FileClient` for details. Default: None.
|
71 |
+
`New in version 1.3.16.`
|
72 |
+
**eval_kwargs: Evaluation arguments fed into the evaluate function of
|
73 |
+
the dataset.
|
74 |
+
|
75 |
+
Note:
|
76 |
+
If new arguments are added for EvalHook, tools/test.py,
|
77 |
+
tools/eval_metric.py may be affected.
|
78 |
+
"""
|
79 |
+
|
80 |
+
# Since the key for determine greater or less is related to the downstream
|
81 |
+
# tasks, downstream repos may need to overwrite the following inner
|
82 |
+
# variable accordingly.
|
83 |
+
|
84 |
+
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
|
85 |
+
init_value_map = {'greater': -inf, 'less': inf}
|
86 |
+
_default_greater_keys = [
|
87 |
+
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
|
88 |
+
'mAcc', 'aAcc'
|
89 |
+
]
|
90 |
+
_default_less_keys = ['loss']
|
91 |
+
|
92 |
+
def __init__(self,
|
93 |
+
dataloader: DataLoader,
|
94 |
+
start: Optional[int] = None,
|
95 |
+
interval: int = 1,
|
96 |
+
by_epoch: bool = True,
|
97 |
+
save_best: Optional[str] = None,
|
98 |
+
rule: Optional[str] = None,
|
99 |
+
test_fn: Optional[Callable] = None,
|
100 |
+
greater_keys: Optional[List[str]] = None,
|
101 |
+
less_keys: Optional[List[str]] = None,
|
102 |
+
out_dir: Optional[str] = None,
|
103 |
+
file_client_args: Optional[dict] = None,
|
104 |
+
**eval_kwargs):
|
105 |
+
if not isinstance(dataloader, DataLoader):
|
106 |
+
raise TypeError(f'dataloader must be a pytorch DataLoader, '
|
107 |
+
f'but got {type(dataloader)}')
|
108 |
+
|
109 |
+
if interval <= 0:
|
110 |
+
raise ValueError(f'interval must be a positive number, '
|
111 |
+
f'but got {interval}')
|
112 |
+
|
113 |
+
assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'
|
114 |
+
|
115 |
+
if start is not None and start < 0:
|
116 |
+
raise ValueError(f'The evaluation start epoch {start} is smaller '
|
117 |
+
f'than 0')
|
118 |
+
|
119 |
+
self.dataloader = dataloader
|
120 |
+
self.interval = interval
|
121 |
+
self.start = start
|
122 |
+
self.by_epoch = by_epoch
|
123 |
+
|
124 |
+
assert isinstance(save_best, str) or save_best is None, \
|
125 |
+
'""save_best"" should be a str or None ' \
|
126 |
+
f'rather than {type(save_best)}'
|
127 |
+
self.save_best = save_best
|
128 |
+
self.eval_kwargs = eval_kwargs
|
129 |
+
self.initial_flag = True
|
130 |
+
|
131 |
+
if test_fn is None:
|
132 |
+
from mmcv.engine import single_gpu_test
|
133 |
+
self.test_fn = single_gpu_test
|
134 |
+
else:
|
135 |
+
self.test_fn = test_fn
|
136 |
+
|
137 |
+
if greater_keys is None:
|
138 |
+
self.greater_keys = self._default_greater_keys
|
139 |
+
else:
|
140 |
+
if not isinstance(greater_keys, (list, tuple)):
|
141 |
+
assert isinstance(greater_keys, str)
|
142 |
+
greater_keys = (greater_keys, )
|
143 |
+
assert is_seq_of(greater_keys, str)
|
144 |
+
self.greater_keys = greater_keys
|
145 |
+
|
146 |
+
if less_keys is None:
|
147 |
+
self.less_keys = self._default_less_keys
|
148 |
+
else:
|
149 |
+
if not isinstance(less_keys, (list, tuple)):
|
150 |
+
assert isinstance(greater_keys, str)
|
151 |
+
less_keys = (less_keys, )
|
152 |
+
assert is_seq_of(less_keys, str)
|
153 |
+
self.less_keys = less_keys
|
154 |
+
|
155 |
+
if self.save_best is not None:
|
156 |
+
self.best_ckpt_path = None
|
157 |
+
self._init_rule(rule, self.save_best)
|
158 |
+
|
159 |
+
self.out_dir = out_dir
|
160 |
+
self.file_client_args = file_client_args
|
161 |
+
|
162 |
+
def _init_rule(self, rule: Optional[str], key_indicator: str):
|
163 |
+
"""Initialize rule, key_indicator, comparison_func, and best score.
|
164 |
+
|
165 |
+
Here is the rule to determine which rule is used for key indicator
|
166 |
+
when the rule is not specific (note that the key indicator matching
|
167 |
+
is case-insensitive):
|
168 |
+
1. If the key indicator is in ``self.greater_keys``, the rule will be
|
169 |
+
specified as 'greater'.
|
170 |
+
2. Or if the key indicator is in ``self.less_keys``, the rule will be
|
171 |
+
specified as 'less'.
|
172 |
+
3. Or if any one item in ``self.greater_keys`` is a substring of
|
173 |
+
key_indicator , the rule will be specified as 'greater'.
|
174 |
+
4. Or if any one item in ``self.less_keys`` is a substring of
|
175 |
+
key_indicator , the rule will be specified as 'less'.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
rule (str | None): Comparison rule for best score.
|
179 |
+
key_indicator (str | None): Key indicator to determine the
|
180 |
+
comparison rule.
|
181 |
+
"""
|
182 |
+
if rule not in self.rule_map and rule is not None:
|
183 |
+
raise KeyError(f'rule must be greater, less or None, '
|
184 |
+
f'but got {rule}.')
|
185 |
+
|
186 |
+
if rule is None:
|
187 |
+
if key_indicator != 'auto':
|
188 |
+
# `_lc` here means we use the lower case of keys for
|
189 |
+
# case-insensitive matching
|
190 |
+
assert isinstance(key_indicator, str)
|
191 |
+
key_indicator_lc = key_indicator.lower()
|
192 |
+
greater_keys = [key.lower() for key in self.greater_keys]
|
193 |
+
less_keys = [key.lower() for key in self.less_keys]
|
194 |
+
|
195 |
+
if key_indicator_lc in greater_keys:
|
196 |
+
rule = 'greater'
|
197 |
+
elif key_indicator_lc in less_keys:
|
198 |
+
rule = 'less'
|
199 |
+
elif any(key in key_indicator_lc for key in greater_keys):
|
200 |
+
rule = 'greater'
|
201 |
+
elif any(key in key_indicator_lc for key in less_keys):
|
202 |
+
rule = 'less'
|
203 |
+
else:
|
204 |
+
raise ValueError(f'Cannot infer the rule for key '
|
205 |
+
f'{key_indicator}, thus a specific rule '
|
206 |
+
f'must be specified.')
|
207 |
+
self.rule = rule
|
208 |
+
self.key_indicator = key_indicator
|
209 |
+
if self.rule is not None:
|
210 |
+
self.compare_func = self.rule_map[self.rule]
|
211 |
+
|
212 |
+
def before_run(self, runner):
|
213 |
+
if not self.out_dir:
|
214 |
+
self.out_dir = runner.work_dir
|
215 |
+
|
216 |
+
self.file_client = FileClient.infer_client(self.file_client_args,
|
217 |
+
self.out_dir)
|
218 |
+
|
219 |
+
# if `self.out_dir` is not equal to `runner.work_dir`, it means that
|
220 |
+
# `self.out_dir` is set so the final `self.out_dir` is the
|
221 |
+
# concatenation of `self.out_dir` and the last level directory of
|
222 |
+
# `runner.work_dir`
|
223 |
+
if self.out_dir != runner.work_dir:
|
224 |
+
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
225 |
+
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
226 |
+
runner.logger.info(
|
227 |
+
f'The best checkpoint will be saved to {self.out_dir} by '
|
228 |
+
f'{self.file_client.name}')
|
229 |
+
|
230 |
+
if self.save_best is not None:
|
231 |
+
if runner.meta is None:
|
232 |
+
warnings.warn('runner.meta is None. Creating an empty one.')
|
233 |
+
runner.meta = dict()
|
234 |
+
runner.meta.setdefault('hook_msgs', dict())
|
235 |
+
self.best_ckpt_path = runner.meta['hook_msgs'].get(
|
236 |
+
'best_ckpt', None)
|
237 |
+
|
238 |
+
def before_train_iter(self, runner):
|
239 |
+
"""Evaluate the model only at the start of training by iteration."""
|
240 |
+
if self.by_epoch or not self.initial_flag:
|
241 |
+
return
|
242 |
+
if self.start is not None and runner.iter >= self.start:
|
243 |
+
self.after_train_iter(runner)
|
244 |
+
self.initial_flag = False
|
245 |
+
|
246 |
+
def before_train_epoch(self, runner):
|
247 |
+
"""Evaluate the model only at the start of training by epoch."""
|
248 |
+
if not (self.by_epoch and self.initial_flag):
|
249 |
+
return
|
250 |
+
if self.start is not None and runner.epoch >= self.start:
|
251 |
+
self.after_train_epoch(runner)
|
252 |
+
self.initial_flag = False
|
253 |
+
|
254 |
+
def after_train_iter(self, runner):
|
255 |
+
"""Called after every training iter to evaluate the results."""
|
256 |
+
if not self.by_epoch and self._should_evaluate(runner):
|
257 |
+
# Because the priority of EvalHook is higher than LoggerHook, the
|
258 |
+
# training log and the evaluating log are mixed. Therefore,
|
259 |
+
# we need to dump the training log and clear it before evaluating
|
260 |
+
# log is generated. In addition, this problem will only appear in
|
261 |
+
# `IterBasedRunner` whose `self.by_epoch` is False, because
|
262 |
+
# `EpochBasedRunner` whose `self.by_epoch` is True calls
|
263 |
+
# `_do_evaluate` in `after_train_epoch` stage, and at this stage
|
264 |
+
# the training log has been printed, so it will not cause any
|
265 |
+
# problem. more details at
|
266 |
+
# https://github.com/open-mmlab/mmsegmentation/issues/694
|
267 |
+
for hook in runner._hooks:
|
268 |
+
if isinstance(hook, LoggerHook):
|
269 |
+
hook.after_train_iter(runner)
|
270 |
+
runner.log_buffer.clear()
|
271 |
+
|
272 |
+
self._do_evaluate(runner)
|
273 |
+
|
274 |
+
def after_train_epoch(self, runner):
|
275 |
+
"""Called after every training epoch to evaluate the results."""
|
276 |
+
if self.by_epoch and self._should_evaluate(runner):
|
277 |
+
self._do_evaluate(runner)
|
278 |
+
|
279 |
+
def _do_evaluate(self, runner):
|
280 |
+
"""perform evaluation and save ckpt."""
|
281 |
+
results = self.test_fn(runner.model, self.dataloader)
|
282 |
+
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
283 |
+
key_score = self.evaluate(runner, results)
|
284 |
+
# the key_score may be `None` so it needs to skip the action to save
|
285 |
+
# the best checkpoint
|
286 |
+
if self.save_best and key_score:
|
287 |
+
self._save_ckpt(runner, key_score)
|
288 |
+
|
289 |
+
def _should_evaluate(self, runner):
|
290 |
+
"""Judge whether to perform evaluation.
|
291 |
+
|
292 |
+
Here is the rule to judge whether to perform evaluation:
|
293 |
+
1. It will not perform evaluation during the epoch/iteration interval,
|
294 |
+
which is determined by ``self.interval``.
|
295 |
+
2. It will not perform evaluation if the start time is larger than
|
296 |
+
current time.
|
297 |
+
3. It will not perform evaluation when current time is larger than
|
298 |
+
the start time but during epoch/iteration interval.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
bool: The flag indicating whether to perform evaluation.
|
302 |
+
"""
|
303 |
+
if self.by_epoch:
|
304 |
+
current = runner.epoch
|
305 |
+
check_time = self.every_n_epochs
|
306 |
+
else:
|
307 |
+
current = runner.iter
|
308 |
+
check_time = self.every_n_iters
|
309 |
+
|
310 |
+
if self.start is None:
|
311 |
+
if not check_time(runner, self.interval):
|
312 |
+
# No evaluation during the interval.
|
313 |
+
return False
|
314 |
+
elif (current + 1) < self.start:
|
315 |
+
# No evaluation if start is larger than the current time.
|
316 |
+
return False
|
317 |
+
else:
|
318 |
+
# Evaluation only at epochs/iters 3, 5, 7...
|
319 |
+
# if start==3 and interval==2
|
320 |
+
if (current + 1 - self.start) % self.interval:
|
321 |
+
return False
|
322 |
+
return True
|
323 |
+
|
324 |
+
def _save_ckpt(self, runner, key_score):
|
325 |
+
"""Save the best checkpoint.
|
326 |
+
|
327 |
+
It will compare the score according to the compare function, write
|
328 |
+
related information (best score, best checkpoint path) and save the
|
329 |
+
best checkpoint into ``work_dir``.
|
330 |
+
"""
|
331 |
+
if self.by_epoch:
|
332 |
+
current = f'epoch_{runner.epoch + 1}'
|
333 |
+
cur_type, cur_time = 'epoch', runner.epoch + 1
|
334 |
+
else:
|
335 |
+
current = f'iter_{runner.iter + 1}'
|
336 |
+
cur_type, cur_time = 'iter', runner.iter + 1
|
337 |
+
|
338 |
+
best_score = runner.meta['hook_msgs'].get(
|
339 |
+
'best_score', self.init_value_map[self.rule])
|
340 |
+
if self.compare_func(key_score, best_score):
|
341 |
+
best_score = key_score
|
342 |
+
runner.meta['hook_msgs']['best_score'] = best_score
|
343 |
+
|
344 |
+
if self.best_ckpt_path and self.file_client.isfile(
|
345 |
+
self.best_ckpt_path):
|
346 |
+
self.file_client.remove(self.best_ckpt_path)
|
347 |
+
runner.logger.info(
|
348 |
+
f'The previous best checkpoint {self.best_ckpt_path} was '
|
349 |
+
'removed')
|
350 |
+
|
351 |
+
best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
|
352 |
+
self.best_ckpt_path = self.file_client.join_path(
|
353 |
+
self.out_dir, best_ckpt_name)
|
354 |
+
runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
|
355 |
+
|
356 |
+
runner.save_checkpoint(
|
357 |
+
self.out_dir,
|
358 |
+
filename_tmpl=best_ckpt_name,
|
359 |
+
create_symlink=False)
|
360 |
+
runner.logger.info(
|
361 |
+
f'Now best checkpoint is saved as {best_ckpt_name}.')
|
362 |
+
runner.logger.info(
|
363 |
+
f'Best {self.key_indicator} is {best_score:0.4f} '
|
364 |
+
f'at {cur_time} {cur_type}.')
|
365 |
+
|
366 |
+
def evaluate(self, runner, results):
|
367 |
+
"""Evaluate the results.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
runner (:obj:`mmcv.Runner`): The underlined training runner.
|
371 |
+
results (list): Output results.
|
372 |
+
"""
|
373 |
+
eval_res = self.dataloader.dataset.evaluate(
|
374 |
+
results, logger=runner.logger, **self.eval_kwargs)
|
375 |
+
|
376 |
+
for name, val in eval_res.items():
|
377 |
+
runner.log_buffer.output[name] = val
|
378 |
+
runner.log_buffer.ready = True
|
379 |
+
|
380 |
+
if self.save_best is not None:
|
381 |
+
# If the performance of model is poor, the `eval_res` may be an
|
382 |
+
# empty dict and it will raise exception when `self.save_best` is
|
383 |
+
# not None. More details at
|
384 |
+
# https://github.com/open-mmlab/mmdetection/issues/6265.
|
385 |
+
if not eval_res:
|
386 |
+
warnings.warn(
|
387 |
+
'Since `eval_res` is an empty dict, the behavior to save '
|
388 |
+
'the best checkpoint will be skipped in this evaluation.')
|
389 |
+
return None
|
390 |
+
|
391 |
+
if self.key_indicator == 'auto':
|
392 |
+
# infer from eval_results
|
393 |
+
self._init_rule(self.rule, list(eval_res.keys())[0])
|
394 |
+
return eval_res[self.key_indicator]
|
395 |
+
|
396 |
+
return None
|
397 |
+
|
398 |
+
|
399 |
+
class _DistEvalHook(_EvalHook):
|
400 |
+
"""Distributed evaluation hook.
|
401 |
+
|
402 |
+
This hook will regularly perform evaluation in a given interval when
|
403 |
+
performing in distributed environment.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
dataloader (DataLoader): A PyTorch dataloader, whose dataset has
|
407 |
+
implemented ``evaluate`` function.
|
408 |
+
start (int | None, optional): Evaluation starting epoch. It enables
|
409 |
+
evaluation before the training starts if ``start`` <= the resuming
|
410 |
+
epoch. If None, whether to evaluate is merely decided by
|
411 |
+
``interval``. Default: None.
|
412 |
+
interval (int): Evaluation interval. Default: 1.
|
413 |
+
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
|
414 |
+
If set to True, it will perform by epoch. Otherwise, by iteration.
|
415 |
+
default: True.
|
416 |
+
save_best (str, optional): If a metric is specified, it would measure
|
417 |
+
the best checkpoint during evaluation. The information about best
|
418 |
+
checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
|
419 |
+
best score value and best checkpoint path, which will be also
|
420 |
+
loaded when resume checkpoint. Options are the evaluation metrics
|
421 |
+
on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
|
422 |
+
detection and instance segmentation. ``AR@100`` for proposal
|
423 |
+
recall. If ``save_best`` is ``auto``, the first key of the returned
|
424 |
+
``OrderedDict`` result will be used. Default: None.
|
425 |
+
rule (str | None, optional): Comparison rule for best score. If set to
|
426 |
+
None, it will infer a reasonable rule. Keys such as 'acc', 'top'
|
427 |
+
.etc will be inferred by 'greater' rule. Keys contain 'loss' will
|
428 |
+
be inferred by 'less' rule. Options are 'greater', 'less', None.
|
429 |
+
Default: None.
|
430 |
+
test_fn (callable, optional): test a model with samples from a
|
431 |
+
dataloader in a multi-gpu manner, and return the test results. If
|
432 |
+
``None``, the default test function ``mmcv.engine.multi_gpu_test``
|
433 |
+
will be used. (default: ``None``)
|
434 |
+
tmpdir (str | None): Temporary directory to save the results of all
|
435 |
+
processes. Default: None.
|
436 |
+
gpu_collect (bool): Whether to use gpu or cpu to collect results.
|
437 |
+
Default: False.
|
438 |
+
broadcast_bn_buffer (bool): Whether to broadcast the
|
439 |
+
buffer(running_mean and running_var) of rank 0 to other rank
|
440 |
+
before evaluation. Default: True.
|
441 |
+
out_dir (str, optional): The root directory to save checkpoints. If not
|
442 |
+
specified, `runner.work_dir` will be used by default. If specified,
|
443 |
+
the `out_dir` will be the concatenation of `out_dir` and the last
|
444 |
+
level directory of `runner.work_dir`.
|
445 |
+
file_client_args (dict): Arguments to instantiate a FileClient.
|
446 |
+
See :class:`mmcv.fileio.FileClient` for details. Default: None.
|
447 |
+
**eval_kwargs: Evaluation arguments fed into the evaluate function of
|
448 |
+
the dataset.
|
449 |
+
"""
|
450 |
+
|
451 |
+
def __init__(self,
|
452 |
+
dataloader: DataLoader,
|
453 |
+
start: Optional[int] = None,
|
454 |
+
interval: int = 1,
|
455 |
+
by_epoch: bool = True,
|
456 |
+
save_best: Optional[str] = None,
|
457 |
+
rule: Optional[str] = None,
|
458 |
+
test_fn: Optional[Callable] = None,
|
459 |
+
greater_keys: Optional[List[str]] = None,
|
460 |
+
less_keys: Optional[List[str]] = None,
|
461 |
+
broadcast_bn_buffer: bool = True,
|
462 |
+
tmpdir: Optional[str] = None,
|
463 |
+
gpu_collect: bool = False,
|
464 |
+
out_dir: Optional[str] = None,
|
465 |
+
file_client_args: Optional[dict] = None,
|
466 |
+
**eval_kwargs):
|
467 |
+
|
468 |
+
if test_fn is None:
|
469 |
+
from mmcv.engine import multi_gpu_test
|
470 |
+
test_fn = multi_gpu_test
|
471 |
+
|
472 |
+
super().__init__(
|
473 |
+
dataloader,
|
474 |
+
start=start,
|
475 |
+
interval=interval,
|
476 |
+
by_epoch=by_epoch,
|
477 |
+
save_best=save_best,
|
478 |
+
rule=rule,
|
479 |
+
test_fn=test_fn,
|
480 |
+
greater_keys=greater_keys,
|
481 |
+
less_keys=less_keys,
|
482 |
+
out_dir=out_dir,
|
483 |
+
file_client_args=file_client_args,
|
484 |
+
**eval_kwargs)
|
485 |
+
|
486 |
+
self.broadcast_bn_buffer = broadcast_bn_buffer
|
487 |
+
self.tmpdir = tmpdir
|
488 |
+
self.gpu_collect = gpu_collect
|
489 |
+
|
490 |
+
def _do_evaluate(self, runner):
|
491 |
+
"""perform evaluation and save ckpt."""
|
492 |
+
# Synchronization of BatchNorm's buffer (running_mean
|
493 |
+
# and running_var) is not supported in the DDP of pytorch,
|
494 |
+
# which may cause the inconsistent performance of models in
|
495 |
+
# different ranks, so we broadcast BatchNorm's buffers
|
496 |
+
# of rank 0 to other ranks to avoid this.
|
497 |
+
if self.broadcast_bn_buffer:
|
498 |
+
model = runner.model
|
499 |
+
for name, module in model.named_modules():
|
500 |
+
if isinstance(module,
|
501 |
+
_BatchNorm) and module.track_running_stats:
|
502 |
+
dist.broadcast(module.running_var, 0)
|
503 |
+
dist.broadcast(module.running_mean, 0)
|
504 |
+
|
505 |
+
tmpdir = self.tmpdir
|
506 |
+
if tmpdir is None:
|
507 |
+
tmpdir = osp.join(runner.work_dir, '.eval_hook')
|
508 |
+
|
509 |
+
results = self.test_fn(
|
510 |
+
runner.model,
|
511 |
+
self.dataloader,
|
512 |
+
tmpdir=tmpdir,
|
513 |
+
gpu_collect=self.gpu_collect)
|
514 |
+
if runner.rank == 0:
|
515 |
+
print('\n')
|
516 |
+
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
517 |
+
key_score = self.evaluate(runner, results)
|
518 |
+
# the key_score may be `None` so it needs to skip the action to
|
519 |
+
# save the best checkpoint
|
520 |
+
if self.save_best and key_score:
|
521 |
+
self._save_ckpt(runner, key_score)
|
522 |
|
523 |
class EvalHook(_EvalHook):
|
524 |
|
main/transformer_utils/mmpose/core/fp16/hooks.py
CHANGED
@@ -1,15 +1,90 @@
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
import copy
|
|
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
-
from
|
7 |
-
from
|
|
|
|
|
8 |
|
9 |
from ..utils.dist_utils import allreduce_grads
|
10 |
from .utils import cast_tensor_type
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
class Fp16OptimizerHook(OptimizerHook):
|
14 |
"""FP16 optimizer hook.
|
15 |
|
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
import copy
|
3 |
+
import logging
|
4 |
+
from typing import Optional
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.nn.utils import clip_grad
|
10 |
+
from mmengine.hooks import Hook
|
11 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
12 |
|
13 |
from ..utils.dist_utils import allreduce_grads
|
14 |
from .utils import cast_tensor_type
|
15 |
|
16 |
|
17 |
+
class OptimizerHook(Hook):
|
18 |
+
"""A hook contains custom operations for the optimizer.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
grad_clip (dict, optional): A config dict to control the clip_grad.
|
22 |
+
Default: None.
|
23 |
+
detect_anomalous_params (bool): This option is only used for
|
24 |
+
debugging which will slow down the training speed.
|
25 |
+
Detect anomalous parameters that are not included in
|
26 |
+
the computational graph with `loss` as the root.
|
27 |
+
There are two cases
|
28 |
+
|
29 |
+
- Parameters were not used during
|
30 |
+
forward pass.
|
31 |
+
- Parameters were not used to produce
|
32 |
+
loss.
|
33 |
+
Default: False.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
grad_clip: Optional[dict] = None,
|
38 |
+
detect_anomalous_params: bool = False):
|
39 |
+
self.grad_clip = grad_clip
|
40 |
+
self.detect_anomalous_params = detect_anomalous_params
|
41 |
+
|
42 |
+
def clip_grads(self, params):
|
43 |
+
params = list(
|
44 |
+
filter(lambda p: p.requires_grad and p.grad is not None, params))
|
45 |
+
if len(params) > 0:
|
46 |
+
return clip_grad.clip_grad_norm_(params, **self.grad_clip)
|
47 |
+
|
48 |
+
def after_train_iter(self, runner):
|
49 |
+
runner.optimizer.zero_grad()
|
50 |
+
if self.detect_anomalous_params:
|
51 |
+
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
|
52 |
+
runner.outputs['loss'].backward()
|
53 |
+
|
54 |
+
if self.grad_clip is not None:
|
55 |
+
grad_norm = self.clip_grads(runner.model.parameters())
|
56 |
+
if grad_norm is not None:
|
57 |
+
# Add grad norm to the logger
|
58 |
+
runner.log_buffer.update({'grad_norm': float(grad_norm)},
|
59 |
+
runner.outputs['num_samples'])
|
60 |
+
runner.optimizer.step()
|
61 |
+
|
62 |
+
def detect_anomalous_parameters(self, loss: Tensor, runner) -> None:
|
63 |
+
logger = runner.logger
|
64 |
+
parameters_in_graph = set()
|
65 |
+
visited = set()
|
66 |
+
|
67 |
+
def traverse(grad_fn):
|
68 |
+
if grad_fn is None:
|
69 |
+
return
|
70 |
+
if grad_fn not in visited:
|
71 |
+
visited.add(grad_fn)
|
72 |
+
if hasattr(grad_fn, 'variable'):
|
73 |
+
parameters_in_graph.add(grad_fn.variable)
|
74 |
+
parents = grad_fn.next_functions
|
75 |
+
if parents is not None:
|
76 |
+
for parent in parents:
|
77 |
+
grad_fn = parent[0]
|
78 |
+
traverse(grad_fn)
|
79 |
+
|
80 |
+
traverse(loss.grad_fn)
|
81 |
+
for n, p in runner.model.named_parameters():
|
82 |
+
if p not in parameters_in_graph and p.requires_grad:
|
83 |
+
logger.log(
|
84 |
+
level=logging.ERROR,
|
85 |
+
msg=f'{n} with shape {p.size()} is not '
|
86 |
+
f'in the computational graph \n')
|
87 |
+
|
88 |
class Fp16OptimizerHook(OptimizerHook):
|
89 |
"""FP16 optimizer hook.
|
90 |
|
main/transformer_utils/mmpose/core/optimizers/builder.py
CHANGED
@@ -1,24 +1,37 @@
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
|
3 |
-
from
|
4 |
-
from mmcv.
|
|
|
|
|
5 |
|
6 |
OPTIMIZERS = Registry('optimizers')
|
7 |
-
OPTIMIZER_BUILDERS = Registry(
|
8 |
-
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
|
9 |
|
10 |
|
11 |
def build_optimizer_constructor(cfg):
|
12 |
constructor_type = cfg.get('type')
|
13 |
if constructor_type in OPTIMIZER_BUILDERS:
|
14 |
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
|
15 |
-
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
|
16 |
-
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
|
17 |
else:
|
18 |
raise KeyError(f'{constructor_type} is not registered '
|
19 |
'in the optimizer builder registry.')
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def build_optimizers(model, cfgs):
|
23 |
"""Build multiple optimizers from configs.
|
24 |
|
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
from typing import Dict
|
4 |
+
# from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
|
5 |
+
from mmengine import Registry
|
6 |
+
from mmengine.registry import build_from_cfg
|
7 |
|
8 |
OPTIMIZERS = Registry('optimizers')
|
9 |
+
OPTIMIZER_BUILDERS = Registry('optimizer builder')
|
|
|
10 |
|
11 |
|
12 |
def build_optimizer_constructor(cfg):
|
13 |
constructor_type = cfg.get('type')
|
14 |
if constructor_type in OPTIMIZER_BUILDERS:
|
15 |
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
|
|
|
|
|
16 |
else:
|
17 |
raise KeyError(f'{constructor_type} is not registered '
|
18 |
'in the optimizer builder registry.')
|
19 |
|
20 |
|
21 |
+
def build_optimizer(model, cfg: Dict):
|
22 |
+
optimizer_cfg = copy.deepcopy(cfg)
|
23 |
+
constructor_type = optimizer_cfg.pop('constructor',
|
24 |
+
'DefaultOptimizerConstructor')
|
25 |
+
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
|
26 |
+
optim_constructor = build_optimizer_constructor(
|
27 |
+
dict(
|
28 |
+
type=constructor_type,
|
29 |
+
optimizer_cfg=optimizer_cfg,
|
30 |
+
paramwise_cfg=paramwise_cfg))
|
31 |
+
optimizer = optim_constructor(model)
|
32 |
+
return optimizer
|
33 |
+
|
34 |
+
|
35 |
def build_optimizers(model, cfgs):
|
36 |
"""Build multiple optimizers from configs.
|
37 |
|
main/transformer_utils/mmpose/core/optimizers/layer_decay_optimizer_constructor.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
import json
|
3 |
import warnings
|
4 |
-
|
5 |
-
from mmcv.runner import DefaultOptimizerConstructor
|
6 |
|
7 |
from mmpose.utils import get_root_logger
|
8 |
from .builder import OPTIMIZER_BUILDERS
|
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
import json
|
3 |
import warnings
|
4 |
+
from mmengine.dist import get_dist_info
|
5 |
+
from mmcv.runner import DefaultOptimizerConstructor
|
6 |
|
7 |
from mmpose.utils import get_root_logger
|
8 |
from .builder import OPTIMIZER_BUILDERS
|
main/transformer_utils/mmpose/core/post_processing/smoother.py
CHANGED
@@ -4,8 +4,8 @@ import warnings
|
|
4 |
from typing import Dict, Union
|
5 |
|
6 |
import numpy as np
|
7 |
-
from
|
8 |
-
|
9 |
from mmpose.core.post_processing.temporal_filters import build_filter
|
10 |
|
11 |
|
|
|
4 |
from typing import Dict, Union
|
5 |
|
6 |
import numpy as np
|
7 |
+
from mmengine.config import Config
|
8 |
+
from mmengine.utils import is_seq_of
|
9 |
from mmpose.core.post_processing.temporal_filters import build_filter
|
10 |
|
11 |
|
main/transformer_utils/mmpose/core/post_processing/temporal_filters/builder.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
from
|
3 |
|
4 |
FILTERS = Registry('filters')
|
5 |
|
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmengine import Registry
|
3 |
|
4 |
FILTERS = Registry('filters')
|
5 |
|
main/transformer_utils/mmpose/core/post_processing/temporal_filters/smoothnet_filter.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Optional
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
from
|
7 |
from torch import Tensor, nn
|
8 |
|
9 |
from .builder import FILTERS
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
+
from mmengine.runner import load_checkpoint
|
7 |
from torch import Tensor, nn
|
8 |
|
9 |
from .builder import FILTERS
|
main/transformer_utils/mmpose/core/utils/dist_utils.py
CHANGED
@@ -4,7 +4,7 @@ from collections import OrderedDict
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.distributed as dist
|
7 |
-
from
|
8 |
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
9 |
_unflatten_dense_tensors)
|
10 |
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import torch.distributed as dist
|
7 |
+
from mmengine.dist import get_dist_info
|
8 |
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
9 |
_unflatten_dense_tensors)
|
10 |
|
main/transformer_utils/mmpose/core/utils/model_util_hooks.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
from
|
3 |
-
|
4 |
|
5 |
@HOOKS.register_module()
|
6 |
class ModelSetEpochHook(Hook):
|
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from mmengine.registry import HOOKS
|
3 |
+
from mmengine.hooks import Hook
|
4 |
|
5 |
@HOOKS.register_module()
|
6 |
class ModelSetEpochHook(Hook):
|
main/transformer_utils/mmpose/core/visualization/image.py
CHANGED
@@ -7,7 +7,7 @@ import cv2
|
|
7 |
import mmcv
|
8 |
import numpy as np
|
9 |
from matplotlib import pyplot as plt
|
10 |
-
from
|
11 |
from mmcv.visualization.color import color_val
|
12 |
|
13 |
try:
|
|
|
7 |
import mmcv
|
8 |
import numpy as np
|
9 |
from matplotlib import pyplot as plt
|
10 |
+
from mmengine.utils import deprecated_api_warning
|
11 |
from mmcv.visualization.color import color_val
|
12 |
|
13 |
try:
|
main/transformer_utils/mmpose/models/__init__.py
CHANGED
@@ -3,9 +3,9 @@ from .builder import (BACKBONES, HEADS, LOSSES, MESH_MODELS, NECKS, POSENETS,
|
|
3 |
build_backbone, build_head, build_loss, build_mesh_model,
|
4 |
build_neck, build_posenet)
|
5 |
from .detectors import * # noqa
|
|
|
6 |
from .heads import * # noqa
|
7 |
from .losses import * # noqa
|
8 |
-
from .necks import * # noqa
|
9 |
from .utils import * # noqa
|
10 |
|
11 |
|
|
|
3 |
build_backbone, build_head, build_loss, build_mesh_model,
|
4 |
build_neck, build_posenet)
|
5 |
from .detectors import * # noqa
|
6 |
+
from .backbones import *
|
7 |
from .heads import * # noqa
|
8 |
from .losses import * # noqa
|
|
|
9 |
from .utils import * # noqa
|
10 |
|
11 |
|
main/transformer_utils/mmpose/models/backbones/__init__.py
CHANGED
@@ -1,41 +1,42 @@
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
from .alexnet import AlexNet
|
3 |
-
from .cpm import CPM
|
4 |
-
from .hourglass import HourglassNet
|
5 |
-
from .hourglass_ae import HourglassAENet
|
6 |
-
from .hrformer import HRFormer
|
7 |
-
from .hrnet import HRNet
|
8 |
-
from .i3d import I3D
|
9 |
-
from .litehrnet import LiteHRNet
|
10 |
-
from .mobilenet_v2 import MobileNetV2
|
11 |
-
from .mobilenet_v3 import MobileNetV3
|
12 |
-
from .mspn import MSPN
|
13 |
-
from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2
|
14 |
-
from .regnet import RegNet
|
15 |
-
from .resnest import ResNeSt
|
16 |
-
from .resnet import ResNet, ResNetV1d
|
17 |
-
from .resnext import ResNeXt
|
18 |
-
from .rsn import RSN
|
19 |
-
from .scnet import SCNet
|
20 |
-
from .seresnet import SEResNet
|
21 |
-
from .seresnext import SEResNeXt
|
22 |
-
from .shufflenet_v1 import ShuffleNetV1
|
23 |
-
from .shufflenet_v2 import ShuffleNetV2
|
24 |
-
from .swin import SwinTransformer
|
25 |
-
from .tcformer import TCFormer
|
26 |
-
from .tcn import TCN
|
27 |
-
from .v2v_net import V2VNet
|
28 |
-
from .vgg import VGG
|
29 |
-
from .vipnas_mbv3 import ViPNAS_MobileNetV3
|
30 |
-
from .vipnas_resnet import ViPNAS_ResNet
|
31 |
-
from .hrt import HRT
|
32 |
from .vit import ViT
|
33 |
|
34 |
-
__all__ = [
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
]
|
|
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
# from .alexnet import AlexNet
|
3 |
+
# from .cpm import CPM
|
4 |
+
# from .hourglass import HourglassNet
|
5 |
+
# from .hourglass_ae import HourglassAENet
|
6 |
+
# from .hrformer import HRFormer
|
7 |
+
# from .hrnet import HRNet
|
8 |
+
# from .i3d import I3D
|
9 |
+
# from .litehrnet import LiteHRNet
|
10 |
+
# from .mobilenet_v2 import MobileNetV2
|
11 |
+
# from .mobilenet_v3 import MobileNetV3
|
12 |
+
# from .mspn import MSPN
|
13 |
+
# from .pvt import PyramidVisionTransformer, PyramidVisionTransformerV2
|
14 |
+
# from .regnet import RegNet
|
15 |
+
# from .resnest import ResNeSt
|
16 |
+
# from .resnet import ResNet, ResNetV1d
|
17 |
+
# from .resnext import ResNeXt
|
18 |
+
# from .rsn import RSN
|
19 |
+
# from .scnet import SCNet
|
20 |
+
# from .seresnet import SEResNet
|
21 |
+
# from .seresnext import SEResNeXt
|
22 |
+
# from .shufflenet_v1 import ShuffleNetV1
|
23 |
+
# from .shufflenet_v2 import ShuffleNetV2
|
24 |
+
# from .swin import SwinTransformer
|
25 |
+
# from .tcformer import TCFormer
|
26 |
+
# from .tcn import TCN
|
27 |
+
# from .v2v_net import V2VNet
|
28 |
+
# from .vgg import VGG
|
29 |
+
# from .vipnas_mbv3 import ViPNAS_MobileNetV3
|
30 |
+
# from .vipnas_resnet import ViPNAS_ResNet
|
31 |
+
# from .hrt import HRT
|
32 |
from .vit import ViT
|
33 |
|
34 |
+
# __all__ = [
|
35 |
+
# 'AlexNet', 'HourglassNet', 'HourglassAENet', 'HRNet', 'MobileNetV2',
|
36 |
+
# 'MobileNetV3', 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SCNet',
|
37 |
+
# 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN',
|
38 |
+
# 'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3',
|
39 |
+
# 'LiteHRNet', 'V2VNet', 'HRFormer', 'PyramidVisionTransformer',
|
40 |
+
# 'PyramidVisionTransformerV2', 'SwinTransformer', 'I3D', 'TCFormer', 'ViT'
|
41 |
+
# ]
|
42 |
+
__all__ = ['ViT']
|
main/transformer_utils/mmpose/models/backbones/alexnet.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import torch.nn as nn
|
3 |
-
|
4 |
-
from ..builder import BACKBONES
|
5 |
-
from .base_backbone import BaseBackbone
|
6 |
-
|
7 |
-
|
8 |
-
@BACKBONES.register_module()
|
9 |
-
class AlexNet(BaseBackbone):
|
10 |
-
"""`AlexNet <https://en.wikipedia.org/wiki/AlexNet>`__ backbone.
|
11 |
-
|
12 |
-
The input for AlexNet is a 224x224 RGB image.
|
13 |
-
|
14 |
-
Args:
|
15 |
-
num_classes (int): number of classes for classification.
|
16 |
-
The default value is -1, which uses the backbone as
|
17 |
-
a feature extractor without the top classifier.
|
18 |
-
"""
|
19 |
-
|
20 |
-
def __init__(self, num_classes=-1):
|
21 |
-
super().__init__()
|
22 |
-
self.num_classes = num_classes
|
23 |
-
self.features = nn.Sequential(
|
24 |
-
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
25 |
-
nn.ReLU(inplace=True),
|
26 |
-
nn.MaxPool2d(kernel_size=3, stride=2),
|
27 |
-
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
28 |
-
nn.ReLU(inplace=True),
|
29 |
-
nn.MaxPool2d(kernel_size=3, stride=2),
|
30 |
-
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
31 |
-
nn.ReLU(inplace=True),
|
32 |
-
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
33 |
-
nn.ReLU(inplace=True),
|
34 |
-
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
35 |
-
nn.ReLU(inplace=True),
|
36 |
-
nn.MaxPool2d(kernel_size=3, stride=2),
|
37 |
-
)
|
38 |
-
if self.num_classes > 0:
|
39 |
-
self.classifier = nn.Sequential(
|
40 |
-
nn.Dropout(),
|
41 |
-
nn.Linear(256 * 6 * 6, 4096),
|
42 |
-
nn.ReLU(inplace=True),
|
43 |
-
nn.Dropout(),
|
44 |
-
nn.Linear(4096, 4096),
|
45 |
-
nn.ReLU(inplace=True),
|
46 |
-
nn.Linear(4096, num_classes),
|
47 |
-
)
|
48 |
-
|
49 |
-
def forward(self, x):
|
50 |
-
|
51 |
-
x = self.features(x)
|
52 |
-
if self.num_classes > 0:
|
53 |
-
x = x.view(x.size(0), 256 * 6 * 6)
|
54 |
-
x = self.classifier(x)
|
55 |
-
|
56 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/cpm.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
from mmcv.cnn import ConvModule, constant_init, normal_init
|
7 |
-
from torch.nn.modules.batchnorm import _BatchNorm
|
8 |
-
|
9 |
-
from mmpose.utils import get_root_logger
|
10 |
-
from ..builder import BACKBONES
|
11 |
-
from .base_backbone import BaseBackbone
|
12 |
-
from .utils import load_checkpoint
|
13 |
-
|
14 |
-
|
15 |
-
class CpmBlock(nn.Module):
|
16 |
-
"""CpmBlock for Convolutional Pose Machine.
|
17 |
-
|
18 |
-
Args:
|
19 |
-
in_channels (int): Input channels of this block.
|
20 |
-
channels (list): Output channels of each conv module.
|
21 |
-
kernels (list): Kernel sizes of each conv module.
|
22 |
-
"""
|
23 |
-
|
24 |
-
def __init__(self,
|
25 |
-
in_channels,
|
26 |
-
channels=(128, 128, 128),
|
27 |
-
kernels=(11, 11, 11),
|
28 |
-
norm_cfg=None):
|
29 |
-
super().__init__()
|
30 |
-
|
31 |
-
assert len(channels) == len(kernels)
|
32 |
-
layers = []
|
33 |
-
for i in range(len(channels)):
|
34 |
-
if i == 0:
|
35 |
-
input_channels = in_channels
|
36 |
-
else:
|
37 |
-
input_channels = channels[i - 1]
|
38 |
-
layers.append(
|
39 |
-
ConvModule(
|
40 |
-
input_channels,
|
41 |
-
channels[i],
|
42 |
-
kernels[i],
|
43 |
-
padding=(kernels[i] - 1) // 2,
|
44 |
-
norm_cfg=norm_cfg))
|
45 |
-
self.model = nn.Sequential(*layers)
|
46 |
-
|
47 |
-
def forward(self, x):
|
48 |
-
"""Model forward function."""
|
49 |
-
out = self.model(x)
|
50 |
-
return out
|
51 |
-
|
52 |
-
|
53 |
-
@BACKBONES.register_module()
|
54 |
-
class CPM(BaseBackbone):
|
55 |
-
"""CPM backbone.
|
56 |
-
|
57 |
-
Convolutional Pose Machines.
|
58 |
-
More details can be found in the `paper
|
59 |
-
<https://arxiv.org/abs/1602.00134>`__ .
|
60 |
-
|
61 |
-
Args:
|
62 |
-
in_channels (int): The input channels of the CPM.
|
63 |
-
out_channels (int): The output channels of the CPM.
|
64 |
-
feat_channels (int): Feature channel of each CPM stage.
|
65 |
-
middle_channels (int): Feature channel of conv after the middle stage.
|
66 |
-
num_stages (int): Number of stages.
|
67 |
-
norm_cfg (dict): Dictionary to construct and config norm layer.
|
68 |
-
|
69 |
-
Example:
|
70 |
-
>>> from mmpose.models import CPM
|
71 |
-
>>> import torch
|
72 |
-
>>> self = CPM(3, 17)
|
73 |
-
>>> self.eval()
|
74 |
-
>>> inputs = torch.rand(1, 3, 368, 368)
|
75 |
-
>>> level_outputs = self.forward(inputs)
|
76 |
-
>>> for level_output in level_outputs:
|
77 |
-
... print(tuple(level_output.shape))
|
78 |
-
(1, 17, 46, 46)
|
79 |
-
(1, 17, 46, 46)
|
80 |
-
(1, 17, 46, 46)
|
81 |
-
(1, 17, 46, 46)
|
82 |
-
(1, 17, 46, 46)
|
83 |
-
(1, 17, 46, 46)
|
84 |
-
"""
|
85 |
-
|
86 |
-
def __init__(self,
|
87 |
-
in_channels,
|
88 |
-
out_channels,
|
89 |
-
feat_channels=128,
|
90 |
-
middle_channels=32,
|
91 |
-
num_stages=6,
|
92 |
-
norm_cfg=dict(type='BN', requires_grad=True)):
|
93 |
-
# Protect mutable default arguments
|
94 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
95 |
-
super().__init__()
|
96 |
-
|
97 |
-
assert in_channels == 3
|
98 |
-
|
99 |
-
self.num_stages = num_stages
|
100 |
-
assert self.num_stages >= 1
|
101 |
-
|
102 |
-
self.stem = nn.Sequential(
|
103 |
-
ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg),
|
104 |
-
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
105 |
-
ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
|
106 |
-
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
107 |
-
ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
|
108 |
-
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
109 |
-
ConvModule(128, 32, 5, padding=2, norm_cfg=norm_cfg),
|
110 |
-
ConvModule(32, 512, 9, padding=4, norm_cfg=norm_cfg),
|
111 |
-
ConvModule(512, 512, 1, padding=0, norm_cfg=norm_cfg),
|
112 |
-
ConvModule(512, out_channels, 1, padding=0, act_cfg=None))
|
113 |
-
|
114 |
-
self.middle = nn.Sequential(
|
115 |
-
ConvModule(in_channels, 128, 9, padding=4, norm_cfg=norm_cfg),
|
116 |
-
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
117 |
-
ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
|
118 |
-
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
119 |
-
ConvModule(128, 128, 9, padding=4, norm_cfg=norm_cfg),
|
120 |
-
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
121 |
-
|
122 |
-
self.cpm_stages = nn.ModuleList([
|
123 |
-
CpmBlock(
|
124 |
-
middle_channels + out_channels,
|
125 |
-
channels=[feat_channels, feat_channels, feat_channels],
|
126 |
-
kernels=[11, 11, 11],
|
127 |
-
norm_cfg=norm_cfg) for _ in range(num_stages - 1)
|
128 |
-
])
|
129 |
-
|
130 |
-
self.middle_conv = nn.ModuleList([
|
131 |
-
nn.Sequential(
|
132 |
-
ConvModule(
|
133 |
-
128, middle_channels, 5, padding=2, norm_cfg=norm_cfg))
|
134 |
-
for _ in range(num_stages - 1)
|
135 |
-
])
|
136 |
-
|
137 |
-
self.out_convs = nn.ModuleList([
|
138 |
-
nn.Sequential(
|
139 |
-
ConvModule(
|
140 |
-
feat_channels,
|
141 |
-
feat_channels,
|
142 |
-
1,
|
143 |
-
padding=0,
|
144 |
-
norm_cfg=norm_cfg),
|
145 |
-
ConvModule(feat_channels, out_channels, 1, act_cfg=None))
|
146 |
-
for _ in range(num_stages - 1)
|
147 |
-
])
|
148 |
-
|
149 |
-
def init_weights(self, pretrained=None):
|
150 |
-
"""Initialize the weights in backbone.
|
151 |
-
|
152 |
-
Args:
|
153 |
-
pretrained (str, optional): Path to pre-trained weights.
|
154 |
-
Defaults to None.
|
155 |
-
"""
|
156 |
-
if isinstance(pretrained, str):
|
157 |
-
logger = get_root_logger()
|
158 |
-
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
159 |
-
elif pretrained is None:
|
160 |
-
for m in self.modules():
|
161 |
-
if isinstance(m, nn.Conv2d):
|
162 |
-
normal_init(m, std=0.001)
|
163 |
-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
164 |
-
constant_init(m, 1)
|
165 |
-
else:
|
166 |
-
raise TypeError('pretrained must be a str or None')
|
167 |
-
|
168 |
-
def forward(self, x):
|
169 |
-
"""Model forward function."""
|
170 |
-
stage1_out = self.stem(x)
|
171 |
-
middle_out = self.middle(x)
|
172 |
-
out_feats = []
|
173 |
-
|
174 |
-
out_feats.append(stage1_out)
|
175 |
-
|
176 |
-
for ind in range(self.num_stages - 1):
|
177 |
-
single_stage = self.cpm_stages[ind]
|
178 |
-
out_conv = self.out_convs[ind]
|
179 |
-
|
180 |
-
inp_feat = torch.cat(
|
181 |
-
[out_feats[-1], self.middle_conv[ind](middle_out)], 1)
|
182 |
-
cpm_feat = single_stage(inp_feat)
|
183 |
-
out_feat = out_conv(cpm_feat)
|
184 |
-
out_feats.append(out_feat)
|
185 |
-
|
186 |
-
return out_feats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/hourglass.py
DELETED
@@ -1,212 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
|
4 |
-
import torch.nn as nn
|
5 |
-
from mmcv.cnn import ConvModule, constant_init, normal_init
|
6 |
-
from torch.nn.modules.batchnorm import _BatchNorm
|
7 |
-
|
8 |
-
from mmpose.utils import get_root_logger
|
9 |
-
from ..builder import BACKBONES
|
10 |
-
from .base_backbone import BaseBackbone
|
11 |
-
from .resnet import BasicBlock, ResLayer
|
12 |
-
from .utils import load_checkpoint
|
13 |
-
|
14 |
-
|
15 |
-
class HourglassModule(nn.Module):
|
16 |
-
"""Hourglass Module for HourglassNet backbone.
|
17 |
-
|
18 |
-
Generate module recursively and use BasicBlock as the base unit.
|
19 |
-
|
20 |
-
Args:
|
21 |
-
depth (int): Depth of current HourglassModule.
|
22 |
-
stage_channels (list[int]): Feature channels of sub-modules in current
|
23 |
-
and follow-up HourglassModule.
|
24 |
-
stage_blocks (list[int]): Number of sub-modules stacked in current and
|
25 |
-
follow-up HourglassModule.
|
26 |
-
norm_cfg (dict): Dictionary to construct and config norm layer.
|
27 |
-
"""
|
28 |
-
|
29 |
-
def __init__(self,
|
30 |
-
depth,
|
31 |
-
stage_channels,
|
32 |
-
stage_blocks,
|
33 |
-
norm_cfg=dict(type='BN', requires_grad=True)):
|
34 |
-
# Protect mutable default arguments
|
35 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
36 |
-
super().__init__()
|
37 |
-
|
38 |
-
self.depth = depth
|
39 |
-
|
40 |
-
cur_block = stage_blocks[0]
|
41 |
-
next_block = stage_blocks[1]
|
42 |
-
|
43 |
-
cur_channel = stage_channels[0]
|
44 |
-
next_channel = stage_channels[1]
|
45 |
-
|
46 |
-
self.up1 = ResLayer(
|
47 |
-
BasicBlock, cur_block, cur_channel, cur_channel, norm_cfg=norm_cfg)
|
48 |
-
|
49 |
-
self.low1 = ResLayer(
|
50 |
-
BasicBlock,
|
51 |
-
cur_block,
|
52 |
-
cur_channel,
|
53 |
-
next_channel,
|
54 |
-
stride=2,
|
55 |
-
norm_cfg=norm_cfg)
|
56 |
-
|
57 |
-
if self.depth > 1:
|
58 |
-
self.low2 = HourglassModule(depth - 1, stage_channels[1:],
|
59 |
-
stage_blocks[1:])
|
60 |
-
else:
|
61 |
-
self.low2 = ResLayer(
|
62 |
-
BasicBlock,
|
63 |
-
next_block,
|
64 |
-
next_channel,
|
65 |
-
next_channel,
|
66 |
-
norm_cfg=norm_cfg)
|
67 |
-
|
68 |
-
self.low3 = ResLayer(
|
69 |
-
BasicBlock,
|
70 |
-
cur_block,
|
71 |
-
next_channel,
|
72 |
-
cur_channel,
|
73 |
-
norm_cfg=norm_cfg,
|
74 |
-
downsample_first=False)
|
75 |
-
|
76 |
-
self.up2 = nn.Upsample(scale_factor=2)
|
77 |
-
|
78 |
-
def forward(self, x):
|
79 |
-
"""Model forward function."""
|
80 |
-
up1 = self.up1(x)
|
81 |
-
low1 = self.low1(x)
|
82 |
-
low2 = self.low2(low1)
|
83 |
-
low3 = self.low3(low2)
|
84 |
-
up2 = self.up2(low3)
|
85 |
-
return up1 + up2
|
86 |
-
|
87 |
-
|
88 |
-
@BACKBONES.register_module()
|
89 |
-
class HourglassNet(BaseBackbone):
|
90 |
-
"""HourglassNet backbone.
|
91 |
-
|
92 |
-
Stacked Hourglass Networks for Human Pose Estimation.
|
93 |
-
More details can be found in the `paper
|
94 |
-
<https://arxiv.org/abs/1603.06937>`__ .
|
95 |
-
|
96 |
-
Args:
|
97 |
-
downsample_times (int): Downsample times in a HourglassModule.
|
98 |
-
num_stacks (int): Number of HourglassModule modules stacked,
|
99 |
-
1 for Hourglass-52, 2 for Hourglass-104.
|
100 |
-
stage_channels (list[int]): Feature channel of each sub-module in a
|
101 |
-
HourglassModule.
|
102 |
-
stage_blocks (list[int]): Number of sub-modules stacked in a
|
103 |
-
HourglassModule.
|
104 |
-
feat_channel (int): Feature channel of conv after a HourglassModule.
|
105 |
-
norm_cfg (dict): Dictionary to construct and config norm layer.
|
106 |
-
|
107 |
-
Example:
|
108 |
-
>>> from mmpose.models import HourglassNet
|
109 |
-
>>> import torch
|
110 |
-
>>> self = HourglassNet()
|
111 |
-
>>> self.eval()
|
112 |
-
>>> inputs = torch.rand(1, 3, 511, 511)
|
113 |
-
>>> level_outputs = self.forward(inputs)
|
114 |
-
>>> for level_output in level_outputs:
|
115 |
-
... print(tuple(level_output.shape))
|
116 |
-
(1, 256, 128, 128)
|
117 |
-
(1, 256, 128, 128)
|
118 |
-
"""
|
119 |
-
|
120 |
-
def __init__(self,
|
121 |
-
downsample_times=5,
|
122 |
-
num_stacks=2,
|
123 |
-
stage_channels=(256, 256, 384, 384, 384, 512),
|
124 |
-
stage_blocks=(2, 2, 2, 2, 2, 4),
|
125 |
-
feat_channel=256,
|
126 |
-
norm_cfg=dict(type='BN', requires_grad=True)):
|
127 |
-
# Protect mutable default arguments
|
128 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
129 |
-
super().__init__()
|
130 |
-
|
131 |
-
self.num_stacks = num_stacks
|
132 |
-
assert self.num_stacks >= 1
|
133 |
-
assert len(stage_channels) == len(stage_blocks)
|
134 |
-
assert len(stage_channels) > downsample_times
|
135 |
-
|
136 |
-
cur_channel = stage_channels[0]
|
137 |
-
|
138 |
-
self.stem = nn.Sequential(
|
139 |
-
ConvModule(3, 128, 7, padding=3, stride=2, norm_cfg=norm_cfg),
|
140 |
-
ResLayer(BasicBlock, 1, 128, 256, stride=2, norm_cfg=norm_cfg))
|
141 |
-
|
142 |
-
self.hourglass_modules = nn.ModuleList([
|
143 |
-
HourglassModule(downsample_times, stage_channels, stage_blocks)
|
144 |
-
for _ in range(num_stacks)
|
145 |
-
])
|
146 |
-
|
147 |
-
self.inters = ResLayer(
|
148 |
-
BasicBlock,
|
149 |
-
num_stacks - 1,
|
150 |
-
cur_channel,
|
151 |
-
cur_channel,
|
152 |
-
norm_cfg=norm_cfg)
|
153 |
-
|
154 |
-
self.conv1x1s = nn.ModuleList([
|
155 |
-
ConvModule(
|
156 |
-
cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
|
157 |
-
for _ in range(num_stacks - 1)
|
158 |
-
])
|
159 |
-
|
160 |
-
self.out_convs = nn.ModuleList([
|
161 |
-
ConvModule(
|
162 |
-
cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
|
163 |
-
for _ in range(num_stacks)
|
164 |
-
])
|
165 |
-
|
166 |
-
self.remap_convs = nn.ModuleList([
|
167 |
-
ConvModule(
|
168 |
-
feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
|
169 |
-
for _ in range(num_stacks - 1)
|
170 |
-
])
|
171 |
-
|
172 |
-
self.relu = nn.ReLU(inplace=True)
|
173 |
-
|
174 |
-
def init_weights(self, pretrained=None):
|
175 |
-
"""Initialize the weights in backbone.
|
176 |
-
|
177 |
-
Args:
|
178 |
-
pretrained (str, optional): Path to pre-trained weights.
|
179 |
-
Defaults to None.
|
180 |
-
"""
|
181 |
-
if isinstance(pretrained, str):
|
182 |
-
logger = get_root_logger()
|
183 |
-
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
184 |
-
elif pretrained is None:
|
185 |
-
for m in self.modules():
|
186 |
-
if isinstance(m, nn.Conv2d):
|
187 |
-
normal_init(m, std=0.001)
|
188 |
-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
189 |
-
constant_init(m, 1)
|
190 |
-
else:
|
191 |
-
raise TypeError('pretrained must be a str or None')
|
192 |
-
|
193 |
-
def forward(self, x):
|
194 |
-
"""Model forward function."""
|
195 |
-
inter_feat = self.stem(x)
|
196 |
-
out_feats = []
|
197 |
-
|
198 |
-
for ind in range(self.num_stacks):
|
199 |
-
single_hourglass = self.hourglass_modules[ind]
|
200 |
-
out_conv = self.out_convs[ind]
|
201 |
-
|
202 |
-
hourglass_feat = single_hourglass(inter_feat)
|
203 |
-
out_feat = out_conv(hourglass_feat)
|
204 |
-
out_feats.append(out_feat)
|
205 |
-
|
206 |
-
if ind < self.num_stacks - 1:
|
207 |
-
inter_feat = self.conv1x1s[ind](
|
208 |
-
inter_feat) + self.remap_convs[ind](
|
209 |
-
out_feat)
|
210 |
-
inter_feat = self.inters[ind](self.relu(inter_feat))
|
211 |
-
|
212 |
-
return out_feats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/hourglass_ae.py
DELETED
@@ -1,212 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
|
4 |
-
import torch.nn as nn
|
5 |
-
from mmcv.cnn import ConvModule, MaxPool2d, constant_init, normal_init
|
6 |
-
from torch.nn.modules.batchnorm import _BatchNorm
|
7 |
-
|
8 |
-
from mmpose.utils import get_root_logger
|
9 |
-
from ..builder import BACKBONES
|
10 |
-
from .base_backbone import BaseBackbone
|
11 |
-
from .utils import load_checkpoint
|
12 |
-
|
13 |
-
|
14 |
-
class HourglassAEModule(nn.Module):
|
15 |
-
"""Modified Hourglass Module for HourglassNet_AE backbone.
|
16 |
-
|
17 |
-
Generate module recursively and use BasicBlock as the base unit.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
depth (int): Depth of current HourglassModule.
|
21 |
-
stage_channels (list[int]): Feature channels of sub-modules in current
|
22 |
-
and follow-up HourglassModule.
|
23 |
-
norm_cfg (dict): Dictionary to construct and config norm layer.
|
24 |
-
"""
|
25 |
-
|
26 |
-
def __init__(self,
|
27 |
-
depth,
|
28 |
-
stage_channels,
|
29 |
-
norm_cfg=dict(type='BN', requires_grad=True)):
|
30 |
-
# Protect mutable default arguments
|
31 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
32 |
-
super().__init__()
|
33 |
-
|
34 |
-
self.depth = depth
|
35 |
-
|
36 |
-
cur_channel = stage_channels[0]
|
37 |
-
next_channel = stage_channels[1]
|
38 |
-
|
39 |
-
self.up1 = ConvModule(
|
40 |
-
cur_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg)
|
41 |
-
|
42 |
-
self.pool1 = MaxPool2d(2, 2)
|
43 |
-
|
44 |
-
self.low1 = ConvModule(
|
45 |
-
cur_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg)
|
46 |
-
|
47 |
-
if self.depth > 1:
|
48 |
-
self.low2 = HourglassAEModule(depth - 1, stage_channels[1:])
|
49 |
-
else:
|
50 |
-
self.low2 = ConvModule(
|
51 |
-
next_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg)
|
52 |
-
|
53 |
-
self.low3 = ConvModule(
|
54 |
-
next_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg)
|
55 |
-
|
56 |
-
self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
|
57 |
-
|
58 |
-
def forward(self, x):
|
59 |
-
"""Model forward function."""
|
60 |
-
up1 = self.up1(x)
|
61 |
-
pool1 = self.pool1(x)
|
62 |
-
low1 = self.low1(pool1)
|
63 |
-
low2 = self.low2(low1)
|
64 |
-
low3 = self.low3(low2)
|
65 |
-
up2 = self.up2(low3)
|
66 |
-
return up1 + up2
|
67 |
-
|
68 |
-
|
69 |
-
@BACKBONES.register_module()
|
70 |
-
class HourglassAENet(BaseBackbone):
|
71 |
-
"""Hourglass-AE Network proposed by Newell et al.
|
72 |
-
|
73 |
-
Associative Embedding: End-to-End Learning for Joint
|
74 |
-
Detection and Grouping.
|
75 |
-
|
76 |
-
More details can be found in the `paper
|
77 |
-
<https://arxiv.org/abs/1611.05424>`__ .
|
78 |
-
|
79 |
-
Args:
|
80 |
-
downsample_times (int): Downsample times in a HourglassModule.
|
81 |
-
num_stacks (int): Number of HourglassModule modules stacked,
|
82 |
-
1 for Hourglass-52, 2 for Hourglass-104.
|
83 |
-
stage_channels (list[int]): Feature channel of each sub-module in a
|
84 |
-
HourglassModule.
|
85 |
-
stage_blocks (list[int]): Number of sub-modules stacked in a
|
86 |
-
HourglassModule.
|
87 |
-
feat_channels (int): Feature channel of conv after a HourglassModule.
|
88 |
-
norm_cfg (dict): Dictionary to construct and config norm layer.
|
89 |
-
|
90 |
-
Example:
|
91 |
-
>>> from mmpose.models import HourglassAENet
|
92 |
-
>>> import torch
|
93 |
-
>>> self = HourglassAENet()
|
94 |
-
>>> self.eval()
|
95 |
-
>>> inputs = torch.rand(1, 3, 512, 512)
|
96 |
-
>>> level_outputs = self.forward(inputs)
|
97 |
-
>>> for level_output in level_outputs:
|
98 |
-
... print(tuple(level_output.shape))
|
99 |
-
(1, 34, 128, 128)
|
100 |
-
"""
|
101 |
-
|
102 |
-
def __init__(self,
|
103 |
-
downsample_times=4,
|
104 |
-
num_stacks=1,
|
105 |
-
out_channels=34,
|
106 |
-
stage_channels=(256, 384, 512, 640, 768),
|
107 |
-
feat_channels=256,
|
108 |
-
norm_cfg=dict(type='BN', requires_grad=True)):
|
109 |
-
# Protect mutable default arguments
|
110 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
111 |
-
super().__init__()
|
112 |
-
|
113 |
-
self.num_stacks = num_stacks
|
114 |
-
assert self.num_stacks >= 1
|
115 |
-
assert len(stage_channels) > downsample_times
|
116 |
-
|
117 |
-
cur_channels = stage_channels[0]
|
118 |
-
|
119 |
-
self.stem = nn.Sequential(
|
120 |
-
ConvModule(3, 64, 7, padding=3, stride=2, norm_cfg=norm_cfg),
|
121 |
-
ConvModule(64, 128, 3, padding=1, norm_cfg=norm_cfg),
|
122 |
-
MaxPool2d(2, 2),
|
123 |
-
ConvModule(128, 128, 3, padding=1, norm_cfg=norm_cfg),
|
124 |
-
ConvModule(128, feat_channels, 3, padding=1, norm_cfg=norm_cfg),
|
125 |
-
)
|
126 |
-
|
127 |
-
self.hourglass_modules = nn.ModuleList([
|
128 |
-
nn.Sequential(
|
129 |
-
HourglassAEModule(
|
130 |
-
downsample_times, stage_channels, norm_cfg=norm_cfg),
|
131 |
-
ConvModule(
|
132 |
-
feat_channels,
|
133 |
-
feat_channels,
|
134 |
-
3,
|
135 |
-
padding=1,
|
136 |
-
norm_cfg=norm_cfg),
|
137 |
-
ConvModule(
|
138 |
-
feat_channels,
|
139 |
-
feat_channels,
|
140 |
-
3,
|
141 |
-
padding=1,
|
142 |
-
norm_cfg=norm_cfg)) for _ in range(num_stacks)
|
143 |
-
])
|
144 |
-
|
145 |
-
self.out_convs = nn.ModuleList([
|
146 |
-
ConvModule(
|
147 |
-
cur_channels,
|
148 |
-
out_channels,
|
149 |
-
1,
|
150 |
-
padding=0,
|
151 |
-
norm_cfg=None,
|
152 |
-
act_cfg=None) for _ in range(num_stacks)
|
153 |
-
])
|
154 |
-
|
155 |
-
self.remap_out_convs = nn.ModuleList([
|
156 |
-
ConvModule(
|
157 |
-
out_channels,
|
158 |
-
feat_channels,
|
159 |
-
1,
|
160 |
-
norm_cfg=norm_cfg,
|
161 |
-
act_cfg=None) for _ in range(num_stacks - 1)
|
162 |
-
])
|
163 |
-
|
164 |
-
self.remap_feature_convs = nn.ModuleList([
|
165 |
-
ConvModule(
|
166 |
-
feat_channels,
|
167 |
-
feat_channels,
|
168 |
-
1,
|
169 |
-
norm_cfg=norm_cfg,
|
170 |
-
act_cfg=None) for _ in range(num_stacks - 1)
|
171 |
-
])
|
172 |
-
|
173 |
-
self.relu = nn.ReLU(inplace=True)
|
174 |
-
|
175 |
-
def init_weights(self, pretrained=None):
|
176 |
-
"""Initialize the weights in backbone.
|
177 |
-
|
178 |
-
Args:
|
179 |
-
pretrained (str, optional): Path to pre-trained weights.
|
180 |
-
Defaults to None.
|
181 |
-
"""
|
182 |
-
if isinstance(pretrained, str):
|
183 |
-
logger = get_root_logger()
|
184 |
-
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
185 |
-
elif pretrained is None:
|
186 |
-
for m in self.modules():
|
187 |
-
if isinstance(m, nn.Conv2d):
|
188 |
-
normal_init(m, std=0.001)
|
189 |
-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
190 |
-
constant_init(m, 1)
|
191 |
-
else:
|
192 |
-
raise TypeError('pretrained must be a str or None')
|
193 |
-
|
194 |
-
def forward(self, x):
|
195 |
-
"""Model forward function."""
|
196 |
-
inter_feat = self.stem(x)
|
197 |
-
out_feats = []
|
198 |
-
|
199 |
-
for ind in range(self.num_stacks):
|
200 |
-
single_hourglass = self.hourglass_modules[ind]
|
201 |
-
out_conv = self.out_convs[ind]
|
202 |
-
|
203 |
-
hourglass_feat = single_hourglass(inter_feat)
|
204 |
-
out_feat = out_conv(hourglass_feat)
|
205 |
-
out_feats.append(out_feat)
|
206 |
-
|
207 |
-
if ind < self.num_stacks - 1:
|
208 |
-
inter_feat = inter_feat + self.remap_out_convs[ind](
|
209 |
-
out_feat) + self.remap_feature_convs[ind](
|
210 |
-
hourglass_feat)
|
211 |
-
|
212 |
-
return out_feats
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/hrformer.py
DELETED
@@ -1,746 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
|
3 |
-
import math
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
# from timm.models.layers import to_2tuple, trunc_normal_
|
8 |
-
from mmcv.cnn import (build_activation_layer, build_conv_layer,
|
9 |
-
build_norm_layer, trunc_normal_init)
|
10 |
-
from mmcv.cnn.bricks.transformer import build_dropout
|
11 |
-
from mmcv.runner import BaseModule
|
12 |
-
from torch.nn.functional import pad
|
13 |
-
|
14 |
-
from ..builder import BACKBONES
|
15 |
-
from .hrnet import Bottleneck, HRModule, HRNet
|
16 |
-
|
17 |
-
|
18 |
-
def nlc_to_nchw(x, hw_shape):
|
19 |
-
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
|
20 |
-
|
21 |
-
Args:
|
22 |
-
x (Tensor): The input tensor of shape [N, L, C] before conversion.
|
23 |
-
hw_shape (Sequence[int]): The height and width of output feature map.
|
24 |
-
|
25 |
-
Returns:
|
26 |
-
Tensor: The output tensor of shape [N, C, H, W] after conversion.
|
27 |
-
"""
|
28 |
-
H, W = hw_shape
|
29 |
-
assert len(x.shape) == 3
|
30 |
-
B, L, C = x.shape
|
31 |
-
assert L == H * W, 'The seq_len doesn\'t match H, W'
|
32 |
-
return x.transpose(1, 2).reshape(B, C, H, W)
|
33 |
-
|
34 |
-
|
35 |
-
def nchw_to_nlc(x):
|
36 |
-
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
|
40 |
-
|
41 |
-
Returns:
|
42 |
-
Tensor: The output tensor of shape [N, L, C] after conversion.
|
43 |
-
"""
|
44 |
-
assert len(x.shape) == 4
|
45 |
-
return x.flatten(2).transpose(1, 2).contiguous()
|
46 |
-
|
47 |
-
|
48 |
-
def build_drop_path(drop_path_rate):
|
49 |
-
"""Build drop path layer."""
|
50 |
-
return build_dropout(dict(type='DropPath', drop_prob=drop_path_rate))
|
51 |
-
|
52 |
-
|
53 |
-
class WindowMSA(BaseModule):
|
54 |
-
"""Window based multi-head self-attention (W-MSA) module with relative
|
55 |
-
position bias.
|
56 |
-
|
57 |
-
Args:
|
58 |
-
embed_dims (int): Number of input channels.
|
59 |
-
num_heads (int): Number of attention heads.
|
60 |
-
window_size (tuple[int]): The height and width of the window.
|
61 |
-
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
62 |
-
Default: True.
|
63 |
-
qk_scale (float | None, optional): Override default qk scale of
|
64 |
-
head_dim ** -0.5 if set. Default: None.
|
65 |
-
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
66 |
-
Default: 0.0
|
67 |
-
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
68 |
-
with_rpe (bool, optional): If True, use relative position bias.
|
69 |
-
Default: True.
|
70 |
-
init_cfg (dict | None, optional): The Config for initialization.
|
71 |
-
Default: None.
|
72 |
-
"""
|
73 |
-
|
74 |
-
def __init__(self,
|
75 |
-
embed_dims,
|
76 |
-
num_heads,
|
77 |
-
window_size,
|
78 |
-
qkv_bias=True,
|
79 |
-
qk_scale=None,
|
80 |
-
attn_drop_rate=0.,
|
81 |
-
proj_drop_rate=0.,
|
82 |
-
with_rpe=True,
|
83 |
-
init_cfg=None):
|
84 |
-
|
85 |
-
super().__init__(init_cfg=init_cfg)
|
86 |
-
self.embed_dims = embed_dims
|
87 |
-
self.window_size = window_size # Wh, Ww
|
88 |
-
self.num_heads = num_heads
|
89 |
-
head_embed_dims = embed_dims // num_heads
|
90 |
-
self.scale = qk_scale or head_embed_dims**-0.5
|
91 |
-
|
92 |
-
self.with_rpe = with_rpe
|
93 |
-
if self.with_rpe:
|
94 |
-
# define a parameter table of relative position bias
|
95 |
-
self.relative_position_bias_table = nn.Parameter(
|
96 |
-
torch.zeros(
|
97 |
-
(2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
98 |
-
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
99 |
-
|
100 |
-
Wh, Ww = self.window_size
|
101 |
-
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
|
102 |
-
rel_position_index = rel_index_coords + rel_index_coords.T
|
103 |
-
rel_position_index = rel_position_index.flip(1).contiguous()
|
104 |
-
self.register_buffer('relative_position_index', rel_position_index)
|
105 |
-
|
106 |
-
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
107 |
-
self.attn_drop = nn.Dropout(attn_drop_rate)
|
108 |
-
self.proj = nn.Linear(embed_dims, embed_dims)
|
109 |
-
self.proj_drop = nn.Dropout(proj_drop_rate)
|
110 |
-
|
111 |
-
self.softmax = nn.Softmax(dim=-1)
|
112 |
-
|
113 |
-
def init_weights(self):
|
114 |
-
trunc_normal_init(self.relative_position_bias_table, std=0.02)
|
115 |
-
|
116 |
-
def forward(self, x, mask=None):
|
117 |
-
"""
|
118 |
-
Args:
|
119 |
-
|
120 |
-
x (tensor): input features with shape of (B*num_windows, N, C)
|
121 |
-
mask (tensor | None, Optional): mask with shape of (num_windows,
|
122 |
-
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
|
123 |
-
"""
|
124 |
-
B, N, C = x.shape
|
125 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
126 |
-
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
127 |
-
q, k, v = qkv[0], qkv[1], qkv[2]
|
128 |
-
|
129 |
-
q = q * self.scale
|
130 |
-
attn = (q @ k.transpose(-2, -1))
|
131 |
-
|
132 |
-
if self.with_rpe:
|
133 |
-
relative_position_bias = self.relative_position_bias_table[
|
134 |
-
self.relative_position_index.view(-1)].view(
|
135 |
-
self.window_size[0] * self.window_size[1],
|
136 |
-
self.window_size[0] * self.window_size[1],
|
137 |
-
-1) # Wh*Ww,Wh*Ww,nH
|
138 |
-
relative_position_bias = relative_position_bias.permute(
|
139 |
-
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
140 |
-
attn = attn + relative_position_bias.unsqueeze(0)
|
141 |
-
|
142 |
-
if mask is not None:
|
143 |
-
nW = mask.shape[0]
|
144 |
-
attn = attn.view(B // nW, nW, self.num_heads, N,
|
145 |
-
N) + mask.unsqueeze(1).unsqueeze(0)
|
146 |
-
attn = attn.view(-1, self.num_heads, N, N)
|
147 |
-
attn = self.softmax(attn)
|
148 |
-
|
149 |
-
attn = self.attn_drop(attn)
|
150 |
-
|
151 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
152 |
-
x = self.proj(x)
|
153 |
-
x = self.proj_drop(x)
|
154 |
-
return x
|
155 |
-
|
156 |
-
@staticmethod
|
157 |
-
def double_step_seq(step1, len1, step2, len2):
|
158 |
-
seq1 = torch.arange(0, step1 * len1, step1)
|
159 |
-
seq2 = torch.arange(0, step2 * len2, step2)
|
160 |
-
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
161 |
-
|
162 |
-
|
163 |
-
class LocalWindowSelfAttention(BaseModule):
|
164 |
-
r""" Local-window Self Attention (LSA) module with relative position bias.
|
165 |
-
|
166 |
-
This module is the short-range self-attention module in the
|
167 |
-
Interlaced Sparse Self-Attention <https://arxiv.org/abs/1907.12273>`_.
|
168 |
-
|
169 |
-
Args:
|
170 |
-
embed_dims (int): Number of input channels.
|
171 |
-
num_heads (int): Number of attention heads.
|
172 |
-
window_size (tuple[int] | int): The height and width of the window.
|
173 |
-
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
174 |
-
Default: True.
|
175 |
-
qk_scale (float | None, optional): Override default qk scale of
|
176 |
-
head_dim ** -0.5 if set. Default: None.
|
177 |
-
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
178 |
-
Default: 0.0
|
179 |
-
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
180 |
-
with_rpe (bool, optional): If True, use relative position bias.
|
181 |
-
Default: True.
|
182 |
-
with_pad_mask (bool, optional): If True, mask out the padded tokens in
|
183 |
-
the attention process. Default: False.
|
184 |
-
init_cfg (dict | None, optional): The Config for initialization.
|
185 |
-
Default: None.
|
186 |
-
"""
|
187 |
-
|
188 |
-
def __init__(self,
|
189 |
-
embed_dims,
|
190 |
-
num_heads,
|
191 |
-
window_size,
|
192 |
-
qkv_bias=True,
|
193 |
-
qk_scale=None,
|
194 |
-
attn_drop_rate=0.,
|
195 |
-
proj_drop_rate=0.,
|
196 |
-
with_rpe=True,
|
197 |
-
with_pad_mask=False,
|
198 |
-
init_cfg=None):
|
199 |
-
super().__init__(init_cfg=init_cfg)
|
200 |
-
if isinstance(window_size, int):
|
201 |
-
window_size = (window_size, window_size)
|
202 |
-
self.window_size = window_size
|
203 |
-
self.with_pad_mask = with_pad_mask
|
204 |
-
self.attn = WindowMSA(
|
205 |
-
embed_dims=embed_dims,
|
206 |
-
num_heads=num_heads,
|
207 |
-
window_size=window_size,
|
208 |
-
qkv_bias=qkv_bias,
|
209 |
-
qk_scale=qk_scale,
|
210 |
-
attn_drop_rate=attn_drop_rate,
|
211 |
-
proj_drop_rate=proj_drop_rate,
|
212 |
-
with_rpe=with_rpe,
|
213 |
-
init_cfg=init_cfg)
|
214 |
-
|
215 |
-
def forward(self, x, H, W, **kwargs):
|
216 |
-
"""Forward function."""
|
217 |
-
B, N, C = x.shape
|
218 |
-
x = x.view(B, H, W, C)
|
219 |
-
Wh, Ww = self.window_size
|
220 |
-
|
221 |
-
# center-pad the feature on H and W axes
|
222 |
-
pad_h = math.ceil(H / Wh) * Wh - H
|
223 |
-
pad_w = math.ceil(W / Ww) * Ww - W
|
224 |
-
x = pad(x, (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
225 |
-
pad_h - pad_h // 2))
|
226 |
-
|
227 |
-
# permute
|
228 |
-
x = x.view(B, math.ceil(H / Wh), Wh, math.ceil(W / Ww), Ww, C)
|
229 |
-
x = x.permute(0, 1, 3, 2, 4, 5)
|
230 |
-
x = x.reshape(-1, Wh * Ww, C) # (B*num_window, Wh*Ww, C)
|
231 |
-
|
232 |
-
# attention
|
233 |
-
if self.with_pad_mask and pad_h > 0 and pad_w > 0:
|
234 |
-
pad_mask = x.new_zeros(1, H, W, 1)
|
235 |
-
pad_mask = pad(
|
236 |
-
pad_mask, [
|
237 |
-
0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
238 |
-
pad_h - pad_h // 2
|
239 |
-
],
|
240 |
-
value=-float('inf'))
|
241 |
-
pad_mask = pad_mask.view(1, math.ceil(H / Wh), Wh,
|
242 |
-
math.ceil(W / Ww), Ww, 1)
|
243 |
-
pad_mask = pad_mask.permute(1, 3, 0, 2, 4, 5)
|
244 |
-
pad_mask = pad_mask.reshape(-1, Wh * Ww)
|
245 |
-
pad_mask = pad_mask[:, None, :].expand([-1, Wh * Ww, -1])
|
246 |
-
out = self.attn(x, pad_mask, **kwargs)
|
247 |
-
else:
|
248 |
-
out = self.attn(x, **kwargs)
|
249 |
-
|
250 |
-
# reverse permutation
|
251 |
-
out = out.reshape(B, math.ceil(H / Wh), math.ceil(W / Ww), Wh, Ww, C)
|
252 |
-
out = out.permute(0, 1, 3, 2, 4, 5)
|
253 |
-
out = out.reshape(B, H + pad_h, W + pad_w, C)
|
254 |
-
|
255 |
-
# de-pad
|
256 |
-
out = out[:, pad_h // 2:H + pad_h // 2, pad_w // 2:W + pad_w // 2]
|
257 |
-
return out.reshape(B, N, C)
|
258 |
-
|
259 |
-
|
260 |
-
class CrossFFN(BaseModule):
|
261 |
-
r"""FFN with Depthwise Conv of HRFormer.
|
262 |
-
|
263 |
-
Args:
|
264 |
-
in_features (int): The feature dimension.
|
265 |
-
hidden_features (int, optional): The hidden dimension of FFNs.
|
266 |
-
Defaults: The same as in_features.
|
267 |
-
act_cfg (dict, optional): Config of activation layer.
|
268 |
-
Default: dict(type='GELU').
|
269 |
-
dw_act_cfg (dict, optional): Config of activation layer appended
|
270 |
-
right after DW Conv. Default: dict(type='GELU').
|
271 |
-
norm_cfg (dict, optional): Config of norm layer.
|
272 |
-
Default: dict(type='SyncBN').
|
273 |
-
init_cfg (dict | list | None, optional): The init config.
|
274 |
-
Default: None.
|
275 |
-
"""
|
276 |
-
|
277 |
-
def __init__(self,
|
278 |
-
in_features,
|
279 |
-
hidden_features=None,
|
280 |
-
out_features=None,
|
281 |
-
act_cfg=dict(type='GELU'),
|
282 |
-
dw_act_cfg=dict(type='GELU'),
|
283 |
-
norm_cfg=dict(type='SyncBN'),
|
284 |
-
init_cfg=None):
|
285 |
-
super().__init__(init_cfg=init_cfg)
|
286 |
-
out_features = out_features or in_features
|
287 |
-
hidden_features = hidden_features or in_features
|
288 |
-
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1)
|
289 |
-
self.act1 = build_activation_layer(act_cfg)
|
290 |
-
self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
|
291 |
-
self.dw3x3 = nn.Conv2d(
|
292 |
-
hidden_features,
|
293 |
-
hidden_features,
|
294 |
-
kernel_size=3,
|
295 |
-
stride=1,
|
296 |
-
groups=hidden_features,
|
297 |
-
padding=1)
|
298 |
-
self.act2 = build_activation_layer(dw_act_cfg)
|
299 |
-
self.norm2 = build_norm_layer(norm_cfg, hidden_features)[1]
|
300 |
-
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)
|
301 |
-
self.act3 = build_activation_layer(act_cfg)
|
302 |
-
self.norm3 = build_norm_layer(norm_cfg, out_features)[1]
|
303 |
-
|
304 |
-
# put the modules togather
|
305 |
-
self.layers = [
|
306 |
-
self.fc1, self.norm1, self.act1, self.dw3x3, self.norm2, self.act2,
|
307 |
-
self.fc2, self.norm3, self.act3
|
308 |
-
]
|
309 |
-
|
310 |
-
def forward(self, x, H, W):
|
311 |
-
"""Forward function."""
|
312 |
-
x = nlc_to_nchw(x, (H, W))
|
313 |
-
for layer in self.layers:
|
314 |
-
x = layer(x)
|
315 |
-
x = nchw_to_nlc(x)
|
316 |
-
return x
|
317 |
-
|
318 |
-
|
319 |
-
class HRFormerBlock(BaseModule):
|
320 |
-
"""High-Resolution Block for HRFormer.
|
321 |
-
|
322 |
-
Args:
|
323 |
-
in_features (int): The input dimension.
|
324 |
-
out_features (int): The output dimension.
|
325 |
-
num_heads (int): The number of head within each LSA.
|
326 |
-
window_size (int, optional): The window size for the LSA.
|
327 |
-
Default: 7
|
328 |
-
mlp_ratio (int, optional): The expansion ration of FFN.
|
329 |
-
Default: 4
|
330 |
-
act_cfg (dict, optional): Config of activation layer.
|
331 |
-
Default: dict(type='GELU').
|
332 |
-
norm_cfg (dict, optional): Config of norm layer.
|
333 |
-
Default: dict(type='SyncBN').
|
334 |
-
transformer_norm_cfg (dict, optional): Config of transformer norm
|
335 |
-
layer. Default: dict(type='LN', eps=1e-6).
|
336 |
-
init_cfg (dict | list | None, optional): The init config.
|
337 |
-
Default: None.
|
338 |
-
"""
|
339 |
-
|
340 |
-
expansion = 1
|
341 |
-
|
342 |
-
def __init__(self,
|
343 |
-
in_features,
|
344 |
-
out_features,
|
345 |
-
num_heads,
|
346 |
-
window_size=7,
|
347 |
-
mlp_ratio=4.0,
|
348 |
-
drop_path=0.0,
|
349 |
-
act_cfg=dict(type='GELU'),
|
350 |
-
norm_cfg=dict(type='SyncBN'),
|
351 |
-
transformer_norm_cfg=dict(type='LN', eps=1e-6),
|
352 |
-
init_cfg=None,
|
353 |
-
**kwargs):
|
354 |
-
super(HRFormerBlock, self).__init__(init_cfg=init_cfg)
|
355 |
-
self.num_heads = num_heads
|
356 |
-
self.window_size = window_size
|
357 |
-
self.mlp_ratio = mlp_ratio
|
358 |
-
|
359 |
-
self.norm1 = build_norm_layer(transformer_norm_cfg, in_features)[1]
|
360 |
-
self.attn = LocalWindowSelfAttention(
|
361 |
-
in_features,
|
362 |
-
num_heads=num_heads,
|
363 |
-
window_size=window_size,
|
364 |
-
init_cfg=None,
|
365 |
-
**kwargs)
|
366 |
-
|
367 |
-
self.norm2 = build_norm_layer(transformer_norm_cfg, out_features)[1]
|
368 |
-
self.ffn = CrossFFN(
|
369 |
-
in_features=in_features,
|
370 |
-
hidden_features=int(in_features * mlp_ratio),
|
371 |
-
out_features=out_features,
|
372 |
-
norm_cfg=norm_cfg,
|
373 |
-
act_cfg=act_cfg,
|
374 |
-
dw_act_cfg=act_cfg,
|
375 |
-
init_cfg=None)
|
376 |
-
|
377 |
-
self.drop_path = build_drop_path(
|
378 |
-
drop_path) if drop_path > 0.0 else nn.Identity()
|
379 |
-
|
380 |
-
def forward(self, x):
|
381 |
-
"""Forward function."""
|
382 |
-
B, C, H, W = x.size()
|
383 |
-
# Attention
|
384 |
-
x = x.view(B, C, -1).permute(0, 2, 1)
|
385 |
-
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
386 |
-
# FFN
|
387 |
-
x = x + self.drop_path(self.ffn(self.norm2(x), H, W))
|
388 |
-
x = x.permute(0, 2, 1).view(B, C, H, W)
|
389 |
-
return x
|
390 |
-
|
391 |
-
def extra_repr(self):
|
392 |
-
"""(Optional) Set the extra information about this module."""
|
393 |
-
return 'num_heads={}, window_size={}, mlp_ratio={}'.format(
|
394 |
-
self.num_heads, self.window_size, self.mlp_ratio)
|
395 |
-
|
396 |
-
|
397 |
-
class HRFomerModule(HRModule):
|
398 |
-
"""High-Resolution Module for HRFormer.
|
399 |
-
|
400 |
-
Args:
|
401 |
-
num_branches (int): The number of branches in the HRFormerModule.
|
402 |
-
block (nn.Module): The building block of HRFormer.
|
403 |
-
The block should be the HRFormerBlock.
|
404 |
-
num_blocks (tuple): The number of blocks in each branch.
|
405 |
-
The length must be equal to num_branches.
|
406 |
-
num_inchannels (tuple): The number of input channels in each branch.
|
407 |
-
The length must be equal to num_branches.
|
408 |
-
num_channels (tuple): The number of channels in each branch.
|
409 |
-
The length must be equal to num_branches.
|
410 |
-
num_heads (tuple): The number of heads within the LSAs.
|
411 |
-
num_window_sizes (tuple): The window size for the LSAs.
|
412 |
-
num_mlp_ratios (tuple): The expansion ratio for the FFNs.
|
413 |
-
drop_path (int, optional): The drop path rate of HRFomer.
|
414 |
-
Default: 0.0
|
415 |
-
multiscale_output (bool, optional): Whether to output multi-level
|
416 |
-
features produced by multiple branches. If False, only the first
|
417 |
-
level feature will be output. Default: True.
|
418 |
-
conv_cfg (dict, optional): Config of the conv layers.
|
419 |
-
Default: None.
|
420 |
-
norm_cfg (dict, optional): Config of the norm layers appended
|
421 |
-
right after conv. Default: dict(type='SyncBN', requires_grad=True)
|
422 |
-
transformer_norm_cfg (dict, optional): Config of the norm layers.
|
423 |
-
Default: dict(type='LN', eps=1e-6)
|
424 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
425 |
-
memory while slowing down the training speed. Default: False
|
426 |
-
upsample_cfg(dict, optional): The config of upsample layers in fuse
|
427 |
-
layers. Default: dict(mode='bilinear', align_corners=False)
|
428 |
-
"""
|
429 |
-
|
430 |
-
def __init__(self,
|
431 |
-
num_branches,
|
432 |
-
block,
|
433 |
-
num_blocks,
|
434 |
-
num_inchannels,
|
435 |
-
num_channels,
|
436 |
-
num_heads,
|
437 |
-
num_window_sizes,
|
438 |
-
num_mlp_ratios,
|
439 |
-
multiscale_output=True,
|
440 |
-
drop_paths=0.0,
|
441 |
-
with_rpe=True,
|
442 |
-
with_pad_mask=False,
|
443 |
-
conv_cfg=None,
|
444 |
-
norm_cfg=dict(type='SyncBN', requires_grad=True),
|
445 |
-
transformer_norm_cfg=dict(type='LN', eps=1e-6),
|
446 |
-
with_cp=False,
|
447 |
-
upsample_cfg=dict(mode='bilinear', align_corners=False)):
|
448 |
-
|
449 |
-
self.transformer_norm_cfg = transformer_norm_cfg
|
450 |
-
self.drop_paths = drop_paths
|
451 |
-
self.num_heads = num_heads
|
452 |
-
self.num_window_sizes = num_window_sizes
|
453 |
-
self.num_mlp_ratios = num_mlp_ratios
|
454 |
-
self.with_rpe = with_rpe
|
455 |
-
self.with_pad_mask = with_pad_mask
|
456 |
-
|
457 |
-
super().__init__(num_branches, block, num_blocks, num_inchannels,
|
458 |
-
num_channels, multiscale_output, with_cp, conv_cfg,
|
459 |
-
norm_cfg, upsample_cfg)
|
460 |
-
|
461 |
-
def _make_one_branch(self,
|
462 |
-
branch_index,
|
463 |
-
block,
|
464 |
-
num_blocks,
|
465 |
-
num_channels,
|
466 |
-
stride=1):
|
467 |
-
"""Build one branch."""
|
468 |
-
# HRFormerBlock does not support down sample layer yet.
|
469 |
-
assert stride == 1 and self.in_channels[branch_index] == num_channels[
|
470 |
-
branch_index]
|
471 |
-
layers = []
|
472 |
-
layers.append(
|
473 |
-
block(
|
474 |
-
self.in_channels[branch_index],
|
475 |
-
num_channels[branch_index],
|
476 |
-
num_heads=self.num_heads[branch_index],
|
477 |
-
window_size=self.num_window_sizes[branch_index],
|
478 |
-
mlp_ratio=self.num_mlp_ratios[branch_index],
|
479 |
-
drop_path=self.drop_paths[0],
|
480 |
-
norm_cfg=self.norm_cfg,
|
481 |
-
transformer_norm_cfg=self.transformer_norm_cfg,
|
482 |
-
init_cfg=None,
|
483 |
-
with_rpe=self.with_rpe,
|
484 |
-
with_pad_mask=self.with_pad_mask))
|
485 |
-
|
486 |
-
self.in_channels[
|
487 |
-
branch_index] = self.in_channels[branch_index] * block.expansion
|
488 |
-
for i in range(1, num_blocks[branch_index]):
|
489 |
-
layers.append(
|
490 |
-
block(
|
491 |
-
self.in_channels[branch_index],
|
492 |
-
num_channels[branch_index],
|
493 |
-
num_heads=self.num_heads[branch_index],
|
494 |
-
window_size=self.num_window_sizes[branch_index],
|
495 |
-
mlp_ratio=self.num_mlp_ratios[branch_index],
|
496 |
-
drop_path=self.drop_paths[i],
|
497 |
-
norm_cfg=self.norm_cfg,
|
498 |
-
transformer_norm_cfg=self.transformer_norm_cfg,
|
499 |
-
init_cfg=None,
|
500 |
-
with_rpe=self.with_rpe,
|
501 |
-
with_pad_mask=self.with_pad_mask))
|
502 |
-
return nn.Sequential(*layers)
|
503 |
-
|
504 |
-
def _make_fuse_layers(self):
|
505 |
-
"""Build fuse layers."""
|
506 |
-
if self.num_branches == 1:
|
507 |
-
return None
|
508 |
-
num_branches = self.num_branches
|
509 |
-
num_inchannels = self.in_channels
|
510 |
-
fuse_layers = []
|
511 |
-
for i in range(num_branches if self.multiscale_output else 1):
|
512 |
-
fuse_layer = []
|
513 |
-
for j in range(num_branches):
|
514 |
-
if j > i:
|
515 |
-
fuse_layer.append(
|
516 |
-
nn.Sequential(
|
517 |
-
build_conv_layer(
|
518 |
-
self.conv_cfg,
|
519 |
-
num_inchannels[j],
|
520 |
-
num_inchannels[i],
|
521 |
-
kernel_size=1,
|
522 |
-
stride=1,
|
523 |
-
bias=False),
|
524 |
-
build_norm_layer(self.norm_cfg,
|
525 |
-
num_inchannels[i])[1],
|
526 |
-
nn.Upsample(
|
527 |
-
scale_factor=2**(j - i),
|
528 |
-
mode=self.upsample_cfg['mode'],
|
529 |
-
align_corners=self.
|
530 |
-
upsample_cfg['align_corners'])))
|
531 |
-
elif j == i:
|
532 |
-
fuse_layer.append(None)
|
533 |
-
else:
|
534 |
-
conv3x3s = []
|
535 |
-
for k in range(i - j):
|
536 |
-
if k == i - j - 1:
|
537 |
-
num_outchannels_conv3x3 = num_inchannels[i]
|
538 |
-
with_out_act = False
|
539 |
-
else:
|
540 |
-
num_outchannels_conv3x3 = num_inchannels[j]
|
541 |
-
with_out_act = True
|
542 |
-
sub_modules = [
|
543 |
-
build_conv_layer(
|
544 |
-
self.conv_cfg,
|
545 |
-
num_inchannels[j],
|
546 |
-
num_inchannels[j],
|
547 |
-
kernel_size=3,
|
548 |
-
stride=2,
|
549 |
-
padding=1,
|
550 |
-
groups=num_inchannels[j],
|
551 |
-
bias=False,
|
552 |
-
),
|
553 |
-
build_norm_layer(self.norm_cfg,
|
554 |
-
num_inchannels[j])[1],
|
555 |
-
build_conv_layer(
|
556 |
-
self.conv_cfg,
|
557 |
-
num_inchannels[j],
|
558 |
-
num_outchannels_conv3x3,
|
559 |
-
kernel_size=1,
|
560 |
-
stride=1,
|
561 |
-
bias=False,
|
562 |
-
),
|
563 |
-
build_norm_layer(self.norm_cfg,
|
564 |
-
num_outchannels_conv3x3)[1]
|
565 |
-
]
|
566 |
-
if with_out_act:
|
567 |
-
sub_modules.append(nn.ReLU(False))
|
568 |
-
conv3x3s.append(nn.Sequential(*sub_modules))
|
569 |
-
fuse_layer.append(nn.Sequential(*conv3x3s))
|
570 |
-
fuse_layers.append(nn.ModuleList(fuse_layer))
|
571 |
-
|
572 |
-
return nn.ModuleList(fuse_layers)
|
573 |
-
|
574 |
-
def get_num_inchannels(self):
|
575 |
-
"""Return the number of input channels."""
|
576 |
-
return self.in_channels
|
577 |
-
|
578 |
-
|
579 |
-
@BACKBONES.register_module()
|
580 |
-
class HRFormer(HRNet):
|
581 |
-
"""HRFormer backbone.
|
582 |
-
|
583 |
-
This backbone is the implementation of `HRFormer: High-Resolution
|
584 |
-
Transformer for Dense Prediction <https://arxiv.org/abs/2110.09408>`_.
|
585 |
-
|
586 |
-
Args:
|
587 |
-
extra (dict): Detailed configuration for each stage of HRNet.
|
588 |
-
There must be 4 stages, the configuration for each stage must have
|
589 |
-
5 keys:
|
590 |
-
|
591 |
-
- num_modules (int): The number of HRModule in this stage.
|
592 |
-
- num_branches (int): The number of branches in the HRModule.
|
593 |
-
- block (str): The type of block.
|
594 |
-
- num_blocks (tuple): The number of blocks in each branch.
|
595 |
-
The length must be equal to num_branches.
|
596 |
-
- num_channels (tuple): The number of channels in each branch.
|
597 |
-
The length must be equal to num_branches.
|
598 |
-
in_channels (int): Number of input image channels. Normally 3.
|
599 |
-
conv_cfg (dict): Dictionary to construct and config conv layer.
|
600 |
-
Default: None.
|
601 |
-
norm_cfg (dict): Config of norm layer.
|
602 |
-
Use `SyncBN` by default.
|
603 |
-
transformer_norm_cfg (dict): Config of transformer norm layer.
|
604 |
-
Use `LN` by default.
|
605 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
606 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
607 |
-
and its variants only. Default: False.
|
608 |
-
zero_init_residual (bool): Whether to use zero init for last norm layer
|
609 |
-
in resblocks to let them behave as identity. Default: False.
|
610 |
-
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
611 |
-
-1 means not freezing any parameters. Default: -1.
|
612 |
-
Example:
|
613 |
-
>>> from mmpose.models import HRFormer
|
614 |
-
>>> import torch
|
615 |
-
>>> extra = dict(
|
616 |
-
>>> stage1=dict(
|
617 |
-
>>> num_modules=1,
|
618 |
-
>>> num_branches=1,
|
619 |
-
>>> block='BOTTLENECK',
|
620 |
-
>>> num_blocks=(2, ),
|
621 |
-
>>> num_channels=(64, )),
|
622 |
-
>>> stage2=dict(
|
623 |
-
>>> num_modules=1,
|
624 |
-
>>> num_branches=2,
|
625 |
-
>>> block='HRFORMER',
|
626 |
-
>>> window_sizes=(7, 7),
|
627 |
-
>>> num_heads=(1, 2),
|
628 |
-
>>> mlp_ratios=(4, 4),
|
629 |
-
>>> num_blocks=(2, 2),
|
630 |
-
>>> num_channels=(32, 64)),
|
631 |
-
>>> stage3=dict(
|
632 |
-
>>> num_modules=4,
|
633 |
-
>>> num_branches=3,
|
634 |
-
>>> block='HRFORMER',
|
635 |
-
>>> window_sizes=(7, 7, 7),
|
636 |
-
>>> num_heads=(1, 2, 4),
|
637 |
-
>>> mlp_ratios=(4, 4, 4),
|
638 |
-
>>> num_blocks=(2, 2, 2),
|
639 |
-
>>> num_channels=(32, 64, 128)),
|
640 |
-
>>> stage4=dict(
|
641 |
-
>>> num_modules=2,
|
642 |
-
>>> num_branches=4,
|
643 |
-
>>> block='HRFORMER',
|
644 |
-
>>> window_sizes=(7, 7, 7, 7),
|
645 |
-
>>> num_heads=(1, 2, 4, 8),
|
646 |
-
>>> mlp_ratios=(4, 4, 4, 4),
|
647 |
-
>>> num_blocks=(2, 2, 2, 2),
|
648 |
-
>>> num_channels=(32, 64, 128, 256)))
|
649 |
-
>>> self = HRFormer(extra, in_channels=1)
|
650 |
-
>>> self.eval()
|
651 |
-
>>> inputs = torch.rand(1, 1, 32, 32)
|
652 |
-
>>> level_outputs = self.forward(inputs)
|
653 |
-
>>> for level_out in level_outputs:
|
654 |
-
... print(tuple(level_out.shape))
|
655 |
-
(1, 32, 8, 8)
|
656 |
-
(1, 64, 4, 4)
|
657 |
-
(1, 128, 2, 2)
|
658 |
-
(1, 256, 1, 1)
|
659 |
-
"""
|
660 |
-
|
661 |
-
blocks_dict = {'BOTTLENECK': Bottleneck, 'HRFORMERBLOCK': HRFormerBlock}
|
662 |
-
|
663 |
-
def __init__(self,
|
664 |
-
extra,
|
665 |
-
in_channels=3,
|
666 |
-
conv_cfg=None,
|
667 |
-
norm_cfg=dict(type='BN', requires_grad=True),
|
668 |
-
transformer_norm_cfg=dict(type='LN', eps=1e-6),
|
669 |
-
norm_eval=False,
|
670 |
-
with_cp=False,
|
671 |
-
zero_init_residual=False,
|
672 |
-
frozen_stages=-1):
|
673 |
-
|
674 |
-
# stochastic depth
|
675 |
-
depths = [
|
676 |
-
extra[stage]['num_blocks'][0] * extra[stage]['num_modules']
|
677 |
-
for stage in ['stage2', 'stage3', 'stage4']
|
678 |
-
]
|
679 |
-
depth_s2, depth_s3, _ = depths
|
680 |
-
drop_path_rate = extra['drop_path_rate']
|
681 |
-
dpr = [
|
682 |
-
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
683 |
-
]
|
684 |
-
extra['stage2']['drop_path_rates'] = dpr[0:depth_s2]
|
685 |
-
extra['stage3']['drop_path_rates'] = dpr[depth_s2:depth_s2 + depth_s3]
|
686 |
-
extra['stage4']['drop_path_rates'] = dpr[depth_s2 + depth_s3:]
|
687 |
-
|
688 |
-
# HRFormer use bilinear upsample as default
|
689 |
-
upsample_cfg = extra.get('upsample', {
|
690 |
-
'mode': 'bilinear',
|
691 |
-
'align_corners': False
|
692 |
-
})
|
693 |
-
extra['upsample'] = upsample_cfg
|
694 |
-
self.transformer_norm_cfg = transformer_norm_cfg
|
695 |
-
self.with_rpe = extra.get('with_rpe', True)
|
696 |
-
self.with_pad_mask = extra.get('with_pad_mask', False)
|
697 |
-
|
698 |
-
super().__init__(extra, in_channels, conv_cfg, norm_cfg, norm_eval,
|
699 |
-
with_cp, zero_init_residual, frozen_stages)
|
700 |
-
|
701 |
-
def _make_stage(self,
|
702 |
-
layer_config,
|
703 |
-
num_inchannels,
|
704 |
-
multiscale_output=True):
|
705 |
-
"""Make each stage."""
|
706 |
-
num_modules = layer_config['num_modules']
|
707 |
-
num_branches = layer_config['num_branches']
|
708 |
-
num_blocks = layer_config['num_blocks']
|
709 |
-
num_channels = layer_config['num_channels']
|
710 |
-
block = self.blocks_dict[layer_config['block']]
|
711 |
-
num_heads = layer_config['num_heads']
|
712 |
-
num_window_sizes = layer_config['window_sizes']
|
713 |
-
num_mlp_ratios = layer_config['mlp_ratios']
|
714 |
-
drop_path_rates = layer_config['drop_path_rates']
|
715 |
-
|
716 |
-
modules = []
|
717 |
-
for i in range(num_modules):
|
718 |
-
# multiscale_output is only used at the last module
|
719 |
-
if not multiscale_output and i == num_modules - 1:
|
720 |
-
reset_multiscale_output = False
|
721 |
-
else:
|
722 |
-
reset_multiscale_output = True
|
723 |
-
|
724 |
-
modules.append(
|
725 |
-
HRFomerModule(
|
726 |
-
num_branches,
|
727 |
-
block,
|
728 |
-
num_blocks,
|
729 |
-
num_inchannels,
|
730 |
-
num_channels,
|
731 |
-
num_heads,
|
732 |
-
num_window_sizes,
|
733 |
-
num_mlp_ratios,
|
734 |
-
reset_multiscale_output,
|
735 |
-
drop_paths=drop_path_rates[num_blocks[0] *
|
736 |
-
i:num_blocks[0] * (i + 1)],
|
737 |
-
with_rpe=self.with_rpe,
|
738 |
-
with_pad_mask=self.with_pad_mask,
|
739 |
-
conv_cfg=self.conv_cfg,
|
740 |
-
norm_cfg=self.norm_cfg,
|
741 |
-
transformer_norm_cfg=self.transformer_norm_cfg,
|
742 |
-
with_cp=self.with_cp,
|
743 |
-
upsample_cfg=self.upsample_cfg))
|
744 |
-
num_inchannels = modules[-1].get_num_inchannels()
|
745 |
-
|
746 |
-
return nn.Sequential(*modules), num_inchannels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/hrnet.py
DELETED
@@ -1,604 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
|
4 |
-
import torch.nn as nn
|
5 |
-
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
|
6 |
-
normal_init)
|
7 |
-
from torch.nn.modules.batchnorm import _BatchNorm
|
8 |
-
|
9 |
-
from mmpose.utils import get_root_logger
|
10 |
-
from ..builder import BACKBONES
|
11 |
-
from .resnet import BasicBlock, Bottleneck, get_expansion
|
12 |
-
from .utils import load_checkpoint
|
13 |
-
|
14 |
-
|
15 |
-
class HRModule(nn.Module):
|
16 |
-
"""High-Resolution Module for HRNet.
|
17 |
-
|
18 |
-
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
|
19 |
-
is in this module.
|
20 |
-
"""
|
21 |
-
|
22 |
-
def __init__(self,
|
23 |
-
num_branches,
|
24 |
-
blocks,
|
25 |
-
num_blocks,
|
26 |
-
in_channels,
|
27 |
-
num_channels,
|
28 |
-
multiscale_output=False,
|
29 |
-
with_cp=False,
|
30 |
-
conv_cfg=None,
|
31 |
-
norm_cfg=dict(type='BN'),
|
32 |
-
upsample_cfg=dict(mode='nearest', align_corners=None)):
|
33 |
-
|
34 |
-
# Protect mutable default arguments
|
35 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
36 |
-
super().__init__()
|
37 |
-
self._check_branches(num_branches, num_blocks, in_channels,
|
38 |
-
num_channels)
|
39 |
-
|
40 |
-
self.in_channels = in_channels
|
41 |
-
self.num_branches = num_branches
|
42 |
-
|
43 |
-
self.multiscale_output = multiscale_output
|
44 |
-
self.norm_cfg = norm_cfg
|
45 |
-
self.conv_cfg = conv_cfg
|
46 |
-
self.upsample_cfg = upsample_cfg
|
47 |
-
self.with_cp = with_cp
|
48 |
-
self.branches = self._make_branches(num_branches, blocks, num_blocks,
|
49 |
-
num_channels)
|
50 |
-
self.fuse_layers = self._make_fuse_layers()
|
51 |
-
self.relu = nn.ReLU(inplace=True)
|
52 |
-
|
53 |
-
@staticmethod
|
54 |
-
def _check_branches(num_branches, num_blocks, in_channels, num_channels):
|
55 |
-
"""Check input to avoid ValueError."""
|
56 |
-
if num_branches != len(num_blocks):
|
57 |
-
error_msg = f'NUM_BRANCHES({num_branches}) ' \
|
58 |
-
f'!= NUM_BLOCKS({len(num_blocks)})'
|
59 |
-
raise ValueError(error_msg)
|
60 |
-
|
61 |
-
if num_branches != len(num_channels):
|
62 |
-
error_msg = f'NUM_BRANCHES({num_branches}) ' \
|
63 |
-
f'!= NUM_CHANNELS({len(num_channels)})'
|
64 |
-
raise ValueError(error_msg)
|
65 |
-
|
66 |
-
if num_branches != len(in_channels):
|
67 |
-
error_msg = f'NUM_BRANCHES({num_branches}) ' \
|
68 |
-
f'!= NUM_INCHANNELS({len(in_channels)})'
|
69 |
-
raise ValueError(error_msg)
|
70 |
-
|
71 |
-
def _make_one_branch(self,
|
72 |
-
branch_index,
|
73 |
-
block,
|
74 |
-
num_blocks,
|
75 |
-
num_channels,
|
76 |
-
stride=1):
|
77 |
-
"""Make one branch."""
|
78 |
-
downsample = None
|
79 |
-
if stride != 1 or \
|
80 |
-
self.in_channels[branch_index] != \
|
81 |
-
num_channels[branch_index] * get_expansion(block):
|
82 |
-
downsample = nn.Sequential(
|
83 |
-
build_conv_layer(
|
84 |
-
self.conv_cfg,
|
85 |
-
self.in_channels[branch_index],
|
86 |
-
num_channels[branch_index] * get_expansion(block),
|
87 |
-
kernel_size=1,
|
88 |
-
stride=stride,
|
89 |
-
bias=False),
|
90 |
-
build_norm_layer(
|
91 |
-
self.norm_cfg,
|
92 |
-
num_channels[branch_index] * get_expansion(block))[1])
|
93 |
-
|
94 |
-
layers = []
|
95 |
-
layers.append(
|
96 |
-
block(
|
97 |
-
self.in_channels[branch_index],
|
98 |
-
num_channels[branch_index] * get_expansion(block),
|
99 |
-
stride=stride,
|
100 |
-
downsample=downsample,
|
101 |
-
with_cp=self.with_cp,
|
102 |
-
norm_cfg=self.norm_cfg,
|
103 |
-
conv_cfg=self.conv_cfg))
|
104 |
-
self.in_channels[branch_index] = \
|
105 |
-
num_channels[branch_index] * get_expansion(block)
|
106 |
-
for _ in range(1, num_blocks[branch_index]):
|
107 |
-
layers.append(
|
108 |
-
block(
|
109 |
-
self.in_channels[branch_index],
|
110 |
-
num_channels[branch_index] * get_expansion(block),
|
111 |
-
with_cp=self.with_cp,
|
112 |
-
norm_cfg=self.norm_cfg,
|
113 |
-
conv_cfg=self.conv_cfg))
|
114 |
-
|
115 |
-
return nn.Sequential(*layers)
|
116 |
-
|
117 |
-
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
118 |
-
"""Make branches."""
|
119 |
-
branches = []
|
120 |
-
|
121 |
-
for i in range(num_branches):
|
122 |
-
branches.append(
|
123 |
-
self._make_one_branch(i, block, num_blocks, num_channels))
|
124 |
-
|
125 |
-
return nn.ModuleList(branches)
|
126 |
-
|
127 |
-
def _make_fuse_layers(self):
|
128 |
-
"""Make fuse layer."""
|
129 |
-
if self.num_branches == 1:
|
130 |
-
return None
|
131 |
-
|
132 |
-
num_branches = self.num_branches
|
133 |
-
in_channels = self.in_channels
|
134 |
-
fuse_layers = []
|
135 |
-
num_out_branches = num_branches if self.multiscale_output else 1
|
136 |
-
|
137 |
-
for i in range(num_out_branches):
|
138 |
-
fuse_layer = []
|
139 |
-
for j in range(num_branches):
|
140 |
-
if j > i:
|
141 |
-
fuse_layer.append(
|
142 |
-
nn.Sequential(
|
143 |
-
build_conv_layer(
|
144 |
-
self.conv_cfg,
|
145 |
-
in_channels[j],
|
146 |
-
in_channels[i],
|
147 |
-
kernel_size=1,
|
148 |
-
stride=1,
|
149 |
-
padding=0,
|
150 |
-
bias=False),
|
151 |
-
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
152 |
-
nn.Upsample(
|
153 |
-
scale_factor=2**(j - i),
|
154 |
-
mode=self.upsample_cfg['mode'],
|
155 |
-
align_corners=self.
|
156 |
-
upsample_cfg['align_corners'])))
|
157 |
-
elif j == i:
|
158 |
-
fuse_layer.append(None)
|
159 |
-
else:
|
160 |
-
conv_downsamples = []
|
161 |
-
for k in range(i - j):
|
162 |
-
if k == i - j - 1:
|
163 |
-
conv_downsamples.append(
|
164 |
-
nn.Sequential(
|
165 |
-
build_conv_layer(
|
166 |
-
self.conv_cfg,
|
167 |
-
in_channels[j],
|
168 |
-
in_channels[i],
|
169 |
-
kernel_size=3,
|
170 |
-
stride=2,
|
171 |
-
padding=1,
|
172 |
-
bias=False),
|
173 |
-
build_norm_layer(self.norm_cfg,
|
174 |
-
in_channels[i])[1]))
|
175 |
-
else:
|
176 |
-
conv_downsamples.append(
|
177 |
-
nn.Sequential(
|
178 |
-
build_conv_layer(
|
179 |
-
self.conv_cfg,
|
180 |
-
in_channels[j],
|
181 |
-
in_channels[j],
|
182 |
-
kernel_size=3,
|
183 |
-
stride=2,
|
184 |
-
padding=1,
|
185 |
-
bias=False),
|
186 |
-
build_norm_layer(self.norm_cfg,
|
187 |
-
in_channels[j])[1],
|
188 |
-
nn.ReLU(inplace=True)))
|
189 |
-
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
190 |
-
fuse_layers.append(nn.ModuleList(fuse_layer))
|
191 |
-
|
192 |
-
return nn.ModuleList(fuse_layers)
|
193 |
-
|
194 |
-
def forward(self, x):
|
195 |
-
"""Forward function."""
|
196 |
-
if self.num_branches == 1:
|
197 |
-
return [self.branches[0](x[0])]
|
198 |
-
|
199 |
-
for i in range(self.num_branches):
|
200 |
-
x[i] = self.branches[i](x[i])
|
201 |
-
|
202 |
-
x_fuse = []
|
203 |
-
for i in range(len(self.fuse_layers)):
|
204 |
-
y = 0
|
205 |
-
for j in range(self.num_branches):
|
206 |
-
if i == j:
|
207 |
-
y += x[j]
|
208 |
-
else:
|
209 |
-
y += self.fuse_layers[i][j](x[j])
|
210 |
-
x_fuse.append(self.relu(y))
|
211 |
-
return x_fuse
|
212 |
-
|
213 |
-
|
214 |
-
@BACKBONES.register_module()
|
215 |
-
class HRNet(nn.Module):
|
216 |
-
"""HRNet backbone.
|
217 |
-
|
218 |
-
`High-Resolution Representations for Labeling Pixels and Regions
|
219 |
-
<https://arxiv.org/abs/1904.04514>`__
|
220 |
-
|
221 |
-
Args:
|
222 |
-
extra (dict): detailed configuration for each stage of HRNet.
|
223 |
-
in_channels (int): Number of input image channels. Default: 3.
|
224 |
-
conv_cfg (dict): dictionary to construct and config conv layer.
|
225 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
226 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
227 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
228 |
-
and its variants only. Default: False
|
229 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
230 |
-
memory while slowing down the training speed.
|
231 |
-
zero_init_residual (bool): whether to use zero init for last norm layer
|
232 |
-
in resblocks to let them behave as identity.
|
233 |
-
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
234 |
-
-1 means not freezing any parameters. Default: -1.
|
235 |
-
|
236 |
-
Example:
|
237 |
-
>>> from mmpose.models import HRNet
|
238 |
-
>>> import torch
|
239 |
-
>>> extra = dict(
|
240 |
-
>>> stage1=dict(
|
241 |
-
>>> num_modules=1,
|
242 |
-
>>> num_branches=1,
|
243 |
-
>>> block='BOTTLENECK',
|
244 |
-
>>> num_blocks=(4, ),
|
245 |
-
>>> num_channels=(64, )),
|
246 |
-
>>> stage2=dict(
|
247 |
-
>>> num_modules=1,
|
248 |
-
>>> num_branches=2,
|
249 |
-
>>> block='BASIC',
|
250 |
-
>>> num_blocks=(4, 4),
|
251 |
-
>>> num_channels=(32, 64)),
|
252 |
-
>>> stage3=dict(
|
253 |
-
>>> num_modules=4,
|
254 |
-
>>> num_branches=3,
|
255 |
-
>>> block='BASIC',
|
256 |
-
>>> num_blocks=(4, 4, 4),
|
257 |
-
>>> num_channels=(32, 64, 128)),
|
258 |
-
>>> stage4=dict(
|
259 |
-
>>> num_modules=3,
|
260 |
-
>>> num_branches=4,
|
261 |
-
>>> block='BASIC',
|
262 |
-
>>> num_blocks=(4, 4, 4, 4),
|
263 |
-
>>> num_channels=(32, 64, 128, 256)))
|
264 |
-
>>> self = HRNet(extra, in_channels=1)
|
265 |
-
>>> self.eval()
|
266 |
-
>>> inputs = torch.rand(1, 1, 32, 32)
|
267 |
-
>>> level_outputs = self.forward(inputs)
|
268 |
-
>>> for level_out in level_outputs:
|
269 |
-
... print(tuple(level_out.shape))
|
270 |
-
(1, 32, 8, 8)
|
271 |
-
"""
|
272 |
-
|
273 |
-
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
|
274 |
-
|
275 |
-
def __init__(self,
|
276 |
-
extra,
|
277 |
-
in_channels=3,
|
278 |
-
conv_cfg=None,
|
279 |
-
norm_cfg=dict(type='BN'),
|
280 |
-
norm_eval=False,
|
281 |
-
with_cp=False,
|
282 |
-
zero_init_residual=False,
|
283 |
-
frozen_stages=-1):
|
284 |
-
# Protect mutable default arguments
|
285 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
286 |
-
super().__init__()
|
287 |
-
self.extra = extra
|
288 |
-
self.conv_cfg = conv_cfg
|
289 |
-
self.norm_cfg = norm_cfg
|
290 |
-
self.norm_eval = norm_eval
|
291 |
-
self.with_cp = with_cp
|
292 |
-
self.zero_init_residual = zero_init_residual
|
293 |
-
self.frozen_stages = frozen_stages
|
294 |
-
|
295 |
-
# stem net
|
296 |
-
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
297 |
-
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
|
298 |
-
|
299 |
-
self.conv1 = build_conv_layer(
|
300 |
-
self.conv_cfg,
|
301 |
-
in_channels,
|
302 |
-
64,
|
303 |
-
kernel_size=3,
|
304 |
-
stride=2,
|
305 |
-
padding=1,
|
306 |
-
bias=False)
|
307 |
-
|
308 |
-
self.add_module(self.norm1_name, norm1)
|
309 |
-
self.conv2 = build_conv_layer(
|
310 |
-
self.conv_cfg,
|
311 |
-
64,
|
312 |
-
64,
|
313 |
-
kernel_size=3,
|
314 |
-
stride=2,
|
315 |
-
padding=1,
|
316 |
-
bias=False)
|
317 |
-
|
318 |
-
self.add_module(self.norm2_name, norm2)
|
319 |
-
self.relu = nn.ReLU(inplace=True)
|
320 |
-
|
321 |
-
self.upsample_cfg = self.extra.get('upsample', {
|
322 |
-
'mode': 'nearest',
|
323 |
-
'align_corners': None
|
324 |
-
})
|
325 |
-
|
326 |
-
# stage 1
|
327 |
-
self.stage1_cfg = self.extra['stage1']
|
328 |
-
num_channels = self.stage1_cfg['num_channels'][0]
|
329 |
-
block_type = self.stage1_cfg['block']
|
330 |
-
num_blocks = self.stage1_cfg['num_blocks'][0]
|
331 |
-
|
332 |
-
block = self.blocks_dict[block_type]
|
333 |
-
stage1_out_channels = num_channels * get_expansion(block)
|
334 |
-
self.layer1 = self._make_layer(block, 64, stage1_out_channels,
|
335 |
-
num_blocks)
|
336 |
-
|
337 |
-
# stage 2
|
338 |
-
self.stage2_cfg = self.extra['stage2']
|
339 |
-
num_channels = self.stage2_cfg['num_channels']
|
340 |
-
block_type = self.stage2_cfg['block']
|
341 |
-
|
342 |
-
block = self.blocks_dict[block_type]
|
343 |
-
num_channels = [
|
344 |
-
channel * get_expansion(block) for channel in num_channels
|
345 |
-
]
|
346 |
-
self.transition1 = self._make_transition_layer([stage1_out_channels],
|
347 |
-
num_channels)
|
348 |
-
self.stage2, pre_stage_channels = self._make_stage(
|
349 |
-
self.stage2_cfg, num_channels)
|
350 |
-
|
351 |
-
# stage 3
|
352 |
-
self.stage3_cfg = self.extra['stage3']
|
353 |
-
num_channels = self.stage3_cfg['num_channels']
|
354 |
-
block_type = self.stage3_cfg['block']
|
355 |
-
|
356 |
-
block = self.blocks_dict[block_type]
|
357 |
-
num_channels = [
|
358 |
-
channel * get_expansion(block) for channel in num_channels
|
359 |
-
]
|
360 |
-
self.transition2 = self._make_transition_layer(pre_stage_channels,
|
361 |
-
num_channels)
|
362 |
-
self.stage3, pre_stage_channels = self._make_stage(
|
363 |
-
self.stage3_cfg, num_channels)
|
364 |
-
|
365 |
-
# stage 4
|
366 |
-
self.stage4_cfg = self.extra['stage4']
|
367 |
-
num_channels = self.stage4_cfg['num_channels']
|
368 |
-
block_type = self.stage4_cfg['block']
|
369 |
-
|
370 |
-
block = self.blocks_dict[block_type]
|
371 |
-
num_channels = [
|
372 |
-
channel * get_expansion(block) for channel in num_channels
|
373 |
-
]
|
374 |
-
self.transition3 = self._make_transition_layer(pre_stage_channels,
|
375 |
-
num_channels)
|
376 |
-
|
377 |
-
self.stage4, pre_stage_channels = self._make_stage(
|
378 |
-
self.stage4_cfg,
|
379 |
-
num_channels,
|
380 |
-
multiscale_output=self.stage4_cfg.get('multiscale_output', False))
|
381 |
-
|
382 |
-
self._freeze_stages()
|
383 |
-
|
384 |
-
@property
|
385 |
-
def norm1(self):
|
386 |
-
"""nn.Module: the normalization layer named "norm1" """
|
387 |
-
return getattr(self, self.norm1_name)
|
388 |
-
|
389 |
-
@property
|
390 |
-
def norm2(self):
|
391 |
-
"""nn.Module: the normalization layer named "norm2" """
|
392 |
-
return getattr(self, self.norm2_name)
|
393 |
-
|
394 |
-
def _make_transition_layer(self, num_channels_pre_layer,
|
395 |
-
num_channels_cur_layer):
|
396 |
-
"""Make transition layer."""
|
397 |
-
num_branches_cur = len(num_channels_cur_layer)
|
398 |
-
num_branches_pre = len(num_channels_pre_layer)
|
399 |
-
|
400 |
-
transition_layers = []
|
401 |
-
for i in range(num_branches_cur):
|
402 |
-
if i < num_branches_pre:
|
403 |
-
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
404 |
-
transition_layers.append(
|
405 |
-
nn.Sequential(
|
406 |
-
build_conv_layer(
|
407 |
-
self.conv_cfg,
|
408 |
-
num_channels_pre_layer[i],
|
409 |
-
num_channels_cur_layer[i],
|
410 |
-
kernel_size=3,
|
411 |
-
stride=1,
|
412 |
-
padding=1,
|
413 |
-
bias=False),
|
414 |
-
build_norm_layer(self.norm_cfg,
|
415 |
-
num_channels_cur_layer[i])[1],
|
416 |
-
nn.ReLU(inplace=True)))
|
417 |
-
else:
|
418 |
-
transition_layers.append(None)
|
419 |
-
else:
|
420 |
-
conv_downsamples = []
|
421 |
-
for j in range(i + 1 - num_branches_pre):
|
422 |
-
in_channels = num_channels_pre_layer[-1]
|
423 |
-
out_channels = num_channels_cur_layer[i] \
|
424 |
-
if j == i - num_branches_pre else in_channels
|
425 |
-
conv_downsamples.append(
|
426 |
-
nn.Sequential(
|
427 |
-
build_conv_layer(
|
428 |
-
self.conv_cfg,
|
429 |
-
in_channels,
|
430 |
-
out_channels,
|
431 |
-
kernel_size=3,
|
432 |
-
stride=2,
|
433 |
-
padding=1,
|
434 |
-
bias=False),
|
435 |
-
build_norm_layer(self.norm_cfg, out_channels)[1],
|
436 |
-
nn.ReLU(inplace=True)))
|
437 |
-
transition_layers.append(nn.Sequential(*conv_downsamples))
|
438 |
-
|
439 |
-
return nn.ModuleList(transition_layers)
|
440 |
-
|
441 |
-
def _make_layer(self, block, in_channels, out_channels, blocks, stride=1):
|
442 |
-
"""Make layer."""
|
443 |
-
downsample = None
|
444 |
-
if stride != 1 or in_channels != out_channels:
|
445 |
-
downsample = nn.Sequential(
|
446 |
-
build_conv_layer(
|
447 |
-
self.conv_cfg,
|
448 |
-
in_channels,
|
449 |
-
out_channels,
|
450 |
-
kernel_size=1,
|
451 |
-
stride=stride,
|
452 |
-
bias=False),
|
453 |
-
build_norm_layer(self.norm_cfg, out_channels)[1])
|
454 |
-
|
455 |
-
layers = []
|
456 |
-
layers.append(
|
457 |
-
block(
|
458 |
-
in_channels,
|
459 |
-
out_channels,
|
460 |
-
stride=stride,
|
461 |
-
downsample=downsample,
|
462 |
-
with_cp=self.with_cp,
|
463 |
-
norm_cfg=self.norm_cfg,
|
464 |
-
conv_cfg=self.conv_cfg))
|
465 |
-
for _ in range(1, blocks):
|
466 |
-
layers.append(
|
467 |
-
block(
|
468 |
-
out_channels,
|
469 |
-
out_channels,
|
470 |
-
with_cp=self.with_cp,
|
471 |
-
norm_cfg=self.norm_cfg,
|
472 |
-
conv_cfg=self.conv_cfg))
|
473 |
-
|
474 |
-
return nn.Sequential(*layers)
|
475 |
-
|
476 |
-
def _make_stage(self, layer_config, in_channels, multiscale_output=True):
|
477 |
-
"""Make stage."""
|
478 |
-
num_modules = layer_config['num_modules']
|
479 |
-
num_branches = layer_config['num_branches']
|
480 |
-
num_blocks = layer_config['num_blocks']
|
481 |
-
num_channels = layer_config['num_channels']
|
482 |
-
block = self.blocks_dict[layer_config['block']]
|
483 |
-
|
484 |
-
hr_modules = []
|
485 |
-
for i in range(num_modules):
|
486 |
-
# multi_scale_output is only used for the last module
|
487 |
-
if not multiscale_output and i == num_modules - 1:
|
488 |
-
reset_multiscale_output = False
|
489 |
-
else:
|
490 |
-
reset_multiscale_output = True
|
491 |
-
|
492 |
-
hr_modules.append(
|
493 |
-
HRModule(
|
494 |
-
num_branches,
|
495 |
-
block,
|
496 |
-
num_blocks,
|
497 |
-
in_channels,
|
498 |
-
num_channels,
|
499 |
-
reset_multiscale_output,
|
500 |
-
with_cp=self.with_cp,
|
501 |
-
norm_cfg=self.norm_cfg,
|
502 |
-
conv_cfg=self.conv_cfg,
|
503 |
-
upsample_cfg=self.upsample_cfg))
|
504 |
-
|
505 |
-
in_channels = hr_modules[-1].in_channels
|
506 |
-
|
507 |
-
return nn.Sequential(*hr_modules), in_channels
|
508 |
-
|
509 |
-
def _freeze_stages(self):
|
510 |
-
"""Freeze parameters."""
|
511 |
-
if self.frozen_stages >= 0:
|
512 |
-
self.norm1.eval()
|
513 |
-
self.norm2.eval()
|
514 |
-
|
515 |
-
for m in [self.conv1, self.norm1, self.conv2, self.norm2]:
|
516 |
-
for param in m.parameters():
|
517 |
-
param.requires_grad = False
|
518 |
-
|
519 |
-
for i in range(1, self.frozen_stages + 1):
|
520 |
-
if i == 1:
|
521 |
-
m = getattr(self, 'layer1')
|
522 |
-
else:
|
523 |
-
m = getattr(self, f'stage{i}')
|
524 |
-
|
525 |
-
m.eval()
|
526 |
-
for param in m.parameters():
|
527 |
-
param.requires_grad = False
|
528 |
-
|
529 |
-
if i < 4:
|
530 |
-
m = getattr(self, f'transition{i}')
|
531 |
-
m.eval()
|
532 |
-
for param in m.parameters():
|
533 |
-
param.requires_grad = False
|
534 |
-
|
535 |
-
def init_weights(self, pretrained=None):
|
536 |
-
"""Initialize the weights in backbone.
|
537 |
-
|
538 |
-
Args:
|
539 |
-
pretrained (str, optional): Path to pre-trained weights.
|
540 |
-
Defaults to None.
|
541 |
-
"""
|
542 |
-
if isinstance(pretrained, str):
|
543 |
-
logger = get_root_logger()
|
544 |
-
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
545 |
-
elif pretrained is None:
|
546 |
-
for m in self.modules():
|
547 |
-
if isinstance(m, nn.Conv2d):
|
548 |
-
normal_init(m, std=0.001)
|
549 |
-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
550 |
-
constant_init(m, 1)
|
551 |
-
|
552 |
-
if self.zero_init_residual:
|
553 |
-
for m in self.modules():
|
554 |
-
if isinstance(m, Bottleneck):
|
555 |
-
constant_init(m.norm3, 0)
|
556 |
-
elif isinstance(m, BasicBlock):
|
557 |
-
constant_init(m.norm2, 0)
|
558 |
-
else:
|
559 |
-
raise TypeError('pretrained must be a str or None')
|
560 |
-
|
561 |
-
def forward(self, x):
|
562 |
-
"""Forward function."""
|
563 |
-
x = self.conv1(x)
|
564 |
-
x = self.norm1(x)
|
565 |
-
x = self.relu(x)
|
566 |
-
x = self.conv2(x)
|
567 |
-
x = self.norm2(x)
|
568 |
-
x = self.relu(x)
|
569 |
-
x = self.layer1(x)
|
570 |
-
|
571 |
-
x_list = []
|
572 |
-
for i in range(self.stage2_cfg['num_branches']):
|
573 |
-
if self.transition1[i] is not None:
|
574 |
-
x_list.append(self.transition1[i](x))
|
575 |
-
else:
|
576 |
-
x_list.append(x)
|
577 |
-
y_list = self.stage2(x_list)
|
578 |
-
|
579 |
-
x_list = []
|
580 |
-
for i in range(self.stage3_cfg['num_branches']):
|
581 |
-
if self.transition2[i] is not None:
|
582 |
-
x_list.append(self.transition2[i](y_list[-1]))
|
583 |
-
else:
|
584 |
-
x_list.append(y_list[i])
|
585 |
-
y_list = self.stage3(x_list)
|
586 |
-
|
587 |
-
x_list = []
|
588 |
-
for i in range(self.stage4_cfg['num_branches']):
|
589 |
-
if self.transition3[i] is not None:
|
590 |
-
x_list.append(self.transition3[i](y_list[-1]))
|
591 |
-
else:
|
592 |
-
x_list.append(y_list[i])
|
593 |
-
y_list = self.stage4(x_list)
|
594 |
-
|
595 |
-
return y_list
|
596 |
-
|
597 |
-
def train(self, mode=True):
|
598 |
-
"""Convert the model into training mode."""
|
599 |
-
super().train(mode)
|
600 |
-
self._freeze_stages()
|
601 |
-
if mode and self.norm_eval:
|
602 |
-
for m in self.modules():
|
603 |
-
if isinstance(m, _BatchNorm):
|
604 |
-
m.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/hrt.py
DELETED
@@ -1,676 +0,0 @@
|
|
1 |
-
# --------------------------------------------------------
|
2 |
-
# High Resolution Transformer
|
3 |
-
# Copyright (c) 2021 Microsoft
|
4 |
-
# Licensed under The MIT License [see LICENSE for details]
|
5 |
-
# Written by Rao Fu, RainbowSecret
|
6 |
-
# --------------------------------------------------------
|
7 |
-
|
8 |
-
import pdb
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from mmcv.cnn import (
|
12 |
-
build_conv_layer,
|
13 |
-
build_norm_layer,
|
14 |
-
constant_init,
|
15 |
-
kaiming_init,
|
16 |
-
normal_init,
|
17 |
-
)
|
18 |
-
# from mmcv.runner import load_checkpoint
|
19 |
-
from .hrt_checkpoint import load_checkpoint
|
20 |
-
from mmcv.runner.checkpoint import load_state_dict
|
21 |
-
from mmcv.utils.parrots_wrapper import _BatchNorm
|
22 |
-
|
23 |
-
from mmpose.models.utils.ops import resize
|
24 |
-
from mmpose.utils import get_root_logger
|
25 |
-
from ..builder import BACKBONES
|
26 |
-
from .modules.bottleneck_block import Bottleneck
|
27 |
-
from .modules.transformer_block import GeneralTransformerBlock
|
28 |
-
|
29 |
-
|
30 |
-
class HighResolutionTransformerModule(nn.Module):
|
31 |
-
def __init__(
|
32 |
-
self,
|
33 |
-
num_branches,
|
34 |
-
blocks,
|
35 |
-
num_blocks,
|
36 |
-
in_channels,
|
37 |
-
num_channels,
|
38 |
-
multiscale_output,
|
39 |
-
with_cp=False,
|
40 |
-
conv_cfg=None,
|
41 |
-
norm_cfg=dict(type="BN", requires_grad=True),
|
42 |
-
num_heads=None,
|
43 |
-
num_window_sizes=None,
|
44 |
-
num_mlp_ratios=None,
|
45 |
-
drop_paths=0.0,
|
46 |
-
):
|
47 |
-
super(HighResolutionTransformerModule, self).__init__()
|
48 |
-
self._check_branches(num_branches, num_blocks, in_channels, num_channels)
|
49 |
-
|
50 |
-
self.in_channels = in_channels
|
51 |
-
self.num_branches = num_branches
|
52 |
-
|
53 |
-
self.multiscale_output = multiscale_output
|
54 |
-
self.norm_cfg = norm_cfg
|
55 |
-
self.conv_cfg = conv_cfg
|
56 |
-
self.with_cp = with_cp
|
57 |
-
self.branches = self._make_branches(
|
58 |
-
num_branches,
|
59 |
-
blocks,
|
60 |
-
num_blocks,
|
61 |
-
num_channels,
|
62 |
-
num_heads,
|
63 |
-
num_window_sizes,
|
64 |
-
num_mlp_ratios,
|
65 |
-
drop_paths,
|
66 |
-
)
|
67 |
-
self.fuse_layers = self._make_fuse_layers()
|
68 |
-
self.relu = nn.ReLU(inplace=True)
|
69 |
-
|
70 |
-
# MHSA parameters
|
71 |
-
self.num_heads = num_heads
|
72 |
-
self.num_window_sizes = num_window_sizes
|
73 |
-
self.num_mlp_ratios = num_mlp_ratios
|
74 |
-
|
75 |
-
def _check_branches(self, num_branches, num_blocks, in_channels, num_channels):
|
76 |
-
logger = get_root_logger()
|
77 |
-
if num_branches != len(num_blocks):
|
78 |
-
error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
|
79 |
-
num_branches, len(num_blocks)
|
80 |
-
)
|
81 |
-
logger.error(error_msg)
|
82 |
-
raise ValueError(error_msg)
|
83 |
-
|
84 |
-
if num_branches != len(num_channels):
|
85 |
-
error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
|
86 |
-
num_branches, len(num_channels)
|
87 |
-
)
|
88 |
-
logger.error(error_msg)
|
89 |
-
raise ValueError(error_msg)
|
90 |
-
|
91 |
-
if num_branches != len(in_channels):
|
92 |
-
error_msg = "NUM_BRANCHES({}) <> IN_CHANNELS({})".format(
|
93 |
-
num_branches, len(in_channels)
|
94 |
-
)
|
95 |
-
logger.error(error_msg)
|
96 |
-
raise ValueError(error_msg)
|
97 |
-
|
98 |
-
def _make_one_branch(
|
99 |
-
self,
|
100 |
-
branch_index,
|
101 |
-
block,
|
102 |
-
num_blocks,
|
103 |
-
num_channels,
|
104 |
-
num_heads,
|
105 |
-
num_window_sizes,
|
106 |
-
num_mlp_ratios,
|
107 |
-
drop_paths,
|
108 |
-
stride=1,
|
109 |
-
):
|
110 |
-
"""Make one branch."""
|
111 |
-
downsample = None
|
112 |
-
if (
|
113 |
-
stride != 1
|
114 |
-
or self.in_channels[branch_index]
|
115 |
-
!= num_channels[branch_index] * block.expansion
|
116 |
-
):
|
117 |
-
downsample = nn.Sequential(
|
118 |
-
build_conv_layer(
|
119 |
-
self.conv_cfg,
|
120 |
-
self.in_channels[branch_index],
|
121 |
-
num_channels[branch_index] * block.expansion,
|
122 |
-
kernel_size=1,
|
123 |
-
stride=stride,
|
124 |
-
bias=False,
|
125 |
-
),
|
126 |
-
build_norm_layer(
|
127 |
-
self.norm_cfg, num_channels[branch_index] * block.expansion
|
128 |
-
)[1],
|
129 |
-
)
|
130 |
-
|
131 |
-
layers = []
|
132 |
-
|
133 |
-
layers.append(
|
134 |
-
block(
|
135 |
-
self.in_channels[branch_index],
|
136 |
-
num_channels[branch_index],
|
137 |
-
num_heads=num_heads[branch_index],
|
138 |
-
window_size=num_window_sizes[branch_index],
|
139 |
-
mlp_ratio=num_mlp_ratios[branch_index],
|
140 |
-
drop_path=drop_paths[0],
|
141 |
-
norm_cfg=self.norm_cfg,
|
142 |
-
conv_cfg=self.conv_cfg,
|
143 |
-
)
|
144 |
-
)
|
145 |
-
self.in_channels[branch_index] = num_channels[branch_index] * block.expansion
|
146 |
-
for i in range(1, num_blocks[branch_index]):
|
147 |
-
layers.append(
|
148 |
-
block(
|
149 |
-
self.in_channels[branch_index],
|
150 |
-
num_channels[branch_index],
|
151 |
-
num_heads=num_heads[branch_index],
|
152 |
-
window_size=num_window_sizes[branch_index],
|
153 |
-
mlp_ratio=num_mlp_ratios[branch_index],
|
154 |
-
drop_path=drop_paths[i],
|
155 |
-
norm_cfg=self.norm_cfg,
|
156 |
-
conv_cfg=self.conv_cfg,
|
157 |
-
)
|
158 |
-
)
|
159 |
-
|
160 |
-
return nn.Sequential(*layers)
|
161 |
-
|
162 |
-
def _make_branches(
|
163 |
-
self,
|
164 |
-
num_branches,
|
165 |
-
block,
|
166 |
-
num_blocks,
|
167 |
-
num_channels,
|
168 |
-
num_heads,
|
169 |
-
num_window_sizes,
|
170 |
-
num_mlp_ratios,
|
171 |
-
drop_paths,
|
172 |
-
):
|
173 |
-
"""Make branches."""
|
174 |
-
branches = []
|
175 |
-
|
176 |
-
for i in range(num_branches):
|
177 |
-
branches.append(
|
178 |
-
self._make_one_branch(
|
179 |
-
i,
|
180 |
-
block,
|
181 |
-
num_blocks,
|
182 |
-
num_channels,
|
183 |
-
num_heads,
|
184 |
-
num_window_sizes,
|
185 |
-
num_mlp_ratios,
|
186 |
-
drop_paths,
|
187 |
-
)
|
188 |
-
)
|
189 |
-
|
190 |
-
return nn.ModuleList(branches)
|
191 |
-
|
192 |
-
def _make_fuse_layers(self):
|
193 |
-
"""Build fuse layer."""
|
194 |
-
if self.num_branches == 1:
|
195 |
-
return None
|
196 |
-
|
197 |
-
num_branches = self.num_branches
|
198 |
-
in_channels = self.in_channels
|
199 |
-
fuse_layers = []
|
200 |
-
num_out_branches = num_branches if self.multiscale_output else 1
|
201 |
-
for i in range(num_out_branches):
|
202 |
-
fuse_layer = []
|
203 |
-
for j in range(num_branches):
|
204 |
-
if j > i:
|
205 |
-
fuse_layer.append(
|
206 |
-
nn.Sequential(
|
207 |
-
build_conv_layer(
|
208 |
-
self.conv_cfg,
|
209 |
-
in_channels[j],
|
210 |
-
in_channels[i],
|
211 |
-
kernel_size=1,
|
212 |
-
stride=1,
|
213 |
-
padding=0,
|
214 |
-
bias=False,
|
215 |
-
),
|
216 |
-
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
217 |
-
nn.Upsample(
|
218 |
-
scale_factor=2 ** (j - i),
|
219 |
-
mode="bilinear",
|
220 |
-
align_corners=False,
|
221 |
-
),
|
222 |
-
)
|
223 |
-
)
|
224 |
-
elif j == i:
|
225 |
-
fuse_layer.append(None)
|
226 |
-
else:
|
227 |
-
conv_downsamples = []
|
228 |
-
for k in range(i - j):
|
229 |
-
if k == i - j - 1:
|
230 |
-
conv_downsamples.append(
|
231 |
-
nn.Sequential(
|
232 |
-
build_conv_layer(
|
233 |
-
self.conv_cfg,
|
234 |
-
in_channels[j],
|
235 |
-
in_channels[j],
|
236 |
-
kernel_size=3,
|
237 |
-
stride=2,
|
238 |
-
padding=1,
|
239 |
-
groups=in_channels[j],
|
240 |
-
bias=False,
|
241 |
-
),
|
242 |
-
build_norm_layer(self.norm_cfg, in_channels[j])[1],
|
243 |
-
build_conv_layer(
|
244 |
-
self.conv_cfg,
|
245 |
-
in_channels[j],
|
246 |
-
in_channels[i],
|
247 |
-
kernel_size=1,
|
248 |
-
stride=1,
|
249 |
-
bias=False,
|
250 |
-
),
|
251 |
-
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
252 |
-
)
|
253 |
-
)
|
254 |
-
else:
|
255 |
-
conv_downsamples.append(
|
256 |
-
nn.Sequential(
|
257 |
-
build_conv_layer(
|
258 |
-
self.conv_cfg,
|
259 |
-
in_channels[j],
|
260 |
-
in_channels[j],
|
261 |
-
kernel_size=3,
|
262 |
-
stride=2,
|
263 |
-
padding=1,
|
264 |
-
groups=in_channels[j],
|
265 |
-
bias=False,
|
266 |
-
),
|
267 |
-
build_norm_layer(self.norm_cfg, in_channels[j])[1],
|
268 |
-
build_conv_layer(
|
269 |
-
self.conv_cfg,
|
270 |
-
in_channels[j],
|
271 |
-
in_channels[j],
|
272 |
-
kernel_size=1,
|
273 |
-
stride=1,
|
274 |
-
bias=False,
|
275 |
-
),
|
276 |
-
build_norm_layer(self.norm_cfg, in_channels[j])[1],
|
277 |
-
nn.ReLU(inplace=True),
|
278 |
-
)
|
279 |
-
)
|
280 |
-
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
281 |
-
fuse_layers.append(nn.ModuleList(fuse_layer))
|
282 |
-
return nn.ModuleList(fuse_layers)
|
283 |
-
|
284 |
-
def forward(self, x):
|
285 |
-
"""Forward function."""
|
286 |
-
if self.num_branches == 1:
|
287 |
-
return [self.branches[0](x[0])]
|
288 |
-
|
289 |
-
for i in range(self.num_branches):
|
290 |
-
x[i] = self.branches[i](x[i])
|
291 |
-
|
292 |
-
x_fuse = []
|
293 |
-
for i in range(len(self.fuse_layers)):
|
294 |
-
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
295 |
-
for j in range(1, self.num_branches):
|
296 |
-
if i == j:
|
297 |
-
y += x[j]
|
298 |
-
elif j > i:
|
299 |
-
y = y + resize(
|
300 |
-
self.fuse_layers[i][j](x[j]),
|
301 |
-
size=x[i].shape[2:],
|
302 |
-
mode="bilinear",
|
303 |
-
align_corners=False,
|
304 |
-
)
|
305 |
-
else:
|
306 |
-
y += self.fuse_layers[i][j](x[j])
|
307 |
-
x_fuse.append(self.relu(y))
|
308 |
-
return x_fuse
|
309 |
-
|
310 |
-
|
311 |
-
@BACKBONES.register_module()
|
312 |
-
class HRT(nn.Module):
|
313 |
-
"""HRT backbone.
|
314 |
-
High Resolution Transformer Backbone
|
315 |
-
"""
|
316 |
-
|
317 |
-
blocks_dict = {
|
318 |
-
"BOTTLENECK": Bottleneck,
|
319 |
-
"TRANSFORMER_BLOCK": GeneralTransformerBlock,
|
320 |
-
}
|
321 |
-
|
322 |
-
def __init__(
|
323 |
-
self,
|
324 |
-
extra,
|
325 |
-
in_channels=3,
|
326 |
-
conv_cfg=None,
|
327 |
-
norm_cfg=dict(type="BN", requires_grad=True),
|
328 |
-
norm_eval=False,
|
329 |
-
with_cp=False,
|
330 |
-
zero_init_residual=False,
|
331 |
-
):
|
332 |
-
super(HRT, self).__init__()
|
333 |
-
self.extra = extra
|
334 |
-
self.conv_cfg = conv_cfg
|
335 |
-
self.norm_cfg = norm_cfg
|
336 |
-
self.norm_eval = norm_eval
|
337 |
-
self.with_cp = with_cp
|
338 |
-
self.zero_init_residual = zero_init_residual
|
339 |
-
|
340 |
-
# stem net
|
341 |
-
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
|
342 |
-
self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
|
343 |
-
|
344 |
-
self.conv1 = build_conv_layer(
|
345 |
-
self.conv_cfg,
|
346 |
-
in_channels,
|
347 |
-
64,
|
348 |
-
kernel_size=3,
|
349 |
-
stride=2,
|
350 |
-
padding=1,
|
351 |
-
bias=False,
|
352 |
-
)
|
353 |
-
self.add_module(self.norm1_name, norm1)
|
354 |
-
|
355 |
-
self.conv2 = build_conv_layer(
|
356 |
-
self.conv_cfg, 64, 64, kernel_size=3, stride=2, padding=1, bias=False
|
357 |
-
)
|
358 |
-
self.add_module(self.norm2_name, norm2)
|
359 |
-
self.relu = nn.ReLU(inplace=True)
|
360 |
-
|
361 |
-
# generat drop path rate list
|
362 |
-
depth_s2 = (
|
363 |
-
self.extra["stage2"]["num_blocks"][0] * self.extra["stage2"]["num_modules"]
|
364 |
-
)
|
365 |
-
depth_s3 = (
|
366 |
-
self.extra["stage3"]["num_blocks"][0] * self.extra["stage3"]["num_modules"]
|
367 |
-
)
|
368 |
-
depth_s4 = (
|
369 |
-
self.extra["stage4"]["num_blocks"][0] * self.extra["stage4"]["num_modules"]
|
370 |
-
)
|
371 |
-
depths = [depth_s2, depth_s3, depth_s4]
|
372 |
-
drop_path_rate = self.extra["drop_path_rate"]
|
373 |
-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
374 |
-
|
375 |
-
logger = get_root_logger()
|
376 |
-
logger.info(dpr)
|
377 |
-
|
378 |
-
# stage 1
|
379 |
-
self.stage1_cfg = self.extra["stage1"]
|
380 |
-
num_channels = self.stage1_cfg["num_channels"][0]
|
381 |
-
block_type = self.stage1_cfg["block"]
|
382 |
-
num_blocks = self.stage1_cfg["num_blocks"][0]
|
383 |
-
|
384 |
-
block = self.blocks_dict[block_type]
|
385 |
-
stage1_out_channels = num_channels * block.expansion
|
386 |
-
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
387 |
-
|
388 |
-
# stage 2
|
389 |
-
self.stage2_cfg = self.extra["stage2"]
|
390 |
-
num_channels = self.stage2_cfg["num_channels"]
|
391 |
-
block_type = self.stage2_cfg["block"]
|
392 |
-
|
393 |
-
block = self.blocks_dict[block_type]
|
394 |
-
num_channels = [channel * block.expansion for channel in num_channels]
|
395 |
-
self.transition1 = self._make_transition_layer(
|
396 |
-
[stage1_out_channels], num_channels
|
397 |
-
)
|
398 |
-
self.stage2, pre_stage_channels = self._make_stage(
|
399 |
-
self.stage2_cfg, num_channels, drop_paths=dpr[0:depth_s2]
|
400 |
-
)
|
401 |
-
|
402 |
-
# stage 3
|
403 |
-
self.stage3_cfg = self.extra["stage3"]
|
404 |
-
num_channels = self.stage3_cfg["num_channels"]
|
405 |
-
block_type = self.stage3_cfg["block"]
|
406 |
-
|
407 |
-
block = self.blocks_dict[block_type]
|
408 |
-
num_channels = [channel * block.expansion for channel in num_channels]
|
409 |
-
self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
|
410 |
-
self.stage3, pre_stage_channels = self._make_stage(
|
411 |
-
self.stage3_cfg,
|
412 |
-
num_channels,
|
413 |
-
drop_paths=dpr[depth_s2 : depth_s2 + depth_s3],
|
414 |
-
)
|
415 |
-
|
416 |
-
# stage 4
|
417 |
-
self.stage4_cfg = self.extra["stage4"]
|
418 |
-
num_channels = self.stage4_cfg["num_channels"]
|
419 |
-
block_type = self.stage4_cfg["block"]
|
420 |
-
|
421 |
-
block = self.blocks_dict[block_type]
|
422 |
-
num_channels = [channel * block.expansion for channel in num_channels]
|
423 |
-
self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
|
424 |
-
self.stage4, pre_stage_channels = self._make_stage(
|
425 |
-
self.stage4_cfg,
|
426 |
-
num_channels,
|
427 |
-
multiscale_output=self.stage4_cfg.get("multiscale_output", True),
|
428 |
-
drop_paths=dpr[depth_s2 + depth_s3 :],
|
429 |
-
)
|
430 |
-
|
431 |
-
@property
|
432 |
-
def norm1(self):
|
433 |
-
"""nn.Module: the normalization layer named "norm1" """
|
434 |
-
return getattr(self, self.norm1_name)
|
435 |
-
|
436 |
-
@property
|
437 |
-
def norm2(self):
|
438 |
-
"""nn.Module: the normalization layer named "norm2" """
|
439 |
-
return getattr(self, self.norm2_name)
|
440 |
-
|
441 |
-
def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
|
442 |
-
"""Make transition layer."""
|
443 |
-
num_branches_cur = len(num_channels_cur_layer)
|
444 |
-
num_branches_pre = len(num_channels_pre_layer)
|
445 |
-
|
446 |
-
transition_layers = []
|
447 |
-
for i in range(num_branches_cur):
|
448 |
-
if i < num_branches_pre:
|
449 |
-
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
450 |
-
transition_layers.append(
|
451 |
-
nn.Sequential(
|
452 |
-
build_conv_layer(
|
453 |
-
self.conv_cfg,
|
454 |
-
num_channels_pre_layer[i],
|
455 |
-
num_channels_cur_layer[i],
|
456 |
-
kernel_size=3,
|
457 |
-
stride=1,
|
458 |
-
padding=1,
|
459 |
-
bias=False,
|
460 |
-
),
|
461 |
-
build_norm_layer(self.norm_cfg, num_channels_cur_layer[i])[
|
462 |
-
1
|
463 |
-
],
|
464 |
-
nn.ReLU(inplace=True),
|
465 |
-
)
|
466 |
-
)
|
467 |
-
else:
|
468 |
-
transition_layers.append(None)
|
469 |
-
else:
|
470 |
-
conv_downsamples = []
|
471 |
-
for j in range(i + 1 - num_branches_pre):
|
472 |
-
in_channels = num_channels_pre_layer[-1]
|
473 |
-
out_channels = (
|
474 |
-
num_channels_cur_layer[i]
|
475 |
-
if j == i - num_branches_pre
|
476 |
-
else in_channels
|
477 |
-
)
|
478 |
-
conv_downsamples.append(
|
479 |
-
nn.Sequential(
|
480 |
-
build_conv_layer(
|
481 |
-
self.conv_cfg,
|
482 |
-
in_channels,
|
483 |
-
out_channels,
|
484 |
-
kernel_size=3,
|
485 |
-
stride=2,
|
486 |
-
padding=1,
|
487 |
-
bias=False,
|
488 |
-
),
|
489 |
-
build_norm_layer(self.norm_cfg, out_channels)[1],
|
490 |
-
nn.ReLU(inplace=True),
|
491 |
-
)
|
492 |
-
)
|
493 |
-
transition_layers.append(nn.Sequential(*conv_downsamples))
|
494 |
-
|
495 |
-
return nn.ModuleList(transition_layers)
|
496 |
-
|
497 |
-
def _make_layer(
|
498 |
-
self,
|
499 |
-
block,
|
500 |
-
inplanes,
|
501 |
-
planes,
|
502 |
-
blocks,
|
503 |
-
stride=1,
|
504 |
-
num_heads=1,
|
505 |
-
window_size=7,
|
506 |
-
mlp_ratio=4.0,
|
507 |
-
):
|
508 |
-
"""Make each layer."""
|
509 |
-
downsample = None
|
510 |
-
if stride != 1 or inplanes != planes * block.expansion:
|
511 |
-
downsample = nn.Sequential(
|
512 |
-
build_conv_layer(
|
513 |
-
self.conv_cfg,
|
514 |
-
inplanes,
|
515 |
-
planes * block.expansion,
|
516 |
-
kernel_size=1,
|
517 |
-
stride=stride,
|
518 |
-
bias=False,
|
519 |
-
),
|
520 |
-
build_norm_layer(self.norm_cfg, planes * block.expansion)[1],
|
521 |
-
)
|
522 |
-
|
523 |
-
layers = []
|
524 |
-
if isinstance(block, GeneralTransformerBlock):
|
525 |
-
layers.append(
|
526 |
-
block(
|
527 |
-
inplanes,
|
528 |
-
planes,
|
529 |
-
num_heads=num_heads,
|
530 |
-
window_size=window_size,
|
531 |
-
mlp_ratio=mlp_ratio,
|
532 |
-
norm_cfg=self.norm_cfg,
|
533 |
-
conv_cfg=self.conv_cfg,
|
534 |
-
)
|
535 |
-
)
|
536 |
-
else:
|
537 |
-
layers.append(
|
538 |
-
block(
|
539 |
-
inplanes,
|
540 |
-
planes,
|
541 |
-
stride,
|
542 |
-
downsample=downsample,
|
543 |
-
with_cp=self.with_cp,
|
544 |
-
norm_cfg=self.norm_cfg,
|
545 |
-
conv_cfg=self.conv_cfg,
|
546 |
-
)
|
547 |
-
)
|
548 |
-
inplanes = planes * block.expansion
|
549 |
-
for i in range(1, blocks):
|
550 |
-
layers.append(
|
551 |
-
block(
|
552 |
-
inplanes,
|
553 |
-
planes,
|
554 |
-
with_cp=self.with_cp,
|
555 |
-
norm_cfg=self.norm_cfg,
|
556 |
-
conv_cfg=self.conv_cfg,
|
557 |
-
)
|
558 |
-
)
|
559 |
-
|
560 |
-
return nn.Sequential(*layers)
|
561 |
-
|
562 |
-
def _make_stage(
|
563 |
-
self, layer_config, in_channels, multiscale_output=True, drop_paths=0.0
|
564 |
-
):
|
565 |
-
"""Make each stage."""
|
566 |
-
num_modules = layer_config["num_modules"]
|
567 |
-
num_branches = layer_config["num_branches"]
|
568 |
-
num_blocks = layer_config["num_blocks"]
|
569 |
-
num_channels = layer_config["num_channels"]
|
570 |
-
block = self.blocks_dict[layer_config["block"]]
|
571 |
-
|
572 |
-
num_heads = layer_config["num_heads"]
|
573 |
-
num_window_sizes = layer_config["num_window_sizes"]
|
574 |
-
num_mlp_ratios = layer_config["num_mlp_ratios"]
|
575 |
-
|
576 |
-
hr_modules = []
|
577 |
-
for i in range(num_modules):
|
578 |
-
# multi_scale_output is only used for the last module
|
579 |
-
if not multiscale_output and i == num_modules - 1:
|
580 |
-
reset_multiscale_output = False
|
581 |
-
else:
|
582 |
-
reset_multiscale_output = True
|
583 |
-
|
584 |
-
hr_modules.append(
|
585 |
-
HighResolutionTransformerModule(
|
586 |
-
num_branches,
|
587 |
-
block,
|
588 |
-
num_blocks,
|
589 |
-
in_channels,
|
590 |
-
num_channels,
|
591 |
-
reset_multiscale_output,
|
592 |
-
with_cp=self.with_cp,
|
593 |
-
norm_cfg=self.norm_cfg,
|
594 |
-
conv_cfg=self.conv_cfg,
|
595 |
-
num_heads=num_heads,
|
596 |
-
num_window_sizes=num_window_sizes,
|
597 |
-
num_mlp_ratios=num_mlp_ratios,
|
598 |
-
drop_paths=drop_paths[num_blocks[0] * i : num_blocks[0] * (i + 1)],
|
599 |
-
)
|
600 |
-
)
|
601 |
-
|
602 |
-
return nn.Sequential(*hr_modules), in_channels
|
603 |
-
|
604 |
-
def init_weights(self, pretrained=None):
|
605 |
-
"""Initialize the weights in backbone.
|
606 |
-
|
607 |
-
Args:
|
608 |
-
pretrained (str, optional): Path to pre-trained weights.
|
609 |
-
Defaults to None.
|
610 |
-
"""
|
611 |
-
if isinstance(pretrained, str):
|
612 |
-
logger = get_root_logger()
|
613 |
-
ckpt = load_checkpoint(self, pretrained, strict=False)
|
614 |
-
if "model" in ckpt:
|
615 |
-
msg = self.load_state_dict(ckpt["model"], strict=False)
|
616 |
-
logger.info(msg)
|
617 |
-
elif pretrained is None:
|
618 |
-
for m in self.modules():
|
619 |
-
if isinstance(m, nn.Conv2d):
|
620 |
-
"""mmseg: kaiming_init(m)"""
|
621 |
-
normal_init(m, std=0.001)
|
622 |
-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
623 |
-
constant_init(m, 1)
|
624 |
-
|
625 |
-
if self.zero_init_residual:
|
626 |
-
for m in self.modules():
|
627 |
-
if isinstance(m, Bottleneck):
|
628 |
-
constant_init(m.norm3, 0)
|
629 |
-
elif isinstance(m, BasicBlock):
|
630 |
-
constant_init(m.norm2, 0)
|
631 |
-
else:
|
632 |
-
raise TypeError("pretrained must be a str or None")
|
633 |
-
|
634 |
-
def forward(self, x):
|
635 |
-
"""Forward function."""
|
636 |
-
x = self.conv1(x)
|
637 |
-
x = self.norm1(x)
|
638 |
-
x = self.relu(x)
|
639 |
-
x = self.conv2(x)
|
640 |
-
x = self.norm2(x)
|
641 |
-
x = self.relu(x)
|
642 |
-
x = self.layer1(x)
|
643 |
-
|
644 |
-
x_list = []
|
645 |
-
for i in range(self.stage2_cfg["num_branches"]):
|
646 |
-
if self.transition1[i] is not None:
|
647 |
-
x_list.append(self.transition1[i](x))
|
648 |
-
else:
|
649 |
-
x_list.append(x)
|
650 |
-
y_list = self.stage2(x_list)
|
651 |
-
|
652 |
-
x_list = []
|
653 |
-
for i in range(self.stage3_cfg["num_branches"]):
|
654 |
-
if self.transition2[i] is not None:
|
655 |
-
x_list.append(self.transition2[i](y_list[-1]))
|
656 |
-
else:
|
657 |
-
x_list.append(y_list[i])
|
658 |
-
y_list = self.stage3(x_list)
|
659 |
-
|
660 |
-
x_list = []
|
661 |
-
for i in range(self.stage4_cfg["num_branches"]):
|
662 |
-
if self.transition3[i] is not None:
|
663 |
-
x_list.append(self.transition3[i](y_list[-1]))
|
664 |
-
else:
|
665 |
-
x_list.append(y_list[i])
|
666 |
-
y_list = self.stage4(x_list)
|
667 |
-
|
668 |
-
return y_list
|
669 |
-
|
670 |
-
def train(self, mode=True):
|
671 |
-
"""Convert the model into training mode."""
|
672 |
-
super(HRT, self).train(mode)
|
673 |
-
if mode and self.norm_eval:
|
674 |
-
for m in self.modules():
|
675 |
-
if isinstance(m, _BatchNorm):
|
676 |
-
m.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/hrt_checkpoint.py
DELETED
@@ -1,500 +0,0 @@
|
|
1 |
-
# Copyright (c) Open-MMLab. All rights reserved.
|
2 |
-
import io
|
3 |
-
import os
|
4 |
-
import os.path as osp
|
5 |
-
import pkgutil
|
6 |
-
import time
|
7 |
-
import warnings
|
8 |
-
from collections import OrderedDict
|
9 |
-
from importlib import import_module
|
10 |
-
from tempfile import TemporaryDirectory
|
11 |
-
|
12 |
-
import torch
|
13 |
-
import torchvision
|
14 |
-
from torch.optim import Optimizer
|
15 |
-
from torch.utils import model_zoo
|
16 |
-
from torch.nn import functional as F
|
17 |
-
|
18 |
-
import mmcv
|
19 |
-
from mmcv.fileio import FileClient
|
20 |
-
from mmcv.fileio import load as load_file
|
21 |
-
from mmcv.parallel import is_module_wrapper
|
22 |
-
from mmcv.utils import mkdir_or_exist
|
23 |
-
from mmcv.runner import get_dist_info
|
24 |
-
|
25 |
-
ENV_MMCV_HOME = 'MMCV_HOME'
|
26 |
-
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
27 |
-
DEFAULT_CACHE_DIR = '~/.cache'
|
28 |
-
|
29 |
-
|
30 |
-
def _get_mmcv_home():
|
31 |
-
mmcv_home = os.path.expanduser(
|
32 |
-
os.getenv(
|
33 |
-
ENV_MMCV_HOME,
|
34 |
-
os.path.join(
|
35 |
-
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
|
36 |
-
|
37 |
-
mkdir_or_exist(mmcv_home)
|
38 |
-
return mmcv_home
|
39 |
-
|
40 |
-
|
41 |
-
def load_state_dict(module, state_dict, strict=False, logger=None):
|
42 |
-
"""Load state_dict to a module.
|
43 |
-
|
44 |
-
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
45 |
-
Default value for ``strict`` is set to ``False`` and the message for
|
46 |
-
param mismatch will be shown even if strict is False.
|
47 |
-
|
48 |
-
Args:
|
49 |
-
module (Module): Module that receives the state_dict.
|
50 |
-
state_dict (OrderedDict): Weights.
|
51 |
-
strict (bool): whether to strictly enforce that the keys
|
52 |
-
in :attr:`state_dict` match the keys returned by this module's
|
53 |
-
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
54 |
-
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
55 |
-
message. If not specified, print function will be used.
|
56 |
-
"""
|
57 |
-
unexpected_keys = []
|
58 |
-
all_missing_keys = []
|
59 |
-
err_msg = []
|
60 |
-
|
61 |
-
metadata = getattr(state_dict, '_metadata', None)
|
62 |
-
state_dict = state_dict.copy()
|
63 |
-
if metadata is not None:
|
64 |
-
state_dict._metadata = metadata
|
65 |
-
|
66 |
-
# use _load_from_state_dict to enable checkpoint version control
|
67 |
-
def load(module, prefix=''):
|
68 |
-
# recursively check parallel module in case that the model has a
|
69 |
-
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
70 |
-
if is_module_wrapper(module):
|
71 |
-
module = module.module
|
72 |
-
local_metadata = {} if metadata is None else metadata.get(
|
73 |
-
prefix[:-1], {})
|
74 |
-
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
75 |
-
all_missing_keys, unexpected_keys,
|
76 |
-
err_msg)
|
77 |
-
for name, child in module._modules.items():
|
78 |
-
if child is not None:
|
79 |
-
load(child, prefix + name + '.')
|
80 |
-
|
81 |
-
load(module)
|
82 |
-
load = None # break load->load reference cycle
|
83 |
-
|
84 |
-
# ignore "num_batches_tracked" of BN layers
|
85 |
-
missing_keys = [
|
86 |
-
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
87 |
-
]
|
88 |
-
|
89 |
-
if unexpected_keys:
|
90 |
-
err_msg.append('unexpected key in source '
|
91 |
-
f'state_dict: {", ".join(unexpected_keys)}\n')
|
92 |
-
if missing_keys:
|
93 |
-
err_msg.append(
|
94 |
-
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
95 |
-
|
96 |
-
rank, _ = get_dist_info()
|
97 |
-
if len(err_msg) > 0 and rank == 0:
|
98 |
-
err_msg.insert(
|
99 |
-
0, 'The model and loaded state dict do not match exactly\n')
|
100 |
-
err_msg = '\n'.join(err_msg)
|
101 |
-
if strict:
|
102 |
-
raise RuntimeError(err_msg)
|
103 |
-
elif logger is not None:
|
104 |
-
logger.warning(err_msg)
|
105 |
-
else:
|
106 |
-
print(err_msg)
|
107 |
-
|
108 |
-
|
109 |
-
def load_url_dist(url, model_dir=None):
|
110 |
-
"""In distributed setting, this function only download checkpoint at local
|
111 |
-
rank 0."""
|
112 |
-
rank, world_size = get_dist_info()
|
113 |
-
rank = int(os.environ.get('LOCAL_RANK', rank))
|
114 |
-
if rank == 0:
|
115 |
-
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
116 |
-
if world_size > 1:
|
117 |
-
torch.distributed.barrier()
|
118 |
-
if rank > 0:
|
119 |
-
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
120 |
-
return checkpoint
|
121 |
-
|
122 |
-
|
123 |
-
def load_pavimodel_dist(model_path, map_location=None):
|
124 |
-
"""In distributed setting, this function only download checkpoint at local
|
125 |
-
rank 0."""
|
126 |
-
try:
|
127 |
-
from pavi import modelcloud
|
128 |
-
except ImportError:
|
129 |
-
raise ImportError(
|
130 |
-
'Please install pavi to load checkpoint from modelcloud.')
|
131 |
-
rank, world_size = get_dist_info()
|
132 |
-
rank = int(os.environ.get('LOCAL_RANK', rank))
|
133 |
-
if rank == 0:
|
134 |
-
model = modelcloud.get(model_path)
|
135 |
-
with TemporaryDirectory() as tmp_dir:
|
136 |
-
downloaded_file = osp.join(tmp_dir, model.name)
|
137 |
-
model.download(downloaded_file)
|
138 |
-
checkpoint = torch.load(downloaded_file, map_location=map_location)
|
139 |
-
if world_size > 1:
|
140 |
-
torch.distributed.barrier()
|
141 |
-
if rank > 0:
|
142 |
-
model = modelcloud.get(model_path)
|
143 |
-
with TemporaryDirectory() as tmp_dir:
|
144 |
-
downloaded_file = osp.join(tmp_dir, model.name)
|
145 |
-
model.download(downloaded_file)
|
146 |
-
checkpoint = torch.load(
|
147 |
-
downloaded_file, map_location=map_location)
|
148 |
-
return checkpoint
|
149 |
-
|
150 |
-
|
151 |
-
def load_fileclient_dist(filename, backend, map_location):
|
152 |
-
"""In distributed setting, this function only download checkpoint at local
|
153 |
-
rank 0."""
|
154 |
-
rank, world_size = get_dist_info()
|
155 |
-
rank = int(os.environ.get('LOCAL_RANK', rank))
|
156 |
-
allowed_backends = ['ceph']
|
157 |
-
if backend not in allowed_backends:
|
158 |
-
raise ValueError(f'Load from Backend {backend} is not supported.')
|
159 |
-
if rank == 0:
|
160 |
-
fileclient = FileClient(backend=backend)
|
161 |
-
buffer = io.BytesIO(fileclient.get(filename))
|
162 |
-
checkpoint = torch.load(buffer, map_location=map_location)
|
163 |
-
if world_size > 1:
|
164 |
-
torch.distributed.barrier()
|
165 |
-
if rank > 0:
|
166 |
-
fileclient = FileClient(backend=backend)
|
167 |
-
buffer = io.BytesIO(fileclient.get(filename))
|
168 |
-
checkpoint = torch.load(buffer, map_location=map_location)
|
169 |
-
return checkpoint
|
170 |
-
|
171 |
-
|
172 |
-
def get_torchvision_models():
|
173 |
-
model_urls = dict()
|
174 |
-
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
175 |
-
if ispkg:
|
176 |
-
continue
|
177 |
-
_zoo = import_module(f'torchvision.models.{name}')
|
178 |
-
if hasattr(_zoo, 'model_urls'):
|
179 |
-
_urls = getattr(_zoo, 'model_urls')
|
180 |
-
model_urls.update(_urls)
|
181 |
-
return model_urls
|
182 |
-
|
183 |
-
|
184 |
-
def get_external_models():
|
185 |
-
mmcv_home = _get_mmcv_home()
|
186 |
-
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
|
187 |
-
default_urls = load_file(default_json_path)
|
188 |
-
assert isinstance(default_urls, dict)
|
189 |
-
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
|
190 |
-
if osp.exists(external_json_path):
|
191 |
-
external_urls = load_file(external_json_path)
|
192 |
-
assert isinstance(external_urls, dict)
|
193 |
-
default_urls.update(external_urls)
|
194 |
-
|
195 |
-
return default_urls
|
196 |
-
|
197 |
-
|
198 |
-
def get_mmcls_models():
|
199 |
-
mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
|
200 |
-
mmcls_urls = load_file(mmcls_json_path)
|
201 |
-
|
202 |
-
return mmcls_urls
|
203 |
-
|
204 |
-
|
205 |
-
def get_deprecated_model_names():
|
206 |
-
deprecate_json_path = osp.join(mmcv.__path__[0],
|
207 |
-
'model_zoo/deprecated.json')
|
208 |
-
deprecate_urls = load_file(deprecate_json_path)
|
209 |
-
assert isinstance(deprecate_urls, dict)
|
210 |
-
|
211 |
-
return deprecate_urls
|
212 |
-
|
213 |
-
|
214 |
-
def _process_mmcls_checkpoint(checkpoint):
|
215 |
-
state_dict = checkpoint['state_dict']
|
216 |
-
new_state_dict = OrderedDict()
|
217 |
-
for k, v in state_dict.items():
|
218 |
-
if k.startswith('backbone.'):
|
219 |
-
new_state_dict[k[9:]] = v
|
220 |
-
new_checkpoint = dict(state_dict=new_state_dict)
|
221 |
-
|
222 |
-
return new_checkpoint
|
223 |
-
|
224 |
-
|
225 |
-
def _load_checkpoint(filename, map_location=None):
|
226 |
-
"""Load checkpoint from somewhere (modelzoo, file, url).
|
227 |
-
|
228 |
-
Args:
|
229 |
-
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
230 |
-
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
231 |
-
details.
|
232 |
-
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
233 |
-
|
234 |
-
Returns:
|
235 |
-
dict | OrderedDict: The loaded checkpoint. It can be either an
|
236 |
-
OrderedDict storing model weights or a dict containing other
|
237 |
-
information, which depends on the checkpoint.
|
238 |
-
"""
|
239 |
-
if filename.startswith('modelzoo://'):
|
240 |
-
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
241 |
-
'use "torchvision://" instead')
|
242 |
-
model_urls = get_torchvision_models()
|
243 |
-
model_name = filename[11:]
|
244 |
-
checkpoint = load_url_dist(model_urls[model_name])
|
245 |
-
elif filename.startswith('torchvision://'):
|
246 |
-
model_urls = get_torchvision_models()
|
247 |
-
model_name = filename[14:]
|
248 |
-
checkpoint = load_url_dist(model_urls[model_name])
|
249 |
-
elif filename.startswith('open-mmlab://'):
|
250 |
-
model_urls = get_external_models()
|
251 |
-
model_name = filename[13:]
|
252 |
-
deprecated_urls = get_deprecated_model_names()
|
253 |
-
if model_name in deprecated_urls:
|
254 |
-
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
|
255 |
-
f'of open-mmlab://{deprecated_urls[model_name]}')
|
256 |
-
model_name = deprecated_urls[model_name]
|
257 |
-
model_url = model_urls[model_name]
|
258 |
-
# check if is url
|
259 |
-
if model_url.startswith(('http://', 'https://')):
|
260 |
-
checkpoint = load_url_dist(model_url)
|
261 |
-
else:
|
262 |
-
filename = osp.join(_get_mmcv_home(), model_url)
|
263 |
-
if not osp.isfile(filename):
|
264 |
-
raise IOError(f'{filename} is not a checkpoint file')
|
265 |
-
checkpoint = torch.load(filename, map_location=map_location)
|
266 |
-
elif filename.startswith('mmcls://'):
|
267 |
-
model_urls = get_mmcls_models()
|
268 |
-
model_name = filename[8:]
|
269 |
-
checkpoint = load_url_dist(model_urls[model_name])
|
270 |
-
checkpoint = _process_mmcls_checkpoint(checkpoint)
|
271 |
-
elif filename.startswith(('http://', 'https://')):
|
272 |
-
checkpoint = load_url_dist(filename)
|
273 |
-
elif filename.startswith('pavi://'):
|
274 |
-
model_path = filename[7:]
|
275 |
-
checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
|
276 |
-
elif filename.startswith('s3://'):
|
277 |
-
checkpoint = load_fileclient_dist(
|
278 |
-
filename, backend='ceph', map_location=map_location)
|
279 |
-
else:
|
280 |
-
if not osp.isfile(filename):
|
281 |
-
raise IOError(f'{filename} is not a checkpoint file')
|
282 |
-
checkpoint = torch.load(filename, map_location=map_location)
|
283 |
-
return checkpoint
|
284 |
-
|
285 |
-
|
286 |
-
def load_checkpoint(model,
|
287 |
-
filename,
|
288 |
-
map_location='cpu',
|
289 |
-
strict=False,
|
290 |
-
logger=None):
|
291 |
-
"""Load checkpoint from a file or URI.
|
292 |
-
|
293 |
-
Args:
|
294 |
-
model (Module): Module to load checkpoint.
|
295 |
-
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
296 |
-
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
297 |
-
details.
|
298 |
-
map_location (str): Same as :func:`torch.load`.
|
299 |
-
strict (bool): Whether to allow different params for the model and
|
300 |
-
checkpoint.
|
301 |
-
logger (:mod:`logging.Logger` or None): The logger for error message.
|
302 |
-
|
303 |
-
Returns:
|
304 |
-
dict or OrderedDict: The loaded checkpoint.
|
305 |
-
"""
|
306 |
-
checkpoint = _load_checkpoint(filename, map_location)
|
307 |
-
# OrderedDict is a subclass of dict
|
308 |
-
if not isinstance(checkpoint, dict):
|
309 |
-
raise RuntimeError(
|
310 |
-
f'No state_dict found in checkpoint file {filename}')
|
311 |
-
# get state_dict from checkpoint
|
312 |
-
if 'state_dict' in checkpoint:
|
313 |
-
state_dict = checkpoint['state_dict']
|
314 |
-
elif 'model' in checkpoint:
|
315 |
-
state_dict = checkpoint['model']
|
316 |
-
else:
|
317 |
-
state_dict = checkpoint
|
318 |
-
# strip prefix of state_dict
|
319 |
-
if list(state_dict.keys())[0].startswith('module.'):
|
320 |
-
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
321 |
-
|
322 |
-
# for MoBY, load model of online branch
|
323 |
-
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
|
324 |
-
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
325 |
-
|
326 |
-
# reshape absolute position embedding
|
327 |
-
if state_dict.get('absolute_pos_embed') is not None:
|
328 |
-
absolute_pos_embed = state_dict['absolute_pos_embed']
|
329 |
-
N1, L, C1 = absolute_pos_embed.size()
|
330 |
-
N2, C2, H, W = model.absolute_pos_embed.size()
|
331 |
-
if N1 != N2 or C1 != C2 or L != H*W:
|
332 |
-
logger.warning("Error in loading absolute_pos_embed, pass")
|
333 |
-
else:
|
334 |
-
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
|
335 |
-
|
336 |
-
# interpolate position bias table if needed
|
337 |
-
# relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
|
338 |
-
# for table_key in relative_position_bias_table_keys:
|
339 |
-
# table_pretrained = state_dict[table_key]
|
340 |
-
# table_current = model.state_dict()[table_key]
|
341 |
-
# L1, nH1 = table_pretrained.size()
|
342 |
-
# L2, nH2 = table_current.size()
|
343 |
-
# if nH1 != nH2:
|
344 |
-
# logger.warning(f"Error in loading {table_key}, pass")
|
345 |
-
# else:
|
346 |
-
# if L1 != L2:
|
347 |
-
# S1 = int(L1 ** 0.5)
|
348 |
-
# S2 = int(L2 ** 0.5)
|
349 |
-
# table_pretrained_resized = F.interpolate(
|
350 |
-
# table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
351 |
-
# size=(S2, S2), mode='bicubic')
|
352 |
-
# state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
|
353 |
-
|
354 |
-
# load state_dict
|
355 |
-
load_state_dict(model, state_dict, strict, logger)
|
356 |
-
return checkpoint
|
357 |
-
|
358 |
-
|
359 |
-
def weights_to_cpu(state_dict):
|
360 |
-
"""Copy a model state_dict to cpu.
|
361 |
-
|
362 |
-
Args:
|
363 |
-
state_dict (OrderedDict): Model weights on GPU.
|
364 |
-
|
365 |
-
Returns:
|
366 |
-
OrderedDict: Model weights on GPU.
|
367 |
-
"""
|
368 |
-
state_dict_cpu = OrderedDict()
|
369 |
-
for key, val in state_dict.items():
|
370 |
-
state_dict_cpu[key] = val.cpu()
|
371 |
-
return state_dict_cpu
|
372 |
-
|
373 |
-
|
374 |
-
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
375 |
-
"""Saves module state to `destination` dictionary.
|
376 |
-
|
377 |
-
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
378 |
-
|
379 |
-
Args:
|
380 |
-
module (nn.Module): The module to generate state_dict.
|
381 |
-
destination (dict): A dict where state will be stored.
|
382 |
-
prefix (str): The prefix for parameters and buffers used in this
|
383 |
-
module.
|
384 |
-
"""
|
385 |
-
for name, param in module._parameters.items():
|
386 |
-
if param is not None:
|
387 |
-
destination[prefix + name] = param if keep_vars else param.detach()
|
388 |
-
for name, buf in module._buffers.items():
|
389 |
-
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
390 |
-
if buf is not None:
|
391 |
-
destination[prefix + name] = buf if keep_vars else buf.detach()
|
392 |
-
|
393 |
-
|
394 |
-
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
395 |
-
"""Returns a dictionary containing a whole state of the module.
|
396 |
-
|
397 |
-
Both parameters and persistent buffers (e.g. running averages) are
|
398 |
-
included. Keys are corresponding parameter and buffer names.
|
399 |
-
|
400 |
-
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
401 |
-
recursively check parallel module in case that the model has a complicated
|
402 |
-
structure, e.g., nn.Module(nn.Module(DDP)).
|
403 |
-
|
404 |
-
Args:
|
405 |
-
module (nn.Module): The module to generate state_dict.
|
406 |
-
destination (OrderedDict): Returned dict for the state of the
|
407 |
-
module.
|
408 |
-
prefix (str): Prefix of the key.
|
409 |
-
keep_vars (bool): Whether to keep the variable property of the
|
410 |
-
parameters. Default: False.
|
411 |
-
|
412 |
-
Returns:
|
413 |
-
dict: A dictionary containing a whole state of the module.
|
414 |
-
"""
|
415 |
-
# recursively check parallel module in case that the model has a
|
416 |
-
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
417 |
-
if is_module_wrapper(module):
|
418 |
-
module = module.module
|
419 |
-
|
420 |
-
# below is the same as torch.nn.Module.state_dict()
|
421 |
-
if destination is None:
|
422 |
-
destination = OrderedDict()
|
423 |
-
destination._metadata = OrderedDict()
|
424 |
-
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
425 |
-
version=module._version)
|
426 |
-
_save_to_state_dict(module, destination, prefix, keep_vars)
|
427 |
-
for name, child in module._modules.items():
|
428 |
-
if child is not None:
|
429 |
-
get_state_dict(
|
430 |
-
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
431 |
-
for hook in module._state_dict_hooks.values():
|
432 |
-
hook_result = hook(module, destination, prefix, local_metadata)
|
433 |
-
if hook_result is not None:
|
434 |
-
destination = hook_result
|
435 |
-
return destination
|
436 |
-
|
437 |
-
|
438 |
-
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
439 |
-
"""Save checkpoint to file.
|
440 |
-
|
441 |
-
The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
|
442 |
-
``optimizer``. By default ``meta`` will contain version and time info.
|
443 |
-
|
444 |
-
Args:
|
445 |
-
model (Module): Module whose params are to be saved.
|
446 |
-
filename (str): Checkpoint filename.
|
447 |
-
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
|
448 |
-
meta (dict, optional): Metadata to be saved in checkpoint.
|
449 |
-
"""
|
450 |
-
if meta is None:
|
451 |
-
meta = {}
|
452 |
-
elif not isinstance(meta, dict):
|
453 |
-
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
454 |
-
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
455 |
-
|
456 |
-
if is_module_wrapper(model):
|
457 |
-
model = model.module
|
458 |
-
|
459 |
-
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
460 |
-
# save class name to the meta
|
461 |
-
meta.update(CLASSES=model.CLASSES)
|
462 |
-
|
463 |
-
checkpoint = {
|
464 |
-
'meta': meta,
|
465 |
-
'state_dict': weights_to_cpu(get_state_dict(model))
|
466 |
-
}
|
467 |
-
# save optimizer state dict in the checkpoint
|
468 |
-
if isinstance(optimizer, Optimizer):
|
469 |
-
checkpoint['optimizer'] = optimizer.state_dict()
|
470 |
-
elif isinstance(optimizer, dict):
|
471 |
-
checkpoint['optimizer'] = {}
|
472 |
-
for name, optim in optimizer.items():
|
473 |
-
checkpoint['optimizer'][name] = optim.state_dict()
|
474 |
-
|
475 |
-
if filename.startswith('pavi://'):
|
476 |
-
try:
|
477 |
-
from pavi import modelcloud
|
478 |
-
from pavi.exception import NodeNotFoundError
|
479 |
-
except ImportError:
|
480 |
-
raise ImportError(
|
481 |
-
'Please install pavi to load checkpoint from modelcloud.')
|
482 |
-
model_path = filename[7:]
|
483 |
-
root = modelcloud.Folder()
|
484 |
-
model_dir, model_name = osp.split(model_path)
|
485 |
-
try:
|
486 |
-
model = modelcloud.get(model_dir)
|
487 |
-
except NodeNotFoundError:
|
488 |
-
model = root.create_training_model(model_dir)
|
489 |
-
with TemporaryDirectory() as tmp_dir:
|
490 |
-
checkpoint_file = osp.join(tmp_dir, model_name)
|
491 |
-
with open(checkpoint_file, 'wb') as f:
|
492 |
-
torch.save(checkpoint, f)
|
493 |
-
f.flush()
|
494 |
-
model.create_file(checkpoint_file, name=model_name)
|
495 |
-
else:
|
496 |
-
mmcv.mkdir_or_exist(osp.dirname(filename))
|
497 |
-
# immediately flush buffer
|
498 |
-
with open(filename, 'wb') as f:
|
499 |
-
torch.save(checkpoint, f)
|
500 |
-
f.flush()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/i3d.py
DELETED
@@ -1,215 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
# Code is modified from `Third-party pytorch implementation of i3d
|
3 |
-
# <https://github.com/hassony2/kinetics_i3d_pytorch>`.
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
|
8 |
-
from ..builder import BACKBONES
|
9 |
-
from .base_backbone import BaseBackbone
|
10 |
-
|
11 |
-
|
12 |
-
class Conv3dBlock(nn.Module):
|
13 |
-
"""Basic 3d convolution block for I3D.
|
14 |
-
|
15 |
-
Args:
|
16 |
-
in_channels (int): Input channels of this block.
|
17 |
-
out_channels (int): Output channels of this block.
|
18 |
-
expansion (float): The multiplier of in_channels and out_channels.
|
19 |
-
Default: 1.
|
20 |
-
kernel_size (tuple[int]): kernel size of the 3d convolution layer.
|
21 |
-
Default: (1, 1, 1).
|
22 |
-
stride (tuple[int]): stride of the block. Default: (1, 1, 1)
|
23 |
-
padding (tuple[int]): padding of the input tensor. Default: (0, 0, 0)
|
24 |
-
use_bias (bool): whether to enable bias in 3d convolution layer.
|
25 |
-
Default: False
|
26 |
-
use_bn (bool): whether to use Batch Normalization after 3d convolution
|
27 |
-
layer. Default: True
|
28 |
-
use_relu (bool): whether to use ReLU after Batch Normalization layer.
|
29 |
-
Default: True
|
30 |
-
"""
|
31 |
-
|
32 |
-
def __init__(self,
|
33 |
-
in_channels,
|
34 |
-
out_channels,
|
35 |
-
expansion=1.0,
|
36 |
-
kernel_size=(1, 1, 1),
|
37 |
-
stride=(1, 1, 1),
|
38 |
-
padding=(0, 0, 0),
|
39 |
-
use_bias=False,
|
40 |
-
use_bn=True,
|
41 |
-
use_relu=True):
|
42 |
-
super().__init__()
|
43 |
-
|
44 |
-
in_channels = int(in_channels * expansion)
|
45 |
-
out_channels = int(out_channels * expansion)
|
46 |
-
|
47 |
-
self.conv3d = nn.Conv3d(
|
48 |
-
in_channels,
|
49 |
-
out_channels,
|
50 |
-
kernel_size,
|
51 |
-
padding=padding,
|
52 |
-
stride=stride,
|
53 |
-
bias=use_bias)
|
54 |
-
|
55 |
-
self.use_bn = use_bn
|
56 |
-
self.use_relu = use_relu
|
57 |
-
|
58 |
-
if self.use_bn:
|
59 |
-
self.batch3d = nn.BatchNorm3d(out_channels)
|
60 |
-
|
61 |
-
if self.use_relu:
|
62 |
-
self.activation = nn.ReLU(inplace=True)
|
63 |
-
|
64 |
-
def forward(self, x):
|
65 |
-
"""Forward function."""
|
66 |
-
out = self.conv3d(x)
|
67 |
-
if self.use_bn:
|
68 |
-
out = self.batch3d(out)
|
69 |
-
if self.use_relu:
|
70 |
-
out = self.activation(out)
|
71 |
-
return out
|
72 |
-
|
73 |
-
|
74 |
-
class Mixed(nn.Module):
|
75 |
-
"""Inception block for I3D.
|
76 |
-
|
77 |
-
Args:
|
78 |
-
in_channels (int): Input channels of this block.
|
79 |
-
out_channels (int): Output channels of this block.
|
80 |
-
expansion (float): The multiplier of in_channels and out_channels.
|
81 |
-
Default: 1.
|
82 |
-
"""
|
83 |
-
|
84 |
-
def __init__(self, in_channels, out_channels, expansion=1.0):
|
85 |
-
super(Mixed, self).__init__()
|
86 |
-
# Branch 0
|
87 |
-
self.branch_0 = Conv3dBlock(
|
88 |
-
in_channels, out_channels[0], expansion, kernel_size=(1, 1, 1))
|
89 |
-
|
90 |
-
# Branch 1
|
91 |
-
branch_1_conv1 = Conv3dBlock(
|
92 |
-
in_channels, out_channels[1], expansion, kernel_size=(1, 1, 1))
|
93 |
-
branch_1_conv2 = Conv3dBlock(
|
94 |
-
out_channels[1],
|
95 |
-
out_channels[2],
|
96 |
-
expansion,
|
97 |
-
kernel_size=(3, 3, 3),
|
98 |
-
padding=(1, 1, 1))
|
99 |
-
self.branch_1 = nn.Sequential(branch_1_conv1, branch_1_conv2)
|
100 |
-
|
101 |
-
# Branch 2
|
102 |
-
branch_2_conv1 = Conv3dBlock(
|
103 |
-
in_channels, out_channels[3], expansion, kernel_size=(1, 1, 1))
|
104 |
-
branch_2_conv2 = Conv3dBlock(
|
105 |
-
out_channels[3],
|
106 |
-
out_channels[4],
|
107 |
-
expansion,
|
108 |
-
kernel_size=(3, 3, 3),
|
109 |
-
padding=(1, 1, 1))
|
110 |
-
self.branch_2 = nn.Sequential(branch_2_conv1, branch_2_conv2)
|
111 |
-
|
112 |
-
# Branch3
|
113 |
-
branch_3_pool = nn.MaxPool3d(
|
114 |
-
kernel_size=(3, 3, 3),
|
115 |
-
stride=(1, 1, 1),
|
116 |
-
padding=(1, 1, 1),
|
117 |
-
ceil_mode=True)
|
118 |
-
branch_3_conv2 = Conv3dBlock(
|
119 |
-
in_channels, out_channels[5], expansion, kernel_size=(1, 1, 1))
|
120 |
-
self.branch_3 = nn.Sequential(branch_3_pool, branch_3_conv2)
|
121 |
-
|
122 |
-
def forward(self, x):
|
123 |
-
"""Forward function."""
|
124 |
-
out_0 = self.branch_0(x)
|
125 |
-
out_1 = self.branch_1(x)
|
126 |
-
out_2 = self.branch_2(x)
|
127 |
-
out_3 = self.branch_3(x)
|
128 |
-
out = torch.cat((out_0, out_1, out_2, out_3), 1)
|
129 |
-
return out
|
130 |
-
|
131 |
-
|
132 |
-
@BACKBONES.register_module()
|
133 |
-
class I3D(BaseBackbone):
|
134 |
-
"""I3D backbone.
|
135 |
-
|
136 |
-
Please refer to the `paper <https://arxiv.org/abs/1705.07750>`__ for
|
137 |
-
details.
|
138 |
-
|
139 |
-
Args:
|
140 |
-
in_channels (int): Input channels of the backbone, which is decided
|
141 |
-
on the input modality.
|
142 |
-
expansion (float): The multiplier of in_channels and out_channels.
|
143 |
-
Default: 1.
|
144 |
-
"""
|
145 |
-
|
146 |
-
def __init__(self, in_channels=3, expansion=1.0):
|
147 |
-
super(I3D, self).__init__()
|
148 |
-
|
149 |
-
# expansion must be an integer multiple of 1/8
|
150 |
-
expansion = round(8 * expansion) / 8.0
|
151 |
-
|
152 |
-
# xut Layer
|
153 |
-
self.conv3d_1a_7x7 = Conv3dBlock(
|
154 |
-
out_channels=64,
|
155 |
-
in_channels=in_channels / expansion,
|
156 |
-
expansion=expansion,
|
157 |
-
kernel_size=(7, 7, 7),
|
158 |
-
stride=(2, 2, 2),
|
159 |
-
padding=(2, 3, 3))
|
160 |
-
self.maxPool3d_2a_3x3 = nn.MaxPool3d(
|
161 |
-
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
|
162 |
-
|
163 |
-
# Layer 2
|
164 |
-
self.conv3d_2b_1x1 = Conv3dBlock(
|
165 |
-
out_channels=64,
|
166 |
-
in_channels=64,
|
167 |
-
expansion=expansion,
|
168 |
-
kernel_size=(1, 1, 1))
|
169 |
-
self.conv3d_2c_3x3 = Conv3dBlock(
|
170 |
-
out_channels=192,
|
171 |
-
in_channels=64,
|
172 |
-
expansion=expansion,
|
173 |
-
kernel_size=(3, 3, 3),
|
174 |
-
padding=(1, 1, 1))
|
175 |
-
self.maxPool3d_3a_3x3 = nn.MaxPool3d(
|
176 |
-
kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
|
177 |
-
|
178 |
-
# Mixed_3b
|
179 |
-
self.mixed_3b = Mixed(192, [64, 96, 128, 16, 32, 32], expansion)
|
180 |
-
self.mixed_3c = Mixed(256, [128, 128, 192, 32, 96, 64], expansion)
|
181 |
-
self.maxPool3d_4a_3x3 = nn.MaxPool3d(
|
182 |
-
kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
|
183 |
-
|
184 |
-
# Mixed 4
|
185 |
-
self.mixed_4b = Mixed(480, [192, 96, 208, 16, 48, 64], expansion)
|
186 |
-
self.mixed_4c = Mixed(512, [160, 112, 224, 24, 64, 64], expansion)
|
187 |
-
self.mixed_4d = Mixed(512, [128, 128, 256, 24, 64, 64], expansion)
|
188 |
-
self.mixed_4e = Mixed(512, [112, 144, 288, 32, 64, 64], expansion)
|
189 |
-
self.mixed_4f = Mixed(528, [256, 160, 320, 32, 128, 128], expansion)
|
190 |
-
|
191 |
-
self.maxPool3d_5a_2x2 = nn.MaxPool3d(
|
192 |
-
kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0))
|
193 |
-
|
194 |
-
# Mixed 5
|
195 |
-
self.mixed_5b = Mixed(832, [256, 160, 320, 32, 128, 128], expansion)
|
196 |
-
self.mixed_5c = Mixed(832, [384, 192, 384, 48, 128, 128], expansion)
|
197 |
-
|
198 |
-
def forward(self, x):
|
199 |
-
out = self.conv3d_1a_7x7(x)
|
200 |
-
out = self.maxPool3d_2a_3x3(out)
|
201 |
-
out = self.conv3d_2b_1x1(out)
|
202 |
-
out = self.conv3d_2c_3x3(out)
|
203 |
-
out = self.maxPool3d_3a_3x3(out)
|
204 |
-
out = self.mixed_3b(out)
|
205 |
-
out = self.mixed_3c(out)
|
206 |
-
out = self.maxPool3d_4a_3x3(out)
|
207 |
-
out = self.mixed_4b(out)
|
208 |
-
out = self.mixed_4c(out)
|
209 |
-
out = self.mixed_4d(out)
|
210 |
-
out = self.mixed_4e(out)
|
211 |
-
out = self.mixed_4f(out)
|
212 |
-
out = self.maxPool3d_5a_2x2(out)
|
213 |
-
out = self.mixed_5b(out)
|
214 |
-
out = self.mixed_5c(out)
|
215 |
-
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/litehrnet.py
DELETED
@@ -1,984 +0,0 @@
|
|
1 |
-
# ------------------------------------------------------------------------------
|
2 |
-
# Adapted from https://github.com/HRNet/Lite-HRNet
|
3 |
-
# Original licence: Apache License 2.0.
|
4 |
-
# ------------------------------------------------------------------------------
|
5 |
-
|
6 |
-
import mmcv
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
import torch.nn.functional as F
|
10 |
-
import torch.utils.checkpoint as cp
|
11 |
-
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
12 |
-
build_conv_layer, build_norm_layer, constant_init,
|
13 |
-
normal_init)
|
14 |
-
from torch.nn.modules.batchnorm import _BatchNorm
|
15 |
-
|
16 |
-
from mmpose.utils import get_root_logger
|
17 |
-
from ..builder import BACKBONES
|
18 |
-
from .utils import channel_shuffle, load_checkpoint
|
19 |
-
|
20 |
-
|
21 |
-
class SpatialWeighting(nn.Module):
|
22 |
-
"""Spatial weighting module.
|
23 |
-
|
24 |
-
Args:
|
25 |
-
channels (int): The channels of the module.
|
26 |
-
ratio (int): channel reduction ratio.
|
27 |
-
conv_cfg (dict): Config dict for convolution layer.
|
28 |
-
Default: None, which means using conv2d.
|
29 |
-
norm_cfg (dict): Config dict for normalization layer.
|
30 |
-
Default: None.
|
31 |
-
act_cfg (dict): Config dict for activation layer.
|
32 |
-
Default: (dict(type='ReLU'), dict(type='Sigmoid')).
|
33 |
-
The last ConvModule uses Sigmoid by default.
|
34 |
-
"""
|
35 |
-
|
36 |
-
def __init__(self,
|
37 |
-
channels,
|
38 |
-
ratio=16,
|
39 |
-
conv_cfg=None,
|
40 |
-
norm_cfg=None,
|
41 |
-
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
|
42 |
-
super().__init__()
|
43 |
-
if isinstance(act_cfg, dict):
|
44 |
-
act_cfg = (act_cfg, act_cfg)
|
45 |
-
assert len(act_cfg) == 2
|
46 |
-
assert mmcv.is_tuple_of(act_cfg, dict)
|
47 |
-
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
|
48 |
-
self.conv1 = ConvModule(
|
49 |
-
in_channels=channels,
|
50 |
-
out_channels=int(channels / ratio),
|
51 |
-
kernel_size=1,
|
52 |
-
stride=1,
|
53 |
-
conv_cfg=conv_cfg,
|
54 |
-
norm_cfg=norm_cfg,
|
55 |
-
act_cfg=act_cfg[0])
|
56 |
-
self.conv2 = ConvModule(
|
57 |
-
in_channels=int(channels / ratio),
|
58 |
-
out_channels=channels,
|
59 |
-
kernel_size=1,
|
60 |
-
stride=1,
|
61 |
-
conv_cfg=conv_cfg,
|
62 |
-
norm_cfg=norm_cfg,
|
63 |
-
act_cfg=act_cfg[1])
|
64 |
-
|
65 |
-
def forward(self, x):
|
66 |
-
out = self.global_avgpool(x)
|
67 |
-
out = self.conv1(out)
|
68 |
-
out = self.conv2(out)
|
69 |
-
return x * out
|
70 |
-
|
71 |
-
|
72 |
-
class CrossResolutionWeighting(nn.Module):
|
73 |
-
"""Cross-resolution channel weighting module.
|
74 |
-
|
75 |
-
Args:
|
76 |
-
channels (int): The channels of the module.
|
77 |
-
ratio (int): channel reduction ratio.
|
78 |
-
conv_cfg (dict): Config dict for convolution layer.
|
79 |
-
Default: None, which means using conv2d.
|
80 |
-
norm_cfg (dict): Config dict for normalization layer.
|
81 |
-
Default: None.
|
82 |
-
act_cfg (dict): Config dict for activation layer.
|
83 |
-
Default: (dict(type='ReLU'), dict(type='Sigmoid')).
|
84 |
-
The last ConvModule uses Sigmoid by default.
|
85 |
-
"""
|
86 |
-
|
87 |
-
def __init__(self,
|
88 |
-
channels,
|
89 |
-
ratio=16,
|
90 |
-
conv_cfg=None,
|
91 |
-
norm_cfg=None,
|
92 |
-
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))):
|
93 |
-
super().__init__()
|
94 |
-
if isinstance(act_cfg, dict):
|
95 |
-
act_cfg = (act_cfg, act_cfg)
|
96 |
-
assert len(act_cfg) == 2
|
97 |
-
assert mmcv.is_tuple_of(act_cfg, dict)
|
98 |
-
self.channels = channels
|
99 |
-
total_channel = sum(channels)
|
100 |
-
self.conv1 = ConvModule(
|
101 |
-
in_channels=total_channel,
|
102 |
-
out_channels=int(total_channel / ratio),
|
103 |
-
kernel_size=1,
|
104 |
-
stride=1,
|
105 |
-
conv_cfg=conv_cfg,
|
106 |
-
norm_cfg=norm_cfg,
|
107 |
-
act_cfg=act_cfg[0])
|
108 |
-
self.conv2 = ConvModule(
|
109 |
-
in_channels=int(total_channel / ratio),
|
110 |
-
out_channels=total_channel,
|
111 |
-
kernel_size=1,
|
112 |
-
stride=1,
|
113 |
-
conv_cfg=conv_cfg,
|
114 |
-
norm_cfg=norm_cfg,
|
115 |
-
act_cfg=act_cfg[1])
|
116 |
-
|
117 |
-
def forward(self, x):
|
118 |
-
mini_size = x[-1].size()[-2:]
|
119 |
-
out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]]
|
120 |
-
out = torch.cat(out, dim=1)
|
121 |
-
out = self.conv1(out)
|
122 |
-
out = self.conv2(out)
|
123 |
-
out = torch.split(out, self.channels, dim=1)
|
124 |
-
out = [
|
125 |
-
s * F.interpolate(a, size=s.size()[-2:], mode='nearest')
|
126 |
-
for s, a in zip(x, out)
|
127 |
-
]
|
128 |
-
return out
|
129 |
-
|
130 |
-
|
131 |
-
class ConditionalChannelWeighting(nn.Module):
|
132 |
-
"""Conditional channel weighting block.
|
133 |
-
|
134 |
-
Args:
|
135 |
-
in_channels (int): The input channels of the block.
|
136 |
-
stride (int): Stride of the 3x3 convolution layer.
|
137 |
-
reduce_ratio (int): channel reduction ratio.
|
138 |
-
conv_cfg (dict): Config dict for convolution layer.
|
139 |
-
Default: None, which means using conv2d.
|
140 |
-
norm_cfg (dict): Config dict for normalization layer.
|
141 |
-
Default: dict(type='BN').
|
142 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
143 |
-
memory while slowing down the training speed. Default: False.
|
144 |
-
"""
|
145 |
-
|
146 |
-
def __init__(self,
|
147 |
-
in_channels,
|
148 |
-
stride,
|
149 |
-
reduce_ratio,
|
150 |
-
conv_cfg=None,
|
151 |
-
norm_cfg=dict(type='BN'),
|
152 |
-
with_cp=False):
|
153 |
-
super().__init__()
|
154 |
-
self.with_cp = with_cp
|
155 |
-
self.stride = stride
|
156 |
-
assert stride in [1, 2]
|
157 |
-
|
158 |
-
branch_channels = [channel // 2 for channel in in_channels]
|
159 |
-
|
160 |
-
self.cross_resolution_weighting = CrossResolutionWeighting(
|
161 |
-
branch_channels,
|
162 |
-
ratio=reduce_ratio,
|
163 |
-
conv_cfg=conv_cfg,
|
164 |
-
norm_cfg=norm_cfg)
|
165 |
-
|
166 |
-
self.depthwise_convs = nn.ModuleList([
|
167 |
-
ConvModule(
|
168 |
-
channel,
|
169 |
-
channel,
|
170 |
-
kernel_size=3,
|
171 |
-
stride=self.stride,
|
172 |
-
padding=1,
|
173 |
-
groups=channel,
|
174 |
-
conv_cfg=conv_cfg,
|
175 |
-
norm_cfg=norm_cfg,
|
176 |
-
act_cfg=None) for channel in branch_channels
|
177 |
-
])
|
178 |
-
|
179 |
-
self.spatial_weighting = nn.ModuleList([
|
180 |
-
SpatialWeighting(channels=channel, ratio=4)
|
181 |
-
for channel in branch_channels
|
182 |
-
])
|
183 |
-
|
184 |
-
def forward(self, x):
|
185 |
-
|
186 |
-
def _inner_forward(x):
|
187 |
-
x = [s.chunk(2, dim=1) for s in x]
|
188 |
-
x1 = [s[0] for s in x]
|
189 |
-
x2 = [s[1] for s in x]
|
190 |
-
|
191 |
-
x2 = self.cross_resolution_weighting(x2)
|
192 |
-
x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)]
|
193 |
-
x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)]
|
194 |
-
|
195 |
-
out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)]
|
196 |
-
out = [channel_shuffle(s, 2) for s in out]
|
197 |
-
|
198 |
-
return out
|
199 |
-
|
200 |
-
if self.with_cp and x.requires_grad:
|
201 |
-
out = cp.checkpoint(_inner_forward, x)
|
202 |
-
else:
|
203 |
-
out = _inner_forward(x)
|
204 |
-
|
205 |
-
return out
|
206 |
-
|
207 |
-
|
208 |
-
class Stem(nn.Module):
|
209 |
-
"""Stem network block.
|
210 |
-
|
211 |
-
Args:
|
212 |
-
in_channels (int): The input channels of the block.
|
213 |
-
stem_channels (int): Output channels of the stem layer.
|
214 |
-
out_channels (int): The output channels of the block.
|
215 |
-
expand_ratio (int): adjusts number of channels of the hidden layer
|
216 |
-
in InvertedResidual by this amount.
|
217 |
-
conv_cfg (dict): Config dict for convolution layer.
|
218 |
-
Default: None, which means using conv2d.
|
219 |
-
norm_cfg (dict): Config dict for normalization layer.
|
220 |
-
Default: dict(type='BN').
|
221 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
222 |
-
memory while slowing down the training speed. Default: False.
|
223 |
-
"""
|
224 |
-
|
225 |
-
def __init__(self,
|
226 |
-
in_channels,
|
227 |
-
stem_channels,
|
228 |
-
out_channels,
|
229 |
-
expand_ratio,
|
230 |
-
conv_cfg=None,
|
231 |
-
norm_cfg=dict(type='BN'),
|
232 |
-
with_cp=False):
|
233 |
-
super().__init__()
|
234 |
-
self.in_channels = in_channels
|
235 |
-
self.out_channels = out_channels
|
236 |
-
self.conv_cfg = conv_cfg
|
237 |
-
self.norm_cfg = norm_cfg
|
238 |
-
self.with_cp = with_cp
|
239 |
-
|
240 |
-
self.conv1 = ConvModule(
|
241 |
-
in_channels=in_channels,
|
242 |
-
out_channels=stem_channels,
|
243 |
-
kernel_size=3,
|
244 |
-
stride=2,
|
245 |
-
padding=1,
|
246 |
-
conv_cfg=self.conv_cfg,
|
247 |
-
norm_cfg=self.norm_cfg,
|
248 |
-
act_cfg=dict(type='ReLU'))
|
249 |
-
|
250 |
-
mid_channels = int(round(stem_channels * expand_ratio))
|
251 |
-
branch_channels = stem_channels // 2
|
252 |
-
if stem_channels == self.out_channels:
|
253 |
-
inc_channels = self.out_channels - branch_channels
|
254 |
-
else:
|
255 |
-
inc_channels = self.out_channels - stem_channels
|
256 |
-
|
257 |
-
self.branch1 = nn.Sequential(
|
258 |
-
ConvModule(
|
259 |
-
branch_channels,
|
260 |
-
branch_channels,
|
261 |
-
kernel_size=3,
|
262 |
-
stride=2,
|
263 |
-
padding=1,
|
264 |
-
groups=branch_channels,
|
265 |
-
conv_cfg=conv_cfg,
|
266 |
-
norm_cfg=norm_cfg,
|
267 |
-
act_cfg=None),
|
268 |
-
ConvModule(
|
269 |
-
branch_channels,
|
270 |
-
inc_channels,
|
271 |
-
kernel_size=1,
|
272 |
-
stride=1,
|
273 |
-
padding=0,
|
274 |
-
conv_cfg=conv_cfg,
|
275 |
-
norm_cfg=norm_cfg,
|
276 |
-
act_cfg=dict(type='ReLU')),
|
277 |
-
)
|
278 |
-
|
279 |
-
self.expand_conv = ConvModule(
|
280 |
-
branch_channels,
|
281 |
-
mid_channels,
|
282 |
-
kernel_size=1,
|
283 |
-
stride=1,
|
284 |
-
padding=0,
|
285 |
-
conv_cfg=conv_cfg,
|
286 |
-
norm_cfg=norm_cfg,
|
287 |
-
act_cfg=dict(type='ReLU'))
|
288 |
-
self.depthwise_conv = ConvModule(
|
289 |
-
mid_channels,
|
290 |
-
mid_channels,
|
291 |
-
kernel_size=3,
|
292 |
-
stride=2,
|
293 |
-
padding=1,
|
294 |
-
groups=mid_channels,
|
295 |
-
conv_cfg=conv_cfg,
|
296 |
-
norm_cfg=norm_cfg,
|
297 |
-
act_cfg=None)
|
298 |
-
self.linear_conv = ConvModule(
|
299 |
-
mid_channels,
|
300 |
-
branch_channels
|
301 |
-
if stem_channels == self.out_channels else stem_channels,
|
302 |
-
kernel_size=1,
|
303 |
-
stride=1,
|
304 |
-
padding=0,
|
305 |
-
conv_cfg=conv_cfg,
|
306 |
-
norm_cfg=norm_cfg,
|
307 |
-
act_cfg=dict(type='ReLU'))
|
308 |
-
|
309 |
-
def forward(self, x):
|
310 |
-
|
311 |
-
def _inner_forward(x):
|
312 |
-
x = self.conv1(x)
|
313 |
-
x1, x2 = x.chunk(2, dim=1)
|
314 |
-
|
315 |
-
x2 = self.expand_conv(x2)
|
316 |
-
x2 = self.depthwise_conv(x2)
|
317 |
-
x2 = self.linear_conv(x2)
|
318 |
-
|
319 |
-
out = torch.cat((self.branch1(x1), x2), dim=1)
|
320 |
-
|
321 |
-
out = channel_shuffle(out, 2)
|
322 |
-
|
323 |
-
return out
|
324 |
-
|
325 |
-
if self.with_cp and x.requires_grad:
|
326 |
-
out = cp.checkpoint(_inner_forward, x)
|
327 |
-
else:
|
328 |
-
out = _inner_forward(x)
|
329 |
-
|
330 |
-
return out
|
331 |
-
|
332 |
-
|
333 |
-
class IterativeHead(nn.Module):
|
334 |
-
"""Extra iterative head for feature learning.
|
335 |
-
|
336 |
-
Args:
|
337 |
-
in_channels (int): The input channels of the block.
|
338 |
-
norm_cfg (dict): Config dict for normalization layer.
|
339 |
-
Default: dict(type='BN').
|
340 |
-
"""
|
341 |
-
|
342 |
-
def __init__(self, in_channels, norm_cfg=dict(type='BN')):
|
343 |
-
super().__init__()
|
344 |
-
projects = []
|
345 |
-
num_branchs = len(in_channels)
|
346 |
-
self.in_channels = in_channels[::-1]
|
347 |
-
|
348 |
-
for i in range(num_branchs):
|
349 |
-
if i != num_branchs - 1:
|
350 |
-
projects.append(
|
351 |
-
DepthwiseSeparableConvModule(
|
352 |
-
in_channels=self.in_channels[i],
|
353 |
-
out_channels=self.in_channels[i + 1],
|
354 |
-
kernel_size=3,
|
355 |
-
stride=1,
|
356 |
-
padding=1,
|
357 |
-
norm_cfg=norm_cfg,
|
358 |
-
act_cfg=dict(type='ReLU'),
|
359 |
-
dw_act_cfg=None,
|
360 |
-
pw_act_cfg=dict(type='ReLU')))
|
361 |
-
else:
|
362 |
-
projects.append(
|
363 |
-
DepthwiseSeparableConvModule(
|
364 |
-
in_channels=self.in_channels[i],
|
365 |
-
out_channels=self.in_channels[i],
|
366 |
-
kernel_size=3,
|
367 |
-
stride=1,
|
368 |
-
padding=1,
|
369 |
-
norm_cfg=norm_cfg,
|
370 |
-
act_cfg=dict(type='ReLU'),
|
371 |
-
dw_act_cfg=None,
|
372 |
-
pw_act_cfg=dict(type='ReLU')))
|
373 |
-
self.projects = nn.ModuleList(projects)
|
374 |
-
|
375 |
-
def forward(self, x):
|
376 |
-
x = x[::-1]
|
377 |
-
|
378 |
-
y = []
|
379 |
-
last_x = None
|
380 |
-
for i, s in enumerate(x):
|
381 |
-
if last_x is not None:
|
382 |
-
last_x = F.interpolate(
|
383 |
-
last_x,
|
384 |
-
size=s.size()[-2:],
|
385 |
-
mode='bilinear',
|
386 |
-
align_corners=True)
|
387 |
-
s = s + last_x
|
388 |
-
s = self.projects[i](s)
|
389 |
-
y.append(s)
|
390 |
-
last_x = s
|
391 |
-
|
392 |
-
return y[::-1]
|
393 |
-
|
394 |
-
|
395 |
-
class ShuffleUnit(nn.Module):
|
396 |
-
"""InvertedResidual block for ShuffleNetV2 backbone.
|
397 |
-
|
398 |
-
Args:
|
399 |
-
in_channels (int): The input channels of the block.
|
400 |
-
out_channels (int): The output channels of the block.
|
401 |
-
stride (int): Stride of the 3x3 convolution layer. Default: 1
|
402 |
-
conv_cfg (dict): Config dict for convolution layer.
|
403 |
-
Default: None, which means using conv2d.
|
404 |
-
norm_cfg (dict): Config dict for normalization layer.
|
405 |
-
Default: dict(type='BN').
|
406 |
-
act_cfg (dict): Config dict for activation layer.
|
407 |
-
Default: dict(type='ReLU').
|
408 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
409 |
-
memory while slowing down the training speed. Default: False.
|
410 |
-
"""
|
411 |
-
|
412 |
-
def __init__(self,
|
413 |
-
in_channels,
|
414 |
-
out_channels,
|
415 |
-
stride=1,
|
416 |
-
conv_cfg=None,
|
417 |
-
norm_cfg=dict(type='BN'),
|
418 |
-
act_cfg=dict(type='ReLU'),
|
419 |
-
with_cp=False):
|
420 |
-
super().__init__()
|
421 |
-
self.stride = stride
|
422 |
-
self.with_cp = with_cp
|
423 |
-
|
424 |
-
branch_features = out_channels // 2
|
425 |
-
if self.stride == 1:
|
426 |
-
assert in_channels == branch_features * 2, (
|
427 |
-
f'in_channels ({in_channels}) should equal to '
|
428 |
-
f'branch_features * 2 ({branch_features * 2}) '
|
429 |
-
'when stride is 1')
|
430 |
-
|
431 |
-
if in_channels != branch_features * 2:
|
432 |
-
assert self.stride != 1, (
|
433 |
-
f'stride ({self.stride}) should not equal 1 when '
|
434 |
-
f'in_channels != branch_features * 2')
|
435 |
-
|
436 |
-
if self.stride > 1:
|
437 |
-
self.branch1 = nn.Sequential(
|
438 |
-
ConvModule(
|
439 |
-
in_channels,
|
440 |
-
in_channels,
|
441 |
-
kernel_size=3,
|
442 |
-
stride=self.stride,
|
443 |
-
padding=1,
|
444 |
-
groups=in_channels,
|
445 |
-
conv_cfg=conv_cfg,
|
446 |
-
norm_cfg=norm_cfg,
|
447 |
-
act_cfg=None),
|
448 |
-
ConvModule(
|
449 |
-
in_channels,
|
450 |
-
branch_features,
|
451 |
-
kernel_size=1,
|
452 |
-
stride=1,
|
453 |
-
padding=0,
|
454 |
-
conv_cfg=conv_cfg,
|
455 |
-
norm_cfg=norm_cfg,
|
456 |
-
act_cfg=act_cfg),
|
457 |
-
)
|
458 |
-
|
459 |
-
self.branch2 = nn.Sequential(
|
460 |
-
ConvModule(
|
461 |
-
in_channels if (self.stride > 1) else branch_features,
|
462 |
-
branch_features,
|
463 |
-
kernel_size=1,
|
464 |
-
stride=1,
|
465 |
-
padding=0,
|
466 |
-
conv_cfg=conv_cfg,
|
467 |
-
norm_cfg=norm_cfg,
|
468 |
-
act_cfg=act_cfg),
|
469 |
-
ConvModule(
|
470 |
-
branch_features,
|
471 |
-
branch_features,
|
472 |
-
kernel_size=3,
|
473 |
-
stride=self.stride,
|
474 |
-
padding=1,
|
475 |
-
groups=branch_features,
|
476 |
-
conv_cfg=conv_cfg,
|
477 |
-
norm_cfg=norm_cfg,
|
478 |
-
act_cfg=None),
|
479 |
-
ConvModule(
|
480 |
-
branch_features,
|
481 |
-
branch_features,
|
482 |
-
kernel_size=1,
|
483 |
-
stride=1,
|
484 |
-
padding=0,
|
485 |
-
conv_cfg=conv_cfg,
|
486 |
-
norm_cfg=norm_cfg,
|
487 |
-
act_cfg=act_cfg))
|
488 |
-
|
489 |
-
def forward(self, x):
|
490 |
-
|
491 |
-
def _inner_forward(x):
|
492 |
-
if self.stride > 1:
|
493 |
-
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
494 |
-
else:
|
495 |
-
x1, x2 = x.chunk(2, dim=1)
|
496 |
-
out = torch.cat((x1, self.branch2(x2)), dim=1)
|
497 |
-
|
498 |
-
out = channel_shuffle(out, 2)
|
499 |
-
|
500 |
-
return out
|
501 |
-
|
502 |
-
if self.with_cp and x.requires_grad:
|
503 |
-
out = cp.checkpoint(_inner_forward, x)
|
504 |
-
else:
|
505 |
-
out = _inner_forward(x)
|
506 |
-
|
507 |
-
return out
|
508 |
-
|
509 |
-
|
510 |
-
class LiteHRModule(nn.Module):
|
511 |
-
"""High-Resolution Module for LiteHRNet.
|
512 |
-
|
513 |
-
It contains conditional channel weighting blocks and
|
514 |
-
shuffle blocks.
|
515 |
-
|
516 |
-
|
517 |
-
Args:
|
518 |
-
num_branches (int): Number of branches in the module.
|
519 |
-
num_blocks (int): Number of blocks in the module.
|
520 |
-
in_channels (list(int)): Number of input image channels.
|
521 |
-
reduce_ratio (int): Channel reduction ratio.
|
522 |
-
module_type (str): 'LITE' or 'NAIVE'
|
523 |
-
multiscale_output (bool): Whether to output multi-scale features.
|
524 |
-
with_fuse (bool): Whether to use fuse layers.
|
525 |
-
conv_cfg (dict): dictionary to construct and config conv layer.
|
526 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
527 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
528 |
-
memory while slowing down the training speed.
|
529 |
-
"""
|
530 |
-
|
531 |
-
def __init__(
|
532 |
-
self,
|
533 |
-
num_branches,
|
534 |
-
num_blocks,
|
535 |
-
in_channels,
|
536 |
-
reduce_ratio,
|
537 |
-
module_type,
|
538 |
-
multiscale_output=False,
|
539 |
-
with_fuse=True,
|
540 |
-
conv_cfg=None,
|
541 |
-
norm_cfg=dict(type='BN'),
|
542 |
-
with_cp=False,
|
543 |
-
):
|
544 |
-
super().__init__()
|
545 |
-
self._check_branches(num_branches, in_channels)
|
546 |
-
|
547 |
-
self.in_channels = in_channels
|
548 |
-
self.num_branches = num_branches
|
549 |
-
|
550 |
-
self.module_type = module_type
|
551 |
-
self.multiscale_output = multiscale_output
|
552 |
-
self.with_fuse = with_fuse
|
553 |
-
self.norm_cfg = norm_cfg
|
554 |
-
self.conv_cfg = conv_cfg
|
555 |
-
self.with_cp = with_cp
|
556 |
-
|
557 |
-
if self.module_type.upper() == 'LITE':
|
558 |
-
self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio)
|
559 |
-
elif self.module_type.upper() == 'NAIVE':
|
560 |
-
self.layers = self._make_naive_branches(num_branches, num_blocks)
|
561 |
-
else:
|
562 |
-
raise ValueError("module_type should be either 'LITE' or 'NAIVE'.")
|
563 |
-
if self.with_fuse:
|
564 |
-
self.fuse_layers = self._make_fuse_layers()
|
565 |
-
self.relu = nn.ReLU()
|
566 |
-
|
567 |
-
def _check_branches(self, num_branches, in_channels):
|
568 |
-
"""Check input to avoid ValueError."""
|
569 |
-
if num_branches != len(in_channels):
|
570 |
-
error_msg = f'NUM_BRANCHES({num_branches}) ' \
|
571 |
-
f'!= NUM_INCHANNELS({len(in_channels)})'
|
572 |
-
raise ValueError(error_msg)
|
573 |
-
|
574 |
-
def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1):
|
575 |
-
"""Make channel weighting blocks."""
|
576 |
-
layers = []
|
577 |
-
for i in range(num_blocks):
|
578 |
-
layers.append(
|
579 |
-
ConditionalChannelWeighting(
|
580 |
-
self.in_channels,
|
581 |
-
stride=stride,
|
582 |
-
reduce_ratio=reduce_ratio,
|
583 |
-
conv_cfg=self.conv_cfg,
|
584 |
-
norm_cfg=self.norm_cfg,
|
585 |
-
with_cp=self.with_cp))
|
586 |
-
|
587 |
-
return nn.Sequential(*layers)
|
588 |
-
|
589 |
-
def _make_one_branch(self, branch_index, num_blocks, stride=1):
|
590 |
-
"""Make one branch."""
|
591 |
-
layers = []
|
592 |
-
layers.append(
|
593 |
-
ShuffleUnit(
|
594 |
-
self.in_channels[branch_index],
|
595 |
-
self.in_channels[branch_index],
|
596 |
-
stride=stride,
|
597 |
-
conv_cfg=self.conv_cfg,
|
598 |
-
norm_cfg=self.norm_cfg,
|
599 |
-
act_cfg=dict(type='ReLU'),
|
600 |
-
with_cp=self.with_cp))
|
601 |
-
for i in range(1, num_blocks):
|
602 |
-
layers.append(
|
603 |
-
ShuffleUnit(
|
604 |
-
self.in_channels[branch_index],
|
605 |
-
self.in_channels[branch_index],
|
606 |
-
stride=1,
|
607 |
-
conv_cfg=self.conv_cfg,
|
608 |
-
norm_cfg=self.norm_cfg,
|
609 |
-
act_cfg=dict(type='ReLU'),
|
610 |
-
with_cp=self.with_cp))
|
611 |
-
|
612 |
-
return nn.Sequential(*layers)
|
613 |
-
|
614 |
-
def _make_naive_branches(self, num_branches, num_blocks):
|
615 |
-
"""Make branches."""
|
616 |
-
branches = []
|
617 |
-
|
618 |
-
for i in range(num_branches):
|
619 |
-
branches.append(self._make_one_branch(i, num_blocks))
|
620 |
-
|
621 |
-
return nn.ModuleList(branches)
|
622 |
-
|
623 |
-
def _make_fuse_layers(self):
|
624 |
-
"""Make fuse layer."""
|
625 |
-
if self.num_branches == 1:
|
626 |
-
return None
|
627 |
-
|
628 |
-
num_branches = self.num_branches
|
629 |
-
in_channels = self.in_channels
|
630 |
-
fuse_layers = []
|
631 |
-
num_out_branches = num_branches if self.multiscale_output else 1
|
632 |
-
for i in range(num_out_branches):
|
633 |
-
fuse_layer = []
|
634 |
-
for j in range(num_branches):
|
635 |
-
if j > i:
|
636 |
-
fuse_layer.append(
|
637 |
-
nn.Sequential(
|
638 |
-
build_conv_layer(
|
639 |
-
self.conv_cfg,
|
640 |
-
in_channels[j],
|
641 |
-
in_channels[i],
|
642 |
-
kernel_size=1,
|
643 |
-
stride=1,
|
644 |
-
padding=0,
|
645 |
-
bias=False),
|
646 |
-
build_norm_layer(self.norm_cfg, in_channels[i])[1],
|
647 |
-
nn.Upsample(
|
648 |
-
scale_factor=2**(j - i), mode='nearest')))
|
649 |
-
elif j == i:
|
650 |
-
fuse_layer.append(None)
|
651 |
-
else:
|
652 |
-
conv_downsamples = []
|
653 |
-
for k in range(i - j):
|
654 |
-
if k == i - j - 1:
|
655 |
-
conv_downsamples.append(
|
656 |
-
nn.Sequential(
|
657 |
-
build_conv_layer(
|
658 |
-
self.conv_cfg,
|
659 |
-
in_channels[j],
|
660 |
-
in_channels[j],
|
661 |
-
kernel_size=3,
|
662 |
-
stride=2,
|
663 |
-
padding=1,
|
664 |
-
groups=in_channels[j],
|
665 |
-
bias=False),
|
666 |
-
build_norm_layer(self.norm_cfg,
|
667 |
-
in_channels[j])[1],
|
668 |
-
build_conv_layer(
|
669 |
-
self.conv_cfg,
|
670 |
-
in_channels[j],
|
671 |
-
in_channels[i],
|
672 |
-
kernel_size=1,
|
673 |
-
stride=1,
|
674 |
-
padding=0,
|
675 |
-
bias=False),
|
676 |
-
build_norm_layer(self.norm_cfg,
|
677 |
-
in_channels[i])[1]))
|
678 |
-
else:
|
679 |
-
conv_downsamples.append(
|
680 |
-
nn.Sequential(
|
681 |
-
build_conv_layer(
|
682 |
-
self.conv_cfg,
|
683 |
-
in_channels[j],
|
684 |
-
in_channels[j],
|
685 |
-
kernel_size=3,
|
686 |
-
stride=2,
|
687 |
-
padding=1,
|
688 |
-
groups=in_channels[j],
|
689 |
-
bias=False),
|
690 |
-
build_norm_layer(self.norm_cfg,
|
691 |
-
in_channels[j])[1],
|
692 |
-
build_conv_layer(
|
693 |
-
self.conv_cfg,
|
694 |
-
in_channels[j],
|
695 |
-
in_channels[j],
|
696 |
-
kernel_size=1,
|
697 |
-
stride=1,
|
698 |
-
padding=0,
|
699 |
-
bias=False),
|
700 |
-
build_norm_layer(self.norm_cfg,
|
701 |
-
in_channels[j])[1],
|
702 |
-
nn.ReLU(inplace=True)))
|
703 |
-
fuse_layer.append(nn.Sequential(*conv_downsamples))
|
704 |
-
fuse_layers.append(nn.ModuleList(fuse_layer))
|
705 |
-
|
706 |
-
return nn.ModuleList(fuse_layers)
|
707 |
-
|
708 |
-
def forward(self, x):
|
709 |
-
"""Forward function."""
|
710 |
-
if self.num_branches == 1:
|
711 |
-
return [self.layers[0](x[0])]
|
712 |
-
|
713 |
-
if self.module_type.upper() == 'LITE':
|
714 |
-
out = self.layers(x)
|
715 |
-
elif self.module_type.upper() == 'NAIVE':
|
716 |
-
for i in range(self.num_branches):
|
717 |
-
x[i] = self.layers[i](x[i])
|
718 |
-
out = x
|
719 |
-
|
720 |
-
if self.with_fuse:
|
721 |
-
out_fuse = []
|
722 |
-
for i in range(len(self.fuse_layers)):
|
723 |
-
# `y = 0` will lead to decreased accuracy (0.5~1 mAP)
|
724 |
-
y = out[0] if i == 0 else self.fuse_layers[i][0](out[0])
|
725 |
-
for j in range(self.num_branches):
|
726 |
-
if i == j:
|
727 |
-
y += out[j]
|
728 |
-
else:
|
729 |
-
y += self.fuse_layers[i][j](out[j])
|
730 |
-
out_fuse.append(self.relu(y))
|
731 |
-
out = out_fuse
|
732 |
-
if not self.multiscale_output:
|
733 |
-
out = [out[0]]
|
734 |
-
return out
|
735 |
-
|
736 |
-
|
737 |
-
@BACKBONES.register_module()
|
738 |
-
class LiteHRNet(nn.Module):
|
739 |
-
"""Lite-HRNet backbone.
|
740 |
-
|
741 |
-
`Lite-HRNet: A Lightweight High-Resolution Network
|
742 |
-
<https://arxiv.org/abs/2104.06403>`_.
|
743 |
-
|
744 |
-
Code adapted from 'https://github.com/HRNet/Lite-HRNet'.
|
745 |
-
|
746 |
-
Args:
|
747 |
-
extra (dict): detailed configuration for each stage of HRNet.
|
748 |
-
in_channels (int): Number of input image channels. Default: 3.
|
749 |
-
conv_cfg (dict): dictionary to construct and config conv layer.
|
750 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
751 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
752 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
753 |
-
and its variants only. Default: False
|
754 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
755 |
-
memory while slowing down the training speed.
|
756 |
-
|
757 |
-
Example:
|
758 |
-
>>> from mmpose.models import LiteHRNet
|
759 |
-
>>> import torch
|
760 |
-
>>> extra=dict(
|
761 |
-
>>> stem=dict(stem_channels=32, out_channels=32, expand_ratio=1),
|
762 |
-
>>> num_stages=3,
|
763 |
-
>>> stages_spec=dict(
|
764 |
-
>>> num_modules=(2, 4, 2),
|
765 |
-
>>> num_branches=(2, 3, 4),
|
766 |
-
>>> num_blocks=(2, 2, 2),
|
767 |
-
>>> module_type=('LITE', 'LITE', 'LITE'),
|
768 |
-
>>> with_fuse=(True, True, True),
|
769 |
-
>>> reduce_ratios=(8, 8, 8),
|
770 |
-
>>> num_channels=(
|
771 |
-
>>> (40, 80),
|
772 |
-
>>> (40, 80, 160),
|
773 |
-
>>> (40, 80, 160, 320),
|
774 |
-
>>> )),
|
775 |
-
>>> with_head=False)
|
776 |
-
>>> self = LiteHRNet(extra, in_channels=1)
|
777 |
-
>>> self.eval()
|
778 |
-
>>> inputs = torch.rand(1, 1, 32, 32)
|
779 |
-
>>> level_outputs = self.forward(inputs)
|
780 |
-
>>> for level_out in level_outputs:
|
781 |
-
... print(tuple(level_out.shape))
|
782 |
-
(1, 40, 8, 8)
|
783 |
-
"""
|
784 |
-
|
785 |
-
def __init__(self,
|
786 |
-
extra,
|
787 |
-
in_channels=3,
|
788 |
-
conv_cfg=None,
|
789 |
-
norm_cfg=dict(type='BN'),
|
790 |
-
norm_eval=False,
|
791 |
-
with_cp=False):
|
792 |
-
super().__init__()
|
793 |
-
self.extra = extra
|
794 |
-
self.conv_cfg = conv_cfg
|
795 |
-
self.norm_cfg = norm_cfg
|
796 |
-
self.norm_eval = norm_eval
|
797 |
-
self.with_cp = with_cp
|
798 |
-
|
799 |
-
self.stem = Stem(
|
800 |
-
in_channels,
|
801 |
-
stem_channels=self.extra['stem']['stem_channels'],
|
802 |
-
out_channels=self.extra['stem']['out_channels'],
|
803 |
-
expand_ratio=self.extra['stem']['expand_ratio'],
|
804 |
-
conv_cfg=self.conv_cfg,
|
805 |
-
norm_cfg=self.norm_cfg)
|
806 |
-
|
807 |
-
self.num_stages = self.extra['num_stages']
|
808 |
-
self.stages_spec = self.extra['stages_spec']
|
809 |
-
|
810 |
-
num_channels_last = [
|
811 |
-
self.stem.out_channels,
|
812 |
-
]
|
813 |
-
for i in range(self.num_stages):
|
814 |
-
num_channels = self.stages_spec['num_channels'][i]
|
815 |
-
num_channels = [num_channels[i] for i in range(len(num_channels))]
|
816 |
-
setattr(
|
817 |
-
self, f'transition{i}',
|
818 |
-
self._make_transition_layer(num_channels_last, num_channels))
|
819 |
-
|
820 |
-
stage, num_channels_last = self._make_stage(
|
821 |
-
self.stages_spec, i, num_channels, multiscale_output=True)
|
822 |
-
setattr(self, f'stage{i}', stage)
|
823 |
-
|
824 |
-
self.with_head = self.extra['with_head']
|
825 |
-
if self.with_head:
|
826 |
-
self.head_layer = IterativeHead(
|
827 |
-
in_channels=num_channels_last,
|
828 |
-
norm_cfg=self.norm_cfg,
|
829 |
-
)
|
830 |
-
|
831 |
-
def _make_transition_layer(self, num_channels_pre_layer,
|
832 |
-
num_channels_cur_layer):
|
833 |
-
"""Make transition layer."""
|
834 |
-
num_branches_cur = len(num_channels_cur_layer)
|
835 |
-
num_branches_pre = len(num_channels_pre_layer)
|
836 |
-
|
837 |
-
transition_layers = []
|
838 |
-
for i in range(num_branches_cur):
|
839 |
-
if i < num_branches_pre:
|
840 |
-
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
841 |
-
transition_layers.append(
|
842 |
-
nn.Sequential(
|
843 |
-
build_conv_layer(
|
844 |
-
self.conv_cfg,
|
845 |
-
num_channels_pre_layer[i],
|
846 |
-
num_channels_pre_layer[i],
|
847 |
-
kernel_size=3,
|
848 |
-
stride=1,
|
849 |
-
padding=1,
|
850 |
-
groups=num_channels_pre_layer[i],
|
851 |
-
bias=False),
|
852 |
-
build_norm_layer(self.norm_cfg,
|
853 |
-
num_channels_pre_layer[i])[1],
|
854 |
-
build_conv_layer(
|
855 |
-
self.conv_cfg,
|
856 |
-
num_channels_pre_layer[i],
|
857 |
-
num_channels_cur_layer[i],
|
858 |
-
kernel_size=1,
|
859 |
-
stride=1,
|
860 |
-
padding=0,
|
861 |
-
bias=False),
|
862 |
-
build_norm_layer(self.norm_cfg,
|
863 |
-
num_channels_cur_layer[i])[1],
|
864 |
-
nn.ReLU()))
|
865 |
-
else:
|
866 |
-
transition_layers.append(None)
|
867 |
-
else:
|
868 |
-
conv_downsamples = []
|
869 |
-
for j in range(i + 1 - num_branches_pre):
|
870 |
-
in_channels = num_channels_pre_layer[-1]
|
871 |
-
out_channels = num_channels_cur_layer[i] \
|
872 |
-
if j == i - num_branches_pre else in_channels
|
873 |
-
conv_downsamples.append(
|
874 |
-
nn.Sequential(
|
875 |
-
build_conv_layer(
|
876 |
-
self.conv_cfg,
|
877 |
-
in_channels,
|
878 |
-
in_channels,
|
879 |
-
kernel_size=3,
|
880 |
-
stride=2,
|
881 |
-
padding=1,
|
882 |
-
groups=in_channels,
|
883 |
-
bias=False),
|
884 |
-
build_norm_layer(self.norm_cfg, in_channels)[1],
|
885 |
-
build_conv_layer(
|
886 |
-
self.conv_cfg,
|
887 |
-
in_channels,
|
888 |
-
out_channels,
|
889 |
-
kernel_size=1,
|
890 |
-
stride=1,
|
891 |
-
padding=0,
|
892 |
-
bias=False),
|
893 |
-
build_norm_layer(self.norm_cfg, out_channels)[1],
|
894 |
-
nn.ReLU()))
|
895 |
-
transition_layers.append(nn.Sequential(*conv_downsamples))
|
896 |
-
|
897 |
-
return nn.ModuleList(transition_layers)
|
898 |
-
|
899 |
-
def _make_stage(self,
|
900 |
-
stages_spec,
|
901 |
-
stage_index,
|
902 |
-
in_channels,
|
903 |
-
multiscale_output=True):
|
904 |
-
num_modules = stages_spec['num_modules'][stage_index]
|
905 |
-
num_branches = stages_spec['num_branches'][stage_index]
|
906 |
-
num_blocks = stages_spec['num_blocks'][stage_index]
|
907 |
-
reduce_ratio = stages_spec['reduce_ratios'][stage_index]
|
908 |
-
with_fuse = stages_spec['with_fuse'][stage_index]
|
909 |
-
module_type = stages_spec['module_type'][stage_index]
|
910 |
-
|
911 |
-
modules = []
|
912 |
-
for i in range(num_modules):
|
913 |
-
# multi_scale_output is only used last module
|
914 |
-
if not multiscale_output and i == num_modules - 1:
|
915 |
-
reset_multiscale_output = False
|
916 |
-
else:
|
917 |
-
reset_multiscale_output = True
|
918 |
-
|
919 |
-
modules.append(
|
920 |
-
LiteHRModule(
|
921 |
-
num_branches,
|
922 |
-
num_blocks,
|
923 |
-
in_channels,
|
924 |
-
reduce_ratio,
|
925 |
-
module_type,
|
926 |
-
multiscale_output=reset_multiscale_output,
|
927 |
-
with_fuse=with_fuse,
|
928 |
-
conv_cfg=self.conv_cfg,
|
929 |
-
norm_cfg=self.norm_cfg,
|
930 |
-
with_cp=self.with_cp))
|
931 |
-
in_channels = modules[-1].in_channels
|
932 |
-
|
933 |
-
return nn.Sequential(*modules), in_channels
|
934 |
-
|
935 |
-
def init_weights(self, pretrained=None):
|
936 |
-
"""Initialize the weights in backbone.
|
937 |
-
|
938 |
-
Args:
|
939 |
-
pretrained (str, optional): Path to pre-trained weights.
|
940 |
-
Defaults to None.
|
941 |
-
"""
|
942 |
-
if isinstance(pretrained, str):
|
943 |
-
logger = get_root_logger()
|
944 |
-
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
945 |
-
elif pretrained is None:
|
946 |
-
for m in self.modules():
|
947 |
-
if isinstance(m, nn.Conv2d):
|
948 |
-
normal_init(m, std=0.001)
|
949 |
-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
950 |
-
constant_init(m, 1)
|
951 |
-
else:
|
952 |
-
raise TypeError('pretrained must be a str or None')
|
953 |
-
|
954 |
-
def forward(self, x):
|
955 |
-
"""Forward function."""
|
956 |
-
x = self.stem(x)
|
957 |
-
|
958 |
-
y_list = [x]
|
959 |
-
for i in range(self.num_stages):
|
960 |
-
x_list = []
|
961 |
-
transition = getattr(self, f'transition{i}')
|
962 |
-
for j in range(self.stages_spec['num_branches'][i]):
|
963 |
-
if transition[j]:
|
964 |
-
if j >= len(y_list):
|
965 |
-
x_list.append(transition[j](y_list[-1]))
|
966 |
-
else:
|
967 |
-
x_list.append(transition[j](y_list[j]))
|
968 |
-
else:
|
969 |
-
x_list.append(y_list[j])
|
970 |
-
y_list = getattr(self, f'stage{i}')(x_list)
|
971 |
-
|
972 |
-
x = y_list
|
973 |
-
if self.with_head:
|
974 |
-
x = self.head_layer(x)
|
975 |
-
|
976 |
-
return [x[0]]
|
977 |
-
|
978 |
-
def train(self, mode=True):
|
979 |
-
"""Convert the model into training mode."""
|
980 |
-
super().train(mode)
|
981 |
-
if mode and self.norm_eval:
|
982 |
-
for m in self.modules():
|
983 |
-
if isinstance(m, _BatchNorm):
|
984 |
-
m.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/mobilenet_v2.py
DELETED
@@ -1,275 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
import logging
|
4 |
-
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.utils.checkpoint as cp
|
7 |
-
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
8 |
-
from torch.nn.modules.batchnorm import _BatchNorm
|
9 |
-
|
10 |
-
from ..builder import BACKBONES
|
11 |
-
from .base_backbone import BaseBackbone
|
12 |
-
from .utils import load_checkpoint, make_divisible
|
13 |
-
|
14 |
-
|
15 |
-
class InvertedResidual(nn.Module):
|
16 |
-
"""InvertedResidual block for MobileNetV2.
|
17 |
-
|
18 |
-
Args:
|
19 |
-
in_channels (int): The input channels of the InvertedResidual block.
|
20 |
-
out_channels (int): The output channels of the InvertedResidual block.
|
21 |
-
stride (int): Stride of the middle (first) 3x3 convolution.
|
22 |
-
expand_ratio (int): adjusts number of channels of the hidden layer
|
23 |
-
in InvertedResidual by this amount.
|
24 |
-
conv_cfg (dict): Config dict for convolution layer.
|
25 |
-
Default: None, which means using conv2d.
|
26 |
-
norm_cfg (dict): Config dict for normalization layer.
|
27 |
-
Default: dict(type='BN').
|
28 |
-
act_cfg (dict): Config dict for activation layer.
|
29 |
-
Default: dict(type='ReLU6').
|
30 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
31 |
-
memory while slowing down the training speed. Default: False.
|
32 |
-
"""
|
33 |
-
|
34 |
-
def __init__(self,
|
35 |
-
in_channels,
|
36 |
-
out_channels,
|
37 |
-
stride,
|
38 |
-
expand_ratio,
|
39 |
-
conv_cfg=None,
|
40 |
-
norm_cfg=dict(type='BN'),
|
41 |
-
act_cfg=dict(type='ReLU6'),
|
42 |
-
with_cp=False):
|
43 |
-
# Protect mutable default arguments
|
44 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
45 |
-
act_cfg = copy.deepcopy(act_cfg)
|
46 |
-
super().__init__()
|
47 |
-
self.stride = stride
|
48 |
-
assert stride in [1, 2], f'stride must in [1, 2]. ' \
|
49 |
-
f'But received {stride}.'
|
50 |
-
self.with_cp = with_cp
|
51 |
-
self.use_res_connect = self.stride == 1 and in_channels == out_channels
|
52 |
-
hidden_dim = int(round(in_channels * expand_ratio))
|
53 |
-
|
54 |
-
layers = []
|
55 |
-
if expand_ratio != 1:
|
56 |
-
layers.append(
|
57 |
-
ConvModule(
|
58 |
-
in_channels=in_channels,
|
59 |
-
out_channels=hidden_dim,
|
60 |
-
kernel_size=1,
|
61 |
-
conv_cfg=conv_cfg,
|
62 |
-
norm_cfg=norm_cfg,
|
63 |
-
act_cfg=act_cfg))
|
64 |
-
layers.extend([
|
65 |
-
ConvModule(
|
66 |
-
in_channels=hidden_dim,
|
67 |
-
out_channels=hidden_dim,
|
68 |
-
kernel_size=3,
|
69 |
-
stride=stride,
|
70 |
-
padding=1,
|
71 |
-
groups=hidden_dim,
|
72 |
-
conv_cfg=conv_cfg,
|
73 |
-
norm_cfg=norm_cfg,
|
74 |
-
act_cfg=act_cfg),
|
75 |
-
ConvModule(
|
76 |
-
in_channels=hidden_dim,
|
77 |
-
out_channels=out_channels,
|
78 |
-
kernel_size=1,
|
79 |
-
conv_cfg=conv_cfg,
|
80 |
-
norm_cfg=norm_cfg,
|
81 |
-
act_cfg=None)
|
82 |
-
])
|
83 |
-
self.conv = nn.Sequential(*layers)
|
84 |
-
|
85 |
-
def forward(self, x):
|
86 |
-
|
87 |
-
def _inner_forward(x):
|
88 |
-
if self.use_res_connect:
|
89 |
-
return x + self.conv(x)
|
90 |
-
return self.conv(x)
|
91 |
-
|
92 |
-
if self.with_cp and x.requires_grad:
|
93 |
-
out = cp.checkpoint(_inner_forward, x)
|
94 |
-
else:
|
95 |
-
out = _inner_forward(x)
|
96 |
-
|
97 |
-
return out
|
98 |
-
|
99 |
-
|
100 |
-
@BACKBONES.register_module()
|
101 |
-
class MobileNetV2(BaseBackbone):
|
102 |
-
"""MobileNetV2 backbone.
|
103 |
-
|
104 |
-
Args:
|
105 |
-
widen_factor (float): Width multiplier, multiply number of
|
106 |
-
channels in each layer by this amount. Default: 1.0.
|
107 |
-
out_indices (None or Sequence[int]): Output from which stages.
|
108 |
-
Default: (7, ).
|
109 |
-
frozen_stages (int): Stages to be frozen (all param fixed).
|
110 |
-
Default: -1, which means not freezing any parameters.
|
111 |
-
conv_cfg (dict): Config dict for convolution layer.
|
112 |
-
Default: None, which means using conv2d.
|
113 |
-
norm_cfg (dict): Config dict for normalization layer.
|
114 |
-
Default: dict(type='BN').
|
115 |
-
act_cfg (dict): Config dict for activation layer.
|
116 |
-
Default: dict(type='ReLU6').
|
117 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
118 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
119 |
-
and its variants only. Default: False.
|
120 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
121 |
-
memory while slowing down the training speed. Default: False.
|
122 |
-
"""
|
123 |
-
|
124 |
-
# Parameters to build layers. 4 parameters are needed to construct a
|
125 |
-
# layer, from left to right: expand_ratio, channel, num_blocks, stride.
|
126 |
-
arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],
|
127 |
-
[6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2],
|
128 |
-
[6, 320, 1, 1]]
|
129 |
-
|
130 |
-
def __init__(self,
|
131 |
-
widen_factor=1.,
|
132 |
-
out_indices=(7, ),
|
133 |
-
frozen_stages=-1,
|
134 |
-
conv_cfg=None,
|
135 |
-
norm_cfg=dict(type='BN'),
|
136 |
-
act_cfg=dict(type='ReLU6'),
|
137 |
-
norm_eval=False,
|
138 |
-
with_cp=False):
|
139 |
-
# Protect mutable default arguments
|
140 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
141 |
-
act_cfg = copy.deepcopy(act_cfg)
|
142 |
-
super().__init__()
|
143 |
-
self.widen_factor = widen_factor
|
144 |
-
self.out_indices = out_indices
|
145 |
-
for index in out_indices:
|
146 |
-
if index not in range(0, 8):
|
147 |
-
raise ValueError('the item in out_indices must in '
|
148 |
-
f'range(0, 8). But received {index}')
|
149 |
-
|
150 |
-
if frozen_stages not in range(-1, 8):
|
151 |
-
raise ValueError('frozen_stages must be in range(-1, 8). '
|
152 |
-
f'But received {frozen_stages}')
|
153 |
-
self.out_indices = out_indices
|
154 |
-
self.frozen_stages = frozen_stages
|
155 |
-
self.conv_cfg = conv_cfg
|
156 |
-
self.norm_cfg = norm_cfg
|
157 |
-
self.act_cfg = act_cfg
|
158 |
-
self.norm_eval = norm_eval
|
159 |
-
self.with_cp = with_cp
|
160 |
-
|
161 |
-
self.in_channels = make_divisible(32 * widen_factor, 8)
|
162 |
-
|
163 |
-
self.conv1 = ConvModule(
|
164 |
-
in_channels=3,
|
165 |
-
out_channels=self.in_channels,
|
166 |
-
kernel_size=3,
|
167 |
-
stride=2,
|
168 |
-
padding=1,
|
169 |
-
conv_cfg=self.conv_cfg,
|
170 |
-
norm_cfg=self.norm_cfg,
|
171 |
-
act_cfg=self.act_cfg)
|
172 |
-
|
173 |
-
self.layers = []
|
174 |
-
|
175 |
-
for i, layer_cfg in enumerate(self.arch_settings):
|
176 |
-
expand_ratio, channel, num_blocks, stride = layer_cfg
|
177 |
-
out_channels = make_divisible(channel * widen_factor, 8)
|
178 |
-
inverted_res_layer = self.make_layer(
|
179 |
-
out_channels=out_channels,
|
180 |
-
num_blocks=num_blocks,
|
181 |
-
stride=stride,
|
182 |
-
expand_ratio=expand_ratio)
|
183 |
-
layer_name = f'layer{i + 1}'
|
184 |
-
self.add_module(layer_name, inverted_res_layer)
|
185 |
-
self.layers.append(layer_name)
|
186 |
-
|
187 |
-
if widen_factor > 1.0:
|
188 |
-
self.out_channel = int(1280 * widen_factor)
|
189 |
-
else:
|
190 |
-
self.out_channel = 1280
|
191 |
-
|
192 |
-
layer = ConvModule(
|
193 |
-
in_channels=self.in_channels,
|
194 |
-
out_channels=self.out_channel,
|
195 |
-
kernel_size=1,
|
196 |
-
stride=1,
|
197 |
-
padding=0,
|
198 |
-
conv_cfg=self.conv_cfg,
|
199 |
-
norm_cfg=self.norm_cfg,
|
200 |
-
act_cfg=self.act_cfg)
|
201 |
-
self.add_module('conv2', layer)
|
202 |
-
self.layers.append('conv2')
|
203 |
-
|
204 |
-
def make_layer(self, out_channels, num_blocks, stride, expand_ratio):
|
205 |
-
"""Stack InvertedResidual blocks to build a layer for MobileNetV2.
|
206 |
-
|
207 |
-
Args:
|
208 |
-
out_channels (int): out_channels of block.
|
209 |
-
num_blocks (int): number of blocks.
|
210 |
-
stride (int): stride of the first block. Default: 1
|
211 |
-
expand_ratio (int): Expand the number of channels of the
|
212 |
-
hidden layer in InvertedResidual by this ratio. Default: 6.
|
213 |
-
"""
|
214 |
-
layers = []
|
215 |
-
for i in range(num_blocks):
|
216 |
-
if i >= 1:
|
217 |
-
stride = 1
|
218 |
-
layers.append(
|
219 |
-
InvertedResidual(
|
220 |
-
self.in_channels,
|
221 |
-
out_channels,
|
222 |
-
stride,
|
223 |
-
expand_ratio=expand_ratio,
|
224 |
-
conv_cfg=self.conv_cfg,
|
225 |
-
norm_cfg=self.norm_cfg,
|
226 |
-
act_cfg=self.act_cfg,
|
227 |
-
with_cp=self.with_cp))
|
228 |
-
self.in_channels = out_channels
|
229 |
-
|
230 |
-
return nn.Sequential(*layers)
|
231 |
-
|
232 |
-
def init_weights(self, pretrained=None):
|
233 |
-
if isinstance(pretrained, str):
|
234 |
-
logger = logging.getLogger()
|
235 |
-
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
236 |
-
elif pretrained is None:
|
237 |
-
for m in self.modules():
|
238 |
-
if isinstance(m, nn.Conv2d):
|
239 |
-
kaiming_init(m)
|
240 |
-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
241 |
-
constant_init(m, 1)
|
242 |
-
else:
|
243 |
-
raise TypeError('pretrained must be a str or None')
|
244 |
-
|
245 |
-
def forward(self, x):
|
246 |
-
x = self.conv1(x)
|
247 |
-
|
248 |
-
outs = []
|
249 |
-
for i, layer_name in enumerate(self.layers):
|
250 |
-
layer = getattr(self, layer_name)
|
251 |
-
x = layer(x)
|
252 |
-
if i in self.out_indices:
|
253 |
-
outs.append(x)
|
254 |
-
|
255 |
-
if len(outs) == 1:
|
256 |
-
return outs[0]
|
257 |
-
return tuple(outs)
|
258 |
-
|
259 |
-
def _freeze_stages(self):
|
260 |
-
if self.frozen_stages >= 0:
|
261 |
-
for param in self.conv1.parameters():
|
262 |
-
param.requires_grad = False
|
263 |
-
for i in range(1, self.frozen_stages + 1):
|
264 |
-
layer = getattr(self, f'layer{i}')
|
265 |
-
layer.eval()
|
266 |
-
for param in layer.parameters():
|
267 |
-
param.requires_grad = False
|
268 |
-
|
269 |
-
def train(self, mode=True):
|
270 |
-
super().train(mode)
|
271 |
-
self._freeze_stages()
|
272 |
-
if mode and self.norm_eval:
|
273 |
-
for m in self.modules():
|
274 |
-
if isinstance(m, _BatchNorm):
|
275 |
-
m.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/mobilenet_v3.py
DELETED
@@ -1,188 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
import logging
|
4 |
-
|
5 |
-
import torch.nn as nn
|
6 |
-
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
7 |
-
from torch.nn.modules.batchnorm import _BatchNorm
|
8 |
-
|
9 |
-
from ..builder import BACKBONES
|
10 |
-
from .base_backbone import BaseBackbone
|
11 |
-
from .utils import InvertedResidual, load_checkpoint
|
12 |
-
|
13 |
-
|
14 |
-
@BACKBONES.register_module()
|
15 |
-
class MobileNetV3(BaseBackbone):
|
16 |
-
"""MobileNetV3 backbone.
|
17 |
-
|
18 |
-
Args:
|
19 |
-
arch (str): Architecture of mobilnetv3, from {small, big}.
|
20 |
-
Default: small.
|
21 |
-
conv_cfg (dict): Config dict for convolution layer.
|
22 |
-
Default: None, which means using conv2d.
|
23 |
-
norm_cfg (dict): Config dict for normalization layer.
|
24 |
-
Default: dict(type='BN').
|
25 |
-
out_indices (None or Sequence[int]): Output from which stages.
|
26 |
-
Default: (-1, ), which means output tensors from final stage.
|
27 |
-
frozen_stages (int): Stages to be frozen (all param fixed).
|
28 |
-
Default: -1, which means not freezing any parameters.
|
29 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
30 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
31 |
-
and its variants only. Default: False.
|
32 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save
|
33 |
-
some memory while slowing down the training speed.
|
34 |
-
Default: False.
|
35 |
-
"""
|
36 |
-
# Parameters to build each block:
|
37 |
-
# [kernel size, mid channels, out channels, with_se, act type, stride]
|
38 |
-
arch_settings = {
|
39 |
-
'small': [[3, 16, 16, True, 'ReLU', 2],
|
40 |
-
[3, 72, 24, False, 'ReLU', 2],
|
41 |
-
[3, 88, 24, False, 'ReLU', 1],
|
42 |
-
[5, 96, 40, True, 'HSwish', 2],
|
43 |
-
[5, 240, 40, True, 'HSwish', 1],
|
44 |
-
[5, 240, 40, True, 'HSwish', 1],
|
45 |
-
[5, 120, 48, True, 'HSwish', 1],
|
46 |
-
[5, 144, 48, True, 'HSwish', 1],
|
47 |
-
[5, 288, 96, True, 'HSwish', 2],
|
48 |
-
[5, 576, 96, True, 'HSwish', 1],
|
49 |
-
[5, 576, 96, True, 'HSwish', 1]],
|
50 |
-
'big': [[3, 16, 16, False, 'ReLU', 1],
|
51 |
-
[3, 64, 24, False, 'ReLU', 2],
|
52 |
-
[3, 72, 24, False, 'ReLU', 1],
|
53 |
-
[5, 72, 40, True, 'ReLU', 2],
|
54 |
-
[5, 120, 40, True, 'ReLU', 1],
|
55 |
-
[5, 120, 40, True, 'ReLU', 1],
|
56 |
-
[3, 240, 80, False, 'HSwish', 2],
|
57 |
-
[3, 200, 80, False, 'HSwish', 1],
|
58 |
-
[3, 184, 80, False, 'HSwish', 1],
|
59 |
-
[3, 184, 80, False, 'HSwish', 1],
|
60 |
-
[3, 480, 112, True, 'HSwish', 1],
|
61 |
-
[3, 672, 112, True, 'HSwish', 1],
|
62 |
-
[5, 672, 160, True, 'HSwish', 1],
|
63 |
-
[5, 672, 160, True, 'HSwish', 2],
|
64 |
-
[5, 960, 160, True, 'HSwish', 1]]
|
65 |
-
} # yapf: disable
|
66 |
-
|
67 |
-
def __init__(self,
|
68 |
-
arch='small',
|
69 |
-
conv_cfg=None,
|
70 |
-
norm_cfg=dict(type='BN'),
|
71 |
-
out_indices=(-1, ),
|
72 |
-
frozen_stages=-1,
|
73 |
-
norm_eval=False,
|
74 |
-
with_cp=False):
|
75 |
-
# Protect mutable default arguments
|
76 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
77 |
-
super().__init__()
|
78 |
-
assert arch in self.arch_settings
|
79 |
-
for index in out_indices:
|
80 |
-
if index not in range(-len(self.arch_settings[arch]),
|
81 |
-
len(self.arch_settings[arch])):
|
82 |
-
raise ValueError('the item in out_indices must in '
|
83 |
-
f'range(0, {len(self.arch_settings[arch])}). '
|
84 |
-
f'But received {index}')
|
85 |
-
|
86 |
-
if frozen_stages not in range(-1, len(self.arch_settings[arch])):
|
87 |
-
raise ValueError('frozen_stages must be in range(-1, '
|
88 |
-
f'{len(self.arch_settings[arch])}). '
|
89 |
-
f'But received {frozen_stages}')
|
90 |
-
self.arch = arch
|
91 |
-
self.conv_cfg = conv_cfg
|
92 |
-
self.norm_cfg = norm_cfg
|
93 |
-
self.out_indices = out_indices
|
94 |
-
self.frozen_stages = frozen_stages
|
95 |
-
self.norm_eval = norm_eval
|
96 |
-
self.with_cp = with_cp
|
97 |
-
|
98 |
-
self.in_channels = 16
|
99 |
-
self.conv1 = ConvModule(
|
100 |
-
in_channels=3,
|
101 |
-
out_channels=self.in_channels,
|
102 |
-
kernel_size=3,
|
103 |
-
stride=2,
|
104 |
-
padding=1,
|
105 |
-
conv_cfg=conv_cfg,
|
106 |
-
norm_cfg=norm_cfg,
|
107 |
-
act_cfg=dict(type='HSwish'))
|
108 |
-
|
109 |
-
self.layers = self._make_layer()
|
110 |
-
self.feat_dim = self.arch_settings[arch][-1][2]
|
111 |
-
|
112 |
-
def _make_layer(self):
|
113 |
-
layers = []
|
114 |
-
layer_setting = self.arch_settings[self.arch]
|
115 |
-
for i, params in enumerate(layer_setting):
|
116 |
-
(kernel_size, mid_channels, out_channels, with_se, act,
|
117 |
-
stride) = params
|
118 |
-
if with_se:
|
119 |
-
se_cfg = dict(
|
120 |
-
channels=mid_channels,
|
121 |
-
ratio=4,
|
122 |
-
act_cfg=(dict(type='ReLU'), dict(type='HSigmoid')))
|
123 |
-
else:
|
124 |
-
se_cfg = None
|
125 |
-
|
126 |
-
layer = InvertedResidual(
|
127 |
-
in_channels=self.in_channels,
|
128 |
-
out_channels=out_channels,
|
129 |
-
mid_channels=mid_channels,
|
130 |
-
kernel_size=kernel_size,
|
131 |
-
stride=stride,
|
132 |
-
se_cfg=se_cfg,
|
133 |
-
with_expand_conv=True,
|
134 |
-
conv_cfg=self.conv_cfg,
|
135 |
-
norm_cfg=self.norm_cfg,
|
136 |
-
act_cfg=dict(type=act),
|
137 |
-
with_cp=self.with_cp)
|
138 |
-
self.in_channels = out_channels
|
139 |
-
layer_name = f'layer{i + 1}'
|
140 |
-
self.add_module(layer_name, layer)
|
141 |
-
layers.append(layer_name)
|
142 |
-
return layers
|
143 |
-
|
144 |
-
def init_weights(self, pretrained=None):
|
145 |
-
if isinstance(pretrained, str):
|
146 |
-
logger = logging.getLogger()
|
147 |
-
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
148 |
-
elif pretrained is None:
|
149 |
-
for m in self.modules():
|
150 |
-
if isinstance(m, nn.Conv2d):
|
151 |
-
kaiming_init(m)
|
152 |
-
elif isinstance(m, nn.BatchNorm2d):
|
153 |
-
constant_init(m, 1)
|
154 |
-
else:
|
155 |
-
raise TypeError('pretrained must be a str or None')
|
156 |
-
|
157 |
-
def forward(self, x):
|
158 |
-
x = self.conv1(x)
|
159 |
-
|
160 |
-
outs = []
|
161 |
-
for i, layer_name in enumerate(self.layers):
|
162 |
-
layer = getattr(self, layer_name)
|
163 |
-
x = layer(x)
|
164 |
-
if i in self.out_indices or \
|
165 |
-
i - len(self.layers) in self.out_indices:
|
166 |
-
outs.append(x)
|
167 |
-
|
168 |
-
if len(outs) == 1:
|
169 |
-
return outs[0]
|
170 |
-
return tuple(outs)
|
171 |
-
|
172 |
-
def _freeze_stages(self):
|
173 |
-
if self.frozen_stages >= 0:
|
174 |
-
for param in self.conv1.parameters():
|
175 |
-
param.requires_grad = False
|
176 |
-
for i in range(1, self.frozen_stages + 1):
|
177 |
-
layer = getattr(self, f'layer{i}')
|
178 |
-
layer.eval()
|
179 |
-
for param in layer.parameters():
|
180 |
-
param.requires_grad = False
|
181 |
-
|
182 |
-
def train(self, mode=True):
|
183 |
-
super().train(mode)
|
184 |
-
self._freeze_stages()
|
185 |
-
if mode and self.norm_eval:
|
186 |
-
for m in self.modules():
|
187 |
-
if isinstance(m, _BatchNorm):
|
188 |
-
m.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/modules/basic_block.py
CHANGED
@@ -12,13 +12,11 @@ import torch.nn as nn
|
|
12 |
import torch.nn.functional as F
|
13 |
import torch.utils.checkpoint as cp
|
14 |
from .transformer_block import TransformerBlock
|
15 |
-
|
16 |
from mmcv.cnn import (
|
17 |
build_conv_layer,
|
18 |
build_norm_layer,
|
19 |
build_plugin_layer,
|
20 |
-
constant_init,
|
21 |
-
kaiming_init,
|
22 |
)
|
23 |
|
24 |
|
|
|
12 |
import torch.nn.functional as F
|
13 |
import torch.utils.checkpoint as cp
|
14 |
from .transformer_block import TransformerBlock
|
15 |
+
from mmengine.model import constant_init, kaiming_init
|
16 |
from mmcv.cnn import (
|
17 |
build_conv_layer,
|
18 |
build_norm_layer,
|
19 |
build_plugin_layer,
|
|
|
|
|
20 |
)
|
21 |
|
22 |
|
main/transformer_utils/mmpose/models/backbones/mspn.py
DELETED
@@ -1,513 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy as cp
|
3 |
-
from collections import OrderedDict
|
4 |
-
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from mmcv.cnn import (ConvModule, MaxPool2d, constant_init, kaiming_init,
|
8 |
-
normal_init)
|
9 |
-
from mmcv.runner.checkpoint import load_state_dict
|
10 |
-
|
11 |
-
from mmpose.utils import get_root_logger
|
12 |
-
from ..builder import BACKBONES
|
13 |
-
from .base_backbone import BaseBackbone
|
14 |
-
from .resnet import Bottleneck as _Bottleneck
|
15 |
-
from .utils.utils import get_state_dict
|
16 |
-
|
17 |
-
|
18 |
-
class Bottleneck(_Bottleneck):
|
19 |
-
expansion = 4
|
20 |
-
"""Bottleneck block for MSPN.
|
21 |
-
|
22 |
-
Args:
|
23 |
-
in_channels (int): Input channels of this block.
|
24 |
-
out_channels (int): Output channels of this block.
|
25 |
-
stride (int): stride of the block. Default: 1
|
26 |
-
downsample (nn.Module): downsample operation on identity branch.
|
27 |
-
Default: None
|
28 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
29 |
-
Default: dict(type='BN')
|
30 |
-
"""
|
31 |
-
|
32 |
-
def __init__(self, in_channels, out_channels, **kwargs):
|
33 |
-
super().__init__(in_channels, out_channels * 4, **kwargs)
|
34 |
-
|
35 |
-
|
36 |
-
class DownsampleModule(nn.Module):
|
37 |
-
"""Downsample module for MSPN.
|
38 |
-
|
39 |
-
Args:
|
40 |
-
block (nn.Module): Downsample block.
|
41 |
-
num_blocks (list): Number of blocks in each downsample unit.
|
42 |
-
num_units (int): Numbers of downsample units. Default: 4
|
43 |
-
has_skip (bool): Have skip connections from prior upsample
|
44 |
-
module or not. Default:False
|
45 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
46 |
-
Default: dict(type='BN')
|
47 |
-
in_channels (int): Number of channels of the input feature to
|
48 |
-
downsample module. Default: 64
|
49 |
-
"""
|
50 |
-
|
51 |
-
def __init__(self,
|
52 |
-
block,
|
53 |
-
num_blocks,
|
54 |
-
num_units=4,
|
55 |
-
has_skip=False,
|
56 |
-
norm_cfg=dict(type='BN'),
|
57 |
-
in_channels=64):
|
58 |
-
# Protect mutable default arguments
|
59 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
60 |
-
super().__init__()
|
61 |
-
self.has_skip = has_skip
|
62 |
-
self.in_channels = in_channels
|
63 |
-
assert len(num_blocks) == num_units
|
64 |
-
self.num_blocks = num_blocks
|
65 |
-
self.num_units = num_units
|
66 |
-
self.norm_cfg = norm_cfg
|
67 |
-
self.layer1 = self._make_layer(block, in_channels, num_blocks[0])
|
68 |
-
for i in range(1, num_units):
|
69 |
-
module_name = f'layer{i + 1}'
|
70 |
-
self.add_module(
|
71 |
-
module_name,
|
72 |
-
self._make_layer(
|
73 |
-
block, in_channels * pow(2, i), num_blocks[i], stride=2))
|
74 |
-
|
75 |
-
def _make_layer(self, block, out_channels, blocks, stride=1):
|
76 |
-
downsample = None
|
77 |
-
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
78 |
-
downsample = ConvModule(
|
79 |
-
self.in_channels,
|
80 |
-
out_channels * block.expansion,
|
81 |
-
kernel_size=1,
|
82 |
-
stride=stride,
|
83 |
-
padding=0,
|
84 |
-
norm_cfg=self.norm_cfg,
|
85 |
-
act_cfg=None,
|
86 |
-
inplace=True)
|
87 |
-
|
88 |
-
units = list()
|
89 |
-
units.append(
|
90 |
-
block(
|
91 |
-
self.in_channels,
|
92 |
-
out_channels,
|
93 |
-
stride=stride,
|
94 |
-
downsample=downsample,
|
95 |
-
norm_cfg=self.norm_cfg))
|
96 |
-
self.in_channels = out_channels * block.expansion
|
97 |
-
for _ in range(1, blocks):
|
98 |
-
units.append(block(self.in_channels, out_channels))
|
99 |
-
|
100 |
-
return nn.Sequential(*units)
|
101 |
-
|
102 |
-
def forward(self, x, skip1, skip2):
|
103 |
-
out = list()
|
104 |
-
for i in range(self.num_units):
|
105 |
-
module_name = f'layer{i + 1}'
|
106 |
-
module_i = getattr(self, module_name)
|
107 |
-
x = module_i(x)
|
108 |
-
if self.has_skip:
|
109 |
-
x = x + skip1[i] + skip2[i]
|
110 |
-
out.append(x)
|
111 |
-
out.reverse()
|
112 |
-
|
113 |
-
return tuple(out)
|
114 |
-
|
115 |
-
|
116 |
-
class UpsampleUnit(nn.Module):
|
117 |
-
"""Upsample unit for upsample module.
|
118 |
-
|
119 |
-
Args:
|
120 |
-
ind (int): Indicates whether to interpolate (>0) and whether to
|
121 |
-
generate feature map for the next hourglass-like module.
|
122 |
-
num_units (int): Number of units that form a upsample module. Along
|
123 |
-
with ind and gen_cross_conv, nm_units is used to decide whether
|
124 |
-
to generate feature map for the next hourglass-like module.
|
125 |
-
in_channels (int): Channel number of the skip-in feature maps from
|
126 |
-
the corresponding downsample unit.
|
127 |
-
unit_channels (int): Channel number in this unit. Default:256.
|
128 |
-
gen_skip: (bool): Whether or not to generate skips for the posterior
|
129 |
-
downsample module. Default:False
|
130 |
-
gen_cross_conv (bool): Whether to generate feature map for the next
|
131 |
-
hourglass-like module. Default:False
|
132 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
133 |
-
Default: dict(type='BN')
|
134 |
-
out_channels (int): Number of channels of feature output by upsample
|
135 |
-
module. Must equal to in_channels of downsample module. Default:64
|
136 |
-
"""
|
137 |
-
|
138 |
-
def __init__(self,
|
139 |
-
ind,
|
140 |
-
num_units,
|
141 |
-
in_channels,
|
142 |
-
unit_channels=256,
|
143 |
-
gen_skip=False,
|
144 |
-
gen_cross_conv=False,
|
145 |
-
norm_cfg=dict(type='BN'),
|
146 |
-
out_channels=64):
|
147 |
-
# Protect mutable default arguments
|
148 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
149 |
-
super().__init__()
|
150 |
-
self.num_units = num_units
|
151 |
-
self.norm_cfg = norm_cfg
|
152 |
-
self.in_skip = ConvModule(
|
153 |
-
in_channels,
|
154 |
-
unit_channels,
|
155 |
-
kernel_size=1,
|
156 |
-
stride=1,
|
157 |
-
padding=0,
|
158 |
-
norm_cfg=self.norm_cfg,
|
159 |
-
act_cfg=None,
|
160 |
-
inplace=True)
|
161 |
-
self.relu = nn.ReLU(inplace=True)
|
162 |
-
|
163 |
-
self.ind = ind
|
164 |
-
if self.ind > 0:
|
165 |
-
self.up_conv = ConvModule(
|
166 |
-
unit_channels,
|
167 |
-
unit_channels,
|
168 |
-
kernel_size=1,
|
169 |
-
stride=1,
|
170 |
-
padding=0,
|
171 |
-
norm_cfg=self.norm_cfg,
|
172 |
-
act_cfg=None,
|
173 |
-
inplace=True)
|
174 |
-
|
175 |
-
self.gen_skip = gen_skip
|
176 |
-
if self.gen_skip:
|
177 |
-
self.out_skip1 = ConvModule(
|
178 |
-
in_channels,
|
179 |
-
in_channels,
|
180 |
-
kernel_size=1,
|
181 |
-
stride=1,
|
182 |
-
padding=0,
|
183 |
-
norm_cfg=self.norm_cfg,
|
184 |
-
inplace=True)
|
185 |
-
|
186 |
-
self.out_skip2 = ConvModule(
|
187 |
-
unit_channels,
|
188 |
-
in_channels,
|
189 |
-
kernel_size=1,
|
190 |
-
stride=1,
|
191 |
-
padding=0,
|
192 |
-
norm_cfg=self.norm_cfg,
|
193 |
-
inplace=True)
|
194 |
-
|
195 |
-
self.gen_cross_conv = gen_cross_conv
|
196 |
-
if self.ind == num_units - 1 and self.gen_cross_conv:
|
197 |
-
self.cross_conv = ConvModule(
|
198 |
-
unit_channels,
|
199 |
-
out_channels,
|
200 |
-
kernel_size=1,
|
201 |
-
stride=1,
|
202 |
-
padding=0,
|
203 |
-
norm_cfg=self.norm_cfg,
|
204 |
-
inplace=True)
|
205 |
-
|
206 |
-
def forward(self, x, up_x):
|
207 |
-
out = self.in_skip(x)
|
208 |
-
|
209 |
-
if self.ind > 0:
|
210 |
-
up_x = F.interpolate(
|
211 |
-
up_x,
|
212 |
-
size=(x.size(2), x.size(3)),
|
213 |
-
mode='bilinear',
|
214 |
-
align_corners=True)
|
215 |
-
up_x = self.up_conv(up_x)
|
216 |
-
out = out + up_x
|
217 |
-
out = self.relu(out)
|
218 |
-
|
219 |
-
skip1 = None
|
220 |
-
skip2 = None
|
221 |
-
if self.gen_skip:
|
222 |
-
skip1 = self.out_skip1(x)
|
223 |
-
skip2 = self.out_skip2(out)
|
224 |
-
|
225 |
-
cross_conv = None
|
226 |
-
if self.ind == self.num_units - 1 and self.gen_cross_conv:
|
227 |
-
cross_conv = self.cross_conv(out)
|
228 |
-
|
229 |
-
return out, skip1, skip2, cross_conv
|
230 |
-
|
231 |
-
|
232 |
-
class UpsampleModule(nn.Module):
|
233 |
-
"""Upsample module for MSPN.
|
234 |
-
|
235 |
-
Args:
|
236 |
-
unit_channels (int): Channel number in the upsample units.
|
237 |
-
Default:256.
|
238 |
-
num_units (int): Numbers of upsample units. Default: 4
|
239 |
-
gen_skip (bool): Whether to generate skip for posterior downsample
|
240 |
-
module or not. Default:False
|
241 |
-
gen_cross_conv (bool): Whether to generate feature map for the next
|
242 |
-
hourglass-like module. Default:False
|
243 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
244 |
-
Default: dict(type='BN')
|
245 |
-
out_channels (int): Number of channels of feature output by upsample
|
246 |
-
module. Must equal to in_channels of downsample module. Default:64
|
247 |
-
"""
|
248 |
-
|
249 |
-
def __init__(self,
|
250 |
-
unit_channels=256,
|
251 |
-
num_units=4,
|
252 |
-
gen_skip=False,
|
253 |
-
gen_cross_conv=False,
|
254 |
-
norm_cfg=dict(type='BN'),
|
255 |
-
out_channels=64):
|
256 |
-
# Protect mutable default arguments
|
257 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
258 |
-
super().__init__()
|
259 |
-
self.in_channels = list()
|
260 |
-
for i in range(num_units):
|
261 |
-
self.in_channels.append(Bottleneck.expansion * out_channels *
|
262 |
-
pow(2, i))
|
263 |
-
self.in_channels.reverse()
|
264 |
-
self.num_units = num_units
|
265 |
-
self.gen_skip = gen_skip
|
266 |
-
self.gen_cross_conv = gen_cross_conv
|
267 |
-
self.norm_cfg = norm_cfg
|
268 |
-
for i in range(num_units):
|
269 |
-
module_name = f'up{i + 1}'
|
270 |
-
self.add_module(
|
271 |
-
module_name,
|
272 |
-
UpsampleUnit(
|
273 |
-
i,
|
274 |
-
self.num_units,
|
275 |
-
self.in_channels[i],
|
276 |
-
unit_channels,
|
277 |
-
self.gen_skip,
|
278 |
-
self.gen_cross_conv,
|
279 |
-
norm_cfg=self.norm_cfg,
|
280 |
-
out_channels=64))
|
281 |
-
|
282 |
-
def forward(self, x):
|
283 |
-
out = list()
|
284 |
-
skip1 = list()
|
285 |
-
skip2 = list()
|
286 |
-
cross_conv = None
|
287 |
-
for i in range(self.num_units):
|
288 |
-
module_i = getattr(self, f'up{i + 1}')
|
289 |
-
if i == 0:
|
290 |
-
outi, skip1_i, skip2_i, _ = module_i(x[i], None)
|
291 |
-
elif i == self.num_units - 1:
|
292 |
-
outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1])
|
293 |
-
else:
|
294 |
-
outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1])
|
295 |
-
out.append(outi)
|
296 |
-
skip1.append(skip1_i)
|
297 |
-
skip2.append(skip2_i)
|
298 |
-
skip1.reverse()
|
299 |
-
skip2.reverse()
|
300 |
-
|
301 |
-
return out, skip1, skip2, cross_conv
|
302 |
-
|
303 |
-
|
304 |
-
class SingleStageNetwork(nn.Module):
|
305 |
-
"""Single_stage Network.
|
306 |
-
|
307 |
-
Args:
|
308 |
-
unit_channels (int): Channel number in the upsample units. Default:256.
|
309 |
-
num_units (int): Numbers of downsample/upsample units. Default: 4
|
310 |
-
gen_skip (bool): Whether to generate skip for posterior downsample
|
311 |
-
module or not. Default:False
|
312 |
-
gen_cross_conv (bool): Whether to generate feature map for the next
|
313 |
-
hourglass-like module. Default:False
|
314 |
-
has_skip (bool): Have skip connections from prior upsample
|
315 |
-
module or not. Default:False
|
316 |
-
num_blocks (list): Number of blocks in each downsample unit.
|
317 |
-
Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks)
|
318 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
319 |
-
Default: dict(type='BN')
|
320 |
-
in_channels (int): Number of channels of the feature from ResNetTop.
|
321 |
-
Default: 64.
|
322 |
-
"""
|
323 |
-
|
324 |
-
def __init__(self,
|
325 |
-
has_skip=False,
|
326 |
-
gen_skip=False,
|
327 |
-
gen_cross_conv=False,
|
328 |
-
unit_channels=256,
|
329 |
-
num_units=4,
|
330 |
-
num_blocks=[2, 2, 2, 2],
|
331 |
-
norm_cfg=dict(type='BN'),
|
332 |
-
in_channels=64):
|
333 |
-
# Protect mutable default arguments
|
334 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
335 |
-
num_blocks = cp.deepcopy(num_blocks)
|
336 |
-
super().__init__()
|
337 |
-
assert len(num_blocks) == num_units
|
338 |
-
self.has_skip = has_skip
|
339 |
-
self.gen_skip = gen_skip
|
340 |
-
self.gen_cross_conv = gen_cross_conv
|
341 |
-
self.num_units = num_units
|
342 |
-
self.unit_channels = unit_channels
|
343 |
-
self.num_blocks = num_blocks
|
344 |
-
self.norm_cfg = norm_cfg
|
345 |
-
|
346 |
-
self.downsample = DownsampleModule(Bottleneck, num_blocks, num_units,
|
347 |
-
has_skip, norm_cfg, in_channels)
|
348 |
-
self.upsample = UpsampleModule(unit_channels, num_units, gen_skip,
|
349 |
-
gen_cross_conv, norm_cfg, in_channels)
|
350 |
-
|
351 |
-
def forward(self, x, skip1, skip2):
|
352 |
-
mid = self.downsample(x, skip1, skip2)
|
353 |
-
out, skip1, skip2, cross_conv = self.upsample(mid)
|
354 |
-
|
355 |
-
return out, skip1, skip2, cross_conv
|
356 |
-
|
357 |
-
|
358 |
-
class ResNetTop(nn.Module):
|
359 |
-
"""ResNet top for MSPN.
|
360 |
-
|
361 |
-
Args:
|
362 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
363 |
-
Default: dict(type='BN')
|
364 |
-
channels (int): Number of channels of the feature output by ResNetTop.
|
365 |
-
"""
|
366 |
-
|
367 |
-
def __init__(self, norm_cfg=dict(type='BN'), channels=64):
|
368 |
-
# Protect mutable default arguments
|
369 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
370 |
-
super().__init__()
|
371 |
-
self.top = nn.Sequential(
|
372 |
-
ConvModule(
|
373 |
-
3,
|
374 |
-
channels,
|
375 |
-
kernel_size=7,
|
376 |
-
stride=2,
|
377 |
-
padding=3,
|
378 |
-
norm_cfg=norm_cfg,
|
379 |
-
inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1))
|
380 |
-
|
381 |
-
def forward(self, img):
|
382 |
-
return self.top(img)
|
383 |
-
|
384 |
-
|
385 |
-
@BACKBONES.register_module()
|
386 |
-
class MSPN(BaseBackbone):
|
387 |
-
"""MSPN backbone. Paper ref: Li et al. "Rethinking on Multi-Stage Networks
|
388 |
-
for Human Pose Estimation" (CVPR 2020).
|
389 |
-
|
390 |
-
Args:
|
391 |
-
unit_channels (int): Number of Channels in an upsample unit.
|
392 |
-
Default: 256
|
393 |
-
num_stages (int): Number of stages in a multi-stage MSPN. Default: 4
|
394 |
-
num_units (int): Number of downsample/upsample units in a single-stage
|
395 |
-
network. Default: 4
|
396 |
-
Note: Make sure num_units == len(self.num_blocks)
|
397 |
-
num_blocks (list): Number of bottlenecks in each
|
398 |
-
downsample unit. Default: [2, 2, 2, 2]
|
399 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
400 |
-
Default: dict(type='BN')
|
401 |
-
res_top_channels (int): Number of channels of feature from ResNetTop.
|
402 |
-
Default: 64.
|
403 |
-
|
404 |
-
Example:
|
405 |
-
>>> from mmpose.models import MSPN
|
406 |
-
>>> import torch
|
407 |
-
>>> self = MSPN(num_stages=2,num_units=2,num_blocks=[2,2])
|
408 |
-
>>> self.eval()
|
409 |
-
>>> inputs = torch.rand(1, 3, 511, 511)
|
410 |
-
>>> level_outputs = self.forward(inputs)
|
411 |
-
>>> for level_output in level_outputs:
|
412 |
-
... for feature in level_output:
|
413 |
-
... print(tuple(feature.shape))
|
414 |
-
...
|
415 |
-
(1, 256, 64, 64)
|
416 |
-
(1, 256, 128, 128)
|
417 |
-
(1, 256, 64, 64)
|
418 |
-
(1, 256, 128, 128)
|
419 |
-
"""
|
420 |
-
|
421 |
-
def __init__(self,
|
422 |
-
unit_channels=256,
|
423 |
-
num_stages=4,
|
424 |
-
num_units=4,
|
425 |
-
num_blocks=[2, 2, 2, 2],
|
426 |
-
norm_cfg=dict(type='BN'),
|
427 |
-
res_top_channels=64):
|
428 |
-
# Protect mutable default arguments
|
429 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
430 |
-
num_blocks = cp.deepcopy(num_blocks)
|
431 |
-
super().__init__()
|
432 |
-
self.unit_channels = unit_channels
|
433 |
-
self.num_stages = num_stages
|
434 |
-
self.num_units = num_units
|
435 |
-
self.num_blocks = num_blocks
|
436 |
-
self.norm_cfg = norm_cfg
|
437 |
-
|
438 |
-
assert self.num_stages > 0
|
439 |
-
assert self.num_units > 1
|
440 |
-
assert self.num_units == len(self.num_blocks)
|
441 |
-
self.top = ResNetTop(norm_cfg=norm_cfg)
|
442 |
-
self.multi_stage_mspn = nn.ModuleList([])
|
443 |
-
for i in range(self.num_stages):
|
444 |
-
if i == 0:
|
445 |
-
has_skip = False
|
446 |
-
else:
|
447 |
-
has_skip = True
|
448 |
-
if i != self.num_stages - 1:
|
449 |
-
gen_skip = True
|
450 |
-
gen_cross_conv = True
|
451 |
-
else:
|
452 |
-
gen_skip = False
|
453 |
-
gen_cross_conv = False
|
454 |
-
self.multi_stage_mspn.append(
|
455 |
-
SingleStageNetwork(has_skip, gen_skip, gen_cross_conv,
|
456 |
-
unit_channels, num_units, num_blocks,
|
457 |
-
norm_cfg, res_top_channels))
|
458 |
-
|
459 |
-
def forward(self, x):
|
460 |
-
"""Model forward function."""
|
461 |
-
out_feats = []
|
462 |
-
skip1 = None
|
463 |
-
skip2 = None
|
464 |
-
x = self.top(x)
|
465 |
-
for i in range(self.num_stages):
|
466 |
-
out, skip1, skip2, x = self.multi_stage_mspn[i](x, skip1, skip2)
|
467 |
-
out_feats.append(out)
|
468 |
-
|
469 |
-
return out_feats
|
470 |
-
|
471 |
-
def init_weights(self, pretrained=None):
|
472 |
-
"""Initialize model weights."""
|
473 |
-
if isinstance(pretrained, str):
|
474 |
-
logger = get_root_logger()
|
475 |
-
state_dict_tmp = get_state_dict(pretrained)
|
476 |
-
state_dict = OrderedDict()
|
477 |
-
state_dict['top'] = OrderedDict()
|
478 |
-
state_dict['bottlenecks'] = OrderedDict()
|
479 |
-
for k, v in state_dict_tmp.items():
|
480 |
-
if k.startswith('layer'):
|
481 |
-
if 'downsample.0' in k:
|
482 |
-
state_dict['bottlenecks'][k.replace(
|
483 |
-
'downsample.0', 'downsample.conv')] = v
|
484 |
-
elif 'downsample.1' in k:
|
485 |
-
state_dict['bottlenecks'][k.replace(
|
486 |
-
'downsample.1', 'downsample.bn')] = v
|
487 |
-
else:
|
488 |
-
state_dict['bottlenecks'][k] = v
|
489 |
-
elif k.startswith('conv1'):
|
490 |
-
state_dict['top'][k.replace('conv1', 'top.0.conv')] = v
|
491 |
-
elif k.startswith('bn1'):
|
492 |
-
state_dict['top'][k.replace('bn1', 'top.0.bn')] = v
|
493 |
-
|
494 |
-
load_state_dict(
|
495 |
-
self.top, state_dict['top'], strict=False, logger=logger)
|
496 |
-
for i in range(self.num_stages):
|
497 |
-
load_state_dict(
|
498 |
-
self.multi_stage_mspn[i].downsample,
|
499 |
-
state_dict['bottlenecks'],
|
500 |
-
strict=False,
|
501 |
-
logger=logger)
|
502 |
-
else:
|
503 |
-
for m in self.multi_stage_mspn.modules():
|
504 |
-
if isinstance(m, nn.Conv2d):
|
505 |
-
kaiming_init(m)
|
506 |
-
elif isinstance(m, nn.BatchNorm2d):
|
507 |
-
constant_init(m, 1)
|
508 |
-
elif isinstance(m, nn.Linear):
|
509 |
-
normal_init(m, std=0.01)
|
510 |
-
|
511 |
-
for m in self.top.modules():
|
512 |
-
if isinstance(m, nn.Conv2d):
|
513 |
-
kaiming_init(m)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/pvt.py
DELETED
@@ -1,592 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import math
|
3 |
-
import warnings
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from mmcv.cnn import (Conv2d, build_activation_layer, build_norm_layer,
|
10 |
-
constant_init, normal_init, trunc_normal_init)
|
11 |
-
from mmcv.cnn.bricks.drop import build_dropout
|
12 |
-
from mmcv.cnn.bricks.transformer import MultiheadAttention
|
13 |
-
from mmcv.cnn.utils.weight_init import trunc_normal_
|
14 |
-
from mmcv.runner import (BaseModule, ModuleList, Sequential, _load_checkpoint,
|
15 |
-
load_state_dict)
|
16 |
-
from torch.nn.modules.utils import _pair as to_2tuple
|
17 |
-
|
18 |
-
from ...utils import get_root_logger
|
19 |
-
from ..builder import BACKBONES
|
20 |
-
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw, pvt_convert
|
21 |
-
|
22 |
-
|
23 |
-
class MixFFN(BaseModule):
|
24 |
-
"""An implementation of MixFFN of PVT.
|
25 |
-
|
26 |
-
The differences between MixFFN & FFN:
|
27 |
-
1. Use 1X1 Conv to replace Linear layer.
|
28 |
-
2. Introduce 3X3 Depth-wise Conv to encode positional information.
|
29 |
-
|
30 |
-
Args:
|
31 |
-
embed_dims (int): The feature dimension. Same as
|
32 |
-
`MultiheadAttention`.
|
33 |
-
feedforward_channels (int): The hidden dimension of FFNs.
|
34 |
-
act_cfg (dict, optional): The activation config for FFNs.
|
35 |
-
Default: dict(type='GELU').
|
36 |
-
ffn_drop (float, optional): Probability of an element to be
|
37 |
-
zeroed in FFN. Default 0.0.
|
38 |
-
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
39 |
-
when adding the shortcut.
|
40 |
-
Default: None.
|
41 |
-
use_conv (bool): If True, add 3x3 DWConv between two Linear layers.
|
42 |
-
Defaults: False.
|
43 |
-
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
44 |
-
Default: None.
|
45 |
-
"""
|
46 |
-
|
47 |
-
def __init__(self,
|
48 |
-
embed_dims,
|
49 |
-
feedforward_channels,
|
50 |
-
act_cfg=dict(type='GELU'),
|
51 |
-
ffn_drop=0.,
|
52 |
-
dropout_layer=None,
|
53 |
-
use_conv=False,
|
54 |
-
init_cfg=None):
|
55 |
-
super(MixFFN, self).__init__(init_cfg=init_cfg)
|
56 |
-
|
57 |
-
self.embed_dims = embed_dims
|
58 |
-
self.feedforward_channels = feedforward_channels
|
59 |
-
self.act_cfg = act_cfg
|
60 |
-
activate = build_activation_layer(act_cfg)
|
61 |
-
|
62 |
-
in_channels = embed_dims
|
63 |
-
fc1 = Conv2d(
|
64 |
-
in_channels=in_channels,
|
65 |
-
out_channels=feedforward_channels,
|
66 |
-
kernel_size=1,
|
67 |
-
stride=1,
|
68 |
-
bias=True)
|
69 |
-
if use_conv:
|
70 |
-
# 3x3 depth wise conv to provide positional encode information
|
71 |
-
dw_conv = Conv2d(
|
72 |
-
in_channels=feedforward_channels,
|
73 |
-
out_channels=feedforward_channels,
|
74 |
-
kernel_size=3,
|
75 |
-
stride=1,
|
76 |
-
padding=(3 - 1) // 2,
|
77 |
-
bias=True,
|
78 |
-
groups=feedforward_channels)
|
79 |
-
fc2 = Conv2d(
|
80 |
-
in_channels=feedforward_channels,
|
81 |
-
out_channels=in_channels,
|
82 |
-
kernel_size=1,
|
83 |
-
stride=1,
|
84 |
-
bias=True)
|
85 |
-
drop = nn.Dropout(ffn_drop)
|
86 |
-
layers = [fc1, activate, drop, fc2, drop]
|
87 |
-
if use_conv:
|
88 |
-
layers.insert(1, dw_conv)
|
89 |
-
self.layers = Sequential(*layers)
|
90 |
-
self.dropout_layer = build_dropout(
|
91 |
-
dropout_layer) if dropout_layer else torch.nn.Identity()
|
92 |
-
|
93 |
-
def forward(self, x, hw_shape, identity=None):
|
94 |
-
out = nlc_to_nchw(x, hw_shape)
|
95 |
-
out = self.layers(out)
|
96 |
-
out = nchw_to_nlc(out)
|
97 |
-
if identity is None:
|
98 |
-
identity = x
|
99 |
-
return identity + self.dropout_layer(out)
|
100 |
-
|
101 |
-
|
102 |
-
class SpatialReductionAttention(MultiheadAttention):
|
103 |
-
"""An implementation of Spatial Reduction Attention of PVT.
|
104 |
-
|
105 |
-
This module is modified from MultiheadAttention which is a module from
|
106 |
-
mmcv.cnn.bricks.transformer.
|
107 |
-
|
108 |
-
Args:
|
109 |
-
embed_dims (int): The embedding dimension.
|
110 |
-
num_heads (int): Parallel attention heads.
|
111 |
-
attn_drop (float): A Dropout layer on attn_output_weights.
|
112 |
-
Default: 0.0.
|
113 |
-
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
114 |
-
Default: 0.0.
|
115 |
-
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
116 |
-
when adding the shortcut. Default: None.
|
117 |
-
batch_first (bool): Key, Query and Value are shape of
|
118 |
-
(batch, n, embed_dim)
|
119 |
-
or (n, batch, embed_dim). Default: False.
|
120 |
-
qkv_bias (bool): enable bias for qkv if True. Default: True.
|
121 |
-
norm_cfg (dict): Config dict for normalization layer.
|
122 |
-
Default: dict(type='LN').
|
123 |
-
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
|
124 |
-
Attention of PVT. Default: 1.
|
125 |
-
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
126 |
-
Default: None.
|
127 |
-
"""
|
128 |
-
|
129 |
-
def __init__(self,
|
130 |
-
embed_dims,
|
131 |
-
num_heads,
|
132 |
-
attn_drop=0.,
|
133 |
-
proj_drop=0.,
|
134 |
-
dropout_layer=None,
|
135 |
-
batch_first=True,
|
136 |
-
qkv_bias=True,
|
137 |
-
norm_cfg=dict(type='LN'),
|
138 |
-
sr_ratio=1,
|
139 |
-
init_cfg=None):
|
140 |
-
super().__init__(
|
141 |
-
embed_dims,
|
142 |
-
num_heads,
|
143 |
-
attn_drop,
|
144 |
-
proj_drop,
|
145 |
-
batch_first=batch_first,
|
146 |
-
dropout_layer=dropout_layer,
|
147 |
-
bias=qkv_bias,
|
148 |
-
init_cfg=init_cfg)
|
149 |
-
|
150 |
-
self.sr_ratio = sr_ratio
|
151 |
-
if sr_ratio > 1:
|
152 |
-
self.sr = Conv2d(
|
153 |
-
in_channels=embed_dims,
|
154 |
-
out_channels=embed_dims,
|
155 |
-
kernel_size=sr_ratio,
|
156 |
-
stride=sr_ratio)
|
157 |
-
# The ret[0] of build_norm_layer is norm name.
|
158 |
-
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
159 |
-
|
160 |
-
# handle the BC-breaking from https://github.com/open-mmlab/mmcv/pull/1418 # noqa
|
161 |
-
from mmpose import digit_version, mmcv_version
|
162 |
-
if mmcv_version < digit_version('1.3.17'):
|
163 |
-
warnings.warn('The legacy version of forward function in'
|
164 |
-
'SpatialReductionAttention is deprecated in'
|
165 |
-
'mmcv>=1.3.17 and will no longer support in the'
|
166 |
-
'future. Please upgrade your mmcv.')
|
167 |
-
self.forward = self.legacy_forward
|
168 |
-
|
169 |
-
def forward(self, x, hw_shape, identity=None):
|
170 |
-
|
171 |
-
x_q = x
|
172 |
-
if self.sr_ratio > 1:
|
173 |
-
x_kv = nlc_to_nchw(x, hw_shape)
|
174 |
-
x_kv = self.sr(x_kv)
|
175 |
-
x_kv = nchw_to_nlc(x_kv)
|
176 |
-
x_kv = self.norm(x_kv)
|
177 |
-
else:
|
178 |
-
x_kv = x
|
179 |
-
|
180 |
-
if identity is None:
|
181 |
-
identity = x_q
|
182 |
-
|
183 |
-
# Because the dataflow('key', 'query', 'value') of
|
184 |
-
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
185 |
-
# embed_dims), We should adjust the shape of dataflow from
|
186 |
-
# batch_first (batch, num_query, embed_dims) to num_query_first
|
187 |
-
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
188 |
-
# from num_query_first to batch_first.
|
189 |
-
if self.batch_first:
|
190 |
-
x_q = x_q.transpose(0, 1)
|
191 |
-
x_kv = x_kv.transpose(0, 1)
|
192 |
-
|
193 |
-
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
|
194 |
-
|
195 |
-
if self.batch_first:
|
196 |
-
out = out.transpose(0, 1)
|
197 |
-
|
198 |
-
return identity + self.dropout_layer(self.proj_drop(out))
|
199 |
-
|
200 |
-
def legacy_forward(self, x, hw_shape, identity=None):
|
201 |
-
"""multi head attention forward in mmcv version < 1.3.17."""
|
202 |
-
x_q = x
|
203 |
-
if self.sr_ratio > 1:
|
204 |
-
x_kv = nlc_to_nchw(x, hw_shape)
|
205 |
-
x_kv = self.sr(x_kv)
|
206 |
-
x_kv = nchw_to_nlc(x_kv)
|
207 |
-
x_kv = self.norm(x_kv)
|
208 |
-
else:
|
209 |
-
x_kv = x
|
210 |
-
|
211 |
-
if identity is None:
|
212 |
-
identity = x_q
|
213 |
-
|
214 |
-
out = self.attn(query=x_q, key=x_kv, value=x_kv)[0]
|
215 |
-
|
216 |
-
return identity + self.dropout_layer(self.proj_drop(out))
|
217 |
-
|
218 |
-
|
219 |
-
class PVTEncoderLayer(BaseModule):
|
220 |
-
"""Implements one encoder layer in PVT.
|
221 |
-
|
222 |
-
Args:
|
223 |
-
embed_dims (int): The feature dimension.
|
224 |
-
num_heads (int): Parallel attention heads.
|
225 |
-
feedforward_channels (int): The hidden dimension for FFNs.
|
226 |
-
drop_rate (float): Probability of an element to be zeroed.
|
227 |
-
after the feed forward layer. Default: 0.0.
|
228 |
-
attn_drop_rate (float): The drop out rate for attention layer.
|
229 |
-
Default: 0.0.
|
230 |
-
drop_path_rate (float): stochastic depth rate. Default: 0.0.
|
231 |
-
qkv_bias (bool): enable bias for qkv if True.
|
232 |
-
Default: True.
|
233 |
-
act_cfg (dict): The activation config for FFNs.
|
234 |
-
Default: dict(type='GELU').
|
235 |
-
norm_cfg (dict): Config dict for normalization layer.
|
236 |
-
Default: dict(type='LN').
|
237 |
-
sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
|
238 |
-
Attention of PVT. Default: 1.
|
239 |
-
use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
|
240 |
-
Default: False.
|
241 |
-
init_cfg (dict, optional): Initialization config dict.
|
242 |
-
Default: None.
|
243 |
-
"""
|
244 |
-
|
245 |
-
def __init__(self,
|
246 |
-
embed_dims,
|
247 |
-
num_heads,
|
248 |
-
feedforward_channels,
|
249 |
-
drop_rate=0.,
|
250 |
-
attn_drop_rate=0.,
|
251 |
-
drop_path_rate=0.,
|
252 |
-
qkv_bias=True,
|
253 |
-
act_cfg=dict(type='GELU'),
|
254 |
-
norm_cfg=dict(type='LN'),
|
255 |
-
sr_ratio=1,
|
256 |
-
use_conv_ffn=False,
|
257 |
-
init_cfg=None):
|
258 |
-
super(PVTEncoderLayer, self).__init__(init_cfg=init_cfg)
|
259 |
-
|
260 |
-
# The ret[0] of build_norm_layer is norm name.
|
261 |
-
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
262 |
-
|
263 |
-
self.attn = SpatialReductionAttention(
|
264 |
-
embed_dims=embed_dims,
|
265 |
-
num_heads=num_heads,
|
266 |
-
attn_drop=attn_drop_rate,
|
267 |
-
proj_drop=drop_rate,
|
268 |
-
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
269 |
-
qkv_bias=qkv_bias,
|
270 |
-
norm_cfg=norm_cfg,
|
271 |
-
sr_ratio=sr_ratio)
|
272 |
-
|
273 |
-
# The ret[0] of build_norm_layer is norm name.
|
274 |
-
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
275 |
-
|
276 |
-
self.ffn = MixFFN(
|
277 |
-
embed_dims=embed_dims,
|
278 |
-
feedforward_channels=feedforward_channels,
|
279 |
-
ffn_drop=drop_rate,
|
280 |
-
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
281 |
-
use_conv=use_conv_ffn,
|
282 |
-
act_cfg=act_cfg)
|
283 |
-
|
284 |
-
def forward(self, x, hw_shape):
|
285 |
-
x = self.attn(self.norm1(x), hw_shape, identity=x)
|
286 |
-
x = self.ffn(self.norm2(x), hw_shape, identity=x)
|
287 |
-
|
288 |
-
return x
|
289 |
-
|
290 |
-
|
291 |
-
class AbsolutePositionEmbedding(BaseModule):
|
292 |
-
"""An implementation of the absolute position embedding in PVT.
|
293 |
-
|
294 |
-
Args:
|
295 |
-
pos_shape (int): The shape of the absolute position embedding.
|
296 |
-
pos_dim (int): The dimension of the absolute position embedding.
|
297 |
-
drop_rate (float): Probability of an element to be zeroed.
|
298 |
-
Default: 0.0.
|
299 |
-
"""
|
300 |
-
|
301 |
-
def __init__(self, pos_shape, pos_dim, drop_rate=0., init_cfg=None):
|
302 |
-
super().__init__(init_cfg=init_cfg)
|
303 |
-
|
304 |
-
if isinstance(pos_shape, int):
|
305 |
-
pos_shape = to_2tuple(pos_shape)
|
306 |
-
elif isinstance(pos_shape, tuple):
|
307 |
-
if len(pos_shape) == 1:
|
308 |
-
pos_shape = to_2tuple(pos_shape[0])
|
309 |
-
assert len(pos_shape) == 2, \
|
310 |
-
f'The size of image should have length 1 or 2, ' \
|
311 |
-
f'but got {len(pos_shape)}'
|
312 |
-
self.pos_shape = pos_shape
|
313 |
-
self.pos_dim = pos_dim
|
314 |
-
|
315 |
-
self.pos_embed = nn.Parameter(
|
316 |
-
torch.zeros(1, pos_shape[0] * pos_shape[1], pos_dim))
|
317 |
-
self.drop = nn.Dropout(p=drop_rate)
|
318 |
-
|
319 |
-
def init_weights(self):
|
320 |
-
trunc_normal_(self.pos_embed, std=0.02)
|
321 |
-
|
322 |
-
def resize_pos_embed(self, pos_embed, input_shape, mode='bilinear'):
|
323 |
-
"""Resize pos_embed weights.
|
324 |
-
|
325 |
-
Resize pos_embed using bilinear interpolate method.
|
326 |
-
|
327 |
-
Args:
|
328 |
-
pos_embed (torch.Tensor): Position embedding weights.
|
329 |
-
input_shape (tuple): Tuple for (downsampled input image height,
|
330 |
-
downsampled input image width).
|
331 |
-
mode (str): Algorithm used for upsampling:
|
332 |
-
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
333 |
-
``'trilinear'``. Default: ``'bilinear'``.
|
334 |
-
|
335 |
-
Return:
|
336 |
-
torch.Tensor: The resized pos_embed of shape [B, L_new, C].
|
337 |
-
"""
|
338 |
-
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
339 |
-
pos_h, pos_w = self.pos_shape
|
340 |
-
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
341 |
-
pos_embed_weight = pos_embed_weight.reshape(
|
342 |
-
1, pos_h, pos_w, self.pos_dim).permute(0, 3, 1, 2).contiguous()
|
343 |
-
pos_embed_weight = F.interpolate(
|
344 |
-
pos_embed_weight, size=input_shape, mode=mode)
|
345 |
-
pos_embed_weight = torch.flatten(pos_embed_weight,
|
346 |
-
2).transpose(1, 2).contiguous()
|
347 |
-
pos_embed = pos_embed_weight
|
348 |
-
|
349 |
-
return pos_embed
|
350 |
-
|
351 |
-
def forward(self, x, hw_shape, mode='bilinear'):
|
352 |
-
pos_embed = self.resize_pos_embed(self.pos_embed, hw_shape, mode)
|
353 |
-
return self.drop(x + pos_embed)
|
354 |
-
|
355 |
-
|
356 |
-
@BACKBONES.register_module()
|
357 |
-
class PyramidVisionTransformer(BaseModule):
|
358 |
-
"""Pyramid Vision Transformer (PVT)
|
359 |
-
|
360 |
-
Implementation of `Pyramid Vision Transformer: A Versatile Backbone for
|
361 |
-
Dense Prediction without Convolutions
|
362 |
-
<https://arxiv.org/pdf/2102.12122.pdf>`_.
|
363 |
-
|
364 |
-
Args:
|
365 |
-
pretrain_img_size (int | tuple[int]): The size of input image when
|
366 |
-
pretrain. Defaults: 224.
|
367 |
-
in_channels (int): Number of input channels. Default: 3.
|
368 |
-
embed_dims (int): Embedding dimension. Default: 64.
|
369 |
-
num_stags (int): The num of stages. Default: 4.
|
370 |
-
num_layers (Sequence[int]): The layer number of each transformer encode
|
371 |
-
layer. Default: [3, 4, 6, 3].
|
372 |
-
num_heads (Sequence[int]): The attention heads of each transformer
|
373 |
-
encode layer. Default: [1, 2, 5, 8].
|
374 |
-
patch_sizes (Sequence[int]): The patch_size of each patch embedding.
|
375 |
-
Default: [4, 2, 2, 2].
|
376 |
-
strides (Sequence[int]): The stride of each patch embedding.
|
377 |
-
Default: [4, 2, 2, 2].
|
378 |
-
paddings (Sequence[int]): The padding of each patch embedding.
|
379 |
-
Default: [0, 0, 0, 0].
|
380 |
-
sr_ratios (Sequence[int]): The spatial reduction rate of each
|
381 |
-
transformer encode layer. Default: [8, 4, 2, 1].
|
382 |
-
out_indices (Sequence[int] | int): Output from which stages.
|
383 |
-
Default: (0, 1, 2, 3).
|
384 |
-
mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
|
385 |
-
embedding dim of each transformer encode layer.
|
386 |
-
Default: [8, 8, 4, 4].
|
387 |
-
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
388 |
-
drop_rate (float): Probability of an element to be zeroed.
|
389 |
-
Default 0.0.
|
390 |
-
attn_drop_rate (float): The drop out rate for attention layer.
|
391 |
-
Default 0.0.
|
392 |
-
drop_path_rate (float): stochastic depth rate. Default 0.1.
|
393 |
-
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
394 |
-
the patch embedding. Defaults: True.
|
395 |
-
use_conv_ffn (bool): If True, use Convolutional FFN to replace FFN.
|
396 |
-
Default: False.
|
397 |
-
act_cfg (dict): The activation config for FFNs.
|
398 |
-
Default: dict(type='GELU').
|
399 |
-
norm_cfg (dict): Config dict for normalization layer.
|
400 |
-
Default: dict(type='LN').
|
401 |
-
pretrained (str, optional): model pretrained path. Default: None.
|
402 |
-
convert_weights (bool): The flag indicates whether the
|
403 |
-
pre-trained model is from the original repo. We may need
|
404 |
-
to convert some keys to make it compatible.
|
405 |
-
Default: True.
|
406 |
-
init_cfg (dict or list[dict], optional): Initialization config dict.
|
407 |
-
Default: None.
|
408 |
-
"""
|
409 |
-
|
410 |
-
def __init__(self,
|
411 |
-
pretrain_img_size=224,
|
412 |
-
in_channels=3,
|
413 |
-
embed_dims=64,
|
414 |
-
num_stages=4,
|
415 |
-
num_layers=[3, 4, 6, 3],
|
416 |
-
num_heads=[1, 2, 5, 8],
|
417 |
-
patch_sizes=[4, 2, 2, 2],
|
418 |
-
strides=[4, 2, 2, 2],
|
419 |
-
paddings=[0, 0, 0, 0],
|
420 |
-
sr_ratios=[8, 4, 2, 1],
|
421 |
-
out_indices=(0, 1, 2, 3),
|
422 |
-
mlp_ratios=[8, 8, 4, 4],
|
423 |
-
qkv_bias=True,
|
424 |
-
drop_rate=0.,
|
425 |
-
attn_drop_rate=0.,
|
426 |
-
drop_path_rate=0.1,
|
427 |
-
use_abs_pos_embed=True,
|
428 |
-
norm_after_stage=False,
|
429 |
-
use_conv_ffn=False,
|
430 |
-
act_cfg=dict(type='GELU'),
|
431 |
-
norm_cfg=dict(type='LN', eps=1e-6),
|
432 |
-
pretrained=None,
|
433 |
-
convert_weights=True,
|
434 |
-
init_cfg=None):
|
435 |
-
super().__init__(init_cfg=init_cfg)
|
436 |
-
|
437 |
-
self.convert_weights = convert_weights
|
438 |
-
if isinstance(pretrain_img_size, int):
|
439 |
-
pretrain_img_size = to_2tuple(pretrain_img_size)
|
440 |
-
elif isinstance(pretrain_img_size, tuple):
|
441 |
-
if len(pretrain_img_size) == 1:
|
442 |
-
pretrain_img_size = to_2tuple(pretrain_img_size[0])
|
443 |
-
assert len(pretrain_img_size) == 2, \
|
444 |
-
f'The size of image should have length 1 or 2, ' \
|
445 |
-
f'but got {len(pretrain_img_size)}'
|
446 |
-
|
447 |
-
assert not (init_cfg and pretrained), \
|
448 |
-
'init_cfg and pretrained cannot be setting at the same time'
|
449 |
-
if isinstance(pretrained, str):
|
450 |
-
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
451 |
-
elif pretrained is None:
|
452 |
-
self.init_cfg = init_cfg
|
453 |
-
else:
|
454 |
-
raise TypeError('pretrained must be a str or None')
|
455 |
-
|
456 |
-
self.embed_dims = embed_dims
|
457 |
-
|
458 |
-
self.num_stages = num_stages
|
459 |
-
self.num_layers = num_layers
|
460 |
-
self.num_heads = num_heads
|
461 |
-
self.patch_sizes = patch_sizes
|
462 |
-
self.strides = strides
|
463 |
-
self.sr_ratios = sr_ratios
|
464 |
-
assert num_stages == len(num_layers) == len(num_heads) \
|
465 |
-
== len(patch_sizes) == len(strides) == len(sr_ratios)
|
466 |
-
|
467 |
-
self.out_indices = out_indices
|
468 |
-
assert max(out_indices) < self.num_stages
|
469 |
-
self.pretrained = pretrained
|
470 |
-
|
471 |
-
# transformer encoder
|
472 |
-
dpr = [
|
473 |
-
x.item()
|
474 |
-
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
|
475 |
-
] # stochastic num_layer decay rule
|
476 |
-
|
477 |
-
cur = 0
|
478 |
-
self.layers = ModuleList()
|
479 |
-
for i, num_layer in enumerate(num_layers):
|
480 |
-
embed_dims_i = embed_dims * num_heads[i]
|
481 |
-
patch_embed = PatchEmbed(
|
482 |
-
in_channels=in_channels,
|
483 |
-
embed_dims=embed_dims_i,
|
484 |
-
kernel_size=patch_sizes[i],
|
485 |
-
stride=strides[i],
|
486 |
-
padding=paddings[i],
|
487 |
-
bias=True,
|
488 |
-
norm_cfg=norm_cfg)
|
489 |
-
|
490 |
-
layers = ModuleList()
|
491 |
-
if use_abs_pos_embed:
|
492 |
-
pos_shape = pretrain_img_size // np.prod(patch_sizes[:i + 1])
|
493 |
-
pos_embed = AbsolutePositionEmbedding(
|
494 |
-
pos_shape=pos_shape,
|
495 |
-
pos_dim=embed_dims_i,
|
496 |
-
drop_rate=drop_rate)
|
497 |
-
layers.append(pos_embed)
|
498 |
-
layers.extend([
|
499 |
-
PVTEncoderLayer(
|
500 |
-
embed_dims=embed_dims_i,
|
501 |
-
num_heads=num_heads[i],
|
502 |
-
feedforward_channels=mlp_ratios[i] * embed_dims_i,
|
503 |
-
drop_rate=drop_rate,
|
504 |
-
attn_drop_rate=attn_drop_rate,
|
505 |
-
drop_path_rate=dpr[cur + idx],
|
506 |
-
qkv_bias=qkv_bias,
|
507 |
-
act_cfg=act_cfg,
|
508 |
-
norm_cfg=norm_cfg,
|
509 |
-
sr_ratio=sr_ratios[i],
|
510 |
-
use_conv_ffn=use_conv_ffn) for idx in range(num_layer)
|
511 |
-
])
|
512 |
-
in_channels = embed_dims_i
|
513 |
-
# The ret[0] of build_norm_layer is norm name.
|
514 |
-
if norm_after_stage:
|
515 |
-
norm = build_norm_layer(norm_cfg, embed_dims_i)[1]
|
516 |
-
else:
|
517 |
-
norm = nn.Identity()
|
518 |
-
self.layers.append(ModuleList([patch_embed, layers, norm]))
|
519 |
-
cur += num_layer
|
520 |
-
|
521 |
-
def init_weights(self, pretrained=None):
|
522 |
-
if isinstance(pretrained, str):
|
523 |
-
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
|
524 |
-
|
525 |
-
logger = get_root_logger()
|
526 |
-
if self.init_cfg is None:
|
527 |
-
logger.warn(f'No pre-trained weights for '
|
528 |
-
f'{self.__class__.__name__}, '
|
529 |
-
f'training start from scratch')
|
530 |
-
for m in self.modules():
|
531 |
-
if isinstance(m, nn.Linear):
|
532 |
-
trunc_normal_init(m, std=.02, bias=0.)
|
533 |
-
elif isinstance(m, nn.LayerNorm):
|
534 |
-
constant_init(m, 1.0)
|
535 |
-
elif isinstance(m, nn.Conv2d):
|
536 |
-
fan_out = m.kernel_size[0] * m.kernel_size[
|
537 |
-
1] * m.out_channels
|
538 |
-
fan_out //= m.groups
|
539 |
-
normal_init(m, 0, math.sqrt(2.0 / fan_out))
|
540 |
-
elif isinstance(m, AbsolutePositionEmbedding):
|
541 |
-
m.init_weights()
|
542 |
-
else:
|
543 |
-
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
544 |
-
f'specify `Pretrained` in ' \
|
545 |
-
f'`init_cfg` in ' \
|
546 |
-
f'{self.__class__.__name__} '
|
547 |
-
checkpoint = _load_checkpoint(
|
548 |
-
self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
|
549 |
-
logger.warn(f'Load pre-trained model for '
|
550 |
-
f'{self.__class__.__name__} from original repo')
|
551 |
-
if 'state_dict' in checkpoint:
|
552 |
-
state_dict = checkpoint['state_dict']
|
553 |
-
elif 'model' in checkpoint:
|
554 |
-
state_dict = checkpoint['model']
|
555 |
-
else:
|
556 |
-
state_dict = checkpoint
|
557 |
-
if self.convert_weights:
|
558 |
-
# Because pvt backbones are not supported by mmcls,
|
559 |
-
# so we need to convert pre-trained weights to match this
|
560 |
-
# implementation.
|
561 |
-
state_dict = pvt_convert(state_dict)
|
562 |
-
load_state_dict(self, state_dict, strict=False, logger=logger)
|
563 |
-
|
564 |
-
def forward(self, x):
|
565 |
-
outs = []
|
566 |
-
|
567 |
-
for i, layer in enumerate(self.layers):
|
568 |
-
x, hw_shape = layer[0](x)
|
569 |
-
|
570 |
-
for block in layer[1]:
|
571 |
-
x = block(x, hw_shape)
|
572 |
-
x = layer[2](x)
|
573 |
-
x = nlc_to_nchw(x, hw_shape)
|
574 |
-
if i in self.out_indices:
|
575 |
-
outs.append(x)
|
576 |
-
|
577 |
-
return outs
|
578 |
-
|
579 |
-
|
580 |
-
@BACKBONES.register_module()
|
581 |
-
class PyramidVisionTransformerV2(PyramidVisionTransformer):
|
582 |
-
"""Implementation of `PVTv2: Improved Baselines with Pyramid Vision
|
583 |
-
Transformer <https://arxiv.org/pdf/2106.13797.pdf>`_."""
|
584 |
-
|
585 |
-
def __init__(self, **kwargs):
|
586 |
-
super(PyramidVisionTransformerV2, self).__init__(
|
587 |
-
patch_sizes=[7, 3, 3, 3],
|
588 |
-
paddings=[3, 1, 1, 1],
|
589 |
-
use_abs_pos_embed=False,
|
590 |
-
norm_after_stage=True,
|
591 |
-
use_conv_ffn=True,
|
592 |
-
**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/regnet.py
DELETED
@@ -1,317 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch.nn as nn
|
6 |
-
from mmcv.cnn import build_conv_layer, build_norm_layer
|
7 |
-
|
8 |
-
from ..builder import BACKBONES
|
9 |
-
from .resnet import ResNet
|
10 |
-
from .resnext import Bottleneck
|
11 |
-
|
12 |
-
|
13 |
-
@BACKBONES.register_module()
|
14 |
-
class RegNet(ResNet):
|
15 |
-
"""RegNet backbone.
|
16 |
-
|
17 |
-
More details can be found in `paper <https://arxiv.org/abs/2003.13678>`__ .
|
18 |
-
|
19 |
-
Args:
|
20 |
-
arch (dict): The parameter of RegNets.
|
21 |
-
- w0 (int): initial width
|
22 |
-
- wa (float): slope of width
|
23 |
-
- wm (float): quantization parameter to quantize the width
|
24 |
-
- depth (int): depth of the backbone
|
25 |
-
- group_w (int): width of group
|
26 |
-
- bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck.
|
27 |
-
strides (Sequence[int]): Strides of the first block of each stage.
|
28 |
-
base_channels (int): Base channels after stem layer.
|
29 |
-
in_channels (int): Number of input image channels. Default: 3.
|
30 |
-
dilations (Sequence[int]): Dilation of each stage.
|
31 |
-
out_indices (Sequence[int]): Output from which stages.
|
32 |
-
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
33 |
-
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
34 |
-
the first 1x1 conv layer. Default: "pytorch".
|
35 |
-
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
36 |
-
not freezing any parameters. Default: -1.
|
37 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
38 |
-
Default: dict(type='BN', requires_grad=True).
|
39 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
40 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
41 |
-
and its variants only. Default: False.
|
42 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
43 |
-
memory while slowing down the training speed. Default: False.
|
44 |
-
zero_init_residual (bool): whether to use zero init for last norm layer
|
45 |
-
in resblocks to let them behave as identity. Default: True.
|
46 |
-
|
47 |
-
Example:
|
48 |
-
>>> from mmpose.models import RegNet
|
49 |
-
>>> import torch
|
50 |
-
>>> self = RegNet(
|
51 |
-
arch=dict(
|
52 |
-
w0=88,
|
53 |
-
wa=26.31,
|
54 |
-
wm=2.25,
|
55 |
-
group_w=48,
|
56 |
-
depth=25,
|
57 |
-
bot_mul=1.0),
|
58 |
-
out_indices=(0, 1, 2, 3))
|
59 |
-
>>> self.eval()
|
60 |
-
>>> inputs = torch.rand(1, 3, 32, 32)
|
61 |
-
>>> level_outputs = self.forward(inputs)
|
62 |
-
>>> for level_out in level_outputs:
|
63 |
-
... print(tuple(level_out.shape))
|
64 |
-
(1, 96, 8, 8)
|
65 |
-
(1, 192, 4, 4)
|
66 |
-
(1, 432, 2, 2)
|
67 |
-
(1, 1008, 1, 1)
|
68 |
-
"""
|
69 |
-
arch_settings = {
|
70 |
-
'regnetx_400mf':
|
71 |
-
dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
|
72 |
-
'regnetx_800mf':
|
73 |
-
dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
|
74 |
-
'regnetx_1.6gf':
|
75 |
-
dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
|
76 |
-
'regnetx_3.2gf':
|
77 |
-
dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
|
78 |
-
'regnetx_4.0gf':
|
79 |
-
dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
|
80 |
-
'regnetx_6.4gf':
|
81 |
-
dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
|
82 |
-
'regnetx_8.0gf':
|
83 |
-
dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
|
84 |
-
'regnetx_12gf':
|
85 |
-
dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
|
86 |
-
}
|
87 |
-
|
88 |
-
def __init__(self,
|
89 |
-
arch,
|
90 |
-
in_channels=3,
|
91 |
-
stem_channels=32,
|
92 |
-
base_channels=32,
|
93 |
-
strides=(2, 2, 2, 2),
|
94 |
-
dilations=(1, 1, 1, 1),
|
95 |
-
out_indices=(3, ),
|
96 |
-
style='pytorch',
|
97 |
-
deep_stem=False,
|
98 |
-
avg_down=False,
|
99 |
-
frozen_stages=-1,
|
100 |
-
conv_cfg=None,
|
101 |
-
norm_cfg=dict(type='BN', requires_grad=True),
|
102 |
-
norm_eval=False,
|
103 |
-
with_cp=False,
|
104 |
-
zero_init_residual=True):
|
105 |
-
# Protect mutable default arguments
|
106 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
107 |
-
super(ResNet, self).__init__()
|
108 |
-
|
109 |
-
# Generate RegNet parameters first
|
110 |
-
if isinstance(arch, str):
|
111 |
-
assert arch in self.arch_settings, \
|
112 |
-
f'"arch": "{arch}" is not one of the' \
|
113 |
-
' arch_settings'
|
114 |
-
arch = self.arch_settings[arch]
|
115 |
-
elif not isinstance(arch, dict):
|
116 |
-
raise TypeError('Expect "arch" to be either a string '
|
117 |
-
f'or a dict, got {type(arch)}')
|
118 |
-
|
119 |
-
widths, num_stages = self.generate_regnet(
|
120 |
-
arch['w0'],
|
121 |
-
arch['wa'],
|
122 |
-
arch['wm'],
|
123 |
-
arch['depth'],
|
124 |
-
)
|
125 |
-
# Convert to per stage format
|
126 |
-
stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
|
127 |
-
# Generate group widths and bot muls
|
128 |
-
group_widths = [arch['group_w'] for _ in range(num_stages)]
|
129 |
-
self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
|
130 |
-
# Adjust the compatibility of stage_widths and group_widths
|
131 |
-
stage_widths, group_widths = self.adjust_width_group(
|
132 |
-
stage_widths, self.bottleneck_ratio, group_widths)
|
133 |
-
|
134 |
-
# Group params by stage
|
135 |
-
self.stage_widths = stage_widths
|
136 |
-
self.group_widths = group_widths
|
137 |
-
self.depth = sum(stage_blocks)
|
138 |
-
self.stem_channels = stem_channels
|
139 |
-
self.base_channels = base_channels
|
140 |
-
self.num_stages = num_stages
|
141 |
-
assert 1 <= num_stages <= 4
|
142 |
-
self.strides = strides
|
143 |
-
self.dilations = dilations
|
144 |
-
assert len(strides) == len(dilations) == num_stages
|
145 |
-
self.out_indices = out_indices
|
146 |
-
assert max(out_indices) < num_stages
|
147 |
-
self.style = style
|
148 |
-
self.deep_stem = deep_stem
|
149 |
-
if self.deep_stem:
|
150 |
-
raise NotImplementedError(
|
151 |
-
'deep_stem has not been implemented for RegNet')
|
152 |
-
self.avg_down = avg_down
|
153 |
-
self.frozen_stages = frozen_stages
|
154 |
-
self.conv_cfg = conv_cfg
|
155 |
-
self.norm_cfg = norm_cfg
|
156 |
-
self.with_cp = with_cp
|
157 |
-
self.norm_eval = norm_eval
|
158 |
-
self.zero_init_residual = zero_init_residual
|
159 |
-
self.stage_blocks = stage_blocks[:num_stages]
|
160 |
-
|
161 |
-
self._make_stem_layer(in_channels, stem_channels)
|
162 |
-
|
163 |
-
_in_channels = stem_channels
|
164 |
-
self.res_layers = []
|
165 |
-
for i, num_blocks in enumerate(self.stage_blocks):
|
166 |
-
stride = self.strides[i]
|
167 |
-
dilation = self.dilations[i]
|
168 |
-
group_width = self.group_widths[i]
|
169 |
-
width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
|
170 |
-
stage_groups = width // group_width
|
171 |
-
|
172 |
-
res_layer = self.make_res_layer(
|
173 |
-
block=Bottleneck,
|
174 |
-
num_blocks=num_blocks,
|
175 |
-
in_channels=_in_channels,
|
176 |
-
out_channels=self.stage_widths[i],
|
177 |
-
expansion=1,
|
178 |
-
stride=stride,
|
179 |
-
dilation=dilation,
|
180 |
-
style=self.style,
|
181 |
-
avg_down=self.avg_down,
|
182 |
-
with_cp=self.with_cp,
|
183 |
-
conv_cfg=self.conv_cfg,
|
184 |
-
norm_cfg=self.norm_cfg,
|
185 |
-
base_channels=self.stage_widths[i],
|
186 |
-
groups=stage_groups,
|
187 |
-
width_per_group=group_width)
|
188 |
-
_in_channels = self.stage_widths[i]
|
189 |
-
layer_name = f'layer{i + 1}'
|
190 |
-
self.add_module(layer_name, res_layer)
|
191 |
-
self.res_layers.append(layer_name)
|
192 |
-
|
193 |
-
self._freeze_stages()
|
194 |
-
|
195 |
-
self.feat_dim = stage_widths[-1]
|
196 |
-
|
197 |
-
def _make_stem_layer(self, in_channels, base_channels):
|
198 |
-
self.conv1 = build_conv_layer(
|
199 |
-
self.conv_cfg,
|
200 |
-
in_channels,
|
201 |
-
base_channels,
|
202 |
-
kernel_size=3,
|
203 |
-
stride=2,
|
204 |
-
padding=1,
|
205 |
-
bias=False)
|
206 |
-
self.norm1_name, norm1 = build_norm_layer(
|
207 |
-
self.norm_cfg, base_channels, postfix=1)
|
208 |
-
self.add_module(self.norm1_name, norm1)
|
209 |
-
self.relu = nn.ReLU(inplace=True)
|
210 |
-
|
211 |
-
@staticmethod
|
212 |
-
def generate_regnet(initial_width,
|
213 |
-
width_slope,
|
214 |
-
width_parameter,
|
215 |
-
depth,
|
216 |
-
divisor=8):
|
217 |
-
"""Generates per block width from RegNet parameters.
|
218 |
-
|
219 |
-
Args:
|
220 |
-
initial_width ([int]): Initial width of the backbone
|
221 |
-
width_slope ([float]): Slope of the quantized linear function
|
222 |
-
width_parameter ([int]): Parameter used to quantize the width.
|
223 |
-
depth ([int]): Depth of the backbone.
|
224 |
-
divisor (int, optional): The divisor of channels. Defaults to 8.
|
225 |
-
|
226 |
-
Returns:
|
227 |
-
list, int: return a list of widths of each stage and the number of
|
228 |
-
stages
|
229 |
-
"""
|
230 |
-
assert width_slope >= 0
|
231 |
-
assert initial_width > 0
|
232 |
-
assert width_parameter > 1
|
233 |
-
assert initial_width % divisor == 0
|
234 |
-
widths_cont = np.arange(depth) * width_slope + initial_width
|
235 |
-
ks = np.round(
|
236 |
-
np.log(widths_cont / initial_width) / np.log(width_parameter))
|
237 |
-
widths = initial_width * np.power(width_parameter, ks)
|
238 |
-
widths = np.round(np.divide(widths, divisor)) * divisor
|
239 |
-
num_stages = len(np.unique(widths))
|
240 |
-
widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
|
241 |
-
return widths, num_stages
|
242 |
-
|
243 |
-
@staticmethod
|
244 |
-
def quantize_float(number, divisor):
|
245 |
-
"""Converts a float to closest non-zero int divisible by divior.
|
246 |
-
|
247 |
-
Args:
|
248 |
-
number (int): Original number to be quantized.
|
249 |
-
divisor (int): Divisor used to quantize the number.
|
250 |
-
|
251 |
-
Returns:
|
252 |
-
int: quantized number that is divisible by devisor.
|
253 |
-
"""
|
254 |
-
return int(round(number / divisor) * divisor)
|
255 |
-
|
256 |
-
def adjust_width_group(self, widths, bottleneck_ratio, groups):
|
257 |
-
"""Adjusts the compatibility of widths and groups.
|
258 |
-
|
259 |
-
Args:
|
260 |
-
widths (list[int]): Width of each stage.
|
261 |
-
bottleneck_ratio (float): Bottleneck ratio.
|
262 |
-
groups (int): number of groups in each stage
|
263 |
-
|
264 |
-
Returns:
|
265 |
-
tuple(list): The adjusted widths and groups of each stage.
|
266 |
-
"""
|
267 |
-
bottleneck_width = [
|
268 |
-
int(w * b) for w, b in zip(widths, bottleneck_ratio)
|
269 |
-
]
|
270 |
-
groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
|
271 |
-
bottleneck_width = [
|
272 |
-
self.quantize_float(w_bot, g)
|
273 |
-
for w_bot, g in zip(bottleneck_width, groups)
|
274 |
-
]
|
275 |
-
widths = [
|
276 |
-
int(w_bot / b)
|
277 |
-
for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
|
278 |
-
]
|
279 |
-
return widths, groups
|
280 |
-
|
281 |
-
def get_stages_from_blocks(self, widths):
|
282 |
-
"""Gets widths/stage_blocks of network at each stage.
|
283 |
-
|
284 |
-
Args:
|
285 |
-
widths (list[int]): Width in each stage.
|
286 |
-
|
287 |
-
Returns:
|
288 |
-
tuple(list): width and depth of each stage
|
289 |
-
"""
|
290 |
-
width_diff = [
|
291 |
-
width != width_prev
|
292 |
-
for width, width_prev in zip(widths + [0], [0] + widths)
|
293 |
-
]
|
294 |
-
stage_widths = [
|
295 |
-
width for width, diff in zip(widths, width_diff[:-1]) if diff
|
296 |
-
]
|
297 |
-
stage_blocks = np.diff([
|
298 |
-
depth for depth, diff in zip(range(len(width_diff)), width_diff)
|
299 |
-
if diff
|
300 |
-
]).tolist()
|
301 |
-
return stage_widths, stage_blocks
|
302 |
-
|
303 |
-
def forward(self, x):
|
304 |
-
x = self.conv1(x)
|
305 |
-
x = self.norm1(x)
|
306 |
-
x = self.relu(x)
|
307 |
-
|
308 |
-
outs = []
|
309 |
-
for i, layer_name in enumerate(self.res_layers):
|
310 |
-
res_layer = getattr(self, layer_name)
|
311 |
-
x = res_layer(x)
|
312 |
-
if i in self.out_indices:
|
313 |
-
outs.append(x)
|
314 |
-
|
315 |
-
if len(outs) == 1:
|
316 |
-
return outs[0]
|
317 |
-
return tuple(outs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/resnest.py
DELETED
@@ -1,338 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
import torch.utils.checkpoint as cp
|
6 |
-
from mmcv.cnn import build_conv_layer, build_norm_layer
|
7 |
-
|
8 |
-
from ..builder import BACKBONES
|
9 |
-
from .resnet import Bottleneck as _Bottleneck
|
10 |
-
from .resnet import ResLayer, ResNetV1d
|
11 |
-
|
12 |
-
|
13 |
-
class RSoftmax(nn.Module):
|
14 |
-
"""Radix Softmax module in ``SplitAttentionConv2d``.
|
15 |
-
|
16 |
-
Args:
|
17 |
-
radix (int): Radix of input.
|
18 |
-
groups (int): Groups of input.
|
19 |
-
"""
|
20 |
-
|
21 |
-
def __init__(self, radix, groups):
|
22 |
-
super().__init__()
|
23 |
-
self.radix = radix
|
24 |
-
self.groups = groups
|
25 |
-
|
26 |
-
def forward(self, x):
|
27 |
-
batch = x.size(0)
|
28 |
-
if self.radix > 1:
|
29 |
-
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
|
30 |
-
x = F.softmax(x, dim=1)
|
31 |
-
x = x.reshape(batch, -1)
|
32 |
-
else:
|
33 |
-
x = torch.sigmoid(x)
|
34 |
-
return x
|
35 |
-
|
36 |
-
|
37 |
-
class SplitAttentionConv2d(nn.Module):
|
38 |
-
"""Split-Attention Conv2d.
|
39 |
-
|
40 |
-
Args:
|
41 |
-
in_channels (int): Same as nn.Conv2d.
|
42 |
-
out_channels (int): Same as nn.Conv2d.
|
43 |
-
kernel_size (int | tuple[int]): Same as nn.Conv2d.
|
44 |
-
stride (int | tuple[int]): Same as nn.Conv2d.
|
45 |
-
padding (int | tuple[int]): Same as nn.Conv2d.
|
46 |
-
dilation (int | tuple[int]): Same as nn.Conv2d.
|
47 |
-
groups (int): Same as nn.Conv2d.
|
48 |
-
radix (int): Radix of SpltAtConv2d. Default: 2
|
49 |
-
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
|
50 |
-
Default: 4.
|
51 |
-
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
52 |
-
which means using conv2d.
|
53 |
-
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
54 |
-
"""
|
55 |
-
|
56 |
-
def __init__(self,
|
57 |
-
in_channels,
|
58 |
-
channels,
|
59 |
-
kernel_size,
|
60 |
-
stride=1,
|
61 |
-
padding=0,
|
62 |
-
dilation=1,
|
63 |
-
groups=1,
|
64 |
-
radix=2,
|
65 |
-
reduction_factor=4,
|
66 |
-
conv_cfg=None,
|
67 |
-
norm_cfg=dict(type='BN')):
|
68 |
-
super().__init__()
|
69 |
-
inter_channels = max(in_channels * radix // reduction_factor, 32)
|
70 |
-
self.radix = radix
|
71 |
-
self.groups = groups
|
72 |
-
self.channels = channels
|
73 |
-
self.conv = build_conv_layer(
|
74 |
-
conv_cfg,
|
75 |
-
in_channels,
|
76 |
-
channels * radix,
|
77 |
-
kernel_size,
|
78 |
-
stride=stride,
|
79 |
-
padding=padding,
|
80 |
-
dilation=dilation,
|
81 |
-
groups=groups * radix,
|
82 |
-
bias=False)
|
83 |
-
self.norm0_name, norm0 = build_norm_layer(
|
84 |
-
norm_cfg, channels * radix, postfix=0)
|
85 |
-
self.add_module(self.norm0_name, norm0)
|
86 |
-
self.relu = nn.ReLU(inplace=True)
|
87 |
-
self.fc1 = build_conv_layer(
|
88 |
-
None, channels, inter_channels, 1, groups=self.groups)
|
89 |
-
self.norm1_name, norm1 = build_norm_layer(
|
90 |
-
norm_cfg, inter_channels, postfix=1)
|
91 |
-
self.add_module(self.norm1_name, norm1)
|
92 |
-
self.fc2 = build_conv_layer(
|
93 |
-
None, inter_channels, channels * radix, 1, groups=self.groups)
|
94 |
-
self.rsoftmax = RSoftmax(radix, groups)
|
95 |
-
|
96 |
-
@property
|
97 |
-
def norm0(self):
|
98 |
-
return getattr(self, self.norm0_name)
|
99 |
-
|
100 |
-
@property
|
101 |
-
def norm1(self):
|
102 |
-
return getattr(self, self.norm1_name)
|
103 |
-
|
104 |
-
def forward(self, x):
|
105 |
-
x = self.conv(x)
|
106 |
-
x = self.norm0(x)
|
107 |
-
x = self.relu(x)
|
108 |
-
|
109 |
-
batch, rchannel = x.shape[:2]
|
110 |
-
if self.radix > 1:
|
111 |
-
splits = x.view(batch, self.radix, -1, *x.shape[2:])
|
112 |
-
gap = splits.sum(dim=1)
|
113 |
-
else:
|
114 |
-
gap = x
|
115 |
-
gap = F.adaptive_avg_pool2d(gap, 1)
|
116 |
-
gap = self.fc1(gap)
|
117 |
-
|
118 |
-
gap = self.norm1(gap)
|
119 |
-
gap = self.relu(gap)
|
120 |
-
|
121 |
-
atten = self.fc2(gap)
|
122 |
-
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
|
123 |
-
|
124 |
-
if self.radix > 1:
|
125 |
-
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
|
126 |
-
out = torch.sum(attens * splits, dim=1)
|
127 |
-
else:
|
128 |
-
out = atten * x
|
129 |
-
return out.contiguous()
|
130 |
-
|
131 |
-
|
132 |
-
class Bottleneck(_Bottleneck):
|
133 |
-
"""Bottleneck block for ResNeSt.
|
134 |
-
|
135 |
-
Args:
|
136 |
-
in_channels (int): Input channels of this block.
|
137 |
-
out_channels (int): Output channels of this block.
|
138 |
-
groups (int): Groups of conv2.
|
139 |
-
width_per_group (int): Width per group of conv2. 64x4d indicates
|
140 |
-
``groups=64, width_per_group=4`` and 32x8d indicates
|
141 |
-
``groups=32, width_per_group=8``.
|
142 |
-
radix (int): Radix of SpltAtConv2d. Default: 2
|
143 |
-
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
|
144 |
-
Default: 4.
|
145 |
-
avg_down_stride (bool): Whether to use average pool for stride in
|
146 |
-
Bottleneck. Default: True.
|
147 |
-
stride (int): stride of the block. Default: 1
|
148 |
-
dilation (int): dilation of convolution. Default: 1
|
149 |
-
downsample (nn.Module): downsample operation on identity branch.
|
150 |
-
Default: None
|
151 |
-
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
152 |
-
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
153 |
-
the first 1x1 conv layer.
|
154 |
-
conv_cfg (dict): dictionary to construct and config conv layer.
|
155 |
-
Default: None
|
156 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
157 |
-
Default: dict(type='BN')
|
158 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
159 |
-
memory while slowing down the training speed.
|
160 |
-
"""
|
161 |
-
|
162 |
-
def __init__(self,
|
163 |
-
in_channels,
|
164 |
-
out_channels,
|
165 |
-
groups=1,
|
166 |
-
width_per_group=4,
|
167 |
-
base_channels=64,
|
168 |
-
radix=2,
|
169 |
-
reduction_factor=4,
|
170 |
-
avg_down_stride=True,
|
171 |
-
**kwargs):
|
172 |
-
super().__init__(in_channels, out_channels, **kwargs)
|
173 |
-
|
174 |
-
self.groups = groups
|
175 |
-
self.width_per_group = width_per_group
|
176 |
-
|
177 |
-
# For ResNet bottleneck, middle channels are determined by expansion
|
178 |
-
# and out_channels, but for ResNeXt bottleneck, it is determined by
|
179 |
-
# groups and width_per_group and the stage it is located in.
|
180 |
-
if groups != 1:
|
181 |
-
assert self.mid_channels % base_channels == 0
|
182 |
-
self.mid_channels = (
|
183 |
-
groups * width_per_group * self.mid_channels // base_channels)
|
184 |
-
|
185 |
-
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
|
186 |
-
|
187 |
-
self.norm1_name, norm1 = build_norm_layer(
|
188 |
-
self.norm_cfg, self.mid_channels, postfix=1)
|
189 |
-
self.norm3_name, norm3 = build_norm_layer(
|
190 |
-
self.norm_cfg, self.out_channels, postfix=3)
|
191 |
-
|
192 |
-
self.conv1 = build_conv_layer(
|
193 |
-
self.conv_cfg,
|
194 |
-
self.in_channels,
|
195 |
-
self.mid_channels,
|
196 |
-
kernel_size=1,
|
197 |
-
stride=self.conv1_stride,
|
198 |
-
bias=False)
|
199 |
-
self.add_module(self.norm1_name, norm1)
|
200 |
-
self.conv2 = SplitAttentionConv2d(
|
201 |
-
self.mid_channels,
|
202 |
-
self.mid_channels,
|
203 |
-
kernel_size=3,
|
204 |
-
stride=1 if self.avg_down_stride else self.conv2_stride,
|
205 |
-
padding=self.dilation,
|
206 |
-
dilation=self.dilation,
|
207 |
-
groups=groups,
|
208 |
-
radix=radix,
|
209 |
-
reduction_factor=reduction_factor,
|
210 |
-
conv_cfg=self.conv_cfg,
|
211 |
-
norm_cfg=self.norm_cfg)
|
212 |
-
delattr(self, self.norm2_name)
|
213 |
-
|
214 |
-
if self.avg_down_stride:
|
215 |
-
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
|
216 |
-
|
217 |
-
self.conv3 = build_conv_layer(
|
218 |
-
self.conv_cfg,
|
219 |
-
self.mid_channels,
|
220 |
-
self.out_channels,
|
221 |
-
kernel_size=1,
|
222 |
-
bias=False)
|
223 |
-
self.add_module(self.norm3_name, norm3)
|
224 |
-
|
225 |
-
def forward(self, x):
|
226 |
-
|
227 |
-
def _inner_forward(x):
|
228 |
-
identity = x
|
229 |
-
|
230 |
-
out = self.conv1(x)
|
231 |
-
out = self.norm1(out)
|
232 |
-
out = self.relu(out)
|
233 |
-
|
234 |
-
out = self.conv2(out)
|
235 |
-
|
236 |
-
if self.avg_down_stride:
|
237 |
-
out = self.avd_layer(out)
|
238 |
-
|
239 |
-
out = self.conv3(out)
|
240 |
-
out = self.norm3(out)
|
241 |
-
|
242 |
-
if self.downsample is not None:
|
243 |
-
identity = self.downsample(x)
|
244 |
-
|
245 |
-
out += identity
|
246 |
-
|
247 |
-
return out
|
248 |
-
|
249 |
-
if self.with_cp and x.requires_grad:
|
250 |
-
out = cp.checkpoint(_inner_forward, x)
|
251 |
-
else:
|
252 |
-
out = _inner_forward(x)
|
253 |
-
|
254 |
-
out = self.relu(out)
|
255 |
-
|
256 |
-
return out
|
257 |
-
|
258 |
-
|
259 |
-
@BACKBONES.register_module()
|
260 |
-
class ResNeSt(ResNetV1d):
|
261 |
-
"""ResNeSt backbone.
|
262 |
-
|
263 |
-
Please refer to the `paper <https://arxiv.org/pdf/2004.08955.pdf>`__
|
264 |
-
for details.
|
265 |
-
|
266 |
-
Args:
|
267 |
-
depth (int): Network depth, from {50, 101, 152, 200}.
|
268 |
-
groups (int): Groups of conv2 in Bottleneck. Default: 32.
|
269 |
-
width_per_group (int): Width per group of conv2 in Bottleneck.
|
270 |
-
Default: 4.
|
271 |
-
radix (int): Radix of SpltAtConv2d. Default: 2
|
272 |
-
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
|
273 |
-
Default: 4.
|
274 |
-
avg_down_stride (bool): Whether to use average pool for stride in
|
275 |
-
Bottleneck. Default: True.
|
276 |
-
in_channels (int): Number of input image channels. Default: 3.
|
277 |
-
stem_channels (int): Output channels of the stem layer. Default: 64.
|
278 |
-
num_stages (int): Stages of the network. Default: 4.
|
279 |
-
strides (Sequence[int]): Strides of the first block of each stage.
|
280 |
-
Default: ``(1, 2, 2, 2)``.
|
281 |
-
dilations (Sequence[int]): Dilation of each stage.
|
282 |
-
Default: ``(1, 1, 1, 1)``.
|
283 |
-
out_indices (Sequence[int]): Output from which stages. If only one
|
284 |
-
stage is specified, a single tensor (feature map) is returned,
|
285 |
-
otherwise multiple stages are specified, a tuple of tensors will
|
286 |
-
be returned. Default: ``(3, )``.
|
287 |
-
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
288 |
-
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
289 |
-
the first 1x1 conv layer.
|
290 |
-
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
291 |
-
Default: False.
|
292 |
-
avg_down (bool): Use AvgPool instead of stride conv when
|
293 |
-
downsampling in the bottleneck. Default: False.
|
294 |
-
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
295 |
-
-1 means not freezing any parameters. Default: -1.
|
296 |
-
conv_cfg (dict | None): The config dict for conv layers. Default: None.
|
297 |
-
norm_cfg (dict): The config dict for norm layers.
|
298 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
299 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
300 |
-
and its variants only. Default: False.
|
301 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
302 |
-
memory while slowing down the training speed. Default: False.
|
303 |
-
zero_init_residual (bool): Whether to use zero init for last norm layer
|
304 |
-
in resblocks to let them behave as identity. Default: True.
|
305 |
-
"""
|
306 |
-
|
307 |
-
arch_settings = {
|
308 |
-
50: (Bottleneck, (3, 4, 6, 3)),
|
309 |
-
101: (Bottleneck, (3, 4, 23, 3)),
|
310 |
-
152: (Bottleneck, (3, 8, 36, 3)),
|
311 |
-
200: (Bottleneck, (3, 24, 36, 3)),
|
312 |
-
269: (Bottleneck, (3, 30, 48, 8))
|
313 |
-
}
|
314 |
-
|
315 |
-
def __init__(self,
|
316 |
-
depth,
|
317 |
-
groups=1,
|
318 |
-
width_per_group=4,
|
319 |
-
radix=2,
|
320 |
-
reduction_factor=4,
|
321 |
-
avg_down_stride=True,
|
322 |
-
**kwargs):
|
323 |
-
self.groups = groups
|
324 |
-
self.width_per_group = width_per_group
|
325 |
-
self.radix = radix
|
326 |
-
self.reduction_factor = reduction_factor
|
327 |
-
self.avg_down_stride = avg_down_stride
|
328 |
-
super().__init__(depth=depth, **kwargs)
|
329 |
-
|
330 |
-
def make_res_layer(self, **kwargs):
|
331 |
-
return ResLayer(
|
332 |
-
groups=self.groups,
|
333 |
-
width_per_group=self.width_per_group,
|
334 |
-
base_channels=self.base_channels,
|
335 |
-
radix=self.radix,
|
336 |
-
reduction_factor=self.reduction_factor,
|
337 |
-
avg_down_stride=self.avg_down_stride,
|
338 |
-
**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/resnet.py
CHANGED
@@ -3,9 +3,9 @@ import copy
|
|
3 |
|
4 |
import torch.nn as nn
|
5 |
import torch.utils.checkpoint as cp
|
6 |
-
from
|
7 |
-
|
8 |
-
from
|
9 |
|
10 |
from ..builder import BACKBONES
|
11 |
from .base_backbone import BaseBackbone
|
|
|
3 |
|
4 |
import torch.nn as nn
|
5 |
import torch.utils.checkpoint as cp
|
6 |
+
from mmengine.model import constant_init, kaiming_init
|
7 |
+
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer)
|
8 |
+
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
9 |
|
10 |
from ..builder import BACKBONES
|
11 |
from .base_backbone import BaseBackbone
|
main/transformer_utils/mmpose/models/backbones/resnext.py
DELETED
@@ -1,162 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
from mmcv.cnn import build_conv_layer, build_norm_layer
|
3 |
-
|
4 |
-
from ..builder import BACKBONES
|
5 |
-
from .resnet import Bottleneck as _Bottleneck
|
6 |
-
from .resnet import ResLayer, ResNet
|
7 |
-
|
8 |
-
|
9 |
-
class Bottleneck(_Bottleneck):
|
10 |
-
"""Bottleneck block for ResNeXt.
|
11 |
-
|
12 |
-
Args:
|
13 |
-
in_channels (int): Input channels of this block.
|
14 |
-
out_channels (int): Output channels of this block.
|
15 |
-
groups (int): Groups of conv2.
|
16 |
-
width_per_group (int): Width per group of conv2. 64x4d indicates
|
17 |
-
``groups=64, width_per_group=4`` and 32x8d indicates
|
18 |
-
``groups=32, width_per_group=8``.
|
19 |
-
stride (int): stride of the block. Default: 1
|
20 |
-
dilation (int): dilation of convolution. Default: 1
|
21 |
-
downsample (nn.Module): downsample operation on identity branch.
|
22 |
-
Default: None
|
23 |
-
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
24 |
-
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
25 |
-
the first 1x1 conv layer.
|
26 |
-
conv_cfg (dict): dictionary to construct and config conv layer.
|
27 |
-
Default: None
|
28 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
29 |
-
Default: dict(type='BN')
|
30 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
31 |
-
memory while slowing down the training speed.
|
32 |
-
"""
|
33 |
-
|
34 |
-
def __init__(self,
|
35 |
-
in_channels,
|
36 |
-
out_channels,
|
37 |
-
base_channels=64,
|
38 |
-
groups=32,
|
39 |
-
width_per_group=4,
|
40 |
-
**kwargs):
|
41 |
-
super().__init__(in_channels, out_channels, **kwargs)
|
42 |
-
self.groups = groups
|
43 |
-
self.width_per_group = width_per_group
|
44 |
-
|
45 |
-
# For ResNet bottleneck, middle channels are determined by expansion
|
46 |
-
# and out_channels, but for ResNeXt bottleneck, it is determined by
|
47 |
-
# groups and width_per_group and the stage it is located in.
|
48 |
-
if groups != 1:
|
49 |
-
assert self.mid_channels % base_channels == 0
|
50 |
-
self.mid_channels = (
|
51 |
-
groups * width_per_group * self.mid_channels // base_channels)
|
52 |
-
|
53 |
-
self.norm1_name, norm1 = build_norm_layer(
|
54 |
-
self.norm_cfg, self.mid_channels, postfix=1)
|
55 |
-
self.norm2_name, norm2 = build_norm_layer(
|
56 |
-
self.norm_cfg, self.mid_channels, postfix=2)
|
57 |
-
self.norm3_name, norm3 = build_norm_layer(
|
58 |
-
self.norm_cfg, self.out_channels, postfix=3)
|
59 |
-
|
60 |
-
self.conv1 = build_conv_layer(
|
61 |
-
self.conv_cfg,
|
62 |
-
self.in_channels,
|
63 |
-
self.mid_channels,
|
64 |
-
kernel_size=1,
|
65 |
-
stride=self.conv1_stride,
|
66 |
-
bias=False)
|
67 |
-
self.add_module(self.norm1_name, norm1)
|
68 |
-
self.conv2 = build_conv_layer(
|
69 |
-
self.conv_cfg,
|
70 |
-
self.mid_channels,
|
71 |
-
self.mid_channels,
|
72 |
-
kernel_size=3,
|
73 |
-
stride=self.conv2_stride,
|
74 |
-
padding=self.dilation,
|
75 |
-
dilation=self.dilation,
|
76 |
-
groups=groups,
|
77 |
-
bias=False)
|
78 |
-
|
79 |
-
self.add_module(self.norm2_name, norm2)
|
80 |
-
self.conv3 = build_conv_layer(
|
81 |
-
self.conv_cfg,
|
82 |
-
self.mid_channels,
|
83 |
-
self.out_channels,
|
84 |
-
kernel_size=1,
|
85 |
-
bias=False)
|
86 |
-
self.add_module(self.norm3_name, norm3)
|
87 |
-
|
88 |
-
|
89 |
-
@BACKBONES.register_module()
|
90 |
-
class ResNeXt(ResNet):
|
91 |
-
"""ResNeXt backbone.
|
92 |
-
|
93 |
-
Please refer to the `paper <https://arxiv.org/abs/1611.05431>`__ for
|
94 |
-
details.
|
95 |
-
|
96 |
-
Args:
|
97 |
-
depth (int): Network depth, from {50, 101, 152}.
|
98 |
-
groups (int): Groups of conv2 in Bottleneck. Default: 32.
|
99 |
-
width_per_group (int): Width per group of conv2 in Bottleneck.
|
100 |
-
Default: 4.
|
101 |
-
in_channels (int): Number of input image channels. Default: 3.
|
102 |
-
stem_channels (int): Output channels of the stem layer. Default: 64.
|
103 |
-
num_stages (int): Stages of the network. Default: 4.
|
104 |
-
strides (Sequence[int]): Strides of the first block of each stage.
|
105 |
-
Default: ``(1, 2, 2, 2)``.
|
106 |
-
dilations (Sequence[int]): Dilation of each stage.
|
107 |
-
Default: ``(1, 1, 1, 1)``.
|
108 |
-
out_indices (Sequence[int]): Output from which stages. If only one
|
109 |
-
stage is specified, a single tensor (feature map) is returned,
|
110 |
-
otherwise multiple stages are specified, a tuple of tensors will
|
111 |
-
be returned. Default: ``(3, )``.
|
112 |
-
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
113 |
-
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
114 |
-
the first 1x1 conv layer.
|
115 |
-
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
116 |
-
Default: False.
|
117 |
-
avg_down (bool): Use AvgPool instead of stride conv when
|
118 |
-
downsampling in the bottleneck. Default: False.
|
119 |
-
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
120 |
-
-1 means not freezing any parameters. Default: -1.
|
121 |
-
conv_cfg (dict | None): The config dict for conv layers. Default: None.
|
122 |
-
norm_cfg (dict): The config dict for norm layers.
|
123 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
124 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
125 |
-
and its variants only. Default: False.
|
126 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
127 |
-
memory while slowing down the training speed. Default: False.
|
128 |
-
zero_init_residual (bool): Whether to use zero init for last norm layer
|
129 |
-
in resblocks to let them behave as identity. Default: True.
|
130 |
-
|
131 |
-
Example:
|
132 |
-
>>> from mmpose.models import ResNeXt
|
133 |
-
>>> import torch
|
134 |
-
>>> self = ResNeXt(depth=50, out_indices=(0, 1, 2, 3))
|
135 |
-
>>> self.eval()
|
136 |
-
>>> inputs = torch.rand(1, 3, 32, 32)
|
137 |
-
>>> level_outputs = self.forward(inputs)
|
138 |
-
>>> for level_out in level_outputs:
|
139 |
-
... print(tuple(level_out.shape))
|
140 |
-
(1, 256, 8, 8)
|
141 |
-
(1, 512, 4, 4)
|
142 |
-
(1, 1024, 2, 2)
|
143 |
-
(1, 2048, 1, 1)
|
144 |
-
"""
|
145 |
-
|
146 |
-
arch_settings = {
|
147 |
-
50: (Bottleneck, (3, 4, 6, 3)),
|
148 |
-
101: (Bottleneck, (3, 4, 23, 3)),
|
149 |
-
152: (Bottleneck, (3, 8, 36, 3))
|
150 |
-
}
|
151 |
-
|
152 |
-
def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
|
153 |
-
self.groups = groups
|
154 |
-
self.width_per_group = width_per_group
|
155 |
-
super().__init__(depth, **kwargs)
|
156 |
-
|
157 |
-
def make_res_layer(self, **kwargs):
|
158 |
-
return ResLayer(
|
159 |
-
groups=self.groups,
|
160 |
-
width_per_group=self.width_per_group,
|
161 |
-
base_channels=self.base_channels,
|
162 |
-
**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/rsn.py
DELETED
@@ -1,616 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy as cp
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from mmcv.cnn import (ConvModule, MaxPool2d, constant_init, kaiming_init,
|
8 |
-
normal_init)
|
9 |
-
|
10 |
-
from ..builder import BACKBONES
|
11 |
-
from .base_backbone import BaseBackbone
|
12 |
-
|
13 |
-
|
14 |
-
class RSB(nn.Module):
|
15 |
-
"""Residual Steps block for RSN. Paper ref: Cai et al. "Learning Delicate
|
16 |
-
Local Representations for Multi-Person Pose Estimation" (ECCV 2020).
|
17 |
-
|
18 |
-
Args:
|
19 |
-
in_channels (int): Input channels of this block.
|
20 |
-
out_channels (int): Output channels of this block.
|
21 |
-
num_steps (int): Numbers of steps in RSB
|
22 |
-
stride (int): stride of the block. Default: 1
|
23 |
-
downsample (nn.Module): downsample operation on identity branch.
|
24 |
-
Default: None.
|
25 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
26 |
-
Default: dict(type='BN')
|
27 |
-
expand_times (int): Times by which the in_channels are expanded.
|
28 |
-
Default:26.
|
29 |
-
res_top_channels (int): Number of channels of feature output by
|
30 |
-
ResNet_top. Default:64.
|
31 |
-
"""
|
32 |
-
|
33 |
-
expansion = 1
|
34 |
-
|
35 |
-
def __init__(self,
|
36 |
-
in_channels,
|
37 |
-
out_channels,
|
38 |
-
num_steps=4,
|
39 |
-
stride=1,
|
40 |
-
downsample=None,
|
41 |
-
with_cp=False,
|
42 |
-
norm_cfg=dict(type='BN'),
|
43 |
-
expand_times=26,
|
44 |
-
res_top_channels=64):
|
45 |
-
# Protect mutable default arguments
|
46 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
47 |
-
super().__init__()
|
48 |
-
assert num_steps > 1
|
49 |
-
self.in_channels = in_channels
|
50 |
-
self.branch_channels = self.in_channels * expand_times
|
51 |
-
self.branch_channels //= res_top_channels
|
52 |
-
self.out_channels = out_channels
|
53 |
-
self.stride = stride
|
54 |
-
self.downsample = downsample
|
55 |
-
self.with_cp = with_cp
|
56 |
-
self.norm_cfg = norm_cfg
|
57 |
-
self.num_steps = num_steps
|
58 |
-
self.conv_bn_relu1 = ConvModule(
|
59 |
-
self.in_channels,
|
60 |
-
self.num_steps * self.branch_channels,
|
61 |
-
kernel_size=1,
|
62 |
-
stride=self.stride,
|
63 |
-
padding=0,
|
64 |
-
norm_cfg=self.norm_cfg,
|
65 |
-
inplace=False)
|
66 |
-
for i in range(self.num_steps):
|
67 |
-
for j in range(i + 1):
|
68 |
-
module_name = f'conv_bn_relu2_{i + 1}_{j + 1}'
|
69 |
-
self.add_module(
|
70 |
-
module_name,
|
71 |
-
ConvModule(
|
72 |
-
self.branch_channels,
|
73 |
-
self.branch_channels,
|
74 |
-
kernel_size=3,
|
75 |
-
stride=1,
|
76 |
-
padding=1,
|
77 |
-
norm_cfg=self.norm_cfg,
|
78 |
-
inplace=False))
|
79 |
-
self.conv_bn3 = ConvModule(
|
80 |
-
self.num_steps * self.branch_channels,
|
81 |
-
self.out_channels * self.expansion,
|
82 |
-
kernel_size=1,
|
83 |
-
stride=1,
|
84 |
-
padding=0,
|
85 |
-
act_cfg=None,
|
86 |
-
norm_cfg=self.norm_cfg,
|
87 |
-
inplace=False)
|
88 |
-
self.relu = nn.ReLU(inplace=False)
|
89 |
-
|
90 |
-
def forward(self, x):
|
91 |
-
"""Forward function."""
|
92 |
-
|
93 |
-
identity = x
|
94 |
-
x = self.conv_bn_relu1(x)
|
95 |
-
spx = torch.split(x, self.branch_channels, 1)
|
96 |
-
outputs = list()
|
97 |
-
outs = list()
|
98 |
-
for i in range(self.num_steps):
|
99 |
-
outputs_i = list()
|
100 |
-
outputs.append(outputs_i)
|
101 |
-
for j in range(i + 1):
|
102 |
-
if j == 0:
|
103 |
-
inputs = spx[i]
|
104 |
-
else:
|
105 |
-
inputs = outputs[i][j - 1]
|
106 |
-
if i > j:
|
107 |
-
inputs = inputs + outputs[i - 1][j]
|
108 |
-
module_name = f'conv_bn_relu2_{i + 1}_{j + 1}'
|
109 |
-
module_i_j = getattr(self, module_name)
|
110 |
-
outputs[i].append(module_i_j(inputs))
|
111 |
-
|
112 |
-
outs.append(outputs[i][i])
|
113 |
-
out = torch.cat(tuple(outs), 1)
|
114 |
-
out = self.conv_bn3(out)
|
115 |
-
|
116 |
-
if self.downsample is not None:
|
117 |
-
identity = self.downsample(identity)
|
118 |
-
out = out + identity
|
119 |
-
|
120 |
-
out = self.relu(out)
|
121 |
-
|
122 |
-
return out
|
123 |
-
|
124 |
-
|
125 |
-
class Downsample_module(nn.Module):
|
126 |
-
"""Downsample module for RSN.
|
127 |
-
|
128 |
-
Args:
|
129 |
-
block (nn.Module): Downsample block.
|
130 |
-
num_blocks (list): Number of blocks in each downsample unit.
|
131 |
-
num_units (int): Numbers of downsample units. Default: 4
|
132 |
-
has_skip (bool): Have skip connections from prior upsample
|
133 |
-
module or not. Default:False
|
134 |
-
num_steps (int): Number of steps in a block. Default:4
|
135 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
136 |
-
Default: dict(type='BN')
|
137 |
-
in_channels (int): Number of channels of the input feature to
|
138 |
-
downsample module. Default: 64
|
139 |
-
expand_times (int): Times by which the in_channels are expanded.
|
140 |
-
Default:26.
|
141 |
-
"""
|
142 |
-
|
143 |
-
def __init__(self,
|
144 |
-
block,
|
145 |
-
num_blocks,
|
146 |
-
num_steps=4,
|
147 |
-
num_units=4,
|
148 |
-
has_skip=False,
|
149 |
-
norm_cfg=dict(type='BN'),
|
150 |
-
in_channels=64,
|
151 |
-
expand_times=26):
|
152 |
-
# Protect mutable default arguments
|
153 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
154 |
-
super().__init__()
|
155 |
-
self.has_skip = has_skip
|
156 |
-
self.in_channels = in_channels
|
157 |
-
assert len(num_blocks) == num_units
|
158 |
-
self.num_blocks = num_blocks
|
159 |
-
self.num_units = num_units
|
160 |
-
self.num_steps = num_steps
|
161 |
-
self.norm_cfg = norm_cfg
|
162 |
-
self.layer1 = self._make_layer(
|
163 |
-
block,
|
164 |
-
in_channels,
|
165 |
-
num_blocks[0],
|
166 |
-
expand_times=expand_times,
|
167 |
-
res_top_channels=in_channels)
|
168 |
-
for i in range(1, num_units):
|
169 |
-
module_name = f'layer{i + 1}'
|
170 |
-
self.add_module(
|
171 |
-
module_name,
|
172 |
-
self._make_layer(
|
173 |
-
block,
|
174 |
-
in_channels * pow(2, i),
|
175 |
-
num_blocks[i],
|
176 |
-
stride=2,
|
177 |
-
expand_times=expand_times,
|
178 |
-
res_top_channels=in_channels))
|
179 |
-
|
180 |
-
def _make_layer(self,
|
181 |
-
block,
|
182 |
-
out_channels,
|
183 |
-
blocks,
|
184 |
-
stride=1,
|
185 |
-
expand_times=26,
|
186 |
-
res_top_channels=64):
|
187 |
-
downsample = None
|
188 |
-
if stride != 1 or self.in_channels != out_channels * block.expansion:
|
189 |
-
downsample = ConvModule(
|
190 |
-
self.in_channels,
|
191 |
-
out_channels * block.expansion,
|
192 |
-
kernel_size=1,
|
193 |
-
stride=stride,
|
194 |
-
padding=0,
|
195 |
-
norm_cfg=self.norm_cfg,
|
196 |
-
act_cfg=None,
|
197 |
-
inplace=True)
|
198 |
-
|
199 |
-
units = list()
|
200 |
-
units.append(
|
201 |
-
block(
|
202 |
-
self.in_channels,
|
203 |
-
out_channels,
|
204 |
-
num_steps=self.num_steps,
|
205 |
-
stride=stride,
|
206 |
-
downsample=downsample,
|
207 |
-
norm_cfg=self.norm_cfg,
|
208 |
-
expand_times=expand_times,
|
209 |
-
res_top_channels=res_top_channels))
|
210 |
-
self.in_channels = out_channels * block.expansion
|
211 |
-
for _ in range(1, blocks):
|
212 |
-
units.append(
|
213 |
-
block(
|
214 |
-
self.in_channels,
|
215 |
-
out_channels,
|
216 |
-
num_steps=self.num_steps,
|
217 |
-
expand_times=expand_times,
|
218 |
-
res_top_channels=res_top_channels))
|
219 |
-
|
220 |
-
return nn.Sequential(*units)
|
221 |
-
|
222 |
-
def forward(self, x, skip1, skip2):
|
223 |
-
out = list()
|
224 |
-
for i in range(self.num_units):
|
225 |
-
module_name = f'layer{i + 1}'
|
226 |
-
module_i = getattr(self, module_name)
|
227 |
-
x = module_i(x)
|
228 |
-
if self.has_skip:
|
229 |
-
x = x + skip1[i] + skip2[i]
|
230 |
-
out.append(x)
|
231 |
-
out.reverse()
|
232 |
-
|
233 |
-
return tuple(out)
|
234 |
-
|
235 |
-
|
236 |
-
class Upsample_unit(nn.Module):
|
237 |
-
"""Upsample unit for upsample module.
|
238 |
-
|
239 |
-
Args:
|
240 |
-
ind (int): Indicates whether to interpolate (>0) and whether to
|
241 |
-
generate feature map for the next hourglass-like module.
|
242 |
-
num_units (int): Number of units that form a upsample module. Along
|
243 |
-
with ind and gen_cross_conv, nm_units is used to decide whether
|
244 |
-
to generate feature map for the next hourglass-like module.
|
245 |
-
in_channels (int): Channel number of the skip-in feature maps from
|
246 |
-
the corresponding downsample unit.
|
247 |
-
unit_channels (int): Channel number in this unit. Default:256.
|
248 |
-
gen_skip: (bool): Whether or not to generate skips for the posterior
|
249 |
-
downsample module. Default:False
|
250 |
-
gen_cross_conv (bool): Whether to generate feature map for the next
|
251 |
-
hourglass-like module. Default:False
|
252 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
253 |
-
Default: dict(type='BN')
|
254 |
-
out_channels (in): Number of channels of feature output by upsample
|
255 |
-
module. Must equal to in_channels of downsample module. Default:64
|
256 |
-
"""
|
257 |
-
|
258 |
-
def __init__(self,
|
259 |
-
ind,
|
260 |
-
num_units,
|
261 |
-
in_channels,
|
262 |
-
unit_channels=256,
|
263 |
-
gen_skip=False,
|
264 |
-
gen_cross_conv=False,
|
265 |
-
norm_cfg=dict(type='BN'),
|
266 |
-
out_channels=64):
|
267 |
-
# Protect mutable default arguments
|
268 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
269 |
-
super().__init__()
|
270 |
-
self.num_units = num_units
|
271 |
-
self.norm_cfg = norm_cfg
|
272 |
-
self.in_skip = ConvModule(
|
273 |
-
in_channels,
|
274 |
-
unit_channels,
|
275 |
-
kernel_size=1,
|
276 |
-
stride=1,
|
277 |
-
padding=0,
|
278 |
-
norm_cfg=self.norm_cfg,
|
279 |
-
act_cfg=None,
|
280 |
-
inplace=True)
|
281 |
-
self.relu = nn.ReLU(inplace=True)
|
282 |
-
|
283 |
-
self.ind = ind
|
284 |
-
if self.ind > 0:
|
285 |
-
self.up_conv = ConvModule(
|
286 |
-
unit_channels,
|
287 |
-
unit_channels,
|
288 |
-
kernel_size=1,
|
289 |
-
stride=1,
|
290 |
-
padding=0,
|
291 |
-
norm_cfg=self.norm_cfg,
|
292 |
-
act_cfg=None,
|
293 |
-
inplace=True)
|
294 |
-
|
295 |
-
self.gen_skip = gen_skip
|
296 |
-
if self.gen_skip:
|
297 |
-
self.out_skip1 = ConvModule(
|
298 |
-
in_channels,
|
299 |
-
in_channels,
|
300 |
-
kernel_size=1,
|
301 |
-
stride=1,
|
302 |
-
padding=0,
|
303 |
-
norm_cfg=self.norm_cfg,
|
304 |
-
inplace=True)
|
305 |
-
|
306 |
-
self.out_skip2 = ConvModule(
|
307 |
-
unit_channels,
|
308 |
-
in_channels,
|
309 |
-
kernel_size=1,
|
310 |
-
stride=1,
|
311 |
-
padding=0,
|
312 |
-
norm_cfg=self.norm_cfg,
|
313 |
-
inplace=True)
|
314 |
-
|
315 |
-
self.gen_cross_conv = gen_cross_conv
|
316 |
-
if self.ind == num_units - 1 and self.gen_cross_conv:
|
317 |
-
self.cross_conv = ConvModule(
|
318 |
-
unit_channels,
|
319 |
-
out_channels,
|
320 |
-
kernel_size=1,
|
321 |
-
stride=1,
|
322 |
-
padding=0,
|
323 |
-
norm_cfg=self.norm_cfg,
|
324 |
-
inplace=True)
|
325 |
-
|
326 |
-
def forward(self, x, up_x):
|
327 |
-
out = self.in_skip(x)
|
328 |
-
|
329 |
-
if self.ind > 0:
|
330 |
-
up_x = F.interpolate(
|
331 |
-
up_x,
|
332 |
-
size=(x.size(2), x.size(3)),
|
333 |
-
mode='bilinear',
|
334 |
-
align_corners=True)
|
335 |
-
up_x = self.up_conv(up_x)
|
336 |
-
out = out + up_x
|
337 |
-
out = self.relu(out)
|
338 |
-
|
339 |
-
skip1 = None
|
340 |
-
skip2 = None
|
341 |
-
if self.gen_skip:
|
342 |
-
skip1 = self.out_skip1(x)
|
343 |
-
skip2 = self.out_skip2(out)
|
344 |
-
|
345 |
-
cross_conv = None
|
346 |
-
if self.ind == self.num_units - 1 and self.gen_cross_conv:
|
347 |
-
cross_conv = self.cross_conv(out)
|
348 |
-
|
349 |
-
return out, skip1, skip2, cross_conv
|
350 |
-
|
351 |
-
|
352 |
-
class Upsample_module(nn.Module):
|
353 |
-
"""Upsample module for RSN.
|
354 |
-
|
355 |
-
Args:
|
356 |
-
unit_channels (int): Channel number in the upsample units.
|
357 |
-
Default:256.
|
358 |
-
num_units (int): Numbers of upsample units. Default: 4
|
359 |
-
gen_skip (bool): Whether to generate skip for posterior downsample
|
360 |
-
module or not. Default:False
|
361 |
-
gen_cross_conv (bool): Whether to generate feature map for the next
|
362 |
-
hourglass-like module. Default:False
|
363 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
364 |
-
Default: dict(type='BN')
|
365 |
-
out_channels (int): Number of channels of feature output by upsample
|
366 |
-
module. Must equal to in_channels of downsample module. Default:64
|
367 |
-
"""
|
368 |
-
|
369 |
-
def __init__(self,
|
370 |
-
unit_channels=256,
|
371 |
-
num_units=4,
|
372 |
-
gen_skip=False,
|
373 |
-
gen_cross_conv=False,
|
374 |
-
norm_cfg=dict(type='BN'),
|
375 |
-
out_channels=64):
|
376 |
-
# Protect mutable default arguments
|
377 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
378 |
-
super().__init__()
|
379 |
-
self.in_channels = list()
|
380 |
-
for i in range(num_units):
|
381 |
-
self.in_channels.append(RSB.expansion * out_channels * pow(2, i))
|
382 |
-
self.in_channels.reverse()
|
383 |
-
self.num_units = num_units
|
384 |
-
self.gen_skip = gen_skip
|
385 |
-
self.gen_cross_conv = gen_cross_conv
|
386 |
-
self.norm_cfg = norm_cfg
|
387 |
-
for i in range(num_units):
|
388 |
-
module_name = f'up{i + 1}'
|
389 |
-
self.add_module(
|
390 |
-
module_name,
|
391 |
-
Upsample_unit(
|
392 |
-
i,
|
393 |
-
self.num_units,
|
394 |
-
self.in_channels[i],
|
395 |
-
unit_channels,
|
396 |
-
self.gen_skip,
|
397 |
-
self.gen_cross_conv,
|
398 |
-
norm_cfg=self.norm_cfg,
|
399 |
-
out_channels=64))
|
400 |
-
|
401 |
-
def forward(self, x):
|
402 |
-
out = list()
|
403 |
-
skip1 = list()
|
404 |
-
skip2 = list()
|
405 |
-
cross_conv = None
|
406 |
-
for i in range(self.num_units):
|
407 |
-
module_i = getattr(self, f'up{i + 1}')
|
408 |
-
if i == 0:
|
409 |
-
outi, skip1_i, skip2_i, _ = module_i(x[i], None)
|
410 |
-
elif i == self.num_units - 1:
|
411 |
-
outi, skip1_i, skip2_i, cross_conv = module_i(x[i], out[i - 1])
|
412 |
-
else:
|
413 |
-
outi, skip1_i, skip2_i, _ = module_i(x[i], out[i - 1])
|
414 |
-
out.append(outi)
|
415 |
-
skip1.append(skip1_i)
|
416 |
-
skip2.append(skip2_i)
|
417 |
-
skip1.reverse()
|
418 |
-
skip2.reverse()
|
419 |
-
|
420 |
-
return out, skip1, skip2, cross_conv
|
421 |
-
|
422 |
-
|
423 |
-
class Single_stage_RSN(nn.Module):
|
424 |
-
"""Single_stage Residual Steps Network.
|
425 |
-
|
426 |
-
Args:
|
427 |
-
unit_channels (int): Channel number in the upsample units. Default:256.
|
428 |
-
num_units (int): Numbers of downsample/upsample units. Default: 4
|
429 |
-
gen_skip (bool): Whether to generate skip for posterior downsample
|
430 |
-
module or not. Default:False
|
431 |
-
gen_cross_conv (bool): Whether to generate feature map for the next
|
432 |
-
hourglass-like module. Default:False
|
433 |
-
has_skip (bool): Have skip connections from prior upsample
|
434 |
-
module or not. Default:False
|
435 |
-
num_steps (int): Number of steps in RSB. Default: 4
|
436 |
-
num_blocks (list): Number of blocks in each downsample unit.
|
437 |
-
Default: [2, 2, 2, 2] Note: Make sure num_units==len(num_blocks)
|
438 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
439 |
-
Default: dict(type='BN')
|
440 |
-
in_channels (int): Number of channels of the feature from ResNet_Top.
|
441 |
-
Default: 64.
|
442 |
-
expand_times (int): Times by which the in_channels are expanded in RSB.
|
443 |
-
Default:26.
|
444 |
-
"""
|
445 |
-
|
446 |
-
def __init__(self,
|
447 |
-
has_skip=False,
|
448 |
-
gen_skip=False,
|
449 |
-
gen_cross_conv=False,
|
450 |
-
unit_channels=256,
|
451 |
-
num_units=4,
|
452 |
-
num_steps=4,
|
453 |
-
num_blocks=[2, 2, 2, 2],
|
454 |
-
norm_cfg=dict(type='BN'),
|
455 |
-
in_channels=64,
|
456 |
-
expand_times=26):
|
457 |
-
# Protect mutable default arguments
|
458 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
459 |
-
num_blocks = cp.deepcopy(num_blocks)
|
460 |
-
super().__init__()
|
461 |
-
assert len(num_blocks) == num_units
|
462 |
-
self.has_skip = has_skip
|
463 |
-
self.gen_skip = gen_skip
|
464 |
-
self.gen_cross_conv = gen_cross_conv
|
465 |
-
self.num_units = num_units
|
466 |
-
self.num_steps = num_steps
|
467 |
-
self.unit_channels = unit_channels
|
468 |
-
self.num_blocks = num_blocks
|
469 |
-
self.norm_cfg = norm_cfg
|
470 |
-
|
471 |
-
self.downsample = Downsample_module(RSB, num_blocks, num_steps,
|
472 |
-
num_units, has_skip, norm_cfg,
|
473 |
-
in_channels, expand_times)
|
474 |
-
self.upsample = Upsample_module(unit_channels, num_units, gen_skip,
|
475 |
-
gen_cross_conv, norm_cfg, in_channels)
|
476 |
-
|
477 |
-
def forward(self, x, skip1, skip2):
|
478 |
-
mid = self.downsample(x, skip1, skip2)
|
479 |
-
out, skip1, skip2, cross_conv = self.upsample(mid)
|
480 |
-
|
481 |
-
return out, skip1, skip2, cross_conv
|
482 |
-
|
483 |
-
|
484 |
-
class ResNet_top(nn.Module):
|
485 |
-
"""ResNet top for RSN.
|
486 |
-
|
487 |
-
Args:
|
488 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
489 |
-
Default: dict(type='BN')
|
490 |
-
channels (int): Number of channels of the feature output by ResNet_top.
|
491 |
-
"""
|
492 |
-
|
493 |
-
def __init__(self, norm_cfg=dict(type='BN'), channels=64):
|
494 |
-
# Protect mutable default arguments
|
495 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
496 |
-
super().__init__()
|
497 |
-
self.top = nn.Sequential(
|
498 |
-
ConvModule(
|
499 |
-
3,
|
500 |
-
channels,
|
501 |
-
kernel_size=7,
|
502 |
-
stride=2,
|
503 |
-
padding=3,
|
504 |
-
norm_cfg=norm_cfg,
|
505 |
-
inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1))
|
506 |
-
|
507 |
-
def forward(self, img):
|
508 |
-
return self.top(img)
|
509 |
-
|
510 |
-
|
511 |
-
@BACKBONES.register_module()
|
512 |
-
class RSN(BaseBackbone):
|
513 |
-
"""Residual Steps Network backbone. Paper ref: Cai et al. "Learning
|
514 |
-
Delicate Local Representations for Multi-Person Pose Estimation" (ECCV
|
515 |
-
2020).
|
516 |
-
|
517 |
-
Args:
|
518 |
-
unit_channels (int): Number of Channels in an upsample unit.
|
519 |
-
Default: 256
|
520 |
-
num_stages (int): Number of stages in a multi-stage RSN. Default: 4
|
521 |
-
num_units (int): NUmber of downsample/upsample units in a single-stage
|
522 |
-
RSN. Default: 4 Note: Make sure num_units == len(self.num_blocks)
|
523 |
-
num_blocks (list): Number of RSBs (Residual Steps Block) in each
|
524 |
-
downsample unit. Default: [2, 2, 2, 2]
|
525 |
-
num_steps (int): Number of steps in a RSB. Default:4
|
526 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
527 |
-
Default: dict(type='BN')
|
528 |
-
res_top_channels (int): Number of channels of feature from ResNet_top.
|
529 |
-
Default: 64.
|
530 |
-
expand_times (int): Times by which the in_channels are expanded in RSB.
|
531 |
-
Default:26.
|
532 |
-
Example:
|
533 |
-
>>> from mmpose.models import RSN
|
534 |
-
>>> import torch
|
535 |
-
>>> self = RSN(num_stages=2,num_units=2,num_blocks=[2,2])
|
536 |
-
>>> self.eval()
|
537 |
-
>>> inputs = torch.rand(1, 3, 511, 511)
|
538 |
-
>>> level_outputs = self.forward(inputs)
|
539 |
-
>>> for level_output in level_outputs:
|
540 |
-
... for feature in level_output:
|
541 |
-
... print(tuple(feature.shape))
|
542 |
-
...
|
543 |
-
(1, 256, 64, 64)
|
544 |
-
(1, 256, 128, 128)
|
545 |
-
(1, 256, 64, 64)
|
546 |
-
(1, 256, 128, 128)
|
547 |
-
"""
|
548 |
-
|
549 |
-
def __init__(self,
|
550 |
-
unit_channels=256,
|
551 |
-
num_stages=4,
|
552 |
-
num_units=4,
|
553 |
-
num_blocks=[2, 2, 2, 2],
|
554 |
-
num_steps=4,
|
555 |
-
norm_cfg=dict(type='BN'),
|
556 |
-
res_top_channels=64,
|
557 |
-
expand_times=26):
|
558 |
-
# Protect mutable default arguments
|
559 |
-
norm_cfg = cp.deepcopy(norm_cfg)
|
560 |
-
num_blocks = cp.deepcopy(num_blocks)
|
561 |
-
super().__init__()
|
562 |
-
self.unit_channels = unit_channels
|
563 |
-
self.num_stages = num_stages
|
564 |
-
self.num_units = num_units
|
565 |
-
self.num_blocks = num_blocks
|
566 |
-
self.num_steps = num_steps
|
567 |
-
self.norm_cfg = norm_cfg
|
568 |
-
|
569 |
-
assert self.num_stages > 0
|
570 |
-
assert self.num_steps > 1
|
571 |
-
assert self.num_units > 1
|
572 |
-
assert self.num_units == len(self.num_blocks)
|
573 |
-
self.top = ResNet_top(norm_cfg=norm_cfg)
|
574 |
-
self.multi_stage_rsn = nn.ModuleList([])
|
575 |
-
for i in range(self.num_stages):
|
576 |
-
if i == 0:
|
577 |
-
has_skip = False
|
578 |
-
else:
|
579 |
-
has_skip = True
|
580 |
-
if i != self.num_stages - 1:
|
581 |
-
gen_skip = True
|
582 |
-
gen_cross_conv = True
|
583 |
-
else:
|
584 |
-
gen_skip = False
|
585 |
-
gen_cross_conv = False
|
586 |
-
self.multi_stage_rsn.append(
|
587 |
-
Single_stage_RSN(has_skip, gen_skip, gen_cross_conv,
|
588 |
-
unit_channels, num_units, num_steps,
|
589 |
-
num_blocks, norm_cfg, res_top_channels,
|
590 |
-
expand_times))
|
591 |
-
|
592 |
-
def forward(self, x):
|
593 |
-
"""Model forward function."""
|
594 |
-
out_feats = []
|
595 |
-
skip1 = None
|
596 |
-
skip2 = None
|
597 |
-
x = self.top(x)
|
598 |
-
for i in range(self.num_stages):
|
599 |
-
out, skip1, skip2, x = self.multi_stage_rsn[i](x, skip1, skip2)
|
600 |
-
out_feats.append(out)
|
601 |
-
|
602 |
-
return out_feats
|
603 |
-
|
604 |
-
def init_weights(self, pretrained=None):
|
605 |
-
"""Initialize model weights."""
|
606 |
-
for m in self.multi_stage_rsn.modules():
|
607 |
-
if isinstance(m, nn.Conv2d):
|
608 |
-
kaiming_init(m)
|
609 |
-
elif isinstance(m, nn.BatchNorm2d):
|
610 |
-
constant_init(m, 1)
|
611 |
-
elif isinstance(m, nn.Linear):
|
612 |
-
normal_init(m, std=0.01)
|
613 |
-
|
614 |
-
for m in self.top.modules():
|
615 |
-
if isinstance(m, nn.Conv2d):
|
616 |
-
kaiming_init(m)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/scnet.py
DELETED
@@ -1,248 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
import torch.utils.checkpoint as cp
|
8 |
-
from mmcv.cnn import build_conv_layer, build_norm_layer
|
9 |
-
|
10 |
-
from ..builder import BACKBONES
|
11 |
-
from .resnet import Bottleneck, ResNet
|
12 |
-
|
13 |
-
|
14 |
-
class SCConv(nn.Module):
|
15 |
-
"""SCConv (Self-calibrated Convolution)
|
16 |
-
|
17 |
-
Args:
|
18 |
-
in_channels (int): The input channels of the SCConv.
|
19 |
-
out_channels (int): The output channel of the SCConv.
|
20 |
-
stride (int): stride of SCConv.
|
21 |
-
pooling_r (int): size of pooling for scconv.
|
22 |
-
conv_cfg (dict): dictionary to construct and config conv layer.
|
23 |
-
Default: None
|
24 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
25 |
-
Default: dict(type='BN')
|
26 |
-
"""
|
27 |
-
|
28 |
-
def __init__(self,
|
29 |
-
in_channels,
|
30 |
-
out_channels,
|
31 |
-
stride,
|
32 |
-
pooling_r,
|
33 |
-
conv_cfg=None,
|
34 |
-
norm_cfg=dict(type='BN', momentum=0.1)):
|
35 |
-
# Protect mutable default arguments
|
36 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
37 |
-
super().__init__()
|
38 |
-
|
39 |
-
assert in_channels == out_channels
|
40 |
-
|
41 |
-
self.k2 = nn.Sequential(
|
42 |
-
nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r),
|
43 |
-
build_conv_layer(
|
44 |
-
conv_cfg,
|
45 |
-
in_channels,
|
46 |
-
in_channels,
|
47 |
-
kernel_size=3,
|
48 |
-
stride=1,
|
49 |
-
padding=1,
|
50 |
-
bias=False),
|
51 |
-
build_norm_layer(norm_cfg, in_channels)[1],
|
52 |
-
)
|
53 |
-
self.k3 = nn.Sequential(
|
54 |
-
build_conv_layer(
|
55 |
-
conv_cfg,
|
56 |
-
in_channels,
|
57 |
-
in_channels,
|
58 |
-
kernel_size=3,
|
59 |
-
stride=1,
|
60 |
-
padding=1,
|
61 |
-
bias=False),
|
62 |
-
build_norm_layer(norm_cfg, in_channels)[1],
|
63 |
-
)
|
64 |
-
self.k4 = nn.Sequential(
|
65 |
-
build_conv_layer(
|
66 |
-
conv_cfg,
|
67 |
-
in_channels,
|
68 |
-
in_channels,
|
69 |
-
kernel_size=3,
|
70 |
-
stride=stride,
|
71 |
-
padding=1,
|
72 |
-
bias=False),
|
73 |
-
build_norm_layer(norm_cfg, out_channels)[1],
|
74 |
-
nn.ReLU(inplace=True),
|
75 |
-
)
|
76 |
-
|
77 |
-
def forward(self, x):
|
78 |
-
"""Forward function."""
|
79 |
-
identity = x
|
80 |
-
|
81 |
-
out = torch.sigmoid(
|
82 |
-
torch.add(identity, F.interpolate(self.k2(x),
|
83 |
-
identity.size()[2:])))
|
84 |
-
out = torch.mul(self.k3(x), out)
|
85 |
-
out = self.k4(out)
|
86 |
-
|
87 |
-
return out
|
88 |
-
|
89 |
-
|
90 |
-
class SCBottleneck(Bottleneck):
|
91 |
-
"""SC(Self-calibrated) Bottleneck.
|
92 |
-
|
93 |
-
Args:
|
94 |
-
in_channels (int): The input channels of the SCBottleneck block.
|
95 |
-
out_channels (int): The output channel of the SCBottleneck block.
|
96 |
-
"""
|
97 |
-
|
98 |
-
pooling_r = 4
|
99 |
-
|
100 |
-
def __init__(self, in_channels, out_channels, **kwargs):
|
101 |
-
super().__init__(in_channels, out_channels, **kwargs)
|
102 |
-
self.mid_channels = out_channels // self.expansion // 2
|
103 |
-
|
104 |
-
self.norm1_name, norm1 = build_norm_layer(
|
105 |
-
self.norm_cfg, self.mid_channels, postfix=1)
|
106 |
-
self.norm2_name, norm2 = build_norm_layer(
|
107 |
-
self.norm_cfg, self.mid_channels, postfix=2)
|
108 |
-
self.norm3_name, norm3 = build_norm_layer(
|
109 |
-
self.norm_cfg, out_channels, postfix=3)
|
110 |
-
|
111 |
-
self.conv1 = build_conv_layer(
|
112 |
-
self.conv_cfg,
|
113 |
-
in_channels,
|
114 |
-
self.mid_channels,
|
115 |
-
kernel_size=1,
|
116 |
-
stride=1,
|
117 |
-
bias=False)
|
118 |
-
self.add_module(self.norm1_name, norm1)
|
119 |
-
|
120 |
-
self.k1 = nn.Sequential(
|
121 |
-
build_conv_layer(
|
122 |
-
self.conv_cfg,
|
123 |
-
self.mid_channels,
|
124 |
-
self.mid_channels,
|
125 |
-
kernel_size=3,
|
126 |
-
stride=self.stride,
|
127 |
-
padding=1,
|
128 |
-
bias=False),
|
129 |
-
build_norm_layer(self.norm_cfg, self.mid_channels)[1],
|
130 |
-
nn.ReLU(inplace=True))
|
131 |
-
|
132 |
-
self.conv2 = build_conv_layer(
|
133 |
-
self.conv_cfg,
|
134 |
-
in_channels,
|
135 |
-
self.mid_channels,
|
136 |
-
kernel_size=1,
|
137 |
-
stride=1,
|
138 |
-
bias=False)
|
139 |
-
self.add_module(self.norm2_name, norm2)
|
140 |
-
|
141 |
-
self.scconv = SCConv(self.mid_channels, self.mid_channels, self.stride,
|
142 |
-
self.pooling_r, self.conv_cfg, self.norm_cfg)
|
143 |
-
|
144 |
-
self.conv3 = build_conv_layer(
|
145 |
-
self.conv_cfg,
|
146 |
-
self.mid_channels * 2,
|
147 |
-
out_channels,
|
148 |
-
kernel_size=1,
|
149 |
-
stride=1,
|
150 |
-
bias=False)
|
151 |
-
self.add_module(self.norm3_name, norm3)
|
152 |
-
|
153 |
-
def forward(self, x):
|
154 |
-
"""Forward function."""
|
155 |
-
|
156 |
-
def _inner_forward(x):
|
157 |
-
identity = x
|
158 |
-
|
159 |
-
out_a = self.conv1(x)
|
160 |
-
out_a = self.norm1(out_a)
|
161 |
-
out_a = self.relu(out_a)
|
162 |
-
|
163 |
-
out_a = self.k1(out_a)
|
164 |
-
|
165 |
-
out_b = self.conv2(x)
|
166 |
-
out_b = self.norm2(out_b)
|
167 |
-
out_b = self.relu(out_b)
|
168 |
-
|
169 |
-
out_b = self.scconv(out_b)
|
170 |
-
|
171 |
-
out = self.conv3(torch.cat([out_a, out_b], dim=1))
|
172 |
-
out = self.norm3(out)
|
173 |
-
|
174 |
-
if self.downsample is not None:
|
175 |
-
identity = self.downsample(x)
|
176 |
-
|
177 |
-
out += identity
|
178 |
-
|
179 |
-
return out
|
180 |
-
|
181 |
-
if self.with_cp and x.requires_grad:
|
182 |
-
out = cp.checkpoint(_inner_forward, x)
|
183 |
-
else:
|
184 |
-
out = _inner_forward(x)
|
185 |
-
|
186 |
-
out = self.relu(out)
|
187 |
-
|
188 |
-
return out
|
189 |
-
|
190 |
-
|
191 |
-
@BACKBONES.register_module()
|
192 |
-
class SCNet(ResNet):
|
193 |
-
"""SCNet backbone.
|
194 |
-
|
195 |
-
Improving Convolutional Networks with Self-Calibrated Convolutions,
|
196 |
-
Jiang-Jiang Liu, Qibin Hou, Ming-Ming Cheng, Changhu Wang, Jiashi Feng,
|
197 |
-
IEEE CVPR, 2020.
|
198 |
-
http://mftp.mmcheng.net/Papers/20cvprSCNet.pdf
|
199 |
-
|
200 |
-
Args:
|
201 |
-
depth (int): Depth of scnet, from {50, 101}.
|
202 |
-
in_channels (int): Number of input image channels. Normally 3.
|
203 |
-
base_channels (int): Number of base channels of hidden layer.
|
204 |
-
num_stages (int): SCNet stages, normally 4.
|
205 |
-
strides (Sequence[int]): Strides of the first block of each stage.
|
206 |
-
dilations (Sequence[int]): Dilation of each stage.
|
207 |
-
out_indices (Sequence[int]): Output from which stages.
|
208 |
-
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
209 |
-
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
210 |
-
the first 1x1 conv layer.
|
211 |
-
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
|
212 |
-
avg_down (bool): Use AvgPool instead of stride conv when
|
213 |
-
downsampling in the bottleneck.
|
214 |
-
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
215 |
-
-1 means not freezing any parameters.
|
216 |
-
norm_cfg (dict): Dictionary to construct and config norm layer.
|
217 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
218 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
219 |
-
and its variants only.
|
220 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
221 |
-
memory while slowing down the training speed.
|
222 |
-
zero_init_residual (bool): Whether to use zero init for last norm layer
|
223 |
-
in resblocks to let them behave as identity.
|
224 |
-
|
225 |
-
Example:
|
226 |
-
>>> from mmpose.models import SCNet
|
227 |
-
>>> import torch
|
228 |
-
>>> self = SCNet(depth=50, out_indices=(0, 1, 2, 3))
|
229 |
-
>>> self.eval()
|
230 |
-
>>> inputs = torch.rand(1, 3, 224, 224)
|
231 |
-
>>> level_outputs = self.forward(inputs)
|
232 |
-
>>> for level_out in level_outputs:
|
233 |
-
... print(tuple(level_out.shape))
|
234 |
-
(1, 256, 56, 56)
|
235 |
-
(1, 512, 28, 28)
|
236 |
-
(1, 1024, 14, 14)
|
237 |
-
(1, 2048, 7, 7)
|
238 |
-
"""
|
239 |
-
|
240 |
-
arch_settings = {
|
241 |
-
50: (SCBottleneck, [3, 4, 6, 3]),
|
242 |
-
101: (SCBottleneck, [3, 4, 23, 3])
|
243 |
-
}
|
244 |
-
|
245 |
-
def __init__(self, depth, **kwargs):
|
246 |
-
if depth not in self.arch_settings:
|
247 |
-
raise KeyError(f'invalid depth {depth} for SCNet')
|
248 |
-
super().__init__(depth, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/seresnet.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import torch.utils.checkpoint as cp
|
3 |
-
|
4 |
-
from ..builder import BACKBONES
|
5 |
-
from .resnet import Bottleneck, ResLayer, ResNet
|
6 |
-
from .utils.se_layer import SELayer
|
7 |
-
|
8 |
-
|
9 |
-
class SEBottleneck(Bottleneck):
|
10 |
-
"""SEBottleneck block for SEResNet.
|
11 |
-
|
12 |
-
Args:
|
13 |
-
in_channels (int): The input channels of the SEBottleneck block.
|
14 |
-
out_channels (int): The output channel of the SEBottleneck block.
|
15 |
-
se_ratio (int): Squeeze ratio in SELayer. Default: 16
|
16 |
-
"""
|
17 |
-
|
18 |
-
def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs):
|
19 |
-
super().__init__(in_channels, out_channels, **kwargs)
|
20 |
-
self.se_layer = SELayer(out_channels, ratio=se_ratio)
|
21 |
-
|
22 |
-
def forward(self, x):
|
23 |
-
|
24 |
-
def _inner_forward(x):
|
25 |
-
identity = x
|
26 |
-
|
27 |
-
out = self.conv1(x)
|
28 |
-
out = self.norm1(out)
|
29 |
-
out = self.relu(out)
|
30 |
-
|
31 |
-
out = self.conv2(out)
|
32 |
-
out = self.norm2(out)
|
33 |
-
out = self.relu(out)
|
34 |
-
|
35 |
-
out = self.conv3(out)
|
36 |
-
out = self.norm3(out)
|
37 |
-
|
38 |
-
out = self.se_layer(out)
|
39 |
-
|
40 |
-
if self.downsample is not None:
|
41 |
-
identity = self.downsample(x)
|
42 |
-
|
43 |
-
out += identity
|
44 |
-
|
45 |
-
return out
|
46 |
-
|
47 |
-
if self.with_cp and x.requires_grad:
|
48 |
-
out = cp.checkpoint(_inner_forward, x)
|
49 |
-
else:
|
50 |
-
out = _inner_forward(x)
|
51 |
-
|
52 |
-
out = self.relu(out)
|
53 |
-
|
54 |
-
return out
|
55 |
-
|
56 |
-
|
57 |
-
@BACKBONES.register_module()
|
58 |
-
class SEResNet(ResNet):
|
59 |
-
"""SEResNet backbone.
|
60 |
-
|
61 |
-
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`__ for
|
62 |
-
details.
|
63 |
-
|
64 |
-
Args:
|
65 |
-
depth (int): Network depth, from {50, 101, 152}.
|
66 |
-
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
|
67 |
-
in_channels (int): Number of input image channels. Default: 3.
|
68 |
-
stem_channels (int): Output channels of the stem layer. Default: 64.
|
69 |
-
num_stages (int): Stages of the network. Default: 4.
|
70 |
-
strides (Sequence[int]): Strides of the first block of each stage.
|
71 |
-
Default: ``(1, 2, 2, 2)``.
|
72 |
-
dilations (Sequence[int]): Dilation of each stage.
|
73 |
-
Default: ``(1, 1, 1, 1)``.
|
74 |
-
out_indices (Sequence[int]): Output from which stages. If only one
|
75 |
-
stage is specified, a single tensor (feature map) is returned,
|
76 |
-
otherwise multiple stages are specified, a tuple of tensors will
|
77 |
-
be returned. Default: ``(3, )``.
|
78 |
-
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
79 |
-
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
80 |
-
the first 1x1 conv layer.
|
81 |
-
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
82 |
-
Default: False.
|
83 |
-
avg_down (bool): Use AvgPool instead of stride conv when
|
84 |
-
downsampling in the bottleneck. Default: False.
|
85 |
-
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
86 |
-
-1 means not freezing any parameters. Default: -1.
|
87 |
-
conv_cfg (dict | None): The config dict for conv layers. Default: None.
|
88 |
-
norm_cfg (dict): The config dict for norm layers.
|
89 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
90 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
91 |
-
and its variants only. Default: False.
|
92 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
93 |
-
memory while slowing down the training speed. Default: False.
|
94 |
-
zero_init_residual (bool): Whether to use zero init for last norm layer
|
95 |
-
in resblocks to let them behave as identity. Default: True.
|
96 |
-
|
97 |
-
Example:
|
98 |
-
>>> from mmpose.models import SEResNet
|
99 |
-
>>> import torch
|
100 |
-
>>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3))
|
101 |
-
>>> self.eval()
|
102 |
-
>>> inputs = torch.rand(1, 3, 224, 224)
|
103 |
-
>>> level_outputs = self.forward(inputs)
|
104 |
-
>>> for level_out in level_outputs:
|
105 |
-
... print(tuple(level_out.shape))
|
106 |
-
(1, 256, 56, 56)
|
107 |
-
(1, 512, 28, 28)
|
108 |
-
(1, 1024, 14, 14)
|
109 |
-
(1, 2048, 7, 7)
|
110 |
-
"""
|
111 |
-
|
112 |
-
arch_settings = {
|
113 |
-
50: (SEBottleneck, (3, 4, 6, 3)),
|
114 |
-
101: (SEBottleneck, (3, 4, 23, 3)),
|
115 |
-
152: (SEBottleneck, (3, 8, 36, 3))
|
116 |
-
}
|
117 |
-
|
118 |
-
def __init__(self, depth, se_ratio=16, **kwargs):
|
119 |
-
if depth not in self.arch_settings:
|
120 |
-
raise KeyError(f'invalid depth {depth} for SEResNet')
|
121 |
-
self.se_ratio = se_ratio
|
122 |
-
super().__init__(depth, **kwargs)
|
123 |
-
|
124 |
-
def make_res_layer(self, **kwargs):
|
125 |
-
return ResLayer(se_ratio=self.se_ratio, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/seresnext.py
DELETED
@@ -1,168 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
from mmcv.cnn import build_conv_layer, build_norm_layer
|
3 |
-
|
4 |
-
from ..builder import BACKBONES
|
5 |
-
from .resnet import ResLayer
|
6 |
-
from .seresnet import SEBottleneck as _SEBottleneck
|
7 |
-
from .seresnet import SEResNet
|
8 |
-
|
9 |
-
|
10 |
-
class SEBottleneck(_SEBottleneck):
|
11 |
-
"""SEBottleneck block for SEResNeXt.
|
12 |
-
|
13 |
-
Args:
|
14 |
-
in_channels (int): Input channels of this block.
|
15 |
-
out_channels (int): Output channels of this block.
|
16 |
-
base_channels (int): Middle channels of the first stage. Default: 64.
|
17 |
-
groups (int): Groups of conv2.
|
18 |
-
width_per_group (int): Width per group of conv2. 64x4d indicates
|
19 |
-
``groups=64, width_per_group=4`` and 32x8d indicates
|
20 |
-
``groups=32, width_per_group=8``.
|
21 |
-
stride (int): stride of the block. Default: 1
|
22 |
-
dilation (int): dilation of convolution. Default: 1
|
23 |
-
downsample (nn.Module): downsample operation on identity branch.
|
24 |
-
Default: None
|
25 |
-
se_ratio (int): Squeeze ratio in SELayer. Default: 16
|
26 |
-
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
27 |
-
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
28 |
-
the first 1x1 conv layer.
|
29 |
-
conv_cfg (dict): dictionary to construct and config conv layer.
|
30 |
-
Default: None
|
31 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
32 |
-
Default: dict(type='BN')
|
33 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
34 |
-
memory while slowing down the training speed.
|
35 |
-
"""
|
36 |
-
|
37 |
-
def __init__(self,
|
38 |
-
in_channels,
|
39 |
-
out_channels,
|
40 |
-
base_channels=64,
|
41 |
-
groups=32,
|
42 |
-
width_per_group=4,
|
43 |
-
se_ratio=16,
|
44 |
-
**kwargs):
|
45 |
-
super().__init__(in_channels, out_channels, se_ratio, **kwargs)
|
46 |
-
self.groups = groups
|
47 |
-
self.width_per_group = width_per_group
|
48 |
-
|
49 |
-
# We follow the same rational of ResNext to compute mid_channels.
|
50 |
-
# For SEResNet bottleneck, middle channels are determined by expansion
|
51 |
-
# and out_channels, but for SEResNeXt bottleneck, it is determined by
|
52 |
-
# groups and width_per_group and the stage it is located in.
|
53 |
-
if groups != 1:
|
54 |
-
assert self.mid_channels % base_channels == 0
|
55 |
-
self.mid_channels = (
|
56 |
-
groups * width_per_group * self.mid_channels // base_channels)
|
57 |
-
|
58 |
-
self.norm1_name, norm1 = build_norm_layer(
|
59 |
-
self.norm_cfg, self.mid_channels, postfix=1)
|
60 |
-
self.norm2_name, norm2 = build_norm_layer(
|
61 |
-
self.norm_cfg, self.mid_channels, postfix=2)
|
62 |
-
self.norm3_name, norm3 = build_norm_layer(
|
63 |
-
self.norm_cfg, self.out_channels, postfix=3)
|
64 |
-
|
65 |
-
self.conv1 = build_conv_layer(
|
66 |
-
self.conv_cfg,
|
67 |
-
self.in_channels,
|
68 |
-
self.mid_channels,
|
69 |
-
kernel_size=1,
|
70 |
-
stride=self.conv1_stride,
|
71 |
-
bias=False)
|
72 |
-
self.add_module(self.norm1_name, norm1)
|
73 |
-
self.conv2 = build_conv_layer(
|
74 |
-
self.conv_cfg,
|
75 |
-
self.mid_channels,
|
76 |
-
self.mid_channels,
|
77 |
-
kernel_size=3,
|
78 |
-
stride=self.conv2_stride,
|
79 |
-
padding=self.dilation,
|
80 |
-
dilation=self.dilation,
|
81 |
-
groups=groups,
|
82 |
-
bias=False)
|
83 |
-
|
84 |
-
self.add_module(self.norm2_name, norm2)
|
85 |
-
self.conv3 = build_conv_layer(
|
86 |
-
self.conv_cfg,
|
87 |
-
self.mid_channels,
|
88 |
-
self.out_channels,
|
89 |
-
kernel_size=1,
|
90 |
-
bias=False)
|
91 |
-
self.add_module(self.norm3_name, norm3)
|
92 |
-
|
93 |
-
|
94 |
-
@BACKBONES.register_module()
|
95 |
-
class SEResNeXt(SEResNet):
|
96 |
-
"""SEResNeXt backbone.
|
97 |
-
|
98 |
-
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`__ for
|
99 |
-
details.
|
100 |
-
|
101 |
-
Args:
|
102 |
-
depth (int): Network depth, from {50, 101, 152}.
|
103 |
-
groups (int): Groups of conv2 in Bottleneck. Default: 32.
|
104 |
-
width_per_group (int): Width per group of conv2 in Bottleneck.
|
105 |
-
Default: 4.
|
106 |
-
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
|
107 |
-
in_channels (int): Number of input image channels. Default: 3.
|
108 |
-
stem_channels (int): Output channels of the stem layer. Default: 64.
|
109 |
-
num_stages (int): Stages of the network. Default: 4.
|
110 |
-
strides (Sequence[int]): Strides of the first block of each stage.
|
111 |
-
Default: ``(1, 2, 2, 2)``.
|
112 |
-
dilations (Sequence[int]): Dilation of each stage.
|
113 |
-
Default: ``(1, 1, 1, 1)``.
|
114 |
-
out_indices (Sequence[int]): Output from which stages. If only one
|
115 |
-
stage is specified, a single tensor (feature map) is returned,
|
116 |
-
otherwise multiple stages are specified, a tuple of tensors will
|
117 |
-
be returned. Default: ``(3, )``.
|
118 |
-
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
|
119 |
-
layer is the 3x3 conv layer, otherwise the stride-two layer is
|
120 |
-
the first 1x1 conv layer.
|
121 |
-
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
|
122 |
-
Default: False.
|
123 |
-
avg_down (bool): Use AvgPool instead of stride conv when
|
124 |
-
downsampling in the bottleneck. Default: False.
|
125 |
-
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
126 |
-
-1 means not freezing any parameters. Default: -1.
|
127 |
-
conv_cfg (dict | None): The config dict for conv layers. Default: None.
|
128 |
-
norm_cfg (dict): The config dict for norm layers.
|
129 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
130 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
131 |
-
and its variants only. Default: False.
|
132 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
133 |
-
memory while slowing down the training speed. Default: False.
|
134 |
-
zero_init_residual (bool): Whether to use zero init for last norm layer
|
135 |
-
in resblocks to let them behave as identity. Default: True.
|
136 |
-
|
137 |
-
Example:
|
138 |
-
>>> from mmpose.models import SEResNeXt
|
139 |
-
>>> import torch
|
140 |
-
>>> self = SEResNet(depth=50, out_indices=(0, 1, 2, 3))
|
141 |
-
>>> self.eval()
|
142 |
-
>>> inputs = torch.rand(1, 3, 224, 224)
|
143 |
-
>>> level_outputs = self.forward(inputs)
|
144 |
-
>>> for level_out in level_outputs:
|
145 |
-
... print(tuple(level_out.shape))
|
146 |
-
(1, 256, 56, 56)
|
147 |
-
(1, 512, 28, 28)
|
148 |
-
(1, 1024, 14, 14)
|
149 |
-
(1, 2048, 7, 7)
|
150 |
-
"""
|
151 |
-
|
152 |
-
arch_settings = {
|
153 |
-
50: (SEBottleneck, (3, 4, 6, 3)),
|
154 |
-
101: (SEBottleneck, (3, 4, 23, 3)),
|
155 |
-
152: (SEBottleneck, (3, 8, 36, 3))
|
156 |
-
}
|
157 |
-
|
158 |
-
def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
|
159 |
-
self.groups = groups
|
160 |
-
self.width_per_group = width_per_group
|
161 |
-
super().__init__(depth, **kwargs)
|
162 |
-
|
163 |
-
def make_res_layer(self, **kwargs):
|
164 |
-
return ResLayer(
|
165 |
-
groups=self.groups,
|
166 |
-
width_per_group=self.width_per_group,
|
167 |
-
base_channels=self.base_channels,
|
168 |
-
**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/shufflenet_v1.py
DELETED
@@ -1,329 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
import logging
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.utils.checkpoint as cp
|
8 |
-
from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
|
9 |
-
normal_init)
|
10 |
-
from torch.nn.modules.batchnorm import _BatchNorm
|
11 |
-
|
12 |
-
from ..builder import BACKBONES
|
13 |
-
from .base_backbone import BaseBackbone
|
14 |
-
from .utils import channel_shuffle, load_checkpoint, make_divisible
|
15 |
-
|
16 |
-
|
17 |
-
class ShuffleUnit(nn.Module):
|
18 |
-
"""ShuffleUnit block.
|
19 |
-
|
20 |
-
ShuffleNet unit with pointwise group convolution (GConv) and channel
|
21 |
-
shuffle.
|
22 |
-
|
23 |
-
Args:
|
24 |
-
in_channels (int): The input channels of the ShuffleUnit.
|
25 |
-
out_channels (int): The output channels of the ShuffleUnit.
|
26 |
-
groups (int, optional): The number of groups to be used in grouped 1x1
|
27 |
-
convolutions in each ShuffleUnit. Default: 3
|
28 |
-
first_block (bool, optional): Whether it is the first ShuffleUnit of a
|
29 |
-
sequential ShuffleUnits. Default: True, which means not using the
|
30 |
-
grouped 1x1 convolution.
|
31 |
-
combine (str, optional): The ways to combine the input and output
|
32 |
-
branches. Default: 'add'.
|
33 |
-
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
34 |
-
which means using conv2d.
|
35 |
-
norm_cfg (dict): Config dict for normalization layer.
|
36 |
-
Default: dict(type='BN').
|
37 |
-
act_cfg (dict): Config dict for activation layer.
|
38 |
-
Default: dict(type='ReLU').
|
39 |
-
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
40 |
-
will save some memory while slowing down the training speed.
|
41 |
-
Default: False.
|
42 |
-
|
43 |
-
Returns:
|
44 |
-
Tensor: The output tensor.
|
45 |
-
"""
|
46 |
-
|
47 |
-
def __init__(self,
|
48 |
-
in_channels,
|
49 |
-
out_channels,
|
50 |
-
groups=3,
|
51 |
-
first_block=True,
|
52 |
-
combine='add',
|
53 |
-
conv_cfg=None,
|
54 |
-
norm_cfg=dict(type='BN'),
|
55 |
-
act_cfg=dict(type='ReLU'),
|
56 |
-
with_cp=False):
|
57 |
-
# Protect mutable default arguments
|
58 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
59 |
-
act_cfg = copy.deepcopy(act_cfg)
|
60 |
-
super().__init__()
|
61 |
-
self.in_channels = in_channels
|
62 |
-
self.out_channels = out_channels
|
63 |
-
self.first_block = first_block
|
64 |
-
self.combine = combine
|
65 |
-
self.groups = groups
|
66 |
-
self.bottleneck_channels = self.out_channels // 4
|
67 |
-
self.with_cp = with_cp
|
68 |
-
|
69 |
-
if self.combine == 'add':
|
70 |
-
self.depthwise_stride = 1
|
71 |
-
self._combine_func = self._add
|
72 |
-
assert in_channels == out_channels, (
|
73 |
-
'in_channels must be equal to out_channels when combine '
|
74 |
-
'is add')
|
75 |
-
elif self.combine == 'concat':
|
76 |
-
self.depthwise_stride = 2
|
77 |
-
self._combine_func = self._concat
|
78 |
-
self.out_channels -= self.in_channels
|
79 |
-
self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
80 |
-
else:
|
81 |
-
raise ValueError(f'Cannot combine tensors with {self.combine}. '
|
82 |
-
'Only "add" and "concat" are supported')
|
83 |
-
|
84 |
-
self.first_1x1_groups = 1 if first_block else self.groups
|
85 |
-
self.g_conv_1x1_compress = ConvModule(
|
86 |
-
in_channels=self.in_channels,
|
87 |
-
out_channels=self.bottleneck_channels,
|
88 |
-
kernel_size=1,
|
89 |
-
groups=self.first_1x1_groups,
|
90 |
-
conv_cfg=conv_cfg,
|
91 |
-
norm_cfg=norm_cfg,
|
92 |
-
act_cfg=act_cfg)
|
93 |
-
|
94 |
-
self.depthwise_conv3x3_bn = ConvModule(
|
95 |
-
in_channels=self.bottleneck_channels,
|
96 |
-
out_channels=self.bottleneck_channels,
|
97 |
-
kernel_size=3,
|
98 |
-
stride=self.depthwise_stride,
|
99 |
-
padding=1,
|
100 |
-
groups=self.bottleneck_channels,
|
101 |
-
conv_cfg=conv_cfg,
|
102 |
-
norm_cfg=norm_cfg,
|
103 |
-
act_cfg=None)
|
104 |
-
|
105 |
-
self.g_conv_1x1_expand = ConvModule(
|
106 |
-
in_channels=self.bottleneck_channels,
|
107 |
-
out_channels=self.out_channels,
|
108 |
-
kernel_size=1,
|
109 |
-
groups=self.groups,
|
110 |
-
conv_cfg=conv_cfg,
|
111 |
-
norm_cfg=norm_cfg,
|
112 |
-
act_cfg=None)
|
113 |
-
|
114 |
-
self.act = build_activation_layer(act_cfg)
|
115 |
-
|
116 |
-
@staticmethod
|
117 |
-
def _add(x, out):
|
118 |
-
# residual connection
|
119 |
-
return x + out
|
120 |
-
|
121 |
-
@staticmethod
|
122 |
-
def _concat(x, out):
|
123 |
-
# concatenate along channel axis
|
124 |
-
return torch.cat((x, out), 1)
|
125 |
-
|
126 |
-
def forward(self, x):
|
127 |
-
|
128 |
-
def _inner_forward(x):
|
129 |
-
residual = x
|
130 |
-
|
131 |
-
out = self.g_conv_1x1_compress(x)
|
132 |
-
out = self.depthwise_conv3x3_bn(out)
|
133 |
-
|
134 |
-
if self.groups > 1:
|
135 |
-
out = channel_shuffle(out, self.groups)
|
136 |
-
|
137 |
-
out = self.g_conv_1x1_expand(out)
|
138 |
-
|
139 |
-
if self.combine == 'concat':
|
140 |
-
residual = self.avgpool(residual)
|
141 |
-
out = self.act(out)
|
142 |
-
out = self._combine_func(residual, out)
|
143 |
-
else:
|
144 |
-
out = self._combine_func(residual, out)
|
145 |
-
out = self.act(out)
|
146 |
-
return out
|
147 |
-
|
148 |
-
if self.with_cp and x.requires_grad:
|
149 |
-
out = cp.checkpoint(_inner_forward, x)
|
150 |
-
else:
|
151 |
-
out = _inner_forward(x)
|
152 |
-
|
153 |
-
return out
|
154 |
-
|
155 |
-
|
156 |
-
@BACKBONES.register_module()
|
157 |
-
class ShuffleNetV1(BaseBackbone):
|
158 |
-
"""ShuffleNetV1 backbone.
|
159 |
-
|
160 |
-
Args:
|
161 |
-
groups (int, optional): The number of groups to be used in grouped 1x1
|
162 |
-
convolutions in each ShuffleUnit. Default: 3.
|
163 |
-
widen_factor (float, optional): Width multiplier - adjusts the number
|
164 |
-
of channels in each layer by this amount. Default: 1.0.
|
165 |
-
out_indices (Sequence[int]): Output from which stages.
|
166 |
-
Default: (2, )
|
167 |
-
frozen_stages (int): Stages to be frozen (all param fixed).
|
168 |
-
Default: -1, which means not freezing any parameters.
|
169 |
-
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
170 |
-
which means using conv2d.
|
171 |
-
norm_cfg (dict): Config dict for normalization layer.
|
172 |
-
Default: dict(type='BN').
|
173 |
-
act_cfg (dict): Config dict for activation layer.
|
174 |
-
Default: dict(type='ReLU').
|
175 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
176 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
177 |
-
and its variants only. Default: False.
|
178 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
179 |
-
memory while slowing down the training speed. Default: False.
|
180 |
-
"""
|
181 |
-
|
182 |
-
def __init__(self,
|
183 |
-
groups=3,
|
184 |
-
widen_factor=1.0,
|
185 |
-
out_indices=(2, ),
|
186 |
-
frozen_stages=-1,
|
187 |
-
conv_cfg=None,
|
188 |
-
norm_cfg=dict(type='BN'),
|
189 |
-
act_cfg=dict(type='ReLU'),
|
190 |
-
norm_eval=False,
|
191 |
-
with_cp=False):
|
192 |
-
# Protect mutable default arguments
|
193 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
194 |
-
act_cfg = copy.deepcopy(act_cfg)
|
195 |
-
super().__init__()
|
196 |
-
self.stage_blocks = [4, 8, 4]
|
197 |
-
self.groups = groups
|
198 |
-
|
199 |
-
for index in out_indices:
|
200 |
-
if index not in range(0, 3):
|
201 |
-
raise ValueError('the item in out_indices must in '
|
202 |
-
f'range(0, 3). But received {index}')
|
203 |
-
|
204 |
-
if frozen_stages not in range(-1, 3):
|
205 |
-
raise ValueError('frozen_stages must be in range(-1, 3). '
|
206 |
-
f'But received {frozen_stages}')
|
207 |
-
self.out_indices = out_indices
|
208 |
-
self.frozen_stages = frozen_stages
|
209 |
-
self.conv_cfg = conv_cfg
|
210 |
-
self.norm_cfg = norm_cfg
|
211 |
-
self.act_cfg = act_cfg
|
212 |
-
self.norm_eval = norm_eval
|
213 |
-
self.with_cp = with_cp
|
214 |
-
|
215 |
-
if groups == 1:
|
216 |
-
channels = (144, 288, 576)
|
217 |
-
elif groups == 2:
|
218 |
-
channels = (200, 400, 800)
|
219 |
-
elif groups == 3:
|
220 |
-
channels = (240, 480, 960)
|
221 |
-
elif groups == 4:
|
222 |
-
channels = (272, 544, 1088)
|
223 |
-
elif groups == 8:
|
224 |
-
channels = (384, 768, 1536)
|
225 |
-
else:
|
226 |
-
raise ValueError(f'{groups} groups is not supported for 1x1 '
|
227 |
-
'Grouped Convolutions')
|
228 |
-
|
229 |
-
channels = [make_divisible(ch * widen_factor, 8) for ch in channels]
|
230 |
-
|
231 |
-
self.in_channels = int(24 * widen_factor)
|
232 |
-
|
233 |
-
self.conv1 = ConvModule(
|
234 |
-
in_channels=3,
|
235 |
-
out_channels=self.in_channels,
|
236 |
-
kernel_size=3,
|
237 |
-
stride=2,
|
238 |
-
padding=1,
|
239 |
-
conv_cfg=conv_cfg,
|
240 |
-
norm_cfg=norm_cfg,
|
241 |
-
act_cfg=act_cfg)
|
242 |
-
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
243 |
-
|
244 |
-
self.layers = nn.ModuleList()
|
245 |
-
for i, num_blocks in enumerate(self.stage_blocks):
|
246 |
-
first_block = (i == 0)
|
247 |
-
layer = self.make_layer(channels[i], num_blocks, first_block)
|
248 |
-
self.layers.append(layer)
|
249 |
-
|
250 |
-
def _freeze_stages(self):
|
251 |
-
if self.frozen_stages >= 0:
|
252 |
-
for param in self.conv1.parameters():
|
253 |
-
param.requires_grad = False
|
254 |
-
for i in range(self.frozen_stages):
|
255 |
-
layer = self.layers[i]
|
256 |
-
layer.eval()
|
257 |
-
for param in layer.parameters():
|
258 |
-
param.requires_grad = False
|
259 |
-
|
260 |
-
def init_weights(self, pretrained=None):
|
261 |
-
if isinstance(pretrained, str):
|
262 |
-
logger = logging.getLogger()
|
263 |
-
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
264 |
-
elif pretrained is None:
|
265 |
-
for name, m in self.named_modules():
|
266 |
-
if isinstance(m, nn.Conv2d):
|
267 |
-
if 'conv1' in name:
|
268 |
-
normal_init(m, mean=0, std=0.01)
|
269 |
-
else:
|
270 |
-
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
|
271 |
-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
272 |
-
constant_init(m, val=1, bias=0.0001)
|
273 |
-
if isinstance(m, _BatchNorm):
|
274 |
-
if m.running_mean is not None:
|
275 |
-
nn.init.constant_(m.running_mean, 0)
|
276 |
-
else:
|
277 |
-
raise TypeError('pretrained must be a str or None. But received '
|
278 |
-
f'{type(pretrained)}')
|
279 |
-
|
280 |
-
def make_layer(self, out_channels, num_blocks, first_block=False):
|
281 |
-
"""Stack ShuffleUnit blocks to make a layer.
|
282 |
-
|
283 |
-
Args:
|
284 |
-
out_channels (int): out_channels of the block.
|
285 |
-
num_blocks (int): Number of blocks.
|
286 |
-
first_block (bool, optional): Whether is the first ShuffleUnit of a
|
287 |
-
sequential ShuffleUnits. Default: False, which means using
|
288 |
-
the grouped 1x1 convolution.
|
289 |
-
"""
|
290 |
-
layers = []
|
291 |
-
for i in range(num_blocks):
|
292 |
-
first_block = first_block if i == 0 else False
|
293 |
-
combine_mode = 'concat' if i == 0 else 'add'
|
294 |
-
layers.append(
|
295 |
-
ShuffleUnit(
|
296 |
-
self.in_channels,
|
297 |
-
out_channels,
|
298 |
-
groups=self.groups,
|
299 |
-
first_block=first_block,
|
300 |
-
combine=combine_mode,
|
301 |
-
conv_cfg=self.conv_cfg,
|
302 |
-
norm_cfg=self.norm_cfg,
|
303 |
-
act_cfg=self.act_cfg,
|
304 |
-
with_cp=self.with_cp))
|
305 |
-
self.in_channels = out_channels
|
306 |
-
|
307 |
-
return nn.Sequential(*layers)
|
308 |
-
|
309 |
-
def forward(self, x):
|
310 |
-
x = self.conv1(x)
|
311 |
-
x = self.maxpool(x)
|
312 |
-
|
313 |
-
outs = []
|
314 |
-
for i, layer in enumerate(self.layers):
|
315 |
-
x = layer(x)
|
316 |
-
if i in self.out_indices:
|
317 |
-
outs.append(x)
|
318 |
-
|
319 |
-
if len(outs) == 1:
|
320 |
-
return outs[0]
|
321 |
-
return tuple(outs)
|
322 |
-
|
323 |
-
def train(self, mode=True):
|
324 |
-
super().train(mode)
|
325 |
-
self._freeze_stages()
|
326 |
-
if mode and self.norm_eval:
|
327 |
-
for m in self.modules():
|
328 |
-
if isinstance(m, _BatchNorm):
|
329 |
-
m.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/shufflenet_v2.py
DELETED
@@ -1,302 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
import logging
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.utils.checkpoint as cp
|
8 |
-
from mmcv.cnn import ConvModule, constant_init, normal_init
|
9 |
-
from torch.nn.modules.batchnorm import _BatchNorm
|
10 |
-
|
11 |
-
from ..builder import BACKBONES
|
12 |
-
from .base_backbone import BaseBackbone
|
13 |
-
from .utils import channel_shuffle, load_checkpoint
|
14 |
-
|
15 |
-
|
16 |
-
class InvertedResidual(nn.Module):
|
17 |
-
"""InvertedResidual block for ShuffleNetV2 backbone.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
in_channels (int): The input channels of the block.
|
21 |
-
out_channels (int): The output channels of the block.
|
22 |
-
stride (int): Stride of the 3x3 convolution layer. Default: 1
|
23 |
-
conv_cfg (dict): Config dict for convolution layer.
|
24 |
-
Default: None, which means using conv2d.
|
25 |
-
norm_cfg (dict): Config dict for normalization layer.
|
26 |
-
Default: dict(type='BN').
|
27 |
-
act_cfg (dict): Config dict for activation layer.
|
28 |
-
Default: dict(type='ReLU').
|
29 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
30 |
-
memory while slowing down the training speed. Default: False.
|
31 |
-
"""
|
32 |
-
|
33 |
-
def __init__(self,
|
34 |
-
in_channels,
|
35 |
-
out_channels,
|
36 |
-
stride=1,
|
37 |
-
conv_cfg=None,
|
38 |
-
norm_cfg=dict(type='BN'),
|
39 |
-
act_cfg=dict(type='ReLU'),
|
40 |
-
with_cp=False):
|
41 |
-
# Protect mutable default arguments
|
42 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
43 |
-
act_cfg = copy.deepcopy(act_cfg)
|
44 |
-
super().__init__()
|
45 |
-
self.stride = stride
|
46 |
-
self.with_cp = with_cp
|
47 |
-
|
48 |
-
branch_features = out_channels // 2
|
49 |
-
if self.stride == 1:
|
50 |
-
assert in_channels == branch_features * 2, (
|
51 |
-
f'in_channels ({in_channels}) should equal to '
|
52 |
-
f'branch_features * 2 ({branch_features * 2}) '
|
53 |
-
'when stride is 1')
|
54 |
-
|
55 |
-
if in_channels != branch_features * 2:
|
56 |
-
assert self.stride != 1, (
|
57 |
-
f'stride ({self.stride}) should not equal 1 when '
|
58 |
-
f'in_channels != branch_features * 2')
|
59 |
-
|
60 |
-
if self.stride > 1:
|
61 |
-
self.branch1 = nn.Sequential(
|
62 |
-
ConvModule(
|
63 |
-
in_channels,
|
64 |
-
in_channels,
|
65 |
-
kernel_size=3,
|
66 |
-
stride=self.stride,
|
67 |
-
padding=1,
|
68 |
-
groups=in_channels,
|
69 |
-
conv_cfg=conv_cfg,
|
70 |
-
norm_cfg=norm_cfg,
|
71 |
-
act_cfg=None),
|
72 |
-
ConvModule(
|
73 |
-
in_channels,
|
74 |
-
branch_features,
|
75 |
-
kernel_size=1,
|
76 |
-
stride=1,
|
77 |
-
padding=0,
|
78 |
-
conv_cfg=conv_cfg,
|
79 |
-
norm_cfg=norm_cfg,
|
80 |
-
act_cfg=act_cfg),
|
81 |
-
)
|
82 |
-
|
83 |
-
self.branch2 = nn.Sequential(
|
84 |
-
ConvModule(
|
85 |
-
in_channels if (self.stride > 1) else branch_features,
|
86 |
-
branch_features,
|
87 |
-
kernel_size=1,
|
88 |
-
stride=1,
|
89 |
-
padding=0,
|
90 |
-
conv_cfg=conv_cfg,
|
91 |
-
norm_cfg=norm_cfg,
|
92 |
-
act_cfg=act_cfg),
|
93 |
-
ConvModule(
|
94 |
-
branch_features,
|
95 |
-
branch_features,
|
96 |
-
kernel_size=3,
|
97 |
-
stride=self.stride,
|
98 |
-
padding=1,
|
99 |
-
groups=branch_features,
|
100 |
-
conv_cfg=conv_cfg,
|
101 |
-
norm_cfg=norm_cfg,
|
102 |
-
act_cfg=None),
|
103 |
-
ConvModule(
|
104 |
-
branch_features,
|
105 |
-
branch_features,
|
106 |
-
kernel_size=1,
|
107 |
-
stride=1,
|
108 |
-
padding=0,
|
109 |
-
conv_cfg=conv_cfg,
|
110 |
-
norm_cfg=norm_cfg,
|
111 |
-
act_cfg=act_cfg))
|
112 |
-
|
113 |
-
def forward(self, x):
|
114 |
-
|
115 |
-
def _inner_forward(x):
|
116 |
-
if self.stride > 1:
|
117 |
-
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
118 |
-
else:
|
119 |
-
x1, x2 = x.chunk(2, dim=1)
|
120 |
-
out = torch.cat((x1, self.branch2(x2)), dim=1)
|
121 |
-
|
122 |
-
out = channel_shuffle(out, 2)
|
123 |
-
|
124 |
-
return out
|
125 |
-
|
126 |
-
if self.with_cp and x.requires_grad:
|
127 |
-
out = cp.checkpoint(_inner_forward, x)
|
128 |
-
else:
|
129 |
-
out = _inner_forward(x)
|
130 |
-
|
131 |
-
return out
|
132 |
-
|
133 |
-
|
134 |
-
@BACKBONES.register_module()
|
135 |
-
class ShuffleNetV2(BaseBackbone):
|
136 |
-
"""ShuffleNetV2 backbone.
|
137 |
-
|
138 |
-
Args:
|
139 |
-
widen_factor (float): Width multiplier - adjusts the number of
|
140 |
-
channels in each layer by this amount. Default: 1.0.
|
141 |
-
out_indices (Sequence[int]): Output from which stages.
|
142 |
-
Default: (0, 1, 2, 3).
|
143 |
-
frozen_stages (int): Stages to be frozen (all param fixed).
|
144 |
-
Default: -1, which means not freezing any parameters.
|
145 |
-
conv_cfg (dict): Config dict for convolution layer.
|
146 |
-
Default: None, which means using conv2d.
|
147 |
-
norm_cfg (dict): Config dict for normalization layer.
|
148 |
-
Default: dict(type='BN').
|
149 |
-
act_cfg (dict): Config dict for activation layer.
|
150 |
-
Default: dict(type='ReLU').
|
151 |
-
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
152 |
-
freeze running stats (mean and var). Note: Effect on Batch Norm
|
153 |
-
and its variants only. Default: False.
|
154 |
-
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
155 |
-
memory while slowing down the training speed. Default: False.
|
156 |
-
"""
|
157 |
-
|
158 |
-
def __init__(self,
|
159 |
-
widen_factor=1.0,
|
160 |
-
out_indices=(3, ),
|
161 |
-
frozen_stages=-1,
|
162 |
-
conv_cfg=None,
|
163 |
-
norm_cfg=dict(type='BN'),
|
164 |
-
act_cfg=dict(type='ReLU'),
|
165 |
-
norm_eval=False,
|
166 |
-
with_cp=False):
|
167 |
-
# Protect mutable default arguments
|
168 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
169 |
-
act_cfg = copy.deepcopy(act_cfg)
|
170 |
-
super().__init__()
|
171 |
-
self.stage_blocks = [4, 8, 4]
|
172 |
-
for index in out_indices:
|
173 |
-
if index not in range(0, 4):
|
174 |
-
raise ValueError('the item in out_indices must in '
|
175 |
-
f'range(0, 4). But received {index}')
|
176 |
-
|
177 |
-
if frozen_stages not in range(-1, 4):
|
178 |
-
raise ValueError('frozen_stages must be in range(-1, 4). '
|
179 |
-
f'But received {frozen_stages}')
|
180 |
-
self.out_indices = out_indices
|
181 |
-
self.frozen_stages = frozen_stages
|
182 |
-
self.conv_cfg = conv_cfg
|
183 |
-
self.norm_cfg = norm_cfg
|
184 |
-
self.act_cfg = act_cfg
|
185 |
-
self.norm_eval = norm_eval
|
186 |
-
self.with_cp = with_cp
|
187 |
-
|
188 |
-
if widen_factor == 0.5:
|
189 |
-
channels = [48, 96, 192, 1024]
|
190 |
-
elif widen_factor == 1.0:
|
191 |
-
channels = [116, 232, 464, 1024]
|
192 |
-
elif widen_factor == 1.5:
|
193 |
-
channels = [176, 352, 704, 1024]
|
194 |
-
elif widen_factor == 2.0:
|
195 |
-
channels = [244, 488, 976, 2048]
|
196 |
-
else:
|
197 |
-
raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. '
|
198 |
-
f'But received {widen_factor}')
|
199 |
-
|
200 |
-
self.in_channels = 24
|
201 |
-
self.conv1 = ConvModule(
|
202 |
-
in_channels=3,
|
203 |
-
out_channels=self.in_channels,
|
204 |
-
kernel_size=3,
|
205 |
-
stride=2,
|
206 |
-
padding=1,
|
207 |
-
conv_cfg=conv_cfg,
|
208 |
-
norm_cfg=norm_cfg,
|
209 |
-
act_cfg=act_cfg)
|
210 |
-
|
211 |
-
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
212 |
-
|
213 |
-
self.layers = nn.ModuleList()
|
214 |
-
for i, num_blocks in enumerate(self.stage_blocks):
|
215 |
-
layer = self._make_layer(channels[i], num_blocks)
|
216 |
-
self.layers.append(layer)
|
217 |
-
|
218 |
-
output_channels = channels[-1]
|
219 |
-
self.layers.append(
|
220 |
-
ConvModule(
|
221 |
-
in_channels=self.in_channels,
|
222 |
-
out_channels=output_channels,
|
223 |
-
kernel_size=1,
|
224 |
-
conv_cfg=conv_cfg,
|
225 |
-
norm_cfg=norm_cfg,
|
226 |
-
act_cfg=act_cfg))
|
227 |
-
|
228 |
-
def _make_layer(self, out_channels, num_blocks):
|
229 |
-
"""Stack blocks to make a layer.
|
230 |
-
|
231 |
-
Args:
|
232 |
-
out_channels (int): out_channels of the block.
|
233 |
-
num_blocks (int): number of blocks.
|
234 |
-
"""
|
235 |
-
layers = []
|
236 |
-
for i in range(num_blocks):
|
237 |
-
stride = 2 if i == 0 else 1
|
238 |
-
layers.append(
|
239 |
-
InvertedResidual(
|
240 |
-
in_channels=self.in_channels,
|
241 |
-
out_channels=out_channels,
|
242 |
-
stride=stride,
|
243 |
-
conv_cfg=self.conv_cfg,
|
244 |
-
norm_cfg=self.norm_cfg,
|
245 |
-
act_cfg=self.act_cfg,
|
246 |
-
with_cp=self.with_cp))
|
247 |
-
self.in_channels = out_channels
|
248 |
-
|
249 |
-
return nn.Sequential(*layers)
|
250 |
-
|
251 |
-
def _freeze_stages(self):
|
252 |
-
if self.frozen_stages >= 0:
|
253 |
-
for param in self.conv1.parameters():
|
254 |
-
param.requires_grad = False
|
255 |
-
|
256 |
-
for i in range(self.frozen_stages):
|
257 |
-
m = self.layers[i]
|
258 |
-
m.eval()
|
259 |
-
for param in m.parameters():
|
260 |
-
param.requires_grad = False
|
261 |
-
|
262 |
-
def init_weights(self, pretrained=None):
|
263 |
-
if isinstance(pretrained, str):
|
264 |
-
logger = logging.getLogger()
|
265 |
-
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
266 |
-
elif pretrained is None:
|
267 |
-
for name, m in self.named_modules():
|
268 |
-
if isinstance(m, nn.Conv2d):
|
269 |
-
if 'conv1' in name:
|
270 |
-
normal_init(m, mean=0, std=0.01)
|
271 |
-
else:
|
272 |
-
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
|
273 |
-
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
274 |
-
constant_init(m.weight, val=1, bias=0.0001)
|
275 |
-
if isinstance(m, _BatchNorm):
|
276 |
-
if m.running_mean is not None:
|
277 |
-
nn.init.constant_(m.running_mean, 0)
|
278 |
-
else:
|
279 |
-
raise TypeError('pretrained must be a str or None. But received '
|
280 |
-
f'{type(pretrained)}')
|
281 |
-
|
282 |
-
def forward(self, x):
|
283 |
-
x = self.conv1(x)
|
284 |
-
x = self.maxpool(x)
|
285 |
-
|
286 |
-
outs = []
|
287 |
-
for i, layer in enumerate(self.layers):
|
288 |
-
x = layer(x)
|
289 |
-
if i in self.out_indices:
|
290 |
-
outs.append(x)
|
291 |
-
|
292 |
-
if len(outs) == 1:
|
293 |
-
return outs[0]
|
294 |
-
return tuple(outs)
|
295 |
-
|
296 |
-
def train(self, mode=True):
|
297 |
-
super().train(mode)
|
298 |
-
self._freeze_stages()
|
299 |
-
if mode and self.norm_eval:
|
300 |
-
for m in self.modules():
|
301 |
-
if isinstance(m, nn.BatchNorm2d):
|
302 |
-
m.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/swin.py
DELETED
@@ -1,733 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
from collections import OrderedDict
|
3 |
-
from copy import deepcopy
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import torch.utils.checkpoint as cp
|
9 |
-
from mmcv.cnn import build_norm_layer, constant_init, trunc_normal_init
|
10 |
-
from mmcv.cnn.bricks.transformer import FFN, build_dropout
|
11 |
-
from mmcv.cnn.utils.weight_init import trunc_normal_
|
12 |
-
from mmcv.runner import _load_checkpoint
|
13 |
-
from mmcv.utils import to_2tuple
|
14 |
-
|
15 |
-
from ...utils import get_root_logger
|
16 |
-
from ..builder import BACKBONES
|
17 |
-
from ..utils.transformer import PatchEmbed, PatchMerging
|
18 |
-
from .base_backbone import BaseBackbone
|
19 |
-
from .utils.ckpt_convert import swin_converter
|
20 |
-
|
21 |
-
|
22 |
-
class WindowMSA(nn.Module):
|
23 |
-
"""Window based multi-head self-attention (W-MSA) module with relative
|
24 |
-
position bias.
|
25 |
-
|
26 |
-
Args:
|
27 |
-
embed_dims (int): Number of input channels.
|
28 |
-
num_heads (int): Number of attention heads.
|
29 |
-
window_size (tuple[int]): The height and width of the window.
|
30 |
-
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
31 |
-
Default: True.
|
32 |
-
qk_scale (float | None, optional): Override default qk scale of
|
33 |
-
head_dim ** -0.5 if set. Default: None.
|
34 |
-
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
35 |
-
Default: 0.0
|
36 |
-
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
|
37 |
-
"""
|
38 |
-
|
39 |
-
def __init__(self,
|
40 |
-
embed_dims,
|
41 |
-
num_heads,
|
42 |
-
window_size,
|
43 |
-
qkv_bias=True,
|
44 |
-
qk_scale=None,
|
45 |
-
attn_drop_rate=0.,
|
46 |
-
proj_drop_rate=0.):
|
47 |
-
|
48 |
-
super().__init__()
|
49 |
-
self.embed_dims = embed_dims
|
50 |
-
self.window_size = window_size # Wh, Ww
|
51 |
-
self.num_heads = num_heads
|
52 |
-
head_embed_dims = embed_dims // num_heads
|
53 |
-
self.scale = qk_scale or head_embed_dims**-0.5
|
54 |
-
|
55 |
-
# define a parameter table of relative position bias
|
56 |
-
self.relative_position_bias_table = nn.Parameter(
|
57 |
-
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
58 |
-
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
59 |
-
|
60 |
-
# About 2x faster than original impl
|
61 |
-
Wh, Ww = self.window_size
|
62 |
-
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
|
63 |
-
rel_position_index = rel_index_coords + rel_index_coords.T
|
64 |
-
rel_position_index = rel_position_index.flip(1).contiguous()
|
65 |
-
self.register_buffer('relative_position_index', rel_position_index)
|
66 |
-
|
67 |
-
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
|
68 |
-
self.attn_drop = nn.Dropout(attn_drop_rate)
|
69 |
-
self.proj = nn.Linear(embed_dims, embed_dims)
|
70 |
-
self.proj_drop = nn.Dropout(proj_drop_rate)
|
71 |
-
|
72 |
-
self.softmax = nn.Softmax(dim=-1)
|
73 |
-
|
74 |
-
def init_weights(self):
|
75 |
-
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
76 |
-
|
77 |
-
def forward(self, x, mask=None):
|
78 |
-
"""
|
79 |
-
Args:
|
80 |
-
|
81 |
-
x (tensor): input features with shape of (num_windows*B, N, C)
|
82 |
-
mask (tensor | None, Optional): mask with shape of (num_windows,
|
83 |
-
Wh*Ww, Wh*Ww), value should be between (-inf, 0].
|
84 |
-
"""
|
85 |
-
B, N, C = x.shape
|
86 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
87 |
-
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
88 |
-
# make torchscript happy (cannot use tensor as tuple)
|
89 |
-
q, k, v = qkv[0], qkv[1], qkv[2]
|
90 |
-
|
91 |
-
q = q * self.scale
|
92 |
-
attn = (q @ k.transpose(-2, -1))
|
93 |
-
|
94 |
-
relative_position_bias = self.relative_position_bias_table[
|
95 |
-
self.relative_position_index.view(-1)].view(
|
96 |
-
self.window_size[0] * self.window_size[1],
|
97 |
-
self.window_size[0] * self.window_size[1],
|
98 |
-
-1) # Wh*Ww,Wh*Ww,nH
|
99 |
-
relative_position_bias = relative_position_bias.permute(
|
100 |
-
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
101 |
-
attn = attn + relative_position_bias.unsqueeze(0)
|
102 |
-
|
103 |
-
if mask is not None:
|
104 |
-
nW = mask.shape[0]
|
105 |
-
attn = attn.view(B // nW, nW, self.num_heads, N,
|
106 |
-
N) + mask.unsqueeze(1).unsqueeze(0)
|
107 |
-
attn = attn.view(-1, self.num_heads, N, N)
|
108 |
-
attn = self.softmax(attn)
|
109 |
-
|
110 |
-
attn = self.attn_drop(attn)
|
111 |
-
|
112 |
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
113 |
-
x = self.proj(x)
|
114 |
-
x = self.proj_drop(x)
|
115 |
-
return x
|
116 |
-
|
117 |
-
@staticmethod
|
118 |
-
def double_step_seq(step1, len1, step2, len2):
|
119 |
-
seq1 = torch.arange(0, step1 * len1, step1)
|
120 |
-
seq2 = torch.arange(0, step2 * len2, step2)
|
121 |
-
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
|
122 |
-
|
123 |
-
|
124 |
-
class ShiftWindowMSA(nn.Module):
|
125 |
-
"""Shifted Window Multihead Self-Attention Module.
|
126 |
-
|
127 |
-
Args:
|
128 |
-
embed_dims (int): Number of input channels.
|
129 |
-
num_heads (int): Number of attention heads.
|
130 |
-
window_size (int): The height and width of the window.
|
131 |
-
shift_size (int, optional): The shift step of each window towards
|
132 |
-
right-bottom. If zero, act as regular window-msa. Defaults to 0.
|
133 |
-
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
|
134 |
-
Default: True
|
135 |
-
qk_scale (float | None, optional): Override default qk scale of
|
136 |
-
head_dim ** -0.5 if set. Defaults: None.
|
137 |
-
attn_drop_rate (float, optional): Dropout ratio of attention weight.
|
138 |
-
Defaults: 0.
|
139 |
-
proj_drop_rate (float, optional): Dropout ratio of output.
|
140 |
-
Defaults: 0.
|
141 |
-
dropout_layer (dict, optional): The dropout_layer used before output.
|
142 |
-
Defaults: dict(type='DropPath', drop_prob=0.).
|
143 |
-
"""
|
144 |
-
|
145 |
-
def __init__(self,
|
146 |
-
embed_dims,
|
147 |
-
num_heads,
|
148 |
-
window_size,
|
149 |
-
shift_size=0,
|
150 |
-
qkv_bias=True,
|
151 |
-
qk_scale=None,
|
152 |
-
attn_drop_rate=0,
|
153 |
-
proj_drop_rate=0,
|
154 |
-
dropout_layer=dict(type='DropPath', drop_prob=0.)):
|
155 |
-
super().__init__()
|
156 |
-
|
157 |
-
self.window_size = window_size
|
158 |
-
self.shift_size = shift_size
|
159 |
-
assert 0 <= self.shift_size < self.window_size
|
160 |
-
|
161 |
-
self.w_msa = WindowMSA(
|
162 |
-
embed_dims=embed_dims,
|
163 |
-
num_heads=num_heads,
|
164 |
-
window_size=to_2tuple(window_size),
|
165 |
-
qkv_bias=qkv_bias,
|
166 |
-
qk_scale=qk_scale,
|
167 |
-
attn_drop_rate=attn_drop_rate,
|
168 |
-
proj_drop_rate=proj_drop_rate)
|
169 |
-
|
170 |
-
self.drop = build_dropout(dropout_layer)
|
171 |
-
|
172 |
-
def forward(self, query, hw_shape):
|
173 |
-
B, L, C = query.shape
|
174 |
-
H, W = hw_shape
|
175 |
-
assert L == H * W, 'input feature has wrong size'
|
176 |
-
query = query.view(B, H, W, C)
|
177 |
-
|
178 |
-
# pad feature maps to multiples of window size
|
179 |
-
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
180 |
-
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
181 |
-
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
|
182 |
-
H_pad, W_pad = query.shape[1], query.shape[2]
|
183 |
-
|
184 |
-
# cyclic shift
|
185 |
-
if self.shift_size > 0:
|
186 |
-
shifted_query = torch.roll(
|
187 |
-
query,
|
188 |
-
shifts=(-self.shift_size, -self.shift_size),
|
189 |
-
dims=(1, 2))
|
190 |
-
|
191 |
-
# calculate attention mask for SW-MSA
|
192 |
-
img_mask = torch.zeros((1, H_pad, W_pad, 1), device=query.device)
|
193 |
-
h_slices = (slice(0, -self.window_size),
|
194 |
-
slice(-self.window_size,
|
195 |
-
-self.shift_size), slice(-self.shift_size, None))
|
196 |
-
w_slices = (slice(0, -self.window_size),
|
197 |
-
slice(-self.window_size,
|
198 |
-
-self.shift_size), slice(-self.shift_size, None))
|
199 |
-
cnt = 0
|
200 |
-
for h in h_slices:
|
201 |
-
for w in w_slices:
|
202 |
-
img_mask[:, h, w, :] = cnt
|
203 |
-
cnt += 1
|
204 |
-
|
205 |
-
# nW, window_size, window_size, 1
|
206 |
-
mask_windows = self.window_partition(img_mask)
|
207 |
-
mask_windows = mask_windows.view(
|
208 |
-
-1, self.window_size * self.window_size)
|
209 |
-
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
210 |
-
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
211 |
-
float(-100.0)).masked_fill(
|
212 |
-
attn_mask == 0, float(0.0))
|
213 |
-
else:
|
214 |
-
shifted_query = query
|
215 |
-
attn_mask = None
|
216 |
-
|
217 |
-
# nW*B, window_size, window_size, C
|
218 |
-
query_windows = self.window_partition(shifted_query)
|
219 |
-
# nW*B, window_size*window_size, C
|
220 |
-
query_windows = query_windows.view(-1, self.window_size**2, C)
|
221 |
-
|
222 |
-
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
|
223 |
-
attn_windows = self.w_msa(query_windows, mask=attn_mask)
|
224 |
-
|
225 |
-
# merge windows
|
226 |
-
attn_windows = attn_windows.view(-1, self.window_size,
|
227 |
-
self.window_size, C)
|
228 |
-
|
229 |
-
# B H' W' C
|
230 |
-
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad)
|
231 |
-
# reverse cyclic shift
|
232 |
-
if self.shift_size > 0:
|
233 |
-
x = torch.roll(
|
234 |
-
shifted_x,
|
235 |
-
shifts=(self.shift_size, self.shift_size),
|
236 |
-
dims=(1, 2))
|
237 |
-
else:
|
238 |
-
x = shifted_x
|
239 |
-
|
240 |
-
if pad_r > 0 or pad_b:
|
241 |
-
x = x[:, :H, :W, :].contiguous()
|
242 |
-
|
243 |
-
x = x.view(B, H * W, C)
|
244 |
-
|
245 |
-
x = self.drop(x)
|
246 |
-
return x
|
247 |
-
|
248 |
-
def window_reverse(self, windows, H, W):
|
249 |
-
"""
|
250 |
-
Args:
|
251 |
-
windows: (num_windows*B, window_size, window_size, C)
|
252 |
-
H (int): Height of image
|
253 |
-
W (int): Width of image
|
254 |
-
Returns:
|
255 |
-
x: (B, H, W, C)
|
256 |
-
"""
|
257 |
-
window_size = self.window_size
|
258 |
-
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
259 |
-
x = windows.view(B, H // window_size, W // window_size, window_size,
|
260 |
-
window_size, -1)
|
261 |
-
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
262 |
-
return x
|
263 |
-
|
264 |
-
def window_partition(self, x):
|
265 |
-
"""
|
266 |
-
Args:
|
267 |
-
x: (B, H, W, C)
|
268 |
-
Returns:
|
269 |
-
windows: (num_windows*B, window_size, window_size, C)
|
270 |
-
"""
|
271 |
-
B, H, W, C = x.shape
|
272 |
-
window_size = self.window_size
|
273 |
-
x = x.view(B, H // window_size, window_size, W // window_size,
|
274 |
-
window_size, C)
|
275 |
-
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
276 |
-
windows = windows.view(-1, window_size, window_size, C)
|
277 |
-
return windows
|
278 |
-
|
279 |
-
|
280 |
-
class SwinBlock(nn.Module):
|
281 |
-
""""
|
282 |
-
Args:
|
283 |
-
embed_dims (int): The feature dimension.
|
284 |
-
num_heads (int): Parallel attention heads.
|
285 |
-
feedforward_channels (int): The hidden dimension for FFNs.
|
286 |
-
window_size (int, optional): The local window scale. Default: 7.
|
287 |
-
shift (bool, optional): whether to shift window or not. Default False.
|
288 |
-
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
289 |
-
qk_scale (float | None, optional): Override default qk scale of
|
290 |
-
head_dim ** -0.5 if set. Default: None.
|
291 |
-
drop_rate (float, optional): Dropout rate. Default: 0.
|
292 |
-
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
293 |
-
drop_path_rate (float, optional): Stochastic depth rate. Default: 0.
|
294 |
-
act_cfg (dict, optional): The config dict of activation function.
|
295 |
-
Default: dict(type='GELU').
|
296 |
-
norm_cfg (dict, optional): The config dict of normalization.
|
297 |
-
Default: dict(type='LN').
|
298 |
-
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
299 |
-
will save some memory while slowing down the training speed.
|
300 |
-
Default: False.
|
301 |
-
"""
|
302 |
-
|
303 |
-
def __init__(self,
|
304 |
-
embed_dims,
|
305 |
-
num_heads,
|
306 |
-
feedforward_channels,
|
307 |
-
window_size=7,
|
308 |
-
shift=False,
|
309 |
-
qkv_bias=True,
|
310 |
-
qk_scale=None,
|
311 |
-
drop_rate=0.,
|
312 |
-
attn_drop_rate=0.,
|
313 |
-
drop_path_rate=0.,
|
314 |
-
act_cfg=dict(type='GELU'),
|
315 |
-
norm_cfg=dict(type='LN'),
|
316 |
-
with_cp=False):
|
317 |
-
|
318 |
-
super(SwinBlock, self).__init__()
|
319 |
-
|
320 |
-
self.with_cp = with_cp
|
321 |
-
|
322 |
-
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
|
323 |
-
self.attn = ShiftWindowMSA(
|
324 |
-
embed_dims=embed_dims,
|
325 |
-
num_heads=num_heads,
|
326 |
-
window_size=window_size,
|
327 |
-
shift_size=window_size // 2 if shift else 0,
|
328 |
-
qkv_bias=qkv_bias,
|
329 |
-
qk_scale=qk_scale,
|
330 |
-
attn_drop_rate=attn_drop_rate,
|
331 |
-
proj_drop_rate=drop_rate,
|
332 |
-
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate))
|
333 |
-
|
334 |
-
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
|
335 |
-
self.ffn = FFN(
|
336 |
-
embed_dims=embed_dims,
|
337 |
-
feedforward_channels=feedforward_channels,
|
338 |
-
num_fcs=2,
|
339 |
-
ffn_drop=drop_rate,
|
340 |
-
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
341 |
-
act_cfg=act_cfg,
|
342 |
-
add_identity=True,
|
343 |
-
init_cfg=None)
|
344 |
-
|
345 |
-
def forward(self, x, hw_shape):
|
346 |
-
|
347 |
-
def _inner_forward(x):
|
348 |
-
identity = x
|
349 |
-
x = self.norm1(x)
|
350 |
-
x = self.attn(x, hw_shape)
|
351 |
-
|
352 |
-
x = x + identity
|
353 |
-
|
354 |
-
identity = x
|
355 |
-
x = self.norm2(x)
|
356 |
-
x = self.ffn(x, identity=identity)
|
357 |
-
|
358 |
-
return x
|
359 |
-
|
360 |
-
if self.with_cp and x.requires_grad:
|
361 |
-
x = cp.checkpoint(_inner_forward, x)
|
362 |
-
else:
|
363 |
-
x = _inner_forward(x)
|
364 |
-
|
365 |
-
return x
|
366 |
-
|
367 |
-
|
368 |
-
class SwinBlockSequence(nn.Module):
|
369 |
-
"""Implements one stage in Swin Transformer.
|
370 |
-
|
371 |
-
Args:
|
372 |
-
embed_dims (int): The feature dimension.
|
373 |
-
num_heads (int): Parallel attention heads.
|
374 |
-
feedforward_channels (int): The hidden dimension for FFNs.
|
375 |
-
depth (int): The number of blocks in this stage.
|
376 |
-
window_size (int, optional): The local window scale. Default: 7.
|
377 |
-
qkv_bias (bool, optional): enable bias for qkv if True. Default: True.
|
378 |
-
qk_scale (float | None, optional): Override default qk scale of
|
379 |
-
head_dim ** -0.5 if set. Default: None.
|
380 |
-
drop_rate (float, optional): Dropout rate. Default: 0.
|
381 |
-
attn_drop_rate (float, optional): Attention dropout rate. Default: 0.
|
382 |
-
drop_path_rate (float | list[float], optional): Stochastic depth
|
383 |
-
rate. Default: 0.
|
384 |
-
downsample (nn.Module | None, optional): The downsample operation
|
385 |
-
module. Default: None.
|
386 |
-
act_cfg (dict, optional): The config dict of activation function.
|
387 |
-
Default: dict(type='GELU').
|
388 |
-
norm_cfg (dict, optional): The config dict of normalization.
|
389 |
-
Default: dict(type='LN').
|
390 |
-
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
391 |
-
will save some memory while slowing down the training speed.
|
392 |
-
Default: False.
|
393 |
-
"""
|
394 |
-
|
395 |
-
def __init__(self,
|
396 |
-
embed_dims,
|
397 |
-
num_heads,
|
398 |
-
feedforward_channels,
|
399 |
-
depth,
|
400 |
-
window_size=7,
|
401 |
-
qkv_bias=True,
|
402 |
-
qk_scale=None,
|
403 |
-
drop_rate=0.,
|
404 |
-
attn_drop_rate=0.,
|
405 |
-
drop_path_rate=0.,
|
406 |
-
downsample=None,
|
407 |
-
act_cfg=dict(type='GELU'),
|
408 |
-
norm_cfg=dict(type='LN'),
|
409 |
-
with_cp=False):
|
410 |
-
super().__init__()
|
411 |
-
|
412 |
-
if isinstance(drop_path_rate, list):
|
413 |
-
drop_path_rates = drop_path_rate
|
414 |
-
assert len(drop_path_rates) == depth
|
415 |
-
else:
|
416 |
-
drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
|
417 |
-
|
418 |
-
self.blocks = nn.ModuleList()
|
419 |
-
for i in range(depth):
|
420 |
-
block = SwinBlock(
|
421 |
-
embed_dims=embed_dims,
|
422 |
-
num_heads=num_heads,
|
423 |
-
feedforward_channels=feedforward_channels,
|
424 |
-
window_size=window_size,
|
425 |
-
shift=False if i % 2 == 0 else True,
|
426 |
-
qkv_bias=qkv_bias,
|
427 |
-
qk_scale=qk_scale,
|
428 |
-
drop_rate=drop_rate,
|
429 |
-
attn_drop_rate=attn_drop_rate,
|
430 |
-
drop_path_rate=drop_path_rates[i],
|
431 |
-
act_cfg=act_cfg,
|
432 |
-
norm_cfg=norm_cfg,
|
433 |
-
with_cp=with_cp)
|
434 |
-
self.blocks.append(block)
|
435 |
-
|
436 |
-
self.downsample = downsample
|
437 |
-
|
438 |
-
def forward(self, x, hw_shape):
|
439 |
-
for block in self.blocks:
|
440 |
-
x = block(x, hw_shape)
|
441 |
-
|
442 |
-
if self.downsample:
|
443 |
-
x_down, down_hw_shape = self.downsample(x, hw_shape)
|
444 |
-
return x_down, down_hw_shape, x, hw_shape
|
445 |
-
else:
|
446 |
-
return x, hw_shape, x, hw_shape
|
447 |
-
|
448 |
-
|
449 |
-
@BACKBONES.register_module()
|
450 |
-
class SwinTransformer(BaseBackbone):
|
451 |
-
""" Swin Transformer
|
452 |
-
A PyTorch implement of : `Swin Transformer:
|
453 |
-
Hierarchical Vision Transformer using Shifted Windows` -
|
454 |
-
https://arxiv.org/abs/2103.14030
|
455 |
-
|
456 |
-
Inspiration from
|
457 |
-
https://github.com/microsoft/Swin-Transformer
|
458 |
-
|
459 |
-
Args:
|
460 |
-
pretrain_img_size (int | tuple[int]): The size of input image when
|
461 |
-
pretrain. Defaults: 224.
|
462 |
-
in_channels (int): The num of input channels.
|
463 |
-
Defaults: 3.
|
464 |
-
embed_dims (int): The feature dimension. Default: 96.
|
465 |
-
patch_size (int | tuple[int]): Patch size. Default: 4.
|
466 |
-
window_size (int): Window size. Default: 7.
|
467 |
-
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
|
468 |
-
Default: 4.
|
469 |
-
depths (tuple[int]): Depths of each Swin Transformer stage.
|
470 |
-
Default: (2, 2, 6, 2).
|
471 |
-
num_heads (tuple[int]): Parallel attention heads of each Swin
|
472 |
-
Transformer stage. Default: (3, 6, 12, 24).
|
473 |
-
strides (tuple[int]): The patch merging or patch embedding stride of
|
474 |
-
each Swin Transformer stage. (In swin, we set kernel size equal to
|
475 |
-
stride.) Default: (4, 2, 2, 2).
|
476 |
-
out_indices (tuple[int]): Output from which stages.
|
477 |
-
Default: (0, 1, 2, 3).
|
478 |
-
qkv_bias (bool, optional): If True, add a learnable bias to query, key,
|
479 |
-
value. Default: True
|
480 |
-
qk_scale (float | None, optional): Override default qk scale of
|
481 |
-
head_dim ** -0.5 if set. Default: None.
|
482 |
-
patch_norm (bool): If add a norm layer for patch embed and patch
|
483 |
-
merging. Default: True.
|
484 |
-
drop_rate (float): Dropout rate. Defaults: 0.
|
485 |
-
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
486 |
-
drop_path_rate (float): Stochastic depth rate. Defaults: 0.1.
|
487 |
-
use_abs_pos_embed (bool): If True, add absolute position embedding to
|
488 |
-
the patch embedding. Defaults: False.
|
489 |
-
act_cfg (dict): Config dict for activation layer.
|
490 |
-
Default: dict(type='LN').
|
491 |
-
norm_cfg (dict): Config dict for normalization layer at
|
492 |
-
output of backone. Defaults: dict(type='LN').
|
493 |
-
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
494 |
-
will save some memory while slowing down the training speed.
|
495 |
-
Default: False.
|
496 |
-
pretrained (str, optional): model pretrained path. Default: None.
|
497 |
-
convert_weights (bool): The flag indicates whether the
|
498 |
-
pre-trained model is from the original repo. We may need
|
499 |
-
to convert some keys to make it compatible.
|
500 |
-
Default: False.
|
501 |
-
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
502 |
-
Default: -1 (-1 means not freezing any parameters).
|
503 |
-
"""
|
504 |
-
|
505 |
-
def __init__(
|
506 |
-
self,
|
507 |
-
pretrain_img_size=224,
|
508 |
-
in_channels=3,
|
509 |
-
embed_dims=96,
|
510 |
-
patch_size=4,
|
511 |
-
window_size=7,
|
512 |
-
mlp_ratio=4,
|
513 |
-
depths=(2, 2, 6, 2),
|
514 |
-
num_heads=(3, 6, 12, 24),
|
515 |
-
strides=(4, 2, 2, 2),
|
516 |
-
out_indices=(0, 1, 2, 3),
|
517 |
-
qkv_bias=True,
|
518 |
-
qk_scale=None,
|
519 |
-
patch_norm=True,
|
520 |
-
drop_rate=0.,
|
521 |
-
attn_drop_rate=0.,
|
522 |
-
drop_path_rate=0.1,
|
523 |
-
use_abs_pos_embed=False,
|
524 |
-
act_cfg=dict(type='GELU'),
|
525 |
-
norm_cfg=dict(type='LN'),
|
526 |
-
with_cp=False,
|
527 |
-
convert_weights=False,
|
528 |
-
frozen_stages=-1,
|
529 |
-
):
|
530 |
-
self.convert_weights = convert_weights
|
531 |
-
self.frozen_stages = frozen_stages
|
532 |
-
if isinstance(pretrain_img_size, int):
|
533 |
-
pretrain_img_size = to_2tuple(pretrain_img_size)
|
534 |
-
elif isinstance(pretrain_img_size, tuple):
|
535 |
-
if len(pretrain_img_size) == 1:
|
536 |
-
pretrain_img_size = to_2tuple(pretrain_img_size[0])
|
537 |
-
assert len(pretrain_img_size) == 2, \
|
538 |
-
f'The size of image should have length 1 or 2, ' \
|
539 |
-
f'but got {len(pretrain_img_size)}'
|
540 |
-
|
541 |
-
super(SwinTransformer, self).__init__()
|
542 |
-
|
543 |
-
num_layers = len(depths)
|
544 |
-
self.out_indices = out_indices
|
545 |
-
self.use_abs_pos_embed = use_abs_pos_embed
|
546 |
-
|
547 |
-
assert strides[0] == patch_size, 'Use non-overlapping patch embed.'
|
548 |
-
|
549 |
-
self.patch_embed = PatchEmbed(
|
550 |
-
in_channels=in_channels,
|
551 |
-
embed_dims=embed_dims,
|
552 |
-
conv_type='Conv2d',
|
553 |
-
kernel_size=patch_size,
|
554 |
-
stride=strides[0],
|
555 |
-
norm_cfg=norm_cfg if patch_norm else None,
|
556 |
-
init_cfg=None)
|
557 |
-
|
558 |
-
if self.use_abs_pos_embed:
|
559 |
-
patch_row = pretrain_img_size[0] // patch_size
|
560 |
-
patch_col = pretrain_img_size[1] // patch_size
|
561 |
-
num_patches = patch_row * patch_col
|
562 |
-
self.absolute_pos_embed = nn.Parameter(
|
563 |
-
torch.zeros((1, num_patches, embed_dims)))
|
564 |
-
|
565 |
-
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
566 |
-
|
567 |
-
# set stochastic depth decay rule
|
568 |
-
total_depth = sum(depths)
|
569 |
-
dpr = [
|
570 |
-
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
|
571 |
-
]
|
572 |
-
|
573 |
-
self.stages = nn.ModuleList()
|
574 |
-
in_channels = embed_dims
|
575 |
-
for i in range(num_layers):
|
576 |
-
if i < num_layers - 1:
|
577 |
-
downsample = PatchMerging(
|
578 |
-
in_channels=in_channels,
|
579 |
-
out_channels=2 * in_channels,
|
580 |
-
stride=strides[i + 1],
|
581 |
-
norm_cfg=norm_cfg if patch_norm else None,
|
582 |
-
init_cfg=None)
|
583 |
-
else:
|
584 |
-
downsample = None
|
585 |
-
|
586 |
-
stage = SwinBlockSequence(
|
587 |
-
embed_dims=in_channels,
|
588 |
-
num_heads=num_heads[i],
|
589 |
-
feedforward_channels=mlp_ratio * in_channels,
|
590 |
-
depth=depths[i],
|
591 |
-
window_size=window_size,
|
592 |
-
qkv_bias=qkv_bias,
|
593 |
-
qk_scale=qk_scale,
|
594 |
-
drop_rate=drop_rate,
|
595 |
-
attn_drop_rate=attn_drop_rate,
|
596 |
-
drop_path_rate=dpr[sum(depths[:i]):sum(depths[:i + 1])],
|
597 |
-
downsample=downsample,
|
598 |
-
act_cfg=act_cfg,
|
599 |
-
norm_cfg=norm_cfg,
|
600 |
-
with_cp=with_cp)
|
601 |
-
self.stages.append(stage)
|
602 |
-
if downsample:
|
603 |
-
in_channels = downsample.out_channels
|
604 |
-
|
605 |
-
self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)]
|
606 |
-
# Add a norm layer for each output
|
607 |
-
for i in out_indices:
|
608 |
-
layer = build_norm_layer(norm_cfg, self.num_features[i])[1]
|
609 |
-
layer_name = f'norm{i}'
|
610 |
-
self.add_module(layer_name, layer)
|
611 |
-
|
612 |
-
def train(self, mode=True):
|
613 |
-
"""Convert the model into training mode while keep layers freezed."""
|
614 |
-
super(SwinTransformer, self).train(mode)
|
615 |
-
self._freeze_stages()
|
616 |
-
|
617 |
-
def _freeze_stages(self):
|
618 |
-
if self.frozen_stages >= 0:
|
619 |
-
self.patch_embed.eval()
|
620 |
-
for param in self.patch_embed.parameters():
|
621 |
-
param.requires_grad = False
|
622 |
-
if self.use_abs_pos_embed:
|
623 |
-
self.absolute_pos_embed.requires_grad = False
|
624 |
-
self.drop_after_pos.eval()
|
625 |
-
|
626 |
-
for i in range(1, self.frozen_stages + 1):
|
627 |
-
|
628 |
-
if (i - 1) in self.out_indices:
|
629 |
-
norm_layer = getattr(self, f'norm{i-1}')
|
630 |
-
norm_layer.eval()
|
631 |
-
for param in norm_layer.parameters():
|
632 |
-
param.requires_grad = False
|
633 |
-
|
634 |
-
m = self.stages[i - 1]
|
635 |
-
m.eval()
|
636 |
-
for param in m.parameters():
|
637 |
-
param.requires_grad = False
|
638 |
-
|
639 |
-
def init_weights(self, pretrained=None):
|
640 |
-
"""Initialize the weights in backbone.
|
641 |
-
|
642 |
-
Args:
|
643 |
-
pretrained (str, optional): Path to pre-trained weights.
|
644 |
-
Defaults to None.
|
645 |
-
"""
|
646 |
-
if isinstance(pretrained, str):
|
647 |
-
logger = get_root_logger()
|
648 |
-
ckpt = _load_checkpoint(
|
649 |
-
pretrained, logger=logger, map_location='cpu')
|
650 |
-
if 'state_dict' in ckpt:
|
651 |
-
_state_dict = ckpt['state_dict']
|
652 |
-
elif 'model' in ckpt:
|
653 |
-
_state_dict = ckpt['model']
|
654 |
-
else:
|
655 |
-
_state_dict = ckpt
|
656 |
-
if self.convert_weights:
|
657 |
-
# supported loading weight from original repo,
|
658 |
-
_state_dict = swin_converter(_state_dict)
|
659 |
-
|
660 |
-
state_dict = OrderedDict()
|
661 |
-
for k, v in _state_dict.items():
|
662 |
-
if k.startswith('backbone.'):
|
663 |
-
state_dict[k[9:]] = v
|
664 |
-
|
665 |
-
# strip prefix of state_dict
|
666 |
-
if list(state_dict.keys())[0].startswith('module.'):
|
667 |
-
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
668 |
-
|
669 |
-
# reshape absolute position embedding
|
670 |
-
if state_dict.get('absolute_pos_embed') is not None:
|
671 |
-
absolute_pos_embed = state_dict['absolute_pos_embed']
|
672 |
-
N1, L, C1 = absolute_pos_embed.size()
|
673 |
-
N2, C2, H, W = self.absolute_pos_embed.size()
|
674 |
-
if N1 != N2 or C1 != C2 or L != H * W:
|
675 |
-
logger.warning('Error in loading absolute_pos_embed, pass')
|
676 |
-
else:
|
677 |
-
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
|
678 |
-
N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
|
679 |
-
|
680 |
-
# interpolate position bias table if needed
|
681 |
-
relative_position_bias_table_keys = [
|
682 |
-
k for k in state_dict.keys()
|
683 |
-
if 'relative_position_bias_table' in k
|
684 |
-
]
|
685 |
-
for table_key in relative_position_bias_table_keys:
|
686 |
-
table_pretrained = state_dict[table_key]
|
687 |
-
table_current = self.state_dict()[table_key]
|
688 |
-
L1, nH1 = table_pretrained.size()
|
689 |
-
L2, nH2 = table_current.size()
|
690 |
-
if nH1 != nH2:
|
691 |
-
logger.warning(f'Error in loading {table_key}, pass')
|
692 |
-
elif L1 != L2:
|
693 |
-
S1 = int(L1**0.5)
|
694 |
-
S2 = int(L2**0.5)
|
695 |
-
table_pretrained_resized = F.interpolate(
|
696 |
-
table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
|
697 |
-
size=(S2, S2),
|
698 |
-
mode='bicubic')
|
699 |
-
state_dict[table_key] = table_pretrained_resized.view(
|
700 |
-
nH2, L2).permute(1, 0).contiguous()
|
701 |
-
|
702 |
-
# load state_dict
|
703 |
-
self.load_state_dict(state_dict, False)
|
704 |
-
elif pretrained is None:
|
705 |
-
if self.use_abs_pos_embed:
|
706 |
-
trunc_normal_(self.absolute_pos_embed, std=0.02)
|
707 |
-
for m in self.modules():
|
708 |
-
if isinstance(m, nn.Linear):
|
709 |
-
trunc_normal_init(m, std=.02, bias=0.)
|
710 |
-
elif isinstance(m, nn.LayerNorm):
|
711 |
-
constant_init(m, 1.0)
|
712 |
-
else:
|
713 |
-
raise TypeError('pretrained must be a str or None')
|
714 |
-
|
715 |
-
def forward(self, x):
|
716 |
-
x, hw_shape = self.patch_embed(x)
|
717 |
-
|
718 |
-
if self.use_abs_pos_embed:
|
719 |
-
x = x + self.absolute_pos_embed
|
720 |
-
x = self.drop_after_pos(x)
|
721 |
-
|
722 |
-
outs = []
|
723 |
-
for i, stage in enumerate(self.stages):
|
724 |
-
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
|
725 |
-
if i in self.out_indices:
|
726 |
-
norm_layer = getattr(self, f'norm{i}')
|
727 |
-
out = norm_layer(out)
|
728 |
-
out = out.view(-1, *out_hw_shape,
|
729 |
-
self.num_features[i]).permute(0, 3, 1,
|
730 |
-
2).contiguous()
|
731 |
-
outs.append(out)
|
732 |
-
|
733 |
-
return outs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/tcformer.py
DELETED
@@ -1,283 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import math
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
from mmcv.cnn import (build_norm_layer, constant_init, normal_init,
|
7 |
-
trunc_normal_init)
|
8 |
-
from mmcv.runner import _load_checkpoint, load_state_dict
|
9 |
-
|
10 |
-
from ...utils import get_root_logger
|
11 |
-
from ..builder import BACKBONES
|
12 |
-
from ..utils import (PatchEmbed, TCFormerDynamicBlock, TCFormerRegularBlock,
|
13 |
-
TokenConv, cluster_dpc_knn, merge_tokens,
|
14 |
-
tcformer_convert, token2map)
|
15 |
-
|
16 |
-
|
17 |
-
class CTM(nn.Module):
|
18 |
-
"""Clustering-based Token Merging module in TCFormer.
|
19 |
-
|
20 |
-
Args:
|
21 |
-
sample_ratio (float): The sample ratio of tokens.
|
22 |
-
embed_dim (int): Input token feature dimension.
|
23 |
-
dim_out (int): Output token feature dimension.
|
24 |
-
k (int): number of the nearest neighbor used i DPC-knn algorithm.
|
25 |
-
"""
|
26 |
-
|
27 |
-
def __init__(self, sample_ratio, embed_dim, dim_out, k=5):
|
28 |
-
super().__init__()
|
29 |
-
self.sample_ratio = sample_ratio
|
30 |
-
self.dim_out = dim_out
|
31 |
-
self.conv = TokenConv(
|
32 |
-
in_channels=embed_dim,
|
33 |
-
out_channels=dim_out,
|
34 |
-
kernel_size=3,
|
35 |
-
stride=2,
|
36 |
-
padding=1)
|
37 |
-
self.norm = nn.LayerNorm(self.dim_out)
|
38 |
-
self.score = nn.Linear(self.dim_out, 1)
|
39 |
-
self.k = k
|
40 |
-
|
41 |
-
def forward(self, token_dict):
|
42 |
-
token_dict = token_dict.copy()
|
43 |
-
x = self.conv(token_dict)
|
44 |
-
x = self.norm(x)
|
45 |
-
token_score = self.score(x)
|
46 |
-
token_weight = token_score.exp()
|
47 |
-
|
48 |
-
token_dict['x'] = x
|
49 |
-
B, N, C = x.shape
|
50 |
-
token_dict['token_score'] = token_score
|
51 |
-
|
52 |
-
cluster_num = max(math.ceil(N * self.sample_ratio), 1)
|
53 |
-
idx_cluster, cluster_num = cluster_dpc_knn(token_dict, cluster_num,
|
54 |
-
self.k)
|
55 |
-
down_dict = merge_tokens(token_dict, idx_cluster, cluster_num,
|
56 |
-
token_weight)
|
57 |
-
|
58 |
-
H, W = token_dict['map_size']
|
59 |
-
H = math.floor((H - 1) / 2 + 1)
|
60 |
-
W = math.floor((W - 1) / 2 + 1)
|
61 |
-
down_dict['map_size'] = [H, W]
|
62 |
-
|
63 |
-
return down_dict, token_dict
|
64 |
-
|
65 |
-
|
66 |
-
@BACKBONES.register_module()
|
67 |
-
class TCFormer(nn.Module):
|
68 |
-
"""Token Clustering Transformer (TCFormer)
|
69 |
-
|
70 |
-
Implementation of `Not All Tokens Are Equal: Human-centric Visual
|
71 |
-
Analysis via Token Clustering Transformer
|
72 |
-
<https://arxiv.org/abs/2204.08680>`
|
73 |
-
|
74 |
-
Args:
|
75 |
-
in_channels (int): Number of input channels. Default: 3.
|
76 |
-
embed_dims (list[int]): Embedding dimension. Default:
|
77 |
-
[64, 128, 256, 512].
|
78 |
-
num_heads (Sequence[int]): The attention heads of each transformer
|
79 |
-
encode layer. Default: [1, 2, 5, 8].
|
80 |
-
mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
|
81 |
-
embedding dim of each transformer block.
|
82 |
-
qkv_bias (bool): Enable bias for qkv if True. Default: True.
|
83 |
-
qk_scale (float | None, optional): Override default qk scale of
|
84 |
-
head_dim ** -0.5 if set. Default: None.
|
85 |
-
drop_rate (float): Probability of an element to be zeroed.
|
86 |
-
Default 0.0.
|
87 |
-
attn_drop_rate (float): The drop out rate for attention layer.
|
88 |
-
Default 0.0.
|
89 |
-
drop_path_rate (float): stochastic depth rate. Default 0.
|
90 |
-
norm_cfg (dict): Config dict for normalization layer.
|
91 |
-
Default: dict(type='LN', eps=1e-6).
|
92 |
-
num_layers (Sequence[int]): The layer number of each transformer encode
|
93 |
-
layer. Default: [3, 4, 6, 3].
|
94 |
-
sr_ratios (Sequence[int]): The spatial reduction rate of each
|
95 |
-
transformer block. Default: [8, 4, 2, 1].
|
96 |
-
num_stages (int): The num of stages. Default: 4.
|
97 |
-
pretrained (str, optional): model pretrained path. Default: None.
|
98 |
-
k (int): number of the nearest neighbor used for local density.
|
99 |
-
sample_ratios (list[float]): The sample ratios of CTM modules.
|
100 |
-
Default: [0.25, 0.25, 0.25]
|
101 |
-
return_map (bool): If True, transfer dynamic tokens to feature map at
|
102 |
-
last. Default: False
|
103 |
-
convert_weights (bool): The flag indicates whether the
|
104 |
-
pre-trained model is from the original repo. We may need
|
105 |
-
to convert some keys to make it compatible.
|
106 |
-
Default: True.
|
107 |
-
"""
|
108 |
-
|
109 |
-
def __init__(self,
|
110 |
-
in_channels=3,
|
111 |
-
embed_dims=[64, 128, 256, 512],
|
112 |
-
num_heads=[1, 2, 4, 8],
|
113 |
-
mlp_ratios=[4, 4, 4, 4],
|
114 |
-
qkv_bias=True,
|
115 |
-
qk_scale=None,
|
116 |
-
drop_rate=0.,
|
117 |
-
attn_drop_rate=0.,
|
118 |
-
drop_path_rate=0.,
|
119 |
-
norm_cfg=dict(type='LN', eps=1e-6),
|
120 |
-
num_layers=[3, 4, 6, 3],
|
121 |
-
sr_ratios=[8, 4, 2, 1],
|
122 |
-
num_stages=4,
|
123 |
-
pretrained=None,
|
124 |
-
k=5,
|
125 |
-
sample_ratios=[0.25, 0.25, 0.25],
|
126 |
-
return_map=False,
|
127 |
-
convert_weights=True):
|
128 |
-
super().__init__()
|
129 |
-
|
130 |
-
self.num_layers = num_layers
|
131 |
-
self.num_stages = num_stages
|
132 |
-
self.grid_stride = sr_ratios[0]
|
133 |
-
self.embed_dims = embed_dims
|
134 |
-
self.sr_ratios = sr_ratios
|
135 |
-
self.mlp_ratios = mlp_ratios
|
136 |
-
self.sample_ratios = sample_ratios
|
137 |
-
self.return_map = return_map
|
138 |
-
self.convert_weights = convert_weights
|
139 |
-
|
140 |
-
# stochastic depth decay rule
|
141 |
-
dpr = [
|
142 |
-
x.item()
|
143 |
-
for x in torch.linspace(0, drop_path_rate, sum(num_layers))
|
144 |
-
]
|
145 |
-
cur = 0
|
146 |
-
|
147 |
-
# In stage 1, use the standard transformer blocks
|
148 |
-
for i in range(1):
|
149 |
-
patch_embed = PatchEmbed(
|
150 |
-
in_channels=in_channels if i == 0 else embed_dims[i - 1],
|
151 |
-
embed_dims=embed_dims[i],
|
152 |
-
kernel_size=7,
|
153 |
-
stride=4,
|
154 |
-
padding=3,
|
155 |
-
bias=True,
|
156 |
-
norm_cfg=dict(type='LN', eps=1e-6))
|
157 |
-
|
158 |
-
block = nn.ModuleList([
|
159 |
-
TCFormerRegularBlock(
|
160 |
-
dim=embed_dims[i],
|
161 |
-
num_heads=num_heads[i],
|
162 |
-
mlp_ratio=mlp_ratios[i],
|
163 |
-
qkv_bias=qkv_bias,
|
164 |
-
qk_scale=qk_scale,
|
165 |
-
drop=drop_rate,
|
166 |
-
attn_drop=attn_drop_rate,
|
167 |
-
drop_path=dpr[cur + j],
|
168 |
-
norm_cfg=norm_cfg,
|
169 |
-
sr_ratio=sr_ratios[i]) for j in range(num_layers[i])
|
170 |
-
])
|
171 |
-
norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
|
172 |
-
|
173 |
-
cur += num_layers[i]
|
174 |
-
|
175 |
-
setattr(self, f'patch_embed{i + 1}', patch_embed)
|
176 |
-
setattr(self, f'block{i + 1}', block)
|
177 |
-
setattr(self, f'norm{i + 1}', norm)
|
178 |
-
|
179 |
-
# In stage 2~4, use TCFormerDynamicBlock for dynamic tokens
|
180 |
-
for i in range(1, num_stages):
|
181 |
-
ctm = CTM(sample_ratios[i - 1], embed_dims[i - 1], embed_dims[i],
|
182 |
-
k)
|
183 |
-
|
184 |
-
block = nn.ModuleList([
|
185 |
-
TCFormerDynamicBlock(
|
186 |
-
dim=embed_dims[i],
|
187 |
-
num_heads=num_heads[i],
|
188 |
-
mlp_ratio=mlp_ratios[i],
|
189 |
-
qkv_bias=qkv_bias,
|
190 |
-
qk_scale=qk_scale,
|
191 |
-
drop=drop_rate,
|
192 |
-
attn_drop=attn_drop_rate,
|
193 |
-
drop_path=dpr[cur + j],
|
194 |
-
norm_cfg=norm_cfg,
|
195 |
-
sr_ratio=sr_ratios[i]) for j in range(num_layers[i])
|
196 |
-
])
|
197 |
-
norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
|
198 |
-
cur += num_layers[i]
|
199 |
-
|
200 |
-
setattr(self, f'ctm{i}', ctm)
|
201 |
-
setattr(self, f'block{i + 1}', block)
|
202 |
-
setattr(self, f'norm{i + 1}', norm)
|
203 |
-
|
204 |
-
self.init_weights(pretrained)
|
205 |
-
|
206 |
-
def init_weights(self, pretrained=None):
|
207 |
-
if isinstance(pretrained, str):
|
208 |
-
logger = get_root_logger()
|
209 |
-
|
210 |
-
checkpoint = _load_checkpoint(
|
211 |
-
pretrained, logger=logger, map_location='cpu')
|
212 |
-
logger.warning(f'Load pre-trained model for '
|
213 |
-
f'{self.__class__.__name__} from original repo')
|
214 |
-
if 'state_dict' in checkpoint:
|
215 |
-
state_dict = checkpoint['state_dict']
|
216 |
-
elif 'model' in checkpoint:
|
217 |
-
state_dict = checkpoint['model']
|
218 |
-
else:
|
219 |
-
state_dict = checkpoint
|
220 |
-
|
221 |
-
if self.convert_weights:
|
222 |
-
# We need to convert pre-trained weights to match this
|
223 |
-
# implementation.
|
224 |
-
state_dict = tcformer_convert(state_dict)
|
225 |
-
load_state_dict(self, state_dict, strict=False, logger=logger)
|
226 |
-
|
227 |
-
elif pretrained is None:
|
228 |
-
for m in self.modules():
|
229 |
-
if isinstance(m, nn.Linear):
|
230 |
-
trunc_normal_init(m, std=.02, bias=0.)
|
231 |
-
elif isinstance(m, nn.LayerNorm):
|
232 |
-
constant_init(m, 1.0)
|
233 |
-
elif isinstance(m, nn.Conv2d):
|
234 |
-
fan_out = m.kernel_size[0] * m.kernel_size[
|
235 |
-
1] * m.out_channels
|
236 |
-
fan_out //= m.groups
|
237 |
-
normal_init(m, 0, math.sqrt(2.0 / fan_out))
|
238 |
-
else:
|
239 |
-
raise TypeError('pretrained must be a str or None')
|
240 |
-
|
241 |
-
def forward(self, x):
|
242 |
-
outs = []
|
243 |
-
|
244 |
-
i = 0
|
245 |
-
patch_embed = getattr(self, f'patch_embed{i + 1}')
|
246 |
-
block = getattr(self, f'block{i + 1}')
|
247 |
-
norm = getattr(self, f'norm{i + 1}')
|
248 |
-
x, (H, W) = patch_embed(x)
|
249 |
-
for blk in block:
|
250 |
-
x = blk(x, H, W)
|
251 |
-
x = norm(x)
|
252 |
-
|
253 |
-
# init token dict
|
254 |
-
B, N, _ = x.shape
|
255 |
-
device = x.device
|
256 |
-
idx_token = torch.arange(N)[None, :].repeat(B, 1).to(device)
|
257 |
-
agg_weight = x.new_ones(B, N, 1)
|
258 |
-
token_dict = {
|
259 |
-
'x': x,
|
260 |
-
'token_num': N,
|
261 |
-
'map_size': [H, W],
|
262 |
-
'init_grid_size': [H, W],
|
263 |
-
'idx_token': idx_token,
|
264 |
-
'agg_weight': agg_weight
|
265 |
-
}
|
266 |
-
outs.append(token_dict.copy())
|
267 |
-
|
268 |
-
# stage 2~4
|
269 |
-
for i in range(1, self.num_stages):
|
270 |
-
ctm = getattr(self, f'ctm{i}')
|
271 |
-
block = getattr(self, f'block{i + 1}')
|
272 |
-
norm = getattr(self, f'norm{i + 1}')
|
273 |
-
|
274 |
-
token_dict = ctm(token_dict) # down sample
|
275 |
-
for j, blk in enumerate(block):
|
276 |
-
token_dict = blk(token_dict)
|
277 |
-
|
278 |
-
token_dict['x'] = norm(token_dict['x'])
|
279 |
-
outs.append(token_dict)
|
280 |
-
|
281 |
-
if self.return_map:
|
282 |
-
outs = [token2map(token_dict) for token_dict in outs]
|
283 |
-
return outs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/tcn.py
DELETED
@@ -1,267 +0,0 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
-
import copy
|
3 |
-
|
4 |
-
import torch.nn as nn
|
5 |
-
from mmcv.cnn import ConvModule, build_conv_layer, constant_init, kaiming_init
|
6 |
-
from mmcv.utils.parrots_wrapper import _BatchNorm
|
7 |
-
|
8 |
-
from mmpose.core import WeightNormClipHook
|
9 |
-
from ..builder import BACKBONES
|
10 |
-
from .base_backbone import BaseBackbone
|
11 |
-
|
12 |
-
|
13 |
-
class BasicTemporalBlock(nn.Module):
|
14 |
-
"""Basic block for VideoPose3D.
|
15 |
-
|
16 |
-
Args:
|
17 |
-
in_channels (int): Input channels of this block.
|
18 |
-
out_channels (int): Output channels of this block.
|
19 |
-
mid_channels (int): The output channels of conv1. Default: 1024.
|
20 |
-
kernel_size (int): Size of the convolving kernel. Default: 3.
|
21 |
-
dilation (int): Spacing between kernel elements. Default: 3.
|
22 |
-
dropout (float): Dropout rate. Default: 0.25.
|
23 |
-
causal (bool): Use causal convolutions instead of symmetric
|
24 |
-
convolutions (for real-time applications). Default: False.
|
25 |
-
residual (bool): Use residual connection. Default: True.
|
26 |
-
use_stride_conv (bool): Use optimized TCN that designed
|
27 |
-
specifically for single-frame batching, i.e. where batches have
|
28 |
-
input length = receptive field, and output length = 1. This
|
29 |
-
implementation replaces dilated convolutions with strided
|
30 |
-
convolutions to avoid generating unused intermediate results.
|
31 |
-
Default: False.
|
32 |
-
conv_cfg (dict): dictionary to construct and config conv layer.
|
33 |
-
Default: dict(type='Conv1d').
|
34 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
35 |
-
Default: dict(type='BN1d').
|
36 |
-
"""
|
37 |
-
|
38 |
-
def __init__(self,
|
39 |
-
in_channels,
|
40 |
-
out_channels,
|
41 |
-
mid_channels=1024,
|
42 |
-
kernel_size=3,
|
43 |
-
dilation=3,
|
44 |
-
dropout=0.25,
|
45 |
-
causal=False,
|
46 |
-
residual=True,
|
47 |
-
use_stride_conv=False,
|
48 |
-
conv_cfg=dict(type='Conv1d'),
|
49 |
-
norm_cfg=dict(type='BN1d')):
|
50 |
-
# Protect mutable default arguments
|
51 |
-
conv_cfg = copy.deepcopy(conv_cfg)
|
52 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
53 |
-
super().__init__()
|
54 |
-
self.in_channels = in_channels
|
55 |
-
self.out_channels = out_channels
|
56 |
-
self.mid_channels = mid_channels
|
57 |
-
self.kernel_size = kernel_size
|
58 |
-
self.dilation = dilation
|
59 |
-
self.dropout = dropout
|
60 |
-
self.causal = causal
|
61 |
-
self.residual = residual
|
62 |
-
self.use_stride_conv = use_stride_conv
|
63 |
-
|
64 |
-
self.pad = (kernel_size - 1) * dilation // 2
|
65 |
-
if use_stride_conv:
|
66 |
-
self.stride = kernel_size
|
67 |
-
self.causal_shift = kernel_size // 2 if causal else 0
|
68 |
-
self.dilation = 1
|
69 |
-
else:
|
70 |
-
self.stride = 1
|
71 |
-
self.causal_shift = kernel_size // 2 * dilation if causal else 0
|
72 |
-
|
73 |
-
self.conv1 = nn.Sequential(
|
74 |
-
ConvModule(
|
75 |
-
in_channels,
|
76 |
-
mid_channels,
|
77 |
-
kernel_size=kernel_size,
|
78 |
-
stride=self.stride,
|
79 |
-
dilation=self.dilation,
|
80 |
-
bias='auto',
|
81 |
-
conv_cfg=conv_cfg,
|
82 |
-
norm_cfg=norm_cfg))
|
83 |
-
self.conv2 = nn.Sequential(
|
84 |
-
ConvModule(
|
85 |
-
mid_channels,
|
86 |
-
out_channels,
|
87 |
-
kernel_size=1,
|
88 |
-
bias='auto',
|
89 |
-
conv_cfg=conv_cfg,
|
90 |
-
norm_cfg=norm_cfg))
|
91 |
-
|
92 |
-
if residual and in_channels != out_channels:
|
93 |
-
self.short_cut = build_conv_layer(conv_cfg, in_channels,
|
94 |
-
out_channels, 1)
|
95 |
-
else:
|
96 |
-
self.short_cut = None
|
97 |
-
|
98 |
-
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
|
99 |
-
|
100 |
-
def forward(self, x):
|
101 |
-
"""Forward function."""
|
102 |
-
if self.use_stride_conv:
|
103 |
-
assert self.causal_shift + self.kernel_size // 2 < x.shape[2]
|
104 |
-
else:
|
105 |
-
assert 0 <= self.pad + self.causal_shift < x.shape[2] - \
|
106 |
-
self.pad + self.causal_shift <= x.shape[2]
|
107 |
-
|
108 |
-
out = self.conv1(x)
|
109 |
-
if self.dropout is not None:
|
110 |
-
out = self.dropout(out)
|
111 |
-
|
112 |
-
out = self.conv2(out)
|
113 |
-
if self.dropout is not None:
|
114 |
-
out = self.dropout(out)
|
115 |
-
|
116 |
-
if self.residual:
|
117 |
-
if self.use_stride_conv:
|
118 |
-
res = x[:, :, self.causal_shift +
|
119 |
-
self.kernel_size // 2::self.kernel_size]
|
120 |
-
else:
|
121 |
-
res = x[:, :,
|
122 |
-
(self.pad + self.causal_shift):(x.shape[2] - self.pad +
|
123 |
-
self.causal_shift)]
|
124 |
-
|
125 |
-
if self.short_cut is not None:
|
126 |
-
res = self.short_cut(res)
|
127 |
-
out = out + res
|
128 |
-
|
129 |
-
return out
|
130 |
-
|
131 |
-
|
132 |
-
@BACKBONES.register_module()
|
133 |
-
class TCN(BaseBackbone):
|
134 |
-
"""TCN backbone.
|
135 |
-
|
136 |
-
Temporal Convolutional Networks.
|
137 |
-
More details can be found in the
|
138 |
-
`paper <https://arxiv.org/abs/1811.11742>`__ .
|
139 |
-
|
140 |
-
Args:
|
141 |
-
in_channels (int): Number of input channels, which equals to
|
142 |
-
num_keypoints * num_features.
|
143 |
-
stem_channels (int): Number of feature channels. Default: 1024.
|
144 |
-
num_blocks (int): NUmber of basic temporal convolutional blocks.
|
145 |
-
Default: 2.
|
146 |
-
kernel_sizes (Sequence[int]): Sizes of the convolving kernel of
|
147 |
-
each basic block. Default: ``(3, 3, 3)``.
|
148 |
-
dropout (float): Dropout rate. Default: 0.25.
|
149 |
-
causal (bool): Use causal convolutions instead of symmetric
|
150 |
-
convolutions (for real-time applications).
|
151 |
-
Default: False.
|
152 |
-
residual (bool): Use residual connection. Default: True.
|
153 |
-
use_stride_conv (bool): Use TCN backbone optimized for
|
154 |
-
single-frame batching, i.e. where batches have input length =
|
155 |
-
receptive field, and output length = 1. This implementation
|
156 |
-
replaces dilated convolutions with strided convolutions to avoid
|
157 |
-
generating unused intermediate results. The weights are
|
158 |
-
interchangeable with the reference implementation. Default: False
|
159 |
-
conv_cfg (dict): dictionary to construct and config conv layer.
|
160 |
-
Default: dict(type='Conv1d').
|
161 |
-
norm_cfg (dict): dictionary to construct and config norm layer.
|
162 |
-
Default: dict(type='BN1d').
|
163 |
-
max_norm (float|None): if not None, the weight of convolution layers
|
164 |
-
will be clipped to have a maximum norm of max_norm.
|
165 |
-
|
166 |
-
Example:
|
167 |
-
>>> from mmpose.models import TCN
|
168 |
-
>>> import torch
|
169 |
-
>>> self = TCN(in_channels=34)
|
170 |
-
>>> self.eval()
|
171 |
-
>>> inputs = torch.rand(1, 34, 243)
|
172 |
-
>>> level_outputs = self.forward(inputs)
|
173 |
-
>>> for level_out in level_outputs:
|
174 |
-
... print(tuple(level_out.shape))
|
175 |
-
(1, 1024, 235)
|
176 |
-
(1, 1024, 217)
|
177 |
-
"""
|
178 |
-
|
179 |
-
def __init__(self,
|
180 |
-
in_channels,
|
181 |
-
stem_channels=1024,
|
182 |
-
num_blocks=2,
|
183 |
-
kernel_sizes=(3, 3, 3),
|
184 |
-
dropout=0.25,
|
185 |
-
causal=False,
|
186 |
-
residual=True,
|
187 |
-
use_stride_conv=False,
|
188 |
-
conv_cfg=dict(type='Conv1d'),
|
189 |
-
norm_cfg=dict(type='BN1d'),
|
190 |
-
max_norm=None):
|
191 |
-
# Protect mutable default arguments
|
192 |
-
conv_cfg = copy.deepcopy(conv_cfg)
|
193 |
-
norm_cfg = copy.deepcopy(norm_cfg)
|
194 |
-
super().__init__()
|
195 |
-
self.in_channels = in_channels
|
196 |
-
self.stem_channels = stem_channels
|
197 |
-
self.num_blocks = num_blocks
|
198 |
-
self.kernel_sizes = kernel_sizes
|
199 |
-
self.dropout = dropout
|
200 |
-
self.causal = causal
|
201 |
-
self.residual = residual
|
202 |
-
self.use_stride_conv = use_stride_conv
|
203 |
-
self.max_norm = max_norm
|
204 |
-
|
205 |
-
assert num_blocks == len(kernel_sizes) - 1
|
206 |
-
for ks in kernel_sizes:
|
207 |
-
assert ks % 2 == 1, 'Only odd filter widths are supported.'
|
208 |
-
|
209 |
-
self.expand_conv = ConvModule(
|
210 |
-
in_channels,
|
211 |
-
stem_channels,
|
212 |
-
kernel_size=kernel_sizes[0],
|
213 |
-
stride=kernel_sizes[0] if use_stride_conv else 1,
|
214 |
-
bias='auto',
|
215 |
-
conv_cfg=conv_cfg,
|
216 |
-
norm_cfg=norm_cfg)
|
217 |
-
|
218 |
-
dilation = kernel_sizes[0]
|
219 |
-
self.tcn_blocks = nn.ModuleList()
|
220 |
-
for i in range(1, num_blocks + 1):
|
221 |
-
self.tcn_blocks.append(
|
222 |
-
BasicTemporalBlock(
|
223 |
-
in_channels=stem_channels,
|
224 |
-
out_channels=stem_channels,
|
225 |
-
mid_channels=stem_channels,
|
226 |
-
kernel_size=kernel_sizes[i],
|
227 |
-
dilation=dilation,
|
228 |
-
dropout=dropout,
|
229 |
-
causal=causal,
|
230 |
-
residual=residual,
|
231 |
-
use_stride_conv=use_stride_conv,
|
232 |
-
conv_cfg=conv_cfg,
|
233 |
-
norm_cfg=norm_cfg))
|
234 |
-
dilation *= kernel_sizes[i]
|
235 |
-
|
236 |
-
if self.max_norm is not None:
|
237 |
-
# Apply weight norm clip to conv layers
|
238 |
-
weight_clip = WeightNormClipHook(self.max_norm)
|
239 |
-
for module in self.modules():
|
240 |
-
if isinstance(module, nn.modules.conv._ConvNd):
|
241 |
-
weight_clip.register(module)
|
242 |
-
|
243 |
-
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
|
244 |
-
|
245 |
-
def forward(self, x):
|
246 |
-
"""Forward function."""
|
247 |
-
x = self.expand_conv(x)
|
248 |
-
|
249 |
-
if self.dropout is not None:
|
250 |
-
x = self.dropout(x)
|
251 |
-
|
252 |
-
outs = []
|
253 |
-
for i in range(self.num_blocks):
|
254 |
-
x = self.tcn_blocks[i](x)
|
255 |
-
outs.append(x)
|
256 |
-
|
257 |
-
return tuple(outs)
|
258 |
-
|
259 |
-
def init_weights(self, pretrained=None):
|
260 |
-
"""Initialize the weights."""
|
261 |
-
super().init_weights(pretrained)
|
262 |
-
if pretrained is None:
|
263 |
-
for m in self.modules():
|
264 |
-
if isinstance(m, nn.modules.conv._ConvNd):
|
265 |
-
kaiming_init(m, mode='fan_in', nonlinearity='relu')
|
266 |
-
elif isinstance(m, _BatchNorm):
|
267 |
-
constant_init(m, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/transformer_utils/mmpose/models/backbones/utils/utils.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
from collections import OrderedDict
|
3 |
|
4 |
-
from
|
5 |
|
6 |
|
7 |
# Copyright (c) Open-MMLab. All rights reserved.
|
@@ -22,11 +22,11 @@ from torch.utils import model_zoo
|
|
22 |
from torch.nn import functional as F
|
23 |
|
24 |
import mmcv
|
25 |
-
from
|
26 |
-
from
|
27 |
-
from
|
28 |
-
from
|
29 |
-
from
|
30 |
|
31 |
from scipy import interpolate
|
32 |
import numpy as np
|
@@ -75,8 +75,8 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
|
|
75 |
def load(module, prefix=''):
|
76 |
# recursively check parallel module in case that the model has a
|
77 |
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
78 |
-
if is_module_wrapper(module):
|
79 |
-
|
80 |
local_metadata = {} if metadata is None else metadata.get(
|
81 |
prefix[:-1], {})
|
82 |
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
@@ -445,8 +445,8 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
|
445 |
"""
|
446 |
# recursively check parallel module in case that the model has a
|
447 |
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
448 |
-
if is_module_wrapper(module):
|
449 |
-
|
450 |
|
451 |
# below is the same as torch.nn.Module.state_dict()
|
452 |
if destination is None:
|
@@ -482,8 +482,8 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
|
|
482 |
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
483 |
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
484 |
|
485 |
-
if is_module_wrapper(model):
|
486 |
-
|
487 |
|
488 |
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
489 |
# save class name to the meta
|
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
from collections import OrderedDict
|
3 |
|
4 |
+
from mmengine.runner import load_state_dict
|
5 |
|
6 |
|
7 |
# Copyright (c) Open-MMLab. All rights reserved.
|
|
|
22 |
from torch.nn import functional as F
|
23 |
|
24 |
import mmcv
|
25 |
+
from mmengine.fileio import FileClient
|
26 |
+
from mmengine.fileio import load as load_file
|
27 |
+
# from mmengine.model.wrappers.utils import is_module_wrapper
|
28 |
+
from mmengine.utils import mkdir_or_exist
|
29 |
+
from mmengine.dist import get_dist_info
|
30 |
|
31 |
from scipy import interpolate
|
32 |
import numpy as np
|
|
|
75 |
def load(module, prefix=''):
|
76 |
# recursively check parallel module in case that the model has a
|
77 |
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
78 |
+
# if is_module_wrapper(module):
|
79 |
+
# module = module.module
|
80 |
local_metadata = {} if metadata is None else metadata.get(
|
81 |
prefix[:-1], {})
|
82 |
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
|
|
445 |
"""
|
446 |
# recursively check parallel module in case that the model has a
|
447 |
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
448 |
+
# if is_module_wrapper(module):
|
449 |
+
# module = module.module
|
450 |
|
451 |
# below is the same as torch.nn.Module.state_dict()
|
452 |
if destination is None:
|
|
|
482 |
raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
|
483 |
meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
|
484 |
|
485 |
+
# if is_module_wrapper(model):
|
486 |
+
# model = model.module
|
487 |
|
488 |
if hasattr(model, 'CLASSES') and model.CLASSES is not None:
|
489 |
# save class name to the meta
|