Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 Google AI, Google Brain and the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Callable, Optional, Tuple | |
import flax | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze | |
from flax.linen.attention import dot_product_attention_weights | |
from flax.traverse_util import flatten_dict, unflatten_dict | |
from jax import lax | |
from ...modeling_flax_outputs import ( | |
FlaxBaseModelOutput, | |
FlaxBaseModelOutputWithPooling, | |
FlaxMaskedLMOutput, | |
FlaxMultipleChoiceModelOutput, | |
FlaxQuestionAnsweringModelOutput, | |
FlaxSequenceClassifierOutput, | |
FlaxTokenClassifierOutput, | |
) | |
from ...modeling_flax_utils import ( | |
ACT2FN, | |
FlaxPreTrainedModel, | |
append_call_sample_docstring, | |
append_replace_return_docstrings, | |
overwrite_call_docstring, | |
) | |
from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging | |
from .configuration_albert import AlbertConfig | |
logger = logging.get_logger(__name__) | |
_CHECKPOINT_FOR_DOC = "albert-base-v2" | |
_CONFIG_FOR_DOC = "AlbertConfig" | |
class FlaxAlbertForPreTrainingOutput(ModelOutput): | |
""" | |
Output type of [`FlaxAlbertForPreTraining`]. | |
Args: | |
prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
sop_logits (`jnp.ndarray` of shape `(batch_size, 2)`): | |
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation | |
before SoftMax). | |
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape | |
`(batch_size, sequence_length, hidden_size)`. | |
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | |
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | |
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | |
sequence_length)`. | |
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |
heads. | |
""" | |
prediction_logits: jnp.ndarray = None | |
sop_logits: jnp.ndarray = None | |
hidden_states: Optional[Tuple[jnp.ndarray]] = None | |
attentions: Optional[Tuple[jnp.ndarray]] = None | |
ALBERT_START_DOCSTRING = r""" | |
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the | |
library implements for all its model (such as downloading, saving and converting weights from PyTorch models) | |
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) | |
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to | |
general usage and behavior. | |
Finally, this model supports inherent JAX features such as: | |
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) | |
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) | |
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) | |
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) | |
Parameters: | |
config ([`AlbertConfig`]): Model configuration class with all the parameters of the model. | |
Initializing with a config file does not load the weights associated with the model, only the | |
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. | |
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): | |
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and | |
`jax.numpy.bfloat16` (on TPUs). | |
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If | |
specified all the computation will be performed with the given `dtype`. | |
**Note that this only specifies the dtype of the computation and does not influence the dtype of model | |
parameters.** | |
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and | |
[`~FlaxPreTrainedModel.to_bf16`]. | |
""" | |
ALBERT_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (`numpy.ndarray` of shape `({0})`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): | |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): | |
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, | |
1]`: | |
- 0 corresponds to a *sentence A* token, | |
- 1 corresponds to a *sentence B* token. | |
[What are token type IDs?](../glossary#token-type-ids) | |
position_ids (`numpy.ndarray` of shape `({0})`, *optional*): | |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | |
config.max_position_embeddings - 1]`. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
class FlaxAlbertEmbeddings(nn.Module): | |
"""Construct the embeddings from word, position and token_type embeddings.""" | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
self.word_embeddings = nn.Embed( | |
self.config.vocab_size, | |
self.config.embedding_size, | |
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), | |
) | |
self.position_embeddings = nn.Embed( | |
self.config.max_position_embeddings, | |
self.config.embedding_size, | |
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), | |
) | |
self.token_type_embeddings = nn.Embed( | |
self.config.type_vocab_size, | |
self.config.embedding_size, | |
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), | |
) | |
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) | |
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) | |
def __call__(self, input_ids, token_type_ids, position_ids, deterministic: bool = True): | |
# Embed | |
inputs_embeds = self.word_embeddings(input_ids.astype("i4")) | |
position_embeds = self.position_embeddings(position_ids.astype("i4")) | |
token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) | |
# Sum all embeddings | |
hidden_states = inputs_embeds + token_type_embeddings + position_embeds | |
# Layer Norm | |
hidden_states = self.LayerNorm(hidden_states) | |
hidden_states = self.dropout(hidden_states, deterministic=deterministic) | |
return hidden_states | |
class FlaxAlbertSelfAttention(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
if self.config.hidden_size % self.config.num_attention_heads != 0: | |
raise ValueError( | |
"`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " | |
" : {self.config.num_attention_heads}" | |
) | |
self.query = nn.Dense( | |
self.config.hidden_size, | |
dtype=self.dtype, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
) | |
self.key = nn.Dense( | |
self.config.hidden_size, | |
dtype=self.dtype, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
) | |
self.value = nn.Dense( | |
self.config.hidden_size, | |
dtype=self.dtype, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
) | |
self.dense = nn.Dense( | |
self.config.hidden_size, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
dtype=self.dtype, | |
) | |
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) | |
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) | |
def __call__(self, hidden_states, attention_mask, deterministic=True, output_attentions: bool = False): | |
head_dim = self.config.hidden_size // self.config.num_attention_heads | |
query_states = self.query(hidden_states).reshape( | |
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) | |
) | |
value_states = self.value(hidden_states).reshape( | |
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) | |
) | |
key_states = self.key(hidden_states).reshape( | |
hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) | |
) | |
# Convert the boolean attention mask to an attention bias. | |
if attention_mask is not None: | |
# attention mask in the form of attention bias | |
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) | |
attention_bias = lax.select( | |
attention_mask > 0, | |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), | |
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), | |
) | |
else: | |
attention_bias = None | |
dropout_rng = None | |
if not deterministic and self.config.attention_probs_dropout_prob > 0.0: | |
dropout_rng = self.make_rng("dropout") | |
attn_weights = dot_product_attention_weights( | |
query_states, | |
key_states, | |
bias=attention_bias, | |
dropout_rng=dropout_rng, | |
dropout_rate=self.config.attention_probs_dropout_prob, | |
broadcast_dropout=True, | |
deterministic=deterministic, | |
dtype=self.dtype, | |
precision=None, | |
) | |
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) | |
attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) | |
projected_attn_output = self.dense(attn_output) | |
projected_attn_output = self.dropout(projected_attn_output, deterministic=deterministic) | |
layernormed_attn_output = self.LayerNorm(projected_attn_output + hidden_states) | |
outputs = (layernormed_attn_output, attn_weights) if output_attentions else (layernormed_attn_output,) | |
return outputs | |
class FlaxAlbertLayer(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype) | |
self.ffn = nn.Dense( | |
self.config.intermediate_size, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
dtype=self.dtype, | |
) | |
self.activation = ACT2FN[self.config.hidden_act] | |
self.ffn_output = nn.Dense( | |
self.config.hidden_size, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
dtype=self.dtype, | |
) | |
self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) | |
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
): | |
attention_outputs = self.attention( | |
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions | |
) | |
attention_output = attention_outputs[0] | |
ffn_output = self.ffn(attention_output) | |
ffn_output = self.activation(ffn_output) | |
ffn_output = self.ffn_output(ffn_output) | |
ffn_output = self.dropout(ffn_output, deterministic=deterministic) | |
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output) | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (attention_outputs[1],) | |
return outputs | |
class FlaxAlbertLayerCollection(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
self.layers = [ | |
FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num) | |
] | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
): | |
layer_hidden_states = () | |
layer_attentions = () | |
for layer_index, albert_layer in enumerate(self.layers): | |
layer_output = albert_layer( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
) | |
hidden_states = layer_output[0] | |
if output_attentions: | |
layer_attentions = layer_attentions + (layer_output[1],) | |
if output_hidden_states: | |
layer_hidden_states = layer_hidden_states + (hidden_states,) | |
outputs = (hidden_states,) | |
if output_hidden_states: | |
outputs = outputs + (layer_hidden_states,) | |
if output_attentions: | |
outputs = outputs + (layer_attentions,) | |
return outputs # last-layer hidden state, (layer hidden states), (layer attentions) | |
class FlaxAlbertLayerCollections(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
layer_index: Optional[str] = None | |
def setup(self): | |
self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype) | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
): | |
outputs = self.albert_layers( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
return outputs | |
class FlaxAlbertLayerGroups(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
self.layers = [ | |
FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype) | |
for i in range(self.config.num_hidden_groups) | |
] | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
all_attentions = () if output_attentions else None | |
all_hidden_states = (hidden_states,) if output_hidden_states else None | |
for i in range(self.config.num_hidden_layers): | |
# Index of the hidden group | |
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups)) | |
layer_group_output = self.layers[group_idx]( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
hidden_states = layer_group_output[0] | |
if output_attentions: | |
all_attentions = all_attentions + layer_group_output[-1] | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
if not return_dict: | |
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) | |
return FlaxBaseModelOutput( | |
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions | |
) | |
class FlaxAlbertEncoder(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
def setup(self): | |
self.embedding_hidden_mapping_in = nn.Dense( | |
self.config.hidden_size, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
dtype=self.dtype, | |
) | |
self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype) | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
hidden_states = self.embedding_hidden_mapping_in(hidden_states) | |
return self.albert_layer_groups( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
class FlaxAlbertOnlyMLMHead(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 | |
bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros | |
def setup(self): | |
self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype) | |
self.activation = ACT2FN[self.config.hidden_act] | |
self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) | |
self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) | |
self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) | |
def __call__(self, hidden_states, shared_embedding=None): | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.activation(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states) | |
if shared_embedding is not None: | |
hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) | |
else: | |
hidden_states = self.decoder(hidden_states) | |
hidden_states += self.bias | |
return hidden_states | |
class FlaxAlbertSOPHead(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.dropout = nn.Dropout(self.config.classifier_dropout_prob) | |
self.classifier = nn.Dense(2, dtype=self.dtype) | |
def __call__(self, pooled_output, deterministic=True): | |
pooled_output = self.dropout(pooled_output, deterministic=deterministic) | |
logits = self.classifier(pooled_output) | |
return logits | |
class FlaxAlbertPreTrainedModel(FlaxPreTrainedModel): | |
""" | |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
models. | |
""" | |
config_class = AlbertConfig | |
base_model_prefix = "albert" | |
module_class: nn.Module = None | |
def __init__( | |
self, | |
config: AlbertConfig, | |
input_shape: Tuple = (1, 1), | |
seed: int = 0, | |
dtype: jnp.dtype = jnp.float32, | |
_do_init: bool = True, | |
**kwargs, | |
): | |
module = self.module_class(config=config, dtype=dtype, **kwargs) | |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) | |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: | |
# init input tensors | |
input_ids = jnp.zeros(input_shape, dtype="i4") | |
token_type_ids = jnp.zeros_like(input_ids) | |
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) | |
attention_mask = jnp.ones_like(input_ids) | |
params_rng, dropout_rng = jax.random.split(rng) | |
rngs = {"params": params_rng, "dropout": dropout_rng} | |
random_params = self.module.init( | |
rngs, input_ids, attention_mask, token_type_ids, position_ids, return_dict=False | |
)["params"] | |
if params is not None: | |
random_params = flatten_dict(unfreeze(random_params)) | |
params = flatten_dict(unfreeze(params)) | |
for missing_key in self._missing_keys: | |
params[missing_key] = random_params[missing_key] | |
self._missing_keys = set() | |
return freeze(unflatten_dict(params)) | |
else: | |
return random_params | |
def __call__( | |
self, | |
input_ids, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
params: dict = None, | |
dropout_rng: jax.random.PRNGKey = None, | |
train: bool = False, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
): | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.return_dict | |
# init input tensors if not passed | |
if token_type_ids is None: | |
token_type_ids = jnp.zeros_like(input_ids) | |
if position_ids is None: | |
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) | |
if attention_mask is None: | |
attention_mask = jnp.ones_like(input_ids) | |
# Handle any PRNG if needed | |
rngs = {} | |
if dropout_rng is not None: | |
rngs["dropout"] = dropout_rng | |
return self.module.apply( | |
{"params": params or self.params}, | |
jnp.array(input_ids, dtype="i4"), | |
jnp.array(attention_mask, dtype="i4"), | |
jnp.array(token_type_ids, dtype="i4"), | |
jnp.array(position_ids, dtype="i4"), | |
not train, | |
output_attentions, | |
output_hidden_states, | |
return_dict, | |
rngs=rngs, | |
) | |
class FlaxAlbertModule(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
add_pooling_layer: bool = True | |
def setup(self): | |
self.embeddings = FlaxAlbertEmbeddings(self.config, dtype=self.dtype) | |
self.encoder = FlaxAlbertEncoder(self.config, dtype=self.dtype) | |
if self.add_pooling_layer: | |
self.pooler = nn.Dense( | |
self.config.hidden_size, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
dtype=self.dtype, | |
name="pooler", | |
) | |
self.pooler_activation = nn.tanh | |
else: | |
self.pooler = None | |
self.pooler_activation = None | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
token_type_ids: Optional[np.ndarray] = None, | |
position_ids: Optional[np.ndarray] = None, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
# make sure `token_type_ids` is correctly initialized when not passed | |
if token_type_ids is None: | |
token_type_ids = jnp.zeros_like(input_ids) | |
# make sure `position_ids` is correctly initialized when not passed | |
if position_ids is None: | |
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) | |
hidden_states = self.embeddings(input_ids, token_type_ids, position_ids, deterministic=deterministic) | |
outputs = self.encoder( | |
hidden_states, | |
attention_mask, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
if self.add_pooling_layer: | |
pooled = self.pooler(hidden_states[:, 0]) | |
pooled = self.pooler_activation(pooled) | |
else: | |
pooled = None | |
if not return_dict: | |
# if pooled is None, don't return it | |
if pooled is None: | |
return (hidden_states,) + outputs[1:] | |
return (hidden_states, pooled) + outputs[1:] | |
return FlaxBaseModelOutputWithPooling( | |
last_hidden_state=hidden_states, | |
pooler_output=pooled, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class FlaxAlbertModel(FlaxAlbertPreTrainedModel): | |
module_class = FlaxAlbertModule | |
append_call_sample_docstring(FlaxAlbertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC) | |
class FlaxAlbertForPreTrainingModule(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) | |
self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) | |
self.sop_classifier = FlaxAlbertSOPHead(config=self.config, dtype=self.dtype) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
# Model | |
outputs = self.albert( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
if self.config.tie_word_embeddings: | |
shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] | |
else: | |
shared_embedding = None | |
hidden_states = outputs[0] | |
pooled_output = outputs[1] | |
prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) | |
sop_scores = self.sop_classifier(pooled_output, deterministic=deterministic) | |
if not return_dict: | |
return (prediction_scores, sop_scores) + outputs[2:] | |
return FlaxAlbertForPreTrainingOutput( | |
prediction_logits=prediction_scores, | |
sop_logits=sop_scores, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class FlaxAlbertForPreTraining(FlaxAlbertPreTrainedModel): | |
module_class = FlaxAlbertForPreTrainingModule | |
FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING = """ | |
Returns: | |
Example: | |
```python | |
>>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining | |
>>> tokenizer = AutoTokenizer.from_pretrained("albert-base-v2") | |
>>> model = FlaxAlbertForPreTraining.from_pretrained("albert-base-v2") | |
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") | |
>>> outputs = model(**inputs) | |
>>> prediction_logits = outputs.prediction_logits | |
>>> seq_relationship_logits = outputs.sop_logits | |
``` | |
""" | |
overwrite_call_docstring( | |
FlaxAlbertForPreTraining, | |
ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ALBERT_FOR_PRETRAINING_DOCSTRING, | |
) | |
append_replace_return_docstrings( | |
FlaxAlbertForPreTraining, output_type=FlaxAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC | |
) | |
class FlaxAlbertForMaskedLMModule(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.albert = FlaxAlbertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) | |
self.predictions = FlaxAlbertOnlyMLMHead(config=self.config, dtype=self.dtype) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
# Model | |
outputs = self.albert( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
if self.config.tie_word_embeddings: | |
shared_embedding = self.albert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] | |
else: | |
shared_embedding = None | |
# Compute the prediction scores | |
logits = self.predictions(hidden_states, shared_embedding=shared_embedding) | |
if not return_dict: | |
return (logits,) + outputs[1:] | |
return FlaxMaskedLMOutput( | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel): | |
module_class = FlaxAlbertForMaskedLMModule | |
append_call_sample_docstring(FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC) | |
class FlaxAlbertForSequenceClassificationModule(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) | |
classifier_dropout = ( | |
self.config.classifier_dropout_prob | |
if self.config.classifier_dropout_prob is not None | |
else self.config.hidden_dropout_prob | |
) | |
self.dropout = nn.Dropout(rate=classifier_dropout) | |
self.classifier = nn.Dense( | |
self.config.num_labels, | |
dtype=self.dtype, | |
) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
# Model | |
outputs = self.albert( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
pooled_output = outputs[1] | |
pooled_output = self.dropout(pooled_output, deterministic=deterministic) | |
logits = self.classifier(pooled_output) | |
if not return_dict: | |
return (logits,) + outputs[2:] | |
return FlaxSequenceClassifierOutput( | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class FlaxAlbertForSequenceClassification(FlaxAlbertPreTrainedModel): | |
module_class = FlaxAlbertForSequenceClassificationModule | |
append_call_sample_docstring( | |
FlaxAlbertForSequenceClassification, | |
_CHECKPOINT_FOR_DOC, | |
FlaxSequenceClassifierOutput, | |
_CONFIG_FOR_DOC, | |
) | |
class FlaxAlbertForMultipleChoiceModule(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype) | |
self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) | |
self.classifier = nn.Dense(1, dtype=self.dtype) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
num_choices = input_ids.shape[1] | |
input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None | |
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None | |
token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None | |
position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None | |
# Model | |
outputs = self.albert( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
pooled_output = outputs[1] | |
pooled_output = self.dropout(pooled_output, deterministic=deterministic) | |
logits = self.classifier(pooled_output) | |
reshaped_logits = logits.reshape(-1, num_choices) | |
if not return_dict: | |
return (reshaped_logits,) + outputs[2:] | |
return FlaxMultipleChoiceModelOutput( | |
logits=reshaped_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class FlaxAlbertForMultipleChoice(FlaxAlbertPreTrainedModel): | |
module_class = FlaxAlbertForMultipleChoiceModule | |
overwrite_call_docstring( | |
FlaxAlbertForMultipleChoice, ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") | |
) | |
append_call_sample_docstring( | |
FlaxAlbertForMultipleChoice, | |
_CHECKPOINT_FOR_DOC, | |
FlaxMultipleChoiceModelOutput, | |
_CONFIG_FOR_DOC, | |
) | |
class FlaxAlbertForTokenClassificationModule(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) | |
classifier_dropout = ( | |
self.config.classifier_dropout_prob | |
if self.config.classifier_dropout_prob is not None | |
else self.config.hidden_dropout_prob | |
) | |
self.dropout = nn.Dropout(rate=classifier_dropout) | |
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
# Model | |
outputs = self.albert( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
hidden_states = self.dropout(hidden_states, deterministic=deterministic) | |
logits = self.classifier(hidden_states) | |
if not return_dict: | |
return (logits,) + outputs[1:] | |
return FlaxTokenClassifierOutput( | |
logits=logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class FlaxAlbertForTokenClassification(FlaxAlbertPreTrainedModel): | |
module_class = FlaxAlbertForTokenClassificationModule | |
append_call_sample_docstring( | |
FlaxAlbertForTokenClassification, | |
_CHECKPOINT_FOR_DOC, | |
FlaxTokenClassifierOutput, | |
_CONFIG_FOR_DOC, | |
) | |
class FlaxAlbertForQuestionAnsweringModule(nn.Module): | |
config: AlbertConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) | |
self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic: bool = True, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
# Model | |
outputs = self.albert( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
position_ids, | |
deterministic=deterministic, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
logits = self.qa_outputs(hidden_states) | |
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) | |
start_logits = start_logits.squeeze(-1) | |
end_logits = end_logits.squeeze(-1) | |
if not return_dict: | |
return (start_logits, end_logits) + outputs[1:] | |
return FlaxQuestionAnsweringModelOutput( | |
start_logits=start_logits, | |
end_logits=end_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class FlaxAlbertForQuestionAnswering(FlaxAlbertPreTrainedModel): | |
module_class = FlaxAlbertForQuestionAnsweringModule | |
append_call_sample_docstring( | |
FlaxAlbertForQuestionAnswering, | |
_CHECKPOINT_FOR_DOC, | |
FlaxQuestionAnsweringModelOutput, | |
_CONFIG_FOR_DOC, | |
) | |