Spaces:
Running
on
Zero
Running
on
Zero
from torch import nn, Tensor | |
import open_clip | |
from ..utils import ConvRefine, ConvUpsample | |
from ..utils import _get_norm_layer, _get_activation | |
mobileclip_names_and_weights = { | |
"MobileCLIP-S1": ["datacompdr"], | |
"MobileCLIP-S2": ["datacompdr"], | |
} | |
refiner_channels = { | |
"MobileCLIP-S1": 1024, | |
"MobileCLIP-S2": 1280, | |
} | |
refiner_groups = { | |
"MobileCLIP-S1": 2, | |
"MobileCLIP-S2": 2, | |
} | |
class MobileCLIP(nn.Module): | |
def __init__( | |
self, | |
model_name: str, | |
weight_name: str, | |
block_size: int = 16, | |
norm: str = "none", | |
act: str = "none" | |
) -> None: | |
super().__init__() | |
assert model_name in mobileclip_names_and_weights, f"Model name should be one of {list(mobileclip_names_and_weights.keys())}, but got {model_name}." | |
assert weight_name in mobileclip_names_and_weights[model_name], f"Pretrained should be one of {mobileclip_names_and_weights[model_name]}, but got {weight_name}." | |
assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}" | |
self.model_name, self.weight_name = model_name, weight_name | |
self.block_size = block_size | |
# model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual | |
model = open_clip.create_model(model_name=model_name, pretrained=False, load_weights=False).visual | |
self.stem = model.trunk.stem | |
self.stages = model.trunk.stages | |
self.depth = len(model.trunk.stages) | |
self.final_conv = model.trunk.final_conv | |
self.in_features, self.out_features = model.trunk.head.fc.in_features, model.trunk.head.fc.out_features | |
# refine_block = LightConvRefine if model_name == "MobileCLIP-S1" else ConvRefine | |
# upsample_block = LightConvUpsample if model_name == "MobileCLIP-S1" else ConvUpsample | |
if norm == "bn": | |
norm_layer = nn.BatchNorm2d | |
elif norm == "ln": | |
norm_layer = nn.LayerNorm | |
else: | |
norm_layer = _get_norm_layer(model) | |
if act == "relu": | |
activation = nn.ReLU(inplace=True) | |
elif act == "gelu": | |
activation = nn.GELU() | |
else: | |
activation = _get_activation(model) | |
if block_size == 32: | |
self.refiner = ConvRefine( | |
in_channels=self.in_features, | |
out_channels=self.in_features, | |
norm_layer=norm_layer, | |
activation=activation, | |
groups=refiner_groups[model_name], | |
) | |
elif block_size == 16: | |
self.refiner = ConvUpsample( | |
in_channels=self.in_features, | |
out_channels=self.in_features, | |
norm_layer=norm_layer, | |
activation=activation, | |
groups=refiner_groups[self.model_name], | |
) | |
else: # block_size == 8 | |
self.refiner = nn.Sequential( | |
ConvUpsample( | |
in_channels=self.in_features, | |
out_channels=self.in_features, | |
norm_layer=norm_layer, | |
activation=activation, | |
groups=refiner_groups[self.model_name], | |
), | |
ConvUpsample( | |
in_channels=self.in_features, | |
out_channels=self.in_features, | |
norm_layer=norm_layer, | |
activation=activation, | |
groups=refiner_groups[self.model_name], | |
), | |
) | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.stem(x) | |
for idx in range(self.depth): | |
x = self.stages[idx](x) | |
x = self.final_conv(x) | |
x = self.refiner(x) | |
return x | |
def _mobileclip( | |
model_name: str, | |
weight_name: str, | |
block_size: int = 16, | |
norm: str = "none", | |
act: str = "none" | |
) -> MobileCLIP: | |
model = MobileCLIP( | |
model_name=model_name, | |
weight_name=weight_name, | |
block_size=block_size, | |
norm=norm, | |
act=act | |
) | |
return model |