Spaces:
Sleeping
Sleeping
# Base configuration for the Hierarchical Transformer. | |
include "trainer_configuration.gin" | |
# Imports | |
from transformer import attention | |
from transformer import decoder_stack | |
from transformer import models | |
from transformer import nn_components | |
from transformer import transformer_base | |
from transformer import transformer_layer | |
NUM_LAYERS = 12 | |
NUM_HEADS = 8 | |
HEAD_DIM = 128 | |
EMBED_DIM = 512 # Size of embedding vector for each token | |
MLP_DIM = 2048 # Number of hidden units in transformer FFN | |
NUM_EMBEDDINGS = 256 # Number of tokens in vocabulary. | |
DROPOUT_RATE = 0.05 | |
ATTN_DROPOUT_RATE = 0.05 | |
# For training on TPU. | |
DTYPE="float32" | |
# Configure the input task. | |
decoder_stack.TransformerTaskConfig: | |
dataset_name = "synthetic" | |
train_split = "train" | |
test_split = "test" | |
sequence_length = 512 | |
batch_size = 8 | |
vocab_size = %NUM_EMBEDDINGS | |
transformer_layer.TransformerLayer: | |
num_heads = %NUM_HEADS | |
head_size = %HEAD_DIM | |
window_length = 512 | |
use_long_xl_architecture = True | |
max_unrolled_windows = -1 # Always unroll. | |
relative_position_type = "t5" # Can be "fourier", "t5", or None. | |
use_causal_mask = True | |
attn_dropout_rate = %ATTN_DROPOUT_RATE # Attention matrix dropout. | |
memory_num_neighbors = 0 | |
compute_importance = False | |
dtype = %DTYPE | |
transformer_base.TransformerBase: | |
attn_mlp_factory = @transformer_attn/nn_components.MLP | |
ffn_factory = @transformer_ffn/nn_components.MLP | |
normalize_keys = True # More stable with Transformer XL. | |
dropout_rate = %DROPOUT_RATE | |
pre_attn_dropout = True | |
post_attn_dropout = False | |
pre_ffn_dropout = False | |
post_ffn_dropout = True | |
transformer_attn/nn_components.MLP: | |
num_layers = 1 # Just a single dense matmul. | |
num_hidden_units = 0 | |
hidden_activation = None | |
use_bias = False | |
transformer_ffn/nn_components.MLP: | |
num_layers = 2 | |
num_hidden_units = %MLP_DIM | |
hidden_activation = "relu" | |
use_bias = False | |
decoder_stack.DecoderStack: | |
# task_config will be passed in from DecoderOnlyLanguageModel. | |
num_layers = %NUM_LAYERS | |
embedding_size = %EMBED_DIM | |
embedding_stddev = 1.0 | |
layer_factory = @transformer_layer.TransformerLayer | |
dstack_window_length = 0 | |
use_absolute_positions = False | |
use_final_layernorm = True # Final layernorm before token lookup. | |
final_dropout_rate = %DROPOUT_RATE # Dropout before token lookup. | |
final_mlp_factory = None # Final MLP to predict target tokens. | |
recurrent_layer_indices = () | |
memory_factory = None # e.g. @memory_factory.memory_on_tpu_factory | |
memory_layer_indices = () | |
dtype = %DTYPE | |
models.DecoderOnlyLanguageModel: | |
task_config = @decoder_stack.TransformerTaskConfig() | |
decoder_factory = @decoder_stack.DecoderStack | |
nn_components.LayerNorm: | |
use_scale = True | |
use_bias = False | |
use_mean = False # Calculate and adjust for the mean as well as the scale. | |
dtype = %DTYPE | |