pere commited on
Commit
b3a728f
1 Parent(s): ced8bf6

updated eval script

Browse files
__pycache__/tasks.cpython-38.pyc CHANGED
Binary files a/__pycache__/tasks.cpython-38.pyc and b/__pycache__/tasks.cpython-38.pyc differ
 
eval.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # pylint:disable=line-too-long
16
+ # pyformat: disable
17
+ r"""This script runs inference-evaluation on a T5X-compatible model.
18
+
19
+ """
20
+ # pyformat: enable
21
+ # pylint:enable=line-too-long
22
+
23
+ import functools
24
+ import os
25
+ import socket
26
+ from datetime import datetime
27
+ import jsonlines
28
+ from typing import Optional, Sequence, Type
29
+
30
+ # pylint:disable=g-import-not-at-top
31
+ # TODO(adarob): Re-enable once users are notified and tests are updated.
32
+ os.environ['FLAX_LAZY_RNG'] = 'no'
33
+ from absl import logging
34
+ from clu import metric_writers
35
+ import jax
36
+ from jax.experimental import multihost_utils
37
+ import seqio
38
+ from t5x import gin_utils
39
+ from t5x import models
40
+ from t5x import partitioning
41
+ from t5x import utils
42
+ from typing_extensions import Protocol
43
+
44
+ # Automatically search for gin files relative to the T5X package.
45
+ _DEFAULT_GIN_SEARCH_PATHS = [
46
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
47
+ ]
48
+
49
+
50
+ class SummarizeConfigFn(Protocol):
51
+
52
+ def __call__(self, model_dir: str,
53
+ summary_writer: Optional[metric_writers.SummaryWriter],
54
+ step: int) -> None:
55
+ ...
56
+
57
+
58
+ def evaluate(
59
+ *,
60
+ model: models.BaseTransformerModel,
61
+ dataset_cfg: utils.DatasetConfig,
62
+ restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
63
+ partitioner: partitioning.BasePartitioner,
64
+ output_dir: str,
65
+ inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator,
66
+ summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config,
67
+ fallback_init_rng: Optional[int] = None):
68
+ """Evaluation function.
69
+
70
+ Args:
71
+ model: The model object to use for inference.
72
+ dataset_cfg: Specification for the dataset to infer based on.
73
+ restore_checkpoint_cfg: Specification for the model parameter checkpoint to
74
+ load.
75
+ partitioner: Partitioner for the model parameters and data across devices.
76
+ output_dir: Path to directory to write temporary files and final results.
77
+ inference_evaluator_cls: seqio.Evaluator class to use for inference
78
+ evaluation, potentially with bound configuration args.
79
+ summarize_config_fn: A function that takes in the model directory, an
80
+ optional SummaryWriter, and the step number, and writes a summary of the
81
+ configuration. SummaryWriter will be None in most cases.
82
+ fallback_init_rng: A random seed used for parameter initialization during
83
+ model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is
84
+ set to True. If None, parameter initialization is not allowed during model
85
+ loading and having fallback_to_scratch enabled will result in an error.
86
+ """
87
+ logging.info('Process ID: %d', jax.process_index())
88
+ if dataset_cfg.module:
89
+ utils.import_module(dataset_cfg.module)
90
+ batch_size = dataset_cfg.batch_size
91
+
92
+ summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0)
93
+
94
+ ds_vocabs = utils.get_vocabulary(dataset_cfg)
95
+ if (ds_vocabs[0] != model.input_vocabulary or
96
+ ds_vocabs[1] != model.output_vocabulary):
97
+ raise ValueError(f'Model and Task vocabularies do not match:\n'
98
+ f' task={dataset_cfg.mixture_or_task_name}\n'
99
+ f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n'
100
+ f' model.input_vocabulary={model.input_vocabulary}\n'
101
+ f' model.output_vocabulary={model.output_vocabulary}\n')
102
+
103
+ # ----------------------------------------------------------------------------
104
+ # SeqIO (inference-based) evaluation setup
105
+ # ----------------------------------------------------------------------------
106
+ # Init evaluator to set up cached datasets
107
+ evaluator = inference_evaluator_cls(
108
+ mixture_or_task_name=dataset_cfg.mixture_or_task_name,
109
+ feature_converter=model.FEATURE_CONVERTER_CLS(pack=False),
110
+ eval_split=dataset_cfg.split,
111
+ use_cached=dataset_cfg.use_cached,
112
+ seed=dataset_cfg.seed,
113
+ sequence_length=dataset_cfg.task_feature_lengths,
114
+ log_dir=os.path.join(output_dir, 'inference_eval'))
115
+ if not evaluator.eval_tasks:
116
+ raise ValueError(
117
+ f"'{dataset_cfg.mixture_or_task_name}' has no metrics for evaluation.")
118
+
119
+ # ----------------------------------------------------------------------------
120
+ # T5X model loading.
121
+ # ----------------------------------------------------------------------------
122
+
123
+ # Initialize optimizer from the existing checkpoint.
124
+ input_shapes = {
125
+ k: (batch_size,) + s for k, s in evaluator.model_feature_shapes.items()
126
+ }
127
+
128
+ train_state_initializer = utils.TrainStateInitializer(
129
+ optimizer_def=None, # Do not load optimizer state.
130
+ init_fn=model.get_initial_variables,
131
+ input_shapes=input_shapes,
132
+ partitioner=partitioner)
133
+ train_state_axes = train_state_initializer.train_state_axes
134
+ # Log the variable shapes information and write to a file.
135
+ log_file = os.path.join(output_dir, 'model-info.txt')
136
+ utils.log_model_info(log_file,
137
+ train_state_initializer.global_train_state_shape,
138
+ partitioner)
139
+
140
+ predict_fn = None
141
+ score_fn = None
142
+
143
+ # Disable strictness since we are dropping the optimizer state.
144
+ restore_checkpoint_cfg.strict = False
145
+
146
+ if fallback_init_rng is not None:
147
+ fallback_init_rng = jax.random.PRNGKey(fallback_init_rng)
148
+ for train_state in train_state_initializer.from_checkpoints(
149
+ [restore_checkpoint_cfg], init_rng=fallback_init_rng):
150
+
151
+ # Compile the model only once.
152
+ if not predict_fn:
153
+ predict_fn = utils.get_infer_fn(
154
+ infer_step=model.predict_batch,
155
+ batch_size=batch_size,
156
+ train_state_axes=train_state_axes,
157
+ partitioner=partitioner)
158
+
159
+ score_fn = utils.get_infer_fn(
160
+ infer_step=model.score_batch,
161
+ batch_size=batch_size,
162
+ train_state_axes=train_state_axes,
163
+ partitioner=partitioner)
164
+
165
+ # ----------------------------------------------------------------------------
166
+ # Main training loop
167
+ # ----------------------------------------------------------------------------
168
+
169
+ # Run final evaluation (with decoding) on the full eval dataset.
170
+ all_metrics, _, _ = evaluator.evaluate(
171
+ compute_metrics=jax.process_index() == 0,
172
+ step=int(train_state.step),
173
+ predict_fn=functools.partial(
174
+ predict_fn, train_state=train_state, rng=jax.random.PRNGKey(0)),
175
+ score_fn=functools.partial(score_fn, train_state=train_state))
176
+ all_metrics.result() # Ensure metrics are finished being computed.
177
+ # Wait until computations are done before continuing.
178
+ multihost_utils.sync_global_devices(f'step_{train_state.step}:complete')
179
+
180
+ ## Write this to the local log directory
181
+ now = datetime.now()
182
+ logtime = now.strftime("%d-%m-%Y %H:%M:%S")
183
+
184
+ if not os.path.exists("log"):
185
+ os.makedirs("log")
186
+
187
+ logname ="./log/"+"eval_results_"+socket.gethostname()+".jsonl"
188
+
189
+ output = {}
190
+ output["model"] = restore_checkpoint_cfg.path
191
+ output["eval_date"] = logtime
192
+ output["split"] = dataset_cfg.split
193
+ output["result"] = all_metrics.result()[dataset_cfg.mixture_or_task_name]
194
+
195
+ with jsonlines.open(logname, mode="a") as writer:
196
+ writer.write(output)
197
+
198
+ logging.info('Finished.')
199
+
200
+
201
+ if __name__ == '__main__':
202
+ from absl import app
203
+ from absl import flags
204
+ import gin
205
+
206
+ FLAGS = flags.FLAGS
207
+
208
+ jax.config.parse_flags_with_absl()
209
+
210
+ flags.DEFINE_multi_string(
211
+ 'gin_file',
212
+ default=None,
213
+ help='Path to gin configuration file. Multiple paths may be passed and '
214
+ 'will be imported in the given order, with later configurations '
215
+ 'overriding earlier ones.')
216
+
217
+ flags.DEFINE_multi_string(
218
+ 'gin_bindings', default=[], help='Individual gin bindings.')
219
+
220
+ flags.DEFINE_list(
221
+ 'gin_search_paths',
222
+ default=['.'],
223
+ help='Comma-separated list of gin config path prefixes to be prepended '
224
+ 'to suffixes given via `--gin_file`. If a file appears in. Only the '
225
+ 'first prefix that produces a valid path for each suffix will be '
226
+ 'used.')
227
+
228
+ flags.DEFINE_string(
229
+ 'tfds_data_dir', None,
230
+ 'If set, this directory will be used to store datasets prepared by '
231
+ 'TensorFlow Datasets that are not available in the public TFDS GCS '
232
+ 'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
233
+ 'all `Task`s.')
234
+
235
+
236
+ def main(argv: Sequence[str]):
237
+ """Wrapper for pdb post mortems."""
238
+ _main(argv)
239
+
240
+ def _main(argv: Sequence[str]):
241
+ """True main function."""
242
+ if len(argv) > 1:
243
+ raise app.UsageError('Too many command-line arguments.')
244
+
245
+ if FLAGS.tfds_data_dir:
246
+ seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
247
+
248
+ # Create gin-configurable version of `eval`.
249
+ evaluate_using_gin = gin.configurable(evaluate)
250
+
251
+ gin_utils.parse_gin_flags(
252
+ # User-provided gin paths take precedence if relative paths conflict.
253
+ FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
254
+ FLAGS.gin_file,
255
+ FLAGS.gin_bindings)
256
+ evaluate_using_gin()
257
+
258
+ gin_utils.run(main)
eval_base.sh CHANGED
@@ -1,10 +1,10 @@
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
  EVAL_OUTPUT_DIR="gs://nb-t5x/eval/"
