test
Browse files- __pycache__/tasks.cpython-38.pyc +0 -0
- base.gin +18 -0
- pretrain_cont.gin +111 -0
- tasks.py +89 -0
- train_base.sh +9 -0
__pycache__/tasks.cpython-38.pyc
ADDED
Binary file (2.18 kB). View file
|
|
base.gin
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include 't5x/examples/t5/mt5/base.gin'
|
2 |
+
include 'pretrain_cont.gin'
|
3 |
+
#include 't5x/configs/runs/pretrain.gin'
|
4 |
+
#iinclude 't5x/configs/runs/finetune.gin'
|
5 |
+
|
6 |
+
|
7 |
+
# Register necessary SeqIO Tasks/Mixtures.
|
8 |
+
import t5.data.mixtures
|
9 |
+
import tasks
|
10 |
+
|
11 |
+
MIXTURE_OR_TASK_NAME = "ncc_scandinavian_span_corruption_stream"
|
12 |
+
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 512}
|
13 |
+
TRAIN_STEPS = 1_700_000
|
14 |
+
DROPOUT_RATE = 0.0 # Changed from the default since T5-1.1 recomments this.
|
15 |
+
INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_NCC_plus_English_t5x_base/checkpoint_1500000"
|
16 |
+
PjitPartitioner.num_partitions = 1
|
17 |
+
|
18 |
+
|
pretrain_cont.gin
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Defaults for pretraining with train.py.
|
2 |
+
#
|
3 |
+
#
|
4 |
+
# You must also include a binding for MODEL.
|
5 |
+
#
|
6 |
+
# Required to be set
|
7 |
+
#
|
8 |
+
# - MIXTURE_OR_TASK_NAME
|
9 |
+
# - TASK_FEATURE_LENGTHS
|
10 |
+
# - TRAIN_STEPS - include pretrain steps
|
11 |
+
# - MODEL_DIR: # automatically set when using xm_launch
|
12 |
+
#
|
13 |
+
# Commonly overridden options:
|
14 |
+
#
|
15 |
+
# - train/DatasetConfig.batch_size
|
16 |
+
# - train_eval/DatasetConfig.batch_size
|
17 |
+
# - PjitPartitioner.num_partitions
|
18 |
+
# - Trainer.num_microbatches
|
19 |
+
# - DROPOUT_RATE
|
20 |
+
from __gin__ import dynamic_registration
|
21 |
+
|
22 |
+
import __main__ as train_script
|
23 |
+
from t5x import gin_utils
|
24 |
+
from t5x import partitioning
|
25 |
+
from t5x import utils
|
26 |
+
from t5x import trainer
|
27 |
+
|
28 |
+
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
|
29 |
+
TASK_FEATURE_LENGTHS = %gin.REQUIRED
|
30 |
+
TRAIN_STEPS = %gin.REQUIRED
|
31 |
+
MODEL_DIR = %gin.REQUIRED
|
32 |
+
BATCH_SIZE = 128
|
33 |
+
USE_CACHED_TASKS = True
|
34 |
+
INITIAL_CHECKPOINT_PATH = %gin.REQUIRED
|
35 |
+
|
36 |
+
# DEPRECATED: Import the this module in your gin file.
|
37 |
+
MIXTURE_OR_TASK_MODULE = None
|
38 |
+
SHUFFLE_TRAIN_EXAMPLES = True
|
39 |
+
|
40 |
+
# HW RNG is faster than SW, but has limited determinism.
|
41 |
+
# Most notably it is not deterministic across different
|
42 |
+
# submeshes.
|
43 |
+
USE_HARDWARE_RNG = False
|
44 |
+
# None always uses faster, hardware RNG
|
45 |
+
RANDOM_SEED = None
|
46 |
+
|
47 |
+
# Can be overridden with `train.*`.`
|
48 |
+
train_script.train:
|
49 |
+
model = %MODEL # imported from separate gin file
|
50 |
+
model_dir = %MODEL_DIR
|
51 |
+
train_dataset_cfg = @train/utils.DatasetConfig()
|
52 |
+
train_eval_dataset_cfg = @train_eval/utils.DatasetConfig()
|
53 |
+
infer_eval_dataset_cfg = None
|
54 |
+
checkpoint_cfg = @utils.CheckpointConfig()
|
55 |
+
partitioner = @partitioning.PjitPartitioner()
|
56 |
+
trainer_cls = @trainer.Trainer
|
57 |
+
total_steps = %TRAIN_STEPS
|
58 |
+
eval_steps = 20
|
59 |
+
eval_period = 1000
|
60 |
+
random_seed = %RANDOM_SEED
|
61 |
+
use_hardware_rng = %USE_HARDWARE_RNG
|
62 |
+
summarize_config_fn = @gin_utils.summarize_gin_config
|
63 |
+
|
64 |
+
partitioning.PjitPartitioner:
|
65 |
+
num_partitions = 1
|
66 |
+
model_parallel_submesh = None
|
67 |
+
logical_axis_rules = @partitioning.standard_logical_axis_rules()
|
68 |
+
|
69 |
+
train/utils.DatasetConfig:
|
70 |
+
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
|
71 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
72 |
+
split = 'train'
|
73 |
+
batch_size = %BATCH_SIZE
|
74 |
+
shuffle = %SHUFFLE_TRAIN_EXAMPLES
|
75 |
+
seed = None # use a new seed each run/restart
|
76 |
+
use_cached = %USE_CACHED_TASKS
|
77 |
+
pack = True
|
78 |
+
module = %MIXTURE_OR_TASK_MODULE
|
79 |
+
|
80 |
+
train_eval/utils.DatasetConfig:
|
81 |
+
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
|
82 |
+
task_feature_lengths = %TASK_FEATURE_LENGTHS
|
83 |
+
split = 'validation'
|
84 |
+
batch_size = %BATCH_SIZE
|
85 |
+
shuffle = False
|
86 |
+
seed = 42
|
87 |
+
use_cached = %USE_CACHED_TASKS
|
88 |
+
pack = True
|
89 |
+
module = %MIXTURE_OR_TASK_MODULE
|
90 |
+
|
91 |
+
utils.CheckpointConfig:
|
92 |
+
restore = @utils.RestoreCheckpointConfig()
|
93 |
+
save = @utils.SaveCheckpointConfig()
|
94 |
+
utils.RestoreCheckpointConfig:
|
95 |
+
path = %INITIAL_CHECKPOINT_PATH
|
96 |
+
mode = 'specific'
|
97 |
+
dtype = 'float32'
|
98 |
+
utils.SaveCheckpointConfig:
|
99 |
+
period = 1000
|
100 |
+
dtype = 'float32'
|
101 |
+
keep = None # keep all checkpoints
|
102 |
+
save_dataset = False # don't checkpoint dataset state
|
103 |
+
|
104 |
+
trainer.Trainer:
|
105 |
+
num_microbatches = None
|
106 |
+
learning_rate_fn = @utils.create_learning_rate_scheduler()
|
107 |
+
|
108 |
+
utils.create_learning_rate_scheduler:
|
109 |
+
factors = 'constant * rsqrt_decay'
|
110 |
+
base_learning_rate = 0.5 #This is set to half of the original since it is continued training
|
111 |
+
warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults.
|
tasks.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
import seqio
|
4 |
+
import tensorflow as tf
|
5 |
+
import t5.data
|
6 |
+
from datasets import load_dataset
|
7 |
+
from t5.data import postprocessors
|
8 |
+
from t5.data import preprocessors
|
9 |
+
from t5.evaluation import metrics
|
10 |
+
from seqio import FunctionDataSource, utils
|
11 |
+
|
12 |
+
TaskRegistry = seqio.TaskRegistry
|
13 |
+
|
14 |
+
vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
|
15 |
+
byt5_vocabulary = t5.data.ByteVocabulary()
|
16 |
+
|
17 |
+
DEFAULT_OUTPUT_FEATURES = {
|
18 |
+
"inputs": seqio.Feature(
|
19 |
+
vocabulary=vocabulary, add_eos=True,
|
20 |
+
required=False),
|
21 |
+
"targets": seqio.Feature(
|
22 |
+
vocabulary=vocabulary, add_eos=True)
|
23 |
+
}
|
24 |
+
|
25 |
+
BYT5_DEFAULT_OUTPUT_FEATURES = {
|
26 |
+
"inputs": seqio.Feature(
|
27 |
+
vocabulary=byt5_vocabulary, add_eos=True,
|
28 |
+
required=False),
|
29 |
+
"targets": seqio.Feature(
|
30 |
+
vocabulary=byt5_vocabulary, add_eos=True)
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_params=None):
|
35 |
+
dataset = load_dataset(**dataset_params)
|
36 |
+
if shuffle:
|
37 |
+
if seed:
|
38 |
+
dataset = dataset.shuffle(seed=seed)
|
39 |
+
else:
|
40 |
+
dataset = dataset.shuffle()
|
41 |
+
while True:
|
42 |
+
for item in dataset[str(split)]:
|
43 |
+
yield item[column]
|
44 |
+
|
45 |
+
|
46 |
+
def dataset_fn(split, shuffle_files, seed=None, dataset_params=None):
|
47 |
+
return tf.data.Dataset.from_generator(
|
48 |
+
functools.partial(gen_dataset, split, shuffle_files, seed, dataset_params=dataset_params),
|
49 |
+
output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_name)
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
@utils.map_over_dataset
|
55 |
+
def target_to_key(x, key_map, target_key):
|
56 |
+
"""Assign the value from the dataset to target_key in key_map"""
|
57 |
+
return {**key_map, target_key: x}
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
# Final pretraining task used in Raffel et al., 2019 adaptated to NCC
|
62 |
+
dataset_name = 'NbAiLab/scandinavian'
|
63 |
+
dataset_params = {"path": dataset_name, "use_auth_token": True, "streaming": True}
|
64 |
+
dataset_shapes = None
|
65 |
+
TaskRegistry.add(
|
66 |
+
"ncc_scandinavian_span_corruption_stream",
|
67 |
+
source=seqio.FunctionDataSource(
|
68 |
+
dataset_fn=functools.partial(dataset_fn, dataset_params=dataset_params),
|
69 |
+
splits=("train", "validation"),
|
70 |
+
caching_permitted=False,
|
71 |
+
num_input_examples=dataset_shapes,
|
72 |
+
),
|
73 |
+
preprocessors=[
|
74 |
+
functools.partial(
|
75 |
+
target_to_key, key_map={
|
76 |
+
"inputs": None,
|
77 |
+
"targets": None,
|
78 |
+
}, target_key="targets"),
|
79 |
+
seqio.preprocessors.tokenize,
|
80 |
+
# seqio.CacheDatasetPlaceholder(),
|
81 |
+
preprocessors.span_corruption,
|
82 |
+
seqio.preprocessors.append_eos_after_trim,
|
83 |
+
],
|
84 |
+
output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"]},
|
85 |
+
metric_fns=[]
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
|
train_base.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PROJECT_DIR=${HOME}"/models/long-t5x"
|
2 |
+
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
3 |
+
MODEL_DIR="gs://nb-t5x-us-central2/long_test_t5x_base"
|
4 |
+
export PYTHONPATH=${PROJECT_DIR}
|
5 |
+
|
6 |
+
python3 ${T5X_DIR}/t5x/train.py \
|
7 |
+
--gin_search_paths=${PROJECT_DIR} \
|
8 |
+
--gin_file="base.gin" \
|
9 |
+
--gin.MODEL_DIR="'${MODEL_DIR}'" \
|