pere commited on
Commit
f5d1117
1 Parent(s): 437a8ab
__pycache__/tasks.cpython-38.pyc CHANGED
Binary files a/__pycache__/tasks.cpython-38.pyc and b/__pycache__/tasks.cpython-38.pyc differ
 
eval.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint:disable=line-too-long
16
+ # pyformat: disable
17
+ r"""This script runs inference-evaluation on a T5X-compatible model.
18
+
19
+ """
20
+ # pyformat: enable
21
+ # pylint:enable=line-too-long
22
+
23
+ import functools
24
+ import os
25
+ from typing import Optional, Sequence, Type
26
+
27
+ # pylint:disable=g-import-not-at-top
28
+ # TODO(adarob): Re-enable once users are notified and tests are updated.
29
+ os.environ['FLAX_LAZY_RNG'] = 'no'
30
+ from absl import logging
31
+ from clu import metric_writers
32
+ import jax
33
+ from jax.experimental import multihost_utils
34
+ import seqio
35
+ from t5x import gin_utils
36
+ from t5x import models
37
+ from t5x import partitioning
38
+ from t5x import utils
39
+ from typing_extensions import Protocol
40
+
41
+ # Automatically search for gin files relative to the T5X package.
42
+ _DEFAULT_GIN_SEARCH_PATHS = [
43
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
44
+ ]
45
+
46
+
47
+ class SummarizeConfigFn(Protocol):
48
+
49
+ def __call__(self, model_dir: str,
50
+ summary_writer: Optional[metric_writers.SummaryWriter],
51
+ step: int) -> None:
52
+ ...
53
+
54
+
55
+ def evaluate(
56
+ *,
57
+ model: models.BaseTransformerModel,
58
+ dataset_cfg: utils.DatasetConfig,
59
+ restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
60
+ partitioner: partitioning.BasePartitioner,
61
+ output_dir: str,
62
+ inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator,
63
+ summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config,
64
+ fallback_init_rng: Optional[int] = None):
65
+ """Evaluation function.
66
+
67
+ Args:
68
+ model: The model object to use for inference.
69
+ dataset_cfg: Specification for the dataset to infer based on.
70
+ restore_checkpoint_cfg: Specification for the model parameter checkpoint to
71
+ load.
72
+ partitioner: Partitioner for the model parameters and data across devices.
73
+ output_dir: Path to directory to write temporary files and final results.
74
+ inference_evaluator_cls: seqio.Evaluator class to use for inference
75
+ evaluation, potentially with bound configuration args.
76
+ summarize_config_fn: A function that takes in the model directory, an
77
+ optional SummaryWriter, and the step number, and writes a summary of the
78
+ configuration. SummaryWriter will be None in most cases.
79
+ fallback_init_rng: A random seed used for parameter initialization during
80
+ model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is
81
+ set to True. If None, parameter initialization is not allowed during model
82
+ loading and having fallback_to_scratch enabled will result in an error.
83
+ """
84
+ logging.info('Process ID: %d', jax.process_index())
85
+ if dataset_cfg.module:
86
+ utils.import_module(dataset_cfg.module)
87
+ batch_size = dataset_cfg.batch_size
88
+
89
+ summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0)
90
+
91
+ ds_vocabs = utils.get_vocabulary(dataset_cfg)
92
+ if (ds_vocabs[0] != model.input_vocabulary or
93
+ ds_vocabs[1] != model.output_vocabulary):
94
+ raise ValueError(f'Model and Task vocabularies do not match:\n'
95
+ f' task={dataset_cfg.mixture_or_task_name}\n'
96
+ f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n'
97
+ f' model.input_vocabulary={model.input_vocabulary}\n'
98
+ f' model.output_vocabulary={model.output_vocabulary}\n')
99
+
100
+ # ----------------------------------------------------------------------------
101
+ # SeqIO (inference-based) evaluation setup
102
+ # ----------------------------------------------------------------------------
103
+ # Init evaluator to set up cached datasets
104
+ evaluator = inference_evaluator_cls(
105
+ mixture_or_task_name=dataset_cfg.mixture_or_task_name,
106
+ feature_converter=model.FEATURE_CONVERTER_CLS(pack=False),
107
+ eval_split=dataset_cfg.split,
108
+ use_cached=dataset_cfg.use_cached,
109
+ seed=dataset_cfg.seed,
110
+ sequence_length=dataset_cfg.task_feature_lengths,
111
+ log_dir=os.path.join(output_dir, 'inference_eval'))
112
+ if not evaluator.eval_tasks:
113
+ raise ValueError(
114
+ f"'{dataset_cfg.mixture_or_task_name}' has no metrics for evaluation.")
115
+
116
+ # ----------------------------------------------------------------------------
117
+ # T5X model loading.
118
+ # ----------------------------------------------------------------------------
119
+
120
+ # Initialize optimizer from the existing checkpoint.
121
+ input_shapes = {
122
+ k: (batch_size,) + s for k, s in evaluator.model_feature_shapes.items()
123
+ }
124
+
125
+ train_state_initializer = utils.TrainStateInitializer(
126
+ optimizer_def=None, # Do not load optimizer state.
127
+ init_fn=model.get_initial_variables,
128
+ input_shapes=input_shapes,
129
+ partitioner=partitioner)
130
+ train_state_axes = train_state_initializer.train_state_axes
131
+ # Log the variable shapes information and write to a file.
132
+ log_file = os.path.join(output_dir, 'model-info.txt')
133
+ utils.log_model_info(log_file,
134
+ train_state_initializer.global_train_state_shape,
135
+ partitioner)
136
+
137
+ predict_fn = None
138
+ score_fn = None
139
+
140
+ # Disable strictness since we are dropping the optimizer state.
141
+ restore_checkpoint_cfg.strict = False
142
+
143
+ if fallback_init_rng is not None:
144
+ fallback_init_rng = jax.random.PRNGKey(fallback_init_rng)
145
+ for train_state in train_state_initializer.from_checkpoints(
146
+ [restore_checkpoint_cfg], init_rng=fallback_init_rng):
147
+
148
+ # Compile the model only once.
149
+ if not predict_fn:
150
+ predict_fn = utils.get_infer_fn(
151
+ infer_step=model.predict_batch,
152
+ batch_size=batch_size,
153
+ train_state_axes=train_state_axes,
154
+ partitioner=partitioner)
155
+
156
+ score_fn = utils.get_infer_fn(
157
+ infer_step=model.score_batch,
158
+ batch_size=batch_size,
159
+ train_state_axes=train_state_axes,
160
+ partitioner=partitioner)
161
+
162
+ # ----------------------------------------------------------------------------
163
+ # Main training loop
164
+ # ----------------------------------------------------------------------------
165
+
166
+ # Run final evaluation (with decoding) on the full eval dataset.
167
+ all_metrics, _, _ = evaluator.evaluate(
168
+ compute_metrics=jax.process_index() == 0,
169
+ step=int(train_state.step),
170
+ predict_fn=functools.partial(
171
+ predict_fn, train_state=train_state, rng=jax.random.PRNGKey(0)),
172
+ score_fn=functools.partial(score_fn, train_state=train_state))
173
+ all_metrics.result() # Ensure metrics are finished being computed.
174
+ # Wait until computations are done before continuing.
175
+ multihost_utils.sync_global_devices(f'step_{train_state.step}:complete')
176
+ print(all_metrics.result())
177
+
178
+ logging.info('Finished.')
179
+
180
+
181
+ if __name__ == '__main__':
182
+ from absl import app
183
+ from absl import flags
184
+ import gin
185
+
186
+ FLAGS = flags.FLAGS
187
+
188
+ jax.config.parse_flags_with_absl()
189
+
190
+ flags.DEFINE_multi_string(
191
+ 'gin_file',
192
+ default=None,
193
+ help='Path to gin configuration file. Multiple paths may be passed and '
194
+ 'will be imported in the given order, with later configurations '
195
+ 'overriding earlier ones.')
196
+
197
+ flags.DEFINE_multi_string(
198
+ 'gin_bindings', default=[], help='Individual gin bindings.')
199
+
200
+ flags.DEFINE_list(
201
+ 'gin_search_paths',
202
+ default=['.'],
203
+ help='Comma-separated list of gin config path prefixes to be prepended '
204
+ 'to suffixes given via `--gin_file`. If a file appears in. Only the '
205
+ 'first prefix that produces a valid path for each suffix will be '
206
+ 'used.')
207
+
208
+ flags.DEFINE_string(
209
+ 'tfds_data_dir', None,
210
+ 'If set, this directory will be used to store datasets prepared by '
211
+ 'TensorFlow Datasets that are not available in the public TFDS GCS '
212
+ 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
213
+ 'all `Task`s.')
214
+
215
+
216
+ def main(argv: Sequence[str]):
217
+ """Wrapper for pdb post mortems."""
218
+ _main(argv)
219
+
220
+ def _main(argv: Sequence[str]):
221
+ """True main function."""
222
+ if len(argv) > 1:
223
+ raise app.UsageError('Too many command-line arguments.')
224
+
225
+ if FLAGS.tfds_data_dir:
226
+ seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
227
+
228
+ # Create gin-configurable version of `eval`.
229
+ evaluate_using_gin = gin.configurable(evaluate)
230
+
231
+ gin_utils.parse_gin_flags(
232
+ # User-provided gin paths take precedence if relative paths conflict.
233
+ FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
234
+ FLAGS.gin_file,
235
+ FLAGS.gin_bindings)
236
+ evaluate_using_gin()
237
+
238
+ gin_utils.run(main)
eval_base.sh CHANGED
@@ -1,10 +1,11 @@
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
- EVAL_OUTPUT_DIR="gs://nb-t5x/eval/"
3
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
4
- CHECKPOINT_PATH="gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1043000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
- python3 ${T5X_DIR}/t5x/eval.py \
 