3
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
4
- CHECKPOINT_PATH="gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1043000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
- python3 ${T5X_DIR}/t5x/eval.py \
8
  --gin_search_paths=${PROJECT_DIR} \
9
  --gin_file="eval_categorisation_base.gin" \
10
  --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
 
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
  EVAL_OUTPUT_DIR="gs://nb-t5x/eval/"
3
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
4
+ CHECKPOINT_PATH="gs://nb-t5x/eval_norwegian_NCC_2_000_000/checkpoint_2005000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
+ python3 eval.py \
8
  --gin_search_paths=${PROJECT_DIR} \
9
  --gin_file="eval_categorisation_base.gin" \
10
  --gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
eval_categorisation_base.gin CHANGED
@@ -24,7 +24,7 @@ eval_script.evaluate:
24
  utils.DatasetConfig:
25
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
26
  task_feature_lengths = None # Auto-computes the max feature lengths.
27
- split = 'test'
28
  batch_size = 32
29
  shuffle = False
30
  seed = 42
 
24
  utils.DatasetConfig:
25
  mixture_or_task_name = %MIXTURE_OR_TASK_NAME
26
  task_feature_lengths = None # Auto-computes the max feature lengths.
27
+ split = 'validation'
28
  batch_size = 32
29
  shuffle = False
