|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Optional, Tuple |
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.core.frozen_dict import FrozenDict, unfreeze |
|
from flax.linen import combine_masks, make_causal_mask |
|
from flax.linen.attention import dot_product_attention_weights |
|
from jax import lax |
|
|
|
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward |
|
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPast, FlaxCausalLMOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, FlaxCausalLMOutputWithCrossAttentions |
|
from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring |
|
from transformers.utils import logging |
|
from transformers.models.gpt2.configuration_gpt2 import GPT2Config |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CHECKPOINT_FOR_DOC = "gpt2" |
|
_CONFIG_FOR_DOC = "GPT2Config" |
|
_TOKENIZER_FOR_DOC = "GPT2Tokenizer" |
|
|
|
|
|
GPT2_START_DOCSTRING = r""" |
|
|
|
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the |
|
generic methods the library implements for all its model (such as downloading or saving, resizing the input |
|
embeddings, pruning heads etc.) |
|
|
|
This model is also a Flax Linen `flax.nn.Module |
|
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax |
|
Module and refer to the Flax documentation for all matter related to general usage and behavior. |
|
|
|
Finally, this model supports inherent JAX features such as: |
|
|
|
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__ |
|
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__ |
|
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__ |
|
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__ |
|
|
|
Parameters: |
|
config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model. |
|
Initializing with a config file does not load the weights associated with the model, only the |
|
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the |
|
model weights. |
|
""" |
|
|
|
GPT2_INPUTS_DOCSTRING = r""" |
|
Args: |
|
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, input_ids_length)`): |
|
:obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary. |
|
|
|
Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See |
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for |
|
details. |
|
|
|
`What are input IDs? <../glossary.html#input-ids>`__ |
|
attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: |
|
|
|
- 1 for tokens that are **not masked**, |
|
- 0 for tokens that are **masked**. |
|
|
|
`What are attention masks? <../glossary.html#attention-mask>`__ |
|
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, |
|
config.max_position_embeddings - 1]``. |
|
past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``): |
|
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast |
|
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`. |
|
output_attentions (:obj:`bool`, `optional`): |
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned |
|
tensors for more detail. |
|
output_hidden_states (:obj:`bool`, `optional`): |
|
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for |
|
more detail. |
|
return_dict (:obj:`bool`, `optional`): |
|
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. |
|
""" |
|
|
|
|
|
class FlaxConv1D(nn.Module): |
|
features: int |
|
use_bias: bool = True |
|
dtype: Any = jnp.float32 |
|
precision: Any = None |
|
|
|
@nn.compact |
|
def __call__(self, inputs): |
|
inputs = jnp.asarray(inputs, self.dtype) |
|
kernel = self.param("kernel", jax.nn.initializers.normal(stddev=0.02), (self.features, inputs.shape[-1])) |
|
kernel = jnp.asarray(kernel.transpose(), self.dtype) |
|
y = lax.dot_general(inputs, kernel, (((inputs.ndim - 1,), (0,)), ((), ())), precision=self.precision) |
|
if self.use_bias: |
|
bias = self.param("bias", jax.nn.initializers.zeros, (self.features,)) |
|
bias = jnp.asarray(bias, self.dtype) |
|
y = y + bias |
|
return y |
|
|
|
|
|
class FlaxGPT2Attention(nn.Module): |
|
config: GPT2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
causal: bool = True |
|
self_attn: bool = True |
|
|
|
def setup(self): |
|
config = self.config |
|
self.embed_dim = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.embed_dim // self.num_heads |
|
|
|
factor = 3 if self.self_attn else 2 |
|
self.c_attn = FlaxConv1D(features=factor * self.embed_dim, dtype=self.dtype) |
|
self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype) |
|
|
|
if not self.self_attn: |
|
self.c_query_attn = FlaxConv1D(features=1 * self.embed_dim, dtype=self.dtype) |
|
|
|
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) |
|
if self.causal: |
|
self.causal_mask = make_causal_mask( |
|
jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool" |
|
) |
|
|
|
def _split_heads(self, hidden_states): |
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) |
|
|
|
def _merge_heads(self, hidden_states): |
|
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) |
|
|
|
@nn.compact |
|
def _concatenate_to_cache(self, key, value, query, attention_mask): |
|
""" |
|
This function takes projected key, value states from a single input token and concatenates the states to cached |
|
states from previous steps. This function is slighly adapted from the official Flax repository: |
|
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 |
|
""" |
|
|
|
is_initialized = self.has_variable("cache", "cached_key") |
|
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) |
|
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) |
|
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) |
|
|
|
if is_initialized: |
|
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape |
|
|
|
cur_index = cache_index.value |
|
indices = (0,) * len(batch_dims) + (cur_index, 0, 0) |
|
key = lax.dynamic_update_slice(cached_key.value, key, indices) |
|
value = lax.dynamic_update_slice(cached_value.value, value, indices) |
|
cached_key.value = key |
|
cached_value.value = value |
|
num_updated_cache_vectors = query.shape[1] |
|
cache_index.value = cache_index.value + num_updated_cache_vectors |
|
|
|
pad_mask = jnp.broadcast_to( |
|
jnp.arange(max_length) < cur_index + num_updated_cache_vectors, |
|
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), |
|
) |
|
attention_mask = combine_masks(pad_mask, attention_mask) |
|
return key, value, attention_mask |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
key_value_states: Optional[jnp.ndarray] = None, |
|
attention_mask=None, |
|
deterministic: bool = True, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
): |
|
|
|
|
|
|
|
is_cross_attention = key_value_states is not None |
|
|
|
if not is_cross_attention: |
|
|
|
assert self.self_attn |
|
qkv_out = self.c_attn(hidden_states) |
|
query, key, value = jnp.split(qkv_out, 3, axis=2) |
|
else: |
|
|
|
assert not self.self_attn |
|
assert not self.causal |
|
q_out = self.c_query_attn(hidden_states) |
|
(query,) = jnp.split(q_out, 1, axis=2) |
|
kv_out = self.c_attn(key_value_states) |
|
key, value = jnp.split(kv_out, 2, axis=2) |
|
|
|
query = self._split_heads(query) |
|
key = self._split_heads(key) |
|
value = self._split_heads(value) |
|
|
|
query_length, key_length = query.shape[1], key.shape[1] |
|
|
|
if self.causal: |
|
if self.has_variable("cache", "cached_key"): |
|
mask_shift = self.variables["cache"]["cache_index"] |
|
max_decoder_length = self.variables["cache"]["cached_key"].shape[1] |
|
causal_mask = lax.dynamic_slice( |
|
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) |
|
) |
|
else: |
|
causal_mask = self.causal_mask[:, :, :query_length, :key_length] |
|
|
|
batch_size = hidden_states.shape[0] |
|
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) |
|
|
|
|
|
if attention_mask is not None and self.causal: |
|
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) |
|
attention_mask = combine_masks(attention_mask, causal_mask) |
|
elif self.causal: |
|
attention_mask = causal_mask |
|
elif attention_mask is not None: |
|
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) |
|
|
|
dropout_rng = None |
|
if not deterministic and self.config.attn_pdrop > 0.0: |
|
dropout_rng = self.make_rng("dropout") |
|
|
|
|
|
|
|
if self.causal and (self.has_variable("cache", "cached_key") or init_cache): |
|
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) |
|
|
|
|
|
if attention_mask is not None: |
|
attention_bias = lax.select( |
|
attention_mask > 0, |
|
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), |
|
jnp.full(attention_mask.shape, -1e4).astype(self.dtype), |
|
) |
|
else: |
|
attention_bias = None |
|
|
|
|
|
attn_weights = dot_product_attention_weights( |
|
query, |
|
key, |
|
bias=attention_bias, |
|
dropout_rng=dropout_rng, |
|
dropout_rate=self.config.attn_pdrop, |
|
deterministic=deterministic, |
|
dtype=self.dtype, |
|
precision=None, |
|
) |
|
|
|
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) |
|
attn_output = self._merge_heads(attn_output) |
|
attn_output = self.c_proj(attn_output) |
|
attn_output = self.resid_dropout(attn_output, deterministic=deterministic) |
|
|
|
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) |
|
return outputs |
|
|
|
|
|
class FlaxGPT2MLP(nn.Module): |
|
config: GPT2Config |
|
intermediate_size: int |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
embed_dim = self.config.hidden_size |
|
self.c_fc = FlaxConv1D(self.intermediate_size, dtype=self.dtype) |
|
self.c_proj = FlaxConv1D(embed_dim, dtype=self.dtype) |
|
self.act = ACT2FN[self.config.activation_function] |
|
self.dropout = nn.Dropout(rate=self.config.resid_pdrop) |
|
|
|
def __call__(self, hidden_states, deterministic: bool = True): |
|
hidden_states = self.c_fc(hidden_states) |
|
hidden_states = self.act(hidden_states) |
|
hidden_states = self.c_proj(hidden_states) |
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
return hidden_states |
|
|
|
|
|
class FlaxGPT2Block(nn.Module): |
|
config: GPT2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
|
|
self.only_self_attn = not self.config.add_cross_attention |
|
|
|
hidden_size = self.config.hidden_size |
|
inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size |
|
|
|
self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) |
|
self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype) |
|
|
|
if not self.only_self_attn: |
|
self.cross_attn_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) |
|
|
|
self.cross_attn = FlaxGPT2Attention(config=self.config, dtype=self.dtype, causal=False, self_attn=False) |
|
|
|
if self.config.project_encoder: |
|
self.encoder_projection_ln = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) |
|
self.encoder_projection_mlp = FlaxGPT2MLP(self.config, self.config.hidden_size, dtype=self.dtype) |
|
|
|
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) |
|
self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
deterministic: bool = True, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
): |
|
residual = hidden_states |
|
hidden_states = self.ln_1(hidden_states) |
|
outputs = self.attn( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
deterministic=deterministic, |
|
init_cache=init_cache, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
attn_output = outputs[0] |
|
hidden_states = attn_output + residual |
|
|
|
|
|
if not self.only_self_attn: |
|
assert encoder_hidden_states is not None |
|
else: |
|
assert encoder_hidden_states is None |
|
|
|
|
|
cross_attn_weights = None |
|
if encoder_hidden_states is not None: |
|
|
|
if self.project_encoder: |
|
residual = encoder_hidden_states |
|
encoder_hidden_states = self.encoder_projection_ln(encoder_hidden_states) |
|
feed_forward_hidden_states = self.encoder_projection_mlp( |
|
encoder_hidden_states, deterministic=deterministic |
|
) |
|
|
|
encoder_hidden_states = residual + feed_forward_hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.cross_attn_ln(hidden_states) |
|
|
|
cross_attn_outputs = self.cross_attn( |
|
hidden_states=hidden_states, |
|
key_value_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
deterministic=deterministic, |
|
|
|
init_cache=False, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
cross_attn_output = cross_attn_outputs[0] |
|
hidden_states = cross_attn_output + residual |
|
|
|
if output_attentions: |
|
cross_attn_weights = cross_attn_outputs[1] |
|
|
|
residual = hidden_states |
|
hidden_states = self.ln_2(hidden_states) |
|
feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic) |
|
|
|
hidden_states = residual + feed_forward_hidden_states |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
self_attn_weights = attn_output[1] |
|
outputs += (self_attn_weights,) |
|
if not self.only_self_attn: |
|
outputs += (cross_attn_weights,) |
|
|
|
return outputs |
|
|
|
|
|
class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = GPT2Config |
|
base_model_prefix = "transformer" |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: GPT2Config, |
|
input_shape: Tuple = (1, 1), |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
**kwargs, |
|
): |
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
|
|
|
input_ids = jnp.zeros(input_shape, dtype="i4") |
|
attention_mask = jnp.ones_like(input_ids) |
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) |
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
if self.config.add_cross_attention: |
|
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,)) |
|
encoder_attention_mask = attention_mask |
|
module_init_outputs = self.module.init( |
|
rngs, input_ids, attention_mask, position_ids, |
|
encoder_hidden_states, encoder_attention_mask, return_dict=False |
|
) |
|
else: |
|
module_init_outputs = self.module.init( |
|
rngs, input_ids, attention_mask, position_ids, return_dict=False |
|
) |
|
|
|
return module_init_outputs["params"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_cache(self, batch_size, max_length): |
|
r""" |
|
Args: |
|
batch_size (:obj:`int`): |
|
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. |
|
max_length (:obj:`int`): |
|
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized |
|
cache. |
|
""" |
|
|
|
input_ids = jnp.ones((batch_size, max_length)) |
|
attention_mask = jnp.ones_like(input_ids) |
|
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) |
|
|
|
init_variables = self.module.init( |
|
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True |
|
) |
|
return init_variables["cache"] |
|
|
|
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) |
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
position_ids=None, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
params: dict = None, |
|
past_key_values: dict = None, |
|
dropout_rng: jax.random.PRNGKey = None, |
|
train: bool = False, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if encoder_hidden_states is not None and encoder_attention_mask is None: |
|
batch_size, sequence_length = encoder_hidden_states.shape[:2] |
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
batch_size, sequence_length = input_ids.shape |
|
|
|
if position_ids is None: |
|
if past_key_values is not None: |
|
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") |
|
|
|
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) |
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
|
|
if past_key_values: |
|
inputs["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
jnp.array(input_ids, dtype="i4"), |
|
jnp.array(attention_mask, dtype="i4"), |
|
jnp.array(position_ids, dtype="i4"), |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
not train, |
|
False, |
|
output_attentions, |
|
output_hidden_states, |
|
return_dict, |
|
rngs=rngs, |
|
mutable=mutable, |
|
) |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs, past_key_values = outputs |
|
outputs["past_key_values"] = unfreeze(past_key_values["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs, past_key_values = outputs |
|
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
|
|
class FlaxGPT2BlockCollection(nn.Module): |
|
config: GPT2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.blocks = [ |
|
FlaxGPT2Block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) |
|
] |
|
|
|
def __call__( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
deterministic: bool = True, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
all_attentions = () if output_attentions else None |
|
all_hidden_states = () if output_hidden_states else None |
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None |
|
|
|
for block in self.blocks: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
layer_outputs = block( |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
deterministic=deterministic, |
|
init_cache=init_cache, |
|
output_attentions=output_attentions, |
|
) |
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_attentions += (layer_outputs[1],) |
|
|
|
if encoder_hidden_states is not None: |
|
all_cross_attentions += (layer_outputs[2],) |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
outputs = [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] |
|
|
|
if not return_dict: |
|
return tuple(v for v in outputs if v is not None) |
|
|
|
if encoder_hidden_states is None: |
|
|
|
return FlaxBaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=None, |
|
hidden_states=all_hidden_states, |
|
attentions=all_attentions, |
|
) |
|
else: |
|
|
|
return FlaxBaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=None, |
|
hidden_states=all_hidden_states, |
|
attentions=all_attentions, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
class FlaxGPT2Module(nn.Module): |
|
config: GPT2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.embed_dim = self.config.hidden_size |
|
|
|
self.wte = nn.Embed( |
|
self.config.vocab_size, |
|
self.embed_dim, |
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
self.wpe = nn.Embed( |
|
self.config.max_position_embeddings, |
|
self.embed_dim, |
|
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), |
|
dtype=self.dtype, |
|
) |
|
self.dropout = nn.Dropout(rate=self.config.embd_pdrop) |
|
self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype) |
|
self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
deterministic=True, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
input_embeds = self.wte(input_ids.astype("i4")) |
|
position_embeds = self.wpe(position_ids.astype("i4")) |
|
|
|
hidden_states = input_embeds + position_embeds |
|
hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
|
|
outputs = self.h( |
|
hidden_states, |
|
attention_mask, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
deterministic=deterministic, |
|
init_cache=init_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
hidden_states = self.ln_f(hidden_states) |
|
|
|
if not return_dict: |
|
return (hidden_states,) + outputs[1:] |
|
|
|
if encoder_hidden_states is None: |
|
return FlaxBaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
else: |
|
return FlaxBaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
) |
|
|
|
@add_start_docstrings( |
|
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", |
|
GPT2_START_DOCSTRING, |
|
) |
|
class FlaxGPT2Model(FlaxGPT2PreTrainedModel): |
|
module_class = FlaxGPT2Module |
|
|
|
|
|
append_call_sample_docstring( |
|
FlaxGPT2Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC |
|
) |
|
|
|
|
|
class FlaxGPT2LMHeadModule(nn.Module): |
|
config: GPT2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.transformer = FlaxGPT2Module(self.config, dtype=self.dtype) |
|
self.lm_head = nn.Dense( |
|
self.config.vocab_size, |
|
use_bias=False, |
|
dtype=self.dtype, |
|
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range, dtype=self.dtype), |
|
) |
|
|
|
def __call__( |
|
self, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
encoder_hidden_states: Optional[jnp.ndarray] = None, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
deterministic: bool = True, |
|
init_cache: bool = False, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
): |
|
outputs = self.transformer( |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
deterministic=deterministic, |
|
init_cache=init_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
|
|
if self.config.tie_word_embeddings: |
|
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T |
|
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) |
|
else: |
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
if not return_dict: |
|
return (lm_logits,) + outputs[1:] |
|
|
|
if encoder_hidden_states is None: |
|
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) |
|
else: |
|
return FlaxCausalLMOutputWithCrossAttentions( |
|
logits=lm_logits, |
|
past_key_values=None, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
cross_attentions=outputs.cross_attentions |
|
) |
|
|
|
@add_start_docstrings( |
|
""" |
|
The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input |
|
embeddings). |
|
""", |
|
GPT2_START_DOCSTRING, |
|
) |
|
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel): |
|
module_class = FlaxGPT2LMHeadModule |
|
|
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): |
|
|
|
batch_size, seq_length = input_ids.shape |
|
|
|
past_key_values = self.init_cache(batch_size, max_length) |
|
|
|
|
|
|
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") |
|
if attention_mask is not None: |
|
position_ids = attention_mask.cumsum(axis=-1) - 1 |
|
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) |
|
else: |
|
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) |
|
|
|
return { |
|
"past_key_values": past_key_values, |
|
"attention_mask": extended_attention_mask, |
|
"position_ids": position_ids, |
|
} |
|
|
|
def update_inputs_for_generation(self, model_outputs, model_kwargs): |
|
model_kwargs["past_key_values"] = model_outputs.past_key_values |
|
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 |
|
return model_kwargs |
|
|
|
|
|
append_call_sample_docstring( |
|
FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC |
|
) |
|
|