t5-parliament-categorisation / eval_categorisation_base.gin
pere's picture
updated eval script
b3a728f
raw
history blame
994 Bytes
from __gin__ import dynamic_registration
import tasks
import __main__ as eval_script
from t5.data import mixtures
from t5x import partitioning
from t5x import utils
include "t5x/examples/t5/mt5/base.gin"
CHECKPOINT_PATH = %gin.REQUIRED # passed via commandline
EVAL_OUTPUT_DIR = %gin.REQUIRED # passed via commandline
DROPOUT_RATE = 0.0 # unused boilerplate
MIXTURE_OR_TASK_NAME = "categorise"
eval_script.evaluate:
model = %MODEL # imported from separate gin file
dataset_cfg = @utils.DatasetConfig()
restore_checkpoint_cfg = @utils.RestoreCheckpointConfig()
output_dir = %EVAL_OUTPUT_DIR
partitioner = @partitioning.PjitPartitioner()
utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
task_feature_lengths = None # Auto-computes the max feature lengths.
split = 'validation'
batch_size = 32
shuffle = False
seed = 42
partitioning.PjitPartitioner.num_partitions = 2
utils.RestoreCheckpointConfig:
path = %CHECKPOINT_PATH
mode = 'specific'