mathiaszinnen's picture
Initialize app
3e99b05
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# # Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------------------------
# Support TIMM Backbone
# Modified from:
# https://github.com/open-mmlab/mmclassification/blob/master/mmcls/models/backbones/timm_backbone.py
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/backbone.py
# ------------------------------------------------------------------------------------------------
import warnings
from typing import Tuple
import torch.nn as nn
from detectron2.modeling.backbone import Backbone
from detectron2.utils import comm
from detectron2.utils.logger import setup_logger
try:
import timm
except ImportError:
timm = None
def log_timm_feature_info(feature_info):
"""Print feature_info of timm backbone to help development and debug.
Args:
feature_info (list[dict] | timm.models.features.FeatureInfo | None):
feature_info of timm backbone.
"""
logger = setup_logger(name="timm backbone")
if feature_info is None:
logger.warning("This backbone does not have feature_info")
elif isinstance(feature_info, list):
for feat_idx, each_info in enumerate(feature_info):
logger.info(f"backbone feature_info[{feat_idx}]: {each_info}")
else:
try:
logger.info(f"backbone out_indices: {feature_info.out_indices}")
logger.info(f"backbone out_channels: {feature_info.channels()}")
logger.info(f"backbone out_strides: {feature_info.reduction()}")
except AttributeError:
logger.warning("Unexpected format of backbone feature_info")
class TimmBackbone(Backbone):
"""A wrapper for using backbone from timm library.
Please see the document for `feature extraction with timm
<https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_
for more details.
Args:
model_name (str): Name of timm model to instantiate.
features_only (bool): Whether to extract feature pyramid (multi-scale
feature maps from the deepest layer of each stage).
pretrained (bool): Whether to load pretrained weights. Default: False.
checkpoint_path (str): Whether to load pretrained weights. Default: False.
in_channels (int): The number of input channels. Default: 3.
out_indices (tuple[str]): The extracted feature indices which select
specific feature levels or limit the stride of the feature extractor.
out_features (tuple[str]): A map for the output feature dict, e.g.,
set ("p0", "p1") to return only the feature from indices (0, 1) as
``{"p0": feature from indice 0, "p1": feature from indice 1}``.
norm_layer (nn.Module): Set the specified norm layer for feature extractor,
e.g., set ``norm_layer=FrozenBatchNorm2d`` to freeze the norm layer
in feature extractor.
"""
def __init__(
self,
model_name: str,
features_only: bool = True,
pretrained: bool = False,
checkpoint_path: str = "",
in_channels: int = 3,
out_indices: Tuple[int] = (0, 1, 2, 3),
norm_layer: nn.Module = None,
):
super().__init__()
logger = setup_logger(name="timm backbone")
if timm is None:
raise RuntimeError('Failed to import timm. Please run "pip install timm". ')
if not isinstance(pretrained, bool):
raise TypeError("pretrained must be bool, not str for model path")
if features_only and checkpoint_path:
warnings.warn(
"Using both features_only and checkpoint_path may cause error"
" in timm. See "
"https://github.com/rwightman/pytorch-image-models/issues/488"
)
try:
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
out_indices=out_indices,
checkpoint_path=checkpoint_path,
norm_layer=norm_layer,
)
except Exception as error:
if "feature_info" in str(error):
raise AttributeError(
"Using features_only may cause attribute error"
" in timm, cause there's no feature_info attribute in some models. See "
"https://github.com/rwightman/pytorch-image-models/issues/1438"
)
elif "norm_layer" in str(error):
raise ValueError(
f"{model_name} does not support specified norm layer, please set 'norm_layer=None'"
)
else:
logger.info(error)
exit()
self.out_indices = out_indices
feature_info = getattr(self.timm_model, "feature_info", None)
if comm.get_rank() == 0:
log_timm_feature_info(feature_info)
if feature_info is not None:
output_feature_channels = {
"p{}".format(out_indices[i]): feature_info.channels()[i]
for i in range(len(out_indices))
}
out_feature_strides = {
"p{}".format(out_indices[i]): feature_info.reduction()[i]
for i in range(len(out_indices))
}
self._out_features = {"p{}".format(out_indices[i]) for i in range(len(out_indices))}
self._out_feature_channels = {
feat: output_feature_channels[feat] for feat in self._out_features
}
self._out_feature_strides = {
feat: out_feature_strides[feat] for feat in self._out_features
}
def forward(self, x):
"""Forward function of `TimmBackbone`.
Args:
x (torch.Tensor): the input tensor for feature extraction.
Returns:
dict[str->Tensor]: mapping from feature name (e.g., "p1") to tensor
"""
features = self.timm_model(x)
outs = {}
for i in range(len(self.out_indices)):
out = features[i]
outs["p{}".format(self.out_indices[i])] = out
return outs