8
  --gin_search_paths=${PROJECT_DIR} \
9
  --gin_file="eval_categorisation_base.gin" \
10
  --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
 
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
+ EVAL_OUTPUT_DIR="/home/perk/models/t5-parliament-categorisation/results"
3
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
4
+ CHECKPOINT_PATH="gs://nb-t5x/eval_norwegian_1_163_000/checkpoint_1635000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
+ #python3 ${T5X_DIR}/t5x/eval.py \
8
+ python3 ./eval.py \
9
  --gin_search_paths=${PROJECT_DIR} \
10
  --gin_file="eval_categorisation_base.gin" \
11
  --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
eval_categorisation_base.gin CHANGED
@@ -23,8 +23,8 @@ eval_script.evaluate:
23
 
24
  utils.DatasetConfig:
25
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
26
- task_feature_lengths = None # Auto-computes the max feature lengths.
27
- split = 'test'
28
  batch_size = 32
29
  shuffle = False
30
  seed = 42
 
23
 
24
  utils.DatasetConfig:
25
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
26
+ task_feature_lengths = {"inputs": 512, "targets": 2} # Auto-computes the max feature lengths.
27
+ split = 'validation'
28
  batch_size = 32
29
  shuffle = False
30
  seed = 42
finetune_categorisation_base.gin CHANGED
@@ -12,7 +12,7 @@ include "t5x/configs/runs/finetune.gin"
12
 
