ZIP / models /clip_ebc /resnet.py
Yiming-M's picture
2025-08-01 10:49 πŸš€
c628976
from torch import nn, Tensor
import open_clip
from ..utils import ConvRefine, ConvUpsample
from ..utils import _get_norm_layer, _get_activation
resnet_names_and_weights = {
"RN50": ["openai", "yfcc15m", "cc12m"],
"RN101": ["openai", "yfcc15m", "cc12m"],
"RN50x4": ["openai", "yfcc15m", "cc12m"],
"RN50x16": ["openai", "yfcc15m", "cc12m"],
"RN50x64": ["openai", "yfcc15m", "cc12m"],
}
refiner_channels = {
"RN50": 2048,
"RN101": 2048,
"RN50x4": 2560,
"RN50x16": 3072,
"RN50x64": 4096,
}
refiner_groups = {
"RN50": refiner_channels["RN50"] // 512, # 4
"RN101": refiner_channels["RN101"] // 512, # 4
"RN50x4": refiner_channels["RN50x4"] // 512, # 5
"RN50x16": refiner_channels["RN50x16"] // 512, # 6
"RN50x64": refiner_channels["RN50x64"] // 512, # 8
}
class ResNet(nn.Module):
def __init__(
self,
model_name: str,
weight_name: str,
block_size: int = 16,
norm: str = "none",
act: str = "none"
) -> None:
super(ResNet, self).__init__()
assert model_name in resnet_names_and_weights, f"Model name should be one of {list(resnet_names_and_weights.keys())}, but got {model_name}."
assert weight_name in resnet_names_and_weights[model_name], f"Pretrained should be one of {resnet_names_and_weights[model_name]}, but got {weight_name}."
assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
self.model_name, self.weight_name = model_name, weight_name
self.block_size = block_size
# model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
model = open_clip.create_model(model_name=model_name, pretrained=False, load_weights=False).visual
# Stem
self.conv1 = model.conv1
self.bn1 = model.bn1
self.act1 = model.act1
self.conv2 = model.conv2
self.bn2 = model.bn2
self.act2 = model.act2
self.conv3 = model.conv3
self.bn3 = model.bn3
self.act3 = model.act3
self.avgpool = model.avgpool
# Stem: reduction = 4
# Layers
for idx in range(1, 5):
setattr(self, f"layer{idx}", getattr(model, f"layer{idx}"))
self.in_features = model.attnpool.c_proj.weight.shape[1]
self.out_features = model.attnpool.c_proj.weight.shape[0]
if norm == "bn":
norm_layer = nn.BatchNorm2d
elif norm == "ln":
norm_layer = nn.LayerNorm
else:
norm_layer = _get_norm_layer(model)
if act == "relu":
activation = nn.ReLU(inplace=True)
elif act == "gelu":
activation = nn.GELU()
else:
activation = _get_activation(model)
if block_size == 32:
self.refiner = ConvRefine(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
)
elif block_size == 16:
self.refiner = ConvUpsample(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
)
else: # block_size == 8
self.refiner = nn.Sequential(
ConvUpsample(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
),
ConvUpsample(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
),
)
def stem(self, x: Tensor) -> Tensor:
x = self.act1(self.bn1(self.conv1(x)))
x = self.act2(self.bn2(self.conv2(x)))
x = self.act3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.refiner(x)
return x
def _resnet(
model_name: str,
weight_name: str,
block_size: int = 16,
norm: str = "none",
act: str = "none"
) -> ResNet:
model = ResNet(
model_name=model_name,
weight_name=weight_name,
block_size=block_size,
norm=norm,
act=act
)
return model