eval
Browse files- __pycache__/tasks.cpython-38.pyc +0 -0
- eval.py +238 -0
- eval_base.sh +4 -3
- eval_categorisation_base.gin +2 -2
- finetune_categorisation_base.gin +3 -3
- results/config.gin +89 -0
- results/model-info.txt +285 -0
- tasks.py +2 -1
- train_base.sh +1 -1
__pycache__/tasks.cpython-38.pyc
CHANGED
Binary files a/__pycache__/tasks.cpython-38.pyc and b/__pycache__/tasks.cpython-38.pyc differ
|
|
eval.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint:disable=line-too-long
|
16 |
+
# pyformat: disable
|
17 |
+
r"""This script runs inference-evaluation on a T5X-compatible model.
|
18 |
+
|
19 |
+
"""
|
20 |
+
# pyformat: enable
|
21 |
+
# pylint:enable=line-too-long
|
22 |
+
|
23 |
+
import functools
|
24 |
+
import os
|
25 |
+
from typing import Optional, Sequence, Type
|
26 |
+
|
27 |
+
# pylint:disable=g-import-not-at-top
|
28 |
+
# TODO(adarob): Re-enable once users are notified and tests are updated.
|
29 |
+
os.environ['FLAX_LAZY_RNG'] = 'no'
|
30 |
+
from absl import logging
|
31 |
+
from clu import metric_writers
|
32 |
+
import jax
|
33 |
+
from jax.experimental import multihost_utils
|
34 |
+
import seqio
|
35 |
+
from t5x import gin_utils
|
36 |
+
from t5x import models
|
37 |
+
from t5x import partitioning
|
38 |
+
from t5x import utils
|
39 |
+
from typing_extensions import Protocol
|
40 |
+
|
41 |
+
# Automatically search for gin files relative to the T5X package.
|
42 |
+
_DEFAULT_GIN_SEARCH_PATHS = [
|
43 |
+
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
44 |
+
]
|
45 |
+
|
46 |
+
|
47 |
+
class SummarizeConfigFn(Protocol):
|
48 |
+
|
49 |
+
def __call__(self, model_dir: str,
|
50 |
+
summary_writer: Optional[metric_writers.SummaryWriter],
|
51 |
+
step: int) -> None:
|
52 |
+
...
|
53 |
+
|
54 |
+
|
55 |
+
def evaluate(
|
56 |
+
*,
|
57 |
+
model: models.BaseTransformerModel,
|
58 |
+
dataset_cfg: utils.DatasetConfig,
|
59 |
+
restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
|
60 |
+
partitioner: partitioning.BasePartitioner,
|
61 |
+
output_dir: str,
|
62 |
+
inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator,
|
63 |
+
summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config,
|
64 |
+
fallback_init_rng: Optional[int] = None):
|
65 |
+
"""Evaluation function.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
model: The model object to use for inference.
|
69 |
+
dataset_cfg: Specification for the dataset to infer based on.
|
70 |
+
restore_checkpoint_cfg: Specification for the model parameter checkpoint to
|
71 |
+
load.
|
72 |
+
partitioner: Partitioner for the model parameters and data across devices.
|
73 |
+
output_dir: Path to directory to write temporary files and final results.
|
74 |
+
inference_evaluator_cls: seqio.Evaluator class to use for inference
|
75 |
+
evaluation, potentially with bound configuration args.
|
76 |
+
summarize_config_fn: A function that takes in the model directory, an
|
77 |
+
optional SummaryWriter, and the step number, and writes a summary of the
|
78 |
+
configuration. SummaryWriter will be None in most cases.
|
79 |
+
fallback_init_rng: A random seed used for parameter initialization during
|
80 |
+
model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is
|
81 |
+
set to True. If None, parameter initialization is not allowed during model
|
82 |
+
loading and having fallback_to_scratch enabled will result in an error.
|
83 |
+
"""
|
84 |
+
logging.info('Process ID: %d', jax.process_index())
|
85 |
+
if dataset_cfg.module:
|
86 |
+
utils.import_module(dataset_cfg.module)
|
87 |
+
batch_size = dataset_cfg.batch_size
|
88 |
+
|
89 |
+
summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0)
|
90 |
+
|
91 |
+
ds_vocabs = utils.get_vocabulary(dataset_cfg)
|
92 |
+
if (ds_vocabs[0] != model.input_vocabulary or
|
93 |
+
ds_vocabs[1] != model.output_vocabulary):
|
94 |
+
raise ValueError(f'Model and Task vocabularies do not match:\n'
|
95 |
+
f' task={dataset_cfg.mixture_or_task_name}\n'
|
96 |
+
f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n'
|
97 |
+
f' model.input_vocabulary={model.input_vocabulary}\n'
|
98 |
+
f' model.output_vocabulary={model.output_vocabulary}\n')
|
99 |
+
|
100 |
+
# ----------------------------------------------------------------------------
|
101 |
+
# SeqIO (inference-based) evaluation setup
|
102 |
+
# ----------------------------------------------------------------------------
|
103 |
+
# Init evaluator to set up cached datasets
|
104 |
+
evaluator = inference_evaluator_cls(
|
105 |
+
mixture_or_task_name=dataset_cfg.mixture_or_task_name,
|
106 |
+
feature_converter=model.FEATURE_CONVERTER_CLS(pack=False),
|
107 |
+
eval_split=dataset_cfg.split,
|
108 |
+
use_cached=dataset_cfg.use_cached,
|
109 |
+
seed=dataset_cfg.seed,
|
110 |
+
sequence_length=dataset_cfg.task_feature_lengths,
|
111 |
+
log_dir=os.path.join(output_dir, 'inference_eval'))
|
112 |
+
if not evaluator.eval_tasks:
|
113 |
+
raise ValueError(
|
114 |
+
f"'{dataset_cfg.mixture_or_task_name}' has no metrics for evaluation.")
|
115 |
+
|
116 |
+
# ----------------------------------------------------------------------------
|
117 |
+
# T5X model loading.
|
118 |
+
# ----------------------------------------------------------------------------
|
119 |
+
|
120 |
+
# Initialize optimizer from the existing checkpoint.
|
121 |
+
input_shapes = {
|
122 |
+
k: (batch_size,) + s for k, s in evaluator.model_feature_shapes.items()
|
123 |
+
}
|
124 |
+
|
125 |
+
train_state_initializer = utils.TrainStateInitializer(
|
126 |
+
optimizer_def=None, # Do not load optimizer state.
|
127 |
+
init_fn=model.get_initial_variables,
|
128 |
+
input_shapes=input_shapes,
|
129 |
+
partitioner=partitioner)
|
130 |
+
train_state_axes = train_state_initializer.train_state_axes
|
131 |
+
# Log the variable shapes information and write to a file.
|
132 |
+
log_file = os.path.join(output_dir, 'model-info.txt')
|
133 |
+
utils.log_model_info(log_file,
|
134 |
+
train_state_initializer.global_train_state_shape,
|
135 |
+
partitioner)
|
136 |
+
|
137 |
+
predict_fn = None
|
138 |
+
score_fn = None
|
139 |
+
|
140 |
+
# Disable strictness since we are dropping the optimizer state.
|
141 |
+
restore_checkpoint_cfg.strict = False
|
142 |
+
|
143 |
+
if fallback_init_rng is not None:
|
144 |
+
fallback_init_rng = jax.random.PRNGKey(fallback_init_rng)
|
145 |
+
for train_state in train_state_initializer.from_checkpoints(
|
146 |
+
[restore_checkpoint_cfg], init_rng=fallback_init_rng):
|
147 |
+
|
148 |
+
# Compile the model only once.
|
149 |
+
if not predict_fn:
|
150 |
+
predict_fn = utils.get_infer_fn(
|
151 |
+
infer_step=model.predict_batch,
|
152 |
+
batch_size=batch_size,
|
153 |
+
train_state_axes=train_state_axes,
|
154 |
+
partitioner=partitioner)
|
155 |
+
|
156 |
+
score_fn = utils.get_infer_fn(
|
157 |
+
infer_step=model.score_batch,
|
158 |
+
batch_size=batch_size,
|
159 |
+
train_state_axes=train_state_axes,
|
160 |
+
partitioner=partitioner)
|
161 |
+
|
162 |
+
# ----------------------------------------------------------------------------
|
163 |
+
# Main training loop
|
164 |
+
# ----------------------------------------------------------------------------
|
165 |
+
|
166 |
+
# Run final evaluation (with decoding) on the full eval dataset.
|
167 |
+
all_metrics, _, _ = evaluator.evaluate(
|
168 |
+
compute_metrics=jax.process_index() == 0,
|
169 |
+
step=int(train_state.step),
|
170 |
+
predict_fn=functools.partial(
|
171 |
+
predict_fn, train_state=train_state, rng=jax.random.PRNGKey(0)),
|
172 |
+
score_fn=functools.partial(score_fn, train_state=train_state))
|
173 |
+
all_metrics.result() # Ensure metrics are finished being computed.
|
174 |
+
# Wait until computations are done before continuing.
|
175 |
+
multihost_utils.sync_global_devices(f'step_{train_state.step}:complete')
|
176 |
+
print(all_metrics.result())
|
177 |
+
|
178 |
+
logging.info('Finished.')
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == '__main__':
|
182 |
+
from absl import app
|
183 |
+
from absl import flags
|
184 |
+
import gin
|
185 |
+
|
186 |
+
FLAGS = flags.FLAGS
|
187 |
+
|
188 |
+
jax.config.parse_flags_with_absl()
|
189 |
+
|
190 |
+
flags.DEFINE_multi_string(
|
191 |
+
'gin_file',
|
192 |
+
default=None,
|
193 |
+
help='Path to gin configuration file. Multiple paths may be passed and '
|
194 |
+
'will be imported in the given order, with later configurations '
|
195 |
+
'overriding earlier ones.')
|
196 |
+
|
197 |
+
flags.DEFINE_multi_string(
|
198 |
+
'gin_bindings', default=[], help='Individual gin bindings.')
|
199 |
+
|
200 |
+
flags.DEFINE_list(
|
201 |
+
'gin_search_paths',
|
202 |
+
default=['.'],
|
203 |
+
help='Comma-separated list of gin config path prefixes to be prepended '
|
204 |
+
'to suffixes given via `--gin_file`. If a file appears in. Only the '
|
205 |
+
'first prefix that produces a valid path for each suffix will be '
|
206 |
+
'used.')
|
207 |
+
|
208 |
+
flags.DEFINE_string(
|
209 |
+
'tfds_data_dir', None,
|
210 |
+
'If set, this directory will be used to store datasets prepared by '
|
211 |
+
'TensorFlow Datasets that are not available in the public TFDS GCS '
|
212 |
+
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
|
213 |
+
'all `Task`s.')
|
214 |
+
|
215 |
+
|
216 |
+
def main(argv: Sequence[str]):
|
217 |
+
"""Wrapper for pdb post mortems."""
|
218 |
+
_main(argv)
|
219 |
+
|
220 |
+
def _main(argv: Sequence[str]):
|
221 |
+
"""True main function."""
|
222 |
+
if len(argv) > 1:
|
223 |
+
raise app.UsageError('Too many command-line arguments.')
|
224 |
+
|
225 |
+
if FLAGS.tfds_data_dir:
|
226 |
+
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
|
227 |
+
|
228 |
+
# Create gin-configurable version of `eval`.
|
229 |
+
evaluate_using_gin = gin.configurable(evaluate)
|
230 |
+
|
231 |
+
gin_utils.parse_gin_flags(
|
232 |
+
# User-provided gin paths take precedence if relative paths conflict.
|
233 |
+
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
|
234 |
+
FLAGS.gin_file,
|
235 |
+
FLAGS.gin_bindings)
|
236 |
+
evaluate_using_gin()
|
237 |
+
|
238 |
+
gin_utils.run(main)
|
eval_base.sh
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
|
2 |
-
EVAL_OUTPUT_DIR="
|
3 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
4 |
-
CHECKPOINT_PATH="gs://nb-t5x
|
5 |
export PYTHONPATH=${PROJECT_DIR}
|
6 |
|
7 |
-
python3 ${T5X_DIR}/t5x/eval.py \
|
|
|
8 |
--gin_search_paths=${PROJECT_DIR} \
|
9 |
--gin_file="eval_categorisation_base.gin" \
|
10 |
--gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
|
|
|
1 |
PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
|
2 |
+
EVAL_OUTPUT_DIR="/home/perk/models/t5-parliament-categorisation/results"
|
3 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
4 |
+
CHECKPOINT_PATH="gs://nb-t5x/eval_norwegian_1_163_000/checkpoint_1635000"
|
5 |
export PYTHONPATH=${PROJECT_DIR}
|
6 |
|
7 |
+
#python3 ${T5X_DIR}/t5x/eval.py \
|
8 |
+
python3 ./eval.py \
|
9 |
--gin_search_paths=${PROJECT_DIR} \
|
10 |
--gin_file="eval_categorisation_base.gin" \
|
11 |
--gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
|
eval_categorisation_base.gin
CHANGED
@@ -23,8 +23,8 @@ eval_script.evaluate:
|
|
23 |
|
24 |
utils.DatasetConfig:
|
25 |
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
|
26 |
-
task_feature_lengths =
|
27 |
-
split = '
|
28 |
batch_size = 32
|
29 |
shuffle = False
|
30 |
seed = 42
|
|
|
23 |
|
24 |
utils.DatasetConfig:
|
25 |
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
|
26 |
+
task_feature_lengths = {"inputs": 512, "targets": 2} # Auto-computes the max feature lengths.
|
27 |
+
split = 'validation'
|
28 |
batch_size = 32
|
29 |
shuffle = False
|
30 |
seed = 42
|
finetune_categorisation_base.gin
CHANGED
@@ -12,7 +12,7 @@ include "t5x/configs/runs/finetune.gin"
|
|
12 |
|
13 |
MIXTURE_OR_TASK_NAME = "categorise"
|
14 |
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
|
15 |
-
TRAIN_STEPS =
|
16 |
USE_CACHED_TASKS = False
|
17 |
DROPOUT_RATE = 0.1
|
18 |
RANDOM_SEED = 0
|
@@ -29,7 +29,7 @@ RANDOM_SEED = 0
|
|
29 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1360000"
|
30 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1_lr_1/checkpoint_1100000"
|
31 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1100000"
|
32 |
-
INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/
|
33 |
|
34 |
#train_script.train:
|
35 |
# eval_period = 500
|
@@ -37,7 +37,7 @@ INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint
|
|
37 |
# partitioning.PjitPartitioner.num_partitions = 1
|
38 |
|
39 |
# `num_decodes` is equivalent to a beam size in a beam search decoding.
|
40 |
-
|
41 |
|
42 |
#mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
|
43 |
#run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
|
|
|
12 |
|
13 |
MIXTURE_OR_TASK_NAME = "categorise"
|
14 |
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
|
15 |
+
TRAIN_STEPS = 1_953_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
|
16 |
USE_CACHED_TASKS = False
|
17 |
DROPOUT_RATE = 0.1
|
18 |
RANDOM_SEED = 0
|
|
|
29 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1360000"
|
30 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1_lr_1/checkpoint_1100000"
|
31 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1100000"
|
32 |
+
INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1948000"
|
33 |
|
34 |
#train_script.train:
|
35 |
# eval_period = 500
|
|
|
37 |
# partitioning.PjitPartitioner.num_partitions = 1
|
38 |
|
39 |
# `num_decodes` is equivalent to a beam size in a beam search decoding.
|
40 |
+
models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
|
41 |
|
42 |
#mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
|
43 |
#run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
|
results/config.gin
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __gin__ import dynamic_registration
|
2 |
+
import __main__ as eval_script
|
3 |
+
import seqio
|
4 |
+
from t5.data import mixtures
|
5 |
+
from t5x import adafactor
|
6 |
+
from t5x.examples.t5 import network
|
7 |
+
from t5x import models
|
8 |
+
from t5x import partitioning
|
9 |
+
from t5x import utils
|
10 |
+
import tasks
|
11 |
+
|
12 |
+
# Macros:
|
13 |
+
# ==============================================================================
|
14 |
+
CHECKPOINT_PATH = 'gs://nb-t5x/eval_norwegian_1_163_000/checkpoint_1635000'
|
15 |
+
DROPOUT_RATE = 0.0
|
16 |
+
EVAL_OUTPUT_DIR = '/home/perk/models/t5-parliament-categorisation/results'
|
17 |
+
LABEL_SMOOTHING = 0.0
|
18 |
+
LOSS_NORMALIZING_FACTOR = None
|
19 |
+
MIXTURE_OR_TASK_NAME = 'categorise'
|
20 |
+
MODEL = @models.EncoderDecoderModel()
|
21 |
+
OPTIMIZER = @adafactor.Adafactor()
|
22 |
+
VOCABULARY = @seqio.SentencePieceVocabulary()
|
23 |
+
Z_LOSS = 0.0001
|
24 |
+
|
25 |
+
# Parameters for adafactor.Adafactor:
|
26 |
+
# ==============================================================================
|
27 |
+
adafactor.Adafactor.decay_rate = 0.8
|
28 |
+
adafactor.Adafactor.logical_factor_rules = \
|
29 |
+
@adafactor.standard_logical_factor_rules()
|
30 |
+
adafactor.Adafactor.step_offset = 0
|
31 |
+
|
32 |
+
# Parameters for utils.DatasetConfig:
|
33 |
+
# ==============================================================================
|
34 |
+
utils.DatasetConfig.batch_size = 32
|
35 |
+
utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
|
36 |
+
utils.DatasetConfig.seed = 42
|
37 |
+
utils.DatasetConfig.shuffle = False
|
38 |
+
utils.DatasetConfig.split = 'test'
|
39 |
+
utils.DatasetConfig.task_feature_lengths = {'inputs': 512, 'targets': 2}
|
40 |
+
|
41 |
+
# Parameters for models.EncoderDecoderModel:
|
42 |
+
# ==============================================================================
|
43 |
+
models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
|
44 |
+
models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
|
45 |
+
models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
|
46 |
+
models.EncoderDecoderModel.module = @network.Transformer()
|
47 |
+
models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
|
48 |
+
models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
|
49 |
+
models.EncoderDecoderModel.z_loss = %Z_LOSS
|
50 |
+
|
51 |
+
# Parameters for eval_script.evaluate:
|
52 |
+
# ==============================================================================
|
53 |
+
eval_script.evaluate.dataset_cfg = @utils.DatasetConfig()
|
54 |
+
eval_script.evaluate.model = %MODEL
|
55 |
+
eval_script.evaluate.output_dir = %EVAL_OUTPUT_DIR
|
56 |
+
eval_script.evaluate.partitioner = @partitioning.PjitPartitioner()
|
57 |
+
eval_script.evaluate.restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
|
58 |
+
|
59 |
+
# Parameters for partitioning.PjitPartitioner:
|
60 |
+
# ==============================================================================
|
61 |
+
partitioning.PjitPartitioner.num_partitions = 2
|
62 |
+
|
63 |
+
# Parameters for utils.RestoreCheckpointConfig:
|
64 |
+
# ==============================================================================
|
65 |
+
utils.RestoreCheckpointConfig.mode = 'specific'
|
66 |
+
utils.RestoreCheckpointConfig.path = %CHECKPOINT_PATH
|
67 |
+
|
68 |
+
# Parameters for seqio.SentencePieceVocabulary:
|
69 |
+
# ==============================================================================
|
70 |
+
seqio.SentencePieceVocabulary.sentencepiece_model_file = \
|
71 |
+
'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model'
|
72 |
+
|
73 |
+
# Parameters for network.T5Config:
|
74 |
+
# ==============================================================================
|
75 |
+
network.T5Config.dropout_rate = %DROPOUT_RATE
|
76 |
+
network.T5Config.dtype = 'bfloat16'
|
77 |
+
network.T5Config.emb_dim = 768
|
78 |
+
network.T5Config.head_dim = 64
|
79 |
+
network.T5Config.logits_via_embedding = False
|
80 |
+
network.T5Config.mlp_activations = ('gelu', 'linear')
|
81 |
+
network.T5Config.mlp_dim = 2048
|
82 |
+
network.T5Config.num_decoder_layers = 12
|
83 |
+
network.T5Config.num_encoder_layers = 12
|
84 |
+
network.T5Config.num_heads = 12
|
85 |
+
network.T5Config.vocab_size = 250112
|
86 |
+
|
87 |
+
# Parameters for network.Transformer:
|
88 |
+
# ==============================================================================
|
89 |
+
network.Transformer.config = @network.T5Config()
|
results/model-info.txt
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Variable decoder/decoder_norm/scale size 768 shape (embed=768) partition spec (None,)
|
2 |
+
Variable decoder/layers_0/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
3 |
+
Variable decoder/layers_0/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
4 |
+
Variable decoder/layers_0/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
5 |
+
Variable decoder/layers_0/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
6 |
+
Variable decoder/layers_0/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
7 |
+
Variable decoder/layers_0/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
8 |
+
Variable decoder/layers_0/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
9 |
+
Variable decoder/layers_0/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
10 |
+
Variable decoder/layers_0/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
11 |
+
Variable decoder/layers_0/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
12 |
+
Variable decoder/layers_0/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
13 |
+
Variable decoder/layers_0/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
14 |
+
Variable decoder/layers_0/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
15 |
+
Variable decoder/layers_0/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
16 |
+
Variable decoder/layers_1/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
17 |
+
Variable decoder/layers_1/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
18 |
+
Variable decoder/layers_1/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
19 |
+
Variable decoder/layers_1/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
20 |
+
Variable decoder/layers_1/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
21 |
+
Variable decoder/layers_1/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
22 |
+
Variable decoder/layers_1/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
23 |
+
Variable decoder/layers_1/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
24 |
+
Variable decoder/layers_1/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
25 |
+
Variable decoder/layers_1/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
26 |
+
Variable decoder/layers_1/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
27 |
+
Variable decoder/layers_1/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
28 |
+
Variable decoder/layers_1/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
29 |
+
Variable decoder/layers_1/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
30 |
+
Variable decoder/layers_10/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
31 |
+
Variable decoder/layers_10/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
32 |
+
Variable decoder/layers_10/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
33 |
+
Variable decoder/layers_10/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
34 |
+
Variable decoder/layers_10/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
35 |
+
Variable decoder/layers_10/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
36 |
+
Variable decoder/layers_10/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
37 |
+
Variable decoder/layers_10/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
38 |
+
Variable decoder/layers_10/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
39 |
+
Variable decoder/layers_10/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
40 |
+
Variable decoder/layers_10/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
41 |
+
Variable decoder/layers_10/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
42 |
+
Variable decoder/layers_10/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
43 |
+
Variable decoder/layers_10/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
44 |
+
Variable decoder/layers_11/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
45 |
+
Variable decoder/layers_11/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
46 |
+
Variable decoder/layers_11/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
47 |
+
Variable decoder/layers_11/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
48 |
+
Variable decoder/layers_11/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
49 |
+
Variable decoder/layers_11/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
50 |
+
Variable decoder/layers_11/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
51 |
+
Variable decoder/layers_11/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
52 |
+
Variable decoder/layers_11/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
53 |
+
Variable decoder/layers_11/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
54 |
+
Variable decoder/layers_11/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
55 |
+
Variable decoder/layers_11/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
56 |
+
Variable decoder/layers_11/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
57 |
+
Variable decoder/layers_11/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
58 |
+
Variable decoder/layers_2/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
59 |
+
Variable decoder/layers_2/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
60 |
+
Variable decoder/layers_2/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
61 |
+
Variable decoder/layers_2/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
62 |
+
Variable decoder/layers_2/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
63 |
+
Variable decoder/layers_2/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
64 |
+
Variable decoder/layers_2/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
65 |
+
Variable decoder/layers_2/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
66 |
+
Variable decoder/layers_2/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
67 |
+
Variable decoder/layers_2/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
68 |
+
Variable decoder/layers_2/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
69 |
+
Variable decoder/layers_2/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
70 |
+
Variable decoder/layers_2/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
71 |
+
Variable decoder/layers_2/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
72 |
+
Variable decoder/layers_3/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
73 |
+
Variable decoder/layers_3/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
74 |
+
Variable decoder/layers_3/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
75 |
+
Variable decoder/layers_3/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
76 |
+
Variable decoder/layers_3/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
77 |
+
Variable decoder/layers_3/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
78 |
+
Variable decoder/layers_3/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
79 |
+
Variable decoder/layers_3/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
80 |
+
Variable decoder/layers_3/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
81 |
+
Variable decoder/layers_3/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
82 |
+
Variable decoder/layers_3/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
83 |
+
Variable decoder/layers_3/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
84 |
+
Variable decoder/layers_3/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
85 |
+
Variable decoder/layers_3/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
86 |
+
Variable decoder/layers_4/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
87 |
+
Variable decoder/layers_4/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
88 |
+
Variable decoder/layers_4/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
89 |
+
Variable decoder/layers_4/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
90 |
+
Variable decoder/layers_4/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
91 |
+
Variable decoder/layers_4/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
92 |
+
Variable decoder/layers_4/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
93 |
+
Variable decoder/layers_4/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
94 |
+
Variable decoder/layers_4/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
95 |
+
Variable decoder/layers_4/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
96 |
+
Variable decoder/layers_4/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
97 |
+
Variable decoder/layers_4/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
98 |
+
Variable decoder/layers_4/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
99 |
+
Variable decoder/layers_4/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
100 |
+
Variable decoder/layers_5/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
101 |
+
Variable decoder/layers_5/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
102 |
+
Variable decoder/layers_5/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
103 |
+
Variable decoder/layers_5/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
104 |
+
Variable decoder/layers_5/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
105 |
+
Variable decoder/layers_5/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
106 |
+
Variable decoder/layers_5/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
107 |
+
Variable decoder/layers_5/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
108 |
+
Variable decoder/layers_5/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
109 |
+
Variable decoder/layers_5/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
110 |
+
Variable decoder/layers_5/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
111 |
+
Variable decoder/layers_5/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
112 |
+
Variable decoder/layers_5/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
113 |
+
Variable decoder/layers_5/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
114 |
+
Variable decoder/layers_6/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
115 |
+
Variable decoder/layers_6/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
116 |
+
Variable decoder/layers_6/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
117 |
+
Variable decoder/layers_6/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
118 |
+
Variable decoder/layers_6/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
119 |
+
Variable decoder/layers_6/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
120 |
+
Variable decoder/layers_6/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
121 |
+
Variable decoder/layers_6/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
122 |
+
Variable decoder/layers_6/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
123 |
+
Variable decoder/layers_6/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
124 |
+
Variable decoder/layers_6/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
125 |
+
Variable decoder/layers_6/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
126 |
+
Variable decoder/layers_6/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
127 |
+
Variable decoder/layers_6/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
128 |
+
Variable decoder/layers_7/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
129 |
+
Variable decoder/layers_7/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
130 |
+
Variable decoder/layers_7/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
131 |
+
Variable decoder/layers_7/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
132 |
+
Variable decoder/layers_7/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
133 |
+
Variable decoder/layers_7/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
134 |
+
Variable decoder/layers_7/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
135 |
+
Variable decoder/layers_7/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
136 |
+
Variable decoder/layers_7/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
137 |
+
Variable decoder/layers_7/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
138 |
+
Variable decoder/layers_7/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
139 |
+
Variable decoder/layers_7/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
140 |
+
Variable decoder/layers_7/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
141 |
+
Variable decoder/layers_7/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
142 |
+
Variable decoder/layers_8/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
143 |
+
Variable decoder/layers_8/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
144 |
+
Variable decoder/layers_8/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
145 |
+
Variable decoder/layers_8/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
146 |
+
Variable decoder/layers_8/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
147 |
+
Variable decoder/layers_8/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
148 |
+
Variable decoder/layers_8/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
149 |
+
Variable decoder/layers_8/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
150 |
+
Variable decoder/layers_8/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
151 |
+
Variable decoder/layers_8/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
152 |
+
Variable decoder/layers_8/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
153 |
+
Variable decoder/layers_8/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
154 |
+
Variable decoder/layers_8/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
155 |
+
Variable decoder/layers_8/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
156 |
+
Variable decoder/layers_9/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
157 |
+
Variable decoder/layers_9/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
158 |
+
Variable decoder/layers_9/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
159 |
+
Variable decoder/layers_9/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
160 |
+
Variable decoder/layers_9/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
161 |
+
Variable decoder/layers_9/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
162 |
+
Variable decoder/layers_9/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
163 |
+
Variable decoder/layers_9/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
164 |
+
Variable decoder/layers_9/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
165 |
+
Variable decoder/layers_9/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
166 |
+
Variable decoder/layers_9/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
167 |
+
Variable decoder/layers_9/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
168 |
+
Variable decoder/layers_9/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
169 |
+
Variable decoder/layers_9/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
170 |
+
Variable decoder/logits_dense/kernel size 192086016 shape (embed=768, vocab=250112) partition spec (None, 'model')
|
171 |
+
Variable decoder/relpos_bias/rel_embedding size 384 shape (heads=12, relpos_buckets=32) partition spec ('model', None)
|
172 |
+
Variable encoder/encoder_norm/scale size 768 shape (embed=768) partition spec (None,)
|
173 |
+
Variable encoder/layers_0/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
174 |
+
Variable encoder/layers_0/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
175 |
+
Variable encoder/layers_0/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
176 |
+
Variable encoder/layers_0/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
177 |
+
Variable encoder/layers_0/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
178 |
+
Variable encoder/layers_0/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
179 |
+
Variable encoder/layers_0/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
180 |
+
Variable encoder/layers_0/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
181 |
+
Variable encoder/layers_0/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
182 |
+
Variable encoder/layers_1/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
183 |
+
Variable encoder/layers_1/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
184 |
+
Variable encoder/layers_1/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
185 |
+
Variable encoder/layers_1/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
186 |
+
Variable encoder/layers_1/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
187 |
+
Variable encoder/layers_1/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
188 |
+
Variable encoder/layers_1/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
189 |
+
Variable encoder/layers_1/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
190 |
+
Variable encoder/layers_1/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
191 |
+
Variable encoder/layers_10/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
192 |
+
Variable encoder/layers_10/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
193 |
+
Variable encoder/layers_10/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
194 |
+
Variable encoder/layers_10/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
195 |
+
Variable encoder/layers_10/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
196 |
+
Variable encoder/layers_10/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
197 |
+
Variable encoder/layers_10/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
198 |
+
Variable encoder/layers_10/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
199 |
+
Variable encoder/layers_10/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
200 |
+
Variable encoder/layers_11/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
201 |
+
Variable encoder/layers_11/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
202 |
+
Variable encoder/layers_11/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
203 |
+
Variable encoder/layers_11/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
204 |
+
Variable encoder/layers_11/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
205 |
+
Variable encoder/layers_11/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
206 |
+
Variable encoder/layers_11/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
207 |
+
Variable encoder/layers_11/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
208 |
+
Variable encoder/layers_11/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
209 |
+
Variable encoder/layers_2/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
210 |
+
Variable encoder/layers_2/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
211 |
+
Variable encoder/layers_2/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
212 |
+
Variable encoder/layers_2/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
213 |
+
Variable encoder/layers_2/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
214 |
+
Variable encoder/layers_2/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
215 |
+
Variable encoder/layers_2/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
216 |
+
Variable encoder/layers_2/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
217 |
+
Variable encoder/layers_2/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
218 |
+
Variable encoder/layers_3/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
219 |
+
Variable encoder/layers_3/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
220 |
+
Variable encoder/layers_3/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
221 |
+
Variable encoder/layers_3/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
222 |
+
Variable encoder/layers_3/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
223 |
+
Variable encoder/layers_3/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
224 |
+
Variable encoder/layers_3/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
225 |
+
Variable encoder/layers_3/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
226 |
+
Variable encoder/layers_3/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
227 |
+
Variable encoder/layers_4/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
228 |
+
Variable encoder/layers_4/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
229 |
+
Variable encoder/layers_4/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
230 |
+
Variable encoder/layers_4/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
231 |
+
Variable encoder/layers_4/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
232 |
+
Variable encoder/layers_4/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
233 |
+
Variable encoder/layers_4/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
234 |
+
Variable encoder/layers_4/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
235 |
+
Variable encoder/layers_4/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
236 |
+
Variable encoder/layers_5/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
237 |
+
Variable encoder/layers_5/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
238 |
+
Variable encoder/layers_5/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
239 |
+
Variable encoder/layers_5/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
240 |
+
Variable encoder/layers_5/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
241 |
+
Variable encoder/layers_5/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
242 |
+
Variable encoder/layers_5/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
243 |
+
Variable encoder/layers_5/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
244 |
+
Variable encoder/layers_5/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
245 |
+
Variable encoder/layers_6/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
246 |
+
Variable encoder/layers_6/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
247 |
+
Variable encoder/layers_6/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
248 |
+
Variable encoder/layers_6/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
249 |
+
Variable encoder/layers_6/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
250 |
+
Variable encoder/layers_6/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
251 |
+
Variable encoder/layers_6/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
252 |
+
Variable encoder/layers_6/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
253 |
+
Variable encoder/layers_6/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
254 |
+
Variable encoder/layers_7/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
255 |
+
Variable encoder/layers_7/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
256 |
+
Variable encoder/layers_7/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
257 |
+
Variable encoder/layers_7/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
258 |
+
Variable encoder/layers_7/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
259 |
+
Variable encoder/layers_7/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
260 |
+
Variable encoder/layers_7/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
261 |
+
Variable encoder/layers_7/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
262 |
+
Variable encoder/layers_7/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
263 |
+
Variable encoder/layers_8/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
264 |
+
Variable encoder/layers_8/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
265 |
+
Variable encoder/layers_8/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
266 |
+
Variable encoder/layers_8/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
267 |
+
Variable encoder/layers_8/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
268 |
+
Variable encoder/layers_8/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
269 |
+
Variable encoder/layers_8/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
270 |
+
Variable encoder/layers_8/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
271 |
+
Variable encoder/layers_8/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
272 |
+
Variable encoder/layers_9/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
273 |
+
Variable encoder/layers_9/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
|
274 |
+
Variable encoder/layers_9/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
275 |
+
Variable encoder/layers_9/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
|
276 |
+
Variable encoder/layers_9/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
277 |
+
Variable encoder/layers_9/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
|
278 |
+
Variable encoder/layers_9/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
|
279 |
+
Variable encoder/layers_9/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
280 |
+
Variable encoder/layers_9/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
|
281 |
+
Variable encoder/relpos_bias/rel_embedding size 384 shape (heads=12, relpos_buckets=32) partition spec ('model', None)
|
282 |
+
Variable token_embedder/embedding size 192086016 shape (vocab=250112, embed=768) partition spec ('model', None)
|
283 |
+
Total number of parameters: 582401280
|
284 |
+
|
285 |
+
Variable step size 1 shape () partition spec None
|
tasks.py
CHANGED
@@ -5,6 +5,7 @@ import seqio
|
|
5 |
import tensorflow_datasets as tfds
|
6 |
from t5.evaluation import metrics
|
7 |
from t5.data import preprocessors
|
|
|
8 |
import t5
|
9 |
import tensorflow.compat.v1 as tf
|
10 |
|
@@ -59,7 +60,7 @@ seqio.TaskRegistry.add(
|
|
59 |
categorise_preprocessor,
|
60 |
seqio.preprocessors.tokenize_and_append_eos,
|
61 |
],
|
62 |
-
|
63 |
output_features=DEFAULT_OUTPUT_FEATURES,
|
64 |
)
|
65 |
|
|
|
5 |
import tensorflow_datasets as tfds
|
6 |
from t5.evaluation import metrics
|
7 |
from t5.data import preprocessors
|
8 |
+
from t5.data import postprocessors
|
9 |
import t5
|
10 |
import tensorflow.compat.v1 as tf
|
11 |
|
|
|
60 |
categorise_preprocessor,
|
61 |
seqio.preprocessors.tokenize_and_append_eos,
|
62 |
],
|
63 |
+
metric_fns=[metrics.accuracy],
|
64 |
output_features=DEFAULT_OUTPUT_FEATURES,
|
65 |
)
|
66 |
|
train_base.sh
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
|
2 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
3 |
#Needs to be updated when moving to tpu-v4 it should then be in another zone
|
4 |
-
MODEL_DIR="gs://nb-t5x/
|
5 |
export PYTHONPATH=${PROJECT_DIR}
|
6 |
|
7 |
python3 ${T5X_DIR}/t5x/train.py \
|
|
|
1 |
PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
|
2 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
3 |
#Needs to be updated when moving to tpu-v4 it should then be in another zone
|
4 |
+
MODEL_DIR="gs://nb-t5x/eval_norwegian_1_194_000"
|
5 |
export PYTHONPATH=${PROJECT_DIR}
|
6 |
|
7 |
python3 ${T5X_DIR}/t5x/train.py \
|