Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The IDEA Authors. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------------------------------ | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# # Copyright (c) OpenMMLab. All rights reserved. | |
# ------------------------------------------------------------------------------------------------ | |
# Support TIMM Backbone | |
# Modified from: | |
# https://github.com/open-mmlab/mmclassification/blob/master/mmcls/models/backbones/timm_backbone.py | |
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/backbone.py | |
# ------------------------------------------------------------------------------------------------ | |
import warnings | |
from typing import Tuple | |
import torch.nn as nn | |
from detectron2.modeling.backbone import Backbone | |
from detectron2.utils import comm | |
from detectron2.utils.logger import setup_logger | |
try: | |
import timm | |
except ImportError: | |
timm = None | |
def log_timm_feature_info(feature_info): | |
"""Print feature_info of timm backbone to help development and debug. | |
Args: | |
feature_info (list[dict] | timm.models.features.FeatureInfo | None): | |
feature_info of timm backbone. | |
""" | |
logger = setup_logger(name="timm backbone") | |
if feature_info is None: | |
logger.warning("This backbone does not have feature_info") | |
elif isinstance(feature_info, list): | |
for feat_idx, each_info in enumerate(feature_info): | |
logger.info(f"backbone feature_info[{feat_idx}]: {each_info}") | |
else: | |
try: | |
logger.info(f"backbone out_indices: {feature_info.out_indices}") | |
logger.info(f"backbone out_channels: {feature_info.channels()}") | |
logger.info(f"backbone out_strides: {feature_info.reduction()}") | |
except AttributeError: | |
logger.warning("Unexpected format of backbone feature_info") | |
class TimmBackbone(Backbone): | |
"""A wrapper for using backbone from timm library. | |
Please see the document for `feature extraction with timm | |
<https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_ | |
for more details. | |
Args: | |
model_name (str): Name of timm model to instantiate. | |
features_only (bool): Whether to extract feature pyramid (multi-scale | |
feature maps from the deepest layer of each stage). | |
pretrained (bool): Whether to load pretrained weights. Default: False. | |
checkpoint_path (str): Whether to load pretrained weights. Default: False. | |
in_channels (int): The number of input channels. Default: 3. | |
out_indices (tuple[str]): The extracted feature indices which select | |
specific feature levels or limit the stride of the feature extractor. | |
out_features (tuple[str]): A map for the output feature dict, e.g., | |
set ("p0", "p1") to return only the feature from indices (0, 1) as | |
``{"p0": feature from indice 0, "p1": feature from indice 1}``. | |
norm_layer (nn.Module): Set the specified norm layer for feature extractor, | |
e.g., set ``norm_layer=FrozenBatchNorm2d`` to freeze the norm layer | |
in feature extractor. | |
""" | |
def __init__( | |
self, | |
model_name: str, | |
features_only: bool = True, | |
pretrained: bool = False, | |
checkpoint_path: str = "", | |
in_channels: int = 3, | |
out_indices: Tuple[int] = (0, 1, 2, 3), | |
norm_layer: nn.Module = None, | |
): | |
super().__init__() | |
logger = setup_logger(name="timm backbone") | |
if timm is None: | |
raise RuntimeError('Failed to import timm. Please run "pip install timm". ') | |
if not isinstance(pretrained, bool): | |
raise TypeError("pretrained must be bool, not str for model path") | |
if features_only and checkpoint_path: | |
warnings.warn( | |
"Using both features_only and checkpoint_path may cause error" | |
" in timm. See " | |
"https://github.com/rwightman/pytorch-image-models/issues/488" | |
) | |
try: | |
self.timm_model = timm.create_model( | |
model_name=model_name, | |
features_only=features_only, | |
pretrained=pretrained, | |
in_chans=in_channels, | |
out_indices=out_indices, | |
checkpoint_path=checkpoint_path, | |
norm_layer=norm_layer, | |
) | |
except Exception as error: | |
if "feature_info" in str(error): | |
raise AttributeError( | |
"Using features_only may cause attribute error" | |
" in timm, cause there's no feature_info attribute in some models. See " | |
"https://github.com/rwightman/pytorch-image-models/issues/1438" | |
) | |
elif "norm_layer" in str(error): | |
raise ValueError( | |
f"{model_name} does not support specified norm layer, please set 'norm_layer=None'" | |
) | |
else: | |
logger.info(error) | |
exit() | |
self.out_indices = out_indices | |
feature_info = getattr(self.timm_model, "feature_info", None) | |
if comm.get_rank() == 0: | |
log_timm_feature_info(feature_info) | |
if feature_info is not None: | |
output_feature_channels = { | |
"p{}".format(out_indices[i]): feature_info.channels()[i] | |
for i in range(len(out_indices)) | |
} | |
out_feature_strides = { | |
"p{}".format(out_indices[i]): feature_info.reduction()[i] | |
for i in range(len(out_indices)) | |
} | |
self._out_features = {"p{}".format(out_indices[i]) for i in range(len(out_indices))} | |
self._out_feature_channels = { | |
feat: output_feature_channels[feat] for feat in self._out_features | |
} | |
self._out_feature_strides = { | |
feat: out_feature_strides[feat] for feat in self._out_features | |
} | |
def forward(self, x): | |
"""Forward function of `TimmBackbone`. | |
Args: | |
x (torch.Tensor): the input tensor for feature extraction. | |
Returns: | |
dict[str->Tensor]: mapping from feature name (e.g., "p1") to tensor | |
""" | |
features = self.timm_model(x) | |
outs = {} | |
for i in range(len(self.out_indices)): | |
out = features[i] | |
outs["p{}".format(self.out_indices[i])] = out | |
return outs | |