import os from typing import Callable, Optional, Tuple, Union import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze from jax import lax from jax.random import PRNGKey from transformers.modeling_flax_outputs import ( FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, ) from .configuration_vit_gpt2 import ViTGPT2Config from transformers import ViTConfig, GPT2Config from transformers import FlaxPreTrainedModel, FlaxViTModel from transformers.models.vit.modeling_flax_vit import FlaxViTModule from .modeling_flax_gpt2 import ( FlaxGPT2PreTrainedModel, FlaxGPT2Module, FlaxGPT2Model, FlaxGPT2LMHeadModule, FlaxGPT2LMHeadModel, ) class FlaxViTGPT2LMModule(nn.Module): """Play the same role as ``FlaxBartModule`` but with the decoder equipped with a LM head.""" config: ViTGPT2Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.encoder = FlaxViTModule(self.config.vision_config, dtype=self.dtype) self.decoder = FlaxGPT2LMHeadModule(self.config.text_config, dtype=self.dtype) def _get_encoder_module(self): return self.encoder def _get_decoder_module(self): return self.decoder def __call__( self, pixel_values, attention_mask, decoder_input_ids, decoder_attention_mask, decoder_position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): encoder_outputs = self.encoder( pixel_values=pixel_values, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) if not return_dict: return decoder_outputs + encoder_outputs return FlaxSeq2SeqLMOutput( logits=decoder_outputs.logits, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module): """Play the same role as ``FlaxBartForConditionalGenerationModule`` but with the decoder equipped with a LM head. Actually, it is identical to ``FlaxBartForConditionalGenerationModule`` with a different name. """ config: ViTGPT2Config dtype: jnp.dtype = jnp.float32 def setup(self): self.model = FlaxViTGPT2LMModule(config=self.config, dtype=self.dtype) def _get_encoder_module(self): return self.model.encoder def _get_decoder_module(self): return self.model.decoder def __call__( self, pixel_values, attention_mask, decoder_input_ids, decoder_attention_mask, decoder_position_ids, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ): outputs = self.model( pixel_values=pixel_values, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, decoder_position_ids=decoder_position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=deterministic, ) return outputs class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel): """Play the same role as ``FlaxBartPretrainedModel``""" config_class = ViTGPT2Config base_model_prefix: str = "model" module_class: nn.Module = None def __init__( self, config: ViTGPT2Config, input_shape: Tuple = None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs, ): if input_shape is None: input_shape = ( (1, config.vision_config.image_size, config.vision_config.image_size, 3), (1, 1), ) module = self.module_class(config=config, dtype=dtype, **kwargs) # This will use ``self.init_weights``. super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: encoder_input_shape, decoder_input_shape = input_shape # init input tensors pixel_values = jax.random.normal(rng, encoder_input_shape) attention_mask = None decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") # make sure initialization pass will work for FlaxBartForSequenceClassificationModule decoder_input_ids = jax.ops.index_update(decoder_input_ids, (..., -1), self.config.text_config.eos_token_id) decoder_attention_mask = jnp.ones_like(decoder_input_ids) batch_size, sequence_length = decoder_input_ids.shape decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.module.init( rngs, pixel_values, attention_mask, decoder_input_ids, decoder_attention_mask, decoder_position_ids, )["params"] def init_cache(self, batch_size, max_length, encoder_outputs): # init input variables to retrieve cache decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") decoder_attention_mask = jnp.ones_like(decoder_input_ids) decoder_position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape, ) def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): decoder_module = module._get_decoder_module() return decoder_module( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, position_ids=decoder_position_ids, **kwargs, ) init_variables = self.module.init( jax.random.PRNGKey(0), decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, decoder_position_ids=decoder_position_ids, encoder_hidden_states=encoder_outputs[0], init_cache=True, method=_decoder_forward, # we only need to call the decoder to init the cache ) return unfreeze(init_variables["cache"]) def encode( self, pixel_values: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = (output_attentions if output_attentions is not None else self.config.vision_config.output_attentions) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.vision_config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.vision_config.return_dict # (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.) pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng def _encoder_forward(module, pixel_values, **kwargs): encode_module = module._get_encoder_module() return encode_module(pixel_values, **kwargs) return self.module.apply( {"params": params or self.params}, pixel_values=jnp.array(pixel_values, dtype=jnp.float32), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, method=_encoder_forward, ) def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.text_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.text_config.return_dict encoder_hidden_states = encoder_outputs[0] if encoder_attention_mask is None: batch_size, sequence_length = encoder_hidden_states.shape[:2] encoder_attention_mask = jnp.ones((batch_size, sequence_length)) batch_size, sequence_length = decoder_input_ids.shape if decoder_attention_mask is None: decoder_attention_mask = jnp.ones((batch_size, sequence_length)) if decoder_position_ids is None: if past_key_values is not None: raise ValueError( "Make sure to provide `position_ids` when passing `past_key_values`." ) decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that # it can be changed by FlaxGPT2Attention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): decoder_module = module._get_decoder_module() return decoder_module( decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs, ) outputs = self.module.apply( inputs, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, mutable=mutable, method=_decoder_forward, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past = outputs outputs["past_key_values"] = unfreeze(past["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past = outputs outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] return outputs def __call__( self, pixel_values: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, decoder_input_ids: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict # prepare encoder inputs (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.) pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) # prepare decoder inputs if decoder_input_ids is None: decoder_input_ids = self.config.decoder_start_token_id * jnp.ones((pixel_values.shape[0], 1)) if decoder_attention_mask is None: decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} return self.module.apply( {"params": params or self.params}, pixel_values=jnp.array(pixel_values, dtype=jnp.float32), attention_mask=attention_mask, decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, ) class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel): module_class = FlaxViTGPT2LMForConditionalGenerationModule dtype: jnp.dtype = jnp.float32 def decode( self, decoder_input_ids, encoder_outputs, encoder_attention_mask: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, past_key_values: dict = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): return super().decode( decoder_input_ids, encoder_outputs, encoder_attention_mask, decoder_attention_mask, decoder_position_ids, past_key_values, output_attentions, output_hidden_states, return_dict, train, params, dropout_rng, ) def prepare_inputs_for_generation( self, decoder_input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None, decoder_attention_mask: Optional[jnp.DeviceArray] = None, encoder_outputs=None, **kwargs, ): # initializing the cache batch_size, seq_length = decoder_input_ids.shape past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. # But since the decoder uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if decoder_attention_mask is not None: position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) else: position_ids = jnp.broadcast_to( jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) ) return { "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, "encoder_attention_mask": attention_mask, "decoder_attention_mask": extended_attention_mask, "decoder_position_ids": position_ids, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 return model_kwargs @classmethod def from_vision_text_pretrained( cls, vision_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], text_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs, ) -> FlaxViTGPT2LMPreTrainedModel: vision_kwargs = { kwarg[len("vision_"):]: value for kwarg, value in kwargs.items() if kwarg.startswith("vision_") } text_kwargs = { kwarg[len("text_"):]: value for kwarg, value in kwargs.items() if kwarg.startswith("text_") } # remove vit & gpt2 kwargs from kwargs for key in vision_kwargs.keys(): del kwargs["vision_" + key] for key in text_kwargs.keys(): del kwargs["text_" + key] vision_model_args = vision_kwargs.pop('model_args', []) text_model_args = text_kwargs.pop('model_args', []) # Load and initialize the vit & gpt2 model vision_model = vision_kwargs.pop("model", None) text_model = text_kwargs.pop("model", None) if vision_model is None: assert ( vision_pretrained_model_name_or_path is not None ), "If `model` is not defined as an argument, a `vision_pretrained_model_name_or_path` has to be defined" if "config" not in vision_kwargs: vision_config = ViTConfig.from_pretrained(vision_pretrained_model_name_or_path) vision_kwargs["config"] = vision_config # TODO: How to deal with model_args? vision_model = FlaxViTModel.from_pretrained( vision_pretrained_model_name_or_path, *vision_model_args, **vision_kwargs ) project_encoder = kwargs.pop("project_encoder", None) if text_model is None: assert ( text_pretrained_model_name_or_path is not None ), "If `model` is not defined as an argument, a `text_pretrained_model_name_or_path` has to be defined" if "config" not in text_kwargs: text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path) text_config.project_encoder = text_kwargs.pop("project_encoder", None) if project_encoder is not None: text_config.project_encoder = project_encoder text_kwargs["config"] = text_config text_kwargs["config"].add_cross_attention = True # TODO: How to deal with model_args? text_model = FlaxGPT2LMHeadModel.from_pretrained( text_pretrained_model_name_or_path, *text_model_args, **text_kwargs ) # instantiate config with corresponding kwargs dtype = kwargs.pop("dtype", jnp.float32) config = ViTGPT2Config.from_vision_text_configs( vision_model.config, text_model.config, project_encoder=project_encoder, **kwargs ) # init model model = cls(config, *model_args, dtype=dtype, **kwargs) model.params["model"]["encoder"] = vision_model.params model.params["model"]["decoder"] = text_model.params return model