# Defaults for infer.py. # # # You must also include a binding for MODEL. # # Required to be set: # # - MIXTURE_OR_TASK_NAME: The SeqIO Task/Mixture to use for inference # - TASK_FEATURE_LENGTHS: The lengths per key in the SeqIO Task to trim features # to. # - CHECKPOINT_PATH: The model checkpoint to use for inference # - INFER_OUTPUT_DIR: The dir to write results to. # # # Commonly overridden options: # # - infer.mode # - infer.checkpoint_period # - infer.shard_id # - infer.num_shards # - DatasetConfig.split # - DatasetConfig.batch_size # - DatasetConfig.use_cached # - RestoreCheckpointConfig.is_tensorflow # - RestoreCheckpointConfig.mode # - PjitPartitioner.num_partitions from __gin__ import dynamic_registration import __main__ as infer_script from t5x import partitioning from t5x import utils # Must be overridden MIXTURE_OR_TASK_NAME = %gin.REQUIRED TASK_FEATURE_LENGTHS = %gin.REQUIRED CHECKPOINT_PATH = %gin.REQUIRED INFER_OUTPUT_DIR = %gin.REQUIRED # DEPRECATED: Import the this module in your gin file. MIXTURE_OR_TASK_MODULE = None infer_script.infer: mode = 'predict' model = %MODEL # imported from separate gin file output_dir = %INFER_OUTPUT_DIR dataset_cfg = @utils.DatasetConfig() partitioner = @partitioning.PjitPartitioner() restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() checkpoint_period = 100 shard_id = 0 num_shards = 1 partitioning.PjitPartitioner: num_partitions = 1 logical_axis_rules = @partitioning.standard_logical_axis_rules() utils.DatasetConfig: mixture_or_task_name = %MIXTURE_OR_TASK_NAME module = %MIXTURE_OR_TASK_MODULE task_feature_lengths = %TASK_FEATURE_LENGTHS use_cached = False split = 'test' batch_size = 32 shuffle = False seed = 0 pack = False utils.RestoreCheckpointConfig: path = %CHECKPOINT_PATH mode = 'specific' dtype = 'bfloat16'