13
  MIXTURE_OR_TASK_NAME = "categorise"
14
  TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
15
- TRAIN_STEPS = 1_635_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
16
  USE_CACHED_TASKS = False
17
  DROPOUT_RATE = 0.1
18
  RANDOM_SEED = 0
@@ -29,7 +29,7 @@ RANDOM_SEED = 0
29
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1360000"
30
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1_lr_1/checkpoint_1100000"
31
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1100000"
32
- INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1630000"
33
 
34
  #train_script.train:
35
  # eval_period = 500
@@ -37,7 +37,7 @@ INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint
37
  # partitioning.PjitPartitioner.num_partitions = 1
38
 
39
  # `num_decodes` is equivalent to a beam size in a beam search decoding.
40
- # models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
41
 
42
  #mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
43
  #run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
 
12
 
13
  MIXTURE_OR_TASK_NAME = "categorise"
14
  TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
15
+ TRAIN_STEPS = 1_953_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
16
  USE_CACHED_TASKS = False
17
  DROPOUT_RATE = 0.1
18
  RANDOM_SEED = 0
 
29
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1360000"
30
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1_lr_1/checkpoint_1100000"
31
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1100000"
32
+ INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1948000"
33
 
34
  #train_script.train:
35
  # eval_period = 500
 
