Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# pylint:disable=line-too-long | |
# pyformat: disable | |
r"""This script runs inference-evaluation on a T5X-compatible model. | |
""" | |
# pyformat: enable | |
# pylint:enable=line-too-long | |
import functools | |
import os | |
from typing import Optional, Sequence, Type | |
# pylint:disable=g-import-not-at-top | |
# TODO(adarob): Re-enable once users are notified and tests are updated. | |
os.environ['FLAX_LAZY_RNG'] = 'no' | |
from absl import logging | |
from clu import metric_writers | |
import jax | |
from jax.experimental import multihost_utils | |
import seqio | |
from t5x import gin_utils | |
from t5x import models | |
from t5x import partitioning | |
from t5x import utils | |
from typing_extensions import Protocol | |
# Automatically search for gin files relative to the T5X package. | |
_DEFAULT_GIN_SEARCH_PATHS = [ | |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
] | |
class SummarizeConfigFn(Protocol): | |
def __call__(self, model_dir: str, | |
summary_writer: Optional[metric_writers.SummaryWriter], | |
step: int) -> None: | |
... | |
def evaluate( | |
*, | |
model: models.BaseTransformerModel, | |
dataset_cfg: utils.DatasetConfig, | |
restore_checkpoint_cfg: utils.RestoreCheckpointConfig, | |
partitioner: partitioning.BasePartitioner, | |
output_dir: str, | |
inference_evaluator_cls: Type[seqio.Evaluator] = seqio.Evaluator, | |
summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config, | |
fallback_init_rng: Optional[int] = None): | |
"""Evaluation function. | |
Args: | |
model: The model object to use for inference. | |
dataset_cfg: Specification for the dataset to infer based on. | |
restore_checkpoint_cfg: Specification for the model parameter checkpoint to | |
load. | |
partitioner: Partitioner for the model parameters and data across devices. | |
output_dir: Path to directory to write temporary files and final results. | |
inference_evaluator_cls: seqio.Evaluator class to use for inference | |
evaluation, potentially with bound configuration args. | |
summarize_config_fn: A function that takes in the model directory, an | |
optional SummaryWriter, and the step number, and writes a summary of the | |
configuration. SummaryWriter will be None in most cases. | |
fallback_init_rng: A random seed used for parameter initialization during | |
model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is | |
set to True. If None, parameter initialization is not allowed during model | |
loading and having fallback_to_scratch enabled will result in an error. | |
""" | |
logging.info('Process ID: %d', jax.process_index()) | |
if dataset_cfg.module: | |
utils.import_module(dataset_cfg.module) | |
batch_size = dataset_cfg.batch_size | |
summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0) | |
ds_vocabs = utils.get_vocabulary(dataset_cfg) | |
if (ds_vocabs[0] != model.input_vocabulary or | |
ds_vocabs[1] != model.output_vocabulary): | |
raise ValueError(f'Model and Task vocabularies do not match:\n' | |
f' task={dataset_cfg.mixture_or_task_name}\n' | |
f' ds_vocabs=({ds_vocabs[0]}, {ds_vocabs[1]})\n' | |
f' model.input_vocabulary={model.input_vocabulary}\n' | |
f' model.output_vocabulary={model.output_vocabulary}\n') | |
# ---------------------------------------------------------------------------- | |
# SeqIO (inference-based) evaluation setup | |
# ---------------------------------------------------------------------------- | |
# Init evaluator to set up cached datasets | |
evaluator = inference_evaluator_cls( | |
mixture_or_task_name=dataset_cfg.mixture_or_task_name, | |
feature_converter=model.FEATURE_CONVERTER_CLS(pack=False), | |
eval_split=dataset_cfg.split, | |
use_cached=dataset_cfg.use_cached, | |
seed=dataset_cfg.seed, | |
sequence_length=dataset_cfg.task_feature_lengths, | |
log_dir=os.path.join(output_dir, 'inference_eval')) | |
if not evaluator.eval_tasks: | |
raise ValueError( | |
f"'{dataset_cfg.mixture_or_task_name}' has no metrics for evaluation.") | |
# ---------------------------------------------------------------------------- | |
# T5X model loading. | |
# ---------------------------------------------------------------------------- | |
# Initialize optimizer from the existing checkpoint. | |
input_shapes = { | |
k: (batch_size,) + s for k, s in evaluator.model_feature_shapes.items() | |
} | |
train_state_initializer = utils.TrainStateInitializer( | |
optimizer_def=None, # Do not load optimizer state. | |
init_fn=model.get_initial_variables, | |
input_shapes=input_shapes, | |
partitioner=partitioner) | |
train_state_axes = train_state_initializer.train_state_axes | |
# Log the variable shapes information and write to a file. | |
log_file = os.path.join(output_dir, 'model-info.txt') | |
utils.log_model_info(log_file, | |
train_state_initializer.global_train_state_shape, | |
partitioner) | |
predict_fn = None | |
score_fn = None | |
# Disable strictness since we are dropping the optimizer state. | |
restore_checkpoint_cfg.strict = False | |
if fallback_init_rng is not None: | |
fallback_init_rng = jax.random.PRNGKey(fallback_init_rng) | |
for train_state in train_state_initializer.from_checkpoints( | |
[restore_checkpoint_cfg], init_rng=fallback_init_rng): | |
# Compile the model only once. | |
if not predict_fn: | |
predict_fn = utils.get_infer_fn( | |
infer_step=model.predict_batch, | |
batch_size=batch_size, | |
train_state_axes=train_state_axes, | |
partitioner=partitioner) | |
predict_with_aux_fn = utils.get_infer_fn( | |
infer_step=model.predict_batch_with_aux, | |
batch_size=batch_size, | |
train_state_axes=train_state_axes, | |
partitioner=partitioner) | |
score_fn = utils.get_infer_fn( | |
infer_step=model.score_batch, | |
batch_size=batch_size, | |
train_state_axes=train_state_axes, | |
partitioner=partitioner) | |
# ---------------------------------------------------------------------------- | |
# Main training loop | |
# ---------------------------------------------------------------------------- | |
# Run final evaluation (with decoding) on the full eval dataset. | |
all_metrics, _, _ = evaluator.evaluate( | |
compute_metrics=jax.process_index() == 0, | |
step=int(train_state.step), | |
predict_fn=functools.partial( | |
predict_fn, train_state=train_state, rng=jax.random.PRNGKey(0)), | |
score_fn=functools.partial(score_fn, train_state=train_state), | |
predict_with_aux_fn=functools.partial( | |
predict_with_aux_fn, | |
train_state=train_state, | |
rng=jax.random.PRNGKey(0))) | |
all_metrics.result() # Ensure metrics are finished being computed. | |
# Wait until computations are done before continuing. | |
multihost_utils.sync_global_devices(f'step_{train_state.step}:complete') | |
logging.info('Finished.') | |
if __name__ == '__main__': | |
from absl import app | |
from absl import flags | |
import gin | |
FLAGS = flags.FLAGS | |
jax.config.parse_flags_with_absl() | |
flags.DEFINE_multi_string( | |
'gin_file', | |
default=None, | |
help='Path to gin configuration file. Multiple paths may be passed and ' | |
'will be imported in the given order, with later configurations ' | |
'overriding earlier ones.') | |
flags.DEFINE_multi_string( | |
'gin_bindings', default=[], help='Individual gin bindings.') | |
flags.DEFINE_list( | |
'gin_search_paths', | |
default=['.'], | |
help='Comma-separated list of gin config path prefixes to be prepended ' | |
'to suffixes given via `--gin_file`. If a file appears in. Only the ' | |
'first prefix that produces a valid path for each suffix will be ' | |
'used.') | |
flags.DEFINE_string( | |
'tfds_data_dir', None, | |
'If set, this directory will be used to store datasets prepared by ' | |
'TensorFlow Datasets that are not available in the public TFDS GCS ' | |
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' | |
'all `Task`s.') | |
def main(argv: Sequence[str]): | |
"""Wrapper for pdb post mortems.""" | |
_main(argv) | |
def _main(argv: Sequence[str]): | |
"""True main function.""" | |
if len(argv) > 1: | |
raise app.UsageError('Too many command-line arguments.') | |
if FLAGS.tfds_data_dir: | |
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) | |
# Create gin-configurable version of `eval`. | |
evaluate_using_gin = gin.configurable(evaluate) | |
gin_utils.parse_gin_flags( | |
# User-provided gin paths take precedence if relative paths conflict. | |
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, | |
FLAGS.gin_file, | |
FLAGS.gin_bindings) | |
evaluate_using_gin() | |
gin_utils.run(main) | |