File size: 3,787 Bytes
81de315
 
 
 
 
 
 
 
 
 
 
 
 
e01a248
57a778a
81de315
 
 
 
57a778a
81de315
 
57a778a
81de315
 
 
 
 
 
 
 
 
 
 
 
e01a248
81de315
 
 
e01a248
81de315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from __gin__ import dynamic_registration
import __main__ as eval_script
import seqio
from t5.data import mixtures
from t5x import adafactor
from t5x.examples.t5 import network
from t5x import models
from t5x import partitioning
from t5x import utils
import tasks

# Macros:
# ==============================================================================
CHECKPOINT_PATH = \
    'gs://nb-t5x-us-central2/finetuned/v2_norwegian_NCC_plus_English_t5x_base_1_500_000_sentiment/checkpoint_1510000'
DROPOUT_RATE = 0.0
EVAL_OUTPUT_DIR = './log/'
LABEL_SMOOTHING = 0.0
LOSS_NORMALIZING_FACTOR = None
MIXTURE_OR_TASK_NAME = 'sentiment'
MODEL = @models.EncoderDecoderModel()
OPTIMIZER = @adafactor.Adafactor()
SPLIT = 'test'
VOCABULARY = @seqio.SentencePieceVocabulary()
Z_LOSS = 0.0001

# Parameters for adafactor.Adafactor:
# ==============================================================================
adafactor.Adafactor.decay_rate = 0.8
adafactor.Adafactor.logical_factor_rules = \
    @adafactor.standard_logical_factor_rules()
adafactor.Adafactor.step_offset = 0

# Parameters for utils.DatasetConfig:
# ==============================================================================
utils.DatasetConfig.batch_size = 16
utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
utils.DatasetConfig.seed = 42
utils.DatasetConfig.shuffle = False
utils.DatasetConfig.split = %SPLIT
utils.DatasetConfig.task_feature_lengths = {'inputs': 512, 'targets': 2}

# Parameters for models.EncoderDecoderModel:
# ==============================================================================
models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
models.EncoderDecoderModel.module = @network.Transformer()
models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
models.EncoderDecoderModel.z_loss = %Z_LOSS

# Parameters for eval_script.evaluate:
# ==============================================================================
eval_script.evaluate.dataset_cfg = @utils.DatasetConfig()
eval_script.evaluate.model = %MODEL
eval_script.evaluate.output_dir = %EVAL_OUTPUT_DIR
eval_script.evaluate.partitioner = @partitioning.PjitPartitioner()
eval_script.evaluate.restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()

# Parameters for partitioning.PjitPartitioner:
# ==============================================================================
partitioning.PjitPartitioner.num_partitions = 2

# Parameters for utils.RestoreCheckpointConfig:
# ==============================================================================
utils.RestoreCheckpointConfig.mode = 'specific'
utils.RestoreCheckpointConfig.path = %CHECKPOINT_PATH

# Parameters for seqio.SentencePieceVocabulary:
# ==============================================================================
seqio.SentencePieceVocabulary.sentencepiece_model_file = \
    'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model'

# Parameters for network.T5Config:
# ==============================================================================
network.T5Config.dropout_rate = %DROPOUT_RATE
network.T5Config.dtype = 'bfloat16'
network.T5Config.emb_dim = 768
network.T5Config.head_dim = 64
network.T5Config.logits_via_embedding = False
network.T5Config.mlp_activations = ('gelu', 'linear')
network.T5Config.mlp_dim = 2048
network.T5Config.num_decoder_layers = 12
network.T5Config.num_encoder_layers = 12
network.T5Config.num_heads = 12
network.T5Config.vocab_size = 250112

# Parameters for network.Transformer:
# ==============================================================================
network.Transformer.config = @network.T5Config()