37
  # partitioning.PjitPartitioner.num_partitions = 1
38
 
39
  # `num_decodes` is equivalent to a beam size in a beam search decoding.
40
+ models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
41
 
42
  #mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
43
  #run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
results/config.gin ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __gin__ import dynamic_registration
2
+ import __main__ as eval_script
3
+ import seqio
4
+ from t5.data import mixtures
5
+ from t5x import adafactor
6
+ from t5x.examples.t5 import network
7
+ from t5x import models
8
+ from t5x import partitioning
9
+ from t5x import utils
10
+ import tasks
11
+
12
+ # Macros:
13
+ # ==============================================================================
14
+ CHECKPOINT_PATH = 'gs://nb-t5x/eval_norwegian_1_163_000/checkpoint_1635000'
15
+ DROPOUT_RATE = 0.0
16
+ EVAL_OUTPUT_DIR = '/home/perk/models/t5-parliament-categorisation/results'
17
+ LABEL_SMOOTHING = 0.0
18
+ LOSS_NORMALIZING_FACTOR = None
19
+ MIXTURE_OR_TASK_NAME = 'categorise'
20
+ MODEL = @models.EncoderDecoderModel()
21
+ OPTIMIZER = @adafactor.Adafactor()
22
+ VOCABULARY = @seqio.SentencePieceVocabulary()
23
+ Z_LOSS = 0.0001
24
+
25
+ # Parameters for adafactor.Adafactor:
26
+ # ==============================================================================
27
+ adafactor.Adafactor.decay_rate = 0.8
28
+ adafactor.Adafactor.logical_factor_rules = \
29
+ @adafactor.standard_logical_factor_rules()
30
+ adafactor.Adafactor.step_offset = 0
31
+
32
+ # Parameters for utils.DatasetConfig:
33
+ # ==============================================================================
34
+ utils.DatasetConfig.batch_size = 32
35
+ utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
36
+ utils.DatasetConfig.seed = 42
37
+ utils.DatasetConfig.shuffle = False
38
+ utils.DatasetConfig.split = 'test'
39
+ utils.DatasetConfig.task_feature_lengths = {'inputs': 512, 'targets': 2}
40
+
41
+ # Parameters for models.EncoderDecoderModel:
42
+ # ==============================================================================
43
+ models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
44
+ models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
45
+ models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
46
+ models.EncoderDecoderModel.module = @network.Transformer()
47
+ models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
48
+ models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
49
+ models.EncoderDecoderModel.z_loss = %Z_LOSS
50
+
51
+ # Parameters for eval_script.evaluate:
52
+ # ==============================================================================
53
+ eval_script.evaluate.dataset_cfg = @utils.DatasetConfig()
54
+ eval_script.evaluate.model = %MODEL
55
+ eval_script.evaluate.output_dir = %EVAL_OUTPUT_DIR
56
+ eval_script.evaluate.partitioner = @partitioning.PjitPartitioner()
57
+ eval_script.evaluate.restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
58
+
59
+ # Parameters for partitioning.PjitPartitioner:
60
+ # ==============================================================================
61
+ partitioning.PjitPartitioner.num_partitions = 2
62
+
63
+ # Parameters for utils.RestoreCheckpointConfig:
64
+ # ==============================================================================
65
+ utils.RestoreCheckpointConfig.mode = 'specific'
66
+ utils.RestoreCheckpointConfig.path = %CHECKPOINT_PATH
67
+
68
+ # Parameters for seqio.SentencePieceVocabulary:
69
+ # ==============================================================================
70
+ seqio.SentencePieceVocabulary.sentencepiece_model_file = \
71
+ 'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model'
72
+
73
+ # Parameters for network.T5Config:
74
+ # ==============================================================================
75
+ network.T5Config.dropout_rate = %DROPOUT_RATE
76
+ network.T5Config.dtype = 'bfloat16'
77
+ network.T5Config.emb_dim = 768
78
+ network.T5Config.head_dim = 64
79
+ network.T5Config.logits_via_embedding = False
80
+ network.T5Config.mlp_activations = ('gelu', 'linear')
81
+ network.T5Config.mlp_dim = 2048
82
+ network.T5Config.num_decoder_layers = 12
83
+ network.T5Config.num_encoder_layers = 12
84
+ network.T5Config.num_heads = 12
85
+ network.T5Config.vocab_size = 250112
86
+
87
+ # Parameters for network.Transformer:
88
+ # ==============================================================================
89
+ network.Transformer.config = @network.T5Config()
results/model-info.txt ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Variable decoder/decoder_norm/scale size 768 shape (embed=768) partition spec (None,)
2
+ Variable decoder/layers_0/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
3
+ Variable decoder/layers_0/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
4
+ Variable decoder/layers_0/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
5
+ Variable decoder/layers_0/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
6
+ Variable decoder/layers_0/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
7
+ Variable decoder/layers_0/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
8
+ Variable decoder/layers_0/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
9
+ Variable decoder/layers_0/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
10
+ Variable decoder/layers_0/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
11
+ Variable decoder/layers_0/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
12
+ Variable decoder/layers_0/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
13
+ Variable decoder/layers_0/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
14
+ Variable decoder/layers_0/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
15
+ Variable decoder/layers_0/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
16
+ Variable decoder/layers_1/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
17
+ Variable decoder/layers_1/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
18
+ Variable decoder/layers_1/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
19
+ Variable decoder/layers_1/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
20
+ Variable decoder/layers_1/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
21
+ Variable decoder/layers_1/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
22
+ Variable decoder/layers_1/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
23
+ Variable decoder/layers_1/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
24
+ Variable decoder/layers_1/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
25
+ Variable decoder/layers_1/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
26
+ Variable decoder/layers_1/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
27
+ Variable decoder/layers_1/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
28
+ Variable decoder/layers_1/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
29
+ Variable decoder/layers_1/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
30
+ Variable decoder/layers_10/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
31
+ Variable decoder/layers_10/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
32
+ Variable decoder/layers_10/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
33
+ Variable decoder/layers_10/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
34
+ Variable decoder/layers_10/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
35
+ Variable decoder/layers_10/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
36
+ Variable decoder/layers_10/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
37
+ Variable decoder/layers_10/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
38
+ Variable decoder/layers_10/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
39
+ Variable decoder/layers_10/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
40
+ Variable decoder/layers_10/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
41
+ Variable decoder/layers_10/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
42
+ Variable decoder/layers_10/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
43
+ Variable decoder/layers_10/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
44
+ Variable decoder/layers_11/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
45
+ Variable decoder/layers_11/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
46
+ Variable decoder/layers_11/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
47
+ Variable decoder/layers_11/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
48
+ Variable decoder/layers_11/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
49
+ Variable decoder/layers_11/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
50
+ Variable decoder/layers_11/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
51
+ Variable decoder/layers_11/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
52
+ Variable decoder/layers_11/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
53
+ Variable decoder/layers_11/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
54
+ Variable decoder/layers_11/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
55
+ Variable decoder/layers_11/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
56
+ Variable decoder/layers_11/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
57
+ Variable decoder/layers_11/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
58
+ Variable decoder/layers_2/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
59
+ Variable decoder/layers_2/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
60
+ Variable decoder/layers_2/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
61
+ Variable decoder/layers_2/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
62
+ Variable decoder/layers_2/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
63
+ Variable decoder/layers_2/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
64
+ Variable decoder/layers_2/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
65
+ Variable decoder/layers_2/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
66
+ Variable decoder/layers_2/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
67
+ Variable decoder/layers_2/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
68
+ Variable decoder/layers_2/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
69
+ Variable decoder/layers_2/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
70
+ Variable decoder/layers_2/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
71
+ Variable decoder/layers_2/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
72
+ Variable decoder/layers_3/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
73
+ Variable decoder/layers_3/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
74
+ Variable decoder/layers_3/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
75
+ Variable decoder/layers_3/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
76
+ Variable decoder/layers_3/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
77
+ Variable decoder/layers_3/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
78
+ Variable decoder/layers_3/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
79
+ Variable decoder/layers_3/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
80
+ Variable decoder/layers_3/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
81
+ Variable decoder/layers_3/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
82
+ Variable decoder/layers_3/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
83
+ Variable decoder/layers_3/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
84
+ Variable decoder/layers_3/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
85
+ Variable decoder/layers_3/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
86
+ Variable decoder/layers_4/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
87
+ Variable decoder/layers_4/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
88
+ Variable decoder/layers_4/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
89
+ Variable decoder/layers_4/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
90
+ Variable decoder/layers_4/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
91
+ Variable decoder/layers_4/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
92
+ Variable decoder/layers_4/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
93
+ Variable decoder/layers_4/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
94
+ Variable decoder/layers_4/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
95
+ Variable decoder/layers_4/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
96
+ Variable decoder/layers_4/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
97
+ Variable decoder/layers_4/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
98
+ Variable decoder/layers_4/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
99
+ Variable decoder/layers_4/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
100
+ Variable decoder/layers_5/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
101
+ Variable decoder/layers_5/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
102
+ Variable decoder/layers_5/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
103
+ Variable decoder/layers_5/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
104
+ Variable decoder/layers_5/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
105
+ Variable decoder/layers_5/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
106
+ Variable decoder/layers_5/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
107
+ Variable decoder/layers_5/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
108
+ Variable decoder/layers_5/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
109
+ Variable decoder/layers_5/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
110
+ Variable decoder/layers_5/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
111
+ Variable decoder/layers_5/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
112
+ Variable decoder/layers_5/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
113
+ Variable decoder/layers_5/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
114
+ Variable decoder/layers_6/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
115
+ Variable decoder/layers_6/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
116
+ Variable decoder/layers_6/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
117
+ Variable decoder/layers_6/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
118
+ Variable decoder/layers_6/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
119
+ Variable decoder/layers_6/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
120
+ Variable decoder/layers_6/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
121
+ Variable decoder/layers_6/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
122
+ Variable decoder/layers_6/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
123
+ Variable decoder/layers_6/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
124
+ Variable decoder/layers_6/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
125
+ Variable decoder/layers_6/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
126
+ Variable decoder/layers_6/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
127
+ Variable decoder/layers_6/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
128
+ Variable decoder/layers_7/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
129
+ Variable decoder/layers_7/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
130
+ Variable decoder/layers_7/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
131
+ Variable decoder/layers_7/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
132
+ Variable decoder/layers_7/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
133
+ Variable decoder/layers_7/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
134
+ Variable decoder/layers_7/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
135
+ Variable decoder/layers_7/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
136
+ Variable decoder/layers_7/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
137
+ Variable decoder/layers_7/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
138
+ Variable decoder/layers_7/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
139
+ Variable decoder/layers_7/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
140
+ Variable decoder/layers_7/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
141
+ Variable decoder/layers_7/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
142
+ Variable decoder/layers_8/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
143
+ Variable decoder/layers_8/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
144
+ Variable decoder/layers_8/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
145
+ Variable decoder/layers_8/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
146
+ Variable decoder/layers_8/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
147
+ Variable decoder/layers_8/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
148
+ Variable decoder/layers_8/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
149
+ Variable decoder/layers_8/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
150
+ Variable decoder/layers_8/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
151
+ Variable decoder/layers_8/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
152
+ Variable decoder/layers_8/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
153
+ Variable decoder/layers_8/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
154
+ Variable decoder/layers_8/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
155
+ Variable decoder/layers_8/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
156
+ Variable decoder/layers_9/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
157
+ Variable decoder/layers_9/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
158
+ Variable decoder/layers_9/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
159
+ Variable decoder/layers_9/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
160
+ Variable decoder/layers_9/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
161
+ Variable decoder/layers_9/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
162
+ Variable decoder/layers_9/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
163
+ Variable decoder/layers_9/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
164
+ Variable decoder/layers_9/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
165
+ Variable decoder/layers_9/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
166
+ Variable decoder/layers_9/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
167
+ Variable decoder/layers_9/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
168
+ Variable decoder/layers_9/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
169
+ Variable decoder/layers_9/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
170
+ Variable decoder/logits_dense/kernel size 192086016 shape (embed=768, vocab=250112) partition spec (None, 'model')
171
+ Variable decoder/relpos_bias/rel_embedding size 384 shape (heads=12, relpos_buckets=32) partition spec ('model', None)
172
+ Variable encoder/encoder_norm/scale size 768 shape (embed=768) partition spec (None,)
173
+ Variable encoder/layers_0/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
174
+ Variable encoder/layers_0/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
175
+ Variable encoder/layers_0/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
176
+ Variable encoder/layers_0/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
177
+ Variable encoder/layers_0/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
178
+ Variable encoder/layers_0/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
179
+ Variable encoder/layers_0/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
180
+ Variable encoder/layers_0/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
181
+ Variable encoder/layers_0/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
182
+ Variable encoder/layers_1/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
183
+ Variable encoder/layers_1/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
184
+ Variable encoder/layers_1/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
185
+ Variable encoder/layers_1/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
186
+ Variable encoder/layers_1/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
187
+ Variable encoder/layers_1/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
188
+ Variable encoder/layers_1/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
189
+ Variable encoder/layers_1/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
190
+ Variable encoder/layers_1/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
191
+ Variable encoder/layers_10/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
192
+ Variable encoder/layers_10/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
193
+ Variable encoder/layers_10/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
194
+ Variable encoder/layers_10/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
195
+ Variable encoder/layers_10/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
196
+ Variable encoder/layers_10/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
197
+ Variable encoder/layers_10/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
198
+ Variable encoder/layers_10/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
199
+ Variable encoder/layers_10/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
200
+ Variable encoder/layers_11/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
201
+ Variable encoder/layers_11/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
202
+ Variable encoder/layers_11/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
203
+ Variable encoder/layers_11/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
204
+ Variable encoder/layers_11/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
205
+ Variable encoder/layers_11/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
206
+ Variable encoder/layers_11/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
207
+ Variable encoder/layers_11/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
208
+ Variable encoder/layers_11/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
209
+ Variable encoder/layers_2/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
210
+ Variable encoder/layers_2/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
211
+ Variable encoder/layers_2/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
212
+ Variable encoder/layers_2/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
213
+ Variable encoder/layers_2/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
214
+ Variable encoder/layers_2/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
215
+ Variable encoder/layers_2/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
216
+ Variable encoder/layers_2/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
217
+ Variable encoder/layers_2/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
218
+ Variable encoder/layers_3/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
219
+ Variable encoder/layers_3/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
220
+ Variable encoder/layers_3/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
221
+ Variable encoder/layers_3/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
222
+ Variable encoder/layers_3/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
223
+ Variable encoder/layers_3/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
224
+ Variable encoder/layers_3/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
225
+ Variable encoder/layers_3/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
226
+ Variable encoder/layers_3/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
227
+ Variable encoder/layers_4/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
228
+ Variable encoder/layers_4/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
229
+ Variable encoder/layers_4/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
230
+ Variable encoder/layers_4/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
231
+ Variable encoder/layers_4/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
232
+ Variable encoder/layers_4/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
233
+ Variable encoder/layers_4/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
234
+ Variable encoder/layers_4/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
235
+ Variable encoder/layers_4/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
236
+ Variable encoder/layers_5/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
237
+ Variable encoder/layers_5/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
238
+ Variable encoder/layers_5/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
239
+ Variable encoder/layers_5/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
240
+ Variable encoder/layers_5/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
241
+ Variable encoder/layers_5/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
242
+ Variable encoder/layers_5/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
243
+ Variable encoder/layers_5/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
244
+ Variable encoder/layers_5/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
245
+ Variable encoder/layers_6/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
246
+ Variable encoder/layers_6/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
247
+ Variable encoder/layers_6/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
248
+ Variable encoder/layers_6/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
249
+ Variable encoder/layers_6/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
250
+ Variable encoder/layers_6/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
251
+ Variable encoder/layers_6/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
252
+ Variable encoder/layers_6/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
253
+ Variable encoder/layers_6/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
254
+ Variable encoder/layers_7/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
255
+ Variable encoder/layers_7/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
256
+ Variable encoder/layers_7/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
257
+ Variable encoder/layers_7/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
258
+ Variable encoder/layers_7/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
259
+ Variable encoder/layers_7/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
260
+ Variable encoder/layers_7/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
261
+ Variable encoder/layers_7/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
262
+ Variable encoder/layers_7/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
263
+ Variable encoder/layers_8/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
264
+ Variable encoder/layers_8/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
265
+ Variable encoder/layers_8/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
266
+ Variable encoder/layers_8/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
267
+ Variable encoder/layers_8/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
268
+ Variable encoder/layers_8/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
269
+ Variable encoder/layers_8/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
270
+ Variable encoder/layers_8/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
271
+ Variable encoder/layers_8/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
272
+ Variable encoder/layers_9/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
273
+ Variable encoder/layers_9/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
274
+ Variable encoder/layers_9/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
275
+ Variable encoder/layers_9/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
276
+ Variable encoder/layers_9/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
277
+ Variable encoder/layers_9/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
278
+ Variable encoder/layers_9/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
279
+ Variable encoder/layers_9/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
280
+ Variable encoder/layers_9/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
281
+ Variable encoder/relpos_bias/rel_embedding size 384 shape (heads=12, relpos_buckets=32) partition spec ('model', None)
282
+ Variable token_embedder/embedding size 192086016 shape (vocab=250112, embed=768) partition spec ('model', None)
283
+ Total number of parameters: 582401280
284
+
285
+ Variable step size 1 shape () partition spec None
tasks.py CHANGED
@@ -5,6 +5,7 @@ import seqio
5
  import tensorflow_datasets as tfds
