vit-gpt2 / vit_gpt2 /modeling_flax_vit_gpt2_lm.py
ydshieh
Add project_encoder and related layers
54ece9e
raw
history blame
22 kB
import os
from typing import Callable, Optional, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from jax import lax
from jax.random import PRNGKey
from transformers.modeling_flax_outputs import (
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
)
from .configuration_vit_gpt2 import ViTGPT2Config
from transformers import ViTConfig, GPT2Config
from transformers import FlaxPreTrainedModel, FlaxViTModel
from transformers.models.vit.modeling_flax_vit import FlaxViTModule
from .modeling_flax_gpt2 import (
FlaxGPT2PreTrainedModel,
FlaxGPT2Module,
FlaxGPT2Model,
FlaxGPT2LMHeadModule,
FlaxGPT2LMHeadModel,
)
class FlaxViTGPT2LMModule(nn.Module):
"""Play the same role as ``FlaxBartModule`` but with the decoder equipped with a LM head."""
config: ViTGPT2Config
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.encoder = FlaxViTModule(self.config.vision_config, dtype=self.dtype)
self.decoder = FlaxGPT2LMHeadModule(self.config.text_config, dtype=self.dtype)
def _get_encoder_module(self):
return self.encoder
def _get_decoder_module(self):
return self.decoder
def __call__(
self,
pixel_values,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
encoder_outputs = self.encoder(
pixel_values=pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqLMOutput(
logits=decoder_outputs.logits,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module):
"""Play the same role as ``FlaxBartForConditionalGenerationModule`` but with the decoder equipped with a LM head.
Actually, it is identical to ``FlaxBartForConditionalGenerationModule`` with a different name.
"""
config: ViTGPT2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.model = FlaxViTGPT2LMModule(config=self.config, dtype=self.dtype)
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
pixel_values,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
pixel_values=pixel_values,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
return outputs
class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel):
"""Play the same role as ``FlaxBartPretrainedModel``"""
config_class = ViTGPT2Config
base_model_prefix: str = "model"
module_class: nn.Module = None
def __init__(
self,
config: ViTGPT2Config,
input_shape: Tuple = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
if input_shape is None:
input_shape = (
(1, config.vision_config.image_size, config.vision_config.image_size, 3),
(1, 1),
)
module = self.module_class(config=config, dtype=dtype, **kwargs)
# This will use ``self.init_weights``.
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape
# init input tensors
pixel_values = jax.random.normal(rng, encoder_input_shape)
attention_mask = None
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
decoder_input_ids = jax.ops.index_update(decoder_input_ids, (..., -1), self.config.text_config.eos_token_id)
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
rngs,
pixel_values,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
)["params"]
def init_cache(self, batch_size, max_length, encoder_outputs):
# init input variables to retrieve cache
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape,
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # we only need to call the decoder to init the cache
)
return unfreeze(init_variables["cache"])
def encode(
self,
pixel_values: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = (output_attentions if output_attentions is not None else self.config.vision_config.output_attentions)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.vision_config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.vision_config.return_dict
# (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.)
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, pixel_values, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(pixel_values, **kwargs)
return self.module.apply(
{"params": params or self.params},
pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = (
output_attentions if output_attentions is not None else self.config.text_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.text_config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.text_config.return_dict
encoder_hidden_states = encoder_outputs[0]
if encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = decoder_input_ids.shape
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError(
"Make sure to provide `position_ids` when passing `past_key_values`."
)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxGPT2Attention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past = outputs
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past = outputs
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
def __call__(
self,
pixel_values: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare encoder inputs (`transpose` is done in `FlaxViTPreTrainedModel.__call__()`, so we do the same here.)
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# prepare decoder inputs
if decoder_input_ids is None:
decoder_input_ids = self.config.decoder_start_token_id * jnp.ones((pixel_values.shape[0], 1))
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
pixel_values=jnp.array(pixel_values, dtype=jnp.float32),
attention_mask=attention_mask,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
# @add_start_docstrings(
# "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.",
# BART_START_DOCSTRING,
# )
# class FlaxViTGPT2LMModel(FlaxViTGPT2LMPreTrainedModel):
# config: BartConfig
# dtype: jnp.dtype = jnp.float32 # the dtype of the computation
# module_class = FlaxViTGPT2LMModule
#
#
# append_call_sample_docstring(
# FlaxBartModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
# )
class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel):
module_class = FlaxViTGPT2LMForConditionalGenerationModule
dtype: jnp.dtype = jnp.float32
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
deterministic: bool = True,
params: dict = None,
dropout_rng: PRNGKey = None,
):
return super().decode(
decoder_input_ids,
encoder_outputs,
encoder_attention_mask,
decoder_attention_mask,
decoder_position_ids,
past_key_values,
output_attentions,
output_hidden_states,
return_dict,
not deterministic,
params,
dropout_rng,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
encoder_outputs=None,
**kwargs,
):
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
)
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
@classmethod
def from_vision_text_pretrained(
cls,
vision_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
text_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
**kwargs,
) -> FlaxViTGPT2LMPreTrainedModel:
vision_kwargs = {
kwarg[len("vision_"):]: value
for kwarg, value in kwargs.items()
if kwarg.startswith("vision_")
}
text_kwargs = {
kwarg[len("text_"):]: value
for kwarg, value in kwargs.items()
if kwarg.startswith("text_")
}
# remove vit & gpt2 kwargs from kwargs
for key in vision_kwargs.keys():
del kwargs["vision_" + key]
for key in text_kwargs.keys():
del kwargs["text_" + key]
vision_model_args = vision_kwargs.pop('model_args', [])
text_model_args = text_kwargs.pop('model_args', [])
# Load and initialize the vit & gpt2 model
vision_model = vision_kwargs.pop("model", None)
text_model = text_kwargs.pop("model", None)
if vision_model is None:
assert (
vision_pretrained_model_name_or_path is not None
), "If `model` is not defined as an argument, a `vision_pretrained_model_name_or_path` has to be defined"
if "config" not in vision_kwargs:
vision_config = ViTConfig.from_pretrained(vision_pretrained_model_name_or_path)
vision_kwargs["config"] = vision_config
# TODO: How to deal with model_args?
vision_model = FlaxViTModel.from_pretrained(
vision_pretrained_model_name_or_path, *vision_model_args, **vision_kwargs
)
if text_model is None:
assert (
text_pretrained_model_name_or_path is not None
), "If `model` is not defined as an argument, a `text_pretrained_model_name_or_path` has to be defined"
if "config" not in text_kwargs:
text_config = GPT2Config.from_pretrained(text_pretrained_model_name_or_path)
text_config.project_encoder = text_kwargs.pop("project_encoder", None)
text_kwargs["config"] = text_config
text_kwargs["config"].add_cross_attention = True
# TODO: How to deal with model_args?
text_model = FlaxGPT2LMHeadModel.from_pretrained(
text_pretrained_model_name_or_path, *text_model_args, **text_kwargs
)
# instantiate config with corresponding kwargs
dtype = kwargs.pop("dtype", jnp.float32)
config = ViTGPT2Config.from_vision_text_configs(
vision_model.config, text_model.config, **kwargs
)
# init model
model = cls(config, *model_args, dtype=dtype, **kwargs)
model.params["model"]["encoder"] = vision_model.params
model.params["model"]["decoder"] = text_model.params
return model