from typing import Any, Dict import torch from torch import nn import torchvision class VisualBackbone(nn.Module): r""" Base class for all visual backbones. All child classes can simply inherit from :class:`~torch.nn.Module`, however this is kept here for uniform type annotations. """ def __init__(self, visual_feature_size: int): super().__init__() self.visual_feature_size = visual_feature_size class TorchvisionVisualBackbone(VisualBackbone): r""" A visual backbone from `Torchvision model zoo `_. Any model can be specified using corresponding method name from the model zoo. Parameters ---------- name: str, optional (default = "resnet50") Name of the model from Torchvision model zoo. visual_feature_size: int, optional (default = 2048) Size of the channel dimension of output visual features from forward pass. pretrained: bool, optional (default = False) Whether to load ImageNet pretrained weights from Torchvision. frozen: float, optional (default = False) Whether to keep all weights frozen during training. """ def __init__( self, name: str = "resnet50", visual_feature_size: int = 2048, pretrained: bool = False, frozen: bool = False, ): super().__init__(visual_feature_size) self.cnn = getattr(torchvision.models, name)( pretrained, zero_init_residual=True ) # Do nothing after the final residual stage. self.cnn.fc = nn.Identity() # Freeze all weights if specified. if frozen: for param in self.cnn.parameters(): param.requires_grad = False self.cnn.eval() def forward(self, image: torch.Tensor) -> torch.Tensor: r""" Compute visual features for a batch of input images. Parameters ---------- image: torch.Tensor Batch of input images. A tensor of shape ``(batch_size, 3, height, width)``. Returns ------- torch.Tensor A tensor of shape ``(batch_size, channels, height, width)``, for example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50. """ for idx, (name, layer) in enumerate(self.cnn.named_children()): out = layer(image) if idx == 0 else layer(out) # These are the spatial features we need. if name == "layer4": # shape: (batch_size, channels, height, width) return out def detectron2_backbone_state_dict(self) -> Dict[str, Any]: r""" Return state dict of visual backbone which can be loaded with `Detectron2 `_. This is useful for downstream tasks based on Detectron2 (such as object detection and instance segmentation). This method renames certain parameters from Torchvision-style to Detectron2-style. Returns ------- Dict[str, Any] A dict with three keys: ``{"model", "author", "matching_heuristics"}``. These are necessary keys for loading this state dict properly with Detectron2. """ # Detectron2 backbones have slightly different module names, this mapping # lists substrings of module names required to be renamed for loading a # torchvision model into Detectron2. DETECTRON2_RENAME_MAPPING: Dict[str, str] = { "layer1": "res2", "layer2": "res3", "layer3": "res4", "layer4": "res5", "bn1": "conv1.norm", "bn2": "conv2.norm", "bn3": "conv3.norm", "downsample.0": "shortcut", "downsample.1": "shortcut.norm", } # Populate this dict by renaming module names. d2_backbone_dict: Dict[str, torch.Tensor] = {} for name, param in self.cnn.state_dict().items(): for old, new in DETECTRON2_RENAME_MAPPING.items(): name = name.replace(old, new) # First conv and bn module parameters are prefixed with "stem.". if not name.startswith("res"): name = f"stem.{name}" d2_backbone_dict[name] = param return { "model": d2_backbone_dict, "__author__": "Karan Desai", "matching_heuristics": True, } class TimmVisualBackbone(VisualBackbone): r""" A visual backbone from `Timm model zoo `_. This class is a generic wrapper over the ``timm`` library, and supports all models provided by the library. Check ``timm.list_models()`` for all supported model names. Parameters ---------- name: str, optional (default = "resnet50") Name of the model from Timm model zoo. visual_feature_size: int, optional (default = 2048) Size of the channel dimension of output visual features from forward pass. pretrained: bool, optional (default = False) Whether to load ImageNet pretrained weights from Torchvision. frozen: float, optional (default = False) Whether to keep all weights frozen during training. """ def __init__( self, name: str = "resnet50", visual_feature_size: int = 2048, pretrained: bool = False, frozen: bool = False, ): super().__init__(visual_feature_size) # Limit the scope of library import inside class definition. import timm # Create the model without any global pooling and softmax classifier. self.cnn = timm.create_model( name, pretrained=pretrained, num_classes=0, global_pool="" ) # Freeze all weights if specified. if frozen: for param in self.cnn.parameters(): param.requires_grad = False self.cnn.eval() def forward(self, image: torch.Tensor) -> torch.Tensor: r""" Compute visual features for a batch of input images. Parameters ---------- image: torch.Tensor Batch of input images. A tensor of shape ``(batch_size, 3, height, width)``. Returns ------- torch.Tensor A tensor of shape ``(batch_size, channels, height, width)``, for example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50. """ # shape: (batch_size, channels, height, width) return self.cnn(image) def detectron2_backbone_state_dict(self) -> Dict[str, Any]: # Detectron2 may not support all timm models out of the box. These # backbones won't be transferred to downstream detection tasks anyway. raise NotImplementedError