6
  from t5.evaluation import metrics
7
  from t5.data import preprocessors
 
8
  import t5
9
  import tensorflow.compat.v1 as tf
10
 
@@ -59,7 +60,7 @@ seqio.TaskRegistry.add(
59
  categorise_preprocessor,
60
  seqio.preprocessors.tokenize_and_append_eos,
61
  ],
62
- #metric_fns=[metrics.bleu],
63
  output_features=DEFAULT_OUTPUT_FEATURES,
64
  )
65
 
 
5
  import tensorflow_datasets as tfds
6
  from t5.evaluation import metrics
7
  from t5.data import preprocessors
8
+ from t5.data import postprocessors
9
  import t5
10
  import tensorflow.compat.v1 as tf
11
 
 
60
  categorise_preprocessor,
61
  seqio.preprocessors.tokenize_and_append_eos,
62
  ],
63
+ metric_fns=[metrics.accuracy],
64
  output_features=DEFAULT_OUTPUT_FEATURES,
65
  )
66
 
train_base.sh CHANGED
@@ -1,7 +1,7 @@
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
  #Needs to be updated when moving to tpu-v4 it should then be in another zone
4
- MODEL_DIR="gs://nb-t5x/eval_norwegian_1_163_000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \
 
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
  #Needs to be updated when moving to tpu-v4 it should then be in another zone
4
+ MODEL_DIR="gs://nb-t5x/eval_norwegian_1_194_000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \