File size: 7,092 Bytes
3e99b05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# # Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------------------------
# Support TIMM Backbone
# Modified from:
# https://github.com/open-mmlab/mmclassification/blob/master/mmcls/models/backbones/timm_backbone.py
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/backbone.py
# ------------------------------------------------------------------------------------------------

import warnings
from typing import Tuple
import torch.nn as nn

from detectron2.modeling.backbone import Backbone
from detectron2.utils import comm
from detectron2.utils.logger import setup_logger

try:
    import timm
except ImportError:
    timm = None


def log_timm_feature_info(feature_info):
    """Print feature_info of timm backbone to help development and debug.
    Args:
        feature_info (list[dict] | timm.models.features.FeatureInfo | None):
            feature_info of timm backbone.
    """
    logger = setup_logger(name="timm backbone")
    if feature_info is None:
        logger.warning("This backbone does not have feature_info")
    elif isinstance(feature_info, list):
        for feat_idx, each_info in enumerate(feature_info):
            logger.info(f"backbone feature_info[{feat_idx}]: {each_info}")
    else:
        try:
            logger.info(f"backbone out_indices: {feature_info.out_indices}")
            logger.info(f"backbone out_channels: {feature_info.channels()}")
            logger.info(f"backbone out_strides: {feature_info.reduction()}")
        except AttributeError:
            logger.warning("Unexpected format of backbone feature_info")


class TimmBackbone(Backbone):
    """A wrapper for using backbone from timm library.
    Please see the document for `feature extraction with timm
    <https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_
    for more details.
    Args:
        model_name (str): Name of timm model to instantiate.
        features_only (bool): Whether to extract feature pyramid (multi-scale
            feature maps from the deepest layer of each stage).
        pretrained (bool): Whether to load pretrained weights. Default: False.
        checkpoint_path (str): Whether to load pretrained weights. Default: False.
        in_channels (int): The number of input channels. Default: 3.
        out_indices (tuple[str]): The extracted feature indices which select
            specific feature levels or limit the stride of the feature extractor.
        out_features (tuple[str]): A map for the output feature dict, e.g.,
            set ("p0", "p1") to return only the feature from indices (0, 1) as
            ``{"p0": feature from indice 0, "p1": feature from indice 1}``.
        norm_layer (nn.Module): Set the specified norm layer for feature extractor,
            e.g., set ``norm_layer=FrozenBatchNorm2d`` to freeze the norm layer
            in feature extractor.
    """

    def __init__(
        self,
        model_name: str,
        features_only: bool = True,
        pretrained: bool = False,
        checkpoint_path: str = "",
        in_channels: int = 3,
        out_indices: Tuple[int] = (0, 1, 2, 3),
        norm_layer: nn.Module = None,
    ):
        super().__init__()
        logger = setup_logger(name="timm backbone")
        if timm is None:
            raise RuntimeError('Failed to import timm. Please run "pip install timm". ')
        if not isinstance(pretrained, bool):
            raise TypeError("pretrained must be bool, not str for model path")
        if features_only and checkpoint_path:
            warnings.warn(
                "Using both features_only and checkpoint_path may cause error"
                " in timm. See "
                "https://github.com/rwightman/pytorch-image-models/issues/488"
            )

        try:
            self.timm_model = timm.create_model(
                model_name=model_name,
                features_only=features_only,
                pretrained=pretrained,
                in_chans=in_channels,
                out_indices=out_indices,
                checkpoint_path=checkpoint_path,
                norm_layer=norm_layer,
            )
        except Exception as error:
            if "feature_info" in str(error):
                raise AttributeError(
                    "Using features_only may cause attribute error"
                    " in timm, cause there's no feature_info attribute in some models. See "
                    "https://github.com/rwightman/pytorch-image-models/issues/1438"
                )
            elif "norm_layer" in str(error):
                raise ValueError(
                    f"{model_name} does not support specified norm layer, please set 'norm_layer=None'"
                )
            else:
                logger.info(error)
                exit()

        self.out_indices = out_indices

        feature_info = getattr(self.timm_model, "feature_info", None)
        if comm.get_rank() == 0:
            log_timm_feature_info(feature_info)

        if feature_info is not None:
            output_feature_channels = {
                "p{}".format(out_indices[i]): feature_info.channels()[i]
                for i in range(len(out_indices))
            }
            out_feature_strides = {
                "p{}".format(out_indices[i]): feature_info.reduction()[i]
                for i in range(len(out_indices))
            }

            self._out_features = {"p{}".format(out_indices[i]) for i in range(len(out_indices))}
            self._out_feature_channels = {
                feat: output_feature_channels[feat] for feat in self._out_features
            }
            self._out_feature_strides = {
                feat: out_feature_strides[feat] for feat in self._out_features
            }

    def forward(self, x):
        """Forward function of `TimmBackbone`.
        Args:
            x (torch.Tensor): the input tensor for feature extraction.
        Returns:
            dict[str->Tensor]: mapping from feature name (e.g., "p1") to tensor
        """
        features = self.timm_model(x)
        outs = {}
        for i in range(len(self.out_indices)):
            out = features[i]
            outs["p{}".format(self.out_indices[i])] = out

        return outs