pere commited on
Commit
ced8bf6
1 Parent(s): f5d1117
__pycache__/tasks.cpython-38.pyc CHANGED
Binary files a/__pycache__/tasks.cpython-38.pyc and b/__pycache__/tasks.cpython-38.pyc differ
 
eval.py DELETED
@@ -1,238 +0,0 @@
1
- # Copyright 2022 The T5X Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # pylint:disable=line-too-long
16
- # pyformat: disable
17
- r"""This script runs inference-evaluation on a T5X-compatible model.
18
-
19
- """
20
- # pyformat: enable
21
- # pylint:enable=line-too-long
22
-
23
- import functools
24
- import os
25
- from typing import Optional, Sequence, Type
26
-
27
- # pylint:disable=g-import-not-at-top
28
- # TODO(adarob): Re-enable once users are notified and tests are updated.
29
- os.environ['FLAX_LAZY_RNG'] = 'no'
30
- from absl import logging
31
- from clu import metric_writers
32
- import jax
33
- from jax.experimental import multihost_utils
34
- import seqio
35
- from t5x import gin_utils
36
- from t5x import models
37
- from t5x import partitioning
38
- from t5x import utils
39
- from typing_extensions import Protocol
40
-
41
- # Automatically search for gin files relative to the T5X package.
42
- _DEFAULT_GIN_SEARCH_PATHS = [
43
- os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
44
- ]
45
-
46
-
47
- class SummarizeConfigFn(Protocol):
48
-
49
- def __call__(self, model_dir: str,
50
- summary_writer: Optional[metric_writers.SummaryWriter],
51
- step: int) -> None:
52
- ...
53
-
54
-
55
- def evaluate(
56
- *,
57
- model: models.BaseTransformerModel,
58
- dataset_cfg: utils.DatasetConfig,
59
- restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
60
- partitioner: partitioning.BasePartitioner,
61
- output_dir: str,
62
- inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator,
63
- summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config,
64
- fallback_init_rng: Optional[int] = None):
65
- """Evaluation function.
66
-
67
- Args:
68
- model: The model object to use for inference.
69
- dataset_cfg: Specification for the dataset to infer based on.
70
- restore_checkpoint_cfg: Specification for the model parameter checkpoint to
71
- load.
72
- partitioner: Partitioner for the model parameters and data across devices.
73
- output_dir: Path to directory to write temporary files and final results.
74
- inference_evaluator_cls: seqio.Evaluator class to use for inference
75
- evaluation, potentially with bound configuration args.
76
- summarize_config_fn: A function that takes in the model directory, an
77
- optional SummaryWriter, and the step number, and writes a summary of the
78
- configuration. SummaryWriter will be None in most cases.
79
- fallback_init_rng: A random seed used for parameter initialization during
80
- model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is
81
- set to True. If None, parameter initialization is not allowed during model
82
- loading and having fallback_to_scratch enabled will result in an error.
83
- """
84
- logging.info('Process ID: %d', jax.process_index())
85
- if dataset_cfg.module:
86
- utils.import_module(dataset_cfg.module)
87
- batch_size = dataset_cfg.batch_size
88
-
89
- summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0)
90
-
91
- ds_vocabs = utils.get_vocabulary(dataset_cfg)
92
- if (ds_vocabs[0] != model.input_vocabulary or
93
- ds_vocabs[1] != model.output_vocabulary):
94
- raise ValueError(f'Model and Task vocabularies do not match:\n'
95
- f' task={dataset_cfg.mixture_or_task_name}\n'
96
- f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n'
97
- f' model.input_vocabulary={model.input_vocabulary}\n'
98
- f' model.output_vocabulary={model.output_vocabulary}\n')
99
-
100
- # ----------------------------------------------------------------------------
101
- # SeqIO (inference-based) evaluation setup
102
- # ----------------------------------------------------------------------------
103
- # Init evaluator to set up cached datasets
104
- evaluator = inference_evaluator_cls(
105
- mixture_or_task_name=dataset_cfg.mixture_or_task_name,
106
- feature_converter=model.FEATURE_CONVERTER_CLS(pack=False),
107
- eval_split=dataset_cfg.split,
108
- use_cached=dataset_cfg.use_cached,
109
- seed=dataset_cfg.seed,
110
- sequence_length=dataset_cfg.task_feature_lengths,
111
- log_dir=os.path.join(output_dir, 'inference_eval'))
112
- if not evaluator.eval_tasks:
113
- raise ValueError(
114
- f"'{dataset_cfg.mixture_or_task_name}' has no metrics for evaluation.")
115
-
116
- # ----------------------------------------------------------------------------
117
- # T5X model loading.
118
- # ----------------------------------------------------------------------------
119
-
120
- # Initialize optimizer from the existing checkpoint.
121
- input_shapes = {
122
- k: (batch_size,) + s for k, s in evaluator.model_feature_shapes.items()
123
- }
124
-
125
- train_state_initializer = utils.TrainStateInitializer(
126
- optimizer_def=None, # Do not load optimizer state.
127
- init_fn=model.get_initial_variables,
128
- input_shapes=input_shapes,
129
- partitioner=partitioner)
130
- train_state_axes = train_state_initializer.train_state_axes
131
- # Log the variable shapes information and write to a file.
132
- log_file = os.path.join(output_dir, 'model-info.txt')
133
- utils.log_model_info(log_file,
134
- train_state_initializer.global_train_state_shape,
135
- partitioner)
136
-
137
- predict_fn = None
138
- score_fn = None
139
-
140
- # Disable strictness since we are dropping the optimizer state.
141
- restore_checkpoint_cfg.strict = False
142
-
143
- if fallback_init_rng is not None:
144
- fallback_init_rng = jax.random.PRNGKey(fallback_init_rng)
145
- for train_state in train_state_initializer.from_checkpoints(
146
- [restore_checkpoint_cfg], init_rng=fallback_init_rng):
147
-
148
- # Compile the model only once.
149
- if not predict_fn:
150
- predict_fn = utils.get_infer_fn(
151
- infer_step=model.predict_batch,
152
- batch_size=batch_size,
153
- train_state_axes=train_state_axes,
154
- partitioner=partitioner)
155
-
156
- score_fn = utils.get_infer_fn(
157
- infer_step=model.score_batch,
158
- batch_size=batch_size,
159
- train_state_axes=train_state_axes,
160
- partitioner=partitioner)
161
-
162
- # ----------------------------------------------------------------------------
163
- # Main training loop
164
- # ----------------------------------------------------------------------------
165
-
166
- # Run final evaluation (with decoding) on the full eval dataset.
167
- all_metrics, _, _ = evaluator.evaluate(
168
- compute_metrics=jax.process_index() == 0,
169
- step=int(train_state.step),
170
- predict_fn=functools.partial(
171
- predict_fn, train_state=train_state, rng=jax.random.PRNGKey(0)),
172
- score_fn=functools.partial(score_fn, train_state=train_state))
173
- all_metrics.result() # Ensure metrics are finished being computed.
174
- # Wait until computations are done before continuing.
175
- multihost_utils.sync_global_devices(f'step_{train_state.step}:complete')
176
- print(all_metrics.result())
177
-
178
- logging.info('Finished.')
179
-
180
-
181
- if __name__ == '__main__':
182
- from absl import app
183
- from absl import flags
184
- import gin
185
-
186
- FLAGS = flags.FLAGS
187
-
188
- jax.config.parse_flags_with_absl()
189
-
190
- flags.DEFINE_multi_string(
191
- 'gin_file',
192
- default=None,
193
- help='Path to gin configuration file. Multiple paths may be passed and '
194
- 'will be imported in the given order, with later configurations '
195
- 'overriding earlier ones.')
196
-
197
- flags.DEFINE_multi_string(
198
- 'gin_bindings', default=[], help='Individual gin bindings.')
199
-
200
- flags.DEFINE_list(
201
- 'gin_search_paths',
202
- default=['.'],
203
- help='Comma-separated list of gin config path prefixes to be prepended '
204
- 'to suffixes given via `--gin_file`. If a file appears in. Only the '
205
- 'first prefix that produces a valid path for each suffix will be '
206
- 'used.')
207
-
208
- flags.DEFINE_string(
209
- 'tfds_data_dir', None,
210
- 'If set, this directory will be used to store datasets prepared by '
211
- 'TensorFlow Datasets that are not available in the public TFDS GCS '
212
- 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
213
- 'all `Task`s.')
214
-
215
-
216
- def main(argv: Sequence[str]):
217
- """Wrapper for pdb post mortems."""
218
- _main(argv)
219
-
220
- def _main(argv: Sequence[str]):
221
- """True main function."""
222
- if len(argv) > 1:
223
- raise app.UsageError('Too many command-line arguments.')
224
-
225
- if FLAGS.tfds_data_dir:
226
- seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
227
-
228
- # Create gin-configurable version of `eval`.
229
- evaluate_using_gin = gin.configurable(evaluate)
230
-
231
- gin_utils.parse_gin_flags(
232
- # User-provided gin paths take precedence if relative paths conflict.
233
- FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
234
- FLAGS.gin_file,
235
- FLAGS.gin_bindings)
236
- evaluate_using_gin()
237
-
238
- gin_utils.run(main)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval_base.sh CHANGED
@@ -1,11 +1,10 @@
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
- EVAL_OUTPUT_DIR="/home/perk/models/t5-parliament-categorisation/results"
3
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
4
- CHECKPOINT_PATH="gs://nb-t5x/eval_norwegian_1_163_000/checkpoint_1635000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
- #python3 ${T5X_DIR}/t5x/eval.py \
8
- python3 ./eval.py \
9
  --gin_search_paths=${PROJECT_DIR} \
