edugp commited on
Commit
2b6deb0
1 Parent(s): 2bc8c97

Add actual files instead of symlinks

Browse files
configuration_hybrid_clip.py DELETED
@@ -1 +0,0 @@
1
- /home/eduardogonzalezponferrada/transformers/examples/research_projects/jax-projects/hybrid_clip/configuration_hybrid_clip.py
 
 
configuration_hybrid_clip.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class HybridCLIPConfig(PretrainedConfig):
11
+ r"""
12
+ :class:`HybridCLIPConfig` is the configuration class to store the configuration of a
13
+ :class:`~HybridCLIPModel`. It is used to instantiate HybridCLIPModel model according to the specified arguments,
14
+ defining the text model and vision model configs.
15
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
16
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
17
+ Args:
18
+ text_config_dict (:obj:`dict`):
19
+ Dictionary of configuration options that defines text model config.
20
+ vision_config_dict (:obj:`dict`):
21
+ Dictionary of configuration options that defines vison model config.
22
+ projection_dim (:obj:`int`, `optional`, defaults to 512):
23
+ Dimentionality of text and vision projection layers.
24
+ kwargs (`optional`):
25
+ Dictionary of keyword arguments.
26
+ Examples::
27
+ >>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
28
+ >>> # Initializing a BERT and CLIP configuration
29
+ >>> config_text = BertConfig()
30
+ >>> config_vision = CLIPConfig()
31
+ >>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
32
+ >>> # Initializing a BERT and CLIPVision model
33
+ >>> model = EncoderDecoderModel(config=config)
34
+ >>> # Accessing the model configuration
35
+ >>> config_text = model.config.text_config
36
+ >>> config_vision = model.config.vision_config
37
+ >>> # Saving the model, including its configuration
38
+ >>> model.save_pretrained('my-model')
39
+ >>> # loading model and config from pretrained folder
40
+ >>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
41
+ >>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
42
+ """
43
+
44
+ model_type = "hybrid-clip"
45
+ is_composition = True
46
+
47
+ def __init__(self, projection_dim=512, **kwargs):
48
+ super().__init__(**kwargs)
49
+
50
+ if "text_config" not in kwargs:
51
+ raise ValueError("`text_config` can not be `None`.")
52
+
53
+ if "vision_config" not in kwargs:
54
+ raise ValueError("`vision_config` can not be `None`.")
55
+
56
+ text_config = kwargs.pop("text_config")
57
+ vision_config = kwargs.pop("vision_config")
58
+
59
+ text_model_type = text_config.pop("model_type")
60
+ vision_model_type = vision_config.pop("model_type")
61
+
62
+ from transformers import AutoConfig
63
+
64
+ self.text_config = AutoConfig.for_model(text_model_type, **text_config)
65
+
66
+ if vision_model_type == "clip":
67
+ self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
68
+ elif vision_model_type == "clip_vision_model":
69
+ from transformers import CLIPVisionConfig
70
+
71
+ self.vision_config = CLIPVisionConfig(**vision_config)
72
+ else:
73
+ self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
74
+
75
+ self.projection_dim = projection_dim
76
+ self.initializer_factor = 1.0
77
+
78
+ @classmethod
79
+ def from_text_vision_configs(cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs):
80
+ r"""
81
+ Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
82
+ vision model configuration.
83
+ Returns:
84
+ :class:`HybridCLIPConfig`: An instance of a configuration object
85
+ """
86
+
87
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
88
+
89
+ def to_dict(self):
90
+ """
91
+ Serializes this instance to a Python dictionary. Override the default
92
+ :meth:`~transformers.PretrainedConfig.to_dict`.
93
+ Returns:
94
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
95
+ """
96
+ output = copy.deepcopy(self.__dict__)
97
+ output["text_config"] = self.text_config.to_dict()
98
+ output["vision_config"] = self.vision_config.to_dict()
99
+ output["model_type"] = self.__class__.model_type
100
+ return output
modeling_hybrid_clip.py DELETED
@@ -1 +0,0 @@
1
- /home/eduardogonzalezponferrada/transformers/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py
 
 
modeling_hybrid_clip.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from configuration_hybrid_clip import HybridCLIPConfig
22
+ from flax.core.frozen_dict import FrozenDict
23
+ from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
24
+ from transformers.modeling_flax_utils import FlaxPreTrainedModel
25
+ from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class FlaxHybridCLIPModule(nn.Module):
33
+ config: HybridCLIPConfig
34
+ dtype: jnp.dtype = jnp.float32
35
+
36
+ def setup(self):
37
+ text_config = self.config.text_config
38
+ vision_config = self.config.vision_config
39
+
40
+ self.projection_dim = self.config.projection_dim
41
+ self.text_embed_dim = text_config.hidden_size
42
+ self.vision_embed_dim = vision_config.hidden_size
43
+
44
+ text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
45
+ vision_module = FLAX_MODEL_MAPPING.get(self.config.vision_config.__class__, FlaxCLIPVisionModel).module_class
46
+
47
+ self.text_model = text_module(text_config, dtype=self.dtype)
48
+ self.vision_model = vision_module(vision_config, dtype=self.dtype)
49
+
50
+ self.visual_projection = nn.Dense(
51
+ self.projection_dim,
52
+ dtype=self.dtype,
53
+ kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
54
+ use_bias=False,
55
+ )
56
+ self.text_projection = nn.Dense(
57
+ self.projection_dim,
58
+ dtype=self.dtype,
59
+ kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
60
+ use_bias=False,
61
+ )
62
+ self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
63
+
64
+ def __call__(
65
+ self,
66
+ input_ids=None,
67
+ pixel_values=None,
68
+ attention_mask=None,
69
+ position_ids=None,
70
+ token_type_ids=None,
71
+ deterministic: bool = True,
72
+ output_attentions=None,
73
+ output_hidden_states=None,
74
+ return_dict=None,
75
+ ):
76
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
77
+
78
+ vision_outputs = self.vision_model(
79
+ pixel_values=pixel_values,
80
+ deterministic=deterministic,
81
+ output_attentions=output_attentions,
82
+ output_hidden_states=output_hidden_states,
83
+ return_dict=return_dict,
84
+ )
85
+
86
+ text_outputs = self.text_model(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask,
89
+ token_type_ids=token_type_ids,
90
+ position_ids=position_ids,
91
+ deterministic=deterministic,
92
+ output_attentions=output_attentions,
93
+ output_hidden_states=output_hidden_states,
94
+ return_dict=return_dict,
95
+ )
96
+
97
+ image_embeds = vision_outputs[1]
98
+ image_embeds = self.visual_projection(image_embeds)
99
+
100
+ text_embeds = text_outputs[1]
101
+ text_embeds = self.text_projection(text_embeds)
102
+
103
+ # normalized features
104
+ image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
105
+ text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
106
+
107
+ # cosine similarity as logits
108
+ logit_scale = jnp.exp(self.logit_scale)
109
+ logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
110
+ logits_per_image = logits_per_text.T
111
+
112
+ if not return_dict:
113
+ return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
114
+
115
+ return FlaxCLIPOutput(
116
+ logits_per_image=logits_per_image,
117
+ logits_per_text=logits_per_text,
118
+ text_embeds=text_embeds,
119
+ image_embeds=image_embeds,
120
+ text_model_output=text_outputs,
121
+ vision_model_output=vision_outputs,
122
+ )
123
+
124
+
125
+ class FlaxHybridCLIP(FlaxPreTrainedModel):
126
+ config_class = HybridCLIPConfig
127
+ module_class = FlaxHybridCLIPModule
128
+
129
+ def __init__(
130
+ self,
131
+ config: HybridCLIPConfig,
132
+ input_shape: Optional[Tuple] = None,
133
+ seed: int = 0,
134
+ dtype: jnp.dtype = jnp.float32,
135
+ **kwargs
136
+ ):
137
+ if input_shape is None:
138
+ input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
139
+
140
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
141
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
142
+
143
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
144
+ # init input tensor
145
+ input_ids = jnp.zeros(input_shape[0], dtype="i4")
146
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
147
+ token_type_ids = jnp.ones_like(input_ids)
148
+ attention_mask = jnp.ones_like(input_ids)
149
+
150
+ pixel_values = jax.random.normal(rng, input_shape[1])
151
+
152
+ params_rng, dropout_rng = jax.random.split(rng)
153
+ rngs = {"params": params_rng, "dropout": dropout_rng}
154
+
155
+ return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids)["params"]
156
+
157
+ def __call__(
158
+ self,
159
+ input_ids,
160
+ pixel_values,
161
+ attention_mask=None,
162
+ position_ids=None,
163
+ token_type_ids=None,
164
+ params: dict = None,
165
+ dropout_rng: jax.random.PRNGKey = None,
166
+ train: bool = False,
167
+ output_attentions: Optional[bool] = None,
168
+ output_hidden_states: Optional[bool] = None,
169
+ return_dict: Optional[bool] = None,
170
+ ):
171
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
172
+ output_hidden_states = (
173
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
174
+ )
175
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
176
+
177
+ if position_ids is None:
178
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
179
+
180
+ if token_type_ids is None:
181
+ token_type_ids = jnp.zeros_like(input_ids)
182
+
183
+ if attention_mask is None:
184
+ attention_mask = jnp.ones_like(input_ids)
185
+
186
+ # Handle any PRNG if needed
187
+ rngs = {}
188
+ if dropout_rng is not None:
189
+ rngs["dropout"] = dropout_rng
190
+
191
+ return self.module.apply(
192
+ {"params": params or self.params},
193
+ jnp.array(input_ids, dtype="i4"),
194
+ jnp.array(pixel_values, dtype=jnp.float32),
195
+ jnp.array(attention_mask, dtype="i4"),
196
+ jnp.array(position_ids, dtype="i4"),
197
+ jnp.array(token_type_ids, dtype="i4"),
198
+ not train,
199
+ output_attentions,
200
+ output_hidden_states,
201
+ return_dict,
202
+ rngs=rngs,
203
+ )
204
+
205
+ def get_text_features(
206
+ self,
207
+ input_ids,
208
+ attention_mask=None,
209
+ position_ids=None,
210
+ token_type_ids=None,
211
+ dropout_rng: jax.random.PRNGKey = None,
212
+ train=False,
213
+ ):
214
+ r"""
215
+ Args:
216
+ input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
217
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
218
+ provide it.
219
+ Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
220
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
221
+ for details.
222
+ `What are input IDs? <../glossary.html#input-ids>`__
223
+ Returns:
224
+ text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
225
+ obtained by applying the projection layer to the pooled output of text model.
226
+ """
227
+ if position_ids is None:
228
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
229
+
230
+ if token_type_ids is None:
231
+ token_type_ids = jnp.zeros_like(input_ids)
232
+
233
+ if attention_mask is None:
234
+ attention_mask = jnp.ones_like(input_ids)
235
+
236
+ # Handle any PRNG if needed
237
+ rngs = {}
238
+ if dropout_rng is not None:
239
+ rngs["dropout"] = dropout_rng
240
+
241
+ def _get_features(module, input_ids, attention_mask, position_ids, token_type_ids, deterministic):
242
+ text_outputs = module.text_model(
243
+ input_ids=input_ids,
244
+ attention_mask=attention_mask,
245
+ position_ids=position_ids,
246
+ token_type_ids=token_type_ids,
247
+ deterministic=deterministic,
248
+ )
249
+ pooled_output = text_outputs[1]
250
+ text_features = module.text_projection(pooled_output)
251
+ return text_features
252
+
253
+ return self.module.apply(
254
+ {"params": self.params},
255
+ jnp.array(input_ids, dtype="i4"),
256
+ jnp.array(attention_mask, dtype="i4"),
257
+ jnp.array(position_ids, dtype="i4"),
258
+ jnp.array(token_type_ids, dtype="i4"),
259
+ not train,
260
+ method=_get_features,
261
+ rngs=rngs,
262
+ )
263
+
264
+ def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
265
+ r"""
266
+ Args:
267
+ pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
268
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
269
+ using :class:`~transformers.ImageFeatureExtractionMixin`. See
270
+ :meth:`transformers.ImageFeatureExtractionMixin.__call__` for details.
271
+ Returns:
272
+ image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
273
+ obtained by applying the projection layer to the pooled output of vision model.
274
+ """
275
+
276
+ # Handle any PRNG if needed
277
+ rngs = {}
278
+ if dropout_rng is not None:
279
+ rngs["dropout"] = dropout_rng
280
+
281
+ def _get_features(module, pixel_values, deterministic):
282
+ vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
283
+ pooled_output = vision_outputs[1] # pooled_output
284
+ image_features = module.visual_projection(pooled_output)
285
+ return image_features
286
+
287
+ return self.module.apply(
288
+ {"params": self.params},
289
+ jnp.array(pixel_values, dtype=jnp.float32),
290
+ not train,
291
+ method=_get_features,
292
+ rngs=rngs,
293
+ )
294
+
295
+ @classmethod
296
+ def from_text_vision_pretrained(
297
+ cls,
298
+ text_model_name_or_path: str = None,
299
+ vision_model_name_or_path: str = None,
300
+ *model_args,
301
+ **kwargs,
302
+ ) -> FlaxPreTrainedModel:
303
+ """
304
+ Params:
305
+ text_model_name_or_path (:obj: `str`, `optional`):
306
+ Information necessary to initiate the text model. Can be either:
307
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
308
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
309
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
310
+ - A path to a `directory` containing model weights saved using
311
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
312
+ - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
313
+ this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
314
+ as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
315
+ a Flax model using the provided conversion scripts and loading the Flax model afterwards.
316
+ vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
317
+ Information necessary to initiate the vision model. Can be either:
318
+ - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
319
+ Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
320
+ a user or organization name, like ``dbmdz/bert-base-german-cased``.
321
+ - A path to a `directory` containing model weights saved using
322
+ :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
323
+ - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
324
+ this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
325
+ as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
326
+ a Flax model using the provided conversion scripts and loading the Flax model afterwards.
327
+ model_args (remaining positional arguments, `optional`):
328
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
329
+ kwargs (remaining dictionary of keyword arguments, `optional`):
330
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
331
+ :obj:`output_attentions=True`).
332
+ - To update the text configuration, use the prefix `text_` for each configuration parameter.
333
+ - To update the vision configuration, use the prefix `vision_` for each configuration parameter.
334
+ - To update the parent model configuration, do not use a prefix for each configuration parameter.
335
+ Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
336
+ Example::
337
+ >>> from transformers import FlaxHybridCLIP
338
+ >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
339
+ >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
340
+ >>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
341
+ >>> # saving model after fine-tuning
342
+ >>> model.save_pretrained("./bert-clip")
343
+ >>> # load fine-tuned model
344
+ >>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
345
+ """
346
+
347
+ kwargs_text = {
348
+ argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
349
+ }
350
+
351
+ kwargs_vision = {
352
+ argument[len("vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_")
353
+ }
354
+
355
+ # remove text, vision kwargs from kwargs
356
+ for key in kwargs_text.keys():
357
+ del kwargs["text_" + key]
358
+ for key in kwargs_vision.keys():
359
+ del kwargs["vision_" + key]
360
+
361
+ # Load and initialize the text and vision model
362
+ text_model = kwargs_text.pop("model", None)
363
+ if text_model is None:
364
+ assert (
365
+ text_model_name_or_path is not None
366
+ ), "If `model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
367
+ from transformers import FlaxAutoModel
368
+
369
+ if "config" not in kwargs_text:
370
+ from transformers import AutoConfig
371
+
372
+ text_config = AutoConfig.from_pretrained(text_model_name_or_path)
373
+ kwargs_text["config"] = text_config
374
+
375
+ text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
376
+
377
+ vision_model = kwargs_vision.pop("model", None)
378
+ if vision_model is None:
379
+ assert (
380
+ vision_model_name_or_path is not None
381
+ ), "If `model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
382
+ from transformers import FlaxAutoModel
383
+
384
+ if "config" not in kwargs_vision:
385
+ from transformers import AutoConfig
386
+
387
+ vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
388
+ kwargs_vision["config"] = vision_config
389
+
390
+ vision_model = FlaxAutoModel.from_pretrained(vision_model_name_or_path, *model_args, **kwargs_vision)
391
+
392
+ # instantiate config with corresponding kwargs
393
+ dtype = kwargs.pop("dtype", jnp.float32)
394
+ config = HybridCLIPConfig.from_text_vision_configs(text_model.config, vision_model.config, **kwargs)
395
+
396
+ # init model
397
+ model = cls(config, *model_args, dtype=dtype, **kwargs)
398
+
399
+ if vision_config.model_type == "clip":
400
+ model.params["vision_model"]["vision_model"] = vision_model.params["vision_model"]
401
+ model.params["visual_projection"]["kernel"] = vision_model.params["visual_projection"]["kernel"]
402
+ else:
403
+ model.params["vision_model"] = vision_model.params
404
+
405
+ model.params["text_model"] = text_model.params
406
+
407
+ return model