Spaces:
Build error
Build error
# Defaults for infer.py if using a TFExample file as input. | |
# | |
# | |
# The features from each TFExample are tokenized using the model's vocabulary. | |
# By default, the inputs feature is assumed to be keyed as 'inputs', but this | |
# can be overridden with `create_task_from_tfexample_file.inputs_key`. | |
# | |
# You must also include a binding for MODEL. | |
# | |
# Required to be set: | |
# | |
# - TF_EXAMPLE_FILE_PATHS: The path to read TF Examples from. | |
# - TF_EXAMPLE_FILE_TYPE: The type of file to read TF Examples from. Currently | |
# supported: 'tfrecord', 'recordio', 'sstable'. | |
# - FEATURE_LENGTHS: The maximum length per feature in the TF Examples. | |
# - CHECKPOINT_PATH: The model checkpoint to use for inference | |
# - INFER_OUTPUT_DIR: The dir to write results to. | |
# | |
# | |
# Commonly overridden options: | |
# | |
# - infer.mode | |
# - infer.checkpoint_period | |
# - infer.shard_id | |
# - infer.num_shards | |
# - create_task_from_tfexample_file.inputs_key | |
# - create_task_from_tfexample_file.targets_key | |
# - DatasetConfig.split | |
# - DatasetConfig.batch_size | |
# - RestoreCheckpointConfig.mode | |
# - PjitPartitioner.num_partitions | |
from __gin__ import dynamic_registration | |
import __main__ as infer_script | |
import seqio | |
from t5x import models | |
from t5x import partitioning | |
from t5x import utils | |
# Must be overridden | |
TF_EXAMPLE_FILE_PATHS = %gin.REQUIRED | |
TF_EXAMPLE_FILE_TYPE = %gin.REQUIRED | |
FEATURE_LENGTHS = %gin.REQUIRED | |
CHECKPOINT_PATH = %gin.REQUIRED | |
INFER_OUTPUT_DIR = %gin.REQUIRED | |
infer_script.infer: | |
mode = 'predict' | |
model = %MODEL # imported from separate gin file | |
output_dir = %INFER_OUTPUT_DIR | |
dataset_cfg = @utils.DatasetConfig() | |
partitioner = @partitioning.PjitPartitioner() | |
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() | |
checkpoint_period = 100 | |
shard_id = 0 | |
num_shards = 1 | |
partitioning.PjitPartitioner: | |
num_partitions = 1 | |
logical_axis_rules = @partitioning.standard_logical_axis_rules() | |
utils.DatasetConfig: | |
mixture_or_task_name = @infer_script.create_task_from_tfexample_file() | |
task_feature_lengths = %FEATURE_LENGTHS | |
split = 'infer' | |
batch_size = 32 | |
shuffle = False | |
seed = 0 | |
pack = False | |
infer_script.create_task_from_tfexample_file: | |
paths = %TF_EXAMPLE_FILE_PATHS | |
file_type = %TF_EXAMPLE_FILE_TYPE | |
inputs_key = 'inputs' | |
targets_key = None | |
features = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()} | |
# Plumbing to extract the vocabulary directly from MODEL. This is needed to | |
# tokenize the features from the TFExample we aren't provided with vocabularies | |
# via a Task. | |
inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary() | |
models.get_input_vocabulary.model = %MODEL | |
outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary() | |
models.get_output_vocabulary.model = %MODEL | |
utils.RestoreCheckpointConfig: | |
mode = 'specific' | |
path = %CHECKPOINT_PATH | |
dtype = 'bfloat16' | |