10
  --gin_file="eval_categorisation_base.gin" \
11
  --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
 
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
+ EVAL_OUTPUT_DIR="gs://nb-t5x/eval/"
3
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
4
+ CHECKPOINT_PATH="gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1043000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
+ python3 ${T5X_DIR}/t5x/eval.py \
 
8
  --gin_search_paths=${PROJECT_DIR} \
9
  --gin_file="eval_categorisation_base.gin" \
10
  --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
eval_categorisation_base.gin CHANGED
@@ -23,8 +23,8 @@ eval_script.evaluate:
23
 
24
  utils.DatasetConfig:
25
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
26
- task_feature_lengths = {"inputs": 512, "targets": 2} # Auto-computes the max feature lengths.
27
- split = 'validation'
28
  batch_size = 32
29
  shuffle = False
30
  seed = 42
 
23
 
24
  utils.DatasetConfig:
25
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
26
+ task_feature_lengths = None # Auto-computes the max feature lengths.
27
+ split = 'test'
28
  batch_size = 32
29
  shuffle = False
30
  seed = 42
finetune_categorisation_base.gin CHANGED
@@ -12,7 +12,7 @@ include "t5x/configs/runs/finetune.gin"
12
 
13
  MIXTURE_OR_TASK_NAME = "categorise"
