juancopi81's picture
Add t5x and mt3 models
b100e1c
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint:disable=line-too-long
# pyformat: disable
r"""This script runs inference on a T5X-compatible model.
"""
# pyformat: enable
# pylint:enable=line-too-long
import concurrent.futures
import functools
import hashlib
import json
import os
import re
import shutil
import time
from typing import Any, Callable, Iterator, List, Mapping, Optional, Sequence, Tuple, Type
# TODO(adarob): Re-enable once users are notified and tests are updated.
# Must be set before flax imports.
# pylint:disable=g-import-not-at-top
os.environ['FLAX_LAZY_RNG'] = 'no'
from absl import logging
from clu import metric_writers
import jax
from jax.experimental import multihost_utils
import jax.numpy as jnp
import numpy as np
import seqio
from t5x import gin_utils
from t5x import models
from t5x import partitioning
from t5x import utils
import tensorflow as tf
from tensorflow.io import gfile
from typing_extensions import Protocol
# Automatically search for gin files relative to the T5X package.
_DEFAULT_GIN_SEARCH_PATHS = [
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
]
AUTOTUNE = tf.data.experimental.AUTOTUNE
class SummarizeConfigFn(Protocol):
def __call__(self, model_dir: str,
summary_writer: Optional[metric_writers.SummaryWriter],
step: int) -> None:
...
class FailFastThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
"""Wrapper for ThreadPoolExecutor that crashes main thread on exceptions.
NOTE: this class should be used only from the main thread.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._incomplete_futures: List[concurrent.futures.Future] = []
def check_for_exceptions(self, wait: bool = False):
"""Raises any exceptions from complete futures on the main thread."""
still_incomplete_futures = []
for future in self._incomplete_futures:
try:
exception = future.exception(timeout=0 if wait else None)
except concurrent.futures.TimeoutError:
still_incomplete_futures.append(future)
if exception is not None:
raise exception
self._incomplete_futures = still_incomplete_futures
def submit(self, *args, **kwargs) -> concurrent.futures.Future:
"""Submit function to threadpool, capturing the returned future."""
future = super().submit(*args, **kwargs)
self._incomplete_futures.append(future)
self.check_for_exceptions(wait=False)
return future
def shutdown(self, *args, wait: bool = False, **kwargs):
self.check_for_exceptions(wait=wait)
super().shutdown(*args, **kwargs)
def create_task_from_tfexample_file(paths: Sequence[str],
file_type: str,
inputs_key: str,
targets_key: Optional[str],
features: Mapping[str, seqio.Feature],
task_id: Optional[str] = None) -> str:
"""Registers ad-hoc Task for file-based dataset of TFExamples.
Args:
paths: Input file paths; all files should have type `file_type` and contain
binary-serialized TFExample protos.
file_type: Input file type; e.g., 'tfrecord', 'recordio', 'sstable'. For
keyed formats like 'sstable', we ignore the keys and use only the values.
inputs_key: Name of TFExample feature containing the input text for T5X. The
value of this feature should be a UTF8-encoded string.
targets_key: Optional name of a TFExample feature containing the target text
(relevant only in scoring mode). The value of this feature should be a
UTF8-encoded string.
features: Should have entries for keys 'inputs' and (if targets_key is not
None) 'targets', mapping to `seqio.Feature` objects that specify
attributes like vocabulary, add_eos, etc. These attributes are used for
preprocessing and featurizing the input text.
task_id: Task name identifier. By default, it is set to a unique and
deterministic hash id. Overrideable via this argument.
Returns:
Name of the newly-registered Task. This Task has a split named 'infer' that
contains the preprocessed and featurized input dataset.
"""
# tf.io.gfile.glob supports lists, in contrast to gfile.glob.
files = tf.io.gfile.glob(paths)
if files:
logging.info('Using tfexample files %s', files)
else:
# Fail early if there's something wrong with the input file pattern.
raise ValueError('Missing or invalid paths: %s' % paths)
reader = {
'tfrecord':
tf.data.TFRecordDataset,
}[file_type]
feature_description = {inputs_key: tf.io.FixedLenFeature([], tf.string)}
if targets_key:
feature_description[targets_key] = tf.io.FixedLenFeature([], tf.string)
# Create a unique, deterministic task name.
if task_id is None:
task_id = hashlib.md5(
':'.join(list(paths) +
[inputs_key, targets_key or '']).encode()).hexdigest()[:10]
task = seqio.TaskRegistry.add(
name=f'infer_{task_id}',
source=seqio.TFExampleDataSource({'infer': paths},
feature_description=feature_description,
reader_cls=reader),
preprocessors=[
functools.partial(
seqio.preprocessors.rekey,
key_map={
'inputs': inputs_key,
'targets': targets_key
}), seqio.preprocessors.tokenize_and_append_eos
],
output_features=features)
return task.name
def merge_chunks_to_file(
output_dir: str,
output_fname: str,
tmp_dir: str,
step: Optional[int],
) -> None:
"""Merge the predictions from different chunks into a unified file."""
logging.info('Merging chunk results.')
# Merge chunks into single file.
chunk_paths = sorted(
gfile.glob(os.path.join(tmp_dir, f'{output_fname}-chunk?????')))
if not chunk_paths:
raise FileNotFoundError(
'No chunk results found! One possible explanation is that your '
'input did not contain any examples')
assert int(chunk_paths[-1][-5:]) + 1 == len(chunk_paths), (
f'Expecting {int(chunk_paths[-1][-5:])} chunk paths, found '
f'{len(chunk_paths)}')
output_path = os.path.join(output_dir, output_fname)
del step
with gfile.GFile(output_path, 'wb') as merged:
for chunk_path in chunk_paths:
with gfile.GFile(chunk_path, 'rb') as ef:
shutil.copyfileobj(ef, merged)
logging.info('Results written to %s.', output_path)
_Inferences = Tuple[Sequence[Any], Mapping[str, Any]]
def write_inferences_to_file(
path: str,
inferences: _Inferences,
task_ds: tf.data.Dataset,
mode: str,
vocabulary: Optional[seqio.Vocabulary] = None,
json_encoder_cls: Type[json.JSONEncoder] = seqio.TensorAndNumpyEncoder,
include_all_inputs: bool = False,
input_fields_to_include: Optional[Sequence[str]] = None,
output_ids: bool = False) -> None:
"""Write model predictions, along with pretokenized inputs, to JSONL file.
Args:
path: File path to write to.
inferences: A tuple containing (predictions, aux_values). If mode is
'predict' then the `predictions` will be token IDs. If it's
'scores' then it'll be a collection of scores. `aux_values` will be an
empty dictionary unless mode is 'predict_with_aux', in which case it'll
contain the model's auxiliary outputs.
task_ds: Original task dataset. Features from task with suffix
`_pretokenized` are added to the outputs.
mode: Prediction mode, either 'predict', 'score' or 'predict_with_aux'.
vocabulary: Task output vocabulary. Only used in `predict` mode in order to
decode predicted outputs into string.
json_encoder_cls: a JSON encoder class used to customize JSON serialization
via json.dumps.
include_all_inputs: if True, will include all model inputs in the output
JSONL file (including raw tokens) in addition to the pretokenized inputs.
input_fields_to_include: List of input fields to include in the output JSONL
file. This list should be None if `include_all_inputs` is set to True.
output_ids: if True, will output the token ID sequence for the output, in
addition to the decoded text.
"""
all_predictions, all_aux_values = inferences
if mode in ('predict', 'predict_with_aux') and vocabulary is None:
raise ValueError('The `vocabulary` parameter is required in `predict` and '
'`predict_with_aux` modes')
def _json_compat(value):
if isinstance(value, bytes):
return value.decode('utf-8')
elif isinstance(value, (jnp.bfloat16, jnp.floating)):
return float(value)
elif isinstance(value, jnp.integer):
return float(value)
elif isinstance(value, (jnp.ndarray, np.ndarray)):
# Flatten array features.
return value.tolist()
else:
return value
if include_all_inputs and input_fields_to_include is not None:
raise ValueError(
'include_all_inputs and input_fields_to_include should not be set'
' simultaneously.')
with gfile.GFile(path, 'w') as f:
for i, inp in task_ds.enumerate().as_numpy_iterator():
predictions = all_predictions[i]
aux_values = {aux_field: v[i] for aux_field, v in all_aux_values.items()}
if include_all_inputs:
inputs = inp
elif input_fields_to_include is not None:
inputs = {
k: v for k, v in inp.items() if k in input_fields_to_include or
(k.endswith('_pretokenized') and
k[:-len('_pretokenized')] in input_fields_to_include)
}
else:
inputs = {k: v for k, v in inp.items() if k.endswith('_pretokenized')}
json_dict = {}
json_dict['inputs'] = {k: _json_compat(v) for k, v in inputs.items()}
if mode == 'predict':
assert vocabulary is not None
json_dict['prediction'] = _json_compat(
vocabulary.decode_tf(tf.constant(predictions)).numpy())
if output_ids:
pred = _json_compat(tf.constant(predictions).numpy())
# Truncate padding tokens.
assert isinstance(pred, list)
pred = pred[:pred.index(0)] if 0 in pred else pred
json_dict['prediction_tokens'] = pred
elif mode == 'score':
json_dict['score'] = _json_compat(predictions)
elif mode == 'predict_with_aux':
assert vocabulary is not None
json_dict['prediction'] = _json_compat(
vocabulary.decode_tf(tf.constant(predictions)).numpy())
if output_ids:
pred = _json_compat(tf.constant(predictions).numpy())
# Truncate padding tokens.
pred = pred[:pred.index(0)] if 0 in pred else pred
json_dict['prediction_tokens'] = pred
json_dict['aux'] = jax.tree_map(_json_compat, aux_values)
else:
raise ValueError(f'Invalid mode: {mode}')
json_str = json.dumps(json_dict, cls=json_encoder_cls)
f.write(json_str + '\n')
WriteFn = Callable[[
str,
_Inferences,
tf.data.Dataset,
str,
Optional[seqio.Vocabulary],
], None]
MergeFn = Callable[[str, str, str, Optional[int]], None]
def _extract_tokens_and_aux_values(inference_fn_outputs) -> _Inferences:
"""Extracts tokens and aux scores from a cached dataset."""
all_aux_values = {}
if isinstance(inference_fn_outputs, tuple):
indices_and_tokens, all_aux_values = inference_fn_outputs
indices, tokens = zip(*indices_and_tokens)
permutation = np.argsort(indices)
tokens = [tokens[permutation[i]] for i in range(len(permutation))]
for aux_keys, aux_values in all_aux_values.items():
all_aux_values[aux_keys] = [
aux_values[permutation[i]] for i in range(len(permutation))
]
else:
indices_and_tokens = inference_fn_outputs
_, tokens = zip(*sorted(indices_and_tokens, key=lambda x: x[0]))
return tokens, all_aux_values
def infer(
*,
mode: str,
model: models.BaseTransformerModel,
dataset_cfg: utils.DatasetConfig,
restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
partitioner: partitioning.BasePartitioner,
output_dir: str,
checkpoint_period: int,
shard_id: int = 0,
num_shards: int = 1,
merge_chunked_results: bool = True,
write_fn: WriteFn = write_inferences_to_file,
checkpoint_ds_iter: bool = True,
fallback_init_rng: Optional[int] = None,
merge_fn: MergeFn = merge_chunks_to_file,
summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config,
):
"""Infer function.
Args:
mode: Either 'predict' to decode targets, 'score' to compute the log
likelihood of given targets, or 'predict_with_aux' for both.
model: The model object to use for inference.
dataset_cfg: Specification for the dataset to infer based on.
restore_checkpoint_cfg: Specification for the model parameter checkpoint to
load.
partitioner: Partitioner for model parameters and data across devices.
output_dir: Path to directory to write temporary files and final results.
checkpoint_period: The intermediate results and dataset iterator will be
checkpointed on each multiple of this number of batches to enable
continuation after a failure.
shard_id: Index of dataset shard for this instance to use if splitting the
work across multiple jobs.
num_shards: Total number of dataset shards to split dataset across.
merge_chunked_results: Whether to merge results of all chunks into a single
json file.
write_fn: Callable function used to serialized and write inferences out to
files.
checkpoint_ds_iter: if True, will checkpoint the dataset iterator every
`checkpoint_period` to enable faster restore. This must be disabled for
certain datasets, for example since stateful iterators (e.g. from
seqio.FunctionTask) cannot be checkpointed.
fallback_init_rng: A random seed used for parameter initialization during
model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is
set to True. If None, parameter initialization is not allowed during model
loading and having fallback_to_scratch enabled will result in an error.
merge_fn: Callable function used to merge inferences from multiple files.
summarize_config_fn: A function that takes in the model directory, an
optional SummaryWriter, and the step number, and writes a summary of the
configuration. SummaryWriter will be None in most cases.
"""
logging.info('Process ID: %d', jax.process_index())
summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0)
if mode not in ('predict', 'score', 'predict_with_aux'):
raise ValueError(
"`mode` must be one of 'predict', 'score' or 'predict_with_aux'. "
f"Got '{mode}'")
# Remove double-slashes in directory path to avoid inconsistencies.
output_dir = re.sub(r'(?<!gs:)([\/]{2,})', '/', output_dir)
ds_vocabs = utils.get_vocabulary(dataset_cfg)
if (ds_vocabs[0] != model.input_vocabulary or
ds_vocabs[1] != model.output_vocabulary):
raise ValueError(
'Model and Task vocabularies do not match.\n'
f'Task Input: {ds_vocabs[0]}, Model Input: {model.input_vocabulary}\n'
f'Task Output: {ds_vocabs[1]}, Model Output: {model.output_vocabulary}')
batch_size = dataset_cfg.batch_size
# Set up dataset.
if dataset_cfg.module:
utils.import_module(dataset_cfg.module)
host_shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards)
task_or_mixture = seqio.get_mixture_or_task(dataset_cfg.mixture_or_task_name)
feature_converter = model.FEATURE_CONVERTER_CLS(pack=False)
def _get_dataset(dataset_provider):
# TODO(adarob): assert pack is false, shuffle is false, seed?
return dataset_provider.get_dataset(
sequence_length=dataset_cfg.task_feature_lengths,
split=dataset_cfg.split,
shuffle=False,
num_epochs=1,
shard_info=host_shard_info,
use_cached=dataset_cfg.use_cached,
seed=dataset_cfg.seed)
# Each "chunk" should be how often we checkpoint the input dataset and flush
# the inferences to disk.
logging.info('Inferring with checkpoints every %d batches of %d examples.',
checkpoint_period, batch_size)
logging.info('Initializing model, optimizer, and step functions.')
element_spec = feature_converter(
_get_dataset(task_or_mixture),
dataset_cfg.task_feature_lengths).element_spec
input_shapes = {
k: (batch_size,) + spec.shape for k, spec in element_spec.items()
}
input_types = {
k: jnp.dtype(spec.dtype.as_numpy_dtype)
for k, spec in element_spec.items()
}
# Initialize optimizer from the existing checkpoint.
# TODO(adarob): Support inference over multiple checkpoints.
train_state_initializer = utils.TrainStateInitializer(
optimizer_def=None, # Do not load optimizer state.
init_fn=model.get_initial_variables,
input_shapes=input_shapes,
input_types=input_types,
partitioner=partitioner)
# Log the variable shapes information and write to a file.
model_info_log_file = os.path.join(output_dir, 'model-info.txt')
if shard_id == 0:
utils.log_model_info(model_info_log_file,
train_state_initializer.global_train_state_shape,
partitioner)
# Disable strictness since we are dropping the optimizer state.
restore_checkpoint_cfg.strict = False
if fallback_init_rng is not None:
fallback_init_rng = jax.random.PRNGKey(fallback_init_rng)
train_state = train_state_initializer.from_checkpoint(
[restore_checkpoint_cfg], init_rng=fallback_init_rng)
if mode == 'predict':
infer_step = model.predict_batch
elif mode == 'predict_with_aux':
infer_step = model.predict_batch_with_aux
else: # mode == 'score'
infer_step = model.score_batch
infer_fn = functools.partial(
utils.get_infer_fn(
infer_step=infer_step,
batch_size=batch_size,
train_state_axes=train_state_initializer.train_state_axes,
partitioner=partitioner),
train_state=train_state)
def infer_task(task: seqio.Task):
tmp_dir = os.path.join(output_dir,
f'tmp-{task.name}-{shard_id:05}-of-{num_shards:05}')
if jax.process_index() == 0:
gfile.makedirs(tmp_dir)
# Use `max_workers=1` to ensure writes occur sequentially.
write_thread_pool = FailFastThreadPoolExecutor(max_workers=1)
logging.info("Loading dataset for task '%s'.", task.name)
ds = _get_dataset(task)
model_ds = feature_converter(
ds, task_feature_lengths=dataset_cfg.task_feature_lengths)
# Zip task and model features.
# (task, model)
infer_ds = tf.data.Dataset.zip((ds, model_ds))
# Create batches the size of each chunk and index them.
# (i, [(task, model)] * chunk_size)
infer_ds = infer_ds.padded_batch(
checkpoint_period * batch_size, drop_remainder=False).enumerate()
infer_ds_iter: Iterator[Tuple[int, Any]] = iter(infer_ds.prefetch(AUTOTUNE))
if checkpoint_ds_iter:
# Create checkpoint manager and restore state, if applicable.
ckpt_path = os.path.join(tmp_dir, 'input.ckpt')
input_ckpt = tf.train.Checkpoint(ds=infer_ds_iter)
if gfile.glob(ckpt_path + '*'):
logging.info('Restoring input iterator from %s', ckpt_path)
input_ckpt.read(ckpt_path).assert_consumed()
output_fname = f'{task.name}-{mode}.jsonl-{shard_id:05}-of-{num_shards:05}'
if gfile.exists(os.path.join(output_dir, output_fname)):
logging.info(
"File %s exists. Skipping inference for shard %d/%d of task '%s'",
output_fname, shard_id, num_shards, task.name)
return
logging.info("Starting inference loop for shard %d of %d of task '%s'.",
shard_id, num_shards, task.name)
def _write_chunk_and_canonicalize_ckpt(chunk: int, chunk_path: str,
inferences: _Inferences,
task_ds: tf.data.Dataset,
chunk_ckpt_path: Optional[str]):
write_tick = time.time()
logging.info('Writing chunk %d results to %s', chunk, chunk_path)
write_fn(chunk_path, inferences, task_ds, mode,
task.output_features['targets'].vocabulary)
with gfile.GFile(chunk_path + '.COMPLETED', 'w') as f:
f.write('')
write_time = time.time() - write_tick
logging.info('Writing completed in %02f seconds (%02f examples/sec).',
write_time,
len(inferences) / write_time)
update_measurement_series('writing_total_sec', chunk, write_time)
update_measurement_series('writing_examples_per_sec', chunk,
len(inferences) / write_time)
if chunk_ckpt_path:
# Canonicalize checkpoint.
for fname in gfile.glob(chunk_ckpt_path + '*'):
gfile.rename(
fname, fname.replace(chunk_ckpt_path, ckpt_path), overwrite=True)
# Main Loop over "chunks".
for chunk, chunk_batch in infer_ds_iter:
logging.info('Starting chunk %d', chunk)
chunk_tick = time.time()
# Load the dataset for the next chunk. We can't use `infer_ds_iter`
# directly since `infer_fn` needs to know the exact size of each chunk,
# which may be smaller for the final one.
chunk_ds = tf.data.Dataset.from_tensor_slices(chunk_batch)
chunk_ds.cache().prefetch(AUTOTUNE)
# Unzip chunk dataset in to pretokenized and model datasets.
task_ds = chunk_ds.map(lambda p, m: p, num_parallel_calls=AUTOTUNE)
model_ds = chunk_ds.map(lambda p, m: m, num_parallel_calls=AUTOTUNE)
# Get a chunk-specific RNG key.
chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk)
chunk_path = os.path.join(tmp_dir, f'{output_fname}-chunk{chunk:05}')
if gfile.exists(chunk_path + '.COMPLETED') and not checkpoint_ds_iter:
logging.info('Skipping chunk %s. Chunk file already exists.', chunk)
continue
logging.info('Running inference on %d batches.', checkpoint_period)
inferences = _extract_tokens_and_aux_values(
infer_fn(model_ds.enumerate(), rng=chunk_rng))
if jax.process_index() == 0:
chunk_time = time.time() - chunk_tick
logging.info('chunk completed in %02f seconds (%02f examples/sec).',
chunk_time,
len(inferences) / chunk_time)
update_measurement_series('inference_total_sec', chunk, chunk_time)
update_measurement_series('inference_examples_per_sec', chunk,
len(inferences) / chunk_time)
chunk_ckpt_path = None
if checkpoint_ds_iter:
# Store iterator checkpoint in temporary location before writing the
# model output asynchronously. After outputs are written, the
# checkpoint will be moved to the canonical location to be used if
# restart occurs.
ckpt_tick = time.time()
chunk_ckpt_path = input_ckpt.write(
os.path.join(tmp_dir, f'{chunk}.ckpt'))
logging.info(
'Checkpoint written to temporary location in %02f seconds.',
time.time() - ckpt_tick)
# These will execute sequentially since the ThreadPool size is 1.
write_thread_pool.submit(
_write_chunk_and_canonicalize_ckpt,
chunk=chunk,
chunk_path=chunk_path,
inferences=inferences,
task_ds=task_ds,
chunk_ckpt_path=chunk_ckpt_path)
# Wait for checkpoint to be written before continuing.
multihost_utils.sync_global_devices(
f'{task.name}:checkpoint_chunk{chunk:05}')
logging.info("Finished inference for task '%s'.", task.name)
logging.info('Waiting for chunk writes to complete.')
write_thread_pool.shutdown(wait=True)
if jax.process_index() == 0 and merge_chunked_results:
step = None if train_state is None else int(train_state.step)
merge_fn(output_dir, output_fname, tmp_dir, step)
logging.info('Deleting temporary files.')
gfile.rmtree(tmp_dir)
# Wait for host 0 to finish writing before exiting.
multihost_utils.sync_global_devices(f'{task.name}:complete')
for task in seqio.get_subtasks(task_or_mixture):
logging.info("Starting inference for task '%s'", task.name)
infer_task(task)
logging.info('DONE')
def update_measurement_series(series_name: str, step: int, value: float):
"""Not implemented externally."""
del series_name, step, value
if __name__ == '__main__':
# pylint:disable=g-import-not-at-top
from absl import app
from absl import flags
import gin
# pylint:enable=g-import-not-at-top
FLAGS = flags.FLAGS
jax.config.parse_flags_with_absl()
flags.DEFINE_integer(
'shard_id',
default=None,
help='Index to use for splitting the Task across multiple inference '
'runs. NB: If set, this overrides --gin.infer.shard_id')
flags.DEFINE_multi_string(
'gin_file',
default=None,
help='Path to gin configuration file. Multiple paths may be passed and '
'will be imported in the given order, with later configurations '
'overriding earlier ones.')
flags.DEFINE_multi_string(
'gin_bindings', default=[], help='Individual gin bindings.')
flags.DEFINE_list(
'gin_search_paths',
default=['.'],
help='Comma-separated list of gin config path prefixes to be prepended '
'to suffixes given via `--gin_file`. If a file appears in. Only the '
'first prefix that produces a valid path for each suffix will be '
'used.')
flags.DEFINE_string(
'tfds_data_dir', None,
'If set, this directory will be used to store datasets prepared by '
'TensorFlow Datasets that are not available in the public TFDS GCS '
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
'all `Task`s.')
def main(argv: Sequence[str]):
"""Wrapper for pdb post mortems."""
_main(argv)
def _main(argv: Sequence[str]):
"""True main function."""
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
if FLAGS.tfds_data_dir:
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
# Create gin-configurable version of `infer`.
infer_using_gin = gin.configurable(infer)
gin_utils.parse_gin_flags(
# User-provided gin paths take precedence if relative paths conflict.
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS,
FLAGS.gin_file,
FLAGS.gin_bindings)
# See http://yaqs/7882016229479677952 for further gin-config discussion.
def _get_gin_parameter(key: str) -> Any:
value = gin.query_parameter(key)
if isinstance(value, gin.config.ConfigurableReference):
if value.evaluate:
return value.scoped_configurable_fn()
return value.scoped_configurable_fn
return value
shard_id = (
FLAGS.shard_id
if FLAGS.shard_id is not None else _get_gin_parameter('infer.shard_id'))
if shard_id == 0:
gin_utils.summarize_gin_config(
model_dir=_get_gin_parameter('infer.output_dir'),
summary_writer=None,
step=0)
if FLAGS.shard_id is not None:
# We fall back to this flag since XM does not support sweeps over flags
# with '.' in them (it treats them like nested dictionaries).
# TODO(adarob): Figure out a workaround so we can deprecate this flag.
infer_using_gin(shard_id=FLAGS.shard_id)
else:
infer_using_gin()
gin_utils.run(main)