30
  seed = 42
finetune_categorisation_base.gin CHANGED
@@ -12,7 +12,7 @@ include "t5x/configs/runs/finetune.gin"
12
 
13
  MIXTURE_OR_TASK_NAME = "categorise"
14
  TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
15
- TRAIN_STEPS = 1_635_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
16
  USE_CACHED_TASKS = False
17
  DROPOUT_RATE = 0.1
18
  RANDOM_SEED = 0
@@ -29,7 +29,7 @@ RANDOM_SEED = 0
29
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1360000"
30
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1_lr_1/checkpoint_1100000"
31
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1100000"
32
- INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1630000"
33
 
34
  #train_script.train:
35
  # eval_period = 500
 
12
 
13
  MIXTURE_OR_TASK_NAME = "categorise"
14
  TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
15
+ TRAIN_STEPS = 2_005_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
16
  USE_CACHED_TASKS = False
17
  DROPOUT_RATE = 0.1
18
  RANDOM_SEED = 0
 
29
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1360000"
30
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1_lr_1/checkpoint_1100000"
31
  #INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1100000"
32
+ INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_2000000"
33
 
34
  #train_script.train:
35
  # eval_period = 500
tasks.py CHANGED
@@ -59,7 +59,7 @@ seqio.TaskRegistry.add(
59
  categorise_preprocessor,
60
  seqio.preprocessors.tokenize_and_append_eos,
61
  ],
62
- #metric_fns=[metrics.bleu],
63
  output_features=DEFAULT_OUTPUT_FEATURES,
64
- )
65
 
 
59
  categorise_preprocessor,
60
  seqio.preprocessors.tokenize_and_append_eos,
61
  ],
62
+ metric_fns=[metrics.accuracy],
63
  output_features=DEFAULT_OUTPUT_FEATURES,
64
+ )
65
 
train_base.sh CHANGED
@@ -1,7 +1,7 @@
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
  #Needs to be updated when moving to tpu-v4 it should then be in another zone
4
- MODEL_DIR="gs://nb-t5x/eval_norwegian_1_163_000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \
 
1
  PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
2
  T5X_DIR="../../t5x" # directory where the t5x is cloned.
3
  #Needs to be updated when moving to tpu-v4 it should then be in another zone
4
+ MODEL_DIR="gs://nb-t5x/eval_norwegian_NCC_2_000_000"
5
  export PYTHONPATH=${PROJECT_DIR}
6
 
7
  python3 ${T5X_DIR}/t5x/train.py \