aapot commited on
Commit
e5b9953
1 Parent(s): 4de4400
.gitattributes CHANGED
@@ -31,3 +31,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
34
+ checkpoint*/** filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
base_nl36.gin ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5.1.1 Efficient base nl36 model.
2
+
3
+ import seqio
4
+ include 't5x/examples/t5/t5_1_1/base.gin' # imports vocab, optimizer and model.
5
+
6
+ # ------------------- Network specification overrides --------------------------
7
+ network.Transformer.config = @network.T5Config()
8
+ network.T5Config:
9
+ emb_dim = 768
10
+ num_heads = 12
11
+ num_encoder_layers = 36
12
+ num_decoder_layers = 36
13
+ head_dim = 64
14
+ mlp_dim = 3072
15
+
16
+ # ------------------- Model specification overrides --------------------------
17
+ VOCABULARY = @seqio.SentencePieceVocabulary()
18
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = "spiece.model"
19
+
20
+ MODEL = @models.EncoderDecoderModel()
21
+ models.EncoderDecoderModel:
22
+ input_vocabulary = %VOCABULARY
23
+ output_vocabulary = %VOCABULARY
base_nl36_pretrain.gin ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Register necessary SeqIO Tasks/Mixtures.
2
+ from __gin__ import dynamic_registration
3
+ from t5x import utils
4
+ import tasks
5
+ import __main__ as train_script
6
+
7
+ include 'base_nl36.gin'
8
+ include 't5x/configs/runs/pretrain.gin'
9
+
10
+
11
+ # ------------------- Training specification overrides --------------------------
12
+ train_script.train:
13
+ eval_period = 10000
14
+
15
+ utils.SaveCheckpointConfig:
16
+ period = 10000
17
+ keep = 10
18
+
19
+ MIXTURE_OR_TASK_NAME = "pretrain_finnish_ul2"
20
+ USE_CACHED_TASKS = False
21
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
22
+ TRAIN_STEPS = 1000000
23
+ DROPOUT_RATE = 0.0
24
+ BATCH_SIZE = 64
config.gin ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import __main__ as train_script
3
+ import seqio
4
+ from t5x import adafactor
5
+ from t5x.examples.t5 import network
6
+ from t5x import gin_utils
7
+ from t5x import models
8
+ from t5x import partitioning
9
+ from t5x import trainer
10
+ from t5x import utils
11
+ import tasks
12
+
13
+ # Macros:
14
+ # ==============================================================================
15
+ BATCH_SIZE = 64
16
+ DROPOUT_RATE = 0.0
17
+ LABEL_SMOOTHING = 0.0
18
+ LOSS_NORMALIZING_FACTOR = None
19
+ MIXTURE_OR_TASK_MODULE = None
20
+ MIXTURE_OR_TASK_NAME = 'pretrain_finnish_ul2'
21
+ MODEL = @models.EncoderDecoderModel()
22
+ MODEL_DIR = '/researchdisk/ul2-base-nl36-finnish'
23
+ OPTIMIZER = @adafactor.Adafactor()
24
+ RANDOM_SEED = None
25
+ SHUFFLE_TRAIN_EXAMPLES = True
26
+ TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512}
27
+ TRAIN_STEPS = 1000000
28
+ USE_CACHED_TASKS = False
29
+ USE_HARDWARE_RNG = False
30
+ VOCABULARY = @seqio.SentencePieceVocabulary()
31
+ Z_LOSS = 0.0001
32
+
33
+ # Parameters for adafactor.Adafactor:
34
+ # ==============================================================================
35
+ adafactor.Adafactor.decay_rate = 0.8
36
+ adafactor.Adafactor.logical_factor_rules = \
37
+ @adafactor.standard_logical_factor_rules()
38
+ adafactor.Adafactor.step_offset = 0
39
+
40
+ # Parameters for utils.CheckpointConfig:
41
+ # ==============================================================================
42
+ utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
43
+ utils.CheckpointConfig.save = @utils.SaveCheckpointConfig()
44
+
45
+ # Parameters for utils.create_learning_rate_scheduler:
46
+ # ==============================================================================
47
+ utils.create_learning_rate_scheduler.base_learning_rate = 1.0
48
+ utils.create_learning_rate_scheduler.factors = 'constant * rsqrt_decay'
49
+ utils.create_learning_rate_scheduler.warmup_steps = 10000
50
+
51
+ # Parameters for train/utils.DatasetConfig:
52
+ # ==============================================================================
53
+ train/utils.DatasetConfig.batch_size = %BATCH_SIZE
54
+ train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
55
+ train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
56
+ train/utils.DatasetConfig.pack = True
57
+ train/utils.DatasetConfig.seed = None
58
+ train/utils.DatasetConfig.shuffle = %SHUFFLE_TRAIN_EXAMPLES
59
+ train/utils.DatasetConfig.split = 'train'
60
+ train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
61
+ train/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
62
+
63
+ # Parameters for train_eval/utils.DatasetConfig:
64
+ # ==============================================================================
65
+ train_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE
66
+ train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
67
+ train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
68
+ train_eval/utils.DatasetConfig.pack = True
69
+ train_eval/utils.DatasetConfig.seed = 42
70
+ train_eval/utils.DatasetConfig.shuffle = False
71
+ train_eval/utils.DatasetConfig.split = 'validation'
72
+ train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
73
+ train_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
74
+
75
+ # Parameters for models.EncoderDecoderModel:
76
+ # ==============================================================================
77
+ models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
78
+ models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
79
+ models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
80
+ models.EncoderDecoderModel.module = @network.Transformer()
81
+ models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
82
+ models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
83
+ models.EncoderDecoderModel.z_loss = %Z_LOSS
84
+
85
+ # Parameters for partitioning.PjitPartitioner:
86
+ # ==============================================================================
87
+ partitioning.PjitPartitioner.logical_axis_rules = \
88
+ @partitioning.standard_logical_axis_rules()
89
+ partitioning.PjitPartitioner.model_parallel_submesh = None
90
+ partitioning.PjitPartitioner.num_partitions = 1
91
+
92
+ # Parameters for utils.RestoreCheckpointConfig:
93
+ # ==============================================================================
94
+ utils.RestoreCheckpointConfig.path = []
95
+
96
+ # Parameters for utils.SaveCheckpointConfig:
97
+ # ==============================================================================
98
+ utils.SaveCheckpointConfig.dtype = 'float32'
99
+ utils.SaveCheckpointConfig.keep = 10
100
+ utils.SaveCheckpointConfig.period = 10000
101
+ utils.SaveCheckpointConfig.save_dataset = False
102
+
103
+ # Parameters for seqio.SentencePieceVocabulary:
104
+ # ==============================================================================
105
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = 'spiece.model'
106
+
107
+ # Parameters for network.T5Config:
108
+ # ==============================================================================
109
+ network.T5Config.dropout_rate = %DROPOUT_RATE
110
+ network.T5Config.dtype = 'bfloat16'
111
+ network.T5Config.emb_dim = 768
112
+ network.T5Config.head_dim = 64
113
+ network.T5Config.logits_via_embedding = False
114
+ network.T5Config.mlp_activations = ('gelu', 'linear')
115
+ network.T5Config.mlp_dim = 3072
116
+ network.T5Config.num_decoder_layers = 36
117
+ network.T5Config.num_encoder_layers = 36
118
+ network.T5Config.num_heads = 12
119
+ network.T5Config.vocab_size = 32128
120
+
121
+ # Parameters for train_script.train:
122
+ # ==============================================================================
123
+ train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
124
+ train_script.train.eval_period = 10000
125
+ train_script.train.eval_steps = 20
126
+ train_script.train.infer_eval_dataset_cfg = None
127
+ train_script.train.model = %MODEL
128
+ train_script.train.model_dir = %MODEL_DIR
129
+ train_script.train.partitioner = @partitioning.PjitPartitioner()
130
+ train_script.train.random_seed = %RANDOM_SEED
131
+ train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config
132
+ train_script.train.total_steps = %TRAIN_STEPS
133
+ train_script.train.train_dataset_cfg = @train/utils.DatasetConfig()
134
+ train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
135
+ train_script.train.trainer_cls = @trainer.Trainer
136
+ train_script.train.use_hardware_rng = %USE_HARDWARE_RNG
137
+
138
+ # Parameters for trainer.Trainer:
139
+ # ==============================================================================
140
+ trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler()
141
+ trainer.Trainer.num_microbatches = None
142
+
143
+ # Parameters for network.Transformer:
144
+ # ==============================================================================
145
+ network.Transformer.config = @network.T5Config()
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5ForConditionalGeneration"
4
+ ],
5
+ "d_ff": 3072,
6
+ "d_kv": 64,
7
+ "d_model": 768,
8
+ "decoder_start_token_id": 0,
9
+ "dense_act_fn": "gelu_new",
10
+ "dropout_rate": 0.1,
11
+ "eos_token_id": 1,
12
+ "feed_forward_proj": "gated-gelu",
13
+ "initializer_factor": 1.0,
14
+ "is_encoder_decoder": true,
15
+ "is_gated_act": true,
16
+ "layer_norm_epsilon": 1e-06,
17
+ "model_type": "t5",
18
+ "n_positions": 512,
19
+ "num_decoder_layers": 36,
20
+ "num_heads": 12,
21
+ "num_layers": 36,
22
+ "output_past": true,
23
+ "pad_token_id": 0,
24
+ "relative_attention_max_distance": 128,
25
+ "relative_attention_num_buckets": 32,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.20.1",
29
+ "use_cache": true,
30
+ "vocab_size": 32128
31
+ }
convert_t5x_checkpoint_to_flax.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://gist.github.com/stefan-it/30e4998ef159f33696e377a46f699d9f
2
+
3
+ import argparse
4
+
5
+ from t5x import checkpoints
6
+ from transformers import T5Config, FlaxT5ForConditionalGeneration, AutoModelForSeq2SeqLM
7
+ import torch
8
+
9
+
10
+ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
11
+ config = T5Config.from_pretrained(config_name)
12
+ flax_model = FlaxT5ForConditionalGeneration(config=config)
13
+ t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
14
+
15
+ split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
16
+
17
+ # Encoder
18
+ for layer_index in range(config.num_layers):
19
+ layer_name = f"layers_{str(layer_index)}"
20
+
21
+ # Self-Attention
22
+ t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
23
+ t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
24
+ t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
25
+ t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
26
+
27
+ ## Layer Normalization
28
+ t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
29
+
30
+ if split_mlp_wi:
31
+ t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
32
+ t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
33
+ else:
34
+ t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
35
+
36
+ t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
37
+
38
+ ## Layer Normalization
39
+ t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
40
+
41
+ # Assigning
42
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
43
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
44
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
45
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
46
+
47
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
48
+
49
+ if split_mlp_wi:
50
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
51
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
52
+ else:
53
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
54
+
55
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
56
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
57
+
58
+ # Only for layer 0:
59
+ t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
60
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding
61
+
62
+ # Assigning
63
+ t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
64
+ flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
65
+
66
+ # Decoder
67
+ for layer_index in range(config.num_decoder_layers):
68
+ layer_name = f"layers_{str(layer_index)}"
69
+
70
+ # Self-Attention
71
+ t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
72
+ t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
73
+ t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
74
+ t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
75
+
76
+ ## Layer Normalization
77
+ t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"]["scale"]
78
+
79
+ # Encoder-Decoder-Attention
80
+ t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"]["kernel"]
81
+ t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"]["kernel"]
82
+ t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"]["kernel"]
83
+ t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"]["kernel"]
84
+
85
+ ## Layer Normalization
86
+ t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
87
+
88
+ # MLP
89
+ if split_mlp_wi:
90
+ t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
91
+ t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
92
+ else:
93
+ t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
94
+
95
+ t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
96
+
97
+ ## Layer Normalization
98
+ tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
99
+
100
+ # Assigning
101
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
102
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
103
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
104
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
105
+
106
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
107
+
108
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
109
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
110
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
111
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
112
+
113
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
114
+
115
+ if split_mlp_wi:
116
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
117
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
118
+ else:
119
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
120
+
121
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
122
+
123
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
124
+
125
+ # Decoder Normalization
126
+ tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
127
+ flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
128
+
129
+ # Only for layer 0:
130
+ t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
131
+ flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_decoder_rel_embedding
132
+
133
+ # Token Embeddings
134
+ tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
135
+ flax_model.params["shared"]["embedding"] = tx5_token_embeddings
136
+
137
+ # LM Head
138
+ flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
139
+
140
+ flax_model.save_pretrained(flax_dump_folder_path)
141
+ print("T5X Model was sucessfully converted!")
142
+
143
+ def convert_flax_to_pytorch(flax_dump_folder_path, pytorch_dump_folder_path):
144
+ model = AutoModelForSeq2SeqLM.from_pretrained(flax_dump_folder_path, from_flax=True, torch_dtype=torch.float32)
145
+ model.save_pretrained(pytorch_dump_folder_path)
146
+ print("Flax model was sucessfully converted to Pytorch!")
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser()
150
+ # Required parameters
151
+ parser.add_argument(
152
+ "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
153
+ )
154
+ parser.add_argument(
155
+ "--config_name", default=None, type=str, required=True, help="Config name of T5 model."
156
+ )
157
+ parser.add_argument(
158
+ "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
159
+ )
160
+ args = parser.parse_args()
161
+ convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
162
+ convert_flax_to_pytorch(args.flax_dump_folder_path, args.flax_dump_folder_path)
163
+
model-info.txt ADDED
The diff for this file is too large to render. See raw diff
 
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b39375bd095c6d78f90949701049adc8f2a8fc1350c8c569c68ba079c80bda3f
3
+ size 821929
spiece.vocab ADDED
The diff for this file is too large to render. See raw diff
 
start_train.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set train hyperparams
2
+ unset LD_PRELOAD
3
+
4
+ PROJECT_DIR="/researchdisk/ul2-base-nl36-finnish"
5
+ T5X_DIR=${HOME}"/t5x" # directory where the t5x is cloned.
6
+ MODEL_DIR="/researchdisk/ul2-base-nl36-finnish"
7
+ export PYTHONPATH=${PROJECT_DIR}
8
+
9
+ python3 ${T5X_DIR}/t5x/train.py \
10
+ --gin_search_paths=${PROJECT_DIR} \
11
+ --gin_file="base_nl36_pretrain.gin" \
12
+ --gin.MODEL_DIR=\"${MODEL_DIR}\"
tasks.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import seqio
3
+ import tensorflow as tf
4
+ import t5.data
5
+ from datasets import load_dataset, load_from_disk
6
+ from t5.data import postprocessors
7
+ from t5.data import preprocessors
8
+ from t5.evaluation import metrics
9
+ from seqio import FunctionDataSource, utils
10
+
11
+ from ul2_objective import ul2_objective
12
+
13
+ # values from UL2 paper https://arxiv.org/pdf/2205.05131.pdf chapter 3.1.2 table 1
14
+ R_DENOISER_SPAN_LENGTHS = [3.0, 8.0]
15
+ X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0]
16
+ R_DENOISER_CORRUPT_RATES = [0.15, 0.15]
17
+ X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5]
18
+
19
+ R_DENOISER_TOKEN_PREFIX = '[NLU]'
20
+ X_DENOISER_TOKEN_PREFIX = '[NLG]'
21
+ S_DENOISER_TOKEN_PREFIX = '[S2S]'
22
+
23
+ TaskRegistry = seqio.TaskRegistry
24
+
25
+ vocabulary = seqio.SentencePieceVocabulary('spiece.model', extra_ids=0)
26
+
27
+ DEFAULT_OUTPUT_FEATURES = {
28
+ "inputs": seqio.Feature(
29
+ vocabulary=vocabulary, add_eos=True,
30
+ required=False),
31
+ "targets": seqio.Feature(
32
+ vocabulary=vocabulary, add_eos=True)
33
+ }
34
+
35
+
36
+ def gen_dataset(split, shuffle=False, seed=None, column="text", dataset=None):
37
+ if shuffle:
38
+ if seed:
39
+ dataset = dataset.shuffle(seed=seed)
40
+ else:
41
+ dataset = dataset.shuffle()
42
+ while True:
43
+ for item in dataset[str(split)]:
44
+ if item[column] is not None:
45
+ yield item[column]
46
+
47
+
48
+ def dataset_fn(split, shuffle_files, seed=None, dataset=None):
49
+ return tf.data.Dataset.from_generator(
50
+ functools.partial(gen_dataset, split, shuffle_files,
51
+ seed, dataset=dataset),
52
+ output_signature=tf.TensorSpec(
53
+ shape=(), dtype=tf.string, name=dataset_name)
54
+ )
55
+
56
+
57
+ @utils.map_over_dataset
58
+ def target_to_key(x, key_map, target_key):
59
+ """Assign the value from the dataset to target_key in key_map"""
60
+ return {**key_map, target_key: x}
61
+
62
+
63
+ dataset_name = "/researchdisk/lm_training_dataset_full"
64
+ dataset_params = {"from_disk_path": dataset_name}
65
+
66
+ if "from_disk_path" in dataset_params:
67
+ dataset = load_from_disk(dataset_params.get("from_disk_path"))
68
+ else:
69
+ dataset = load_dataset(**dataset_params)
70
+
71
+ dataset_shapes = {"train": dataset["train"].num_rows,
72
+ "validation": dataset["validation"].num_rows}
73
+
74
+ TaskRegistry.add(
75
+ "pretrain_finnish_ul2",
76
+ source=seqio.FunctionDataSource(
77
+ dataset_fn=functools.partial(dataset_fn, dataset=dataset),
78
+ splits=("train", "validation"),
79
+ caching_permitted=False,
80
+ num_input_examples=dataset_shapes,
81
+ ),
82
+ preprocessors=[
83
+ functools.partial(
84
+ target_to_key, key_map={
85
+ "inputs": None,
86
+ "targets": None,
87
+ }, target_key="targets"),
88
+ seqio.preprocessors.tokenize,
89
+ functools.partial(
90
+ ul2_objective,
91
+ shard_ds=False,
92
+ use_prefix_lm_task=True, # use S-denoising
93
+ rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)]*len(R_DENOISER_SPAN_LENGTHS) + [
94
+ 0.4 / len(X_DENOISER_SPAN_LENGTHS)]*len(X_DENOISER_SPAN_LENGTHS) + [0.2], # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
95
+ mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
96
+ noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
97
+ optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]*len(R_DENOISER_SPAN_LENGTHS) + [
98
+ X_DENOISER_TOKEN_PREFIX]*len(X_DENOISER_SPAN_LENGTHS) + [S_DENOISER_TOKEN_PREFIX],
99
+ reserved_for_packing=1, # make room for task prefix token
100
+ ),
101
+ seqio.preprocessors.append_eos_after_trim,
102
+ ],
103
+ output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
104
+ metric_fns=[metrics.accuracy]
105
+ )
train/events.out.tfevents.1666371310.t1v-n-12f94ad0-w-0.556051.0.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ec174375c1c148e5ef07745a54fd61dc130b962b15bad8738104f278650d97f
3
+ size 7475
train_sentencepiece.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+
3
+ spm.SentencePieceTrainer.train(input="/researchdisk/training_dataset_sentences/train.txt", model_prefix='spiece', vocab_size=32000, character_coverage=1.0,
4
+ pad_id=0, unk_id=2, eos_id=1, bos_id=-1,
5
+ user_defined_symbols=['[NLU]', '[NLG]', '[S2S]'],
6
+ train_extremely_large_corpus=True,
7
+ num_threads=96, input_sentence_size=50000000, shuffle_input_sentence=True)
training_eval/pretrain_finnish_ul2/events.out.tfevents.1666371310.t1v-n-12f94ad0-w-0.556051.1.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d4be65a8327d80518f5112f644459714a7c9b936736d60465b726bb44abfb6d
3
+ size 1414
ul2_objective.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import tensorflow as tf
3
+ import seqio
4
+ import t5.data
5
+ from typing import Optional, Sequence
6
+
7
+ # found this function and modified from https://github.com/GoogleCloudPlatform/t5x-on-vertex-ai/blob/main/tasks/custom_tasks.py#L78
8
+ # UL2 paper appendix code missed this function
9
+ def prepend_prompt(dataset: tf.data.Dataset,
10
+ output_features: seqio.preprocessors.OutputFeaturesType,
11
+ sequence_length: Optional[
12
+ seqio.preprocessors.SequenceLengthType] = None,
13
+ prompt_mode: str = "",
14
+ key: str = "inputs",
15
+ mode: str = "") -> tf.data.Dataset:
16
+ """Prepends a prompt at the beginning of an input sequence."""
17
+ del sequence_length
18
+ if prompt_mode and mode:
19
+ # output_features may not have inputs key
20
+ out_keys = list(output_features.keys())
21
+ prompt_tokens = output_features[out_keys[0]
22
+ ].vocabulary.encode_tf(prompt_mode)
23
+
24
+ def add_to_inputs(x):
25
+ x[key] = tf.concat([prompt_tokens, x[key]], axis=0)
26
+ return x
27
+
28
+ dataset = dataset.map(add_to_inputs)
29
+ return dataset
30
+
31
+ # modified from t5.data.preprocessors because output_features may not have inputs key
32
+ def split_tokens_to_inputs_length(dataset, sequence_length,
33
+ output_features, **kwargs):
34
+ max_tokens = sequence_length['inputs']
35
+ # output_features may not have inputs key
36
+ out_keys = list(output_features.keys())
37
+ if output_features[out_keys[0]].add_eos:
38
+ # Leave room to insert an EOS token.
39
+ max_tokens -= 1
40
+
41
+ return t5.data.preprocessors.split_tokens(dataset, max_tokens_per_segment=max_tokens, **kwargs)
42
+
43
+ # modified from t5.data.preprocessors because output_features may not have inputs key
44
+ def prefix_lm(dataset, sequence_length, output_features):
45
+ """Prefix language modeling objective used in Raffel et al. 2019."""
46
+ ds = dataset
47
+ ds = t5.data.preprocessors.select_random_chunk(ds, output_features=output_features,
48
+ feature_key='targets', max_length=65536)
49
+ ds = split_tokens_to_inputs_length(ds, output_features=output_features,
50
+ sequence_length=sequence_length)
51
+ ds = t5.data.preprocessors.denoise(
52
+ ds,
53
+ output_features,
54
+ inputs_fn=t5.data.preprocessors.drop_nonnoise_tokens,
55
+ targets_fn=t5.data.preprocessors.drop_noise_tokens,
56
+ noise_density=0.5,
57
+ noise_mask_fn=t5.data.preprocessors.random_prefix_noise_mask,
58
+ )
59
+ return ds
60
+
61
+ # copied from UL2 paper https://arxiv.org/pdf/2205.05131.pdf appendix chapter 9.2
62
+ # note: modified to use the prefix_lm() from above instead of the default t5.data.preprocessors.prefix_lm() because output_features may not have inputs key
63
+ def ul2_objective(dataset: tf.data.Dataset,
64
+ sequence_length: seqio.preprocessors.SequenceLengthType,
65
+ output_features: seqio.preprocessors.OutputFeaturesType,
66
+ use_prefix_lm_task: bool = False,
67
+ rates: Optional[Sequence[float]] = None,
68
+ mean_noise_span_lengths: Sequence[float] = (3.0,),
69
+ noise_densities: Sequence[float] = (0.15,),
70
+ shard_ds: bool = True,
71
+ optional_task_prefixes: Optional[Sequence[str]] = None,
72
+ input_feature_key: str = "inputs",
73
+ merge_examples_to_reduce_padding: bool = True,
74
+ reserved_for_packing: bool = None,
75
+ seed: int = 7) -> tf.data.Dataset:
76
+ """UL2-like pre-training objectives.
77
+ This preprocessor amounts to calling the 'span_corruption' function several
78
+ times with different values of 'noise_density' and 'mean_noise_span_length'.
79
+ We either shard or copy the dataset, then apply each function to each shard.
80
+ Add S-denoising (prefixLM) using use_prefix_lm_task.
81
+ Args:
82
+ dataset: A tf.data.Dataset with dictionaries containing the key 'input_feature_key'.
83
+ sequence_length: dict mapping of feature key to int length for that feature.
84
+ output_features: mapping of keys to features.
85
+ use_prefix_lm_task: <bool> If True, include PrefixLM in the task mix.
86
+ rates: <Optional<List<float>> List of rates per task. If None, tasks are sampled uniformly.
87
+ mean_noise_span_lengths: List of mean number of tokens per masked span per example.
88
+ noise_densities: List of what fraction of the tokens to mask.
89
+ shard_ds: <bool> If True, shard dataset per objective.
90
+ optional_task_prefixes: <Optional<list<str>> Strings to prepend for each corruption scheme. NOTE: If including prefixLM task, it must be the last prefix.
91
+ input_feature_key: which feature to use from the dataset as the input text tokens.
92
+ merge_examples_to_reduce_padding: if True, combines multiple input examples to reduce padding.
93
+ reserved_for_packing: if specified, reduces the desired inputs length by the specified amount to enable multiple examples to be packed together downstream.
94
+ seed: tf.int64 for controlling the random choice of spans.
95
+ Returns:
96
+ a dataset
97
+ """
98
+
99
+ if optional_task_prefixes: # Ensure each task has a prefix.
100
+ num_tasks = len(noise_densities) + int(use_prefix_lm_task)
101
+ valid_number_of_prefixes = num_tasks == len(optional_task_prefixes)
102
+ if not valid_number_of_prefixes:
103
+ raise ValueError(
104
+ "Number of task prefixes must match number of tasks.")
105
+ inputs_length = sequence_length[input_feature_key]
106
+ input_lengths, targets_lengths = [], []
107
+ sequence_lengths = {x: y for x, y in sequence_length.items()}
108
+ if reserved_for_packing:
109
+ inputs_length -= reserved_for_packing
110
+ for x, y in sequence_length.items():
111
+ sequence_lengths[x] = y - reserved_for_packing
112
+ hyperparams = list(zip(mean_noise_span_lengths, noise_densities))
113
+ for mean_noise_span_length, noise_density in hyperparams:
114
+ input_length, targets_length = t5.data.preprocessors.random_spans_helper(
115
+ extra_tokens_per_span_inputs=1,
116
+ extra_tokens_per_span_targets=1,
117
+ inputs_length=inputs_length,
118
+ mean_noise_span_length=mean_noise_span_length,
119
+ noise_density=noise_density)
120
+ input_lengths.append(input_length)
121
+ targets_lengths.append(targets_length)
122
+
123
+ if sequence_length["targets"] < targets_length:
124
+ upper_bound = max(targets_lengths)
125
+ raise ValueError(
126
+ f'Expected max targets length for span corruption ({upper_bound}) is '
127
+ f'greater than configured targets length '
128
+ f"({sequence_length['targets']})")
129
+ ds = dataset
130
+ ds = t5.data.preprocessors.select_random_chunk(
131
+ ds,
132
+ output_features=output_features,
133
+ feature_key="targets",
134
+ max_length=65536)
135
+ if merge_examples_to_reduce_padding:
136
+ ds = t5.data.preprocessors.reduce_concat_tokens(
137
+ ds, feature_key="targets", batch_size=128)
138
+ num_shards = len(input_lengths) + int(use_prefix_lm_task)
139
+ if shard_ds:
140
+ ds_shards = [ds.shard(num_shards, i) for i in range(num_shards)]
141
+ else:
142
+ ds_shards = [ds for _ in range(num_shards)]
143
+ processed_ds = []
144
+ hyperparams = zip(input_lengths, hyperparams, range(num_shards))
145
+ for input_length, (noise_span_length, noise_density), i in hyperparams:
146
+ ds = ds_shards[i]
147
+ ds = t5.data.preprocessors.split_tokens(
148
+ ds,
149
+ feature_key="targets",
150
+ min_tokens_per_segment=None,
151
+ max_tokens_per_segment=input_length)
152
+ ds = t5.data.preprocessors.denoise(
153
+ ds,
154
+ output_features,
155
+ inputs_fn=t5.data.preprocessors.noise_span_to_unique_sentinel,
156
+ targets_fn=t5.data.preprocessors.nonnoise_span_to_unique_sentinel,
157
+ noise_density=noise_density,
158
+ noise_mask_fn=functools.partial(
159
+ t5.data.preprocessors.random_spans_noise_mask,
160
+ mean_noise_span_length=noise_span_length),
161
+ input_feature_key=input_feature_key)
162
+ if optional_task_prefixes:
163
+ ds = prepend_prompt(
164
+ ds,
165
+ output_features,
166
+ prompt_mode=optional_task_prefixes[i],
167
+ mode=optional_task_prefixes[i],
168
+ key=input_feature_key)
169
+ processed_ds.append(ds)
170
+ if use_prefix_lm_task:
171
+ ds = ds_shards[-1]
172
+ ds = prefix_lm(
173
+ ds, sequence_lengths, output_features)
174
+ if optional_task_prefixes:
175
+ ds = prepend_prompt(
176
+ ds,
177
+ output_features,
178
+ prompt_mode=optional_task_prefixes[-1],
179
+ mode=optional_task_prefixes[-1],
180
+ key=input_feature_key)
181
+ processed_ds.append(ds)
182
+ ds = tf.data.experimental.sample_from_datasets(processed_ds, rates, seed)
183
+ return ds