from typing import Callable, Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze from jax import lax from jax.random import PRNGKey from transformers import ( CLIPVisionConfig, FlaxCLIPVisionModel, FlaxMarianMTModel, MarianConfig, ) from transformers.modeling_flax_outputs import ( FlaxBaseModelOutputWithPooling, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, ) from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule from transformers.models.marian.modeling_flax_marian import ( FlaxMarianDecoder, FlaxPreTrainedModel, shift_tokens_right, ) from .configuration_clip_vision_marian import CLIPVisionMarianConfig from .modeling_clip_vision_marian_utils import FlaxCLIPVisionMarianPreTrainedModel class FlaxCLIPVisionMarianModule(nn.Module): config: CLIPVisionMarianConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.shared = nn.Embed( self.config.marian_config.vocab_size, self.config.marian_config.d_model, embedding_init=jax.nn.initializers.normal( self.config.marian_config.init_std, self.dtype ), dtype=self.dtype, ) self.encoder = FlaxCLIPVisionModule( self.config.clip_vision_config, dtype=self.dtype ) self.decoder = FlaxMarianDecoder( self.config.marian_config, dtype=self.dtype, embed_tokens=self.shared ) self.visual_projection = nn.Dense( self.config.marian_config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.marian_config.init_std, self.dtype ), ) def _get_encoder_module(self): return self.encoder def _get_decoder_module(self): return self.decoder def __call__( self, pixel_values, decoder_input_ids, decoder_attention_mask, decoder_position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): encoder_outputs = self.encoder( pixel_values=pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) batch_size, sequence_length = encoder_outputs[0].shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) encoder_hidden_states = self.visual_projection(encoder_outputs[0]) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, position_ids=decoder_position_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) if not return_dict: return decoder_outputs + encoder_outputs return FlaxSeq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) class FlaxCLIPVisionMarianMTModule(nn.Module): config: CLIPVisionMarianConfig dtype: jnp.dtype = jnp.float32 bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros def setup(self): self.model = FlaxCLIPVisionMarianModule(config=self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.model.shared.num_embeddings, use_bias=False, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.marian_config.init_std, self.dtype ), ) self.final_logits_bias = self.param( "final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings) ) def _get_encoder_module(self): return self.model.encoder def _get_decoder_module(self): return self.model.decoder def _get_visual_projection_module(self): return self.model.visual_projection def __call__( self, pixel_values, decoder_input_ids, decoder_attention_mask, decoder_position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): outputs = self.model( pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, decoder_position_ids=decoder_position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_embedding = self.model.variables["params"]["shared"]["embedding"] lm_logits = self.lm_head.apply( {"params": {"kernel": shared_embedding.T}}, hidden_states ) else: lm_logits = self.lm_head(hidden_states) lm_logits += self.final_logits_bias if not return_dict: output = (lm_logits,) + outputs[1:] return output return FlaxSeq2SeqLMOutput( logits=lm_logits, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) class FlaxCLIPVisionMarianOuterPreTrainedModel(FlaxCLIPVisionMarianPreTrainedModel): config_class = CLIPVisionMarianConfig base_model_prefix: str = "model" module_class: nn.Module = None def __init__( self, config: CLIPVisionMarianConfig, input_shape: Tuple = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs, ): if input_shape is None: input_shape = ( ( 1, config.clip_vision_config.image_size, config.clip_vision_config.image_size, 3, ), (1, 1), ) module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__( config, module, input_shape=input_shape, seed=seed, dtype=dtype ) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors pixel_values = jax.random.normal(rng, input_shape[0]) # # make sure initialization pass will work for FlaxMarianForSequenceClassificationModule # input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) decoder_input_ids = jnp.zeros(input_shape[1], dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) batch_size, sequence_length = decoder_input_ids.shape decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.module.init( rngs, pixel_values, decoder_input_ids, decoder_attention_mask, decoder_position_ids, )["params"] def init_cache(self, batch_size, max_length, encoder_outputs): decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) decoder_position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape, ) def _decoder_forward( module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ): decoder_module = module._get_decoder_module() return decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ) init_variables = self.module.init( jax.random.PRNGKey(0), decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, decoder_position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], init_cache=True, method=_decoder_forward, # we only need to call the decoder to init the cache ) return unfreeze(init_variables["cache"]) def encode( self, pixel_values: jnp.ndarray, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) # pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng def _encoder_forward(module, pixel_values, **kwargs): encode_module = module._get_encoder_module() visual_projection = module._get_visual_projection_module() outputs = encode_module(pixel_values, **kwargs) return FlaxBaseModelOutputWithPooling( last_hidden_state=visual_projection(outputs.last_hidden_state), pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) return self.module.apply( {"params": params or self.params}, pixel_values=jnp.array(pixel_values, dtype=jnp.float32), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, method=_encoder_forward, ) def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) if decoder_position_ids is None: if past_key_values is not None: raise ValueError( "Make sure to provide `decoder_position_ids` when passing `past_key_values`." ) decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxMarianAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward( module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ): decoder_module = module._get_decoder_module() return decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ) outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past = outputs outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past = outputs outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs def __call__( self, pixel_values: jnp.ndarray, decoder_input_ids: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) # pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) # # prepare encoder inputs # if attention_mask is None: # attention_mask = jnp.ones_like(input_ids) # if position_ids is None: # batch_size, sequence_length = input_ids.shape # position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # prepare decoder inputs # if decoder_input_ids is None: # decoder_input_ids = shift_tokens_right( # input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id # ) # TODO: Check how to use this if decoder_attention_mask is None: decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} return self.module.apply( {"params": params or self.params}, pixel_values=jnp.array(pixel_values, dtype=jnp.float32), decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, ) class FlaxCLIPVisionMarianMT( FlaxCLIPVisionMarianOuterPreTrainedModel ): module_class = FlaxCLIPVisionMarianMTModule dtype: jnp.dtype = jnp.float32 def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, deterministic: bool = True, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) if decoder_position_ids is None: if past_key_values is not None: raise ValueError( "Make sure to provide `decoder_position_ids` when passing `past_key_values`." ) decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxMarianAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward( module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ): decoder_module = module._get_decoder_module() outputs = decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ) hidden_states = outputs[0] if self.config.tie_word_embeddings: shared_embedding = module.model.variables["params"]["shared"][ "embedding" ] lm_logits = module.lm_head.apply( {"params": {"kernel": shared_embedding.T}}, hidden_states ) else: lm_logits = module.lm_head(hidden_states) lm_logits += module.final_logits_bias return lm_logits, outputs outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, rngs=rngs, mutable=mutable, method=_decoder_forward, ) if past_key_values is None: lm_logits, decoder_outputs = outputs else: (lm_logits, decoder_outputs), past = outputs if return_dict: outputs = FlaxCausalLMOutputWithCrossAttentions( logits=lm_logits, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, ) else: outputs = (lm_logits,) + decoder_outputs[1:] # add updated cache to model output if past_key_values is not None and return_dict: outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs def _adapt_logits_for_beam_search(self, logits): """This function enforces the padding token never to be generated.""" logits = jax.ops.index_update(logits, jax.ops.index[:, :, self.config.marian_config.pad_token_id], float("-inf")) return logits def prepare_inputs_for_generation( self, decoder_input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None, decoder_attention_mask: Optional[jnp.DeviceArray] = None, encoder_outputs=None, **kwargs, ): # initializing the cache batch_size, seq_length = decoder_input_ids.shape past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. # But since the decoder uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if decoder_attention_mask is not None: position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice( extended_attention_mask, decoder_attention_mask, (0, 0) ) else: position_ids = jnp.broadcast_to( jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) ) return { "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "encoder_attention_mask": attention_mask, "decoder_attention_mask": extended_attention_mask, "decoder_position_ids": position_ids, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values model_kwargs["decoder_position_ids"] = ( model_kwargs["decoder_position_ids"][:, -1:] + 1 ) return model_kwargs @classmethod def from_pretrained(cls, *args, **kwargs): # At the moment fast initialization is not supported # for composite models # kwargs["_fast_init"] = False return super().from_pretrained(*args, **kwargs) @classmethod def from_clip_vision_marian_pretrained( cls, clip_vision_model_name_or_path: str = None, marian_model_name_or_path: str = None, *model_args, **kwargs, ) -> FlaxCLIPVisionMarianPreTrainedModel: kwargs_marian = { argument[len("marian_") :]: value for argument, value in kwargs.items() if argument.startswith("marian_") } kwargs_clip_vision = { argument[len("clip_vision_") :]: value for argument, value in kwargs.items() if argument.startswith("clip_vision_") } # remove marian, clip_vision kwargs from kwargs for key in kwargs_marian.keys(): del kwargs["marian_" + key] for key in kwargs_clip_vision.keys(): del kwargs["clip_vision_" + key] # Load and initialize the marian and clip_vision model marian_model = kwargs_marian.pop("model", None) if marian_model is None: assert ( marian_model_name_or_path is not None ), "If `model` is not defined as an argument, a `marian_model_name_or_path` has to be defined" if "config" not in kwargs_marian: marian_config = MarianConfig.from_pretrained(marian_model_name_or_path) kwargs_marian["config"] = marian_config marian_model = FlaxMarianMTModel.from_pretrained( marian_model_name_or_path, *model_args, **kwargs_marian ) clip_vision_model = kwargs_clip_vision.pop("model", None) if clip_vision_model is None: assert ( clip_vision_model_name_or_path is not None ), "If `model` is not defined as an argument, a `clip_vision_model_name_or_path` has to be defined" if "config" not in kwargs_clip_vision: clip_vision_config = CLIPVisionConfig.from_pretrained( clip_vision_model_name_or_path ) kwargs_clip_vision["config"] = clip_vision_config clip_vision_model = FlaxCLIPVisionModel.from_pretrained( clip_vision_model_name_or_path, *model_args, **kwargs_clip_vision ) # instantiate config with corresponding kwargs dtype = kwargs.pop("dtype", jnp.float32) config = CLIPVisionMarianConfig.from_clip_vision_marian_configs( clip_vision_model.config, marian_model.config, **kwargs ) # init model model = cls(config, *model_args, dtype=dtype, **kwargs) model.params["model"]["encoder"] = clip_vision_model.params model.params["model"]["decoder"] = marian_model.params["model"]["decoder"] model.params["model"]["shared"] = marian_model.params["model"]["shared"] model.params["final_logits_bias"] = marian_model.params["final_logits_bias"] return model