Siddharth63 commited on
Commit
98c533d
1 Parent(s): 238eef6

new additions

Browse files
README.md CHANGED
@@ -1,3 +1,3 @@
1
  ---
2
- license: afl-3.0
3
  ---
 
1
  ---
2
+ license: other
3
  ---
config.gin ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import __main__ as train_script
3
+ import seqio
4
+ from t5x import adafactor
5
+ from t5x.examples.t5 import network
6
+ from t5x import gin_utils
7
+ from t5x import models
8
+ from t5x import partitioning
9
+ from t5x import trainer
10
+ from t5x import utils
11
+ import ul2_tasks
12
+
13
+ # Macros:
14
+ # ==============================================================================
15
+ BATCH_SIZE = 64
16
+ DROPOUT_RATE = 0.0
17
+ LABEL_SMOOTHING = 0.0
18
+ LOSS_NORMALIZING_FACTOR = None
19
+ MIXTURE_OR_TASK_MODULE = None
20
+ MIXTURE_OR_TASK_NAME = 'pretrain_biological_ul2'
21
+ MODEL = @models.EncoderDecoderModel()
22
+ MODEL_DIR = '/models/bioul2-base'
23
+ OPTIMIZER = @adafactor.Adafactor()
24
+ RANDOM_SEED = None
25
+ SHUFFLE_TRAIN_EXAMPLES = True
26
+ TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 512}
27
+ TRAIN_STEPS = 1000000
28
+ USE_CACHED_TASKS = False
29
+ USE_HARDWARE_RNG = False
30
+ VOCABULARY = @seqio.SentencePieceVocabulary()
31
+ Z_LOSS = 0.0001
32
+
33
+ # Parameters for adafactor.Adafactor:
34
+ # ==============================================================================
35
+ adafactor.Adafactor.decay_rate = 0.8
36
+ adafactor.Adafactor.logical_factor_rules = \
37
+ @adafactor.standard_logical_factor_rules()
38
+ adafactor.Adafactor.step_offset = 0
39
+
40
+ # Parameters for utils.CheckpointConfig:
41
+ # ==============================================================================
42
+ utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig()
43
+ utils.CheckpointConfig.save = @utils.SaveCheckpointConfig()
44
+
45
+ # Parameters for utils.create_learning_rate_scheduler:
46
+ # ==============================================================================
47
+ utils.create_learning_rate_scheduler.base_learning_rate = 1.0
48
+ utils.create_learning_rate_scheduler.factors = 'constant * rsqrt_decay'
49
+ utils.create_learning_rate_scheduler.warmup_steps = 10000
50
+
51
+ # Parameters for train/utils.DatasetConfig:
52
+ # ==============================================================================
53
+ train/utils.DatasetConfig.batch_size = %BATCH_SIZE
54
+ train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
55
+ train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
56
+ train/utils.DatasetConfig.pack = True
57
+ train/utils.DatasetConfig.seed = None
58
+ train/utils.DatasetConfig.shuffle = %SHUFFLE_TRAIN_EXAMPLES
59
+ train/utils.DatasetConfig.split = 'train'
60
+ train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
61
+ train/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
62
+
63
+ # Parameters for train_eval/utils.DatasetConfig:
64
+ # ==============================================================================
65
+ train_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE
66
+ train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
67
+ train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE
68
+ train_eval/utils.DatasetConfig.pack = True
69
+ train_eval/utils.DatasetConfig.seed = 42
70
+ train_eval/utils.DatasetConfig.shuffle = False
71
+ train_eval/utils.DatasetConfig.split = 'validation'
72
+ train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS
73
+ train_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS
74
+
75
+ # Parameters for models.EncoderDecoderModel:
76
+ # ==============================================================================
77
+ models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
78
+ models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
79
+ models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
80
+ models.EncoderDecoderModel.module = @network.Transformer()
81
+ models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
82
+ models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
83
+ models.EncoderDecoderModel.z_loss = %Z_LOSS
84
+
85
+ # Parameters for partitioning.PjitPartitioner:
86
+ # ==============================================================================
87
+ partitioning.PjitPartitioner.logical_axis_rules = \
88
+ @partitioning.standard_logical_axis_rules()
89
+ partitioning.PjitPartitioner.model_parallel_submesh = None
90
+ partitioning.PjitPartitioner.num_partitions = 1
91
+
92
+ # Parameters for utils.RestoreCheckpointConfig:
93
+ # ==============================================================================
94
+ utils.RestoreCheckpointConfig.path = []
95
+
96
+ # Parameters for utils.SaveCheckpointConfig:
97
+ # ==============================================================================
98
+ utils.SaveCheckpointConfig.dtype = 'float32'
99
+ utils.SaveCheckpointConfig.keep = 10
100
+ utils.SaveCheckpointConfig.period = 10000
101
+ utils.SaveCheckpointConfig.save_dataset = False
102
+
103
+ # Parameters for seqio.SentencePieceVocabulary:
104
+ # ==============================================================================
105
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = 'spiece.model'
106
+
107
+ # Parameters for network.T5Config:
108
+ # ==============================================================================
109
+ network.T5Config.dropout_rate = %DROPOUT_RATE
110
+ network.T5Config.dtype = 'bfloat16'
111
+ network.T5Config.emb_dim = 768
112
+ network.T5Config.head_dim = 64
113
+ network.T5Config.logits_via_embedding = False
114
+ network.T5Config.mlp_activations = ('gelu', 'linear')
115
+ network.T5Config.mlp_dim = 2048
116
+ network.T5Config.num_decoder_layers = 12
117
+ network.T5Config.num_encoder_layers = 12
118
+ network.T5Config.num_heads = 12
119
+ network.T5Config.vocab_size = 32128
120
+
121
+ # Parameters for train_script.train:
122
+ # ==============================================================================
123
+ train_script.train.checkpoint_cfg = @utils.CheckpointConfig()
124
+ train_script.train.eval_period = 2000
125
+ train_script.train.eval_steps = 20
126
+ train_script.train.infer_eval_dataset_cfg = None
127
+ train_script.train.model = %MODEL
128
+ train_script.train.model_dir = %MODEL_DIR
129
+ train_script.train.partitioner = @partitioning.PjitPartitioner()
130
+ train_script.train.random_seed = %RANDOM_SEED
131
+ train_script.train.stats_period = 100
132
+ train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config
133
+ train_script.train.total_steps = %TRAIN_STEPS
134
+ train_script.train.train_dataset_cfg = @train/utils.DatasetConfig()
135
+ train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
136
+ train_script.train.trainer_cls = @trainer.Trainer
137
+ train_script.train.use_hardware_rng = %USE_HARDWARE_RNG
138
+
139
+ # Parameters for trainer.Trainer:
140
+ # ==============================================================================
141
+ trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler()
142
+ trainer.Trainer.num_microbatches = None
143
+
144
+ # Parameters for network.Transformer:
145
+ # ==============================================================================
146
+ network.Transformer.config = @network.T5Config()
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 2048,
7
+ "d_kv": 64,
8
+ "d_model": 768,
9
+ "decoder_start_token_id": 0,
10
+ "dense_act_fn": "gelu_new",
11
+ "dropout_rate": 0.1,
12
+ "eos_token_id": 1,
13
+ "feed_forward_proj": "gated-gelu",
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "is_gated_act": true,
17
+ "layer_norm_epsilon": 1e-06,
18
+ "model_type": "t5",
19
+ "num_decoder_layers": 12,
20
+ "num_heads": 12,
21
+ "num_layers": 12,
22
+ "output_past": true,
23
+ "pad_token_id": 0,
24
+ "relative_attention_max_distance": 128,
25
+ "relative_attention_num_buckets": 32,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.23.1",
29
+ "use_cache": true,
30
+ "vocab_size": 32128
31
+ }
convert_t5x_checkpoint_to_flax.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://gist.github.com/stefan-it/30e4998ef159f33696e377a46f699d9f
2
+
3
+ import argparse
4
+
5
+ from t5x import checkpoints
6
+ from transformers import T5Config, FlaxT5ForConditionalGeneration, AutoModelForSeq2SeqLM
7
+ import torch
8
+
9
+
10
+ def convert_t5x_checkpoint_to_flax(t5x_checkpoint_path, config_name, flax_dump_folder_path):
11
+ config = T5Config.from_pretrained(config_name)
12
+ flax_model = FlaxT5ForConditionalGeneration(config=config)
13
+ t5x_model = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
14
+
15
+ split_mlp_wi = "wi_0" in t5x_model["target"]["encoder"]["layers_0"]["mlp"]
16
+
17
+ # Encoder
18
+ for layer_index in range(config.num_layers):
19
+ layer_name = f"layers_{str(layer_index)}"
20
+
21
+ # Self-Attention
22
+ t5x_attention_key = t5x_model["target"]["encoder"][layer_name]["attention"]["key"]["kernel"]
23
+ t5x_attention_out = t5x_model["target"]["encoder"][layer_name]["attention"]["out"]["kernel"]
24
+ t5x_attention_query = t5x_model["target"]["encoder"][layer_name]["attention"]["query"]["kernel"]
25
+ t5x_attention_value = t5x_model["target"]["encoder"][layer_name]["attention"]["value"]["kernel"]
26
+
27
+ ## Layer Normalization
28
+ t5x_attention_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_attention_layer_norm"]["scale"]
29
+
30
+ if split_mlp_wi:
31
+ t5x_mlp_wi_0 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_0"]["kernel"]
32
+ t5x_mlp_wi_1 = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi_1"]["kernel"]
33
+ else:
34
+ t5x_mlp_wi = t5x_model["target"]["encoder"][layer_name]["mlp"]["wi"]["kernel"]
35
+
36
+ t5x_mlp_wo = t5x_model["target"]["encoder"][layer_name]["mlp"]["wo"]["kernel"]
37
+
38
+ ## Layer Normalization
39
+ t5x_mlp_layer_norm = t5x_model["target"]["encoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
40
+
41
+ # Assigning
42
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
43
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
44
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
45
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
46
+
47
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_attention_layer_norm
48
+
49
+ if split_mlp_wi:
50
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
51
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
52
+ else:
53
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
54
+
55
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
56
+ flax_model.params["encoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_mlp_layer_norm
57
+
58
+ # Only for layer 0:
59
+ t5x_encoder_rel_embedding = t5x_model["target"]["encoder"]["relpos_bias"]["rel_embedding"].T
60
+ flax_model.params["encoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_encoder_rel_embedding
61
+
62
+ # Assigning
63
+ t5x_encoder_norm = t5x_model["target"]["encoder"]["encoder_norm"]["scale"]
64
+ flax_model.params["encoder"]["final_layer_norm"]["weight"] = t5x_encoder_norm
65
+
66
+ # Decoder
67
+ for layer_index in range(config.num_decoder_layers):
68
+ layer_name = f"layers_{str(layer_index)}"
69
+
70
+ # Self-Attention
71
+ t5x_attention_key = t5x_model["target"]["decoder"][layer_name]["self_attention"]["key"]["kernel"]
72
+ t5x_attention_out = t5x_model["target"]["decoder"][layer_name]["self_attention"]["out"]["kernel"]
73
+ t5x_attention_query = t5x_model["target"]["decoder"][layer_name]["self_attention"]["query"]["kernel"]
74
+ t5x_attention_value = t5x_model["target"]["decoder"][layer_name]["self_attention"]["value"]["kernel"]
75
+
76
+ ## Layer Normalization
77
+ t5x_pre_attention_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_self_attention_layer_norm"]["scale"]
78
+
79
+ # Encoder-Decoder-Attention
80
+ t5x_enc_dec_attention_key = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["key"]["kernel"]
81
+ t5x_enc_dec_attention_out = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["out"]["kernel"]
82
+ t5x_enc_dec_attention_query = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["query"]["kernel"]
83
+ t5x_enc_dec_attention_value = t5x_model["target"]["decoder"][layer_name]["encoder_decoder_attention"]["value"]["kernel"]
84
+
85
+ ## Layer Normalization
86
+ t5x_cross_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_cross_attention_layer_norm"]["scale"]
87
+
88
+ # MLP
89
+ if split_mlp_wi:
90
+ t5x_mlp_wi_0 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_0"]["kernel"]
91
+ t5x_mlp_wi_1 = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi_1"]["kernel"]
92
+ else:
93
+ t5x_mlp_wi = t5x_model["target"]["decoder"][layer_name]["mlp"]["wi"]["kernel"]
94
+
95
+ t5x_mlp_wo = t5x_model["target"]["decoder"][layer_name]["mlp"]["wo"]["kernel"]
96
+
97
+ ## Layer Normalization
98
+ tx5_mlp_layer_norm = t5x_model["target"]["decoder"][layer_name]["pre_mlp_layer_norm"]["scale"]
99
+
100
+ # Assigning
101
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["k"]["kernel"] = t5x_attention_key
102
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["o"]["kernel"] = t5x_attention_out
103
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["q"]["kernel"] = t5x_attention_query
104
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["SelfAttention"]["v"]["kernel"] = t5x_attention_value
105
+
106
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["0"]["layer_norm"]["weight"] = t5x_pre_attention_layer_norm
107
+
108
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["k"]["kernel"] = t5x_enc_dec_attention_key
109
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["o"]["kernel"] = t5x_enc_dec_attention_out
110
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["q"]["kernel"] = t5x_enc_dec_attention_query
111
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["EncDecAttention"]["v"]["kernel"] = t5x_enc_dec_attention_value
112
+
113
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["1"]["layer_norm"]["weight"] = t5x_cross_layer_norm
114
+
115
+ if split_mlp_wi:
116
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_0"]["kernel"] = t5x_mlp_wi_0
117
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi_1"]["kernel"] = t5x_mlp_wi_1
118
+ else:
119
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wi"]["kernel"] = t5x_mlp_wi
120
+
121
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["DenseReluDense"]["wo"]["kernel"] = t5x_mlp_wo
122
+
123
+ flax_model.params["decoder"]["block"][str(layer_index)]["layer"]["2"]["layer_norm"]["weight"] = tx5_mlp_layer_norm
124
+
125
+ # Decoder Normalization
126
+ tx5_decoder_norm = t5x_model["target"]["decoder"]["decoder_norm"]["scale"]
127
+ flax_model.params["decoder"]["final_layer_norm"]["weight"] = tx5_decoder_norm
128
+
129
+ # Only for layer 0:
130
+ t5x_decoder_rel_embedding = t5x_model["target"]["decoder"]["relpos_bias"]["rel_embedding"].T
131
+ flax_model.params["decoder"]["block"]["0"]["layer"]["0"]["SelfAttention"]["relative_attention_bias"]["embedding"] = t5x_decoder_rel_embedding
132
+
133
+ # Token Embeddings
134
+ tx5_token_embeddings = t5x_model["target"]["token_embedder"]["embedding"]
135
+ flax_model.params["shared"]["embedding"] = tx5_token_embeddings
136
+
137
+ # LM Head
138
+ flax_model.params["lm_head"]["kernel"] = t5x_model["target"]["decoder"]["logits_dense"]["kernel"]
139
+
140
+ flax_model.save_pretrained(flax_dump_folder_path)
141
+ print("T5X Model was sucessfully converted!")
142
+
143
+ def convert_flax_to_pytorch(flax_dump_folder_path, pytorch_dump_folder_path):
144
+ model = AutoModelForSeq2SeqLM.from_pretrained(flax_dump_folder_path, from_flax=True, torch_dtype=torch.float32)
145
+ model.save_pretrained(pytorch_dump_folder_path)
146
+ print("Flax model was sucessfully converted to Pytorch!")
147
+
148
+ if __name__ == "__main__":
149
+ parser = argparse.ArgumentParser()
150
+ # Required parameters
151
+ parser.add_argument(
152
+ "--t5x_checkpoint_path", default=None, type=str, required=True, help="Path the TX5 checkpoint."
153
+ )
154
+ parser.add_argument(
155
+ "--config_name", default=None, type=str, required=True, help="Config name of T5 model."
156
+ )
157
+ parser.add_argument(
158
+ "--flax_dump_folder_path", default=None, type=str, required=True, help="Path to the output FLAX model."
159
+ )
160
+ args = parser.parse_args()
161
+ convert_t5x_checkpoint_to_flax(args.t5x_checkpoint_path, args.config_name, args.flax_dump_folder_path)
162
+ convert_flax_to_pytorch(args.flax_dump_folder_path, args.flax_dump_folder_path)
export_checkpoint.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration
3
+
4
+ def main(args):
5
+ pt_model = T5ForConditionalGeneration.from_pretrained(args.model_dir, from_flax=True)
6
+ pt_model.save_pretrained(args.model_dir)
7
+ tf_model = TFT5ForConditionalGeneration.from_pretrained(args.model_dir, from_pt=True)
8
+ tf_model.save_pretrained(args.model_dir)
9
+
10
+
11
+ if __name__ == "__main__":
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--model_dir', type=str, default='.')
model-info.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets>=1.16.1
2
+ transformers>=4.13.0
3
+ flax>=0.3.5
4
+ optax>=0.1.0
5
+ tqdm>=4.61.1
6
+ numpy>=1.19.5
7
+ tokenizers>=0.10.3
8
+ sentencepiece>=0.1.96
9
+ protobuf>=3.17.3,<=3.20.99
10
+ tensorboard>=2.7.0
11
+ torch>=1.9.0
12
+ tensorflow>=2.7.0
13
+ jax[tpu]>=0.2.28
small.gin ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T5.1.1 Efficient base nl36 model.
2
+
3
+ import seqio
4
+ include 't5x/examples/t5/t5_1_1/small.gin' # imports vocab, optimizer and model.
5
+
6
+ # ------------------- Model specification overrides --------------------------
7
+ VOCABULARY = @seqio.SentencePieceVocabulary()
8
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = "spiece.model"
9
+
10
+ MODEL = @models.EncoderDecoderModel()
11
+ models.EncoderDecoderModel:
12
+ input_vocabulary = %VOCABULARY
13
+ output_vocabulary = %VOCABULARY
small_pretrain.gin ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Register necessary SeqIO Tasks/Mixtures.
2
+ from __gin__ import dynamic_registration
3
+ from t5x import utils
4
+ import ul2_tasks
5
+ import __main__ as train_script
6
+
7
+ include 'small.gin'
8
+ include 't5x/configs/runs/pretrain.gin'
9
+
10
+
11
+ # ------------------- Training specification overrides --------------------------
12
+ train_script.train:
13
+ eval_period = 2000
14
+ stats_period = 100
15
+
16
+ utils.SaveCheckpointConfig:
17
+ period = 50000
18
+ keep = 4
19
+
20
+
21
+ MIXTURE_OR_TASK_NAME = "pretrain_biological_ul2"
22
+ USE_CACHED_TASKS = False
23
+ TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
24
+ TRAIN_STEPS = 500000
25
+ DROPOUT_RATE = 0.0
26
+ BATCH_SIZE = 256
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:579ebba0921710bb6bd17cd678d4379b4a81ca84756dab644d7e8529bd01009d
3
+ size 805610
spiece.vocab ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %pip install sentencepiece
2
+ # %pip install datasets
3
+ # %pip install seqio
4
+
5
+ import unicodedata
6
+ import os
7
+ import nltk
8
+ from tqdm import tqdm
9
+ import glob
10
+ from random import sample
11
+
12
+ def sample_and_make_tempfile(sentences_dir, num_files):
13
+ """ Use the set of files containing a sentence per line,
14
+ sample num_files out of those and save as a temp file """
15
+
16
+ sentence_files = glob.glob(sentences_dir + "/*.txt")
17
+
18
+ # sample num_files
19
+ sampled_files=sample(sentence_files, num_files)
20
+
21
+ print("sampled files:")
22
+ print(sampled_files)
23
+
24
+ #read all the lines from sampled files and save to a list
25
+ all_lines = []
26
+ for filename in sampled_files:
27
+ with open(filename) as f:
28
+ lines = f.read().splitlines()
29
+
30
+ all_lines.extend(lines)
31
+
32
+ print("number of lines sampled:", len(all_lines))
33
+
34
+ #combine into a single file and save
35
+ tempfile_path = os.path.join("text", "temp.txt")
36
+ with open(tempfile_path, "w") as f:
37
+
38
+ for sentence in tqdm(all_lines):
39
+
40
+ # remove newlines
41
+ line = sentence.strip()
42
+
43
+ # do not save empty items such as
44
+ if sentence != []:
45
+
46
+ f.writelines(sentence + '\n')
47
+
48
+ print("Wrote to ", tempfile_path)
49
+ return tempfile_path
50
+
51
+
52
+ def chunks(sentences, n, tot_len):
53
+ """Yield successive n-sized chunks from sentences."""
54
+ for i in range(0, tot_len, n):
55
+ end_i = min(len(sentences),i + n)
56
+ yield sentences[i:end_i]["text"]
57
+
58
+
59
+
60
+ def make_sentence_files(dataset, chunksize = 5600000, data_dir = 'text/sentences'):
61
+ """
62
+ Make a sentence per line files, chuncsize sentences per file"""
63
+
64
+ # make sure data dir exists
65
+ if not os.path.exists(data_dir):
66
+ os.makedirs(data_dir)
67
+
68
+ # use simple regex for sentence tokenizing
69
+ sent_detector = nltk.RegexpTokenizer(u'[^ !?。]*[!?。.\n]')
70
+
71
+ # loop over the chunks
72
+ for chunk_ind, sentence_chunk in enumerate(chunks(dataset, chunksize, len(dataset))):
73
+
74
+ # new file for each chunk
75
+ filename = "sent_{}.txt".format(chunk_ind)
76
+ filepath = os.path.join(data_dir, filename)
77
+
78
+ print("writing to ", filepath)
79
+
80
+ with open(filepath, "w") as f:
81
+
82
+ for sentence in tqdm(sentence_chunk):
83
+
84
+ # remove newlines
85
+ line = sentence.strip()
86
+
87
+ # unicode normalize japanese spaces etc
88
+ unicodedata.normalize('NFKC', line)
89
+
90
+ # tokenize into sentences
91
+ sentences = sent_detector.tokenize(line)
92
+
93
+ # do not save empty items such as
94
+ if sentences != []:
95
+
96
+ f.writelines(s + '\n' for s in sentences)
97
+
98
+
99
+ def combine_files(output_file, *files):
100
+ """
101
+ Combines the contents of multiple text files into a single file.
102
+
103
+ :param output_file: Path to the output file.
104
+ :param files: Paths to the files to be combined.
105
+ :return: Total number of lines in the combined file.
106
+ """
107
+ total_lines = 0
108
+
109
+ with open(output_file, 'w') as outfile:
110
+ for file in files:
111
+ with open(file, 'r') as infile:
112
+ lines = infile.readlines()
113
+ total_lines += len(lines)
114
+ outfile.writelines(lines)
115
+ # Add a newline for separation (optional)
116
+ outfile.write('\n')
117
+
118
+
119
+ return total_lines
120
+
121
+ # make sentence files from hugingface dataset
122
+ dataset_bio = datasets.load_dataset("Siddharth63/biological_dataset")
123
+ make_sentence_files(dataset_bio["train"])
124
+
125
+ # combine files to get 45 million sentences
126
+ files_to_combine = glob.glob("text/sentences/*.txt")
127
+ files_to_combine = files_to_combine[:2]
128
+ total_lines = combine_files(output_file_path, *files_to_combine)
129
+
130
+ # Train the sentencepiece transformers on 45 million sentences
131
+ import sentencepiece as spm
132
+
133
+ spm.SentencePieceTrainer.train(input="text/final_file.txt", model_prefix='spiece', vocab_size=32000, character_coverage=1.0,
134
+ pad_id=0, unk_id=2, eos_id=1, bos_id=-1,
135
+ user_defined_symbols=['[NLU]', '[NLG]', '[S2S]'],
136
+ train_extremely_large_corpus=True,
137
+ num_threads=90, input_sentence_size=45000000, shuffle_input_sentence=True)
138
+
139
+
140
+ # Add 100 extra tokens to the model
141
+ from seqio import SentencePieceVocabulary
142
+ import os
143
+ import tensorflow as tf
144
+ from sentencepiece import SentencePieceProcessor, sentencepiece_model_pb2
145
+
146
+
147
+ def add_100extra(vocab: SentencePieceVocabulary, out_dir: str):
148
+ tf.io.gfile.makedirs(out_dir)
149
+ tf.io.gfile.GFile(os.path.join(out_dir, 'spiece.model'), 'w').write(vocab.sp_model)
150
+
151
+ model = sentencepiece_model_pb2.ModelProto.FromString(vocab.sp_model)
152
+ tf.io.gfile.GFile(os.path.join(out_dir, 'spiece.vocab'), 'w').write(
153
+ '\n'.join(f'{p.piece}\t{p.score}' for p in model.pieces)
154
+ )
155
+
156
+
157
+ # vocab = t5.data.get_default_vocabulary()
158
+ # out_dir = "../vocabulary/cc_all.32000.100extra"
159
+ #
160
+ # add_100extra(vocab, out_dir)
161
+ #
162
+ # vocab = seqio.SentencePieceVocabulary("../vocabulary/nedd.32000/spiece.model", extra_ids=100)
163
+ # out_dir = "../vocabulary/nedd.32000.100extra"
164
+ # add_100extra(vocab, out_dir)
165
+ #
166
+ # vocab = seqio.SentencePieceVocabulary("../vocabulary/nedd.32000/spiece.model", extra_ids=128)
167
+ # out_dir = "../vocabulary/nedd.32000.128extra"
168
+ # add_100extra(vocab, out_dir)
169
+ #
170
+
171
+
172
+ vocab = SentencePieceVocabulary("/Users/sdeshpande/Desktop/Challenges/patents/spiece_45.model", extra_ids=100)
173
+ out_dir = "conv"
174
+ add_100extra(vocab, out_dir)
ul2_objective.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import tensorflow as tf
3
+ import seqio
4
+ import t5.data
5
+ from typing import Optional, Sequence
6
+
7
+ # found this function and modified from https://github.com/GoogleCloudPlatform/t5x-on-vertex-ai/blob/main/tasks/custom_tasks.py#L78
8
+ # UL2 paper appendix code missed this function
9
+ def prepend_prompt(
10
+ dataset: tf.data.Dataset,
11
+ output_features: seqio.preprocessors.OutputFeaturesType,
12
+ sequence_length: Optional[seqio.preprocessors.SequenceLengthType] = None,
13
+ prompt_mode: str = "",
14
+ key: str = "inputs",
15
+ mode: str = "",
16
+ ) -> tf.data.Dataset:
17
+ """Prepends a prompt at the beginning of an input sequence."""
18
+ del sequence_length
19
+ if prompt_mode and mode:
20
+ # output_features may not have inputs key
21
+ out_keys = list(output_features.keys())
22
+ prompt_tokens = output_features[out_keys[0]].vocabulary.encode_tf(prompt_mode)
23
+
24
+ def add_to_inputs(x):
25
+ x[key] = tf.concat([prompt_tokens, x[key]], axis=0)
26
+ return x
27
+
28
+ dataset = dataset.map(add_to_inputs)
29
+ return dataset
30
+
31
+
32
+ # modified from t5.data.preprocessors because output_features may not have inputs key
33
+ def split_tokens_to_inputs_length(dataset, sequence_length, output_features, **kwargs):
34
+ max_tokens = sequence_length["inputs"]
35
+ # output_features may not have inputs key
36
+ out_keys = list(output_features.keys())
37
+ if output_features[out_keys[0]].add_eos:
38
+ # Leave room to insert an EOS token.
39
+ max_tokens -= 1
40
+
41
+ return t5.data.preprocessors.split_tokens(
42
+ dataset, max_tokens_per_segment=max_tokens, **kwargs
43
+ )
44
+
45
+
46
+ # modified from t5.data.preprocessors because output_features may not have inputs key
47
+ def prefix_lm(dataset, sequence_length, output_features):
48
+ """Prefix language modeling objective used in Raffel et al. 2019."""
49
+ ds = dataset
50
+ ds = t5.data.preprocessors.select_random_chunk(
51
+ ds, output_features=output_features, feature_key="targets", max_length=65536
52
+ )
53
+ ds = split_tokens_to_inputs_length(
54
+ ds, output_features=output_features, sequence_length=sequence_length
55
+ )
56
+ ds = t5.data.preprocessors.denoise(
57
+ ds,
58
+ output_features,
59
+ inputs_fn=t5.data.preprocessors.drop_nonnoise_tokens,
60
+ targets_fn=t5.data.preprocessors.drop_noise_tokens,
61
+ noise_density=0.5,
62
+ noise_mask_fn=t5.data.preprocessors.random_prefix_noise_mask,
63
+ )
64
+ return ds
65
+
66
+
67
+ # copied from UL2 paper https://arxiv.org/pdf/2205.05131.pdf appendix chapter 9.2
68
+ # note: modified to use the prefix_lm() from above instead of the default t5.data.preprocessors.prefix_lm() because output_features may not have inputs key
69
+ def ul2_objective(
70
+ dataset: tf.data.Dataset,
71
+ sequence_length: seqio.preprocessors.SequenceLengthType,
72
+ output_features: seqio.preprocessors.OutputFeaturesType,
73
+ use_prefix_lm_task: bool = False,
74
+ rates: Optional[Sequence[float]] = None,
75
+ mean_noise_span_lengths: Sequence[float] = (3.0,),
76
+ noise_densities: Sequence[float] = (0.15,),
77
+ shard_ds: bool = True,
78
+ optional_task_prefixes: Optional[Sequence[str]] = None,
79
+ input_feature_key: str = "inputs",
80
+ merge_examples_to_reduce_padding: bool = True,
81
+ reserved_for_packing: bool = None,
82
+ seed: int = 7,
83
+ ) -> tf.data.Dataset:
84
+ """UL2-like pre-training objectives.
85
+ This preprocessor amounts to calling the 'span_corruption' function several
86
+ times with different values of 'noise_density' and 'mean_noise_span_length'.
87
+ We either shard or copy the dataset, then apply each function to each shard.
88
+ Add S-denoising (prefixLM) using use_prefix_lm_task.
89
+ Args:
90
+ dataset: A tf.data.Dataset with dictionaries containing the key 'input_feature_key'.
91
+ sequence_length: dict mapping of feature key to int length for that feature.
92
+ output_features: mapping of keys to features.
93
+ use_prefix_lm_task: <bool> If True, include PrefixLM in the task mix.
94
+ rates: <Optional<List<float>> List of rates per task. If None, tasks are sampled uniformly.
95
+ mean_noise_span_lengths: List of mean number of tokens per masked span per example.
96
+ noise_densities: List of what fraction of the tokens to mask.
97
+ shard_ds: <bool> If True, shard dataset per objective.
98
+ optional_task_prefixes: <Optional<list<str>> Strings to prepend for each corruption scheme. NOTE: If including prefixLM task, it must be the last prefix.
99
+ input_feature_key: which feature to use from the dataset as the input text tokens.
100
+ merge_examples_to_reduce_padding: if True, combines multiple input examples to reduce padding.
101
+ reserved_for_packing: if specified, reduces the desired inputs length by the specified amount to enable multiple examples to be packed together downstream.
102
+ seed: tf.int64 for controlling the random choice of spans.
103
+ Returns:
104
+ a dataset
105
+ """
106
+
107
+ if optional_task_prefixes: # Ensure each task has a prefix.
108
+ num_tasks = len(noise_densities) + int(use_prefix_lm_task)
109
+ valid_number_of_prefixes = num_tasks == len(optional_task_prefixes)
110
+ if not valid_number_of_prefixes:
111
+ raise ValueError("Number of task prefixes must match number of tasks.")
112
+ inputs_length = sequence_length[input_feature_key]
113
+ input_lengths, targets_lengths = [], []
114
+ sequence_lengths = {x: y for x, y in sequence_length.items()}
115
+ if reserved_for_packing:
116
+ inputs_length -= reserved_for_packing
117
+ for x, y in sequence_length.items():
118
+ sequence_lengths[x] = y - reserved_for_packing
119
+ hyperparams = list(zip(mean_noise_span_lengths, noise_densities))
120
+ for mean_noise_span_length, noise_density in hyperparams:
121
+ input_length, targets_length = t5.data.preprocessors.random_spans_helper(
122
+ extra_tokens_per_span_inputs=1,
123
+ extra_tokens_per_span_targets=1,
124
+ inputs_length=inputs_length,
125
+ mean_noise_span_length=mean_noise_span_length,
126
+ noise_density=noise_density,
127
+ )
128
+ input_lengths.append(input_length)
129
+ targets_lengths.append(targets_length)
130
+
131
+ if sequence_length["targets"] < targets_length:
132
+ upper_bound = max(targets_lengths)
133
+ raise ValueError(
134
+ f"Expected max targets length for span corruption ({upper_bound}) is "
135
+ f"greater than configured targets length "
136
+ f"({sequence_length['targets']})"
137
+ )
138
+ ds = dataset
139
+ ds = t5.data.preprocessors.select_random_chunk(
140
+ ds, output_features=output_features, feature_key="targets", max_length=65536
141
+ )
142
+ if merge_examples_to_reduce_padding:
143
+ ds = t5.data.preprocessors.reduce_concat_tokens(
144
+ ds, feature_key="targets", batch_size=128
145
+ )
146
+ num_shards = len(input_lengths) + int(use_prefix_lm_task)
147
+ if shard_ds:
148
+ ds_shards = [ds.shard(num_shards, i) for i in range(num_shards)]
149
+ else:
150
+ ds_shards = [ds for _ in range(num_shards)]
151
+ processed_ds = []
152
+ hyperparams = zip(input_lengths, hyperparams, range(num_shards))
153
+ for input_length, (noise_span_length, noise_density), i in hyperparams:
154
+ ds = ds_shards[i]
155
+ ds = t5.data.preprocessors.split_tokens(
156
+ ds,
157
+ feature_key="targets",
158
+ min_tokens_per_segment=None,
159
+ max_tokens_per_segment=input_length,
160
+ )
161
+ ds = t5.data.preprocessors.denoise(
162
+ ds,
163
+ output_features,
164
+ inputs_fn=t5.data.preprocessors.noise_span_to_unique_sentinel,
165
+ targets_fn=t5.data.preprocessors.nonnoise_span_to_unique_sentinel,
166
+ noise_density=noise_density,
167
+ noise_mask_fn=functools.partial(
168
+ t5.data.preprocessors.random_spans_noise_mask,
169
+ mean_noise_span_length=noise_span_length,
170
+ ),
171
+ input_feature_key=input_feature_key,
172
+ )
173
+ if optional_task_prefixes:
174
+ ds = prepend_prompt(
175
+ ds,
176
+ output_features,
177
+ prompt_mode=optional_task_prefixes[i],
178
+ mode=optional_task_prefixes[i],
179
+ key=input_feature_key,
180
+ )
181
+ processed_ds.append(ds)
182
+ if use_prefix_lm_task:
183
+ ds = ds_shards[-1]
184
+ ds = prefix_lm(ds, sequence_lengths, output_features)
185
+ if optional_task_prefixes:
186
+ ds = prepend_prompt(
187
+ ds,
188
+ output_features,
189
+ prompt_mode=optional_task_prefixes[-1],
190
+ mode=optional_task_prefixes[-1],
191
+ key=input_feature_key,
192
+ )
193
+ processed_ds.append(ds)
194
+ ds = tf.data.experimental.sample_from_datasets(processed_ds, rates, seed)
195
+ return ds
ul2_tasks.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import Dict
3
+
4
+ import seqio
5
+ import tensorflow as tf
6
+ from datasets import load_dataset, load_from_disk
7
+ from t5.evaluation import metrics
8
+ from seqio import utils, FunctionDataSource
9
+ import t5.data
10
+ from datasets import load_dataset, load_from_disk
11
+ from t5.data import postprocessors
12
+ from t5.data import preprocessors
13
+
14
+
15
+ from ul2_objective import ul2_objective
16
+
17
+ # values from UL2 paper https://arxiv.org/pdf/2205.05131.pdf chapter 3.1.2 table 1
18
+ R_DENOISER_SPAN_LENGTHS = [3.0, 8.0]
19
+ X_DENOISER_SPAN_LENGTHS = [3.0, 8.0, 64.0, 64.0]
20
+ R_DENOISER_CORRUPT_RATES = [0.15, 0.15]
21
+ X_DENOISER_CORRUPT_RATES = [0.5, 0.5, 0.15, 0.5]
22
+
23
+ R_DENOISER_TOKEN_PREFIX = "[NLU]"
24
+ X_DENOISER_TOKEN_PREFIX = "[NLG]"
25
+ S_DENOISER_TOKEN_PREFIX = "[S2S]"
26
+
27
+ TaskRegistry = seqio.TaskRegistry
28
+
29
+ vocabulary = seqio.SentencePieceVocabulary('spiece.model')
30
+
31
+ DEFAULT_OUTPUT_FEATURES = {
32
+ "inputs": seqio.Feature(vocabulary=vocabulary, add_eos=True, required=False),
33
+ "targets": seqio.Feature(vocabulary=vocabulary, add_eos=True),
34
+ }
35
+
36
+ def gen_dataset(split, shuffle=False, seed=None, column="text", path=None, name=None):
37
+ dataset = load_dataset(path, name, streaming=True, use_auth_token=True)
38
+ if shuffle:
39
+ if seed:
40
+ dataset = dataset.shuffle(seed=seed)
41
+ else:
42
+ dataset = dataset.shuffle()
43
+ while True:
44
+ for item in dataset[str(split)]:
45
+ yield item[column]
46
+
47
+
48
+ def dataset_fn(split, shuffle_files, seed=None, path=None, name=None):
49
+ return tf.data.Dataset.from_generator(
50
+ functools.partial(
51
+ gen_dataset, split, shuffle_files, seed, path=path, name=name
52
+ ),
53
+ output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=path),
54
+ )
55
+
56
+
57
+ @utils.map_over_dataset
58
+ def target_to_key(x, key_map, target_key):
59
+ """Assign the value from the dataset to target_key in key_map"""
60
+ return {**key_map, target_key: x}
61
+
62
+ ## First way to add to task registry
63
+ dataset_name = 'Siddharth63/biological_dataset'
64
+ dataset = load_dataset(dataset_name)
65
+
66
+ dataset_shapes = {"train": dataset["train"].num_rows,
67
+ "validation": dataset["validation"].num_rows}
68
+
69
+ TaskRegistry.add(
70
+ "pretrain_biological_ul2",
71
+ source=seqio.FunctionDataSource(
72
+ dataset_fn=functools.partial(
73
+ dataset_fn, path="Siddharth63/biological_dataset",
74
+ ),
75
+ splits=("train", "validation"),
76
+ caching_permitted=False,
77
+ ),
78
+ preprocessors=[
79
+ functools.partial(
80
+ target_to_key,
81
+ key_map={
82
+ "inputs": "text",
83
+ "targets": "text",
84
+ },
85
+ target_key="targets",
86
+ ),
87
+ seqio.preprocessors.tokenize,
88
+ functools.partial(
89
+ ul2_objective,
90
+ shard_ds=False,
91
+ use_prefix_lm_task=True, # use S-denoising
92
+ rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)] * len(R_DENOISER_SPAN_LENGTHS)
93
+ + [0.4 / len(X_DENOISER_SPAN_LENGTHS)] * len(X_DENOISER_SPAN_LENGTHS)
94
+ + [
95
+ 0.2
96
+ ], # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
97
+ mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
98
+ noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
99
+ optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]
100
+ * len(R_DENOISER_SPAN_LENGTHS)
101
+ + [X_DENOISER_TOKEN_PREFIX] * len(X_DENOISER_SPAN_LENGTHS)
102
+ + [S_DENOISER_TOKEN_PREFIX],
103
+ reserved_for_packing=1, # make room for task prefix token
104
+ ),
105
+ seqio.preprocessors.append_eos_after_trim,
106
+ ],
107
+ output_features={
108
+ "targets": DEFAULT_OUTPUT_FEATURES["targets"],
109
+ "inputs": seqio.Feature(vocabulary=vocabulary, add_eos=True),
110
+ },
111
+ metric_fns=[metrics.accuracy],
112
+ )
113
+
114
+
115
+ ## Second way to add to task registry
116
+ # TaskRegistry.add(
117
+ # "pretrain_biological_ul2",
118
+ # source=seqio.FunctionDataSource(
119
+ # dataset_fn=functools.partial(
120
+ # dataset_fn, path="Siddharth63/biological_dataset", name="full"
121
+ # ),
122
+ # splits=("train", "validation"),
123
+ # caching_permitted=False,
124
+ # ),
125
+ # preprocessors=[
126
+ # functools.partial(
127
+ # target_to_key,
128
+ # key_map={
129
+ # "inputs": "text",
130
+ # "targets": "text",
131
+ # },
132
+ # target_key="targets",
133
+ # ),
134
+ # seqio.preprocessors.tokenize,
135
+ # functools.partial(
136
+ # ul2_objective,
137
+ # shard_ds=False,
138
+ # use_prefix_lm_task=True, # use S-denoising
139
+ # rates=[0.4 / len(R_DENOISER_SPAN_LENGTHS)] * len(R_DENOISER_SPAN_LENGTHS)
140
+ # + [0.4 / len(X_DENOISER_SPAN_LENGTHS)] * len(X_DENOISER_SPAN_LENGTHS)
141
+ # + [
142
+ # 0.2
143
+ # ], # equal total 40% rate for both R- and X-denoisers + 20% for S-denoising (suggested at the paper chapter 4.5)
144
+ # mean_noise_span_lengths=R_DENOISER_SPAN_LENGTHS + X_DENOISER_SPAN_LENGTHS,
145
+ # noise_densities=R_DENOISER_CORRUPT_RATES + X_DENOISER_CORRUPT_RATES,
146
+ # optional_task_prefixes=[R_DENOISER_TOKEN_PREFIX]
147
+ # * len(R_DENOISER_SPAN_LENGTHS)
148
+ # + [X_DENOISER_TOKEN_PREFIX] * len(X_DENOISER_SPAN_LENGTHS)
149
+ # + [S_DENOISER_TOKEN_PREFIX],
150
+ # reserved_for_packing=1, # make room for task prefix token
151
+ # ),
152
+ # seqio.preprocessors.append_eos_after_trim,
153
+ # ],
154
+ # output_features={
155
+ # "targets": DEFAULT_OUTPUT_FEATURES["targets"],
156
+ # "inputs": seqio.Feature(vocabulary=vocabulary, add_eos=True),
157
+ # },
158
+ # metric_fns=[metrics.accuracy],
159
+ # )