Spaces:
Sleeping
Sleeping
File size: 2,918 Bytes
c50c41b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# Base configuration for the Hierarchical Transformer.
include "trainer_configuration.gin"
# Imports
from transformer import attention
from transformer import decoder_stack
from transformer import models
from transformer import nn_components
from transformer import transformer_base
from transformer import transformer_layer
NUM_LAYERS = 12
NUM_HEADS = 8
HEAD_DIM = 128
EMBED_DIM = 512 # Size of embedding vector for each token
MLP_DIM = 2048 # Number of hidden units in transformer FFN
NUM_EMBEDDINGS = 256 # Number of tokens in vocabulary.
DROPOUT_RATE = 0.05
ATTN_DROPOUT_RATE = 0.05
# For training on TPU.
DTYPE="float32"
# Configure the input task.
decoder_stack.TransformerTaskConfig:
dataset_name = "synthetic"
train_split = "train"
test_split = "test"
sequence_length = 512
batch_size = 8
vocab_size = %NUM_EMBEDDINGS
transformer_layer.TransformerLayer:
num_heads = %NUM_HEADS
head_size = %HEAD_DIM
window_length = 512
use_long_xl_architecture = True
max_unrolled_windows = -1 # Always unroll.
relative_position_type = "t5" # Can be "fourier", "t5", or None.
use_causal_mask = True
attn_dropout_rate = %ATTN_DROPOUT_RATE # Attention matrix dropout.
memory_num_neighbors = 0
compute_importance = False
dtype = %DTYPE
transformer_base.TransformerBase:
attn_mlp_factory = @transformer_attn/nn_components.MLP
ffn_factory = @transformer_ffn/nn_components.MLP
normalize_keys = True # More stable with Transformer XL.
dropout_rate = %DROPOUT_RATE
pre_attn_dropout = True
post_attn_dropout = False
pre_ffn_dropout = False
post_ffn_dropout = True
transformer_attn/nn_components.MLP:
num_layers = 1 # Just a single dense matmul.
num_hidden_units = 0
hidden_activation = None
use_bias = False
transformer_ffn/nn_components.MLP:
num_layers = 2
num_hidden_units = %MLP_DIM
hidden_activation = "relu"
use_bias = False
decoder_stack.DecoderStack:
# task_config will be passed in from DecoderOnlyLanguageModel.
num_layers = %NUM_LAYERS
embedding_size = %EMBED_DIM
embedding_stddev = 1.0
layer_factory = @transformer_layer.TransformerLayer
dstack_window_length = 0
use_absolute_positions = False
use_final_layernorm = True # Final layernorm before token lookup.
final_dropout_rate = %DROPOUT_RATE # Dropout before token lookup.
final_mlp_factory = None # Final MLP to predict target tokens.
recurrent_layer_indices = ()
memory_factory = None # e.g. @memory_factory.memory_on_tpu_factory
memory_layer_indices = ()
dtype = %DTYPE
models.DecoderOnlyLanguageModel:
task_config = @decoder_stack.TransformerTaskConfig()
decoder_factory = @decoder_stack.DecoderStack
nn_components.LayerNorm:
use_scale = True
use_bias = False
use_mean = False # Calculate and adjust for the mean as well as the scale.
dtype = %DTYPE
|