Spaces:
Paused
Paused
# coding=utf-8 | |
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Callable, Dict, Tuple | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from jax.random import PRNGKey | |
from modeling_flax_performer_utils import make_fast_softmax_attention | |
from transformers.file_utils import add_start_docstrings | |
from transformers.modeling_flax_utils import ACT2FN | |
from transformers.models.bert.configuration_bert import BertConfig | |
from transformers.models.bert.modeling_flax_bert import FlaxBertOnlyMLMHead, FlaxBertPreTrainedModel | |
from transformers.utils import logging | |
logger = logging.get_logger(__name__) | |
_CONFIG_FOR_DOC = "BertConfig" | |
_TOKENIZER_FOR_DOC = "BertTokenizer" | |
BERT_START_DOCSTRING = r""" | |
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic | |
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, | |
pruning heads etc.) | |
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__ | |
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to | |
general usage and behavior. | |
Parameters: | |
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model. | |
Initializing with a config file does not load the weights associated with the model, only the | |
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model | |
weights. | |
""" | |
BERT_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using :class:`~transformers.BertTokenizer`. See | |
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |
details. | |
`What are input IDs? <../glossary.html#input-ids>`__ | |
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): | |
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
`What are attention masks? <../glossary.html#attention-mask>`__ | |
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): | |
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, | |
1]``: | |
- 0 corresponds to a `sentence A` token, | |
- 1 corresponds to a `sentence B` token. | |
`What are token type IDs? <../glossary.html#token-type-ids>`_ | |
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): | |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, | |
config.max_position_embeddings - 1]``. | |
`What are position IDs? <../glossary.html#position-ids>`_ | |
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): | |
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: | |
- 1 indicates the head is **not masked**, | |
- 0 indicates the head is **masked**. | |
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`): | |
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. | |
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated | |
vectors than the model's internal embedding lookup matrix. | |
output_attentions (:obj:`bool`, `optional`): | |
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned | |
tensors for more detail. | |
output_hidden_states (:obj:`bool`, `optional`): | |
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for | |
more detail. | |
return_dict (:obj:`bool`, `optional`): | |
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. | |
""" | |
class FlaxPerformerLayerNorm(nn.Module): | |
""" | |
Layer normalization (https://arxiv.org/abs/1607.06450). Operates on the last axis of the input data. | |
""" | |
epsilon: float = 1e-6 | |
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |
bias: bool = True # If True, bias (beta) is added. | |
scale: bool = True # If True, multiply by scale (gamma). When the next layer is linear | |
# (also e.g. nn.relu), this can be disabled since the scaling will be | |
# done by the next layer. | |
bias_init: jnp.ndarray = nn.initializers.zeros | |
scale_init: jnp.ndarray = nn.initializers.ones | |
def __call__(self, x): | |
""" | |
Applies layer normalization on the input. It normalizes the activations of the layer for each given example in | |
a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that | |
maintains the mean activation within each example close to 0 and the activation standard deviation close to 1 | |
Args: | |
x: the inputs | |
Returns: | |
Normalized inputs (the same shape as inputs). | |
""" | |
features = x.shape[-1] | |
mean = jnp.mean(x, axis=-1, keepdims=True) | |
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) | |
var = mean2 - jax.lax.square(mean) | |
mul = jax.lax.rsqrt(var + self.epsilon) | |
if self.scale: | |
mul = mul * jnp.asarray(self.param("gamma", self.scale_init, (features,)), self.dtype) | |
y = (x - mean) * mul | |
if self.bias: | |
y = y + jnp.asarray(self.param("beta", self.bias_init, (features,)), self.dtype) | |
return y | |
class FlaxPerformerEmbedding(nn.Module): | |
""" | |
Specify a new class for doing the embedding stuff as Flax's one use 'embedding' for the parameter name and PyTorch | |
use 'weight' | |
""" | |
vocab_size: int | |
hidden_size: int | |
emb_init: Callable[..., np.ndarray] = nn.initializers.normal(stddev=0.1) | |
def __call__(self, inputs): | |
embedding = self.param("weight", self.emb_init, (self.vocab_size, self.hidden_size)) | |
return jnp.take(embedding, inputs, axis=0) | |
class FlaxPerformerEmbeddings(nn.Module): | |
"""Construct the embeddings from word, position and token_type embeddings.""" | |
vocab_size: int | |
hidden_size: int | |
type_vocab_size: int | |
max_length: int | |
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): | |
# Embed | |
w_emb = FlaxPerformerEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")( | |
jnp.atleast_2d(input_ids.astype("i4")) | |
) | |
p_emb = FlaxPerformerEmbedding(self.max_length, self.hidden_size, name="position_embeddings")( | |
jnp.atleast_2d(position_ids.astype("i4")) | |
) | |
t_emb = FlaxPerformerEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")( | |
jnp.atleast_2d(token_type_ids.astype("i4")) | |
) | |
# Sum all embeddings | |
summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb | |
# Layer Norm | |
layer_norm = FlaxPerformerLayerNorm(name="layer_norm")(summed_emb) | |
return layer_norm | |
class FlaxPerformerAttention(nn.Module): | |
num_heads: int | |
head_size: int | |
def __call__(self, hidden_state, attention_mask): | |
single_head_dim = self.head_size // self.num_heads | |
fast_softmax_attention = make_fast_softmax_attention(qkv_dim=single_head_dim) | |
self_att = nn.attention.SelfAttention( | |
num_heads=self.num_heads, qkv_features=self.head_size, name="self", attention_fn=fast_softmax_attention | |
)(hidden_state, attention_mask) | |
layer_norm = FlaxPerformerLayerNorm(name="layer_norm")(self_att + hidden_state) | |
return layer_norm | |
class FlaxPerformerIntermediate(nn.Module): | |
output_size: int | |
hidden_act: str = "gelu" | |
def __call__(self, hidden_state): | |
# TODO: Add ACT2FN reference to change activation function | |
dense = nn.Dense(features=self.output_size, name="dense")(hidden_state) | |
return ACT2FN[self.hidden_act](dense) | |
class FlaxPerformerOutput(nn.Module): | |
def __call__(self, intermediate_output, attention_output): | |
hidden_state = nn.Dense(attention_output.shape[-1], name="dense")(intermediate_output) | |
hidden_state = FlaxPerformerLayerNorm(name="layer_norm")(hidden_state + attention_output) | |
return hidden_state | |
class FlaxPerformerLayer(nn.Module): | |
num_heads: int | |
head_size: int | |
intermediate_size: int | |
hidden_act: str = "gelu" | |
def __call__(self, hidden_state, attention_mask): | |
attention = FlaxPerformerAttention(self.num_heads, self.head_size, name="attention")( | |
hidden_state, attention_mask | |
) | |
intermediate = FlaxPerformerIntermediate( | |
self.intermediate_size, name="intermediate", hidden_act=self.hidden_act | |
)(attention) | |
output = FlaxPerformerOutput(name="output")(intermediate, attention) | |
return output | |
class FlaxPerformerLayerCollection(nn.Module): | |
""" | |
Stores N BertLayer(s) | |
""" | |
num_layers: int | |
num_heads: int | |
head_size: int | |
intermediate_size: int | |
hidden_act: str = "gelu" | |
def __call__(self, inputs, attention_mask): | |
assert self.num_layers > 0, f"num_layers should be >= 1, got ({self.num_layers})" | |
# Initialize input / output | |
input_i = inputs | |
# Forward over all encoders | |
for i in range(self.num_layers): | |
layer = FlaxPerformerLayer( | |
self.num_heads, self.head_size, self.intermediate_size, hidden_act=self.hidden_act, name=f"{i}" | |
) | |
input_i = layer(input_i, attention_mask) | |
return input_i | |
class FlaxPerformerEncoder(nn.Module): | |
num_layers: int | |
num_heads: int | |
head_size: int | |
intermediate_size: int | |
hidden_act: str = "gelu" | |
def __call__(self, hidden_state, attention_mask): | |
layer = FlaxPerformerLayerCollection( | |
self.num_layers, | |
self.num_heads, | |
self.head_size, | |
self.intermediate_size, | |
name="layer", | |
hidden_act=self.hidden_act, | |
)(hidden_state, attention_mask) | |
return layer | |
class FlaxPerformerPooler(nn.Module): | |
def __call__(self, hidden_state): | |
cls_token = hidden_state[:, 0] | |
out = nn.Dense(hidden_state.shape[-1], name="dense")(cls_token) | |
return jax.lax.tanh(out) | |
class FlaxPerformerModule(nn.Module): | |
vocab_size: int | |
hidden_size: int | |
type_vocab_size: int | |
max_length: int | |
num_encoder_layers: int | |
num_heads: int | |
head_size: int | |
intermediate_size: int | |
hidden_act: str = "gelu" | |
add_pooling_layer: bool = True | |
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): | |
# Embedding | |
embeddings = FlaxPerformerEmbeddings( | |
self.vocab_size, self.hidden_size, self.type_vocab_size, self.max_length, name="embeddings" | |
)(input_ids, token_type_ids, position_ids, attention_mask) | |
# N stacked encoding layers | |
encoder = FlaxPerformerEncoder( | |
self.num_encoder_layers, | |
self.num_heads, | |
self.head_size, | |
self.intermediate_size, | |
hidden_act=self.hidden_act, | |
name="encoder", | |
)(embeddings, attention_mask) | |
if not self.add_pooling_layer: | |
return encoder | |
pooled = FlaxPerformerPooler(name="pooler")(encoder) | |
return encoder, pooled | |
class FlaxPerformerModel(FlaxBertPreTrainedModel): | |
""" | |
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of | |
cross-attention is added between the self-attention layers, following the architecture described in `Attention is | |
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, | |
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. | |
""" | |
model_class = FlaxPerformerModule | |
config_class = BertConfig | |
base_model_prefix = "bert" | |
def convert_from_pytorch(pt_state: Dict, config: BertConfig) -> Dict: | |
jax_state = dict(pt_state) | |
# Need to change some parameters name to match Flax names so that we don't have to fork any layer | |
for key, tensor in pt_state.items(): | |
# Key parts | |
key_parts = set(key.split(".")) | |
# Every dense layer has "kernel" parameters instead of "weight" | |
if "dense.weight" in key: | |
del jax_state[key] | |
key = key.replace("weight", "kernel") | |
jax_state[key] = tensor | |
# SelfAttention needs also to replace "weight" by "kernel" | |
if {"query", "key", "value"} & key_parts: | |
# Flax SelfAttention decomposes the heads (num_head, size // num_heads) | |
if "bias" in key: | |
jax_state[key] = tensor.reshape((config.num_attention_heads, -1)) | |
elif "weight": | |
del jax_state[key] | |
key = key.replace("weight", "kernel") | |
tensor = tensor.reshape((config.num_attention_heads, -1, config.hidden_size)).transpose((2, 0, 1)) | |
jax_state[key] = tensor | |
# SelfAttention output is not a separate layer, remove one nesting | |
if "attention.output.dense" in key: | |
del jax_state[key] | |
key = key.replace("attention.output.dense", "attention.self.out") | |
jax_state[key] = tensor | |
# SelfAttention output is not a separate layer, remove nesting on layer norm | |
if "attention.output.LayerNorm" in key: | |
del jax_state[key] | |
key = key.replace("attention.output.LayerNorm", "attention.LayerNorm") | |
jax_state[key] = tensor | |
# There are some transposed parameters w.r.t their PyTorch counterpart | |
if "intermediate.dense.kernel" in key or "output.dense.kernel" in key: | |
jax_state[key] = tensor.T | |
# Self Attention output projection needs to be transposed | |
if "out.kernel" in key: | |
jax_state[key] = tensor.reshape((config.hidden_size, config.num_attention_heads, -1)).transpose( | |
1, 2, 0 | |
) | |
# Pooler needs to transpose its kernel | |
if "pooler.dense.kernel" in key: | |
jax_state[key] = tensor.T | |
# Handle LayerNorm conversion | |
if "LayerNorm" in key: | |
del jax_state[key] | |
# Replace LayerNorm by layer_norm | |
new_key = key.replace("LayerNorm", "layer_norm") | |
if "weight" in key: | |
new_key = new_key.replace("weight", "gamma") | |
elif "bias" in key: | |
new_key = new_key.replace("bias", "beta") | |
jax_state[new_key] = tensor | |
return jax_state | |
def __init__( | |
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs | |
): | |
module = FlaxPerformerModule( | |
vocab_size=config.vocab_size, | |
hidden_size=config.hidden_size, | |
type_vocab_size=config.type_vocab_size, | |
max_length=config.max_position_embeddings, | |
num_encoder_layers=config.num_hidden_layers, | |
num_heads=config.num_attention_heads, | |
head_size=config.hidden_size, | |
intermediate_size=config.intermediate_size, | |
dropout_rate=config.hidden_dropout_prob, | |
hidden_act=config.hidden_act, | |
) | |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) | |
def module(self) -> nn.Module: | |
return self._module | |
def __call__( | |
self, input_ids, token_type_ids=None, position_ids=None, dropout_rng: PRNGKey = None, attention_mask=None | |
): | |
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( | |
input_ids, attention_mask, token_type_ids, position_ids | |
) | |
# Handle any PRNG if needed | |
rngs = {} | |
if dropout_rng is not None: | |
rngs["dropout"] = dropout_rng | |
return self.module.apply( | |
{"params": self.params}, | |
jnp.array(input_ids, dtype="i4"), | |
jnp.array(token_type_ids, dtype="i4"), | |
jnp.array(position_ids, dtype="i4"), | |
jnp.array(attention_mask, dtype="i4"), | |
rng=rngs, | |
) | |
class FlaxPerformerForMaskedLM(FlaxBertPreTrainedModel): | |
def __init__( | |
self, config: BertConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs | |
): | |
module = FlaxPerformerForMaskedLMModule( | |
vocab_size=config.vocab_size, | |
type_vocab_size=config.type_vocab_size, | |
hidden_size=config.hidden_size, | |
intermediate_size=config.intermediate_size, | |
head_size=config.hidden_size, | |
num_heads=config.num_attention_heads, | |
num_encoder_layers=config.num_hidden_layers, | |
max_length=config.max_position_embeddings, | |
hidden_act=config.hidden_act, | |
**kwargs, | |
) | |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
params: dict = None, | |
train: bool = False, | |
dropout_rng: PRNGKey = None, | |
): | |
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs( | |
input_ids, attention_mask, token_type_ids, position_ids | |
) | |
# Handle any PRNG if needed | |
rngs = {} | |
if dropout_rng is not None: | |
rngs["dropout"] = dropout_rng | |
return self.module.apply( | |
{"params": params or self.params}, | |
jnp.array(input_ids, dtype="i4"), | |
jnp.array(attention_mask, dtype="i4"), | |
jnp.array(token_type_ids, dtype="i4"), | |
jnp.array(position_ids, dtype="i4"), | |
not train, | |
rngs=rngs, | |
) | |
class FlaxPerformerForMaskedLMModule(nn.Module): | |
vocab_size: int | |
hidden_size: int | |
intermediate_size: int | |
head_size: int | |
num_heads: int | |
num_encoder_layers: int | |
type_vocab_size: int | |
max_length: int | |
hidden_act: str | |
dropout_rate: float = 0.0 | |
dtype: jnp.dtype = jnp.float32 | |
def __call__( | |
self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, deterministic: bool = True | |
): | |
# Model | |
encoder = FlaxPerformerModule( | |
vocab_size=self.vocab_size, | |
hidden_size=self.hidden_size, | |
type_vocab_size=self.type_vocab_size, | |
max_length=self.max_length, | |
num_encoder_layers=self.num_encoder_layers, | |
num_heads=self.num_heads, | |
head_size=self.hidden_size, | |
intermediate_size=self.intermediate_size, | |
hidden_act=self.hidden_act, | |
add_pooling_layer=False, | |
name="bert", | |
)(input_ids, attention_mask, token_type_ids, position_ids) | |
# Compute the prediction scores | |
encoder = nn.Dropout(rate=self.dropout_rate)(encoder, deterministic=deterministic) | |
logits = FlaxBertOnlyMLMHead( | |
vocab_size=self.vocab_size, hidden_act=self.hidden_act, name="cls", dtype=self.dtype | |
)(encoder) | |
return (logits,) | |