|
""" timm model adapter
|
|
|
|
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
|
|
"""
|
|
from collections import OrderedDict
|
|
|
|
import torch.nn as nn
|
|
|
|
try:
|
|
import timm
|
|
from timm.models.layers import Mlp, to_2tuple
|
|
from timm.models.layers.attention_pool2d import RotAttentionPool2d
|
|
from timm.models.layers.attention_pool2d import (
|
|
AttentionPool2d as AbsAttentionPool2d,
|
|
)
|
|
except ImportError:
|
|
timm = None
|
|
|
|
from .utils import freeze_batch_norm_2d
|
|
|
|
|
|
class TimmModel(nn.Module):
|
|
"""timm model adapter
|
|
# FIXME this adapter is a work in progress, may change in ways that break weight compat
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name,
|
|
embed_dim,
|
|
image_size=224,
|
|
pool="avg",
|
|
proj="linear",
|
|
drop=0.0,
|
|
pretrained=False,
|
|
):
|
|
super().__init__()
|
|
if timm is None:
|
|
raise RuntimeError("Please `pip install timm` to use timm models.")
|
|
|
|
self.image_size = to_2tuple(image_size)
|
|
self.trunk = timm.create_model(model_name, pretrained=pretrained)
|
|
feat_size = self.trunk.default_cfg.get("pool_size", None)
|
|
feature_ndim = 1 if not feat_size else 2
|
|
if pool in ("abs_attn", "rot_attn"):
|
|
assert feature_ndim == 2
|
|
|
|
self.trunk.reset_classifier(0, global_pool="")
|
|
else:
|
|
|
|
reset_kwargs = dict(global_pool=pool) if pool else {}
|
|
self.trunk.reset_classifier(0, **reset_kwargs)
|
|
prev_chs = self.trunk.num_features
|
|
|
|
head_layers = OrderedDict()
|
|
if pool == "abs_attn":
|
|
head_layers["pool"] = AbsAttentionPool2d(
|
|
prev_chs, feat_size=feat_size, out_features=embed_dim
|
|
)
|
|
prev_chs = embed_dim
|
|
elif pool == "rot_attn":
|
|
head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
|
prev_chs = embed_dim
|
|
else:
|
|
assert proj, "projection layer needed if non-attention pooling is used."
|
|
|
|
|
|
if proj == "linear":
|
|
head_layers["drop"] = nn.Dropout(drop)
|
|
head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
|
|
elif proj == "mlp":
|
|
head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
|
|
|
|
self.head = nn.Sequential(head_layers)
|
|
|
|
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
|
"""lock modules
|
|
Args:
|
|
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
|
"""
|
|
if not unlocked_groups:
|
|
|
|
for param in self.trunk.parameters():
|
|
param.requires_grad = False
|
|
if freeze_bn_stats:
|
|
freeze_batch_norm_2d(self.trunk)
|
|
else:
|
|
|
|
try:
|
|
|
|
from timm.models.helpers import group_parameters, group_modules
|
|
except ImportError:
|
|
raise RuntimeError(
|
|
"Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
|
|
)
|
|
matcher = self.trunk.group_matcher()
|
|
gparams = group_parameters(self.trunk, matcher)
|
|
max_layer_id = max(gparams.keys())
|
|
max_layer_id = max_layer_id - unlocked_groups
|
|
for group_idx in range(max_layer_id + 1):
|
|
group = gparams[group_idx]
|
|
for param in group:
|
|
self.trunk.get_parameter(param).requires_grad = False
|
|
if freeze_bn_stats:
|
|
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
|
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
|
freeze_batch_norm_2d(self.trunk, gmodules)
|
|
|
|
def forward(self, x):
|
|
x = self.trunk(x)
|
|
x = self.head(x)
|
|
return x
|
|
|