mrdbourke commited on
Commit
64305ae
·
verified ·
1 Parent(s): 0b2dd4d

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. app.py +159 -0
  3. baklava.jpg +3 -0
  4. cat.jpg +3 -0
  5. mobileclip/modules/common/mobileone.py +341 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ baklava.jpg filter=lfs diff=lfs merge=lfs -text
37
+ cat.jpg filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MobileCLIP2 Zero-Shot Classification Demo"""
2
+ import torch
3
+ import open_clip
4
+ import gradio as gr
5
+ from mobileclip.modules.common.mobileone import reparameterize_model
6
+
7
+ ################################################################################
8
+ # Model Configuration
9
+ ################################################################################
10
+ AVAILABLE_MODELS = {
11
+ "MobileCLIP2-B": ("MobileCLIP2-B", "dfndr2b"),
12
+ "MobileCLIP2-S0": ("MobileCLIP2-S0", "dfndr2b"),
13
+ "MobileCLIP2-S2": ("MobileCLIP2-S2", "dfndr2b"),
14
+ "MobileCLIP2-S3": ("MobileCLIP2-S3", "dfndr2b"),
15
+ "MobileCLIP2-S4": ("MobileCLIP2-S4", "dfndr2b"),
16
+ "MobileCLIP2-L-14": ("MobileCLIP2-L-14", "dfndr2b"),
17
+ }
18
+
19
+ # Cache for loaded models
20
+ model_cache = {}
21
+
22
+ ################################################################################
23
+ # Model Loading
24
+ ################################################################################
25
+ def load_model(model_name):
26
+ """Load and cache MobileCLIP2 model"""
27
+ if model_name in model_cache:
28
+ return model_cache[model_name]
29
+
30
+ model_id, pretrained = AVAILABLE_MODELS[model_name]
31
+
32
+ # Create model and preprocessing transforms
33
+ model, _, preprocess = open_clip.create_model_and_transforms(
34
+ model_id,
35
+ pretrained=pretrained
36
+ )
37
+ tokenizer = open_clip.get_tokenizer(model_id)
38
+
39
+ # Reparameterize model for inference
40
+ model = reparameterize_model(model.eval())
41
+
42
+ # Cache the model components
43
+ model_cache[model_name] = {
44
+ "model": model,
45
+ "preprocess": preprocess,
46
+ "tokenizer": tokenizer
47
+ }
48
+
49
+ return model_cache[model_name]
50
+
51
+ ################################################################################
52
+ # Inference
53
+ ################################################################################
54
+ def classify_image(image, candidate_labels, model_name):
55
+ """
56
+ Classify image using selected MobileCLIP2 model
57
+
58
+ Args:
59
+ image: PIL Image
60
+ candidate_labels: comma-separated string of labels
61
+ model_name: selected model from dropdown
62
+
63
+ Returns:
64
+ Dictionary of label probabilities
65
+ """
66
+ if image is None:
67
+ return {}
68
+
69
+ # Parse labels
70
+ labels = [label.strip() for label in candidate_labels.split(",") if label.strip()]
71
+
72
+ if not labels:
73
+ return {}
74
+
75
+ # Load model components
76
+ model_components = load_model(model_name)
77
+ model = model_components["model"]
78
+ preprocess = model_components["preprocess"]
79
+ tokenizer = model_components["tokenizer"]
80
+
81
+ # Preprocess image
82
+ image_tensor = preprocess(image.convert('RGB')).unsqueeze(0)
83
+
84
+ # Tokenize text
85
+ text_tokens = tokenizer(labels)
86
+
87
+ # Run inference
88
+ # with torch.no_grad(), torch.cuda.amp.autocast():
89
+ with torch.no_grad():
90
+ image_features = model.encode_image(image_tensor)
91
+ text_features = model.encode_text(text_tokens)
92
+
93
+ # Normalize features
94
+ image_features /= image_features.norm(dim=-1, keepdim=True)
95
+ text_features /= text_features.norm(dim=-1, keepdim=True)
96
+
97
+ # Compute similarity and probabilities
98
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
99
+
100
+ # Format output as dictionary
101
+ output = {labels[i]: float(text_probs[0][i]) for i in range(len(labels))}
102
+
103
+ return output
104
+
105
+ ################################################################################
106
+ # Gradio Interface
107
+ ################################################################################
108
+ with gr.Blocks() as demo:
109
+ gr.Markdown("# MobileCLIP2 Zero-Shot Image Classification")
110
+ gr.Markdown(
111
+ "Classify images using MobileCLIP2 models. Select a model, upload an image, "
112
+ "and provide comma-separated class labels to get predictions."
113
+ )
114
+
115
+ with gr.Row():
116
+ with gr.Column():
117
+ model_dropdown = gr.Dropdown(
118
+ choices=list(AVAILABLE_MODELS.keys()),
119
+ value="MobileCLIP2-S2",
120
+ label="Select MobileCLIP2 Model",
121
+ info="Choose which model to use for classification"
122
+ )
123
+ image_input = gr.Image(type="pil", label="Upload Image")
124
+ text_input = gr.Textbox(
125
+ label="Class Labels (comma separated)",
126
+ placeholder="e.g., a cat, a dog, a bird"
127
+ )
128
+ run_button = gr.Button("Classify", variant="primary")
129
+
130
+ with gr.Column():
131
+ output_label = gr.Label(
132
+ label="Classification Results",
133
+ num_top_classes=5
134
+ )
135
+
136
+ # Examples
137
+ examples = [
138
+ ["MobileCLIP2-S2", "./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
139
+ ["MobileCLIP2-S2", "./cat.jpg", "a cat, two cats, three cats"],
140
+ ["MobileCLIP2-S2", "./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
141
+ ]
142
+
143
+ gr.Examples(
144
+ examples=examples,
145
+ inputs=[model_dropdown, image_input, text_input],
146
+ outputs=[output_label],
147
+ fn=classify_image,
148
+ cache_examples=False
149
+ )
150
+
151
+ # Connect button
152
+ run_button.click(
153
+ fn=classify_image,
154
+ inputs=[image_input, text_input, model_dropdown],
155
+ outputs=[output_label]
156
+ )
157
+
158
+ if __name__ == "__main__":
159
+ demo.launch()
baklava.jpg ADDED

Git LFS Details

  • SHA256: c7b83d3f4d8e57b63c94783c3054d064073e9dbaae524d32764ea2f470b65582
  • Pointer size: 131 Bytes
  • Size of remote file: 148 kB
cat.jpg ADDED

Git LFS Details

  • SHA256: dea9e7ef97386345f7cff32f9055da4982da5471c48d575146c796ab4563b04e
  • Pointer size: 131 Bytes
  • Size of remote file: 173 kB
mobileclip/modules/common/mobileone.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ from typing import Union, Tuple
6
+
7
+ import copy
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ __all__ = ["MobileOneBlock", "reparameterize_model"]
13
+
14
+
15
+ class SEBlock(nn.Module):
16
+ """Squeeze and Excite module.
17
+
18
+ Pytorch implementation of `Squeeze-and-Excitation Networks` -
19
+ https://arxiv.org/pdf/1709.01507.pdf
20
+ """
21
+
22
+ def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
23
+ """Construct a Squeeze and Excite Module.
24
+
25
+ Args:
26
+ in_channels: Number of input channels.
27
+ rd_ratio: Input channel reduction ratio.
28
+ """
29
+ super(SEBlock, self).__init__()
30
+ self.reduce = nn.Conv2d(
31
+ in_channels=in_channels,
32
+ out_channels=int(in_channels * rd_ratio),
33
+ kernel_size=1,
34
+ stride=1,
35
+ bias=True,
36
+ )
37
+ self.expand = nn.Conv2d(
38
+ in_channels=int(in_channels * rd_ratio),
39
+ out_channels=in_channels,
40
+ kernel_size=1,
41
+ stride=1,
42
+ bias=True,
43
+ )
44
+
45
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
46
+ """Apply forward pass."""
47
+ b, c, h, w = inputs.size()
48
+ x = F.avg_pool2d(inputs, kernel_size=[h, w])
49
+ x = self.reduce(x)
50
+ x = F.relu(x)
51
+ x = self.expand(x)
52
+ x = torch.sigmoid(x)
53
+ x = x.view(-1, c, 1, 1)
54
+ return inputs * x
55
+
56
+
57
+ class MobileOneBlock(nn.Module):
58
+ """MobileOne building block.
59
+
60
+ This block has a multi-branched architecture at train-time
61
+ and plain-CNN style architecture at inference time
62
+ For more details, please refer to our paper:
63
+ `An Improved One millisecond Mobile Backbone` -
64
+ https://arxiv.org/pdf/2206.04040.pdf
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ in_channels: int,
70
+ out_channels: int,
71
+ kernel_size: int,
72
+ stride: int = 1,
73
+ padding: int = 0,
74
+ dilation: int = 1,
75
+ groups: int = 1,
76
+ inference_mode: bool = False,
77
+ use_se: bool = False,
78
+ use_act: bool = True,
79
+ use_scale_branch: bool = True,
80
+ num_conv_branches: int = 1,
81
+ activation: nn.Module = nn.GELU(),
82
+ ) -> None:
83
+ """Construct a MobileOneBlock module.
84
+
85
+ Args:
86
+ in_channels: Number of channels in the input.
87
+ out_channels: Number of channels produced by the block.
88
+ kernel_size: Size of the convolution kernel.
89
+ stride: Stride size.
90
+ padding: Zero-padding size.
91
+ dilation: Kernel dilation factor.
92
+ groups: Group number.
93
+ inference_mode: If True, instantiates model in inference mode.
94
+ use_se: Whether to use SE-ReLU activations.
95
+ use_act: Whether to use activation. Default: ``True``
96
+ use_scale_branch: Whether to use scale branch. Default: ``True``
97
+ num_conv_branches: Number of linear conv branches.
98
+ """
99
+ super(MobileOneBlock, self).__init__()
100
+ self.inference_mode = inference_mode
101
+ self.groups = groups
102
+ self.stride = stride
103
+ self.padding = padding
104
+ self.dilation = dilation
105
+ self.kernel_size = kernel_size
106
+ self.in_channels = in_channels
107
+ self.out_channels = out_channels
108
+ self.num_conv_branches = num_conv_branches
109
+
110
+ # Check if SE-ReLU is requested
111
+ if use_se:
112
+ self.se = SEBlock(out_channels)
113
+ else:
114
+ self.se = nn.Identity()
115
+
116
+ if use_act:
117
+ self.activation = activation
118
+ else:
119
+ self.activation = nn.Identity()
120
+
121
+ if inference_mode:
122
+ self.reparam_conv = nn.Conv2d(
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ kernel_size=kernel_size,
126
+ stride=stride,
127
+ padding=padding,
128
+ dilation=dilation,
129
+ groups=groups,
130
+ bias=True,
131
+ )
132
+ else:
133
+ # Re-parameterizable skip connection
134
+ self.rbr_skip = (
135
+ nn.BatchNorm2d(num_features=in_channels)
136
+ if out_channels == in_channels and stride == 1
137
+ else None
138
+ )
139
+
140
+ # Re-parameterizable conv branches
141
+ if num_conv_branches > 0:
142
+ rbr_conv = list()
143
+ for _ in range(self.num_conv_branches):
144
+ rbr_conv.append(
145
+ self._conv_bn(kernel_size=kernel_size, padding=padding)
146
+ )
147
+ self.rbr_conv = nn.ModuleList(rbr_conv)
148
+ else:
149
+ self.rbr_conv = None
150
+
151
+ # Re-parameterizable scale branch
152
+ self.rbr_scale = None
153
+ if not isinstance(kernel_size, int):
154
+ kernel_size = kernel_size[0]
155
+ if (kernel_size > 1) and use_scale_branch:
156
+ self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
157
+
158
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
159
+ """Apply forward pass."""
160
+ # Inference mode forward pass.
161
+ if self.inference_mode:
162
+ return self.activation(self.se(self.reparam_conv(x)))
163
+
164
+ # Multi-branched train-time forward pass.
165
+ # Skip branch output
166
+ identity_out = 0
167
+ if self.rbr_skip is not None:
168
+ identity_out = self.rbr_skip(x)
169
+
170
+ # Scale branch output
171
+ scale_out = 0
172
+ if self.rbr_scale is not None:
173
+ scale_out = self.rbr_scale(x)
174
+
175
+ # Other branches
176
+ out = scale_out + identity_out
177
+ if self.rbr_conv is not None:
178
+ for ix in range(self.num_conv_branches):
179
+ out += self.rbr_conv[ix](x)
180
+
181
+ return self.activation(self.se(out))
182
+
183
+ def reparameterize(self):
184
+ """Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
185
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
186
+ architecture used at training time to obtain a plain CNN-like structure
187
+ for inference.
188
+ """
189
+ if self.inference_mode:
190
+ return
191
+ kernel, bias = self._get_kernel_bias()
192
+ self.reparam_conv = nn.Conv2d(
193
+ in_channels=self.in_channels,
194
+ out_channels=self.out_channels,
195
+ kernel_size=self.kernel_size,
196
+ stride=self.stride,
197
+ padding=self.padding,
198
+ dilation=self.dilation,
199
+ groups=self.groups,
200
+ bias=True,
201
+ )
202
+ self.reparam_conv.weight.data = kernel
203
+ self.reparam_conv.bias.data = bias
204
+
205
+ # Delete un-used branches
206
+ for para in self.parameters():
207
+ para.detach_()
208
+ self.__delattr__("rbr_conv")
209
+ self.__delattr__("rbr_scale")
210
+ if hasattr(self, "rbr_skip"):
211
+ self.__delattr__("rbr_skip")
212
+
213
+ self.inference_mode = True
214
+
215
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
216
+ """Method to obtain re-parameterized kernel and bias.
217
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
218
+
219
+ Returns:
220
+ Tuple of (kernel, bias) after fusing branches.
221
+ """
222
+ # get weights and bias of scale branch
223
+ kernel_scale = 0
224
+ bias_scale = 0
225
+ if self.rbr_scale is not None:
226
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
227
+ # Pad scale branch kernel to match conv branch kernel size.
228
+ pad = self.kernel_size // 2
229
+ kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
230
+
231
+ # get weights and bias of skip branch
232
+ kernel_identity = 0
233
+ bias_identity = 0
234
+ if self.rbr_skip is not None:
235
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
236
+
237
+ # get weights and bias of conv branches
238
+ kernel_conv = 0
239
+ bias_conv = 0
240
+ if self.rbr_conv is not None:
241
+ for ix in range(self.num_conv_branches):
242
+ _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
243
+ kernel_conv += _kernel
244
+ bias_conv += _bias
245
+
246
+ kernel_final = kernel_conv + kernel_scale + kernel_identity
247
+ bias_final = bias_conv + bias_scale + bias_identity
248
+ return kernel_final, bias_final
249
+
250
+ def _fuse_bn_tensor(
251
+ self, branch: Union[nn.Sequential, nn.BatchNorm2d]
252
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
253
+ """Method to fuse batchnorm layer with preceeding conv layer.
254
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
255
+
256
+ Args:
257
+ branch: Sequence of ops to be fused.
258
+
259
+ Returns:
260
+ Tuple of (kernel, bias) after fusing batchnorm.
261
+ """
262
+ if isinstance(branch, nn.Sequential):
263
+ kernel = branch.conv.weight
264
+ running_mean = branch.bn.running_mean
265
+ running_var = branch.bn.running_var
266
+ gamma = branch.bn.weight
267
+ beta = branch.bn.bias
268
+ eps = branch.bn.eps
269
+ else:
270
+ assert isinstance(branch, nn.BatchNorm2d)
271
+ if not hasattr(self, "id_tensor"):
272
+ input_dim = self.in_channels // self.groups
273
+
274
+ kernel_size = self.kernel_size
275
+ if isinstance(self.kernel_size, int):
276
+ kernel_size = (self.kernel_size, self.kernel_size)
277
+
278
+ kernel_value = torch.zeros(
279
+ (self.in_channels, input_dim, kernel_size[0], kernel_size[1]),
280
+ dtype=branch.weight.dtype,
281
+ device=branch.weight.device,
282
+ )
283
+ for i in range(self.in_channels):
284
+ kernel_value[
285
+ i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2
286
+ ] = 1
287
+ self.id_tensor = kernel_value
288
+ kernel = self.id_tensor
289
+ running_mean = branch.running_mean
290
+ running_var = branch.running_var
291
+ gamma = branch.weight
292
+ beta = branch.bias
293
+ eps = branch.eps
294
+ std = (running_var + eps).sqrt()
295
+ t = (gamma / std).reshape(-1, 1, 1, 1)
296
+ return kernel * t, beta - running_mean * gamma / std
297
+
298
+ def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
299
+ """Helper method to construct conv-batchnorm layers.
300
+
301
+ Args:
302
+ kernel_size: Size of the convolution kernel.
303
+ padding: Zero-padding size.
304
+
305
+ Returns:
306
+ Conv-BN module.
307
+ """
308
+ mod_list = nn.Sequential()
309
+ mod_list.add_module(
310
+ "conv",
311
+ nn.Conv2d(
312
+ in_channels=self.in_channels,
313
+ out_channels=self.out_channels,
314
+ kernel_size=kernel_size,
315
+ stride=self.stride,
316
+ padding=padding,
317
+ groups=self.groups,
318
+ bias=False,
319
+ ),
320
+ )
321
+ mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels))
322
+ return mod_list
323
+
324
+
325
+ def reparameterize_model(model: torch.nn.Module) -> nn.Module:
326
+ """Method returns a model where a multi-branched structure
327
+ used in training is re-parameterized into a single branch
328
+ for inference.
329
+
330
+ Args:
331
+ model: MobileOne model in train mode.
332
+
333
+ Returns:
334
+ MobileOne model in inference mode.
335
+ """
336
+ # Avoid editing original graph
337
+ model = copy.deepcopy(model)
338
+ for module in model.modules():
339
+ if hasattr(module, "reparameterize"):
340
+ module.reparameterize()
341
+ return model