14
  TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
15
- TRAIN_STEPS = 1_953_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
16
  USE_CACHED_TASKS = False
17
  DROPOUT_RATE = 0.1
18
  RANDOM_SEED = 0
@@ -29,7 +29,7 @@ RANDOM_SEED = 0
29
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1360000"
30
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1_lr_1/checkpoint_1100000"
31
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1100000"
32
- INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1948000"
33
 
34
  #train_script.train:
35
  # eval_period = 500
@@ -37,7 +37,7 @@ INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint
37
  # partitioning.PjitPartitioner.num_partitions = 1
38
 
39
  # `num_decodes` is equivalent to a beam size in a beam search decoding.
40
- models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
41
 
42
  #mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
43
  #run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
 
12
 
13
  MIXTURE_OR_TASK_NAME = "categorise"
14
  TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
15
+ TRAIN_STEPS = 1_635_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
16
  USE_CACHED_TASKS = False
17
  DROPOUT_RATE = 0.1
18
  RANDOM_SEED = 0
 
29
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1360000"
30
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1_lr_1/checkpoint_1100000"
31
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1100000"
32
+ INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1630000"
33
 
34
  #train_script.train:
35
  # eval_period = 500
 
37
  # partitioning.PjitPartitioner.num_partitions = 1
