# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transformer Encoders. Includes configurations and factory methods. """ import dataclasses from typing import Optional, Sequence, Union import gin import tensorflow as tf, tf_keras from official.modeling import hyperparams from official.modeling import tf_utils from official.nlp.modeling import layers from official.nlp.modeling import networks from official.projects.bigbird import encoder as bigbird_encoder @dataclasses.dataclass class BertEncoderConfig(hyperparams.Config): """BERT encoder configuration.""" vocab_size: int = 30522 hidden_size: int = 768 num_layers: int = 12 num_attention_heads: int = 12 hidden_activation: str = "gelu" intermediate_size: int = 3072 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 max_position_embeddings: int = 512 type_vocab_size: int = 2 initializer_range: float = 0.02 embedding_size: Optional[int] = None output_range: Optional[int] = None return_all_encoder_outputs: bool = False return_attention_scores: bool = False # Pre/Post-LN Transformer norm_first: bool = False @dataclasses.dataclass class FunnelEncoderConfig(hyperparams.Config): """Funnel encoder configuration.""" vocab_size: int = 30522 hidden_size: int = 768 num_layers: int = 12 num_attention_heads: int = 12 max_position_embeddings: int = 512 type_vocab_size: int = 16 inner_dim: int = 3072 hidden_activation: str = "gelu" approx_gelu: bool = True dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 pool_type: str = "max" pool_stride: Union[int, Sequence[Union[int, float]]] = 2 unpool_length: int = 0 initializer_range: float = 0.02 output_range: Optional[int] = None embedding_width: Optional[int] = None embedding_layer: Optional[tf_keras.layers.Layer] = None norm_first: bool = False share_rezero: bool = False append_dense_inputs: bool = False transformer_cls: str = "TransformerEncoderBlock" @dataclasses.dataclass class MobileBertEncoderConfig(hyperparams.Config): """MobileBERT encoder configuration. Attributes: word_vocab_size: number of words in the vocabulary. word_embed_size: word embedding size. type_vocab_size: number of word types. max_sequence_length: maximum length of input sequence. num_blocks: number of transformer block in the encoder model. hidden_size: the hidden size for the transformer block. num_attention_heads: number of attention heads in the transformer block. intermediate_size: the size of the "intermediate" (a.k.a., feed forward) layer. hidden_activation: the non-linear activation function to apply to the output of the intermediate/feed-forward layer. hidden_dropout_prob: dropout probability for the hidden layers. attention_probs_dropout_prob: dropout probability of the attention probabilities. intra_bottleneck_size: the size of bottleneck. initializer_range: The stddev of the truncated_normal_initializer for initializing all weight matrices. use_bottleneck_attention: Use attention inputs from the bottleneck transformation. If true, the following `key_query_shared_bottleneck` will be ignored. key_query_shared_bottleneck: whether to share linear transformation for keys and queries. num_feedforward_networks: number of stacked feed-forward networks. normalization_type: the type of normalization_type, only 'no_norm' and 'layer_norm' are supported. 'no_norm' represents the element-wise linear transformation for the student model, as suggested by the original MobileBERT paper. 'layer_norm' is used for the teacher model. classifier_activation: if using the tanh activation for the final representation of the [CLS] token in fine-tuning. """ word_vocab_size: int = 30522 word_embed_size: int = 128 type_vocab_size: int = 2 max_sequence_length: int = 512 num_blocks: int = 24 hidden_size: int = 512 num_attention_heads: int = 4 intermediate_size: int = 4096 hidden_activation: str = "gelu" hidden_dropout_prob: float = 0.1 attention_probs_dropout_prob: float = 0.1 intra_bottleneck_size: int = 1024 initializer_range: float = 0.02 use_bottleneck_attention: bool = False key_query_shared_bottleneck: bool = False num_feedforward_networks: int = 1 normalization_type: str = "layer_norm" classifier_activation: bool = True input_mask_dtype: str = "int32" @dataclasses.dataclass class AlbertEncoderConfig(hyperparams.Config): """ALBERT encoder configuration.""" vocab_size: int = 30000 embedding_width: int = 128 hidden_size: int = 768 num_layers: int = 12 num_attention_heads: int = 12 hidden_activation: str = "gelu" intermediate_size: int = 3072 dropout_rate: float = 0.0 attention_dropout_rate: float = 0.0 max_position_embeddings: int = 512 type_vocab_size: int = 2 initializer_range: float = 0.02 @dataclasses.dataclass class BigBirdEncoderConfig(hyperparams.Config): """BigBird encoder configuration.""" vocab_size: int = 50358 hidden_size: int = 768 num_layers: int = 12 num_attention_heads: int = 12 hidden_activation: str = "gelu" intermediate_size: int = 3072 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 # Pre/Post-LN Transformer norm_first: bool = False max_position_embeddings: int = 4096 num_rand_blocks: int = 3 block_size: int = 64 type_vocab_size: int = 16 initializer_range: float = 0.02 embedding_width: Optional[int] = None use_gradient_checkpointing: bool = False @dataclasses.dataclass class KernelEncoderConfig(hyperparams.Config): """Linear encoder configuration.""" vocab_size: int = 30522 hidden_size: int = 768 num_layers: int = 12 num_attention_heads: int = 12 hidden_activation: str = "gelu" intermediate_size: int = 3072 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 # Pre/Post-LN Transformer norm_first: bool = False max_position_embeddings: int = 512 type_vocab_size: int = 2 initializer_range: float = 0.02 embedding_size: Optional[int] = None feature_transform: str = "exp" num_random_features: int = 256 redraw: bool = False is_short_seq: bool = False begin_kernel: int = 0 scale: Optional[float] = None @dataclasses.dataclass class ReuseEncoderConfig(hyperparams.Config): """Reuse encoder configuration.""" vocab_size: int = 30522 hidden_size: int = 768 num_layers: int = 12 num_attention_heads: int = 12 hidden_activation: str = "gelu" intermediate_size: int = 3072 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 max_position_embeddings: int = 512 type_vocab_size: int = 2 initializer_range: float = 0.02 embedding_size: Optional[int] = None output_range: Optional[int] = None return_all_encoder_outputs: bool = False # Pre/Post-LN Transformer norm_first: bool = False # Reuse transformer reuse_attention: int = -1 use_relative_pe: bool = False pe_max_seq_length: int = 512 max_reuse_layer_idx: int = 6 @dataclasses.dataclass class XLNetEncoderConfig(hyperparams.Config): """XLNet encoder configuration.""" vocab_size: int = 32000 num_layers: int = 24 hidden_size: int = 1024 num_attention_heads: int = 16 head_size: int = 64 inner_size: int = 4096 inner_activation: str = "gelu" dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 attention_type: str = "bi" bi_data: bool = False tie_attention_biases: bool = False memory_length: int = 0 same_length: bool = False clamp_length: int = -1 reuse_length: int = 0 use_cls_mask: bool = False embedding_width: int = 1024 initializer_range: float = 0.02 two_stream: bool = False @dataclasses.dataclass class QueryBertConfig(hyperparams.Config): """Query BERT encoder configuration.""" vocab_size: int = 30522 hidden_size: int = 768 num_layers: int = 12 num_attention_heads: int = 12 hidden_activation: str = "gelu" intermediate_size: int = 3072 dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 max_position_embeddings: int = 512 type_vocab_size: int = 2 initializer_range: float = 0.02 embedding_size: Optional[int] = None output_range: Optional[int] = None return_all_encoder_outputs: bool = False return_attention_scores: bool = False # Pre/Post-LN Transformer norm_first: bool = False @dataclasses.dataclass class FNetEncoderConfig(hyperparams.Config): """FNet encoder configuration.""" vocab_size: int = 30522 hidden_size: int = 768 num_layers: int = 12 num_attention_heads: int = 12 inner_activation: str = "gelu" inner_dim: int = 3072 output_dropout: float = 0.1 attention_dropout: float = 0.1 max_sequence_length: int = 512 type_vocab_size: int = 2 initializer_range: float = 0.02 embedding_width: Optional[int] = None output_range: Optional[int] = None norm_first: bool = False use_fft: bool = False attention_layers: Sequence[int] = () @dataclasses.dataclass class SparseMixerEncoderConfig(hyperparams.Config): """SparseMixer encoder configuration.""" vocab_size: int = 30522 hidden_size: int = 768 num_layers: int = 14 moe_layers: Sequence[int] = (5, 6, 7, 8) attention_layers: Sequence[int] = (10, 11, 12, 13) num_experts: int = 16 train_capacity_factor: float = 1. eval_capacity_factor: float = 1. examples_per_group: float = 1. use_fft: bool = False num_attention_heads: int = 8 max_sequence_length: int = 512 type_vocab_size: int = 2 inner_dim: int = 3072 inner_activation: str = "gelu" output_dropout: float = 0.1 attention_dropout: float = 0.1 initializer_range: float = 0.02 output_range: Optional[int] = None embedding_width: Optional[int] = None norm_first: bool = False @dataclasses.dataclass class EncoderConfig(hyperparams.OneOfConfig): """Encoder configuration.""" type: Optional[str] = "bert" albert: AlbertEncoderConfig = dataclasses.field( default_factory=AlbertEncoderConfig ) bert: BertEncoderConfig = dataclasses.field(default_factory=BertEncoderConfig) bert_v2: BertEncoderConfig = dataclasses.field( default_factory=BertEncoderConfig ) bigbird: BigBirdEncoderConfig = dataclasses.field( default_factory=BigBirdEncoderConfig ) funnel: FunnelEncoderConfig = dataclasses.field( default_factory=FunnelEncoderConfig ) kernel: KernelEncoderConfig = dataclasses.field( default_factory=KernelEncoderConfig ) mobilebert: MobileBertEncoderConfig = dataclasses.field( default_factory=MobileBertEncoderConfig ) reuse: ReuseEncoderConfig = dataclasses.field( default_factory=ReuseEncoderConfig ) xlnet: XLNetEncoderConfig = dataclasses.field( default_factory=XLNetEncoderConfig ) query_bert: QueryBertConfig = dataclasses.field( default_factory=QueryBertConfig ) fnet: FNetEncoderConfig = dataclasses.field(default_factory=FNetEncoderConfig) sparse_mixer: SparseMixerEncoderConfig = dataclasses.field( default_factory=SparseMixerEncoderConfig ) # If `any` is used, the encoder building relies on any.BUILDER. any: hyperparams.Config = dataclasses.field( default_factory=hyperparams.Config ) @gin.configurable def build_encoder(config: EncoderConfig, embedding_layer: Optional[tf_keras.layers.Layer] = None, encoder_cls=None, bypass_config: bool = False): """Instantiate a Transformer encoder network from EncoderConfig. Args: config: the one-of encoder config, which provides encoder parameters of a chosen encoder. embedding_layer: an external embedding layer passed to the encoder. encoder_cls: an external encoder cls not included in the supported encoders, usually used by gin.configurable. bypass_config: whether to ignore config instance to create the object with `encoder_cls`. Returns: An encoder instance. """ if bypass_config: return encoder_cls() encoder_type = config.type encoder_cfg = config.get() if encoder_cls and encoder_cls.__name__ == "EncoderScaffold": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, kernel_initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), ) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs, dict_outputs=True) return encoder_cls(**kwargs) if encoder_type == "any": encoder = encoder_cfg.BUILDER(encoder_cfg) if not isinstance(encoder, (tf.Module, tf_keras.Model, tf_keras.layers.Layer)): raise ValueError("The BUILDER returns an unexpected instance. The " "`build_encoder` should returns a tf.Module, " "tf_keras.Model or tf_keras.layers.Layer. However, " f"we get {encoder.__class__}") return encoder if encoder_type == "mobilebert": return networks.MobileBERTEncoder( word_vocab_size=encoder_cfg.word_vocab_size, word_embed_size=encoder_cfg.word_embed_size, type_vocab_size=encoder_cfg.type_vocab_size, max_sequence_length=encoder_cfg.max_sequence_length, num_blocks=encoder_cfg.num_blocks, hidden_size=encoder_cfg.hidden_size, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_act_fn=encoder_cfg.hidden_activation, hidden_dropout_prob=encoder_cfg.hidden_dropout_prob, attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob, intra_bottleneck_size=encoder_cfg.intra_bottleneck_size, initializer_range=encoder_cfg.initializer_range, use_bottleneck_attention=encoder_cfg.use_bottleneck_attention, key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck, num_feedforward_networks=encoder_cfg.num_feedforward_networks, normalization_type=encoder_cfg.normalization_type, classifier_activation=encoder_cfg.classifier_activation, input_mask_dtype=encoder_cfg.input_mask_dtype) if encoder_type == "albert": return networks.AlbertEncoder( vocab_size=encoder_cfg.vocab_size, embedding_width=encoder_cfg.embedding_width, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dict_outputs=True) if encoder_type == "bigbird": # TODO(frederickliu): Support use_gradient_checkpointing and update # experiments to use the EncoderScaffold only. if encoder_cfg.use_gradient_checkpointing: return bigbird_encoder.BigBirdEncoder( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, num_rand_blocks=encoder_cfg.num_rand_blocks, block_size=encoder_cfg.block_size, max_position_embeddings=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), embedding_width=encoder_cfg.embedding_width, use_gradient_checkpointing=encoder_cfg.use_gradient_checkpointing) embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate) attention_cfg = dict( num_heads=encoder_cfg.num_attention_heads, key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads), kernel_initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), max_rand_mask_length=encoder_cfg.max_position_embeddings, num_rand_blocks=encoder_cfg.num_rand_blocks, from_block_size=encoder_cfg.block_size, to_block_size=encoder_cfg.block_size, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, norm_first=encoder_cfg.norm_first, kernel_initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), attention_cls=layers.BigBirdAttention, attention_cfg=attention_cfg) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cls=layers.TransformerScaffold, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, mask_cls=layers.BigBirdMasks, mask_cfg=dict(block_size=encoder_cfg.block_size), pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=False, dict_outputs=True, layer_idx_as_attention_seed=True) return networks.EncoderScaffold(**kwargs) if encoder_type == "funnel": if encoder_cfg.hidden_activation == "gelu": activation = tf_utils.get_activation( encoder_cfg.hidden_activation, approximate=encoder_cfg.approx_gelu) else: activation = tf_utils.get_activation(encoder_cfg.hidden_activation) return networks.FunnelTransformerEncoder( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, inner_dim=encoder_cfg.inner_dim, inner_activation=activation, output_dropout=encoder_cfg.dropout_rate, attention_dropout=encoder_cfg.attention_dropout_rate, pool_type=encoder_cfg.pool_type, pool_stride=encoder_cfg.pool_stride, unpool_length=encoder_cfg.unpool_length, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), output_range=encoder_cfg.output_range, embedding_width=encoder_cfg.embedding_width, embedding_layer=embedding_layer, norm_first=encoder_cfg.norm_first, share_rezero=encoder_cfg.share_rezero, append_dense_inputs=encoder_cfg.append_dense_inputs, transformer_cls=encoder_cfg.transformer_cls, ) if encoder_type == "kernel": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate) attention_cfg = dict( num_heads=encoder_cfg.num_attention_heads, key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads), kernel_initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), feature_transform=encoder_cfg.feature_transform, num_random_features=encoder_cfg.num_random_features, redraw=encoder_cfg.redraw, is_short_seq=encoder_cfg.is_short_seq, begin_kernel=encoder_cfg.begin_kernel, scale=encoder_cfg.scale, ) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, intermediate_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, norm_first=encoder_cfg.norm_first, kernel_initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), attention_cls=layers.KernelAttention, attention_cfg=attention_cfg) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cls=layers.TransformerScaffold, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, mask_cls=layers.KernelMask, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=False, dict_outputs=True, layer_idx_as_attention_seed=True) return networks.EncoderScaffold(**kwargs) if encoder_type == "xlnet": return networks.XLNetBase( vocab_size=encoder_cfg.vocab_size, num_layers=encoder_cfg.num_layers, hidden_size=encoder_cfg.hidden_size, num_attention_heads=encoder_cfg.num_attention_heads, head_size=encoder_cfg.head_size, inner_size=encoder_cfg.inner_size, dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, attention_type=encoder_cfg.attention_type, bi_data=encoder_cfg.bi_data, two_stream=encoder_cfg.two_stream, tie_attention_biases=encoder_cfg.tie_attention_biases, memory_length=encoder_cfg.memory_length, clamp_length=encoder_cfg.clamp_length, reuse_length=encoder_cfg.reuse_length, inner_activation=encoder_cfg.inner_activation, use_cls_mask=encoder_cfg.use_cls_mask, embedding_width=encoder_cfg.embedding_width, initializer=tf_keras.initializers.RandomNormal( stddev=encoder_cfg.initializer_range)) if encoder_type == "reuse": embedding_cfg = dict( vocab_size=encoder_cfg.vocab_size, type_vocab_size=encoder_cfg.type_vocab_size, hidden_size=encoder_cfg.hidden_size, max_seq_length=encoder_cfg.max_position_embeddings, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), dropout_rate=encoder_cfg.dropout_rate) hidden_cfg = dict( num_attention_heads=encoder_cfg.num_attention_heads, inner_dim=encoder_cfg.intermediate_size, inner_activation=tf_utils.get_activation( encoder_cfg.hidden_activation), output_dropout=encoder_cfg.dropout_rate, attention_dropout=encoder_cfg.attention_dropout_rate, norm_first=encoder_cfg.norm_first, kernel_initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), reuse_attention=encoder_cfg.reuse_attention, use_relative_pe=encoder_cfg.use_relative_pe, pe_max_seq_length=encoder_cfg.pe_max_seq_length, max_reuse_layer_idx=encoder_cfg.max_reuse_layer_idx) kwargs = dict( embedding_cfg=embedding_cfg, hidden_cls=layers.ReuseTransformer, hidden_cfg=hidden_cfg, num_hidden_instances=encoder_cfg.num_layers, pooled_output_dim=encoder_cfg.hidden_size, pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), return_all_layer_outputs=False, dict_outputs=True, feed_layer_idx=True, recursive=True) return networks.EncoderScaffold(**kwargs) if encoder_type == "query_bert": embedding_layer = layers.FactorizedEmbedding( vocab_size=encoder_cfg.vocab_size, embedding_width=encoder_cfg.embedding_size, output_dim=encoder_cfg.hidden_size, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), name="word_embeddings") return networks.BertEncoderV2( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), output_range=encoder_cfg.output_range, embedding_layer=embedding_layer, return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs, return_attention_scores=encoder_cfg.return_attention_scores, dict_outputs=True, norm_first=encoder_cfg.norm_first) if encoder_type == "fnet": return networks.FNet( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, inner_dim=encoder_cfg.inner_dim, inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation), output_dropout=encoder_cfg.output_dropout, attention_dropout=encoder_cfg.attention_dropout, max_sequence_length=encoder_cfg.max_sequence_length, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), output_range=encoder_cfg.output_range, embedding_width=encoder_cfg.embedding_width, embedding_layer=embedding_layer, norm_first=encoder_cfg.norm_first, use_fft=encoder_cfg.use_fft, attention_layers=encoder_cfg.attention_layers) if encoder_type == "sparse_mixer": return networks.SparseMixer( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, moe_layers=encoder_cfg.moe_layers, attention_layers=encoder_cfg.attention_layers, num_experts=encoder_cfg.num_experts, train_capacity_factor=encoder_cfg.train_capacity_factor, eval_capacity_factor=encoder_cfg.eval_capacity_factor, examples_per_group=encoder_cfg.examples_per_group, use_fft=encoder_cfg.use_fft, num_attention_heads=encoder_cfg.num_attention_heads, max_sequence_length=encoder_cfg.max_sequence_length, type_vocab_size=encoder_cfg.type_vocab_size, inner_dim=encoder_cfg.inner_dim, inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation), output_dropout=encoder_cfg.output_dropout, attention_dropout=encoder_cfg.attention_dropout, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), output_range=encoder_cfg.output_range, embedding_width=encoder_cfg.embedding_width, norm_first=encoder_cfg.norm_first, embedding_layer=embedding_layer) bert_encoder_cls = networks.BertEncoder if encoder_type == "bert_v2": bert_encoder_cls = networks.BertEncoderV2 # Uses the default BERTEncoder configuration schema to create the encoder. # If it does not match, please add a switch branch by the encoder type. return bert_encoder_cls( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf_keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), output_range=encoder_cfg.output_range, embedding_width=encoder_cfg.embedding_size, embedding_layer=embedding_layer, return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs, return_attention_scores=encoder_cfg.return_attention_scores, dict_outputs=True, norm_first=encoder_cfg.norm_first)