Spaces:
Runtime error
Runtime error
from typing import Any, Dict | |
import torch | |
from torch import nn | |
import torchvision | |
class VisualBackbone(nn.Module): | |
r""" | |
Base class for all visual backbones. All child classes can simply inherit | |
from :class:`~torch.nn.Module`, however this is kept here for uniform | |
type annotations. | |
""" | |
def __init__(self, visual_feature_size: int): | |
super().__init__() | |
self.visual_feature_size = visual_feature_size | |
class TorchvisionVisualBackbone(VisualBackbone): | |
r""" | |
A visual backbone from `Torchvision model zoo | |
<https://pytorch.org/docs/stable/torchvision/models.html>`_. Any model can | |
be specified using corresponding method name from the model zoo. | |
Parameters | |
---------- | |
name: str, optional (default = "resnet50") | |
Name of the model from Torchvision model zoo. | |
visual_feature_size: int, optional (default = 2048) | |
Size of the channel dimension of output visual features from forward pass. | |
pretrained: bool, optional (default = False) | |
Whether to load ImageNet pretrained weights from Torchvision. | |
frozen: float, optional (default = False) | |
Whether to keep all weights frozen during training. | |
""" | |
def __init__( | |
self, | |
name: str = "resnet50", | |
visual_feature_size: int = 2048, | |
pretrained: bool = False, | |
frozen: bool = False, | |
): | |
super().__init__(visual_feature_size) | |
self.cnn = getattr(torchvision.models, name)( | |
pretrained, zero_init_residual=True | |
) | |
# Do nothing after the final residual stage. | |
self.cnn.fc = nn.Identity() | |
# Freeze all weights if specified. | |
if frozen: | |
for param in self.cnn.parameters(): | |
param.requires_grad = False | |
self.cnn.eval() | |
def forward(self, image: torch.Tensor) -> torch.Tensor: | |
r""" | |
Compute visual features for a batch of input images. | |
Parameters | |
---------- | |
image: torch.Tensor | |
Batch of input images. A tensor of shape | |
``(batch_size, 3, height, width)``. | |
Returns | |
------- | |
torch.Tensor | |
A tensor of shape ``(batch_size, channels, height, width)``, for | |
example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50. | |
""" | |
for idx, (name, layer) in enumerate(self.cnn.named_children()): | |
out = layer(image) if idx == 0 else layer(out) | |
# These are the spatial features we need. | |
if name == "layer4": | |
# shape: (batch_size, channels, height, width) | |
return out | |
def detectron2_backbone_state_dict(self) -> Dict[str, Any]: | |
r""" | |
Return state dict of visual backbone which can be loaded with | |
`Detectron2 <https://github.com/facebookresearch/detectron2>`_. | |
This is useful for downstream tasks based on Detectron2 (such as | |
object detection and instance segmentation). This method renames | |
certain parameters from Torchvision-style to Detectron2-style. | |
Returns | |
------- | |
Dict[str, Any] | |
A dict with three keys: ``{"model", "author", "matching_heuristics"}``. | |
These are necessary keys for loading this state dict properly with | |
Detectron2. | |
""" | |
# Detectron2 backbones have slightly different module names, this mapping | |
# lists substrings of module names required to be renamed for loading a | |
# torchvision model into Detectron2. | |
DETECTRON2_RENAME_MAPPING: Dict[str, str] = { | |
"layer1": "res2", | |
"layer2": "res3", | |
"layer3": "res4", | |
"layer4": "res5", | |
"bn1": "conv1.norm", | |
"bn2": "conv2.norm", | |
"bn3": "conv3.norm", | |
"downsample.0": "shortcut", | |
"downsample.1": "shortcut.norm", | |
} | |
# Populate this dict by renaming module names. | |
d2_backbone_dict: Dict[str, torch.Tensor] = {} | |
for name, param in self.cnn.state_dict().items(): | |
for old, new in DETECTRON2_RENAME_MAPPING.items(): | |
name = name.replace(old, new) | |
# First conv and bn module parameters are prefixed with "stem.". | |
if not name.startswith("res"): | |
name = f"stem.{name}" | |
d2_backbone_dict[name] = param | |
return { | |
"model": d2_backbone_dict, | |
"__author__": "Karan Desai", | |
"matching_heuristics": True, | |
} | |
class TimmVisualBackbone(VisualBackbone): | |
r""" | |
A visual backbone from `Timm model zoo | |
<https://rwightman.github.io/pytorch-image-models/models/>`_. | |
This class is a generic wrapper over the ``timm`` library, and supports | |
all models provided by the library. Check ``timm.list_models()`` for all | |
supported model names. | |
Parameters | |
---------- | |
name: str, optional (default = "resnet50") | |
Name of the model from Timm model zoo. | |
visual_feature_size: int, optional (default = 2048) | |
Size of the channel dimension of output visual features from forward pass. | |
pretrained: bool, optional (default = False) | |
Whether to load ImageNet pretrained weights from Torchvision. | |
frozen: float, optional (default = False) | |
Whether to keep all weights frozen during training. | |
""" | |
def __init__( | |
self, | |
name: str = "resnet50", | |
visual_feature_size: int = 2048, | |
pretrained: bool = False, | |
frozen: bool = False, | |
): | |
super().__init__(visual_feature_size) | |
# Limit the scope of library import inside class definition. | |
import timm | |
# Create the model without any global pooling and softmax classifier. | |
self.cnn = timm.create_model( | |
name, pretrained=pretrained, num_classes=0, global_pool="" | |
) | |
# Freeze all weights if specified. | |
if frozen: | |
for param in self.cnn.parameters(): | |
param.requires_grad = False | |
self.cnn.eval() | |
def forward(self, image: torch.Tensor) -> torch.Tensor: | |
r""" | |
Compute visual features for a batch of input images. | |
Parameters | |
---------- | |
image: torch.Tensor | |
Batch of input images. A tensor of shape | |
``(batch_size, 3, height, width)``. | |
Returns | |
------- | |
torch.Tensor | |
A tensor of shape ``(batch_size, channels, height, width)``, for | |
example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50. | |
""" | |
# shape: (batch_size, channels, height, width) | |
return self.cnn(image) | |
def detectron2_backbone_state_dict(self) -> Dict[str, Any]: | |
# Detectron2 may not support all timm models out of the box. These | |
# backbones won't be transferred to downstream detection tasks anyway. | |
raise NotImplementedError | |