sultan commited on
Commit
a6ee278
1 Parent(s): dcee2eb

Upload 7 files

Browse files
config.gin ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import __main__ as train_script
3
+ import seqio
4
+ import t5.data.mixtures
5
+ from t5x import adafactor
6
+ from t5x.examples.t5 import network
7
+ from t5x import gin_utils
8
+ from t5x import models
9
+ from t5x import partitioning
10
+ from t5x import trainer
11
+ from t5x import utils
12
+ import tasks
13
+
14
+ # Macros:
15
+ # ==============================================================================
16
+ BATCH_SIZE = 256
17
+ DROPOUT_RATE = 0.0
18
+ LABEL_SMOOTHING = 0.0
19
+ LOSS_NORMALIZING_FACTOR = None
20
+ MIXTURE_OR_TASK_MODULE = None
21
+ MIXTURE_OR_TASK_NAME = 'arabic_dataset'
22
+ MODEL = @models.EncoderDecoderModel()
23
+ MODEL_DIR = 'gs://sultan-t5x/arabict5_base'
24
+ OPTIMIZER = @adafactor.Adafactor()
25
+ RANDOM_SEED = None
26
+ SHUFFLE_TRAIN_EXAMPLES = True
27
+ TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 114}
28
+ TRAIN_STEPS = 2000000
29
+ USE_CACHED_TASKS = False
30
+ USE_HARDWARE_RNG = False
31
+ VOCABULARY = @seqio.SentencePieceVocabulary()
32
+ Z_LOSS = 0.0001
33
+
34
+ # Parameters for adafactor.Adafactor:
35
+ # ==============================================================================
36
+ adafactor.Adafactor.decay_rate = 0.8
37
+ adafactor.Adafactor.logical_factor_rules = \
38
+ @adafactor.standard_logical_factor_rules()
39
+ adafactor.Adafactor.step_offset = 0
40
+
41
+ # Parameters for utils.CheckpointConfig:
42
+ # ==============================================================================
43
+ utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
44
+ utils.CheckpointConfig.save = @utils.SaveCheckpointConfig()
45
+
46
+ # Parameters for utils.create_learning_rate_scheduler:
47
+ # ==============================================================================
48
+ utils.create_learning_rate_scheduler.base_learning_rate = 1.0
49
+ utils.create_learning_rate_scheduler.factors = 'constant * rsqrt_decay'
50
+ utils.create_learning_rate_scheduler.warmup_steps = 10000
51
+
52
+ # Parameters for train/utils.DatasetConfig:
53
+ # ==============================================================================
54
+ train/utils.DatasetConfig.batch_size = %BATCH_SIZE
55
+ train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
56
+ train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
57
+ train/utils.DatasetConfig.pack = True
58
+ train/utils.DatasetConfig.seed = None
59
+ train/utils.DatasetConfig.shuffle = %SHUFFLE_TRAIN_EXAMPLES
60
+ train/utils.DatasetConfig.split = 'train'
61
+ train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
62
+ train/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
63
+
64
+ # Parameters for train_eval/utils.DatasetConfig:
65
+ # ==============================================================================
66
+ train_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE
67
+ train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
68
+ train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
69
+ train_eval/utils.DatasetConfig.pack = True
70
+ train_eval/utils.DatasetConfig.seed = 42
71
+ train_eval/utils.DatasetConfig.shuffle = False
72
+ train_eval/utils.DatasetConfig.split = 'validation'
73
+ train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
74
+ train_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
75
+
76
+ # Parameters for models.EncoderDecoderModel:
77
+ # ==============================================================================
78
+ models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
79
+ models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
80
+ models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
81
+ models.EncoderDecoderModel.module = @network.Transformer()
82
+ models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
83
+ models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
84
+ models.EncoderDecoderModel.z_loss = %Z_LOSS
85
+
86
+ # Parameters for partitioning.PjitPartitioner:
87
+ # ==============================================================================
88
+ partitioning.PjitPartitioner.logical_axis_rules = \
89
+ @partitioning.standard_logical_axis_rules()
90
+ partitioning.PjitPartitioner.model_parallel_submesh = None
91
+ partitioning.PjitPartitioner.num_partitions = 1
92
+
93
+ # Parameters for utils.RestoreCheckpointConfig:
94
+ # ==============================================================================
95
+ utils.RestoreCheckpointConfig.path = []
96
+
97
+ # Parameters for utils.SaveCheckpointConfig:
98
+ # ==============================================================================
99
+ utils.SaveCheckpointConfig.dtype = 'float32'
100
+ utils.SaveCheckpointConfig.keep = None
101
+ utils.SaveCheckpointConfig.period = 50000
102
+ utils.SaveCheckpointConfig.save_dataset = False
103
+
104
+ # Parameters for seqio.SentencePieceVocabulary:
105
+ # ==============================================================================
106
+ seqio.SentencePieceVocabulary.extra_ids = 100
107
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = \
108
+ 'gs://sultan-t5x/spiece.model'
109
+
110
+ # Parameters for network.T5Config:
111
+ # ==============================================================================
112
+ network.T5Config.dropout_rate = %DROPOUT_RATE
113
+ network.T5Config.dtype = 'bfloat16'
114
+ network.T5Config.emb_dim = 768
115
+ network.T5Config.head_dim = 64
116
+ network.T5Config.logits_via_embedding = False
117
+ network.T5Config.mlp_activations = ('gelu', 'linear')
118
+ network.T5Config.mlp_dim = 2048
119
+ network.T5Config.num_decoder_layers = 16
120
+ network.T5Config.num_encoder_layers = 16
121
+ network.T5Config.num_heads = 12
122
+ network.T5Config.vocab_size = 32128
123
+
124
+ # Parameters for train_script.train:
125
+ # ==============================================================================
126
+ train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
127
+ train_script.train.eval_period = 5000
128
+ train_script.train.eval_steps = 20
129
+ train_script.train.infer_eval_dataset_cfg = None
130
+ train_script.train.model = %MODEL
131
+ train_script.train.model_dir = %MODEL_DIR
132
+ train_script.train.partitioner = @partitioning.PjitPartitioner()
133
+ train_script.train.random_seed = %RANDOM_SEED
134
+ train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config
135
+ train_script.train.total_steps = %TRAIN_STEPS
136
+ train_script.train.train_dataset_cfg = @train/utils.DatasetConfig()
137
+ train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
138
+ train_script.train.trainer_cls = @trainer.Trainer
139
+ train_script.train.use_hardware_rng = %USE_HARDWARE_RNG
140
+
141
+ # Parameters for trainer.Trainer:
142
+ # ==============================================================================
143
+ trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler()
144
+ trainer.Trainer.num_microbatches = None
145
+
146
+ # Parameters for network.Transformer:
147
+ # ==============================================================================
148
+ network.Transformer.config = @network.T5Config()
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5ForConditionalGeneration"
4
+ ],
5
+ "classifier_dropout": 0.0,
6
+ "d_ff": 2048,
7
+ "d_kv": 64,
8
+ "d_model": 768,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "gelu_new",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 1,
13
+ "feed_forward_proj": "gated-gelu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": true,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "num_decoder_layers": 16,
20
+ "num_heads": 12,
21
+ "num_layers": 16,
22
+ "output_past": true,
23
+ "pad_token_id": 0,
24
+ "relative_attention_max_distance": 128,
25
+ "relative_attention_num_buckets": 32,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.34.0",
29
+ "use_cache": true,
30
+ "vocab_size": 32128
31
+ }
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3d8ffebf6ae4667cfecf46a1b0e611aaa336f9a445c086a226875b6c1c38e72
3
+ size 1254630175
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.34.0"
7
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bc91038e5fdfe8dada9106edce062b3f625ca04b20531e9e4e669041aeb96dd
3
+ size 1254738594
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e0e213aa09cfbd551379f003065090bd61adc4d6a41cf298f147d3111048920
3
+ size 893053
spiece.vocab ADDED
The diff for this file is too large to render. See raw diff