from typing import Callable, Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze from jax import lax from jax.random import PRNGKey from transformers import GPT2Config, FlaxViTModel, ViTConfig from transformers.modeling_flax_outputs import ( FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, ) from transformers.models.bart.modeling_flax_bart import ( shift_tokens_right, ) from .modeling_flax_gpt2 import ( FlaxGPT2Module, FlaxGPT2Model, FlaxGPT2LMHeadModule, FlaxGPT2LMHeadModel, FlaxPreTrainedModel ) from transformers.models.vit.modeling_flax_vit import FlaxViTModule from .configuration_vit_gpt2 import ViTGPT2Config def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) return shifted_input_ids class FlaxViTGPT2LMModule(nn.Module): config: ViTGPT2Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.encoder = FlaxViTModule(self.config.vit_config, dtype=self.dtype) self.decoder = FlaxGPT2LMHeadModule(self.config.gpt2_config, dtype=self.dtype) def _get_encoder_module(self): return self.encoder def _get_decoder_module(self): return self.decoder def __call__( self, pixel_values, input_ids, attention_mask, position_ids, encoder_attention_mask: Optional[jnp.ndarray] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): encoder_outputs = self.encoder( pixel_values=pixel_values, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) decoder_outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=encoder_attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) if not return_dict: return decoder_outputs + encoder_outputs return FlaxSeq2SeqLMOutput( logits=decoder_outputs.logits, decoder_hidden_states=decoder_outputs.decoder_hidden_states, decoder_attentions=decoder_outputs.decoder_attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module): config: ViTGPT2Config dtype: jnp.dtype = jnp.float32 bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros def setup(self): self.model = FlaxViTGPT2LMModule(config=self.config, dtype=self.dtype) def _get_encoder_module(self): return self.model.encoder def _get_decoder_module(self): return self.model.decoder def __call__( self, pixel_values, input_ids, attention_mask, position_ids, encoder_attention_mask: Optional[jnp.ndarray] = None, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): outputs = self.model( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) return outputs class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel): config_class = ViTGPT2Config base_model_prefix: str = "model" module_class: nn.Module = None def __init__( self, config: ViTGPT2Config, input_shape: Tuple = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs, ): if input_shape is None: input_shape = ( (1, config.vit_config.image_size, config.vit_config.image_size, 3), (1, 1), ) module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__( config, module, input_shape=input_shape, seed=seed, dtype=dtype ) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors pixel_values = jax.random.normal(rng, input_shape[0]) # # make sure initialization pass will work for FlaxBartForSequenceClassificationModule # input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) input_ids = jnp.zeros(input_shape[1], dtype="i4") attention_mask = jnp.ones_like(input_ids) batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.module.init( rngs, pixel_values, input_ids, attention_mask, position_ids, )["params"] def init_cache(self, batch_size, max_length, encoder_outputs): input_ids = jnp.ones((batch_size, max_length), dtype="i4") attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape, ) def _decoder_forward( module, input_ids, attention_mask, position_ids, **kwargs, ): decoder_module = module._get_decoder_module() return decoder_module( input_ids, attention_mask, position_ids, **kwargs, ) init_variables = self.module.init( jax.random.PRNGKey(0), input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, encoder_hidden_states=encoder_outputs[0], init_cache=True, method=_decoder_forward, # we only need to call the decoder to init the cache ) return unfreeze(init_variables["cache"]) def encode( self, pixel_values: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng def _encoder_forward(module, pixel_values, **kwargs): encode_module = module._get_encoder_module() return encode_module(pixel_values, **kwargs) return self.module.apply( {"params": params or self.params}, pixel_values=jnp.array(pixel_values, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, method=_encoder_forward, ) def decode( self, input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = input_ids.shape if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) if position_ids is None: if past_key_values is not None: raise ValueError( "Make sure to provide `position_ids` when passing `past_key_values`." ) position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxGPT2Attention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward( module, input_ids, attention_mask, position_ids, **kwargs, ): decoder_module = module._get_decoder_module() return decoder_module( input_ids, attention_mask, position_ids, **kwargs, ) outputs = self.module.apply( inputs, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past = outputs outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past = outputs outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs def __call__( self, pixel_values: jnp.ndarray, input_ids: Optional[jnp.ndarray] = None, attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) # # prepare encoder inputs # if encoder_attention_mask is None: # encoder_attention_mask = jnp.ones_like(input_ids) # if position_ids is None: # batch_size, sequence_length = input_ids.shape # position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # prepare decoder inputs # if decoder_input_ids is None: # decoder_input_ids = shift_tokens_right( # input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id # ) # TODO: Check how to use this if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} return self.module.apply( {"params": params or self.params}, pixel_values=jnp.array(pixel_values, dtype=jnp.float32), input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, ) class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel): module_class = FlaxViTGPT2LMForConditionalGenerationModule dtype: jnp.dtype = jnp.float32 def decode( self, input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, deterministic: bool = True, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = input_ids.shape if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) if position_ids is None: if past_key_values is not None: raise ValueError( "Make sure to provide `position_ids` when passing `past_key_values`." ) position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxGPT2Attention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward( module, input_ids, attention_mask, position_ids, **kwargs, ): decoder_module = module._get_decoder_module() outputs = decoder_module( input_ids, attention_mask, position_ids, **kwargs, ) lm_logits = outputs[0] return lm_logits, outputs outputs = self.module.apply( inputs, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, rngs=rngs, mutable=mutable, method=_decoder_forward, ) if past_key_values is None: lm_logits, outputs = outputs else: (lm_logits, outputs), past = outputs if return_dict: outputs = FlaxCausalLMOutputWithCrossAttentions( logits=lm_logits, hidden_states=outputs.decoder_hidden_states, attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, ) else: outputs = (lm_logits,) + outputs[1:] # add updated cache to model output if past_key_values is not None and return_dict: outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs def prepare_inputs_for_generation( self, input_ids, max_length, encoder_attention_mask: Optional[jnp.DeviceArray] = None, attention_mask: Optional[jnp.DeviceArray] = None, encoder_outputs=None, **kwargs, ): # initializing the cache batch_size, seq_length = input_ids.shape past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. # But since the decoder uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice( extended_attention_mask, attention_mask, (0, 0) ) else: position_ids = jnp.broadcast_to( jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) ) return { "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "encoder_attention_mask": encoder_attention_mask, "attention_mask": extended_attention_mask, "position_ids": position_ids, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values model_kwargs["position_ids"] = ( model_kwargs["position_ids"][:, -1:] + 1 ) return model_kwargs @classmethod def from_vit_gpt2_pretrained( cls, vit_model_name_or_path: str = None, gpt2_model_name_or_path: str = None, *model_args, **kwargs, ) -> FlaxViTGPT2LMPreTrainedModel: kwargs_gpt2 = { argument[len("gpt2_") :]: value for argument, value in kwargs.items() if argument.startswith("gpt2_") } kwargs_vit = { argument[len("vit_") :]: value for argument, value in kwargs.items() if argument.startswith("vit_") } # remove gpt2, vit kwargs from kwargs for key in kwargs_gpt2.keys(): del kwargs["gpt2_" + key] for key in kwargs_vit.keys(): del kwargs["vit_" + key] # Load and initialize the gpt2 and vit model gpt2_model = kwargs_gpt2.pop("model", None) if gpt2_model is None: assert ( gpt2_model_name_or_path is not None ), "If `model` is not defined as an argument, a `gpt2_model_name_or_path` has to be defined" if "config" not in kwargs_gpt2: gpt2_config = GPT2Config.from_pretrained(gpt2_model_name_or_path) kwargs_gpt2["config"] = gpt2_config kwargs_gpt2["config"].add_cross_attention = True gpt2_model = FlaxGPT2LMHeadModel.from_pretrained( gpt2_model_name_or_path, *model_args, **kwargs_gpt2 ) vit_model = kwargs_vit.pop("model", None) if vit_model is None: assert ( vit_model_name_or_path is not None ), "If `model` is not defined as an argument, a `vit_model_name_or_path` has to be defined" if "config" not in kwargs_vit: vit_config = ViTConfig.from_pretrained(vit_model_name_or_path) kwargs_vit["config"] = vit_config vit_model = FlaxViTModel.from_pretrained( vit_model_name_or_path, *model_args, **kwargs_vit ) # instantiate config with corresponding kwargs dtype = kwargs.pop("dtype", jnp.float32) config = ViTGPT2Config.from_vit_gpt2_configs( vit_model.config, gpt2_model.config, **kwargs ) # init model model = cls(config, *model_args, dtype=dtype, **kwargs) model.params["model"]["encoder"] = vit_model.params model.params["model"]["decoder"] = gpt2_model.params return model