|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""This script runs inference-evaluation on a T5X-compatible model. |
|
|
|
""" |
|
|
|
|
|
|
|
import functools |
|
import os |
|
import socket |
|
from datetime import datetime |
|
import jsonlines |
|
from typing import Optional, Sequence, Type |
|
|
|
|
|
|
|
os.environ['FLAX_LAZY_RNG'] = 'no' |
|
from absl import logging |
|
from clu import metric_writers |
|
import jax |
|
from jax.experimental import multihost_utils |
|
import seqio |
|
from t5x import gin_utils |
|
from t5x import models |
|
from t5x import partitioning |
|
from t5x import utils |
|
from typing_extensions import Protocol |
|
|
|
|
|
_DEFAULT_GIN_SEARCH_PATHS = [ |
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
] |
|
|
|
|
|
class SummarizeConfigFn(Protocol): |
|
|
|
def __call__(self, model_dir: str, |
|
summary_writer: Optional[metric_writers.SummaryWriter], |
|
step: int) -> None: |
|
... |
|
|
|
|
|
def evaluate( |
|
*, |
|
model: models.BaseTransformerModel, |
|
dataset_cfg: utils.DatasetConfig, |
|
restore_checkpoint_cfg: utils.RestoreCheckpointConfig, |
|
partitioner: partitioning.BasePartitioner, |
|
output_dir: str, |
|
inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator, |
|
summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config, |
|
fallback_init_rng: Optional[int] = None): |
|
"""Evaluation function. |
|
|
|
Args: |
|
model: The model object to use for inference. |
|
dataset_cfg: Specification for the dataset to infer based on. |
|
restore_checkpoint_cfg: Specification for the model parameter checkpoint to |
|
load. |
|
partitioner: Partitioner for the model parameters and data across devices. |
|
output_dir: Path to directory to write temporary files and final results. |
|
inference_evaluator_cls: seqio.Evaluator class to use for inference |
|
evaluation, potentially with bound configuration args. |
|
summarize_config_fn: A function that takes in the model directory, an |
|
optional SummaryWriter, and the step number, and writes a summary of the |
|
configuration. SummaryWriter will be None in most cases. |
|
fallback_init_rng: A random seed used for parameter initialization during |
|
model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is |
|
set to True. If None, parameter initialization is not allowed during model |
|
loading and having fallback_to_scratch enabled will result in an error. |
|
""" |
|
logging.info('Process ID: %d', jax.process_index()) |
|
if dataset_cfg.module: |
|
utils.import_module(dataset_cfg.module) |
|
batch_size = dataset_cfg.batch_size |
|
|
|
summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0) |
|
|
|
ds_vocabs = utils.get_vocabulary(dataset_cfg) |
|
if (ds_vocabs[0] != model.input_vocabulary or |
|
ds_vocabs[1] != model.output_vocabulary): |
|
raise ValueError(f'Model and Task vocabularies do not match:\n' |
|
f' task={dataset_cfg.mixture_or_task_name}\n' |
|
f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n' |
|
f' model.input_vocabulary={model.input_vocabulary}\n' |
|
f' model.output_vocabulary={model.output_vocabulary}\n') |
|
|
|
|
|
|
|
|
|
|
|
evaluator = inference_evaluator_cls( |
|
mixture_or_task_name=dataset_cfg.mixture_or_task_name, |
|
feature_converter=model.FEATURE_CONVERTER_CLS(pack=False), |
|
eval_split=dataset_cfg.split, |
|
use_cached=dataset_cfg.use_cached, |
|
seed=dataset_cfg.seed, |
|
sequence_length=dataset_cfg.task_feature_lengths, |
|
log_dir=os.path.join(output_dir, 'inference_eval')) |
|
if not evaluator.eval_tasks: |
|
raise ValueError( |
|
f"'{dataset_cfg.mixture_or_task_name}' has no metrics for evaluation.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
input_shapes = { |
|
k: (batch_size,) + s for k, s in evaluator.model_feature_shapes.items() |
|
} |
|
|
|
train_state_initializer = utils.TrainStateInitializer( |
|
optimizer_def=None, |
|
init_fn=model.get_initial_variables, |
|
input_shapes=input_shapes, |
|
partitioner=partitioner) |
|
train_state_axes = train_state_initializer.train_state_axes |
|
|
|
log_file = os.path.join(output_dir, 'model-info.txt') |
|
utils.log_model_info(log_file, |
|
train_state_initializer.global_train_state_shape, |
|
partitioner) |
|
|
|
predict_fn = None |
|
score_fn = None |
|
|
|
|
|
restore_checkpoint_cfg.strict = False |
|
|
|
if fallback_init_rng is not None: |
|
fallback_init_rng = jax.random.PRNGKey(fallback_init_rng) |
|
for train_state in train_state_initializer.from_checkpoints( |
|
[restore_checkpoint_cfg], init_rng=fallback_init_rng): |
|
|
|
|
|
if not predict_fn: |
|
predict_fn = utils.get_infer_fn( |
|
infer_step=model.predict_batch, |
|
batch_size=batch_size, |
|
train_state_axes=train_state_axes, |
|
partitioner=partitioner) |
|
|
|
score_fn = utils.get_infer_fn( |
|
infer_step=model.score_batch, |
|
batch_size=batch_size, |
|
train_state_axes=train_state_axes, |
|
partitioner=partitioner) |
|
|
|
|
|
|
|
|
|
|
|
|
|
all_metrics, _, _ = evaluator.evaluate( |
|
compute_metrics=jax.process_index() == 0, |
|
step=int(train_state.step), |
|
predict_fn=functools.partial( |
|
predict_fn, train_state=train_state, rng=jax.random.PRNGKey(0)), |
|
score_fn=functools.partial(score_fn, train_state=train_state)) |
|
all_metrics.result() |
|
|
|
multihost_utils.sync_global_devices(f'step_{train_state.step}:complete') |
|
|
|
|
|
now = datetime.now() |
|
logtime = now.strftime("%d-%m-%Y %H:%M:%S") |
|
|
|
logname = output_dir+"eval_results_"+socket.gethostname()+".jsonl" |
|
|
|
output = {} |
|
output["model"] = restore_checkpoint_cfg.path |
|
output["task"] = dataset_cfg.mixture_or_task_name |
|
output["eval_date"] = logtime |
|
output["split"] = dataset_cfg.split |
|
output["feature_length"] = dataset_cfg.task_feature_lengths |
|
output["eval_batch_size"] = dataset_cfg.batch_size |
|
output["result"] = all_metrics.result()[dataset_cfg.mixture_or_task_name] |
|
|
|
with jsonlines.open(logname, mode="a") as writer: |
|
writer.write(output) |
|
|
|
logging.info('Finished.') |
|
|
|
|
|
if __name__ == '__main__': |
|
from absl import app |
|
from absl import flags |
|
import gin |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
jax.config.parse_flags_with_absl() |
|
|
|
flags.DEFINE_multi_string( |
|
'gin_file', |
|
default=None, |
|
help='Path to gin configuration file. Multiple paths may be passed and ' |
|
'will be imported in the given order, with later configurations ' |
|
'overriding earlier ones.') |
|
|
|
flags.DEFINE_multi_string( |
|
'gin_bindings', default=[], help='Individual gin bindings.') |
|
|
|
flags.DEFINE_list( |
|
'gin_search_paths', |
|
default=['.'], |
|
help='Comma-separated list of gin config path prefixes to be prepended ' |
|
'to suffixes given via `--gin_file`. If a file appears in. Only the ' |
|
'first prefix that produces a valid path for each suffix will be ' |
|
'used.') |
|
|
|
flags.DEFINE_string( |
|
'tfds_data_dir', None, |
|
'If set, this directory will be used to store datasets prepared by ' |
|
'TensorFlow Datasets that are not available in the public TFDS GCS ' |
|
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' |
|
'all `Task`s.') |
|
|
|
|
|
def main(argv: Sequence[str]): |
|
"""Wrapper for pdb post mortems.""" |
|
_main(argv) |
|
|
|
def _main(argv: Sequence[str]): |
|
"""True main function.""" |
|
if len(argv) > 1: |
|
raise app.UsageError('Too many command-line arguments.') |
|
|
|
if FLAGS.tfds_data_dir: |
|
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) |
|
|
|
|
|
evaluate_using_gin = gin.configurable(evaluate) |
|
|
|
gin_utils.parse_gin_flags( |
|
|
|
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, |
|
FLAGS.gin_file, |
|
FLAGS.gin_bindings) |
|
evaluate_using_gin() |
|
|
|
gin_utils.run(main) |
|
|