|
|
|
|
|
|
|
|
include "trainer_configuration.gin" |
|
|
|
|
|
|
|
|
from transformer import attention |
|
|
from transformer import decoder_stack |
|
|
from transformer import models |
|
|
from transformer import nn_components |
|
|
from transformer import transformer_base |
|
|
from transformer import transformer_layer |
|
|
|
|
|
|
|
|
NUM_LAYERS = 12 |
|
|
NUM_HEADS = 8 |
|
|
HEAD_DIM = 128 |
|
|
EMBED_DIM = 512 |
|
|
MLP_DIM = 2048 |
|
|
NUM_EMBEDDINGS = 256 |
|
|
DROPOUT_RATE = 0.05 |
|
|
ATTN_DROPOUT_RATE = 0.05 |
|
|
|
|
|
|
|
|
DTYPE="float32" |
|
|
|
|
|
|
|
|
decoder_stack.TransformerTaskConfig: |
|
|
dataset_name = "synthetic" |
|
|
train_split = "train" |
|
|
test_split = "test" |
|
|
sequence_length = 512 |
|
|
batch_size = 8 |
|
|
vocab_size = %NUM_EMBEDDINGS |
|
|
|
|
|
transformer_layer.TransformerLayer: |
|
|
num_heads = %NUM_HEADS |
|
|
head_size = %HEAD_DIM |
|
|
window_length = 512 |
|
|
use_long_xl_architecture = True |
|
|
max_unrolled_windows = -1 |
|
|
relative_position_type = "t5" |
|
|
use_causal_mask = True |
|
|
attn_dropout_rate = %ATTN_DROPOUT_RATE |
|
|
|
|
|
memory_num_neighbors = 0 |
|
|
compute_importance = False |
|
|
dtype = %DTYPE |
|
|
|
|
|
transformer_base.TransformerBase: |
|
|
attn_mlp_factory = @transformer_attn/nn_components.MLP |
|
|
ffn_factory = @transformer_ffn/nn_components.MLP |
|
|
normalize_keys = True |
|
|
dropout_rate = %DROPOUT_RATE |
|
|
pre_attn_dropout = True |
|
|
post_attn_dropout = False |
|
|
pre_ffn_dropout = False |
|
|
post_ffn_dropout = True |
|
|
|
|
|
transformer_attn/nn_components.MLP: |
|
|
num_layers = 1 |
|
|
num_hidden_units = 0 |
|
|
hidden_activation = None |
|
|
use_bias = False |
|
|
|
|
|
transformer_ffn/nn_components.MLP: |
|
|
num_layers = 2 |
|
|
num_hidden_units = %MLP_DIM |
|
|
hidden_activation = "relu" |
|
|
use_bias = False |
|
|
|
|
|
decoder_stack.DecoderStack: |
|
|
|
|
|
num_layers = %NUM_LAYERS |
|
|
embedding_size = %EMBED_DIM |
|
|
embedding_stddev = 1.0 |
|
|
layer_factory = @transformer_layer.TransformerLayer |
|
|
dstack_window_length = 0 |
|
|
use_absolute_positions = False |
|
|
use_final_layernorm = True |
|
|
final_dropout_rate = %DROPOUT_RATE |
|
|
final_mlp_factory = None |
|
|
recurrent_layer_indices = () |
|
|
memory_factory = None |
|
|
memory_layer_indices = () |
|
|
dtype = %DTYPE |
|
|
|
|
|
models.DecoderOnlyLanguageModel: |
|
|
task_config = @decoder_stack.TransformerTaskConfig() |
|
|
decoder_factory = @decoder_stack.DecoderStack |
|
|
|
|
|
nn_components.LayerNorm: |
|
|
use_scale = True |
|
|
use_bias = False |
|
|
use_mean = False |
|
|
dtype = %DTYPE |
|
|
|
|
|
|