updated eval script
Browse files- __pycache__/tasks.cpython-38.pyc +0 -0
- eval.py +258 -0
- eval_base.sh +2 -2
- eval_categorisation_base.gin +1 -1
- finetune_categorisation_base.gin +2 -2
- tasks.py +2 -2
- train_base.sh +1 -1
__pycache__/tasks.cpython-38.pyc
CHANGED
Binary files a/__pycache__/tasks.cpython-38.pyc and b/__pycache__/tasks.cpython-38.pyc differ
|
|
eval.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# pylint:disable=line-too-long
|
16 |
+
# pyformat: disable
|
17 |
+
r"""This script runs inference-evaluation on a T5X-compatible model.
|
18 |
+
|
19 |
+
"""
|
20 |
+
# pyformat: enable
|
21 |
+
# pylint:enable=line-too-long
|
22 |
+
|
23 |
+
import functools
|
24 |
+
import os
|
25 |
+
import socket
|
26 |
+
from datetime import datetime
|
27 |
+
import jsonlines
|
28 |
+
from typing import Optional, Sequence, Type
|
29 |
+
|
30 |
+
# pylint:disable=g-import-not-at-top
|
31 |
+
# TODO(adarob): Re-enable once users are notified and tests are updated.
|
32 |
+
os.environ['FLAX_LAZY_RNG'] = 'no'
|
33 |
+
from absl import logging
|
34 |
+
from clu import metric_writers
|
35 |
+
import jax
|
36 |
+
from jax.experimental import multihost_utils
|
37 |
+
import seqio
|
38 |
+
from t5x import gin_utils
|
39 |
+
from t5x import models
|
40 |
+
from t5x import partitioning
|
41 |
+
from t5x import utils
|
42 |
+
from typing_extensions import Protocol
|
43 |
+
|
44 |
+
# Automatically search for gin files relative to the T5X package.
|
45 |
+
_DEFAULT_GIN_SEARCH_PATHS = [
|
46 |
+
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
47 |
+
]
|
48 |
+
|
49 |
+
|
50 |
+
class SummarizeConfigFn(Protocol):
|
51 |
+
|
52 |
+
def __call__(self, model_dir: str,
|
53 |
+
summary_writer: Optional[metric_writers.SummaryWriter],
|
54 |
+
step: int) -> None:
|
55 |
+
...
|
56 |
+
|
57 |
+
|
58 |
+
def evaluate(
|
59 |
+
*,
|
60 |
+
model: models.BaseTransformerModel,
|
61 |
+
dataset_cfg: utils.DatasetConfig,
|
62 |
+
restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
|
63 |
+
partitioner: partitioning.BasePartitioner,
|
64 |
+
output_dir: str,
|
65 |
+
inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator,
|
66 |
+
summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config,
|
67 |
+
fallback_init_rng: Optional[int] = None):
|
68 |
+
"""Evaluation function.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
model: The model object to use for inference.
|
72 |
+
dataset_cfg: Specification for the dataset to infer based on.
|
73 |
+
restore_checkpoint_cfg: Specification for the model parameter checkpoint to
|
74 |
+
load.
|
75 |
+
partitioner: Partitioner for the model parameters and data across devices.
|
76 |
+
output_dir: Path to directory to write temporary files and final results.
|
77 |
+
inference_evaluator_cls: seqio.Evaluator class to use for inference
|
78 |
+
evaluation, potentially with bound configuration args.
|
79 |
+
summarize_config_fn: A function that takes in the model directory, an
|
80 |
+
optional SummaryWriter, and the step number, and writes a summary of the
|
81 |
+
configuration. SummaryWriter will be None in most cases.
|
82 |
+
fallback_init_rng: A random seed used for parameter initialization during
|
83 |
+
model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is
|
84 |
+
set to True. If None, parameter initialization is not allowed during model
|
85 |
+
loading and having fallback_to_scratch enabled will result in an error.
|
86 |
+
"""
|
87 |
+
logging.info('Process ID: %d', jax.process_index())
|
88 |
+
if dataset_cfg.module:
|
89 |
+
utils.import_module(dataset_cfg.module)
|
90 |
+
batch_size = dataset_cfg.batch_size
|
91 |
+
|
92 |
+
summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0)
|
93 |
+
|
94 |
+
ds_vocabs = utils.get_vocabulary(dataset_cfg)
|
95 |
+
if (ds_vocabs[0] != model.input_vocabulary or
|
96 |
+
ds_vocabs[1] != model.output_vocabulary):
|
97 |
+
raise ValueError(f'Model and Task vocabularies do not match:\n'
|
98 |
+
f' task={dataset_cfg.mixture_or_task_name}\n'
|
99 |
+
f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n'
|
100 |
+
f' model.input_vocabulary={model.input_vocabulary}\n'
|
101 |
+
f' model.output_vocabulary={model.output_vocabulary}\n')
|
102 |
+
|
103 |
+
# ----------------------------------------------------------------------------
|
104 |
+
# SeqIO (inference-based) evaluation setup
|
105 |
+
# ----------------------------------------------------------------------------
|
106 |
+
# Init evaluator to set up cached datasets
|
107 |
+
evaluator = inference_evaluator_cls(
|
108 |
+
mixture_or_task_name=dataset_cfg.mixture_or_task_name,
|
109 |
+
feature_converter=model.FEATURE_CONVERTER_CLS(pack=False),
|
110 |
+
eval_split=dataset_cfg.split,
|
111 |
+
use_cached=dataset_cfg.use_cached,
|
112 |
+
seed=dataset_cfg.seed,
|
113 |
+
sequence_length=dataset_cfg.task_feature_lengths,
|
114 |
+
log_dir=os.path.join(output_dir, 'inference_eval'))
|
115 |
+
if not evaluator.eval_tasks:
|
116 |
+
raise ValueError(
|
117 |
+
f"'{dataset_cfg.mixture_or_task_name}' has no metrics for evaluation.")
|
118 |
+
|
119 |
+
# ----------------------------------------------------------------------------
|
120 |
+
# T5X model loading.
|
121 |
+
# ----------------------------------------------------------------------------
|
122 |
+
|
123 |
+
# Initialize optimizer from the existing checkpoint.
|
124 |
+
input_shapes = {
|
125 |
+
k: (batch_size,) + s for k, s in evaluator.model_feature_shapes.items()
|
126 |
+
}
|
127 |
+
|
128 |
+
train_state_initializer = utils.TrainStateInitializer(
|
129 |
+
optimizer_def=None, # Do not load optimizer state.
|
130 |
+
init_fn=model.get_initial_variables,
|
131 |
+
input_shapes=input_shapes,
|
132 |
+
partitioner=partitioner)
|
133 |
+
train_state_axes = train_state_initializer.train_state_axes
|
134 |
+
# Log the variable shapes information and write to a file.
|
135 |
+
log_file = os.path.join(output_dir, 'model-info.txt')
|
136 |
+
utils.log_model_info(log_file,
|
137 |
+
train_state_initializer.global_train_state_shape,
|
138 |
+
partitioner)
|
139 |
+
|
140 |
+
predict_fn = None
|
141 |
+
score_fn = None
|
142 |
+
|
143 |
+
# Disable strictness since we are dropping the optimizer state.
|
144 |
+
restore_checkpoint_cfg.strict = False
|
145 |
+
|
146 |
+
if fallback_init_rng is not None:
|
147 |
+
fallback_init_rng = jax.random.PRNGKey(fallback_init_rng)
|
148 |
+
for train_state in train_state_initializer.from_checkpoints(
|
149 |
+
[restore_checkpoint_cfg], init_rng=fallback_init_rng):
|
150 |
+
|
151 |
+
# Compile the model only once.
|
152 |
+
if not predict_fn:
|
153 |
+
predict_fn = utils.get_infer_fn(
|
154 |
+
infer_step=model.predict_batch,
|
155 |
+
batch_size=batch_size,
|
156 |
+
train_state_axes=train_state_axes,
|
157 |
+
partitioner=partitioner)
|
158 |
+
|
159 |
+
score_fn = utils.get_infer_fn(
|
160 |
+
infer_step=model.score_batch,
|
161 |
+
batch_size=batch_size,
|
162 |
+
train_state_axes=train_state_axes,
|
163 |
+
partitioner=partitioner)
|
164 |
+
|
165 |
+
# ----------------------------------------------------------------------------
|
166 |
+
# Main training loop
|
167 |
+
# ----------------------------------------------------------------------------
|
168 |
+
|
169 |
+
# Run final evaluation (with decoding) on the full eval dataset.
|
170 |
+
all_metrics, _, _ = evaluator.evaluate(
|
171 |
+
compute_metrics=jax.process_index() == 0,
|
172 |
+
step=int(train_state.step),
|
173 |
+
predict_fn=functools.partial(
|
174 |
+
predict_fn, train_state=train_state, rng=jax.random.PRNGKey(0)),
|
175 |
+
score_fn=functools.partial(score_fn, train_state=train_state))
|
176 |
+
all_metrics.result() # Ensure metrics are finished being computed.
|
177 |
+
# Wait until computations are done before continuing.
|
178 |
+
multihost_utils.sync_global_devices(f'step_{train_state.step}:complete')
|
179 |
+
|
180 |
+
## Write this to the local log directory
|
181 |
+
now = datetime.now()
|
182 |
+
logtime = now.strftime("%d-%m-%Y %H:%M:%S")
|
183 |
+
|
184 |
+
if not os.path.exists("log"):
|
185 |
+
os.makedirs("log")
|
186 |
+
|
187 |
+
logname ="./log/"+"eval_results_"+socket.gethostname()+".jsonl"
|
188 |
+
|
189 |
+
output = {}
|
190 |
+
output["model"] = restore_checkpoint_cfg.path
|
191 |
+
output["eval_date"] = logtime
|
192 |
+
output["split"] = dataset_cfg.split
|
193 |
+
output["result"] = all_metrics.result()[dataset_cfg.mixture_or_task_name]
|
194 |
+
|
195 |
+
with jsonlines.open(logname, mode="a") as writer:
|
196 |
+
writer.write(output)
|
197 |
+
|
198 |
+
logging.info('Finished.')
|
199 |
+
|
200 |
+
|
201 |
+
if __name__ == '__main__':
|
202 |
+
from absl import app
|
203 |
+
from absl import flags
|
204 |
+
import gin
|
205 |
+
|
206 |
+
FLAGS = flags.FLAGS
|
207 |
+
|
208 |
+
jax.config.parse_flags_with_absl()
|
209 |
+
|
210 |
+
flags.DEFINE_multi_string(
|
211 |
+
'gin_file',
|
212 |
+
default=None,
|
213 |
+
help='Path to gin configuration file. Multiple paths may be passed and '
|
214 |
+
'will be imported in the given order, with later configurations '
|
215 |
+
'overriding earlier ones.')
|
216 |
+
|
217 |
+
flags.DEFINE_multi_string(
|
218 |
+
'gin_bindings', default=[], help='Individual gin bindings.')
|
219 |
+
|
220 |
+
flags.DEFINE_list(
|
221 |
+
'gin_search_paths',
|
222 |
+
default=['.'],
|
223 |
+
help='Comma-separated list of gin config path prefixes to be prepended '
|
224 |
+
'to suffixes given via `--gin_file`. If a file appears in. Only the '
|
225 |
+
'first prefix that produces a valid path for each suffix will be '
|
226 |
+
'used.')
|
227 |
+
|
228 |
+
flags.DEFINE_string(
|
229 |
+
'tfds_data_dir', None,
|
230 |
+
'If set, this directory will be used to store datasets prepared by '
|
231 |
+
'TensorFlow Datasets that are not available in the public TFDS GCS '
|
232 |
+
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
|
233 |
+
'all `Task`s.')
|
234 |
+
|
235 |
+
|
236 |
+
def main(argv: Sequence[str]):
|
237 |
+
"""Wrapper for pdb post mortems."""
|
238 |
+
_main(argv)
|
239 |
+
|
240 |
+
def _main(argv: Sequence[str]):
|
241 |
+
"""True main function."""
|
242 |
+
if len(argv) > 1:
|
243 |
+
raise app.UsageError('Too many command-line arguments.')
|
244 |
+
|
245 |
+
if FLAGS.tfds_data_dir:
|
246 |
+
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
|
247 |
+
|
248 |
+
# Create gin-configurable version of `eval`.
|
249 |
+
evaluate_using_gin = gin.configurable(evaluate)
|
250 |
+
|
251 |
+
gin_utils.parse_gin_flags(
|
252 |
+
# User-provided gin paths take precedence if relative paths conflict.
|
253 |
+
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
|
254 |
+
FLAGS.gin_file,
|
255 |
+
FLAGS.gin_bindings)
|
256 |
+
evaluate_using_gin()
|
257 |
+
|
258 |
+
gin_utils.run(main)
|
eval_base.sh
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
|
2 |
EVAL_OUTPUT_DIR="gs://nb-t5x/eval/"
|
3 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
4 |
-
CHECKPOINT_PATH="gs://nb-t5x
|
5 |
export PYTHONPATH=${PROJECT_DIR}
|
6 |
|
7 |
-
python3
|
8 |
--gin_search_paths=${PROJECT_DIR} \
|
9 |
--gin_file="eval_categorisation_base.gin" \
|
10 |
--gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
|
|
|
1 |
PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
|
2 |
EVAL_OUTPUT_DIR="gs://nb-t5x/eval/"
|
3 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
4 |
+
CHECKPOINT_PATH="gs://nb-t5x/eval_norwegian_NCC_2_000_000/checkpoint_2005000"
|
5 |
export PYTHONPATH=${PROJECT_DIR}
|
6 |
|
7 |
+
python3 eval.py \
|
8 |
--gin_search_paths=${PROJECT_DIR} \
|
9 |
--gin_file="eval_categorisation_base.gin" \
|
10 |
--gin.CHECKPOINT_PATH=\"${CHECKPOINT_PATH}\" \
|
eval_categorisation_base.gin
CHANGED
@@ -24,7 +24,7 @@ eval_script.evaluate:
|
|
24 |
utils.DatasetConfig:
|
25 |
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
|
26 |
task_feature_lengths = None # Auto-computes the max feature lengths.
|
27 |
-
split = '
|
28 |
batch_size = 32
|
29 |
shuffle = False
|
30 |
seed = 42
|
|
|
24 |
utils.DatasetConfig:
|
25 |
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
|
26 |
task_feature_lengths = None # Auto-computes the max feature lengths.
|
27 |
+
split = 'validation'
|
28 |
batch_size = 32
|
29 |
shuffle = False
|
30 |
seed = 42
|
finetune_categorisation_base.gin
CHANGED
@@ -12,7 +12,7 @@ include "t5x/configs/runs/finetune.gin"
|
|
12 |
|
13 |
MIXTURE_OR_TASK_NAME = "categorise"
|
14 |
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
|
15 |
-
TRAIN_STEPS =
|
16 |
USE_CACHED_TASKS = False
|
17 |
DROPOUT_RATE = 0.1
|
18 |
RANDOM_SEED = 0
|
@@ -29,7 +29,7 @@ RANDOM_SEED = 0
|
|
29 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1360000"
|
30 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1_lr_1/checkpoint_1100000"
|
31 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1100000"
|
32 |
-
INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/
|
33 |
|
34 |
#train_script.train:
|
35 |
# eval_period = 500
|
|
|
12 |
|
13 |
MIXTURE_OR_TASK_NAME = "categorise"
|
14 |
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 2}
|
15 |
+
TRAIN_STEPS = 2_005_000 # 1000000 pre-trained steps + 10000 fine-tuning steps.
|
16 |
USE_CACHED_TASKS = False
|
17 |
DROPOUT_RATE = 0.1
|
18 |
RANDOM_SEED = 0
|
|
|
29 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_1360000"
|
30 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_run1_lr_1/checkpoint_1100000"
|
31 |
#INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/pk_nb_t5x_base_scandinavian/checkpoint_1100000"
|
32 |
+
INITIAL_CHECKPOINT_PATH = "gs://nb-t5x-us-central2/norwegian_t5x_base/checkpoint_2000000"
|
33 |
|
34 |
#train_script.train:
|
35 |
# eval_period = 500
|
tasks.py
CHANGED
@@ -59,7 +59,7 @@ seqio.TaskRegistry.add(
|
|
59 |
categorise_preprocessor,
|
60 |
seqio.preprocessors.tokenize_and_append_eos,
|
61 |
],
|
62 |
-
|
63 |
output_features=DEFAULT_OUTPUT_FEATURES,
|
64 |
-
)
|
65 |
|
|
|
59 |
categorise_preprocessor,
|
60 |
seqio.preprocessors.tokenize_and_append_eos,
|
61 |
],
|
62 |
+
metric_fns=[metrics.accuracy],
|
63 |
output_features=DEFAULT_OUTPUT_FEATURES,
|
64 |
+
)
|
65 |
|
train_base.sh
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
|
2 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
3 |
#Needs to be updated when moving to tpu-v4 it should then be in another zone
|
4 |
-
MODEL_DIR="gs://nb-t5x/
|
5 |
export PYTHONPATH=${PROJECT_DIR}
|
6 |
|
7 |
python3 ${T5X_DIR}/t5x/train.py \
|
|
|
1 |
PROJECT_DIR=${HOME}"/models/t5-parliament-categorisation"
|
2 |
T5X_DIR="../../t5x" # directory where the t5x is cloned.
|
3 |
#Needs to be updated when moving to tpu-v4 it should then be in another zone
|
4 |
+
MODEL_DIR="gs://nb-t5x/eval_norwegian_NCC_2_000_000"
|
5 |
export PYTHONPATH=${PROJECT_DIR}
|
6 |
|
7 |
python3 ${T5X_DIR}/t5x/train.py \
|