Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2022 The IDEA Authors. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ------------------------------------------------------------------------------------------------ | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # # Copyright (c) OpenMMLab. All rights reserved. | |
| # ------------------------------------------------------------------------------------------------ | |
| # Support TIMM Backbone | |
| # Modified from: | |
| # https://github.com/open-mmlab/mmclassification/blob/master/mmcls/models/backbones/timm_backbone.py | |
| # https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/backbone.py | |
| # ------------------------------------------------------------------------------------------------ | |
| import warnings | |
| from typing import Tuple | |
| import torch.nn as nn | |
| from detectron2.modeling.backbone import Backbone | |
| from detectron2.utils import comm | |
| from detectron2.utils.logger import setup_logger | |
| try: | |
| import timm | |
| except ImportError: | |
| timm = None | |
| def log_timm_feature_info(feature_info): | |
| """Print feature_info of timm backbone to help development and debug. | |
| Args: | |
| feature_info (list[dict] | timm.models.features.FeatureInfo | None): | |
| feature_info of timm backbone. | |
| """ | |
| logger = setup_logger(name="timm backbone") | |
| if feature_info is None: | |
| logger.warning("This backbone does not have feature_info") | |
| elif isinstance(feature_info, list): | |
| for feat_idx, each_info in enumerate(feature_info): | |
| logger.info(f"backbone feature_info[{feat_idx}]: {each_info}") | |
| else: | |
| try: | |
| logger.info(f"backbone out_indices: {feature_info.out_indices}") | |
| logger.info(f"backbone out_channels: {feature_info.channels()}") | |
| logger.info(f"backbone out_strides: {feature_info.reduction()}") | |
| except AttributeError: | |
| logger.warning("Unexpected format of backbone feature_info") | |
| class TimmBackbone(Backbone): | |
| """A wrapper for using backbone from timm library. | |
| Please see the document for `feature extraction with timm | |
| <https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_ | |
| for more details. | |
| Args: | |
| model_name (str): Name of timm model to instantiate. | |
| features_only (bool): Whether to extract feature pyramid (multi-scale | |
| feature maps from the deepest layer of each stage). | |
| pretrained (bool): Whether to load pretrained weights. Default: False. | |
| checkpoint_path (str): Whether to load pretrained weights. Default: False. | |
| in_channels (int): The number of input channels. Default: 3. | |
| out_indices (tuple[str]): The extracted feature indices which select | |
| specific feature levels or limit the stride of the feature extractor. | |
| out_features (tuple[str]): A map for the output feature dict, e.g., | |
| set ("p0", "p1") to return only the feature from indices (0, 1) as | |
| ``{"p0": feature from indice 0, "p1": feature from indice 1}``. | |
| norm_layer (nn.Module): Set the specified norm layer for feature extractor, | |
| e.g., set ``norm_layer=FrozenBatchNorm2d`` to freeze the norm layer | |
| in feature extractor. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str, | |
| features_only: bool = True, | |
| pretrained: bool = False, | |
| checkpoint_path: str = "", | |
| in_channels: int = 3, | |
| out_indices: Tuple[int] = (0, 1, 2, 3), | |
| norm_layer: nn.Module = None, | |
| ): | |
| super().__init__() | |
| logger = setup_logger(name="timm backbone") | |
| if timm is None: | |
| raise RuntimeError('Failed to import timm. Please run "pip install timm". ') | |
| if not isinstance(pretrained, bool): | |
| raise TypeError("pretrained must be bool, not str for model path") | |
| if features_only and checkpoint_path: | |
| warnings.warn( | |
| "Using both features_only and checkpoint_path may cause error" | |
| " in timm. See " | |
| "https://github.com/rwightman/pytorch-image-models/issues/488" | |
| ) | |
| try: | |
| self.timm_model = timm.create_model( | |
| model_name=model_name, | |
| features_only=features_only, | |
| pretrained=pretrained, | |
| in_chans=in_channels, | |
| out_indices=out_indices, | |
| checkpoint_path=checkpoint_path, | |
| norm_layer=norm_layer, | |
| ) | |
| except Exception as error: | |
| if "feature_info" in str(error): | |
| raise AttributeError( | |
| "Using features_only may cause attribute error" | |
| " in timm, cause there's no feature_info attribute in some models. See " | |
| "https://github.com/rwightman/pytorch-image-models/issues/1438" | |
| ) | |
| elif "norm_layer" in str(error): | |
| raise ValueError( | |
| f"{model_name} does not support specified norm layer, please set 'norm_layer=None'" | |
| ) | |
| else: | |
| logger.info(error) | |
| exit() | |
| self.out_indices = out_indices | |
| feature_info = getattr(self.timm_model, "feature_info", None) | |
| if comm.get_rank() == 0: | |
| log_timm_feature_info(feature_info) | |
| if feature_info is not None: | |
| output_feature_channels = { | |
| "p{}".format(out_indices[i]): feature_info.channels()[i] | |
| for i in range(len(out_indices)) | |
| } | |
| out_feature_strides = { | |
| "p{}".format(out_indices[i]): feature_info.reduction()[i] | |
| for i in range(len(out_indices)) | |
| } | |
| self._out_features = {"p{}".format(out_indices[i]) for i in range(len(out_indices))} | |
| self._out_feature_channels = { | |
| feat: output_feature_channels[feat] for feat in self._out_features | |
| } | |
| self._out_feature_strides = { | |
| feat: out_feature_strides[feat] for feat in self._out_features | |
| } | |
| def forward(self, x): | |
| """Forward function of `TimmBackbone`. | |
| Args: | |
| x (torch.Tensor): the input tensor for feature extraction. | |
| Returns: | |
| dict[str->Tensor]: mapping from feature name (e.g., "p1") to tensor | |
| """ | |
| features = self.timm_model(x) | |
| outs = {} | |
| for i in range(len(self.out_indices)): | |
| out = features[i] | |
| outs["p{}".format(self.out_indices[i])] = out | |
| return outs | |