from typing import Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict from transformers.modeling_flax_outputs import ( FlaxBaseModelOutputWithPooling, FlaxMaskedLMOutput, FlaxSequenceClassifierOutput, ) from transformers.models.bert.modeling_flax_bert import ( FlaxBertEncoder, FlaxBertOnlyMLMHead, FlaxBertPooler, FlaxPreTrainedModel, ) from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule from .configuration_clip_vision_bert import CLIPVisionBertConfig class FlaxCLIPVisionBertEmbeddings(nn.Module): config: CLIPVisionBertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): bert_config = self.config.bert_config clip_vision_config = self.config.clip_vision_config self.word_embeddings = nn.Embed( bert_config.vocab_size, bert_config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=bert_config.initializer_range ), dtype=self.dtype, ) self.position_embeddings = nn.Embed( bert_config.max_position_embeddings, bert_config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=bert_config.initializer_range ), dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( bert_config.type_vocab_size, bert_config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=bert_config.initializer_range ), dtype=self.dtype, ) self.clip_vision_module = FlaxCLIPVisionModule( clip_vision_config, dtype=self.dtype ) self.visual_projection = nn.Dense( bert_config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( bert_config.initializer_range, self.dtype ), ) self.visual_position_embeddings = nn.Embed( bert_config.max_position_embeddings, bert_config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=bert_config.initializer_range ), dtype=self.dtype, ) self.visual_token_type_embeddings = nn.Embed( bert_config.type_vocab_size, bert_config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=bert_config.initializer_range ), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm( epsilon=bert_config.layer_norm_eps, dtype=self.dtype ) self.dropout = nn.Dropout(rate=bert_config.hidden_dropout_prob) def __call__( self, input_ids, token_type_ids, position_ids, pixel_values, visual_token_type_ids, visual_position_ids, deterministic: bool = True, ): # Embed inputs_embeds = self.word_embeddings(input_ids.astype("i4")) position_embeds = self.position_embeddings(position_ids.astype("i4")) token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) # Sum all embeddings word_embeddings = inputs_embeds + token_type_embeddings + position_embeds # Visual Embed visual_inputs_embeds = self.clip_vision_module(pixel_values=pixel_values)[0] visual_inputs_embeds = self.visual_projection(visual_inputs_embeds) visual_token_type_embeddings = self.visual_token_type_embeddings( visual_token_type_ids.astype("i4") ) visual_position_embeds = self.visual_position_embeddings( visual_position_ids.astype("i4") ) # Sum all visual embeddings visual_embeddings = ( visual_inputs_embeds + visual_token_type_embeddings + visual_position_embeds ) # Concat hidden_states = jnp.concatenate((word_embeddings, visual_embeddings), axis=1) # Layer Norm hidden_states = self.LayerNorm(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) return hidden_states class FlaxCLIPVisionBertModule(nn.Module): config: CLIPVisionBertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True def setup(self): self.embeddings = FlaxCLIPVisionBertEmbeddings(self.config, dtype=self.dtype) self.encoder = FlaxBertEncoder(self.config.bert_config, dtype=self.dtype) self.pooler = FlaxBertPooler(self.config.bert_config, dtype=self.dtype) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): hidden_states = self.embeddings( input_ids, token_type_ids, position_ids, pixel_values, visual_token_type_ids, visual_position_ids, deterministic=deterministic, ) combined_attention_mask = jnp.concatenate( (attention_mask, visual_attention_mask), axis=1 ) outputs = self.encoder( hidden_states, combined_attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] pooled = self.pooler(hidden_states) if self.add_pooling_layer else None if not return_dict: # if pooled is None, don't return it if pooled is None: return (hidden_states,) + outputs[1:] return (hidden_states, pooled) + outputs[1:] return FlaxBaseModelOutputWithPooling( last_hidden_state=hidden_states, pooler_output=pooled, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxCLIPVisionBertModel(FlaxPreTrainedModel): config_class = CLIPVisionBertConfig module_class = FlaxCLIPVisionBertModule def __init__( self, config: CLIPVisionBertConfig, input_shape: Tuple = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs, ): if input_shape is None: input_shape = ( (1, 1), ( 1, config.clip_vision_config.image_size, config.clip_vision_config.image_size, 3, ), ( 1, ( config.clip_vision_config.image_size // config.clip_vision_config.patch_size ) ** 2 + 1, ), ) module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__( config, module, input_shape=input_shape, seed=seed, dtype=dtype ) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors textual_input_shape = input_shape[0] input_ids = jnp.zeros(textual_input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), textual_input_shape ) attention_mask = jnp.ones_like(input_ids) pixel_values = jax.random.normal(rng, input_shape[1]) visual_attention_mask = jnp.ones(input_shape[2]) visual_token_type_ids = jnp.ones(input_shape[2]) visual_position_ids = jnp.broadcast_to( jnp.zeros(jnp.atleast_2d(visual_token_type_ids).shape[-1]), input_shape[2] ) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.module.init( rngs, input_ids, attention_mask, token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, return_dict=False, )["params"] def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, pixel_values=None, visual_attention_mask=None, visual_token_type_ids=None, visual_position_ids=None, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.bert_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.bert_config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.bert_config.return_dict ) # pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) # Don't need this for torch permuted input visual_sequence_length = ( pixel_values.shape[0], ( self.config.clip_vision_config.image_size // self.config.clip_vision_config.patch_size ) ** 2 + 1, ) # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.zeros_like(input_ids) if position_ids is None: position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape ) if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if visual_token_type_ids is None: visual_token_type_ids = jnp.ones(visual_sequence_length) if visual_position_ids is None: visual_position_ids = jnp.broadcast_to( jnp.atleast_2d(visual_token_type_ids).shape[-1], visual_sequence_length ) if visual_attention_mask is None: visual_attention_mask = jnp.ones(visual_sequence_length) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng return self.module.apply( {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), jnp.array(pixel_values, dtype=jnp.float32), jnp.array(visual_attention_mask, dtype="i4"), jnp.array(visual_token_type_ids, dtype="i4"), jnp.array(visual_position_ids, dtype="i4"), not train, output_attentions, output_hidden_states, return_dict, rngs=rngs, ) @classmethod def from_bert_clip_vision_pretrained( cls, bert_model_name_or_path: str = None, clip_vision_model_name_or_path: str = None, *model_args, **kwargs, ) -> FlaxPreTrainedModel: kwargs_bert = { argument[len("bert_") :]: value for argument, value in kwargs.items() if argument.startswith("text_") } kwargs_clip_vision = { argument[len("clip_vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_") } # remove text, vision kwargs from kwargs for key in kwargs_bert.keys(): del kwargs["bert_" + key] for key in kwargs_clip_vision.keys(): del kwargs["clip_vision_" + key] # Load and initialize the text and vision model bert_model = kwargs_bert.pop("model", None) if bert_model is None: assert ( bert_model_name_or_path is not None ), "If `model` is not defined as an argument, a `bert_model_name_or_path` has to be defined" from transformers import FlaxBertModel if "config" not in kwargs_bert: from transformers import BertConfig bert_config = BertConfig.from_pretrained(bert_model_name_or_path) kwargs_bert["config"] = bert_config bert_model = FlaxBertModel.from_pretrained( bert_model_name_or_path, *model_args, from_pt=True, **kwargs_bert ) clip_vision_model = kwargs_clip_vision.pop("model", None) if clip_vision_model is None: assert ( clip_vision_model_name_or_path is not None ), "If `model` is not defined as an argument, a `clip_vision_model_name_or_path` has to be defined" from transformers import FlaxCLIPVisionModel if "config" not in kwargs_clip_vision: from transformers import CLIPVisionConfig clip_vision_config = CLIPVisionConfig.from_pretrained( clip_vision_model_name_or_path ) kwargs_clip_vision["config"] = clip_vision_config clip_vision_model = FlaxCLIPVisionModel.from_pretrained( clip_vision_model_name_or_path, *model_args, **kwargs_clip_vision ) # instantiate config with corresponding kwargs dtype = kwargs.pop("dtype", jnp.float32) config = CLIPVisionBertConfig.from_bert_clip_vision_configs( bert_model.config, clip_vision_model.config, **kwargs ) # init model model = cls(config, *model_args, dtype=dtype, **kwargs) for key in model.params.keys(): if key != "embeddings": model.params[key] = bert_model.params[key] else: model.params["embeddings"][ "clip_vision_module" ] = clip_vision_model.params for sub_key in bert_model.params[key]: model.params[key][sub_key] = bert_model.params[key][sub_key] return model # flax_model = FlaxCLIPVisionBertModel.from_bert_clip_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32', seed=0, dtype=jnp.float32) # outputs = flax_model(input_ids, attention_mask,token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, output_hidden_states=True) class FlaxCLIPVisionBertForMaskedLMModule(nn.Module): config: CLIPVisionBertConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.model = FlaxCLIPVisionBertModule( config=self.config, add_pooling_layer=False, dtype=self.dtype ) self.cls = FlaxBertOnlyMLMHead(config=self.config.bert_config, dtype=self.dtype) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.model( input_ids, attention_mask, token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] if self.config.bert_config.tie_word_embeddings: shared_embedding = self.model.variables["params"]["embeddings"][ "word_embeddings" ]["embedding"] else: shared_embedding = None # Compute the prediction scores logits = self.cls(hidden_states, shared_embedding=shared_embedding) if not return_dict: return (logits,) + outputs[1:] return FlaxMaskedLMOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxCLIPVisionBertForMaskedLM(FlaxPreTrainedModel): config_class = CLIPVisionBertConfig module_class = FlaxCLIPVisionBertForMaskedLMModule def __init__( self, config: CLIPVisionBertConfig, input_shape: Tuple = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs, ): if input_shape is None: input_shape = ( (1, 1), ( 1, config.clip_vision_config.image_size, config.clip_vision_config.image_size, 3, ), ( 1, ( config.clip_vision_config.image_size // config.clip_vision_config.patch_size ) ** 2 + 1, ), ) module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__( config, module, input_shape=input_shape, seed=seed, dtype=dtype ) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors textual_input_shape = input_shape[0] input_ids = jnp.zeros(textual_input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), textual_input_shape ) attention_mask = jnp.ones_like(input_ids) pixel_values = jax.random.normal(rng, input_shape[1]) visual_attention_mask = jnp.ones(input_shape[2]) visual_token_type_ids = jnp.ones(input_shape[2]) visual_position_ids = jnp.broadcast_to( jnp.zeros(jnp.atleast_2d(visual_token_type_ids).shape[-1]), input_shape[2] ) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.module.init( rngs, input_ids, attention_mask, token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, return_dict=False, )["params"] def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, pixel_values=None, visual_attention_mask=None, visual_token_type_ids=None, visual_position_ids=None, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.bert_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.bert_config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.bert_config.return_dict ) # pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.zeros_like(input_ids) if position_ids is None: position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape ) if attention_mask is None: attention_mask = jnp.ones_like(input_ids) visual_sequence_length = ( pixel_values.shape[0], ( self.config.clip_vision_config.image_size // self.config.clip_vision_config.patch_size ) ** 2 + 1, ) if visual_token_type_ids is None: visual_token_type_ids = jnp.ones(visual_sequence_length) if visual_position_ids is None: visual_position_ids = jnp.broadcast_to( jnp.atleast_2d(jnp.ones(visual_sequence_length)).shape[-1], (visual_sequence_length), ) if visual_attention_mask is None: visual_attention_mask = jnp.ones(visual_sequence_length) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng return self.module.apply( {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), jnp.array(pixel_values, dtype=jnp.float32), jnp.array(visual_attention_mask, dtype="i4"), jnp.array(visual_token_type_ids, dtype="i4"), jnp.array(visual_position_ids, dtype="i4"), not train, output_attentions, output_hidden_states, return_dict, rngs=rngs, ) @classmethod def from_pretrained(cls, *args, **kwargs): # At the moment fast initialization is not supported # for composite models # kwargs["_fast_init"] = False return super().from_pretrained(*args, **kwargs) @classmethod def from_clip_vision_bert_pretrained( cls, clip_vision_model_name_or_path: str = None, bert_model_name_or_path: str = None, *model_args, **kwargs, ) -> FlaxPreTrainedModel: kwargs_bert = { argument[len("bert_") :]: value for argument, value in kwargs.items() if argument.startswith("text_") } kwargs_clip_vision = { argument[len("clip_vision_") :]: value for argument, value in kwargs.items() if argument.startswith("vision_") } # remove text, vision kwargs from kwargs for key in kwargs_bert.keys(): del kwargs["bert_" + key] for key in kwargs_clip_vision.keys(): del kwargs["clip_vision_" + key] # Load and initialize the text and vision model bert_model = kwargs_bert.pop("model", None) if bert_model is None: assert ( bert_model_name_or_path is not None ), "If `model` is not defined as an argument, a `bert_model_name_or_path` has to be defined" from transformers import FlaxBertForMaskedLM if "config" not in kwargs_bert: from transformers import BertConfig bert_config = BertConfig.from_pretrained(bert_model_name_or_path) kwargs_bert["config"] = bert_config bert_model = FlaxBertForMaskedLM.from_pretrained( bert_model_name_or_path, *model_args, from_pt=True, **kwargs_bert ) clip_vision_model = kwargs_clip_vision.pop("model", None) if clip_vision_model is None: assert ( clip_vision_model_name_or_path is not None ), "If `model` is not defined as an argument, a `clip_vision_model_name_or_path` has to be defined" from transformers import FlaxCLIPVisionModel if "config" not in kwargs_clip_vision: from transformers import CLIPVisionConfig clip_vision_config = CLIPVisionConfig.from_pretrained( clip_vision_model_name_or_path ) kwargs_clip_vision["config"] = clip_vision_config clip_vision_model = FlaxCLIPVisionModel.from_pretrained( clip_vision_model_name_or_path, *model_args, **kwargs_clip_vision ) # instantiate config with corresponding kwargs dtype = kwargs.pop("dtype", jnp.float32) config = CLIPVisionBertConfig.from_clip_vision_bert_configs( clip_vision_model.config, bert_model.config, **kwargs ) # init model model = cls(config, *model_args, dtype=dtype, **kwargs) model.params["cls"] = bert_model.params["cls"] for key in model.params["model"].keys(): if key != "embeddings": model.params["model"][key] = bert_model.params["bert"][key] else: model.params["model"]["embeddings"][ "clip_vision_module" ] = clip_vision_model.params for sub_key in bert_model.params["bert"][key]: model.params["model"][key][sub_key] = bert_model.params["bert"][ key ][sub_key] return model class FlaxCLIPVisionBertForSequenceClassificationModule(nn.Module): config: CLIPVisionBertConfig dtype: jnp.dtype = jnp.float32 num_labels: int = 3129 # TODO: Remove this hard-coding! def setup(self): self.model = FlaxCLIPVisionBertModule(config=self.config, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.bert_config.hidden_dropout_prob) self.classifier = nn.Dense( self.num_labels, dtype=self.dtype, ) def __call__( self, input_ids, attention_mask, token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): # Model outputs = self.model( input_ids, attention_mask, token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output, deterministic=deterministic) logits = self.classifier(pooled_output) if not return_dict: return (logits,) + outputs[2:] return FlaxSequenceClassifierOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class FlaxCLIPVisionBertForSequenceClassification(FlaxPreTrainedModel): config_class = CLIPVisionBertConfig module_class = FlaxCLIPVisionBertForSequenceClassificationModule def __init__( self, config: CLIPVisionBertConfig, input_shape: Tuple = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs, ): if input_shape is None: input_shape = ( (1, 1), ( 1, config.clip_vision_config.image_size, config.clip_vision_config.image_size, 3, ), ( 1, ( config.clip_vision_config.image_size // config.clip_vision_config.patch_size ) ** 2 + 1, ), ) module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__( config, module, input_shape=input_shape, seed=seed, dtype=dtype ) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors textual_input_shape = input_shape[0] input_ids = jnp.zeros(textual_input_shape, dtype="i4") token_type_ids = jnp.zeros_like(input_ids) position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), textual_input_shape ) attention_mask = jnp.ones_like(input_ids) pixel_values = jax.random.normal(rng, input_shape[1]) visual_attention_mask = jnp.ones(input_shape[2]) visual_token_type_ids = jnp.ones(input_shape[2]) visual_position_ids = jnp.broadcast_to( jnp.zeros(jnp.atleast_2d(visual_token_type_ids).shape[-1]), input_shape[2] ) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.module.init( rngs, input_ids, attention_mask, token_type_ids, position_ids, pixel_values, visual_attention_mask, visual_token_type_ids, visual_position_ids, return_dict=False, )["params"] def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, pixel_values=None, visual_attention_mask=None, visual_token_type_ids=None, visual_position_ids=None, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.bert_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.bert_config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.bert_config.return_dict ) # pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.zeros_like(input_ids) if position_ids is None: position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape ) if attention_mask is None: attention_mask = jnp.ones_like(input_ids) visual_sequence_length = ( pixel_values.shape[0], ( self.config.clip_vision_config.image_size // self.config.clip_vision_config.patch_size ) ** 2 + 1, ) if visual_token_type_ids is None: visual_token_type_ids = jnp.ones(visual_sequence_length) if visual_position_ids is None: visual_position_ids = jnp.broadcast_to( jnp.atleast_2d(jnp.ones(visual_sequence_length)).shape[-1], (visual_sequence_length), ) if visual_attention_mask is None: visual_attention_mask = jnp.ones(visual_sequence_length) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng return self.module.apply( {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), jnp.array(pixel_values, dtype=jnp.float32), jnp.array(visual_attention_mask, dtype="i4"), jnp.array(visual_token_type_ids, dtype="i4"), jnp.array(visual_position_ids, dtype="i4"), not train, output_attentions, output_hidden_states, return_dict, rngs=rngs, ) @classmethod def from_pretrained(cls, *args, **kwargs): # At the moment fast initialization is not supported # for composite models # kwargs["_fast_init"] = False return super().from_pretrained(*args, **kwargs) @classmethod def from_clip_vision_bert_pretrained( cls, clip_vision_model_name_or_path: str = None, bert_model_name_or_path: str = None, *model_args, **kwargs, ) -> FlaxPreTrainedModel: kwargs_bert = { argument[len("bert_") :]: value for argument, value in kwargs.items() if argument.startswith("bert_") } kwargs_clip_vision = { argument[len("clip_vision_") :]: value for argument, value in kwargs.items() if argument.startswith("clip_vision_") } # remove text, vision kwargs from kwargs for key in kwargs_bert.keys(): del kwargs["bert_" + key] for key in kwargs_clip_vision.keys(): del kwargs["clip_vision_" + key] # Load and initialize the text and vision model bert_model = kwargs_bert.pop("model", None) if bert_model is None: assert ( bert_model_name_or_path is not None ), "If `model` is not defined as an argument, a `bert_model_name_or_path` has to be defined" from transformers import FlaxBertForSequenceClassification if "config" not in kwargs_bert: from transformers import BertConfig bert_config = BertConfig.from_pretrained(bert_model_name_or_path) kwargs_bert["config"] = bert_config bert_model = FlaxBertForSequenceClassification.from_pretrained( bert_model_name_or_path, *model_args, from_pt=True, **kwargs_bert ) clip_vision_model = kwargs_clip_vision.pop("model", None) if clip_vision_model is None: assert ( clip_vision_model_name_or_path is not None ), "If `model` is not defined as an argument, a `clip_vision_model_name_or_path` has to be defined" from transformers import FlaxCLIPVisionModel if "config" not in kwargs_clip_vision: from transformers import CLIPVisionConfig clip_vision_config = CLIPVisionConfig.from_pretrained( clip_vision_model_name_or_path ) kwargs_clip_vision["config"] = clip_vision_config clip_vision_model = FlaxCLIPVisionModel.from_pretrained( clip_vision_model_name_or_path, *model_args, **kwargs_clip_vision ) # instantiate config with corresponding kwargs dtype = kwargs.pop("dtype", jnp.float32) config = CLIPVisionBertConfig.from_clip_vision_bert_configs( clip_vision_model.config, bert_model.config, **kwargs ) # init model model = cls(config, *model_args, dtype=dtype, **kwargs) # model.params["classifier"] = bert_model.params["classifier"] for key in model.params["model"].keys(): if key != "embeddings": model.params["model"][key] = bert_model.params["bert"][key] else: model.params["model"]["embeddings"][ "clip_vision_module" ] = clip_vision_model.params for sub_key in bert_model.params["bert"][key]: model.params["model"][key][sub_key] = bert_model.params["bert"][ key ][sub_key] return model