38
 
39
  # `num_decodes` is equivalent to a beam size in a beam search decoding.
40
+ # models.EncoderDecoderModel.predict_batch_with_aux.num_decodes = 1
41
 
42
  #mesh_transformer.learning_rate_schedules.constant_learning_rate.learning_rate = 0.0005
43
  #run.learning_rate_schedule = @learning_rate_schedules.constant_learning_rate
results/config.gin DELETED
@@ -1,89 +0,0 @@
1
- from __gin__ import dynamic_registration
2
- import __main__ as eval_script
3
- import seqio
4
- from t5.data import mixtures
5
- from t5x import adafactor
6
- from t5x.examples.t5 import network
7
- from t5x import models
8
- from t5x import partitioning
9
- from t5x import utils
10
- import tasks
11
-
12
- # Macros:
13
- # ==============================================================================
14
- CHECKPOINT_PATH = 'gs://nb-t5x/eval_norwegian_1_163_000/checkpoint_1635000'
15
- DROPOUT_RATE = 0.0
16
- EVAL_OUTPUT_DIR = '/home/perk/models/t5-parliament-categorisation/results'
17
- LABEL_SMOOTHING = 0.0
18
- LOSS_NORMALIZING_FACTOR = None
19
- MIXTURE_OR_TASK_NAME = 'categorise'
20
- MODEL = @models.EncoderDecoderModel()
21
- OPTIMIZER = @adafactor.Adafactor()
22
- VOCABULARY = @seqio.SentencePieceVocabulary()
23
- Z_LOSS = 0.0001
24
-
25
- # Parameters for adafactor.Adafactor:
26
- # ==============================================================================
27
- adafactor.Adafactor.decay_rate = 0.8
28
- adafactor.Adafactor.logical_factor_rules = \
29
- @adafactor.standard_logical_factor_rules()
30
- adafactor.Adafactor.step_offset = 0
31
-
32
- # Parameters for utils.DatasetConfig:
33
- # ==============================================================================
34
- utils.DatasetConfig.batch_size = 32
35
- utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME
36
- utils.DatasetConfig.seed = 42
37
- utils.DatasetConfig.shuffle = False
38
- utils.DatasetConfig.split = 'test'
39
- utils.DatasetConfig.task_feature_lengths = {'inputs': 512, 'targets': 2}
40
-
41
- # Parameters for models.EncoderDecoderModel:
42
- # ==============================================================================
43
- models.EncoderDecoderModel.input_vocabulary = %VOCABULARY
44
- models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING
45
- models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR
46
- models.EncoderDecoderModel.module = @network.Transformer()
47
- models.EncoderDecoderModel.optimizer_def = %OPTIMIZER
48
- models.EncoderDecoderModel.output_vocabulary = %VOCABULARY
49
- models.EncoderDecoderModel.z_loss = %Z_LOSS
50
-
51
- # Parameters for eval_script.evaluate:
52
- # ==============================================================================
53
- eval_script.evaluate.dataset_cfg = @utils.DatasetConfig()
54
- eval_script.evaluate.model = %MODEL
55
- eval_script.evaluate.output_dir = %EVAL_OUTPUT_DIR
56
- eval_script.evaluate.partitioner = @partitioning.PjitPartitioner()
57
- eval_script.evaluate.restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
58
-
59
- # Parameters for partitioning.PjitPartitioner:
60
- # ==============================================================================
61
- partitioning.PjitPartitioner.num_partitions = 2
62
-
63
- # Parameters for utils.RestoreCheckpointConfig:
64
- # ==============================================================================
65
- utils.RestoreCheckpointConfig.mode = 'specific'
66
- utils.RestoreCheckpointConfig.path = %CHECKPOINT_PATH
67
-
68
- # Parameters for seqio.SentencePieceVocabulary:
69
- # ==============================================================================
70
- seqio.SentencePieceVocabulary.sentencepiece_model_file = \
71
- 'gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model'
72
-
73
- # Parameters for network.T5Config:
74
- # ==============================================================================
75
- network.T5Config.dropout_rate = %DROPOUT_RATE
76
- network.T5Config.dtype = 'bfloat16'
77
- network.T5Config.emb_dim = 768
78
- network.T5Config.head_dim = 64
79
- network.T5Config.logits_via_embedding = False
80
- network.T5Config.mlp_activations = ('gelu', 'linear')
81
- network.T5Config.mlp_dim = 2048
82
- network.T5Config.num_decoder_layers = 12
83
- network.T5Config.num_encoder_layers = 12
84
- network.T5Config.num_heads = 12
85
- network.T5Config.vocab_size = 250112
86
-
87
- # Parameters for network.Transformer:
88
- # ==============================================================================
89
- network.Transformer.config = @network.T5Config()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
results/model-info.txt DELETED
@@ -1,285 +0,0 @@
1
- Variable decoder/decoder_norm/scale size 768 shape (embed=768) partition spec (None,)
2
- Variable decoder/layers_0/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
3
- Variable decoder/layers_0/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
4
- Variable decoder/layers_0/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
5
- Variable decoder/layers_0/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
6
- Variable decoder/layers_0/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
7
- Variable decoder/layers_0/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
8
- Variable decoder/layers_0/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
9
- Variable decoder/layers_0/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
10
- Variable decoder/layers_0/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
11
- Variable decoder/layers_0/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
12
- Variable decoder/layers_0/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
13
- Variable decoder/layers_0/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
14
- Variable decoder/layers_0/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
15
- Variable decoder/layers_0/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
16
- Variable decoder/layers_1/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
17
- Variable decoder/layers_1/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
18
- Variable decoder/layers_1/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
19
- Variable decoder/layers_1/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
20
- Variable decoder/layers_1/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
21
- Variable decoder/layers_1/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
22
- Variable decoder/layers_1/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
23
- Variable decoder/layers_1/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
24
- Variable decoder/layers_1/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
25
- Variable decoder/layers_1/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
26
- Variable decoder/layers_1/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
27
- Variable decoder/layers_1/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
28
- Variable decoder/layers_1/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
29
- Variable decoder/layers_1/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
30
- Variable decoder/layers_10/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
31
- Variable decoder/layers_10/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
32
- Variable decoder/layers_10/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
33
- Variable decoder/layers_10/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
34
- Variable decoder/layers_10/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
35
- Variable decoder/layers_10/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
36
- Variable decoder/layers_10/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
37
- Variable decoder/layers_10/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
38
- Variable decoder/layers_10/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
39
- Variable decoder/layers_10/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
40
- Variable decoder/layers_10/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
41
- Variable decoder/layers_10/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
42
- Variable decoder/layers_10/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
43
- Variable decoder/layers_10/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
44
- Variable decoder/layers_11/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
45
- Variable decoder/layers_11/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
46
- Variable decoder/layers_11/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
47
- Variable decoder/layers_11/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
48
- Variable decoder/layers_11/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
49
- Variable decoder/layers_11/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
50
- Variable decoder/layers_11/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
51
- Variable decoder/layers_11/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
52
- Variable decoder/layers_11/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
53
- Variable decoder/layers_11/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
54
- Variable decoder/layers_11/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
55
- Variable decoder/layers_11/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
56
- Variable decoder/layers_11/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
57
- Variable decoder/layers_11/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
58
- Variable decoder/layers_2/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
59
- Variable decoder/layers_2/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
60
- Variable decoder/layers_2/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
61
- Variable decoder/layers_2/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
62
- Variable decoder/layers_2/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
63
- Variable decoder/layers_2/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
64
- Variable decoder/layers_2/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
65
- Variable decoder/layers_2/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
66
- Variable decoder/layers_2/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
67
- Variable decoder/layers_2/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
68
- Variable decoder/layers_2/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
69
- Variable decoder/layers_2/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
70
- Variable decoder/layers_2/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
71
- Variable decoder/layers_2/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
72
- Variable decoder/layers_3/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
73
- Variable decoder/layers_3/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
74
- Variable decoder/layers_3/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
75
- Variable decoder/layers_3/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
76
- Variable decoder/layers_3/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
77
- Variable decoder/layers_3/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
78
- Variable decoder/layers_3/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
79
- Variable decoder/layers_3/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
80
- Variable decoder/layers_3/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
81
- Variable decoder/layers_3/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
82
- Variable decoder/layers_3/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
83
- Variable decoder/layers_3/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
84
- Variable decoder/layers_3/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
85
- Variable decoder/layers_3/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
86
- Variable decoder/layers_4/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
87
- Variable decoder/layers_4/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
88
- Variable decoder/layers_4/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
89
- Variable decoder/layers_4/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
90
- Variable decoder/layers_4/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
91
- Variable decoder/layers_4/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
92
- Variable decoder/layers_4/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
93
- Variable decoder/layers_4/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
94
- Variable decoder/layers_4/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
95
- Variable decoder/layers_4/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
96
- Variable decoder/layers_4/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
97
- Variable decoder/layers_4/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
98
- Variable decoder/layers_4/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
99
- Variable decoder/layers_4/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
100
- Variable decoder/layers_5/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
101
- Variable decoder/layers_5/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
102
- Variable decoder/layers_5/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
103
- Variable decoder/layers_5/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
104
- Variable decoder/layers_5/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
105
- Variable decoder/layers_5/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
106
- Variable decoder/layers_5/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
107
- Variable decoder/layers_5/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
108
- Variable decoder/layers_5/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
109
- Variable decoder/layers_5/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
110
- Variable decoder/layers_5/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
111
- Variable decoder/layers_5/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
112
- Variable decoder/layers_5/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
113
- Variable decoder/layers_5/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
114
- Variable decoder/layers_6/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
115
- Variable decoder/layers_6/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
116
- Variable decoder/layers_6/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
117
- Variable decoder/layers_6/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
118
- Variable decoder/layers_6/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
119
- Variable decoder/layers_6/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
120
- Variable decoder/layers_6/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
121
- Variable decoder/layers_6/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
122
- Variable decoder/layers_6/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
123
- Variable decoder/layers_6/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
124
- Variable decoder/layers_6/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
125
- Variable decoder/layers_6/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
126
- Variable decoder/layers_6/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
127
- Variable decoder/layers_6/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
128
- Variable decoder/layers_7/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
129
- Variable decoder/layers_7/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
130
- Variable decoder/layers_7/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
131
- Variable decoder/layers_7/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
132
- Variable decoder/layers_7/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
133
- Variable decoder/layers_7/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
134
- Variable decoder/layers_7/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
135
- Variable decoder/layers_7/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
136
- Variable decoder/layers_7/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
137
- Variable decoder/layers_7/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
138
- Variable decoder/layers_7/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
139
- Variable decoder/layers_7/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
140
- Variable decoder/layers_7/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
141
- Variable decoder/layers_7/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
142
- Variable decoder/layers_8/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
143
- Variable decoder/layers_8/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
144
- Variable decoder/layers_8/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
145
- Variable decoder/layers_8/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
146
- Variable decoder/layers_8/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
147
- Variable decoder/layers_8/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
148
- Variable decoder/layers_8/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
149
- Variable decoder/layers_8/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
150
- Variable decoder/layers_8/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
151
- Variable decoder/layers_8/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
152
- Variable decoder/layers_8/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
153
- Variable decoder/layers_8/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
154
- Variable decoder/layers_8/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
155
- Variable decoder/layers_8/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
156
- Variable decoder/layers_9/encoder_decoder_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
157
- Variable decoder/layers_9/encoder_decoder_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
158
- Variable decoder/layers_9/encoder_decoder_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
159
- Variable decoder/layers_9/encoder_decoder_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
160
- Variable decoder/layers_9/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
161
- Variable decoder/layers_9/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
162
- Variable decoder/layers_9/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
163
- Variable decoder/layers_9/pre_cross_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
164
- Variable decoder/layers_9/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
165
- Variable decoder/layers_9/pre_self_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
166
- Variable decoder/layers_9/self_attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
167
- Variable decoder/layers_9/self_attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
168
- Variable decoder/layers_9/self_attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
169
- Variable decoder/layers_9/self_attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
170
- Variable decoder/logits_dense/kernel size 192086016 shape (embed=768, vocab=250112) partition spec (None, 'model')
171
- Variable decoder/relpos_bias/rel_embedding size 384 shape (heads=12, relpos_buckets=32) partition spec ('model', None)
172
- Variable encoder/encoder_norm/scale size 768 shape (embed=768) partition spec (None,)
173
- Variable encoder/layers_0/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
174
- Variable encoder/layers_0/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
175
- Variable encoder/layers_0/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
176
- Variable encoder/layers_0/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
177
- Variable encoder/layers_0/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
178
- Variable encoder/layers_0/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
179
- Variable encoder/layers_0/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
180
- Variable encoder/layers_0/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
181
- Variable encoder/layers_0/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
182
- Variable encoder/layers_1/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
183
- Variable encoder/layers_1/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
184
- Variable encoder/layers_1/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
185
- Variable encoder/layers_1/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
186
- Variable encoder/layers_1/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
187
- Variable encoder/layers_1/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
188
- Variable encoder/layers_1/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
189
- Variable encoder/layers_1/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
190
- Variable encoder/layers_1/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
191
- Variable encoder/layers_10/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
192
- Variable encoder/layers_10/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
193
- Variable encoder/layers_10/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
194
- Variable encoder/layers_10/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
195
- Variable encoder/layers_10/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
196
- Variable encoder/layers_10/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
197
- Variable encoder/layers_10/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
198
- Variable encoder/layers_10/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
199
- Variable encoder/layers_10/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
200
- Variable encoder/layers_11/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
201
- Variable encoder/layers_11/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
202
- Variable encoder/layers_11/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
203
- Variable encoder/layers_11/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
204
- Variable encoder/layers_11/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
205
- Variable encoder/layers_11/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
206
- Variable encoder/layers_11/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
207
- Variable encoder/layers_11/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
208
- Variable encoder/layers_11/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
209
- Variable encoder/layers_2/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
210
- Variable encoder/layers_2/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
211
- Variable encoder/layers_2/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
212
- Variable encoder/layers_2/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
213
- Variable encoder/layers_2/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
214
- Variable encoder/layers_2/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
215
- Variable encoder/layers_2/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
216
- Variable encoder/layers_2/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
217
- Variable encoder/layers_2/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
218
- Variable encoder/layers_3/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
219
- Variable encoder/layers_3/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
220
- Variable encoder/layers_3/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
221
- Variable encoder/layers_3/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
222
- Variable encoder/layers_3/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
223
- Variable encoder/layers_3/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
224
- Variable encoder/layers_3/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
225
- Variable encoder/layers_3/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
226
- Variable encoder/layers_3/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
227
- Variable encoder/layers_4/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
228
- Variable encoder/layers_4/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
229
- Variable encoder/layers_4/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
230
- Variable encoder/layers_4/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
231
- Variable encoder/layers_4/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
232
- Variable encoder/layers_4/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
233
- Variable encoder/layers_4/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
234
- Variable encoder/layers_4/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
235
- Variable encoder/layers_4/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
236
- Variable encoder/layers_5/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
237
- Variable encoder/layers_5/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
238
- Variable encoder/layers_5/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
239
- Variable encoder/layers_5/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
240
- Variable encoder/layers_5/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
241
- Variable encoder/layers_5/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
242
- Variable encoder/layers_5/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
243
- Variable encoder/layers_5/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
244
- Variable encoder/layers_5/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
245
- Variable encoder/layers_6/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
246
- Variable encoder/layers_6/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
247
- Variable encoder/layers_6/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
248
- Variable encoder/layers_6/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
249
- Variable encoder/layers_6/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
250
- Variable encoder/layers_6/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
251
- Variable encoder/layers_6/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
252
- Variable encoder/layers_6/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
253
- Variable encoder/layers_6/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
254
- Variable encoder/layers_7/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
255
- Variable encoder/layers_7/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
256
- Variable encoder/layers_7/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
257
- Variable encoder/layers_7/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
258
- Variable encoder/layers_7/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
259
- Variable encoder/layers_7/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
260
- Variable encoder/layers_7/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
261
- Variable encoder/layers_7/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
262
- Variable encoder/layers_7/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
263
- Variable encoder/layers_8/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
264
- Variable encoder/layers_8/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
265
- Variable encoder/layers_8/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
266
- Variable encoder/layers_8/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
267
- Variable encoder/layers_8/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
268
- Variable encoder/layers_8/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
269
- Variable encoder/layers_8/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
270
- Variable encoder/layers_8/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
271
- Variable encoder/layers_8/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
272
- Variable encoder/layers_9/attention/key/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
273
- Variable encoder/layers_9/attention/out/kernel size 589824 shape (joined_kv=768, embed=768) partition spec ('model', None)
274
- Variable encoder/layers_9/attention/query/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
275
- Variable encoder/layers_9/attention/value/kernel size 589824 shape (embed=768, joined_kv=768) partition spec (None, 'model')
276
- Variable encoder/layers_9/mlp/wi_0/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
277
- Variable encoder/layers_9/mlp/wi_1/kernel size 1572864 shape (embed=768, mlp=2048) partition spec (None, 'model')
278
- Variable encoder/layers_9/mlp/wo/kernel size 1572864 shape (mlp=2048, embed=768) partition spec ('model', None)
279
- Variable encoder/layers_9/pre_attention_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
280
- Variable encoder/layers_9/pre_mlp_layer_norm/scale size 768 shape (embed=768) partition spec (None,)
281
- Variable encoder/relpos_bias/rel_embedding size 384 shape (heads=12, relpos_buckets=32) partition spec ('model', None)
282
- Variable token_embedder/embedding size 192086016 shape (vocab=250112, embed=768) partition spec ('model', None)
283
- Total number of parameters: 582401280
284
-
285
- Variable step size 1 shape () partition spec None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tasks.py CHANGED
@@ -5,7 +5,6 @@ import seqio
5
  import tensorflow_datasets as tfds
