Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# pylint:disable=line-too-long | |
# pyformat: disable | |
r"""This script runs inference on a T5X-compatible model. | |
""" | |
# pyformat: enable | |
# pylint:enable=line-too-long | |
import concurrent.futures | |
import functools | |
import hashlib | |
import json | |
import os | |
import re | |
import shutil | |
import time | |
from typing import Any, Callable, Iterator, List, Mapping, Optional, Sequence, Tuple, Type | |
# TODO(adarob): Re-enable once users are notified and tests are updated. | |
# Must be set before flax imports. | |
# pylint:disable=g-import-not-at-top | |
os.environ['FLAX_LAZY_RNG'] = 'no' | |
from absl import logging | |
from clu import metric_writers | |
import jax | |
from jax.experimental import multihost_utils | |
import jax.numpy as jnp | |
import numpy as np | |
import seqio | |
from t5x import gin_utils | |
from t5x import models | |
from t5x import partitioning | |
from t5x import utils | |
import tensorflow as tf | |
from tensorflow.io import gfile | |
from typing_extensions import Protocol | |
# Automatically search for gin files relative to the T5X package. | |
_DEFAULT_GIN_SEARCH_PATHS = [ | |
os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
] | |
AUTOTUNE = tf.data.experimental.AUTOTUNE | |
class SummarizeConfigFn(Protocol): | |
def __call__(self, model_dir: str, | |
summary_writer: Optional[metric_writers.SummaryWriter], | |
step: int) -> None: | |
... | |
class FailFastThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor): | |
"""Wrapper for ThreadPoolExecutor that crashes main thread on exceptions. | |
NOTE: this class should be used only from the main thread. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._incomplete_futures: List[concurrent.futures.Future] = [] | |
def check_for_exceptions(self, wait: bool = False): | |
"""Raises any exceptions from complete futures on the main thread.""" | |
still_incomplete_futures = [] | |
for future in self._incomplete_futures: | |
try: | |
exception = future.exception(timeout=0 if wait else None) | |
except concurrent.futures.TimeoutError: | |
still_incomplete_futures.append(future) | |
if exception is not None: | |
raise exception | |
self._incomplete_futures = still_incomplete_futures | |
def submit(self, *args, **kwargs) -> concurrent.futures.Future: | |
"""Submit function to threadpool, capturing the returned future.""" | |
future = super().submit(*args, **kwargs) | |
self._incomplete_futures.append(future) | |
self.check_for_exceptions(wait=False) | |
return future | |
def shutdown(self, *args, wait: bool = False, **kwargs): | |
self.check_for_exceptions(wait=wait) | |
super().shutdown(*args, **kwargs) | |
def create_task_from_tfexample_file(paths: Sequence[str], | |
file_type: str, | |
inputs_key: str, | |
targets_key: Optional[str], | |
features: Mapping[str, seqio.Feature], | |
task_id: Optional[str] = None) -> str: | |
"""Registers ad-hoc Task for file-based dataset of TFExamples. | |
Args: | |
paths: Input file paths; all files should have type `file_type` and contain | |
binary-serialized TFExample protos. | |
file_type: Input file type; e.g., 'tfrecord', 'recordio', 'sstable'. For | |
keyed formats like 'sstable', we ignore the keys and use only the values. | |
inputs_key: Name of TFExample feature containing the input text for T5X. The | |
value of this feature should be a UTF8-encoded string. | |
targets_key: Optional name of a TFExample feature containing the target text | |
(relevant only in scoring mode). The value of this feature should be a | |
UTF8-encoded string. | |
features: Should have entries for keys 'inputs' and (if targets_key is not | |
None) 'targets', mapping to `seqio.Feature` objects that specify | |
attributes like vocabulary, add_eos, etc. These attributes are used for | |
preprocessing and featurizing the input text. | |
task_id: Task name identifier. By default, it is set to a unique and | |
deterministic hash id. Overrideable via this argument. | |
Returns: | |
Name of the newly-registered Task. This Task has a split named 'infer' that | |
contains the preprocessed and featurized input dataset. | |
""" | |
# tf.io.gfile.glob supports lists, in contrast to gfile.glob. | |
files = tf.io.gfile.glob(paths) | |
if files: | |
logging.info('Using tfexample files %s', files) | |
else: | |
# Fail early if there's something wrong with the input file pattern. | |
raise ValueError('Missing or invalid paths: %s' % paths) | |
reader = { | |
'tfrecord': | |
tf.data.TFRecordDataset, | |
}[file_type] | |
feature_description = {inputs_key: tf.io.FixedLenFeature([], tf.string)} | |
if targets_key: | |
feature_description[targets_key] = tf.io.FixedLenFeature([], tf.string) | |
# Create a unique, deterministic task name. | |
if task_id is None: | |
task_id = hashlib.md5( | |
':'.join(list(paths) + | |
[inputs_key, targets_key or '']).encode()).hexdigest()[:10] | |
task = seqio.TaskRegistry.add( | |
name=f'infer_{task_id}', | |
source=seqio.TFExampleDataSource({'infer': paths}, | |
feature_description=feature_description, | |
reader_cls=reader), | |
preprocessors=[ | |
functools.partial( | |
seqio.preprocessors.rekey, | |
key_map={ | |
'inputs': inputs_key, | |
'targets': targets_key | |
}), seqio.preprocessors.tokenize_and_append_eos | |
], | |
output_features=features) | |
return task.name | |
def merge_chunks_to_file( | |
output_dir: str, | |
output_fname: str, | |
tmp_dir: str, | |
step: Optional[int], | |
) -> None: | |
"""Merge the predictions from different chunks into a unified file.""" | |
logging.info('Merging chunk results.') | |
# Merge chunks into single file. | |
chunk_paths = sorted( | |
gfile.glob(os.path.join(tmp_dir, f'{output_fname}-chunk?????'))) | |
if not chunk_paths: | |
raise FileNotFoundError( | |
'No chunk results found! One possible explanation is that your ' | |
'input did not contain any examples') | |
assert int(chunk_paths[-1][-5:]) + 1 == len(chunk_paths), ( | |
f'Expecting {int(chunk_paths[-1][-5:])} chunk paths, found ' | |
f'{len(chunk_paths)}') | |
output_path = os.path.join(output_dir, output_fname) | |
del step | |
with gfile.GFile(output_path, 'wb') as merged: | |
for chunk_path in chunk_paths: | |
with gfile.GFile(chunk_path, 'rb') as ef: | |
shutil.copyfileobj(ef, merged) | |
logging.info('Results written to %s.', output_path) | |
_Inferences = Tuple[Sequence[Any], Mapping[str, Any]] | |
def write_inferences_to_file( | |
path: str, | |
inferences: _Inferences, | |
task_ds: tf.data.Dataset, | |
mode: str, | |
vocabulary: Optional[seqio.Vocabulary] = None, | |
json_encoder_cls: Type[json.JSONEncoder] = seqio.TensorAndNumpyEncoder, | |
include_all_inputs: bool = False, | |
input_fields_to_include: Optional[Sequence[str]] = None, | |
output_ids: bool = False) -> None: | |
"""Write model predictions, along with pretokenized inputs, to JSONL file. | |
Args: | |
path: File path to write to. | |
inferences: A tuple containing (predictions, aux_values). If mode is | |
'predict' then the `predictions` will be token IDs. If it's | |
'scores' then it'll be a collection of scores. `aux_values` will be an | |
empty dictionary unless mode is 'predict_with_aux', in which case it'll | |
contain the model's auxiliary outputs. | |
task_ds: Original task dataset. Features from task with suffix | |
`_pretokenized` are added to the outputs. | |
mode: Prediction mode, either 'predict', 'score' or 'predict_with_aux'. | |
vocabulary: Task output vocabulary. Only used in `predict` mode in order to | |
decode predicted outputs into string. | |
json_encoder_cls: a JSON encoder class used to customize JSON serialization | |
via json.dumps. | |
include_all_inputs: if True, will include all model inputs in the output | |
JSONL file (including raw tokens) in addition to the pretokenized inputs. | |
input_fields_to_include: List of input fields to include in the output JSONL | |
file. This list should be None if `include_all_inputs` is set to True. | |
output_ids: if True, will output the token ID sequence for the output, in | |
addition to the decoded text. | |
""" | |
all_predictions, all_aux_values = inferences | |
if mode in ('predict', 'predict_with_aux') and vocabulary is None: | |
raise ValueError('The `vocabulary` parameter is required in `predict` and ' | |
'`predict_with_aux` modes') | |
def _json_compat(value): | |
if isinstance(value, bytes): | |
return value.decode('utf-8') | |
elif isinstance(value, (jnp.bfloat16, jnp.floating)): | |
return float(value) | |
elif isinstance(value, jnp.integer): | |
return float(value) | |
elif isinstance(value, (jnp.ndarray, np.ndarray)): | |
# Flatten array features. | |
return value.tolist() | |
else: | |
return value | |
if include_all_inputs and input_fields_to_include is not None: | |
raise ValueError( | |
'include_all_inputs and input_fields_to_include should not be set' | |
' simultaneously.') | |
with gfile.GFile(path, 'w') as f: | |
for i, inp in task_ds.enumerate().as_numpy_iterator(): | |
predictions = all_predictions[i] | |
aux_values = {aux_field: v[i] for aux_field, v in all_aux_values.items()} | |
if include_all_inputs: | |
inputs = inp | |
elif input_fields_to_include is not None: | |
inputs = { | |
k: v for k, v in inp.items() if k in input_fields_to_include or | |
(k.endswith('_pretokenized') and | |
k[:-len('_pretokenized')] in input_fields_to_include) | |
} | |
else: | |
inputs = {k: v for k, v in inp.items() if k.endswith('_pretokenized')} | |
json_dict = {} | |
json_dict['inputs'] = {k: _json_compat(v) for k, v in inputs.items()} | |
if mode == 'predict': | |
assert vocabulary is not None | |
json_dict['prediction'] = _json_compat( | |
vocabulary.decode_tf(tf.constant(predictions)).numpy()) | |
if output_ids: | |
pred = _json_compat(tf.constant(predictions).numpy()) | |
# Truncate padding tokens. | |
assert isinstance(pred, list) | |
pred = pred[:pred.index(0)] if 0 in pred else pred | |
json_dict['prediction_tokens'] = pred | |
elif mode == 'score': | |
json_dict['score'] = _json_compat(predictions) | |
elif mode == 'predict_with_aux': | |
assert vocabulary is not None | |
json_dict['prediction'] = _json_compat( | |
vocabulary.decode_tf(tf.constant(predictions)).numpy()) | |
if output_ids: | |
pred = _json_compat(tf.constant(predictions).numpy()) | |
# Truncate padding tokens. | |
pred = pred[:pred.index(0)] if 0 in pred else pred | |
json_dict['prediction_tokens'] = pred | |
json_dict['aux'] = jax.tree_map(_json_compat, aux_values) | |
else: | |
raise ValueError(f'Invalid mode: {mode}') | |
json_str = json.dumps(json_dict, cls=json_encoder_cls) | |
f.write(json_str + '\n') | |
WriteFn = Callable[[ | |
str, | |
_Inferences, | |
tf.data.Dataset, | |
str, | |
Optional[seqio.Vocabulary], | |
], None] | |
MergeFn = Callable[[str, str, str, Optional[int]], None] | |
def _extract_tokens_and_aux_values(inference_fn_outputs) -> _Inferences: | |
"""Extracts tokens and aux scores from a cached dataset.""" | |
all_aux_values = {} | |
if isinstance(inference_fn_outputs, tuple): | |
indices_and_tokens, all_aux_values = inference_fn_outputs | |
indices, tokens = zip(*indices_and_tokens) | |
permutation = np.argsort(indices) | |
tokens = [tokens[permutation[i]] for i in range(len(permutation))] | |
for aux_keys, aux_values in all_aux_values.items(): | |
all_aux_values[aux_keys] = [ | |
aux_values[permutation[i]] for i in range(len(permutation)) | |
] | |
else: | |
indices_and_tokens = inference_fn_outputs | |
_, tokens = zip(*sorted(indices_and_tokens, key=lambda x: x[0])) | |
return tokens, all_aux_values | |
def infer( | |
*, | |
mode: str, | |
model: models.BaseTransformerModel, | |
dataset_cfg: utils.DatasetConfig, | |
restore_checkpoint_cfg: utils.RestoreCheckpointConfig, | |
partitioner: partitioning.BasePartitioner, | |
output_dir: str, | |
checkpoint_period: int, | |
shard_id: int = 0, | |
num_shards: int = 1, | |
merge_chunked_results: bool = True, | |
write_fn: WriteFn = write_inferences_to_file, | |
checkpoint_ds_iter: bool = True, | |
fallback_init_rng: Optional[int] = None, | |
merge_fn: MergeFn = merge_chunks_to_file, | |
summarize_config_fn: SummarizeConfigFn = gin_utils.summarize_gin_config, | |
): | |
"""Infer function. | |
Args: | |
mode: Either 'predict' to decode targets, 'score' to compute the log | |
likelihood of given targets, or 'predict_with_aux' for both. | |
model: The model object to use for inference. | |
dataset_cfg: Specification for the dataset to infer based on. | |
restore_checkpoint_cfg: Specification for the model parameter checkpoint to | |
load. | |
partitioner: Partitioner for model parameters and data across devices. | |
output_dir: Path to directory to write temporary files and final results. | |
checkpoint_period: The intermediate results and dataset iterator will be | |
checkpointed on each multiple of this number of batches to enable | |
continuation after a failure. | |
shard_id: Index of dataset shard for this instance to use if splitting the | |
work across multiple jobs. | |
num_shards: Total number of dataset shards to split dataset across. | |
merge_chunked_results: Whether to merge results of all chunks into a single | |
json file. | |
write_fn: Callable function used to serialized and write inferences out to | |
files. | |
checkpoint_ds_iter: if True, will checkpoint the dataset iterator every | |
`checkpoint_period` to enable faster restore. This must be disabled for | |
certain datasets, for example since stateful iterators (e.g. from | |
seqio.FunctionTask) cannot be checkpointed. | |
fallback_init_rng: A random seed used for parameter initialization during | |
model re-loading when utils.RestoreCheckpointConfig.fallback_to_scratch is | |
set to True. If None, parameter initialization is not allowed during model | |
loading and having fallback_to_scratch enabled will result in an error. | |
merge_fn: Callable function used to merge inferences from multiple files. | |
summarize_config_fn: A function that takes in the model directory, an | |
optional SummaryWriter, and the step number, and writes a summary of the | |
configuration. SummaryWriter will be None in most cases. | |
""" | |
logging.info('Process ID: %d', jax.process_index()) | |
summarize_config_fn(model_dir=output_dir, summary_writer=None, step=0) | |
if mode not in ('predict', 'score', 'predict_with_aux'): | |
raise ValueError( | |
"`mode` must be one of 'predict', 'score' or 'predict_with_aux'. " | |
f"Got '{mode}'") | |
# Remove double-slashes in directory path to avoid inconsistencies. | |
output_dir = re.sub(r'(?<!gs:)([\/]{2,})', '/', output_dir) | |
ds_vocabs = utils.get_vocabulary(dataset_cfg) | |
if (ds_vocabs[0] != model.input_vocabulary or | |
ds_vocabs[1] != model.output_vocabulary): | |
raise ValueError( | |
'Model and Task vocabularies do not match.\n' | |
f'Task Input: {ds_vocabs[0]}, Model Input: {model.input_vocabulary}\n' | |
f'Task Output: {ds_vocabs[1]}, Model Output: {model.output_vocabulary}') | |
batch_size = dataset_cfg.batch_size | |
# Set up dataset. | |
if dataset_cfg.module: | |
utils.import_module(dataset_cfg.module) | |
host_shard_info = seqio.ShardInfo(index=shard_id, num_shards=num_shards) | |
task_or_mixture = seqio.get_mixture_or_task(dataset_cfg.mixture_or_task_name) | |
feature_converter = model.FEATURE_CONVERTER_CLS(pack=False) | |
def _get_dataset(dataset_provider): | |
# TODO(adarob): assert pack is false, shuffle is false, seed? | |
return dataset_provider.get_dataset( | |
sequence_length=dataset_cfg.task_feature_lengths, | |
split=dataset_cfg.split, | |
shuffle=False, | |
num_epochs=1, | |
shard_info=host_shard_info, | |
use_cached=dataset_cfg.use_cached, | |
seed=dataset_cfg.seed) | |
# Each "chunk" should be how often we checkpoint the input dataset and flush | |
# the inferences to disk. | |
logging.info('Inferring with checkpoints every %d batches of %d examples.', | |
checkpoint_period, batch_size) | |
logging.info('Initializing model, optimizer, and step functions.') | |
element_spec = feature_converter( | |
_get_dataset(task_or_mixture), | |
dataset_cfg.task_feature_lengths).element_spec | |
input_shapes = { | |
k: (batch_size,) + spec.shape for k, spec in element_spec.items() | |
} | |
input_types = { | |
k: jnp.dtype(spec.dtype.as_numpy_dtype) | |
for k, spec in element_spec.items() | |
} | |
# Initialize optimizer from the existing checkpoint. | |
# TODO(adarob): Support inference over multiple checkpoints. | |
train_state_initializer = utils.TrainStateInitializer( | |
optimizer_def=None, # Do not load optimizer state. | |
init_fn=model.get_initial_variables, | |
input_shapes=input_shapes, | |
input_types=input_types, | |
partitioner=partitioner) | |
# Log the variable shapes information and write to a file. | |
model_info_log_file = os.path.join(output_dir, 'model-info.txt') | |
if shard_id == 0: | |
utils.log_model_info(model_info_log_file, | |
train_state_initializer.global_train_state_shape, | |
partitioner) | |
# Disable strictness since we are dropping the optimizer state. | |
restore_checkpoint_cfg.strict = False | |
if fallback_init_rng is not None: | |
fallback_init_rng = jax.random.PRNGKey(fallback_init_rng) | |
train_state = train_state_initializer.from_checkpoint( | |
[restore_checkpoint_cfg], init_rng=fallback_init_rng) | |
if mode == 'predict': | |
infer_step = model.predict_batch | |
elif mode == 'predict_with_aux': | |
infer_step = model.predict_batch_with_aux | |
else: # mode == 'score' | |
infer_step = model.score_batch | |
infer_fn = functools.partial( | |
utils.get_infer_fn( | |
infer_step=infer_step, | |
batch_size=batch_size, | |
train_state_axes=train_state_initializer.train_state_axes, | |
partitioner=partitioner), | |
train_state=train_state) | |
def infer_task(task: seqio.Task): | |
tmp_dir = os.path.join(output_dir, | |
f'tmp-{task.name}-{shard_id:05}-of-{num_shards:05}') | |
if jax.process_index() == 0: | |
gfile.makedirs(tmp_dir) | |
# Use `max_workers=1` to ensure writes occur sequentially. | |
write_thread_pool = FailFastThreadPoolExecutor(max_workers=1) | |
logging.info("Loading dataset for task '%s'.", task.name) | |
ds = _get_dataset(task) | |
model_ds = feature_converter( | |
ds, task_feature_lengths=dataset_cfg.task_feature_lengths) | |
# Zip task and model features. | |
# (task, model) | |
infer_ds = tf.data.Dataset.zip((ds, model_ds)) | |
# Create batches the size of each chunk and index them. | |
# (i, [(task, model)] * chunk_size) | |
infer_ds = infer_ds.padded_batch( | |
checkpoint_period * batch_size, drop_remainder=False).enumerate() | |
infer_ds_iter: Iterator[Tuple[int, Any]] = iter(infer_ds.prefetch(AUTOTUNE)) | |
if checkpoint_ds_iter: | |
# Create checkpoint manager and restore state, if applicable. | |
ckpt_path = os.path.join(tmp_dir, 'input.ckpt') | |
input_ckpt = tf.train.Checkpoint(ds=infer_ds_iter) | |
if gfile.glob(ckpt_path + '*'): | |
logging.info('Restoring input iterator from %s', ckpt_path) | |
input_ckpt.read(ckpt_path).assert_consumed() | |
output_fname = f'{task.name}-{mode}.jsonl-{shard_id:05}-of-{num_shards:05}' | |
if gfile.exists(os.path.join(output_dir, output_fname)): | |
logging.info( | |
"File %s exists. Skipping inference for shard %d/%d of task '%s'", | |
output_fname, shard_id, num_shards, task.name) | |
return | |
logging.info("Starting inference loop for shard %d of %d of task '%s'.", | |
shard_id, num_shards, task.name) | |
def _write_chunk_and_canonicalize_ckpt(chunk: int, chunk_path: str, | |
inferences: _Inferences, | |
task_ds: tf.data.Dataset, | |
chunk_ckpt_path: Optional[str]): | |
write_tick = time.time() | |
logging.info('Writing chunk %d results to %s', chunk, chunk_path) | |
write_fn(chunk_path, inferences, task_ds, mode, | |
task.output_features['targets'].vocabulary) | |
with gfile.GFile(chunk_path + '.COMPLETED', 'w') as f: | |
f.write('') | |
write_time = time.time() - write_tick | |
logging.info('Writing completed in %02f seconds (%02f examples/sec).', | |
write_time, | |
len(inferences) / write_time) | |
update_measurement_series('writing_total_sec', chunk, write_time) | |
update_measurement_series('writing_examples_per_sec', chunk, | |
len(inferences) / write_time) | |
if chunk_ckpt_path: | |
# Canonicalize checkpoint. | |
for fname in gfile.glob(chunk_ckpt_path + '*'): | |
gfile.rename( | |
fname, fname.replace(chunk_ckpt_path, ckpt_path), overwrite=True) | |
# Main Loop over "chunks". | |
for chunk, chunk_batch in infer_ds_iter: | |
logging.info('Starting chunk %d', chunk) | |
chunk_tick = time.time() | |
# Load the dataset for the next chunk. We can't use `infer_ds_iter` | |
# directly since `infer_fn` needs to know the exact size of each chunk, | |
# which may be smaller for the final one. | |
chunk_ds = tf.data.Dataset.from_tensor_slices(chunk_batch) | |
chunk_ds.cache().prefetch(AUTOTUNE) | |
# Unzip chunk dataset in to pretokenized and model datasets. | |
task_ds = chunk_ds.map(lambda p, m: p, num_parallel_calls=AUTOTUNE) | |
model_ds = chunk_ds.map(lambda p, m: m, num_parallel_calls=AUTOTUNE) | |
# Get a chunk-specific RNG key. | |
chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk) | |
chunk_path = os.path.join(tmp_dir, f'{output_fname}-chunk{chunk:05}') | |
if gfile.exists(chunk_path + '.COMPLETED') and not checkpoint_ds_iter: | |
logging.info('Skipping chunk %s. Chunk file already exists.', chunk) | |
continue | |
logging.info('Running inference on %d batches.', checkpoint_period) | |
inferences = _extract_tokens_and_aux_values( | |
infer_fn(model_ds.enumerate(), rng=chunk_rng)) | |
if jax.process_index() == 0: | |
chunk_time = time.time() - chunk_tick | |
logging.info('chunk completed in %02f seconds (%02f examples/sec).', | |
chunk_time, | |
len(inferences) / chunk_time) | |
update_measurement_series('inference_total_sec', chunk, chunk_time) | |
update_measurement_series('inference_examples_per_sec', chunk, | |
len(inferences) / chunk_time) | |
chunk_ckpt_path = None | |
if checkpoint_ds_iter: | |
# Store iterator checkpoint in temporary location before writing the | |
# model output asynchronously. After outputs are written, the | |
# checkpoint will be moved to the canonical location to be used if | |
# restart occurs. | |
ckpt_tick = time.time() | |
chunk_ckpt_path = input_ckpt.write( | |
os.path.join(tmp_dir, f'{chunk}.ckpt')) | |
logging.info( | |
'Checkpoint written to temporary location in %02f seconds.', | |
time.time() - ckpt_tick) | |
# These will execute sequentially since the ThreadPool size is 1. | |
write_thread_pool.submit( | |
_write_chunk_and_canonicalize_ckpt, | |
chunk=chunk, | |
chunk_path=chunk_path, | |
inferences=inferences, | |
task_ds=task_ds, | |
chunk_ckpt_path=chunk_ckpt_path) | |
# Wait for checkpoint to be written before continuing. | |
multihost_utils.sync_global_devices( | |
f'{task.name}:checkpoint_chunk{chunk:05}') | |
logging.info("Finished inference for task '%s'.", task.name) | |
logging.info('Waiting for chunk writes to complete.') | |
write_thread_pool.shutdown(wait=True) | |
if jax.process_index() == 0 and merge_chunked_results: | |
step = None if train_state is None else int(train_state.step) | |
merge_fn(output_dir, output_fname, tmp_dir, step) | |
logging.info('Deleting temporary files.') | |
gfile.rmtree(tmp_dir) | |
# Wait for host 0 to finish writing before exiting. | |
multihost_utils.sync_global_devices(f'{task.name}:complete') | |
for task in seqio.get_subtasks(task_or_mixture): | |
logging.info("Starting inference for task '%s'", task.name) | |
infer_task(task) | |
logging.info('DONE') | |
def update_measurement_series(series_name: str, step: int, value: float): | |
"""Not implemented externally.""" | |
del series_name, step, value | |
if __name__ == '__main__': | |
# pylint:disable=g-import-not-at-top | |
from absl import app | |
from absl import flags | |
import gin | |
# pylint:enable=g-import-not-at-top | |
FLAGS = flags.FLAGS | |
jax.config.parse_flags_with_absl() | |
flags.DEFINE_integer( | |
'shard_id', | |
default=None, | |
help='Index to use for splitting the Task across multiple inference ' | |
'runs. NB: If set, this overrides --gin.infer.shard_id') | |
flags.DEFINE_multi_string( | |
'gin_file', | |
default=None, | |
help='Path to gin configuration file. Multiple paths may be passed and ' | |
'will be imported in the given order, with later configurations ' | |
'overriding earlier ones.') | |
flags.DEFINE_multi_string( | |
'gin_bindings', default=[], help='Individual gin bindings.') | |
flags.DEFINE_list( | |
'gin_search_paths', | |
default=['.'], | |
help='Comma-separated list of gin config path prefixes to be prepended ' | |
'to suffixes given via `--gin_file`. If a file appears in. Only the ' | |
'first prefix that produces a valid path for each suffix will be ' | |
'used.') | |
flags.DEFINE_string( | |
'tfds_data_dir', None, | |
'If set, this directory will be used to store datasets prepared by ' | |
'TensorFlow Datasets that are not available in the public TFDS GCS ' | |
'bucket. Note that this flag overrides the `tfds_data_dir` attribute of ' | |
'all `Task`s.') | |
def main(argv: Sequence[str]): | |
"""Wrapper for pdb post mortems.""" | |
_main(argv) | |
def _main(argv: Sequence[str]): | |
"""True main function.""" | |
if len(argv) > 1: | |
raise app.UsageError('Too many command-line arguments.') | |
if FLAGS.tfds_data_dir: | |
seqio.set_tfds_data_dir_override(FLAGS.tfds_data_dir) | |
# Create gin-configurable version of `infer`. | |
infer_using_gin = gin.configurable(infer) | |
gin_utils.parse_gin_flags( | |
# User-provided gin paths take precedence if relative paths conflict. | |
FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, | |
FLAGS.gin_file, | |
FLAGS.gin_bindings) | |
# See http://yaqs/7882016229479677952 for further gin-config discussion. | |
def _get_gin_parameter(key: str) -> Any: | |
value = gin.query_parameter(key) | |
if isinstance(value, gin.config.ConfigurableReference): | |
if value.evaluate: | |
return value.scoped_configurable_fn() | |
return value.scoped_configurable_fn | |
return value | |
shard_id = ( | |
FLAGS.shard_id | |
if FLAGS.shard_id is not None else _get_gin_parameter('infer.shard_id')) | |
if shard_id == 0: | |
gin_utils.summarize_gin_config( | |
model_dir=_get_gin_parameter('infer.output_dir'), | |
summary_writer=None, | |
step=0) | |
if FLAGS.shard_id is not None: | |
# We fall back to this flag since XM does not support sweeps over flags | |
# with '.' in them (it treats them like nested dictionaries). | |
# TODO(adarob): Figure out a workaround so we can deprecate this flag. | |
infer_using_gin(shard_id=FLAGS.shard_id) | |
else: | |
infer_using_gin() | |
gin_utils.run(main) | |