# Defaults for infer.py if using a TFExample file as input. # # # The features from each TFExample are tokenized using the model's vocabulary. # By default, the inputs feature is assumed to be keyed as 'inputs', but this # can be overridden with `create_task_from_tfexample_file.inputs_key`. # # You must also include a binding for MODEL. # # Required to be set: # # - TF_EXAMPLE_FILE_PATHS: The path to read TF Examples from. # - TF_EXAMPLE_FILE_TYPE: The type of file to read TF Examples from. Currently # supported: 'tfrecord', 'recordio', 'sstable'. # - FEATURE_LENGTHS: The maximum length per feature in the TF Examples. # - CHECKPOINT_PATH: The model checkpoint to use for inference # - INFER_OUTPUT_DIR: The dir to write results to. # # # Commonly overridden options: # # - infer.mode # - infer.checkpoint_period # - infer.shard_id # - infer.num_shards # - create_task_from_tfexample_file.inputs_key # - create_task_from_tfexample_file.targets_key # - DatasetConfig.split # - DatasetConfig.batch_size # - RestoreCheckpointConfig.mode # - PjitPartitioner.num_partitions from __gin__ import dynamic_registration import __main__ as infer_script import seqio from t5x import models from t5x import partitioning from t5x import utils # Must be overridden TF_EXAMPLE_FILE_PATHS = %gin.REQUIRED TF_EXAMPLE_FILE_TYPE = %gin.REQUIRED FEATURE_LENGTHS = %gin.REQUIRED CHECKPOINT_PATH = %gin.REQUIRED INFER_OUTPUT_DIR = %gin.REQUIRED infer_script.infer: mode = 'predict' model = %MODEL # imported from separate gin file output_dir = %INFER_OUTPUT_DIR dataset_cfg = @utils.DatasetConfig() partitioner = @partitioning.PjitPartitioner() restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() checkpoint_period = 100 shard_id = 0 num_shards = 1 partitioning.PjitPartitioner: num_partitions = 1 logical_axis_rules = @partitioning.standard_logical_axis_rules() utils.DatasetConfig: mixture_or_task_name = @infer_script.create_task_from_tfexample_file() task_feature_lengths = %FEATURE_LENGTHS split = 'infer' batch_size = 32 shuffle = False seed = 0 pack = False infer_script.create_task_from_tfexample_file: paths = %TF_EXAMPLE_FILE_PATHS file_type = %TF_EXAMPLE_FILE_TYPE inputs_key = 'inputs' targets_key = None features = {'inputs': @inputs/seqio.Feature(), 'targets': @outputs/seqio.Feature()} # Plumbing to extract the vocabulary directly from MODEL. This is needed to # tokenize the features from the TFExample we aren't provided with vocabularies # via a Task. inputs/seqio.Feature.vocabulary = @models.get_input_vocabulary() models.get_input_vocabulary.model = %MODEL outputs/seqio.Feature.vocabulary = @models.get_output_vocabulary() models.get_output_vocabulary.model = %MODEL utils.RestoreCheckpointConfig: mode = 'specific' path = %CHECKPOINT_PATH dtype = 'bfloat16'