ZIP / models /clip_ebc /mobileclip.py
Yiming-M's picture
2025-08-01 10:49 πŸš€
c628976
from torch import nn, Tensor
import open_clip
from ..utils import ConvRefine, ConvUpsample
from ..utils import _get_norm_layer, _get_activation
mobileclip_names_and_weights = {
"MobileCLIP-S1": ["datacompdr"],
"MobileCLIP-S2": ["datacompdr"],
}
refiner_channels = {
"MobileCLIP-S1": 1024,
"MobileCLIP-S2": 1280,
}
refiner_groups = {
"MobileCLIP-S1": 2,
"MobileCLIP-S2": 2,
}
class MobileCLIP(nn.Module):
def __init__(
self,
model_name: str,
weight_name: str,
block_size: int = 16,
norm: str = "none",
act: str = "none"
) -> None:
super().__init__()
assert model_name in mobileclip_names_and_weights, f"Model name should be one of {list(mobileclip_names_and_weights.keys())}, but got {model_name}."
assert weight_name in mobileclip_names_and_weights[model_name], f"Pretrained should be one of {mobileclip_names_and_weights[model_name]}, but got {weight_name}."
assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
self.model_name, self.weight_name = model_name, weight_name
self.block_size = block_size
# model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
model = open_clip.create_model(model_name=model_name, pretrained=False, load_weights=False).visual
self.stem = model.trunk.stem
self.stages = model.trunk.stages
self.depth = len(model.trunk.stages)
self.final_conv = model.trunk.final_conv
self.in_features, self.out_features = model.trunk.head.fc.in_features, model.trunk.head.fc.out_features
# refine_block = LightConvRefine if model_name == "MobileCLIP-S1" else ConvRefine
# upsample_block = LightConvUpsample if model_name == "MobileCLIP-S1" else ConvUpsample
if norm == "bn":
norm_layer = nn.BatchNorm2d
elif norm == "ln":
norm_layer = nn.LayerNorm
else:
norm_layer = _get_norm_layer(model)
if act == "relu":
activation = nn.ReLU(inplace=True)
elif act == "gelu":
activation = nn.GELU()
else:
activation = _get_activation(model)
if block_size == 32:
self.refiner = ConvRefine(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[model_name],
)
elif block_size == 16:
self.refiner = ConvUpsample(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
)
else: # block_size == 8
self.refiner = nn.Sequential(
ConvUpsample(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
),
ConvUpsample(
in_channels=self.in_features,
out_channels=self.in_features,
norm_layer=norm_layer,
activation=activation,
groups=refiner_groups[self.model_name],
),
)
def forward(self, x: Tensor) -> Tensor:
x = self.stem(x)
for idx in range(self.depth):
x = self.stages[idx](x)
x = self.final_conv(x)
x = self.refiner(x)
return x
def _mobileclip(
model_name: str,
weight_name: str,
block_size: int = 16,
norm: str = "none",
act: str = "none"
) -> MobileCLIP:
model = MobileCLIP(
model_name=model_name,
weight_name=weight_name,
block_size=block_size,
norm=norm,
act=act
)
return model