g8a9 commited on
Commit
69dfe66
·
1 Parent(s): 88e550b

Add basic files. Try caching the model.

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +16 -2
  3. configuration_hybrid_clip.py +112 -0
  4. modeling_hybrid_clip.py +420 -0
  5. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -1,4 +1,18 @@
1
  import streamlit as st
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from modeling_hybrid_clip import FlaxHybridCLIP
3
 
4
+
5
+ @st.cache
6
+ def get_model():
7
+ return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
8
+
9
+
10
+ """
11
+ # CLIP Italian Demo (Flax Community Week)
12
+ """
13
+
14
+ x = st.slider("Select a value")
15
+ st.write(x, "squared is", x * x)
16
+
17
+ model = get_model()
18
+ st.write(str(model.config["architectures"]))
configuration_hybrid_clip.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class HybridCLIPConfig(PretrainedConfig):
11
+ r"""
12
+ :class:`HybridCLIPConfig` is the configuration class to store the configuration of a
13
+ :class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
14
+ defining the text model and vision model configs.
15
+
16
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
17
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
18
+
19
+ Args:
20
+ text_config_dict (:obj:`dict`):
21
+ Dictionary of configuration options that defines text model config.
22
+ vision_config_dict (:obj:`dict`):
23
+ Dictionary of configuration options that defines vison model config.
24
+ projection_dim (:obj:`int`, `optional`, defaults to 512):
25
+ Dimentionality of text and vision projection layers.
26
+ kwargs (`optional`):
27
+ Dictionary of keyword arguments.
28
+
29
+ Examples::
30
+
31
+ >>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
32
+
33
+ >>> # Initializing a BERT and CLIP configuration
34
+ >>> config_text = BertConfig()
35
+ >>> config_vision = CLIPConfig()
36
+
37
+ >>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
38
+
39
+ >>> # Initializing a BERT and CLIPVision model
40
+ >>> model = EncoderDecoderModel(config=config)
41
+
42
+ >>> # Accessing the model configuration
43
+ >>> config_text = model.config.text_config
44
+ >>> config_vision = model.config.vision_config
45
+
46
+ >>> # Saving the model, including its configuration
47
+ >>> model.save_pretrained('my-model')
48
+
49
+ >>> # loading model and config from pretrained folder
50
+ >>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
51
+ >>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
52
+ """
53
+
54
+ model_type = "hybrid-clip"
55
+ is_composition = True
56
+
57
+ def __init__(self, projection_dim=512, **kwargs):
58
+ super().__init__(**kwargs)
59
+
60
+ if "text_config" not in kwargs:
61
+ raise ValueError("`text_config` can not be `None`.")
62
+
63
+ if "vision_config" not in kwargs:
64
+ raise ValueError("`vision_config` can not be `None`.")
65
+
66
+ text_config = kwargs.pop("text_config")
67
+ vision_config = kwargs.pop("vision_config")
68
+
69
+ text_model_type = text_config.pop("model_type")
70
+ vision_model_type = vision_config.pop("model_type")
71
+
72
+ from transformers import AutoConfig
73
+
74
+ self.text_config = AutoConfig.for_model(text_model_type, **text_config)
75
+
76
+ if vision_model_type == "clip":
77
+ self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
78
+ elif vision_model_type == "clip_vision_model":
79
+ from transformers import CLIPVisionConfig
80
+
81
+ self.vision_config = CLIPVisionConfig(**vision_config)
82
+ else:
83
+ self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
84
+
85
+ self.projection_dim = projection_dim
86
+ self.initializer_factor = 1.0
87
+
88
+ @classmethod
89
+ def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
90
+ r"""
91
+ Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
92
+ vision model configuration.
93
+
94
+ Returns:
95
+ :class:`HybridCLIPConfig`: An instance of a configuration object
96
+ """
97
+
98
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
99
+
100
+ def to_dict(self):
101
+ """
102
+ Serializes this instance to a Python dictionary. Override the default
103
+ :meth:`~transformers.PretrainedConfig.to_dict`.
104
+
105
+ Returns:
106
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
107
+ """
108
+ output = copy.deepcopy(self.__dict__)
109
+ output["text_config"] = self.text_config.to_dict()
110
+ output["vision_config"] = self.vision_config.to_dict()
111
+ output["model_type"] = self.__class__.model_type
112
+ return output
modeling_hybrid_clip.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from configuration_hybrid_clip import HybridCLIPConfig
22
+ from flax.core.frozen_dict import FrozenDict
23
+ from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
24
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
25
+ from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class FlaxHybridCLIPModule(nn.Module):
33
+ config: HybridCLIPConfig
34
+ dtype: jnp.dtype = jnp.float32
35
+
36
+ def setup(self):
37
+ text_config = self.config.text_config
38
+ vision_config = self.config.vision_config
39
+
40
+ self.projection_dim = self.config.projection_dim
41
+ self.text_embed_dim = text_config.hidden_size
42
+ self.vision_embed_dim = vision_config.hidden_size
43
+
44
+ text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
45
+ vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
46
+
47
+ self.text_model = text_module(text_config, dtype=self.dtype)
48
+ self.vision_model = vision_module(vision_config, dtype=self.dtype)
49
+
50
+ self.visual_projection = nn.Dense(
51
+ self.projection_dim,
52
+ dtype=self.dtype,
53
+ kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
54
+ use_bias=False,
55
+ )
56
+ self.text_projection = nn.Dense(
57
+ self.projection_dim,
58
+ dtype=self.dtype,
59
+ kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
60
+ use_bias=False,
61
+ )
62
+ self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
63
+
64
+ def __call__(
65
+ self,
66
+ input_ids=None,
67
+ pixel_values=None,
68
+ attention_mask=None,
69
+ position_ids=None,
70
+ token_type_ids=None,
71
+ deterministic: bool = True,
72
+ output_attentions=None,
73
+ output_hidden_states=None,
74
+ return_dict=None,
75
+ ):
76
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
77
+
78
+ vision_outputs = self.vision_model(
79
+ pixel_values=pixel_values,
80
+ deterministic=deterministic,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ )
85
+
86
+ text_outputs = self.text_model(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask,
89
+ token_type_ids=token_type_ids,
90
+ position_ids=position_ids,
91
+ deterministic=deterministic,
92
+ output_attentions=output_attentions,
93
+ output_hidden_states=output_hidden_states,
94
+ return_dict=return_dict,
95
+ )
96
+
97
+ image_embeds = vision_outputs[1]
98
+ image_embeds = self.visual_projection(image_embeds)
99
+
100
+ text_embeds = text_outputs[1]
101
+ text_embeds = self.text_projection(text_embeds)
102
+
103
+ # normalized features
104
+ image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
105
+ text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
106
+
107
+ # cosine similarity as logits
108
+ logit_scale = jnp.exp(self.logit_scale)
109
+ logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
110
+ logits_per_image = logits_per_text.T
111
+
112
+ if not return_dict:
113
+ return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
114
+
115
+ return FlaxCLIPOutput(
116
+ logits_per_image=logits_per_image,
117
+ logits_per_text=logits_per_text,
118
+ text_embeds=text_embeds,
119
+ image_embeds=image_embeds,
120
+ text_model_output=text_outputs,
121
+ vision_model_output=vision_outputs,
122
+ )
123
+
124
+
125
+ class FlaxHybridCLIP(FlaxPreTrainedModel):
126
+ config_class = HybridCLIPConfig
127
+ module_class = FlaxHybridCLIPModule
128
+
129
+ def __init__(
130
+ self,
131
+ config: HybridCLIPConfig,
132
+ input_shape: Optional[Tuple] = None,
133
+ seed: int = 0,
134
+ dtype: jnp.dtype = jnp.float32,
135
+ **kwargs
136
+ ):
137
+ if input_shape is None:
138
+ input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
139
+
140
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
141
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
142
+
143
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
144
+ # init input tensor
145
+ input_ids = jnp.zeros(input_shape[0], dtype="i4")
146
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
147
+ token_type_ids = jnp.ones_like(input_ids)
148
+ attention_mask = jnp.ones_like(input_ids)
149
+
150
+ pixel_values = jax.random.normal(rng, input_shape[1])
151
+
152
+ params_rng, dropout_rng = jax.random.split(rng)
153
+ rngs = {"params": params_rng, "dropout": dropout_rng}
154
+
155
+ return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
156
+
157
+ def __call__(
158
+ self,
159
+ input_ids,
160
+ pixel_values,
161
+ attention_mask=None,
162
+ position_ids=None,
163
+ token_type_ids=None,
164
+ params: dict = None,
165
+ dropout_rng: jax.random.PRNGKey = None,
166
+ train: bool = False,
167
+ output_attentions: Optional[bool] = None,
168
+ output_hidden_states: Optional[bool] = None,
169
+ return_dict: Optional[bool] = None,
170
+ ):
171
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
172
+ output_hidden_states = (
173
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
174
+ )
175
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
176
+
177
+ if position_ids is None:
178
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
179
+
180
+ if token_type_ids is None:
181
+ token_type_ids = jnp.zeros_like(input_ids)
182
+
183
+ if attention_mask is None:
184
+ attention_mask = jnp.ones_like(input_ids)
185
+
186
+ # Handle any PRNG if needed
187
+ rngs = {}
188
+ if dropout_rng is not None:
189
+ rngs["dropout"] = dropout_rng
190
+
191
+ return self.module.apply(
192
+ {"params": params or self.params},
193
+ jnp.array(input_ids, dtype="i4"),
194
+ jnp.array(pixel_values, dtype=jnp.float32),
195
+ jnp.array(attention_mask, dtype="i4"),
196
+ jnp.array(position_ids, dtype="i4"),
197
+ jnp.array(token_type_ids, dtype="i4"),
198
+ not train,
199
+ output_attentions,
200
+ output_hidden_states,
201
+ return_dict,
202
+ rngs=rngs,
203
+ )
204
+
205
+ def get_text_features(
206
+ self,
207
+ input_ids,
208
+ attention_mask=None,
209
+ position_ids=None,
210
+ token_type_ids=None,
211
+ dropout_rng: jax.random.PRNGKey = None,
212
+ train=False,
213
+ ):
214
+ r"""
215
+ Args:
216
+ input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
217
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
218
+ provide it.
219
+
220
+ Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
221
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
222
+ for details.
223
+
224
+ `What are input IDs? <../glossary.html#input-ids>`__
225
+
226
+ Returns:
227
+ text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
228
+ obtained by applying the projection layer to the pooled output of text model.
229
+ """
230
+ if position_ids is None:
231
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
232
+
233
+ if token_type_ids is None:
234
+ token_type_ids = jnp.zeros_like(input_ids)
235
+
236
+ if attention_mask is None:
237
+ attention_mask = jnp.ones_like(input_ids)
238
+
239
+ # Handle any PRNG if needed
240
+ rngs = {}
241
+ if dropout_rng is not None:
242
+ rngs["dropout"] = dropout_rng
243
+
244
+ def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
245
+ text_outputs = module.text_model(
246
+ input_ids=input_ids,
247
+ attention_mask=attention_mask,
248
+ position_ids=position_ids,
249
+ token_type_ids=token_type_ids,
250
+ deterministic=deterministic,
251
+ )
252
+ pooled_output = text_outputs[1]
253
+ text_features = module.text_projection(pooled_output)
254
+ return text_features
255
+
256
+ return self.module.apply(
257
+ {"params": self.params},
258
+ jnp.array(input_ids, dtype="i4"),
259
+ jnp.array(attention_mask, dtype="i4"),
260
+ jnp.array(position_ids, dtype="i4"),
261
+ jnp.array(token_type_ids, dtype="i4"),
262
+ not train,
263
+ method=_get_features,
264
+ rngs=rngs,
265
+ )
266
+
267
+ def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
268
+ r"""
269
+ Args:
270
+ pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
271
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
272
+ using :class:`~transformers.ImageFeatureExtractionMixin`. See
273
+ :meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
274
+
275
+ Returns:
276
+ image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
277
+ obtained by applying the projection layer to the pooled output of vision model.
278
+ """
279
+
280
+ # Handle any PRNG if needed
281
+ rngs = {}
282
+ if dropout_rng is not None:
283
+ rngs["dropout"] = dropout_rng
284
+
285
+ def _get_features(module, pixel_values, deterministic):
286
+ vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
287
+ pooled_output = vision_outputs[1] # pooled_output
288
+ image_features = module.visual_projection(pooled_output)
289
+ return image_features
290
+
291
+ return self.module.apply(
292
+ {"params": self.params},
293
+ jnp.array(pixel_values, dtype=jnp.float32),
294
+ not train,
295
+ method=_get_features,
296
+ rngs=rngs,
297
+ )
298
+
299
+ @classmethod
300
+ def from_text_vision_pretrained(
301
+ cls,
302
+ text_model_name_or_path: str = None,
303
+ vision_model_name_or_path: str = None,
304
+ *model_args,
305
+ **kwargs,
306
+ ) -> FlaxPreTrainedModel:
307
+ """
308
+ Params:
309
+ text_model_name_or_path (:obj: `str`, `optional`):
310
+ Information necessary to initiate the text model. Can be either:
311
+
312
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
313
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
314
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
315
+ - A path to a `directory` containing model weights saved using
316
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
317
+ - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
318
+ this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
319
+ as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
320
+ a Flax model using the provided conversion scripts and loading the Flax model afterwards.
321
+
322
+ vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
323
+ Information necessary to initiate the vision model. Can be either:
324
+
325
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
326
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
327
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
328
+ - A path to a `directory` containing model weights saved using
329
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
330
+ - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
331
+ this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
332
+ as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
333
+ a Flax model using the provided conversion scripts and loading the Flax model afterwards.
334
+
335
+ model_args (remaining positional arguments, `optional`):
336
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
337
+
338
+ kwargs (remaining dictionary of keyword arguments, `optional`):
339
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
340
+ :obj:`output_attentions=True`).
341
+
342
+ - To update the text configuration, use the prefix `text_` for each configuration parameter.
343
+ - To update the vision configuration, use the prefix `vision_` for each configuration parameter.
344
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
345
+
346
+ Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
347
+
348
+ Example::
349
+
350
+ >>> from transformers import FlaxHybridCLIP
351
+ >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
352
+ >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
353
+ >>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
354
+ >>> # saving model after fine-tuning
355
+ >>> model.save_pretrained("./bert-clip")
356
+ >>> # load fine-tuned model
357
+ >>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
358
+ """
359
+
360
+ kwargs_text = {
361
+ argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
362
+ }
363
+
364
+ kwargs_vision = {
365
+ argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
366
+ }
367
+
368
+ # remove text, vision kwargs from kwargs
369
+ for key in kwargs_text.keys():
370
+ del kwargs["text_" + key]
371
+ for key in kwargs_vision.keys():
372
+ del kwargs["vision_" + key]
373
+
374
+ # Load and initialize the text and vision model
375
+ text_model = kwargs_text.pop("model", None)
376
+ if text_model is None:
377
+ assert (
378
+ text_model_name_or_path is not None
379
+ ), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
380
+ from transformers import FlaxAutoModel
381
+
382
+ if "config" not in kwargs_text:
383
+ from transformers import AutoConfig
384
+
385
+ text_config = AutoConfig.from_pretrained(text_model_name_or_path)
386
+ kwargs_text["config"] = text_config
387
+
388
+ text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
389
+
390
+ vision_model = kwargs_vision.pop("model", None)
391
+ if vision_model is None:
392
+ assert (
393
+ vision_model_name_or_path is not None
394
+ ), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
395
+ from transformers import FlaxAutoModel
396
+
397
+ if "config" not in kwargs_vision:
398
+ from transformers import AutoConfig
399
+
400
+ vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
401
+ kwargs_vision["config"] = vision_config
402
+
403
+ vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
404
+
405
+ # instantiate config with corresponding kwargs
406
+ dtype = kwargs.pop("dtype", jnp.float32)
407
+ config = HybridCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
408
+
409
+ # init model
410
+ model = cls(config, *model_args, dtype=dtype, **kwargs)
411
+
412
+ if vision_config.model_type == "clip":
413
+ model.params["vision_model"]["vision_model"] = vision_model.params["vision_model"]
414
+ model.params["visual_projection"]["kernel"] = vision_model.params["visual_projection"]["kernel"]
415
+ else:
416
+ model.params["vision_model"] = vision_model.params
417
+
418
+ model.params["text_model"] = text_model.params
419
+
420
+ return model
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ jax==0.2.17
2
+ flax==0.3.4
3
+ transformers==4.8.2