6
  from t5.evaluation import metrics
7
  from t5.data import preprocessors
8
- from t5.data import postprocessors
9
  import t5
10
  import tensorflow.compat.v1 as tf
11
 
@@ -60,7 +59,7 @@ seqio.TaskRegistry.add(
60
  categorise_preprocessor,
61
  seqio.preprocessors.tokenize_and_append_eos,
62
  ],
63
- metric_fns=[metrics.accuracy],
64
  output_features=DEFAULT_OUTPUT_FEATURES,
65
  )
66
 
 
5
  import tensorflow_datasets as tfds
6
  from t5.evaluation import metrics
7
  from t5.data import preprocessors
 
8
  import t5
9
  import tensorflow.compat.v1 as tf
10
 
 
59
  categorise_preprocessor,
60
  seqio.preprocessors.tokenize_and_append_eos,
61
  ],
62
+ #metric_fns=[metrics.bleu],
63
  output_features=DEFAULT_OUTPUT_FEATURES,
64
  )
65
 
train_base.sh CHANGED
@@ -1,7 +1,7 @@
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
  #Needs to be updated when moving to tpu-v4 it should then be in another zone
4
- MODEL_DIR="gs://nb-t5x/eval_norwegian_1_194_000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \
 
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
  #Needs to be updated when moving to tpu-v4 it should then be in another zone
4
+ MODEL_DIR="gs://nb-t5x/eval_norwegian_1_163_000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \