Spaces:
Runtime error
Runtime error
File size: 6,933 Bytes
a5f8a35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from typing import Any, Dict
import torch
from torch import nn
import torchvision
class VisualBackbone(nn.Module):
r"""
Base class for all visual backbones. All child classes can simply inherit
from :class:`~torch.nn.Module`, however this is kept here for uniform
type annotations.
"""
def __init__(self, visual_feature_size: int):
super().__init__()
self.visual_feature_size = visual_feature_size
class TorchvisionVisualBackbone(VisualBackbone):
r"""
A visual backbone from `Torchvision model zoo
<https://pytorch.org/docs/stable/torchvision/models.html>`_. Any model can
be specified using corresponding method name from the model zoo.
Parameters
----------
name: str, optional (default = "resnet50")
Name of the model from Torchvision model zoo.
visual_feature_size: int, optional (default = 2048)
Size of the channel dimension of output visual features from forward pass.
pretrained: bool, optional (default = False)
Whether to load ImageNet pretrained weights from Torchvision.
frozen: float, optional (default = False)
Whether to keep all weights frozen during training.
"""
def __init__(
self,
name: str = "resnet50",
visual_feature_size: int = 2048,
pretrained: bool = False,
frozen: bool = False,
):
super().__init__(visual_feature_size)
self.cnn = getattr(torchvision.models, name)(
pretrained, zero_init_residual=True
)
# Do nothing after the final residual stage.
self.cnn.fc = nn.Identity()
# Freeze all weights if specified.
if frozen:
for param in self.cnn.parameters():
param.requires_grad = False
self.cnn.eval()
def forward(self, image: torch.Tensor) -> torch.Tensor:
r"""
Compute visual features for a batch of input images.
Parameters
----------
image: torch.Tensor
Batch of input images. A tensor of shape
``(batch_size, 3, height, width)``.
Returns
-------
torch.Tensor
A tensor of shape ``(batch_size, channels, height, width)``, for
example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50.
"""
for idx, (name, layer) in enumerate(self.cnn.named_children()):
out = layer(image) if idx == 0 else layer(out)
# These are the spatial features we need.
if name == "layer4":
# shape: (batch_size, channels, height, width)
return out
def detectron2_backbone_state_dict(self) -> Dict[str, Any]:
r"""
Return state dict of visual backbone which can be loaded with
`Detectron2 <https://github.com/facebookresearch/detectron2>`_.
This is useful for downstream tasks based on Detectron2 (such as
object detection and instance segmentation). This method renames
certain parameters from Torchvision-style to Detectron2-style.
Returns
-------
Dict[str, Any]
A dict with three keys: ``{"model", "author", "matching_heuristics"}``.
These are necessary keys for loading this state dict properly with
Detectron2.
"""
# Detectron2 backbones have slightly different module names, this mapping
# lists substrings of module names required to be renamed for loading a
# torchvision model into Detectron2.
DETECTRON2_RENAME_MAPPING: Dict[str, str] = {
"layer1": "res2",
"layer2": "res3",
"layer3": "res4",
"layer4": "res5",
"bn1": "conv1.norm",
"bn2": "conv2.norm",
"bn3": "conv3.norm",
"downsample.0": "shortcut",
"downsample.1": "shortcut.norm",
}
# Populate this dict by renaming module names.
d2_backbone_dict: Dict[str, torch.Tensor] = {}
for name, param in self.cnn.state_dict().items():
for old, new in DETECTRON2_RENAME_MAPPING.items():
name = name.replace(old, new)
# First conv and bn module parameters are prefixed with "stem.".
if not name.startswith("res"):
name = f"stem.{name}"
d2_backbone_dict[name] = param
return {
"model": d2_backbone_dict,
"__author__": "Karan Desai",
"matching_heuristics": True,
}
class TimmVisualBackbone(VisualBackbone):
r"""
A visual backbone from `Timm model zoo
<https://rwightman.github.io/pytorch-image-models/models/>`_.
This class is a generic wrapper over the ``timm`` library, and supports
all models provided by the library. Check ``timm.list_models()`` for all
supported model names.
Parameters
----------
name: str, optional (default = "resnet50")
Name of the model from Timm model zoo.
visual_feature_size: int, optional (default = 2048)
Size of the channel dimension of output visual features from forward pass.
pretrained: bool, optional (default = False)
Whether to load ImageNet pretrained weights from Torchvision.
frozen: float, optional (default = False)
Whether to keep all weights frozen during training.
"""
def __init__(
self,
name: str = "resnet50",
visual_feature_size: int = 2048,
pretrained: bool = False,
frozen: bool = False,
):
super().__init__(visual_feature_size)
# Limit the scope of library import inside class definition.
import timm
# Create the model without any global pooling and softmax classifier.
self.cnn = timm.create_model(
name, pretrained=pretrained, num_classes=0, global_pool=""
)
# Freeze all weights if specified.
if frozen:
for param in self.cnn.parameters():
param.requires_grad = False
self.cnn.eval()
def forward(self, image: torch.Tensor) -> torch.Tensor:
r"""
Compute visual features for a batch of input images.
Parameters
----------
image: torch.Tensor
Batch of input images. A tensor of shape
``(batch_size, 3, height, width)``.
Returns
-------
torch.Tensor
A tensor of shape ``(batch_size, channels, height, width)``, for
example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50.
"""
# shape: (batch_size, channels, height, width)
return self.cnn(image)
def detectron2_backbone_state_dict(self) -> Dict[str, Any]:
# Detectron2 may not support all timm models out of the box. These
# backbones won't be transferred to downstream detection tasks anyway.
raise NotImplementedError
|